Source code for dendrotweaks.simulators

# SPDX-FileCopyrightText: 2025 Poirazi Lab <dendrotweaks@dendrites.gr>
# SPDX-License-Identifier: MPL-2.0

from collections import defaultdict
import warnings
from functools import cached_property

import matplotlib.pyplot as plt
import neuron
from neuron import h
from neuron.units import ms, mV
h.load_file('stdrun.hoc')
# h.load_file('import3d.hoc')
# h.load_file('nrngui.hoc')
# h.load_file('import3d')
import numpy as np

import contextlib

@contextlib.contextmanager
def push_section(section):
    section.push()
    yield
    h.pop_section()

def reset_neuron():

    # h('forall delete_section()')
    # h('forall delete_all()')
    # h('forall delete()')

    for sec in h.allsec():
        with push_section(sec):
            h.delete_section()

reset_neuron()            

# -------------------------------------------------------
# SIMULATOR
# -------------------------------------------------------

class Simulator:
    """
    A generic simulator class.
    """
    def __init__(self):
        self._t = None
        self.dt = None
        self._recordings = {'v': {}}

    def plot_var(self, var='v', ax=None, segments=None, **kwargs):
        if self._t is None:
            raise ValueError('Simulation has not been run yet.')
        if var not in self.recordings:
            raise ValueError(f'Variable {var} not recorded.')
        if ax is None:
            fig, ax = plt.subplots()
        if segments is None:
            segments = self.recordings[var].keys()
        for seg, x in self.recordings[var].items():
            if segments and seg not in segments:
                continue
            ax.plot(self.t, x, label=f'{var} {seg.domain_name} {seg.idx}', **kwargs)
        if len(segments) < 10:
            ax.legend()
        ax.set_xlabel('Time (ms)')
        if var == 'v':
            ax.set_ylabel('Voltage (mV)')
        elif var.startswith('i_'):
            ax.set_ylabel('Current (nA)')
        return ax

    def plot_voltage(self, **kwargs):
        """
        Plot the recorded voltages.
        """
        self.plot_var('v', **kwargs)
    
    def plot_currents(self, **kwargs):
        """
        Plot the recorded currents.
        """
        ax = kwargs.pop('ax', None)
        for var in self.recordings:
            if var.startswith('i_'):
                ax = self.plot_var(var, ax=ax, **kwargs)
            
        

[docs] class NeuronSimulator(Simulator): """ A class to represent the NEURON simulator. Parameters ---------- temperature : float The temperature of the simulation in Celsius. v_init : float The initial membrane potential of the neuron in mV. dt : float The time step of the simulation in ms. cvode : bool Whether to use the CVode variable time step integrator. Attributes ---------- temperature : float The temperature of the simulation in Celsius. v_init : float The initial membrane potential of the neuron in mV. dt : float The time step of the simulation in ms. """ def __init__(self, temperature=37, v_init=-70, dt=0.025, cvode=False): super().__init__() self.temperature = temperature self.v_init = v_init * mV self._duration = 300 self.dt = dt self._cvode = cvode @cached_property def recordings(self): return { var:{ seg: vec.to_python() for seg, vec in recs.items() } for var, recs in self._recordings.items() } @cached_property def t(self): return self._t.to_python() def _clean_cache(self): """ Clean the cache of the simulator. """ try: del self.recordings del self.t except AttributeError: # Property hasn't been accessed yet, so no need to delete pass
[docs] def add_recording(self, sec, loc, var='v'): """ Add a recording to the simulator. Parameters ---------- sec : Section The section to record from. loc : float The location along the normalized section length to record from. var : str The variable to record. Default is 'v' (voltage). """ seg = sec(loc) if not hasattr(seg._ref, f'_ref_{var}'): raise ValueError(f'Segment {seg} does not have variable {var}.') if self._recordings.get(var, {}).get(seg): self.remove_recording(sec, loc, var) if var not in self._recordings: self._recordings[var] = {} self._recordings[var][seg] = h.Vector().record(getattr(seg._ref, f'_ref_{var}')) self._clean_cache()
[docs] def remove_recording(self, sec, loc, var='v'): """ Remove a recording from the simulator. Parameters ---------- sec : Section The section to remove the recording from. loc : float The location along the normalized section length to remove the recording from. """ seg = sec(loc) if seg in self._recordings[var]: self._recordings[var][seg] = None self._recordings[var].pop(seg) if not self._recordings[var]: self._recordings.pop(var) self._clean_cache()
[docs] def remove_all_recordings(self, var=None): """ Remove all recordings from the simulator. """ variables = [var] if var else list(self._recordings.keys()) for variable in variables: for seg in list(self._recordings.get(variable, {}).keys()): self.remove_recording(seg._section, seg.x, variable) if self._recordings.get(variable): warnings.warn(f'Not all recordings were removed for variable {variable}: {self._recordings}')
def _init_simulation(self): h.celsius = self.temperature if self._cvode: h.cvode.active(1) else: h.cvode.active(0) h.dt = self.dt h.finitialize(self.v_init) if self._cvode: h.cvode.re_init() else: h.fcurrent() h.frecord_init()
[docs] def run(self, duration=300): """ Run a simulation. Parameters ---------- duration : float The duration of the simulation in milliseconds. """ self._clean_cache() self._duration = duration self._t = h.Vector().record(h._ref_t) self._init_simulation() h.continuerun(duration * ms)
[docs] def to_dict(self): """ Convert the simulator to a dictionary. Returns ------- dict A dictionary representation of the simulator. """ return { 'temperature': self.temperature, 'v_init': self.v_init, 'dt': self.dt, 'duration': self._duration }
[docs] def from_dict(self, data): """ Create a simulator from a dictionary. Parameters ---------- data : dict The dictionary representation of the simulator. """ self.temperature = data['temperature'] self.v_init = data['v_init'] self.dt = data['dt'] self._duration = data['duration']
class JaxleySimulator(Simulator): """ A class to represent a Jaxley simulator. """ def __init__(self): super().__init__() ...