Source code for dendrotweaks.morphology.trees

# SPDX-FileCopyrightText: 2025 Poirazi Lab <dendrotweaks@dendrites.gr>
# SPDX-License-Identifier: MPL-2.0

from collections import defaultdict
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from dendrotweaks.utils import timeit
from typing import Union
from functools import cached_property

# ========================================================================
# NODE
# ========================================================================

[docs] class Node(): """ Represents a node in a tree. A node can be a 3D point in a neuron morphology, a segment, a section, or even a tree. Parameters ---------- idx : Union[int, str] The index of the node. parent_idx : Union[int, str] The index of the parent node. """ def __init__(self, idx: Union[int, str], parent_idx: Union[int, str]) -> None: """ Creates a node in a tree. Args: idx (Union[int, str]): The index of the node. parent_idx (Union[int, str]): The index of the parent node. """ self.idx = int(idx) self.parent_idx = int(parent_idx) self._parent = None self.children = [] self._tree = None def __repr__(self): class_name = type(self).__name__ return f"{class_name}(idx={self.idx!r})" @property def parent(self): return self._parent @parent.setter def parent(self, parent): self._parent = parent self.parent_idx = parent.idx if parent else -1 # Invalidate is_root cache when parent changes if 'is_root' in self.__dict__: del self.__dict__['is_root'] @cached_property def is_root(self) -> bool: """ Check if the node is the root of the tree. Returns: bool: True if the node is the root, False otherwise. """ return self.parent is None @property def topological_type(self) -> str: """The topological type of the node based on the number of children. Returns: str: The topological type of the node: 'continuation', 'bifurcation', or 'termination'. """ types = {0: 'termination', 1: 'continuation'} return types.get(len(self.children), 'bifurcation') @property def subtree(self) -> list: """ Gets the subtree of the node (including the node itself) using an iterative depth-first traversal. Returns: list: A list of nodes in the subtree. """ subtree = [] stack = [self] # Start from the current node while stack: node = stack.pop() subtree.append(node) stack.extend(node.children) # Push children to stack return subtree @property def subtree_size(self): """ Gets the size of the subtree of the node. Returns: int: The size of the subtree of the node. """ return len(self.subtree) @property def depth(self): """ Computes the depth of the node in the tree iteratively. """ depth = 0 node = self while node.parent: # Traverse up to the root depth += 1 node = node.parent return depth @property def siblings(self): """ Gets the siblings of the node. Returns: list: A list of nodes that share the same parent as the node. """ if self.is_root: return [] return [child for child in self.parent.children if child is not self] @property def nearest_neighbours(self): """ Gets the nearest neighbours of the node. Returns: list: A list of nodes that share the same parent or children as the node. """ return [self.parent] + self.children def _iter_to_root(self): """Iterate from this node up to the root.""" node = self while node: yield node node = node.parent @property def ancestors(self): """Get all ancestors of this node up to the root.""" return list(self._iter_to_root())
[docs] def path_to_ancestor(self, ancestor=None, include_self=True, include_ancestor=True): """Get the path from this node to a given ancestor node. Parameters ---------- ancestor : Node, optional The ancestor node to get the path to. If None, the path to the root is returned. include_self : bool, optional Whether to include this node in the path. Defaults to True. include_ancestor : bool, optional Whether to include the ancestor node in the path. Defaults to True. Returns ------- list A list of nodes representing the path from this node to the ancestor node. """ path = [] for node in self._iter_to_root(): path.append(node) if node is ancestor: break else: # Loop completed without break and ancestor not found if ancestor is not None: raise ValueError(f'{ancestor} is not an ancestor of {self}') if not include_self and path: path = path[1:] if not include_ancestor and path: path = path[:-1] return path
[docs] def find_common_ancestor(self, other): """Find the common ancestor between this node and another node.""" if self is other: return self ancestors = set(self._iter_to_root()) for node in other._iter_to_root(): if node in ancestors: return node raise ValueError('No common ancestor found.')
[docs] def path(self, other, include_self=True, include_other=True, include_ancestor=True): """ Get the path between this node and another node. Parameters ---------- other : Node The other node to get the path to. include_self : bool Whether to include this node in the path. include_other : bool Whether to include the other node in the path. include_ancestor : bool Whether to include the lowest common ancestor (LCA) in the path. Only relevant when self and other are in parallel subtrees, ignored otherwise. Returns ------- list A list of nodes representing the path from this node to the other node. """ if self is other: return [self] if (include_self and include_other) else [] common_ancestor = self.find_common_ancestor(other) if common_ancestor is self: include_ancestor = include_self elif common_ancestor is other: include_ancestor = include_other path_from_self = self.path_to_ancestor(common_ancestor, include_self=include_self, include_ancestor=include_ancestor) path_from_other = other.path_to_ancestor(common_ancestor, include_self=include_other, include_ancestor=False) return path_from_self + path_from_other[::-1]
[docs] def connect_to_parent(self, parent): """ Attach the node to a parent node. Warning ------- This method should not be used directly when working with trees as it doesn't add the node to the tree's list of nodes. Use the `Tree` class to insert nodes into the tree. Args: parent (Node): The parent node to attach the node to. """ if parent in self.subtree: raise ValueError(f'Attaching a node will create a loop in the tree: {self} -> {parent}') self.parent = parent if self not in parent.children: parent.children.append(self)
# parent.childrensorted(parent.children + [node], key=lambda x: x.idx)
[docs] def disconnect_from_parent(self): """ Detach the node from its parent. Examples -------- for child in node.children: child.disconnect_from_parent() """ if self.parent: self.parent.children.remove(self) self.parent = None
# ======================================================================== # TREE # ========================================================================
[docs] class Tree: """ A class to represent a tree data structure. A tree graph is a hierarchical data structure that consists of nodes connected by edges. Parameters ---------- nodes : list A list of nodes in the tree. Attributes ---------- root : Node The root node of the tree. """ def __init__(self, nodes: list) -> None: for node in nodes: node._tree = self self._nodes = nodes self.root = self._find_root() if not self.is_connected: self._connect_nodes() # MAGIC METHODS def __repr__(self): return f"Tree(root={self.root!r}, num_nodes={len(self._nodes)})" def __getitem__(self, idx): return self._nodes[idx] def __len__(self): return len(self._nodes) def __iter__(self): for node in self._nodes: yield node def __contains__(self, node): return node in self._nodes # PROPERTIES @property def is_connected(self): """ Whether the tree is connected i.e. each node can be reached from the root. Returns ------- bool True if the root node's subtree contains exactly the same nodes as the entire tree. False otherwise. """ nodes_set = set(self._nodes) subtree_set = set(self.root.subtree) return nodes_set == subtree_set @property def is_sorted(self): """ Whether the nodes in the tree are sorted by index. Returns ------- bool True if the nodes are sorted by index. False otherwise. """ if not all([node.idx == i for i, node in enumerate(self._nodes, start=0)]): return False traversal_indices = [node.idx for node in self.traverse()] return traversal_indices == sorted(traversal_indices) @property def bifurcations(self): """ The bifurcation nodes in the tree. Returns ------- list A list of bifurcation nodes in the tree. """ return [node for node in self._nodes if len(node.children) > 1] @property def terminations(self): return [node for node in self._nodes if len(node.children) == 0] @property def edges(self) -> list: """ Returns a list of edges in the tree. Returns: list[tuple[Node, Node]]: A list of edges in the tree. """ edges = [] for node in self._nodes: if node.parent is not None: edges.append((node.parent, node)) return edges # TREE CONSTRUCTION METHODS def _find_root(self): """ Find the root node. Returns ------- Node The root node of the tree. """ ROOT_PARENT = {None, -1, '-1'} root_nodes = [node for node in self._nodes if node.parent_idx in ROOT_PARENT] if len(root_nodes) != 1: print('Root nodes:', root_nodes) raise ValueError(f'Tree must have exactly one root node. Found: {root_nodes}') return root_nodes[0] def _connect_nodes(self): """ Efficiently builds the hierarchical tree structure for the nodes using a dictionary for fast parent lookups. """ if self.is_connected: print(' Tree already connected.') return # Step 1: Create a dictionary for O(1) lookups node_map = {node.idx: node for node in self._nodes} # Step 2: Assign parent-child relationships in O(N) time for node in self._nodes: if node is not self.root and node.parent_idx in node_map: node.connect_to_parent(node_map[node.parent_idx]) # Step 3: Ensure tree is fully connected if not self.is_connected: raise ValueError('Tree is not connected.') # TRAVERSAL METHODS
[docs] def traverse(self, root=None): """ Iterate over the nodes in the tree using a stack-based depth-first traversal. Parameters ---------- root : Node, optional The root node to start the traversal from. Defaults to None. """ root = root or self.root stack = [root] visited = set() while stack: node = stack.pop() if node in visited: continue yield node visited.add(node) for child in reversed(node.children): stack.append(child)
# SORTIONG METHODS def _sort_children(self): """ Iterate through all nodes in the tree and sort their children based on the number of bifurcations (nodes with more than one child) in each child's subtree. Nodes with fewer bifurcations in their subtrees are placed earlier in the list of the node's children, ensuring that the shortest paths are traversed first. """ # subtree_size_map = {node: len(self.get_subtree(node)) for node in self._nodes} for node in self._nodes: node.children = sorted( node.children, key=lambda x: sum(1 for n in x.subtree if len(n.children) > 1), reverse=False ) # @timeit
[docs] def sort(self, sort_children=True, force=False): """ Sort the nodes in the tree using a stack-based depth-first traversal. Parameters ---------- sort_children : bool, optional Whether to sort the children of each node based on the number of bifurcations in their subtrees. Defaults to True. force : bool, optional Whether to force the sorting of the tree even if it is already sorted. Defaults to False. """ if self.is_sorted and not force: return if sort_children: self._sort_children() count = 0 for node in self.traverse(): node.idx = count node.parent_idx = node.parent.idx if node.parent else -1 count += 1 self._nodes = sorted(self._nodes, key=lambda x: x.idx) if not self.is_sorted: raise ValueError('Tree is not sorted.') print(f'Sorted {self}.')
# INSERTION AND REMOVAL METHODS
[docs] def remove_node(self, node): """ Remove a node from the tree. Parameters ---------- node : Node The node to remove. Raises ------ ValueError If the tree is not sorted. """ if node.is_root: raise ValueError('Cannot remove the root node.') parent = node.parent children = node.children[:] for child in children: child.disconnect_from_parent() child.connect_to_parent(parent) node.disconnect_from_parent() self._nodes.remove(node)
[docs] def remove_subtree(self, node): """ Remove a subtree from the tree. Parameters ---------- node : Node The root node of the subtree to remove. """ node.disconnect_from_parent() for n in node.subtree: self._nodes.remove(n)
[docs] def add_subtree(self, node, parent): """ Add a subtree to the tree. Parameters ---------- node : Node The root node of the subtree to add. parent : Node The parent node to attach the subtree to. """ node.connect_to_parent(parent) self._nodes.extend(node.subtree)
[docs] def insert_node_after(self, new_node, existing_node): """ Insert a node after a given node in the tree. Parameters ---------- new_node : Node The new node to insert. existing_node : Node The existing node after which to insert the new node. """ if new_node in self._nodes: raise ValueError('Node already exists in the tree.') for child in existing_node.children: child.disconnect_from_parent() child.connect_to_parent(new_node) new_node.connect_to_parent(existing_node) self._nodes.append(new_node)
[docs] def insert_node_before(self, new_node, existing_node): """ Insert a node before a given node in the tree. Parameters ---------- new_node : Node The new node to insert. existing_node : Node The existing node before which to insert the new node. """ if new_node in self._nodes: raise ValueError('Node already exists in the tree.') new_node.connect_to_parent(existing_node.parent) existing_node.disconnect_from_parent() existing_node.connect_to_parent(new_node) self._nodes.append(new_node)
[docs] def reposition_subtree(self, node, new_parent_node, origin=None): """ Reposition a subtree in the tree. Parameters ---------- node : Node The root node of the subtree to reposition. new_parent_node : Node The new parent node of the subtree. origin : Node, optional The origin node to use as the reference point for the repositioning. Defaults to None. Note ----- Treats differently the children of the root node. """ if node.is_root: raise ValueError('Cannot reposition the root node.') origin = origin or node.parent self.remove_subtree(node) shift_coordinates(node.subtree, origin=origin, target=new_parent_node) self.add_subtree(node, new_parent_node)
# VISUALIZATION METHODS
[docs] def topology(self): """ Print the topology of the tree with a visual tree structure. """ def print_node(node, prefix="", is_last=True): """Recursive function to print the node with branches.""" # Print the current root_str = f"{node.parent_idx:6} | " prefix = root_str + prefix print(prefix + '•' + str(node.idx)) # Handle the children nodes num_children = len(node.children) for i, child in enumerate(node.children): is_last_child = (i == num_children - 1) branch = "└─" if is_last_child else "├─" prefix = prefix.replace("└─", " ").replace("├─", "│ ") prefix = prefix.replace(root_str, "") print_node(child, prefix + branch, is_last_child) print('parent | idx') print('-'*15) print_node(self.root)
def shift_coordinates(points, origin, target): """ Shift the coordinates of a list of points from an origin to a target. """ origin_vector = (origin.x, origin.y, origin.z) target_vector = (target.x, target.y, target.z) for pt in points: pt.x = pt.x - origin_vector[0] + target_vector[0] pt.y = pt.y - origin_vector[1] + target_vector[1] pt.z = pt.z - origin_vector[2] + target_vector[2]