Source code for dendrotweaks.morphology.sec_trees

# SPDX-FileCopyrightText: 2025 Poirazi Lab <dendrotweaks@dendrites.gr>
# SPDX-License-Identifier: MPL-2.0

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
import matplotlib.colors as mcolors

from typing import Callable, List
from neuron import h

from dendrotweaks.morphology.trees import Node, Tree
from dendrotweaks.morphology.domains import Domain
from dataclasses import dataclass, field
from bisect import bisect_left
from functools import cached_property

import warnings


def custom_warning_formatter(message, category, filename, lineno, file=None, line=None):
    return f"{category.__name__}: {message} ({os.path.basename(filename)}, line {lineno})\n"

warnings.formatwarning = custom_warning_formatter



[docs] class Section(Node): """ A class representing a section in a neuron morphology. A section is continuous part of a neuron's morphology between two branching points. Parameters ---------- idx : str The index of the section. parent_idx : str The index of the parent section. points : List[Point] The points that define the section. Attributes ---------- points : List[Point] The points that define the section. segments : List[Segment] The segments to which the section is divided. _ref : h.Section The reference to the NEURON section. domain : Domain The domain to which the section belongs. idx_within_domain : int The index of the section within its domain. """ def __init__(self, idx: str, parent_idx: str, points: List[Node]) -> None: super().__init__(idx, parent_idx) self.points = points self.segments = [] self._domain = None self.idx_within_domain = None self._ref = None self._nseg = None self._cell = None if not all(pt.domain_name == self.domain_name for pt in points): raise ValueError('All points in a section must belong to the same domain.') # MAGIC METHODS def __call__(self, x: float): """ Return the segment at a given position. """ if self._ref is None: raise ValueError('Section is not referenced in NEURON.') if x < 0 or x > 1: raise ValueError('Location x must be in the range [0, 1].') elif x == 0: # TODO: Decide how to handle sec(0) and sec(1) # as they are not shown in the seg_graph return self.segments[0] elif x == 1: return self.segments[-1] matching_segs = [self._ref(x) == seg._ref for seg in self.segments] if any(matching_segs): return self.segments[matching_segs.index(True)] raise ValueError(f'No segment found at location {x}') def __iter__(self): """ Iterate over the segments in the section. """ for seg in self.segments: yield seg # PROPERTIES @property def domain_name(self): """ The morphological or functional domain of the node. """ return self.points[0].domain_name @property def type_idx(self): """ The type index of the section based on its domain. """ return self.points[0].type_idx @property def domain_color(self): """ The color of the section based on its domain. """ return self.points[0].domain_color @property def df_points(self): """ A DataFrame of the points in the section. """ # concatenate the dataframes of the nodes return pd.concat([pt.df for pt in self.points]) @property def df(self): """ A DataFrame of the section. """ return pd.DataFrame({'idx': [self.idx], 'parent_idx': [self.parent_idx]}) @property def radii(self): """ Radii of the points in the section. """ return [pt.r for pt in self.points] @property def diameters(self): """ Diameters of the points in the section. """ return [2 * pt.r for pt in self.points] @property def xs (self): """ X-coordinates of the points in the section. """ return [pt.x for pt in self.points] @property def ys(self): """ Y-coordinates of the points in the section. """ return [pt.y for pt in self.points] @property def zs(self): """ Z-coordinates of the points in the section. """ return [pt.z for pt in self.points]
[docs] def get_location_coordinates(self, loc): """ Get point at a normalized location along the section. Parameters ---------- loc : float The normalized location along the section (0 = start, 1 = end). Returns ------- tuple The (x, y, z) coordinates of the point at the given location. """ # Validate input if loc < 0 or loc > 1: raise ValueError(f"Location must be between 0 and 1, got {loc}") # Get cumulative distances and normalize cumulative_distances = self.distances total_length = self.length normalized_distances = cumulative_distances / total_length # Handle edge cases if loc == 0: pt = self.points[0] return pt.x, pt.y, pt.z if loc == 1: pt = self.points[-1] return pt.x, pt.y, pt.z # Find which segment contains the target location segment_idx = np.searchsorted(normalized_distances, loc) - 1 # Calculate position within that segment segment_start_norm = normalized_distances[segment_idx] segment_end_norm = normalized_distances[segment_idx + 1] # Local parameter t within the segment (0 to 1) t = (loc - segment_start_norm) / (segment_end_norm - segment_start_norm) # Get start and end points of the segment start_point = self.points[segment_idx] end_point = self.points[segment_idx + 1] # Linear interpolation for each coordinate x = start_point.x + t * (end_point.x - start_point.x) y = start_point.y + t * (end_point.y - start_point.y) z = start_point.z + t * (end_point.z - start_point.z) # Return new Point object return (x, y, z)
@property def seg_centers(self): """ The list of segment centers in the section with normalized length. """ if self._ref is None: raise ValueError('Section is not referenced in NEURON.') return (np.array([(2*i - 1) / (2 * self.nseg) for i in range(1, self.nseg + 1)]) * self.L).tolist() @property def seg_borders(self): """ The list of segment borders in the section with normalized length. """ if self._ref is None: raise ValueError('Section is not referenced in NEURON.') nseg = int(self.nseg) return [i / nseg for i in range(nseg + 1)] @property def distances(self): """ The list of cumulative euclidean distances of the points in the section. """ coords = np.array([[pt.x, pt.y, pt.z] for pt in self.points]) deltas = np.diff(coords, axis=0) frusta_distances = np.sqrt(np.sum(deltas**2, axis=1)) cumulative_frusta_distances = np.insert(np.cumsum(frusta_distances), 0, 0) return cumulative_frusta_distances @property def center(self): """ The coordinates of the center of the section. """ return np.mean(self.xs), np.mean(self.ys), np.mean(self.zs) @cached_property def length(self): """ The length of the section calculated as the sum of the distances between the points. """ return self.distances[-1] @property def area(self): """ The surface area of the section calculated as the sum of the areas of the frusta segments. """ areas = [np.pi * (r1 + r2) * np.sqrt((r1 - r2)**2 + h**2) for r1, r2, h in zip(self.radii[:-1], self.radii[1:], np.diff(self.distances))] return sum(areas) def _invalidate_geometry_cache(self): """Call after modifying the geometry of the section to invalidate cached properties.""" self.__dict__.pop('length', None) def _invalidate_topology_cache(self): """Call after modifying the geometry of the section to invalidate cached properties.""" self.__dict__.pop('path_distance_to_root', None) # MECHANISM METHODS
[docs] def insert_mechanism(self, mech): """ Inserts a mechanism in the section if it is not already inserted. """ if self._ref.has_membrane(mech.name): return self._ref.insert(mech.name)
[docs] def uninsert_mechanism(self, mech): """ Uninserts a mechanism in the section if it was inserted. """ if not self._ref.has_membrane(mech.name): return self._ref.uninsert(mech.name)
# PARAMETER METHODS
[docs] def get_param_value(self, param_name): """ Get the average parameter of the section's segments. Parameters ---------- param_name : str The name of the parameter to get. Returns ------- float The average value of the parameter in the section's segments. """ # if param_name in ['Ra', 'diam', 'L', 'nseg', 'domain', 'subtree_size']: # return getattr(self, param_name) # if param_name in ['dist']: # return self.distance_to_root(0.5) seg_values = [seg.get_param_value(param_name) for seg in self.segments] return round(np.mean(seg_values), 16)
@cached_property def _legacy_path_distance_to_root(self) -> float: """ Calculate the total distance from the section start to the root. Returns ------- float The distance from the section start to the root. """ distance = 0 node = self.parent # Start from the parent node if node is None: return 0 while node.parent: distance += node.length node = node.parent return distance @cached_property def path_distance_to_root(self) -> float: """ Calculate the total distance from the section start to the root. Returns ------- float The distance from the section start to the root. """ path = self.path_to_ancestor(include_self=False, include_ancestor=False) return sum(sec.length for sec in path) @cached_property def _legacy_path_distance_within_domain(self) -> float: """ Calculate the distance from the section start to the root within the same domain. Returns ------- float The distance from the section start to the root within the same domain. """ distance = 0 node = self.parent # Start from the parent node if node is None: return 0 while node.parent: if node.domain_name != self.domain_name: break distance += node.length node = node.parent return distance @property def domain_root(self): """ Get the root section of the current domain. The domain root is the shallowest (closest to tree root) section that still belongs to this domain. Returns ------- Section The root section of the current domain. Raises ------ ValueError If the section has no parent (is the tree root). """ for sec in self._iter_to_root(): if sec.is_root or sec.parent.domain_name != self.domain_name: return sec raise RuntimeError('Unexpected: reached end of tree without finding domain root') def _legacy_path_distance(self, relative_position: float = 0, within_domain: bool = False) -> float: """ Get the distance from the section to the root at a given relative position. Parameters ---------- relative_position : float The position along the section's normalized length [0, 1]. within_domain : bool Whether to stop at the domain boundary. Returns ------- float The distance from the section to the root. """ if not (0 <= relative_position <= 1): raise ValueError('Relative position must be between 0 and 1.') if self.parent is None: # Soma section # relative_position = abs(relative_position - 0.5) return 0 base_distance = self._legacy_path_distance_within_domain if within_domain else self._legacy_path_distance_to_root return base_distance + relative_position * self.length
[docs] def path_distance(self, other=None, relative_position: float = None, relative_position_other: float = None) -> float: """ Calculate the path distance between positions on two sections. Parameters ---------- other : Section, optional The other section. If None, calculates distance to root. relative_position : float, optional Position along self (0 = start, 1 = end). Default None (context-dependent). relative_position_other : float, optional Position along other (0 = start, 1 = end). Default None (context-dependent). Returns ------- float The path distance between the two positions. """ if relative_position is not None and not (0 <= relative_position <= 1): raise ValueError(f'relative_position must be in [0, 1], got {relative_position}') if relative_position_other is not None and not (0 <= relative_position_other <= 1): raise ValueError(f'relative_position_other must be in [0, 1], got {relative_position_other}') # Calculate distance to root if other is not provided if other is None: if relative_position is None: relative_position = 0 if self.is_root: return self.length * abs(relative_position - 0.5) path_length = self.path_distance_to_root path_length += relative_position * self.length return path_length if self is other: if relative_position is None: relative_position = 0 if relative_position_other is None: relative_position_other = 0 return abs(relative_position_other - relative_position) * self.length # Get path between sections (excluding both endpoints and common ancestor) path = self.path(other, include_self=False, include_other=False, include_ancestor=False) path_length = sum(sec.length for sec in path) # Add contribution from self if self in other.ancestors: # If self is an ancestor of other, traverse backwards from self if relative_position is None: relative_position = 1 path_length += (1 - relative_position) * self.length else: # Normal case: traverse forwards from self if relative_position is None: relative_position = 0 path_length += relative_position * self.length # Add contribution from other if other in self.ancestors: # If other is an ancestor of self, traverse backwards from other if relative_position_other is None: relative_position_other = 1 path_length += (1 - relative_position_other) * other.length else: # Normal case: traverse forwards from other if relative_position_other is None: relative_position_other = 0 path_length += relative_position_other * other.length return path_length
[docs] def disconnect_from_parent(self): """ Detach the section from its parent section. """ # In SectionTree super().disconnect_from_parent() # In NEURON if self._ref: h.disconnect(sec=self._ref) #from parent # In PointTree self.points[0].disconnect_from_parent() # In SegmentTree if self.segments: self.segments[0].disconnect_from_parent()
[docs] def connect_to_parent(self, parent): """ Attach the section to a parent section. Parameters ---------- parent : Section The parent section to attach to. """ # In SectionTree super().connect_to_parent(parent) # In NEURON if self._ref: if self.parent is not None: if self.parent.is_root: # if parent is soma self._ref.connect(self.parent._ref(0.5)) else: self._ref.connect(self.parent._ref(1)) # In PointTree if self.parent is not None: if self.parent.is_root: # if parent is soma parent_sec = self.parent parent_pt = parent_sec.points[1] if len(parent_sec.points) > 1 else parent_sec.points[0] self.points[0].connect_to_parent(parent_pt) # attach to the middle of the parent else: self.points[0].connect_to_parent(parent.points[-1]) # attach to the end of the parent # In SegmentTree if self.segments: self.segments[0].connect_to_parent(parent.segments[-1])
# PLOTTING METHODS
[docs] def plot(self, ax=None, plot_parent=True, section_color=None, parent_color='gray', show_labels=True, aspect_equal=True): """ Plot section morphology in 3D projections (XZ, YZ, XY) and radii distribution. Parameters ---------- ax : list or array of matplotlib.axes.Axes, optional Four axes for plotting (XZ, YZ, XY, radii). If None, creates a new figure with axes. plot_parent : bool, optional Whether to include parent section in the visualization. section_color : str or None, optional Color for the current section. If None, assigns based on section domain. parent_color : str, optional Color for the parent section. show_labels : bool, optional Whether to show axis labels and titles. aspect_equal : bool, optional Whether to set aspect ratio to 'equal' for the projections. Returns ------- ax : list of matplotlib.axes.Axes The axes containing the plots. """ # Create figure and axes if not provided if ax is None: fig = plt.figure(figsize=(10, 8)) gs = GridSpec(2, 3, width_ratios=[1, 1, 1.2], figure=fig) # Create the three projection axes and one radius axis ax_xz = fig.add_subplot(gs[0, 0]) ax_yz = fig.add_subplot(gs[0, 1]) ax_xy = fig.add_subplot(gs[1, 0]) ax_radii = fig.add_subplot(gs[1, 1:]) ax = [ax_xz, ax_yz, ax_xy, ax_radii] else: # Use provided axes if len(ax) != 4: # flatten ax = [ai for a in ax for ai in a] ax_xz, ax_yz, ax_xy, ax_radii = ax # Determine section color based on domain if not provided if section_color is None: section_color = self.domain_color # Extract coordinates xs = np.array([p.x for p in self.points]) ys = np.array([p.y for p in self.points]) zs = np.array([p.z for p in self.points]) # Plot section projections self._plot_projection(ax_xz, xs, zs, 'X', 'Z', 'XZ Projection', section_color, show_labels, aspect_equal) self._plot_projection(ax_yz, ys, zs, 'Y', 'Z', 'YZ Projection', section_color, show_labels, aspect_equal) self._plot_projection(ax_xy, xs, ys, 'X', 'Y', 'XY Projection', section_color, show_labels, aspect_equal) # Plot radius distribution self._plot_radii_distribution(ax_radii, plot_parent, section_color, parent_color) # Plot parent section if requested if plot_parent and self.parent: # Only plot parent projections, radii are handled in _plot_radii_distribution parent_xs = np.array([p.x for p in self.parent.points]) parent_ys = np.array([p.y for p in self.parent.points]) parent_zs = np.array([p.z for p in self.parent.points]) self.parent._plot_projection(ax_xz, parent_xs, parent_zs, None, None, None, parent_color, False, aspect_equal) self.parent._plot_projection(ax_yz, parent_ys, parent_zs, None, None, None, parent_color, False, aspect_equal) self.parent._plot_projection(ax_xy, parent_xs, parent_ys, None, None, None, parent_color, False, aspect_equal) # Add overall title if we created the figure if ax is not None and show_labels: fig = ax_xz.get_figure() fig.suptitle(f"Section {self.idx} ({self.domain_name})", fontsize=14) fig.tight_layout() return ax
def _plot_projection(self, ax, x_coords, y_coords, x_label, y_label, title, color, show_labels, aspect_equal): """Helper method to plot a 2D projection of the section.""" ax.plot(x_coords, y_coords, 'o-', color=color, markerfacecolor=color, markeredgecolor='black', markersize=4, linewidth=1.5) if show_labels: if x_label: ax.set_xlabel(x_label) if y_label: ax.set_ylabel(y_label) if title: ax.set_title(title) if aspect_equal: ax.set_aspect('equal') def _plot_radii_distribution(self, ax, plot_parent, section_color, parent_color): """Helper method to plot radius distribution along the section.""" # Get section length for normalization section_length = self.distances[-1] - self.distances[0] # Normalize distances to start at 0 and end at section_length normalized_distances = self.distances - self.distances[0] # Plot section radii ax.plot(normalized_distances, self.radii, 'o-', color=section_color, label=f"{self.domain_name} ({self.idx})", linewidth=2) # Plot reference NEURON segments if available if hasattr(self, '_ref') and self._ref: # Calculate normalized segment centers normalized_seg_centers = np.array(self.seg_centers) - self.distances[0] # Extract radii from segments seg_radii = np.array([seg.diam / 2 for seg in self._ref]) # Use the specified bar width calculation from original code bar_width = [self.L / self._nseg] * self._nseg # Plot segment radii as bars ax.bar(normalized_seg_centers, seg_radii, width=bar_width, alpha=0.5, color=section_color, edgecolor='white', label=f"{self.domain_name} segments") # Plot parent section if requested if plot_parent and self.parent: parent_length = self.parent.distances[-1] - self.parent.distances[0] # Normalize parent distances to end at 0 (connecting to child) # Parent section goes from -parent_length to 0 normalized_parent_distances = self.parent.distances - self.parent.distances[-1] # Plot parent radii ax.plot(normalized_parent_distances, self.parent.radii, 'o-', color=parent_color, linewidth=2, label=f"Parent {self.parent.domain_name} ({self.parent.idx})") # Plot parent reference segments if available if hasattr(self.parent, '_ref') and self.parent._ref: # Normalize parent segment centers to the same scale normalized_parent_seg_centers = (np.array(self.parent.seg_centers) - self.parent.distances[-1]) # Extract parent segment radii parent_seg_radii = np.array([seg.diam / 2 for seg in self.parent._ref]) # Use the specified bar width calculation for parent parent_bar_width = [self.parent.L / self.parent._nseg] * self.parent._nseg # Plot parent segment radii as bars ax.bar(normalized_parent_seg_centers, parent_seg_radii, width=parent_bar_width, alpha=0.5, color=parent_color, edgecolor='white', label=f"Parent segments") # Set plot labels and legend ax.set_xlabel('Distance (µm)') ax.set_ylabel('Radius (µm)') ax.set_title('Radius Distribution') # Ensure y-axis starts at 0 ax.set_ylim(bottom=0) # Adjust x-axis to show the full section(s) if plot_parent and self.parent: parent_length = self.parent.distances[-1] - self.parent.distances[0] ax.set_xlim(-parent_length * 1.05, section_length * 1.05) else: ax.set_xlim(-section_length * 0.05, section_length * 1.05) # Add legend if we have multiple data series if ((hasattr(self, '_ref') and self._ref) or (plot_parent and self.parent)): ax.legend(loc='best', frameon=True, framealpha=0.8)
[docs] def plot_radii(self, ax=None, include_parent=False, section_color=None, parent_color='gray'): """ Plot just the radius distribution for the section. Parameters ---------- ax : matplotlib.axes.Axes, optional Axes to plot on. If None, creates a new figure and axes. include_parent : bool, optional Whether to include parent section in the plot. section_color : str or None, optional Color for current section. If None, assigns based on section domain. parent_color : str, optional Color for parent section if included. Returns ------- ax : matplotlib.axes.Axes The axes containing the plot. """ # Create new figure and axes if not provided if ax is None: fig, ax = plt.subplots(figsize=(10, 5)) # Determine section color based on domain if not provided if section_color is None: domain_colors = { 'soma': 'black', 'axon': 'red', 'dend': 'blue', 'apic': 'green' } section_color = domain_colors.get(self.domain_name, 'purple') # Plot radius distribution self._plot_radii_distribution(ax, include_parent, section_color, parent_color) # Add title if creating a standalone plot if ax.get_figure().get_axes()[0] == ax: # If this is the only axes in the figure ax.set_title(f"Radius Distribution - Section {self.idx} ({self.domain_name})") plt.tight_layout() return ax
# -------------------------------------------------------------- # NEURON SECTION # -------------------------------------------------------------- class NeuronSection(Section): def __init__(self, idx, parent_idx, points) -> None: super().__init__(idx, parent_idx, points) @property def diam(self): """ Diameter of the central segment of the section (from NEURON). """ return self._ref.diam @property def L(self): """ Length of the section (from NEURON). """ return self._ref.L @property def cm(self): """ Specific membrane capacitance of the section (from NEURON). """ return self._ref.cm @property def Ra(self): """ Axial resistance of the section (from NEURON). """ return self._ref.Ra def has_mechanism(self, mech_name): """ Check if the section has a mechanism inserted. Parameters ---------- mech_name : str The name of the mechanism to check. """ return self._ref.has_membrane(mech_name) @property def nseg(self): """ Number of segments in the section (from NEURON). """ return self._nseg @nseg.setter def nseg(self, value): if value < 1: raise ValueError('Number of segments must be at least 1.') if value % 2 == 0: raise ValueError('Number of segments must be odd.') # Set the number in NEURON self._nseg = self._ref.nseg = value # Get the new NEURON segments nrnsegs = [seg for seg in self._ref] # Create new DendroTweaks segments from dendrotweaks.morphology.seg_trees import NeuronSegment old_segments = self.segments new_segments = [NeuronSegment(idx=0, parent_idx=0, sim_seg=seg, section=self) for seg in nrnsegs] seg_tree = self._tree._seg_tree first_segment = self.segments[0] parent = first_segment.parent for seg in new_segments: seg_tree.insert_node_before(seg, first_segment) # for i, seg in enumerate(new_segments[:]): # if i == 0: # seg_tree.insert_node_before(seg, first_segment) # else: # seg_tree.insert_node_after(seg, new_segments[i-1]) for seg in old_segments: seg_tree.remove_node(seg) # Sort the tree self._tree._seg_tree.sort() # Update the section's segments self.segments = new_segments # REFERENCING METHODS def create_and_reference(self): """ Create a NEURON section. """ self._ref = h.Section() # name=f'Sec_{self.idx}' self._nseg = self._ref.nseg if self.parent is not None: # TODO: Attaching basal to soma 0 if self.parent.is_root: # if parent is soma self._ref.connect(self.parent._ref(0.5)) else: self._ref.connect(self.parent._ref(1)) # Add 3D points to the section self._ref.pt3dclear() for pt in self.points: diam = 2*pt.r diam = round(diam, 16) self._ref.pt3dadd(pt.x, pt.y, pt.z, diam) # ======================================================================== # SECTION TREE # ========================================================================
[docs] class SectionTree(Tree): """ A class representing a tree graph of sections in a neuron morphology. Parameters ---------- sections : List[Section] A list of sections in the tree. Attributes ---------- domains : Dict[str, Domain] A dictionary of domains in the tree. """ def __init__(self, sections: list[Section]) -> None: super().__init__(sections) # self._create_domains() self._point_tree = None self._seg_tree = None def __repr__(self): return f"SectionTree(root={self.root!r}, num_nodes={len(self._nodes)})" # def _create_domains(self): # """ # Create domains using the data from the sections (from the points in the sections). # """ # unique_domain_precursors = set([ # (sec.domain_name, sec.type_idx, sec.domain_color) # for sec in self.sections # ]) # self.domains = { # name: Domain(name, type_idx, color) # for name, type_idx, color in sorted(unique_domain_precursors) # } # for sec in self.sections: # self.domains[sec.domain_name].add_section(sec) # PROPERTIES @property def sections(self): """ A list of sections in the tree. Alias for self._nodes. """ return self._nodes @property def soma(self): """ The soma section of the tree. Alias for self.root. """ return self.root @property def sections_by_depth(self): """ A dictionary of sections grouped by depth in the tree (depth is the number of edges from the root). """ sections_by_depth = {} for sec in self.sections: if sec.depth not in sections_by_depth: sections_by_depth[sec.depth] = [] sections_by_depth[sec.depth].append(sec) return sections_by_depth @property def df(self): """ A DataFrame of the sections in the tree. """ data = { 'idx': [], 'domain': [], 'x': [], 'y': [], 'z': [], 'r': [], 'parent_idx': [], 'section_idx': [], 'parent_section_idx': [], } for sec in self.sections: points = sec.points if sec.is_root or sec.parent.is_root else sec.points[1:] for pt in points: data['idx'].append(pt.idx) data['domain'].append(pt.domain_name) data['x'].append(pt.x) data['y'].append(pt.y) data['z'].append(pt.z) data['r'].append(pt.r) data['parent_idx'].append(pt.parent_idx) data['section_idx'].append(sec.idx) data['parent_section_idx'].append(sec.parent_idx) return pd.DataFrame(data) # SORTING METHODS
[docs] def sort(self, **kwargs): """ Sort the sections in the tree using a depth-first traversal. Parameters ---------- sort_children : bool, optional Whether to sort the children of each node based on the number of bifurcations in their subtrees. Defaults to True. force : bool, optional Whether to force the sorting of the tree even if it is already sorted. Defaults to False. """ super().sort(**kwargs) self._point_tree.sort(**kwargs) if self._seg_tree: self._seg_tree.sort(**kwargs)
# STRUCTURE METHODS
[docs] def remove_subtree(self, section): """ Remove a section and its subtree from the tree. Parameters ---------- section : Section The section to remove. """ super().remove_subtree(section) # Domains # for domain in self.domains.values(): # for sec in section.subtree: # if sec in domain.sections: # domain.remove_section(sec) # Points self._point_tree.remove_subtree(section.points[0]) # Segments if self._seg_tree: self._seg_tree.remove_subtree(section.segments[0]) # NEURON if section._ref: h.disconnect(sec=section._ref) for sec in section.subtree: h.delete_section(sec=sec._ref) self.sort()
[docs] def reposition_subtree(self, section, new_parent_section, origin=None): raise NotImplementedError( 'Repositioning subtrees of a SectionTree is not implemented yet. Use this method on the PointTree before (!) creating the SectionTree.' )
[docs] def remove_zero_length_sections(self): """ Remove sections with zero length. """ for sec in self.sections: if sec.length == 0: for pt in sec.points: self._point_tree.remove_node(pt) for seg in sec.segments: self._seg_tree.remove_node(seg) self.remove_node(sec)
[docs] def downsample(self, factor: float): """ Downsample the SWC tree by reducing the number of points in each section based on the given factor, while preserving the first and last points. :param factor: The proportion of points to keep (e.g., 0.5 keeps 50% of points) If factor is 0, keep only the first and last points. """ for sec in self.sections: if sec is self.soma: continue if len(sec.points) < 3: # Keep sections with only start & end points continue num_points = len(sec.points) if factor == 0: num_to_keep = 2 else: num_to_keep = max(2, int(num_points * factor)) # Ensure at least start & end remain # Select indices to keep (first, last, and spaced indices in between) keep_indices = np.linspace(0, num_points - 1, num_to_keep, dtype=int) keep_set = set(keep_indices) points_to_remove = [pt for i, pt in enumerate(sec.points) if i not in keep_set] print(f'Removing {len(points_to_remove)} points from section {sec.idx}') for pt in points_to_remove: self._point_tree.remove_node(pt) sec.points.remove(pt) self._point_tree.sort()
# PLOTTING METHODS
[docs] def plot(self, ax=None, show_points=False, show_lines=True, show_domains=True, annotate=False, projection='XY', highlight_sections=None, focus_sections=None): """ Plot the sections in the tree in a 2D projection. Parameters ---------- ax : matplotlib.axes.Axes, optional Axes to plot on. If None, creates a new figure and axes. show_points : bool, optional Whether to show the points in the sections. show_lines : bool, optional Whether to show the lines connecting the points. show_domains : bool, optional Whether to color sections based on their domain. annotate : bool, optional Whether to annotate the sections with their index. projection : str or tuple, optional The projection to use for the plot. Can be 'XY', 'XZ', 'YZ', or a tuple of two axes. highlight_sections : list of Section, optional Sections to highlight in the plot. focus_sections : list of Section, optional Sections to focus on in the plot. """ if ax is None: fig, ax = plt.subplots(figsize=(10, 10)) highlight_sections = set(highlight_sections) if highlight_sections else None focus_sections = set(focus_sections) if focus_sections else None x_attr, y_attr = projection[0].lower(), projection[1].lower() section_count = len(self.sections) # Avoid recalculating for sec in self.sections: # Skip sections that are not in the focus set (if focus is specified) if focus_sections and sec not in focus_sections: continue xs = [getattr(pt, x_attr) for pt in sec.points] ys = [getattr(pt, y_attr) for pt in sec.points] # Assign colors based on domains or section index color = plt.cm.jet(1 - sec.idx / section_count) if show_domains: color = sec.domain_color if highlight_sections and sec in highlight_sections: color = 'red' # Plot section points and lines if show_points: ax.plot(xs, ys, '.', color=color, markersize=7, markeredgecolor='black') if show_lines: ax.plot(xs, ys, color=color, zorder=0) # Annotate section index if needed if annotate: mean_x, mean_y = np.mean(xs), np.mean(ys) ax.annotate( f'{sec.idx}', (mean_x, mean_y), fontsize=8, color='white', bbox=dict(facecolor='black', edgecolor='white', boxstyle='round,pad=0.3') ) ax.set_xlabel(projection[0]) ax.set_ylabel(projection[1]) ax.set_aspect('equal')
[docs] def plot_radii_distribution(self, ax=None, highlight=None, domains=True, show_soma=False): """ Plot the radius distribution of the sections in the tree. Parameters ---------- ax : matplotlib.axes.Axes, optional Axes to plot on. If None, creates a new figure and axes. highlight : list of int, optional Indices of sections to highlight in the plot. domains : bool, optional Whether to color sections based on their domain. show_soma : bool, optional Whether to show the soma section in the plot. """ if ax is None: fig, ax = plt.subplots(figsize=(8, 3)) for sec in self.sections: if not show_soma and sec.is_root: continue color = sec.domain_color if highlight and sec.idx in highlight: ax.plot( [pt.path_distance() for pt in sec.points], sec.radii, marker='.', color='red', zorder=2 ) else: ax.plot( [pt.path_distance() for pt in sec.points], sec.radii, marker='.', color=color, zorder=1 ) ax.set_xlabel('Distance from root') ax.set_ylabel('Radius')