Source code for dendrotweaks.model

# SPDX-FileCopyrightText: 2025 Poirazi Lab <dendrotweaks@dendrites.gr>
# SPDX-License-Identifier: MPL-2.0

# Imports
from typing import List, Union, Callable
import os
from collections import defaultdict

import matplotlib.pyplot as plt
import numpy as np
from numpy import nan
import pandas as pd
import quantities as pq

# DendroTweaks imports
from dendrotweaks.simulators import NeuronSimulator
from dendrotweaks.biophys.io import MODFileLoader
from dendrotweaks.morphology import Domain
from dendrotweaks.biophys.groups import SegmentGroup
from dendrotweaks.biophys.distributions import Distribution
from dendrotweaks.path_manager import PathManager
import dendrotweaks.morphology.reduce as rdc
from dendrotweaks.utils import DOMAINS_TO_GROUPS
from dendrotweaks.utils import DEFAULT_FIT_MODELS

# Mixins
from dendrotweaks.model_io import IOMixin
from dendrotweaks.model_simulation import SimulationMixin

# Warnings configuration
import warnings

def custom_warning_formatter(message, category, filename, lineno, file=None, line=None):
    return f"WARNING: {message}\n({os.path.basename(filename)}, line {lineno})\n"

warnings.formatwarning = custom_warning_formatter



