Source code for dendrotweaks.morphology.io.factories

# SPDX-FileCopyrightText: 2025 Poirazi Lab <dendrotweaks@dendrites.gr>
# SPDX-License-Identifier: MPL-2.0

from dendrotweaks.morphology.trees import Node, Tree
from dendrotweaks.morphology.point_trees import Point, PointTree
from dendrotweaks.morphology.sec_trees import NeuronSection, Section, SectionTree
from dendrotweaks.morphology.seg_trees import NeuronSegment, Segment, SegmentTree
from dendrotweaks.morphology.domains import Domain

from dendrotweaks.morphology.io.reader import SWCReader

from typing import List, Union
from pandas import DataFrame, isna

from dendrotweaks.morphology.io.validation import validate_tree



[docs] def create_point_tree(source: Union[str, DataFrame]) -> PointTree: """ Create a point tree from either a file path or a DataFrame. Parameters ---------- source : Union[str, DataFrame] The source of the SWC data. Can be a file path or a DataFrame. Returns ------- PointTree The point tree representing the reconstruction of the neuron morphology. """ if isinstance(source, str): reader = SWCReader() df = reader.read_file(source) elif isinstance(source, DataFrame): df = source else: raise ValueError("Source must be a file path (str) or a DataFrame.") nodes = [ Point(row.Index, row.Type, row.X, row.Y, row.Z, row.R, row.Parent, domain_name=None if isna(row.Domain) else row.Domain, domain_color=None if isna(row.Color) else row.Color ) for row in df.itertuples(index=False) ] point_tree = PointTree(nodes) point_tree.remove_overlaps() return point_tree
[docs] def create_section_tree(point_tree: PointTree): """ Create a section tree from a point tree. Parameters ---------- point_tree : PointTree The point tree to create the section tree from by splitting it into sections. Returns ------- SectionTree The section tree created representing the neuron morphology on a more abstract level. """ point_tree.extend_sections() point_tree.sort() sections = _split_to_sections(point_tree) sec_tree = SectionTree(sections) sec_tree._point_tree = point_tree return sec_tree
def _split_to_sections(point_tree: PointTree) -> List[Section]: """ Split the point tree into sections. """ sections = [] bifurcation_children = [ child for b in point_tree.bifurcations for child in b.children] bifurcation_children = [point_tree.root] + bifurcation_children # Filter out the bifurcation children to enforce the original order bifurcation_children = [node for node in point_tree._nodes if node in bifurcation_children] # Assign a section to each bifurcation child for i, child in enumerate(bifurcation_children): section = NeuronSection(idx=i, parent_idx=-1, points=[child]) sections.append(section) child._section = section # Propagate the section to the children until the next # bifurcation or termination point is reached while child.children: next_child = child.children[0] if next_child in bifurcation_children: break next_child._section = section section.points.append(next_child) child = next_child section.parent = section.points[0].parent._section if section.points[0].parent else None section.parent_idx = section.parent.idx if section.parent else -1 if point_tree.soma_notation == '3PS': sections = _merge_soma(sections, point_tree) return sections def _merge_soma(sections: List[Section], point_tree: PointTree): """ If soma has 3PS notation, merge it into one section. """ true_soma = point_tree.root._section true_soma.idx = 0 true_soma.parent_idx = -1 # Find false soma sections false_somas = [sec for sec in sections if sec.domain_name == 'soma' and sec is not true_soma] if len(false_somas) != 2: print(false_somas) raise ValueError('Soma must have exactly 2 children of domain soma.') # Reassign points from false somas to true soma for i, sec in enumerate(false_somas): if len(sec.points) != 1: raise ValueError('False somas must have exactly 1 point.') for pt in sec.points: pt._section = true_soma # Sort points in true soma according to the 3PS convention true_soma.points = [ false_somas[0].points[0], true_soma.points[0], false_somas[1].points[0] ] # Rebuild section list without false somas kept_sections = [s for s in sections if s not in false_somas] kept_sections.sort(key=lambda s: s.idx) # Create mapping from old to new indices old_to_new = {sec.idx: i for i, sec in enumerate(kept_sections)} # Update indices for sec in kept_sections: sec.idx = old_to_new[sec.idx] # Update parent indices for sec in kept_sections: if sec is not true_soma: sec.parent_idx = sec.points[0].parent._section.idx return kept_sections
[docs] def create_segment_tree(sec_tree): """ Create a segment tree from a section tree. Parameters ---------- sec_tree : SectionTree The section tree to create the segment tree from by splitting it into segments. Returns ------- SegmentTree The segment tree representing spatial discretization of the neuron morphology for numerical simulations. """ segments = _create_segments(sec_tree) seg_tree = SegmentTree(segments) sec_tree._seg_tree = seg_tree return seg_tree
def _create_segments(sec_tree) -> List[Segment]: """ Create a list of Segment objects from a SectionTree object. """ segments = [] # TODO: Refactor this to use a stack instead of recursion def add_segments(sec, parent_idx, idx_counter): segs = {seg: idx + idx_counter for idx, seg in enumerate(sec._ref)} sec.segments = [] for seg, idx in segs.items(): segment = NeuronSegment( idx=idx, parent_idx=parent_idx, sim_seg=seg, section=sec) segments.append(segment) sec.segments.append(segment) parent_idx = idx idx_counter += len(segs) for child in sec.children: # IMPORTANT: This is needed since 0 and 1 segments are not explicitly # defined in the section segments list if child._ref.parentseg().x == 1: new_parent_idx = list(segs.values())[-1] elif child._ref.parentseg().x == 0: new_parent_idx = list(segs.values())[0] else: new_parent_idx = segs[child._ref.parentseg()] # Recurse for the child section idx_counter = add_segments(child, new_parent_idx, idx_counter) return idx_counter # Start with the root section of the sec_tree add_segments(sec_tree.root, parent_idx=-1, idx_counter=0) return segments
[docs] def create_domains(section_tree): """ Create domains using the data from the sections (from the points in the sections). """ unique_domain_precursors = set([ (sec.domain_name, sec.type_idx, sec.domain_color) for sec in section_tree.sections ]) domains = { name: Domain(name, type_idx, color) for name, type_idx, color in sorted(unique_domain_precursors) } for sec in section_tree.sections: domains[sec.domain_name].add_section(sec) return domains