[docs] class Model(IOMixin, SimulationMixin): """ A model object that represents a neuron model. The class incorporates various mixins to separate concerns while maintaining a flat interface. Parameters ---------- name : str The name of the model. simulator_name : str The name of the simulator to use (either 'NEURON' or 'Jaxley'). path_to_data : str The path to the data files where swc and mod files are stored. Attributes ---------- path_to_model : str The path to the model directory. path_manager : PathManager The path manager for the model. mod_loader : MODFileLoader The MOD file loader. simulator_name : str The name of the simulator to use. Default is 'NEURON'. point_tree : PointTree The point tree representing the morphological reconstruction. sec_tree : SectionTree The section tree representing the morphology on the section level. mechanisms : dict A dictionary of mechanisms available for the model. domains_to_mechs : dict A dictionary mapping domains to mechanisms inserted in them. params : dict A dictionary mapping parameters to their distributions. d_lambda : float The spatial discretization parameter. seg_tree : SegmentTree The segment tree representing the morphology on the segment level. iclamps : dict A dictionary of current clamps in the model. populations : dict A dictionary of "virtual" populations forming synapses on the model. simulator : Simulator The simulator object to use. """ def __init__(self, path_to_model, simulator_name='NEURON',) -> None: # Metadata self.path_to_model = path_to_model self._name = os.path.basename(os.path.normpath(path_to_model)) self.morphology_name = '' self.version = '' self.path_manager = PathManager(path_to_model) self.simulator_name = simulator_name self._verbose = False # File managers self.mod_loader = MODFileLoader() # Morphology self.point_tree = None self.sec_tree = None self.domains = {} # Mechanisms self.mechanisms = {} self.domains_to_mechs = {} # Parameters self.params = { 'cm': {'all': Distribution('constant', value=1)}, # uF/cm2 'Ra': {'all': Distribution('constant', value=35.4)}, # Ohm cm } self.params_to_units = { 'cm': pq.uF/pq.cm**2, 'Ra': pq.ohm*pq.cm, } # Groups self._groups = [] # Distributions # self.distributed_params = {} # Segmentation self.d_lambda = 0.1 self.seg_tree = None # Stimuli self.iclamps = {} self.populations = {} # Simulator if simulator_name == 'NEURON': self.simulator = NeuronSimulator() elif simulator_name == 'Jaxley': self.simulator = JaxleySimulator() else: raise ValueError( 'Simulator name not recognized. Use NEURON or Jaxley.') # ----------------------------------------------------------------------- # PROPERTIES # ----------------------------------------------------------------------- @property def name(self): """ The name of the directory containing the model. """ return self._name @property def verbose(self): """ Whether to print verbose output. """ return self._verbose @verbose.setter def verbose(self, value): self._verbose = value self.mod_loader.verbose = value @property def mechs_to_domains(self): """ The dictionary mapping mechanisms to domains where they are inserted. """ mechs_to_domains = defaultdict(set) for domain_name, mech_names in self.domains_to_mechs.items(): for mech_name in mech_names: mechs_to_domains[mech_name].add(domain_name) return dict(mechs_to_domains) @property def groups(self): """ The dictionary of segment groups in the model. """ return {group.name: group for group in self._groups} @property def groups_to_parameters(self): """ The dictionary mapping segment groups to parameters. """ groups_to_parameters = {} for group in self._groups: groups_to_parameters[group.name] = {} for mech_name, params in self.mechs_to_params.items(): if mech_name not in group.mechanisms: continue groups_to_parameters[group.name] = params return groups_to_parameters @property def parameters_to_groups(self): """ The dictionary mapping parameters to groups where they are distributed. """ parameters_to_groups = defaultdict(list) for group in self._groups: for mech_name, params in self.mechs_to_params.items(): if mech_name not in group.mechanisms: continue for param in params: parameters_to_groups[param].append(group.name) return dict(parameters_to_groups) @property def params_to_mechs(self): """ The dictionary mapping parameters to mechanisms to which they belong. """ params_to_mechs = {} # Sort mechanisms by length (longer first) to ensure specific matches sorted_mechs = sorted(self.mechanisms, key=len, reverse=True) for param in self.params: matched = False for mech in sorted_mechs: suffix = f"_{mech}" # Define exact suffix if param.endswith(suffix): params_to_mechs[param] = mech matched = True break if not matched: params_to_mechs[param] = "Independent" # No match found return params_to_mechs @property def mechs_to_params(self): """ The dictionary mapping mechanisms to parameters they contain. """ mechs_to_params = defaultdict(list) for param, mech_name in self.params_to_mechs.items(): mechs_to_params[mech_name].append(param) return dict(mechs_to_params) @property def conductances(self): """ A filtered dictionary of parameters that represent conductances. """ return {param: value for param, value in self.params.items() if param.startswith('gbar')} @property def df_params(self): """ A DataFrame of parameters and their distributions. """ data = [] for mech_name, params in self.mechs_to_params.items(): for param in params: for group_name, distribution in self.params[param].items(): data.append({ 'Mechanism': mech_name, 'Parameter': param, 'Group': group_name, 'Distribution': distribution if isinstance(distribution, str) else distribution.function_name, 'Distribution params': {} if isinstance(distribution, str) else distribution.parameters, }) df = pd.DataFrame(data) return df # ======================================================================== # DOMAINS # ========================================================================
[docs] def add_domain(self, name, type_idx, color, sections, distribute=True): """ Adds a new empty domain to the model. Parameters ---------- name : str The name of the domain. color : str The color assigned to the domain. type_idx : int The type index of the domain. sections : list[Section] The sections to include in the domain. distribute : bool, optional Whether to re-distribute the parameters after defining the domain. Default is True. Notes ----- This method does not automatically insert mechanisms into the newly created domain. It is the user's responsibility to insert mechanisms into the domain after its creation. Suggested type indices and colors: 1: soma: orange 2: axon: gold 3: dend: forestgreen 31: basal: seagreen 4: apic: steelblue 41: trunk: skyblue 42: tuft: plum 43: oblique: rosybrown """ if name in self.domains: raise ValueError(f"Domain '{name}' already exists.") if not name or not name.strip(): raise ValueError("Domain name cannot be empty.") if not sections: raise ValueError('No sections provided to define the domain.') complement_sections = set(self.sec_tree.sections) - set(sections) self._validate_domain_type_idx(complement_sections, type_idx) self._validate_domain_color(complement_sections, color) domain = Domain(name, type_idx, color) self._add_domain_groups(domain.name) self.domains[domain.name] = domain self.domains_to_mechs[domain.name] = set() self.extend_domain(name, sections, distribute=distribute)
def _validate_domain_type_idx(self, complement_sections, type_idx): unique_complement_type_ids = set(sec.type_idx for sec in complement_sections) if type_idx in unique_complement_type_ids: raise ValueError(f'Type index {type_idx} is already used by another domain.') def _validate_domain_color(self, complement_sections, color): unique_complement_colors = set(sec.domain_color for sec in complement_sections) if color in unique_complement_colors: raise ValueError(f'Color {color} is already used by another domain.')
[docs] def update_domain_name(self, old_name, new_name): """ Update the name of a domain. Parameters ---------- old_name : str The current name of the domain. new_name : str The new name to assign to the domain. """ if new_name in self.domains: raise ValueError(f'Domain {new_name} already exists.') domain = self.domains[old_name] domain.name = new_name self.domains[domain.name] = domain self.domains.pop(old_name) # Update groups self._add_domain_groups(domain.name) self._remove_domain_groups(old_name) # Update domains_to_mechs self.domains_to_mechs[domain.name] = self.domains_to_mechs.pop(old_name) self._remove_empty()
[docs] def update_domain_type_idx(self, name, new_type_idx): """ Update the type index of a domain. Notes ----- Suggested type indices: 1: soma 2: axon 3: dend 31: basal 4: apic 41: trunk 42: tuft 43: oblique """ domain = self.domains[name] if new_type_idx is not None: if new_type_idx == domain.type_idx: return existing_type_ids = set(domain.type_idx for domain in self.domains.values()) - {domain.type_idx} if new_type_idx in existing_type_ids: raise ValueError(f'Type index {new_type_idx} is already used by another domain.') domain.type_idx = new_type_idx
[docs] def update_domain_color(self, name, new_color, force=False): """ Update the color of a domain. Notes ----- Suggested colors: soma: orange axon: gold dend: forestgreen basal: seagreen apic: steelblue trunk: skyblue tuft: plum oblique: rosybrown """ domain = self.domains[name] if new_color is not None: if new_color == domain.color: return existing_colors = set(domain.color for domain in self.domains.values()) - {domain.color} if new_color in existing_colors and not force: raise ValueError(f'Color {new_color} is already used by another domain.') domain.color = new_color
[docs] def extend_domain(self, name, sections, distribute=True): """ Extends an existing domain by adding sections to it. Parameters ---------- name : str The name of the domain to extend. sections : list[Section] The sections to add to the domain. distribute : bool, optional Whether to re-distribute the parameters after extending the domain. Default is True. Notes ----- If the domain already exists and is being extended, mechanisms will be inserted automatically into the newly added sections. """ domain = self.domains.get(name) if domain is None: raise ValueError(f'Domain {name} does not exist.') if not sections: raise ValueError('No sections provided to extend the domain.') # Find sections that are not in the domain yet sections_to_move = [sec for sec in sections if sec.domain_name != name] if not sections_to_move: warnings.warn(f'Sections already in domain {name}.') return # Remove sections from their old domains for sec in sections_to_move: old_domain = self.domains[sec.domain_name] old_domain.remove_section(sec) for mech_name in self.domains_to_mechs[old_domain.name]: # TODO: What if section is already in domain? Can't be as # we use a filtered list of sections. mech = self.mechanisms[mech_name] sec.uninsert_mechanism(mech) # Add sections to the new domain for sec in sections_to_move: domain.add_section(sec) # Important: here we insert mechanisms only if we extend the domain, # i.e. the domain already exists and has mechanisms. # If the domain is new, we DO NOT insert mechanisms automatically # and leave it to the user to do so. for mech_name in self.domains_to_mechs.get(domain.name, set()): mech = self.mechanisms[mech_name] sec.insert_mechanism(mech) self._remove_empty() self.sec_tree.sort(sort_children=True, force=True) if distribute: self.distribute_all()
def _add_domain_groups(self, domain_name): """ Manage groups when a domain is added. """ # Add new domain to `all` group if self.groups.get('all'): self.groups['all'].domains.append(domain_name) # Create a new group for the domain group_name = DOMAINS_TO_GROUPS.get(domain_name, domain_name) self.add_group(group_name, domains=[domain_name]) def _remove_domain_groups(self, domain_name): """ Manage groups when a domain is removed. """ for group in self._groups: if domain_name in group.domains: group.domains.remove(domain_name) def _remove_empty(self): self._remove_empty_domains() self._remove_uninserted_mechanisms() self._remove_empty_groups() def _remove_empty_domains(self): """ """ empty_domains = [domain for domain in self.domains.values() if domain.is_empty()] for domain in empty_domains: warnings.warn(f'Domain {domain.name} is empty and will be removed.') self.domains.pop(domain.name) self.domains_to_mechs.pop(domain.name) self._remove_domain_groups(domain.name) def _remove_uninserted_mechanisms(self): mech_names = list(self.mechs_to_params.keys()) mechs = [self.mechanisms[mech_name] for mech_name in mech_names if mech_name != 'Independent'] uninserted_mechs = [mech for mech in mechs if mech.name not in self.mechs_to_domains] for mech in uninserted_mechs: warnings.warn(f'Mechanism {mech.name} is not inserted in any domain and will be removed.') self._remove_mechanism_params(mech) def _remove_empty_groups(self): empty_groups = [group for group in self._groups if not any(seg in group for seg in self.seg_tree)] for group in empty_groups: warnings.warn(f'Group {group.name} is empty and will be removed.') self.remove_group(group.name) # ======================================================================== # MECHANISMS # ========================================================================
[docs] def insert_mechanism(self, mechanism_name: str, domain_name: str, distribute=True): """ Insert a mechanism into all sections in a domain. Parameters ---------- mechanism_name : str The name of the mechanism to insert. domain_name : str The name of the domain to insert the mechanism into. distribute : bool, optional Whether to distribute the parameters after inserting the mechanism. """ mech = self.mechanisms[mechanism_name] domain = self.domains[domain_name] # domain.insert_mechanism(mech) self.domains_to_mechs[domain_name].add(mech.name) for sec in domain.sections: sec.insert_mechanism(mech) self._add_mechanism_params(mech) # TODO: Redistribute parameters if any group contains this domain if distribute: for param_name in self.params: self.distribute(param_name)
def _add_mechanism_params(self, mech): """ Update the parameters when a mechanism is inserted. By default each parameter is set to a constant value through the entire cell. """ for param_name, value in mech.range_params_with_suffix.items(): self.params[param_name] = {'all': Distribution('constant', value=value)} if hasattr(mech, 'ion') and mech.ion in ['na', 'k', 'ca']: self._add_equilibrium_potentials_on_mech_insert(mech.ion) def _add_equilibrium_potentials_on_mech_insert(self, ion: str) -> None: """ """ if ion == 'na' and not self.params.get('ena'): self.params['ena'] = {'all': Distribution('constant', value=50)} elif ion == 'k' and not self.params.get('ek'): self.params['ek'] = {'all': Distribution('constant', value=-77)} elif ion == 'ca' and not self.params.get('eca'): self.params['eca'] = {'all': Distribution('constant', value=140)}
[docs] def uninsert_mechanism(self, mechanism_name: str, domain_name: str): """ Uninsert a mechanism from all sections in a domain Parameters ---------- mechanism_name : str The name of the mechanism to uninsert. domain_name : str The name of the domain to uninsert the mechanism from. """ mech = self.mechanisms[mechanism_name] domain = self.domains[domain_name] # domain.uninsert_mechanism(mech) for sec in domain.sections: sec.uninsert_mechanism(mech) self.domains_to_mechs[domain_name].remove(mech.name) if not self.mechs_to_domains.get(mech.name): warnings.warn(f'Mechanism {mech.name} is not inserted in any domain and will be removed.') self._remove_mechanism_params(mech)
def _remove_mechanism_params(self, mech): for param_name in self.mechs_to_params.get(mech.name, []): self.params.pop(param_name) if hasattr(mech, 'ion') and mech.ion in ['na', 'k', 'ca']: self._remove_equilibrium_potentials_on_mech_uninsert(mech.ion) def _remove_equilibrium_potentials_on_mech_uninsert(self, ion: str) -> None: """ """ for mech_name, mech in self.mechanisms.items(): if hasattr(mech, 'ion'): if mech.ion == mech.ion: return if ion == 'na': self.params.pop('ena', None) elif ion == 'k': self.params.pop('ek', None) elif ion == 'ca': self.params.pop('eca', None) # ======================================================================== # PARAMETERS # ======================================================================== # ----------------------------------------------------------------------- # SEGMENT GROUPS (Where) # -----------------------------------------------------------------------
[docs] def add_group(self, name, domains, select_by=None, min_value=None, max_value=None): """ Add a group of sections to the model. Parameters ---------- name : str The name of the group. domains : list[str] The domains to include in the group. select_by : str, optional The parameter to select the sections by. Can be 'diam', 'distance', 'domain_distance'. min_value : float, optional The minimum value of the parameter. max_value : float, optional The maximum value of the """ if self.verbose: print(f'Adding group {name}...') group = SegmentGroup(name, domains, select_by, min_value, max_value) self._groups.append(group)
[docs] def remove_group(self, group_name): """ Remove a group from the model. Parameters ---------- group_name : str The name of the group to remove. """ # Remove group from the list of groups self._groups = [group for group in self._groups if group.name != group_name] # Remove distributions that refer to this group for param_name, groups_to_distrs in self.params.items(): groups_to_distrs.pop(group_name, None)
[docs] def move_group_down(self, name): """ Move a group down in the list of groups. Parameters ---------- name : str The name of the group to move down. """ idx = next(i for i, group in enumerate(self._groups) if group.name == name) if idx > 0: self._groups[idx-1], self._groups[idx] = self._groups[idx], self._groups[idx-1] for param_name in self.distributed_params: self.distribute(param_name)
[docs] def move_group_up(self, name): """ Move a group up in the list of groups. Parameters ---------- name : str The name of the group to move up. """ idx = next(i for i, group in enumerate(self._groups) if group.name == name) if idx < len(self._groups) - 1: self._groups[idx+1], self._groups[idx] = self._groups[idx], self._groups[idx+1] for param_name in self.distributed_params: self.distribute(param_name)
# ----------------------------------------------------------------------- # DISTRIBUTIONS (How) # -----------------------------------------------------------------------
[docs] def set_param(self, param_name: str, group_name: str = 'all', distr_type: str = 'constant', **distr_params): """ Set a parameter for a group of segments. Parameters ---------- param_name : str The name of the parameter to set. group_name : str, optional The name of the group to set the parameter for. Default is 'all'. distr_type : str, optional The type of the distribution to use. Default is 'constant'. distr_params : dict The parameters of the distribution. """ if 'group' in distr_params: raise ValueError("Did you mean 'group_name' instead of 'group'?") if param_name in ['temperature', 'v_init']: setattr(self.simulator, param_name, distr_params['value']) return for key, value in distr_params.items(): if not isinstance(value, (int, float)) or value is nan: raise ValueError(f"Parameter '{key}' must be a numeric value and not NaN, got {type(value).__name__} instead.") self.set_distribution(param_name, group_name, distr_type, **distr_params) self.distribute(param_name)
[docs] def set_distribution(self, param_name: str, group_name: None, distr_type: str = 'constant', **distr_params): """ Set a distribution for a parameter. Parameters ---------- param_name : str The name of the parameter to set. group_name : str, optional The name of the group to set the parameter for. Default is 'all'. distr_type : str, optional The type of the distribution to use. Default is 'constant'. distr_params : dict The parameters of the distribution. """ if distr_type == 'inherit': distribution = 'inherit' else: distribution = Distribution(distr_type, **distr_params) self.params[param_name][group_name] = distribution
[docs] def distribute_all(self): """ Distribute all parameters to the segments. """ groups_to_segments = {group.name: [seg for seg in self.seg_tree if seg in group] for group in self._groups} for param_name in self.params: self.distribute(param_name, groups_to_segments)
[docs] def distribute(self, param_name: str, precomputed_groups=None): """ Distribute a parameter to the segments. Parameters ---------- param_name : str The name of the parameter to distribute. precomputed_groups : dict, optional A dictionary mapping group names to segments. Default is None. """ if param_name == 'Ra': self._distribute_Ra(precomputed_groups) return groups_to_segments = precomputed_groups if groups_to_segments is None: groups_to_segments = {group.name: [seg for seg in self.seg_tree if seg in group] for group in self._groups} param_distributions = self.params[param_name] for group_name, distribution in param_distributions.items(): filtered_segments = groups_to_segments[group_name] if distribution == 'inherit': for seg in filtered_segments: value = seg.parent.get_param_value(param_name) seg.set_param_value(param_name, value) else: for seg in filtered_segments: value = distribution(seg.path_distance()) seg.set_param_value(param_name, value)
def _distribute_Ra(self, precomputed_groups=None): """ Distribute the axial resistance to the segments. """ groups_to_segments = precomputed_groups if groups_to_segments is None: groups_to_segments = {group.name: [seg for seg in self.seg_tree if seg in group] for group in self._groups} param_distributions = self.params['Ra'] for group_name, distribution in param_distributions.items(): filtered_segments = groups_to_segments[group_name] if distribution == 'inherit': raise NotImplementedError("Inheritance of Ra is not implemented.") else: for seg in filtered_segments: value = distribution(seg._section.path_distance(relative_position=0.5)) seg._section._ref.Ra = value
[docs] def remove_distribution(self, param_name, group_name): """ Remove a distribution for a parameter. Parameters ---------- param_name : str The name of the parameter to remove the distribution for. group_name : str The name of the group to remove the distribution for. """ self.params[param_name].pop(group_name, None) self.distribute(param_name)
# ----------------------------------------------------------------------- # FITTING # -----------------------------------------------------------------------
[docs] def fit_distribution(self, param_name: str, segments, candidate_models=None, plot=True): if candidate_models is None: candidate_models = DEFAULT_FIT_MODELS from dendrotweaks.utils import mse values = [seg.get_param_value(param_name) for seg in segments] if all(np.isnan(values)): return None distances = [seg.path_distance() for seg in segments] distances, values = zip(*sorted(zip(distances, values))) best_score = float('inf') best_model = None best_params = None best_pred = None results = [] for name, model in candidate_models.items(): try: params, pred_values = model['fit'](distances, values) score = model.get('score', mse)(values, pred_values) complexity = model.get('complexity', 1)(params) results.append((name, score, params, complexity, pred_values)) except Exception as e: warnings.warn(f"Model {name} failed to fit: {e}") # Sort results by score and complexity results.sort(key=lambda x: (np.round(x[1], 10), x[3])) best_model, best_score, best_params, _, best_pred = results[0] if plot: self.plot_param(param_name, show_nan=False) plt.plot(distances, best_pred, label=f'Best Fit: {best_model}', color='red', linestyle='--') plt.legend() return {'model': best_model, 'params': best_params, 'score': best_score}
def _set_distribution(self, param_name, group_name, fit_result, plot=False): if fit_result is None: warnings.warn(f"No valid fit found for parameter {param_name}. Skipping distribution assignment.") return model_type = fit_result['model'] params = fit_result['params'] if model_type == 'poly': coeffs = np.array(params) coeffs = np.where(np.round(coeffs) == 0, coeffs, np.round(coeffs, 10)) if len(coeffs) == 1: self.params[param_name][group_name] = Distribution('constant', value=coeffs[0]) elif len(coeffs) == 2: self.params[param_name][group_name] = Distribution('linear', slope=coeffs[0], intercept=coeffs[1]) else: self.params[param_name][group_name] = Distribution('polynomial', coeffs=coeffs.tolist()) elif model_type == 'step': start, end, min_value, max_value = params self.params[param_name][group_name] = Distribution('step', max_value=max_value, min_value=min_value, start=start, end=end) # ----------------------------------------------------------------------- # PLOTTING # -----------------------------------------------------------------------
[docs] def plot_param(self, param_name, ax=None, show_nan=True): """ Plot the distribution of a parameter in the model. Parameters ---------- param_name : str The name of the parameter to plot. ax : matplotlib.axes.Axes, optional The axes to plot on. Default is None. show_nan : bool, optional Whether to show NaN values. Default is True. """ if ax is None: fig, ax = plt.subplots(figsize=(10, 2)) if param_name not in self.params: warnings.warn(f'Parameter {param_name} not found.') values = [(seg.path_distance(), seg.get_param_value(param_name)) for seg in self.seg_tree] colors = [seg.domain_color for seg in self.seg_tree] valid_values = [(x, y) for (x, y), color in zip(values, colors) if not pd.isna(y) and y != 0] zero_values = [(x, y) for (x, y), color in zip(values, colors) if y == 0] nan_values = [(x, 0) for (x, y), color in zip(values, colors) if pd.isna(y)] valid_colors = [color for (x, y), color in zip(values, colors) if not pd.isna(y) and y != 0] zero_colors = [color for (x, y), color in zip(values, colors) if y == 0] nan_colors = [color for (x, y), color in zip(values, colors) if pd.isna(y)] if valid_values: ax.scatter(*zip(*valid_values), c=valid_colors) if zero_values: ax.scatter(*zip(*zero_values), edgecolors=zero_colors, facecolors='none', marker='.') if nan_values and show_nan: ax.scatter(*zip(*nan_values), c=nan_colors, marker='x', alpha=0.5, zorder=0) ax.axhline(y=0, color='k', linestyle='--') ax.set_xlabel('Path distance') ax.set_ylabel(param_name) ax.set_title(f'{param_name} distribution')
# ======================================================================== # MORPHOLOGY # ========================================================================
[docs] def get_sections(self, filter_function): """Filter sections using a lambda function. Parameters ---------- filter_function : Callable The lambda function to filter sections. """ return [sec for sec in self.sec_tree.sections if filter_function(sec)]
[docs] def get_segments(self, group_names=None): """ Get the segments in specified groups. Parameters ---------- group_names : List[str] The names of the groups to get segments from. """ if not isinstance(group_names, list): raise ValueError('Group names must be a list.') return [seg for group_name in group_names for seg in self.seg_tree.segments if seg in self.groups[group_name]]
[docs] def remove_subtree(self, section): """ Remove a subtree from the model. Parameters ---------- section : Section The root section of the subtree to remove. """ for domain in self.domains.values(): for sec in section.subtree: if sec in domain.sections: domain.remove_section(sec) self.sec_tree.remove_subtree(section) self._remove_empty()
[docs] def merge_domains(self, domain_names: List[str]): """ Merge two domains into one. """ domains = [self.domains[domain_name] for domain_name in domain_names] for domain in domains[1:]: domains[0].merge(domain) self.remove_empty()
[docs] def reduce_subtree(self, root, reduction_frequency=0, total_segments_manual=-1, fit=True): """ Reduce a subtree to a single section. Parameters ---------- root : Section The root section of the subtree to reduce. reduction_frequency : float, optional The frequency of the reduction. Default is 0. total_segments_manual : int, optional The number of segments in the reduced subtree. Default is -1 (automatic). fit : bool, optional Whether to create distributions for the reduced subtree by fitting the calculated average values. Default is True. """ domain_name = root.domain_name parent = root.parent domains_in_subtree = [self.domains[domain_name] for domain_name in set([sec.domain_name for sec in root.subtree])] if len(domains_in_subtree) > 1: # ensure the domains have the same mechanisms using self.domains_to_mechs domains_to_mechs = {domain_name: mech_names for domain_name, mech_names in self.domains_to_mechs.items() if domain_name in [domain.name for domain in domains_in_subtree]} common_mechs = set.intersection(*domains_to_mechs.values()) if not all(mech_names == common_mechs for mech_names in domains_to_mechs.values()): raise ValueError( 'The domains in the subtree have different mechanisms. ' 'Please ensure that all domains in the subtree have the same mechanisms. ' 'You may need to insert the missing mechanisms and set their conductances to 0 where they are not needed.' ) elif len(domains_in_subtree) == 1: common_mechs = self.domains_to_mechs[domain_name].copy() inserted_mechs = {mech_name: mech for mech_name, mech in self.mechanisms.items() if mech_name in self.domains_to_mechs[domain_name] } subtree_without_root = [sec for sec in root.subtree if sec is not root] # Map original segment names to their parameters segs_to_params = rdc.map_segs_to_params(root, inserted_mechs) # Temporarily remove active mechanisms for mech_name in inserted_mechs: if mech_name == 'Leak': continue for sec in root.subtree: mech = self.mechanisms[mech_name] sec.uninsert_mechanism(mech) # Disconnect root.disconnect_from_parent() # Calculate new properties of a reduced subtree new_cable_properties = rdc.get_unique_cable_properties(root._ref, reduction_frequency) new_nseg = rdc.calculate_nsegs(new_cable_properties, total_segments_manual) print(new_cable_properties) # Map segment names to their new locations in the reduced cylinder segs_to_locs = rdc.map_segs_to_locs(root, reduction_frequency, new_cable_properties) # Reconnect root.connect_to_parent(parent) # Delete the original subtree children = root.children[:] for child_sec in children: self.remove_subtree(child_sec) # Set passive mechanisms for the reduced cylinder: rdc.apply_params_to_section(root, new_cable_properties, new_nseg) # Reinsert active mechanisms for mech_name in inserted_mechs: if mech_name == 'Leak': continue for sec in root.subtree: mech = self.mechanisms[mech_name] sec.insert_mechanism(mech) # Replace locs with corresponding segs segs_to_reduced_segs = rdc.map_segs_to_reduced_segs(segs_to_locs, root) # Map reduced segments to lists of parameters of corresponding original segments reduced_segs_to_params = rdc.map_reduced_segs_to_params(segs_to_reduced_segs, segs_to_params) # Set new values of parameters rdc.set_avg_params_to_reduced_segs(reduced_segs_to_params) rdc.interpolate_missing_values(reduced_segs_to_params, root) data = { 'segs_to_params': segs_to_params, 'segs_to_locs': segs_to_locs, 'segs_to_reduced_segs': segs_to_reduced_segs, 'reduced_segs_to_params': reduced_segs_to_params, } if not fit: return data root_segs = [seg for seg in root.segments] params_to_fits = {} # for param_name in self.params: common_mechs.add('Independent') for mech in common_mechs: for param_name in self.mechs_to_params[mech]: fit_result = self.fit_distribution(param_name, segments=root_segs, plot=False) params_to_fits[param_name] = fit_result # Create new domain reduced_domains = [domain_name for domain_name in self.domains if domain_name.startswith('reduced')] new_reduced_domain_name = f'reduced_8{len(reduced_domains)}' new_reduced_domain_type_idx = int(f'8{len(reduced_domains)}') group_name = new_reduced_domain_name old_domain = root.domain_name self.update_domain_name(old_domain, new_reduced_domain_name) self.update_domain_type_idx(new_reduced_domain_name, new_reduced_domain_type_idx) self.update_domain_color(new_reduced_domain_name, 'palevioletred') # # Fit distributions to data for the group for param_name, fit_result in params_to_fits.items(): self._set_distribution(param_name, group_name, fit_result, plot=True) # # Distribute parameters self.distribute_all() data.update({ 'params_to_fits': params_to_fits, 'domain_name': new_reduced_domain_name, 'group_name': group_name,}) return data