# SPDX-FileCopyrightText: 2025 Poirazi Lab <dendrotweaks@dendrites.gr>
# SPDX-License-Identifier: MPL-2.0
from collections import defaultdict
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from dendrotweaks.utils import timeit
from typing import Union
from functools import cached_property
# ========================================================================
# NODE
# ========================================================================
[docs]
class Node():
"""
Represents a node in a tree.
A node can be a 3D point in a neuron morphology,
a segment, a section, or even a tree.
Parameters
----------
idx : Union[int, str]
The index of the node.
parent_idx : Union[int, str]
The index of the parent node.
"""
def __init__(self, idx: Union[int, str], parent_idx: Union[int, str]) -> None:
"""
Creates a node in a tree.
Args:
idx (Union[int, str]): The index of the node.
parent_idx (Union[int, str]): The index of the parent node.
"""
self.idx = int(idx)
self.parent_idx = int(parent_idx)
self._parent = None
self.children = []
self._tree = None
def __repr__(self):
class_name = type(self).__name__
return f"{class_name}(idx={self.idx!r})"
@property
def parent(self):
return self._parent
@parent.setter
def parent(self, parent):
self._parent = parent
self.parent_idx = parent.idx if parent else -1
# Invalidate is_root cache when parent changes
if 'is_root' in self.__dict__:
del self.__dict__['is_root']
@cached_property
def is_root(self) -> bool:
"""
Check if the node is the root of the tree.
Returns:
bool: True if the node is the root, False otherwise.
"""
return self.parent is None
@property
def topological_type(self) -> str:
"""The topological type of the node based on the number of children.
Returns:
str: The topological type of the node: 'continuation', 'bifurcation', or 'termination'.
"""
types = {0: 'termination', 1: 'continuation'}
return types.get(len(self.children), 'bifurcation')
@property
def subtree(self) -> list:
"""
Gets the subtree of the node (including the node itself) using
an iterative depth-first traversal.
Returns:
list: A list of nodes in the subtree.
"""
subtree = []
stack = [self] # Start from the current node
while stack:
node = stack.pop()
subtree.append(node)
stack.extend(node.children) # Push children to stack
return subtree
@property
def subtree_size(self):
"""
Gets the size of the subtree of the node.
Returns:
int: The size of the subtree of the node.
"""
return len(self.subtree)
@property
def depth(self):
"""
Computes the depth of the node in the tree iteratively.
"""
depth = 0
node = self
while node.parent: # Traverse up to the root
depth += 1
node = node.parent
return depth
@property
def siblings(self):
"""
Gets the siblings of the node.
Returns:
list: A list of nodes that share the same parent as the node.
"""
if self.is_root:
return []
return [child for child in self.parent.children if child is not self]
@property
def nearest_neighbours(self):
"""
Gets the nearest neighbours of the node.
Returns:
list: A list of nodes that share the same parent or children as the node.
"""
return [self.parent] + self.children
def _iter_to_root(self):
"""Iterate from this node up to the root."""
node = self
while node:
yield node
node = node.parent
@property
def ancestors(self):
"""Get all ancestors of this node up to the root."""
return list(self._iter_to_root())
[docs]
def path_to_ancestor(self, ancestor=None, include_self=True, include_ancestor=True):
"""Get the path from this node to a given ancestor node.
Parameters
----------
ancestor : Node, optional
The ancestor node to get the path to. If None, the path to the root is returned.
include_self : bool, optional
Whether to include this node in the path. Defaults to True.
include_ancestor : bool, optional
Whether to include the ancestor node in the path. Defaults to True.
Returns
-------
list
A list of nodes representing the path from this node to the ancestor node.
"""
path = []
for node in self._iter_to_root():
path.append(node)
if node is ancestor:
break
else:
# Loop completed without break and ancestor not found
if ancestor is not None:
raise ValueError(f'{ancestor} is not an ancestor of {self}')
if not include_self and path:
path = path[1:]
if not include_ancestor and path:
path = path[:-1]
return path
[docs]
def find_common_ancestor(self, other):
"""Find the common ancestor between this node and another node."""
if self is other:
return self
ancestors = set(self._iter_to_root())
for node in other._iter_to_root():
if node in ancestors:
return node
raise ValueError('No common ancestor found.')
[docs]
def path(self, other, include_self=True, include_other=True, include_ancestor=True):
"""
Get the path between this node and another node.
Parameters
----------
other : Node
The other node to get the path to.
include_self : bool
Whether to include this node in the path.
include_other : bool
Whether to include the other node in the path.
include_ancestor : bool
Whether to include the lowest common ancestor (LCA) in the path.
Only relevant when self and other are in parallel subtrees, ignored otherwise.
Returns
-------
list
A list of nodes representing the path from this node to the other node.
"""
if self is other:
return [self] if (include_self and include_other) else []
common_ancestor = self.find_common_ancestor(other)
if common_ancestor is self:
include_ancestor = include_self
elif common_ancestor is other:
include_ancestor = include_other
path_from_self = self.path_to_ancestor(common_ancestor,
include_self=include_self,
include_ancestor=include_ancestor)
path_from_other = other.path_to_ancestor(common_ancestor,
include_self=include_other,
include_ancestor=False)
return path_from_self + path_from_other[::-1]
[docs]
def connect_to_parent(self, parent):
"""
Attach the node to a parent node.
Warning
-------
This method should not be used directly when working with trees
as it doesn't add the node to the tree's list of nodes.
Use the `Tree` class to insert nodes into the tree.
Args:
parent (Node): The parent node to attach the node to.
"""
if parent in self.subtree:
raise ValueError(f'Attaching a node will create a loop in the tree: {self} -> {parent}')
self.parent = parent
if self not in parent.children:
parent.children.append(self)
# parent.childrensorted(parent.children + [node], key=lambda x: x.idx)
[docs]
def disconnect_from_parent(self):
"""
Detach the node from its parent.
Examples
--------
for child in node.children: child.disconnect_from_parent()
"""
if self.parent:
self.parent.children.remove(self)
self.parent = None
# ========================================================================
# TREE
# ========================================================================
[docs]
class Tree:
"""
A class to represent a tree data structure.
A tree graph is a hierarchical data structure that consists of nodes connected by edges.
Parameters
----------
nodes : list
A list of nodes in the tree.
Attributes
----------
root : Node
The root node of the tree.
"""
def __init__(self, nodes: list) -> None:
for node in nodes: node._tree = self
self._nodes = nodes
self.root = self._find_root()
if not self.is_connected:
self._connect_nodes()
# MAGIC METHODS
def __repr__(self):
return f"Tree(root={self.root!r}, num_nodes={len(self._nodes)})"
def __getitem__(self, idx):
return self._nodes[idx]
def __len__(self):
return len(self._nodes)
def __iter__(self):
for node in self._nodes:
yield node
def __contains__(self, node):
return node in self._nodes
# PROPERTIES
@property
def is_connected(self):
"""
Whether the tree is connected i.e. each node can be reached from the root.
Returns
-------
bool
True if the root node's subtree contains exactly the same nodes
as the entire tree. False otherwise.
"""
nodes_set = set(self._nodes)
subtree_set = set(self.root.subtree)
return nodes_set == subtree_set
@property
def is_sorted(self):
"""
Whether the nodes in the tree are sorted by index.
Returns
-------
bool
True if the nodes are sorted by index. False otherwise.
"""
if not all([node.idx == i for i, node in enumerate(self._nodes, start=0)]):
return False
traversal_indices = [node.idx for node in self.traverse()]
return traversal_indices == sorted(traversal_indices)
@property
def bifurcations(self):
"""
The bifurcation nodes in the tree.
Returns
-------
list
A list of bifurcation nodes in the tree.
"""
return [node for node in self._nodes if len(node.children) > 1]
@property
def terminations(self):
return [node for node in self._nodes if len(node.children) == 0]
@property
def edges(self) -> list:
"""
Returns a list of edges in the tree.
Returns:
list[tuple[Node, Node]]: A list of edges in the tree.
"""
edges = []
for node in self._nodes:
if node.parent is not None:
edges.append((node.parent, node))
return edges
# TREE CONSTRUCTION METHODS
def _find_root(self):
"""
Find the root node.
Returns
-------
Node
The root node of the tree.
"""
ROOT_PARENT = {None, -1, '-1'}
root_nodes = [node for node in self._nodes if node.parent_idx in ROOT_PARENT]
if len(root_nodes) != 1:
print('Root nodes:', root_nodes)
raise ValueError(f'Tree must have exactly one root node. Found: {root_nodes}')
return root_nodes[0]
def _connect_nodes(self):
"""
Efficiently builds the hierarchical tree structure for the nodes
using a dictionary for fast parent lookups.
"""
if self.is_connected:
print(' Tree already connected.')
return
# Step 1: Create a dictionary for O(1) lookups
node_map = {node.idx: node for node in self._nodes}
# Step 2: Assign parent-child relationships in O(N) time
for node in self._nodes:
if node is not self.root and node.parent_idx in node_map:
node.connect_to_parent(node_map[node.parent_idx])
# Step 3: Ensure tree is fully connected
if not self.is_connected:
raise ValueError('Tree is not connected.')
# TRAVERSAL METHODS
[docs]
def traverse(self, root=None):
"""
Iterate over the nodes in the tree using a stack-based
depth-first traversal.
Parameters
----------
root : Node, optional
The root node to start the traversal from. Defaults to None.
"""
root = root or self.root
stack = [root]
visited = set()
while stack:
node = stack.pop()
if node in visited:
continue
yield node
visited.add(node)
for child in reversed(node.children):
stack.append(child)
# SORTIONG METHODS
def _sort_children(self):
"""
Iterate through all nodes in the tree and sort their children based on
the number of bifurcations (nodes with more than one child) in each child's
subtree. Nodes with fewer bifurcations in their subtrees are placed earlier in the list
of the node's children, ensuring that the shortest paths are traversed first.
"""
# subtree_size_map = {node: len(self.get_subtree(node)) for node in self._nodes}
for node in self._nodes:
node.children = sorted(
node.children,
key=lambda x: sum(1 for n in x.subtree if len(n.children) > 1),
reverse=False
)
# @timeit
[docs]
def sort(self, sort_children=True, force=False):
"""
Sort the nodes in the tree using a stack-based depth-first traversal.
Parameters
----------
sort_children : bool, optional
Whether to sort the children of each node
based on the number of bifurcations in their subtrees. Defaults to True.
force : bool, optional
Whether to force the sorting of the tree even if it is already sorted. Defaults to False.
"""
if self.is_sorted and not force:
return
if sort_children:
self._sort_children()
count = 0
for node in self.traverse():
node.idx = count
node.parent_idx = node.parent.idx if node.parent else -1
count += 1
self._nodes = sorted(self._nodes, key=lambda x: x.idx)
if not self.is_sorted:
raise ValueError('Tree is not sorted.')
print(f'Sorted {self}.')
# INSERTION AND REMOVAL METHODS
[docs]
def remove_node(self, node):
"""
Remove a node from the tree.
Parameters
----------
node : Node
The node to remove.
Raises
------
ValueError
If the tree is not sorted.
"""
if node.is_root:
raise ValueError('Cannot remove the root node.')
parent = node.parent
children = node.children[:]
for child in children:
child.disconnect_from_parent()
child.connect_to_parent(parent)
node.disconnect_from_parent()
self._nodes.remove(node)
[docs]
def remove_subtree(self, node):
"""
Remove a subtree from the tree.
Parameters
----------
node : Node
The root node of the subtree to remove.
"""
node.disconnect_from_parent()
for n in node.subtree:
self._nodes.remove(n)
[docs]
def add_subtree(self, node, parent):
"""
Add a subtree to the tree.
Parameters
----------
node : Node
The root node of the subtree to add.
parent : Node
The parent node to attach the subtree to.
"""
node.connect_to_parent(parent)
self._nodes.extend(node.subtree)
[docs]
def insert_node_after(self, new_node, existing_node):
"""
Insert a node after a given node in the tree.
Parameters
----------
new_node : Node
The new node to insert.
existing_node : Node
The existing node after which to insert the new node.
"""
if new_node in self._nodes:
raise ValueError('Node already exists in the tree.')
for child in existing_node.children:
child.disconnect_from_parent()
child.connect_to_parent(new_node)
new_node.connect_to_parent(existing_node)
self._nodes.append(new_node)
[docs]
def insert_node_before(self, new_node, existing_node):
"""
Insert a node before a given node in the tree.
Parameters
----------
new_node : Node
The new node to insert.
existing_node : Node
The existing node before which to insert the new node.
"""
if new_node in self._nodes:
raise ValueError('Node already exists in the tree.')
new_node.connect_to_parent(existing_node.parent)
existing_node.disconnect_from_parent()
existing_node.connect_to_parent(new_node)
self._nodes.append(new_node)
[docs]
def reposition_subtree(self, node, new_parent_node, origin=None):
"""
Reposition a subtree in the tree.
Parameters
----------
node : Node
The root node of the subtree to reposition.
new_parent_node : Node
The new parent node of the subtree.
origin : Node, optional
The origin node to use as the reference point for the repositioning.
Defaults to None.
Note
-----
Treats differently the children of the root node.
"""
if node.is_root:
raise ValueError('Cannot reposition the root node.')
origin = origin or node.parent
self.remove_subtree(node)
shift_coordinates(node.subtree,
origin=origin,
target=new_parent_node)
self.add_subtree(node, new_parent_node)
# VISUALIZATION METHODS
[docs]
def topology(self):
"""
Print the topology of the tree with a visual tree structure.
"""
def print_node(node, prefix="", is_last=True):
"""Recursive function to print the node with branches."""
# Print the current
root_str = f"{node.parent_idx:6} | "
prefix = root_str + prefix
print(prefix + '•' + str(node.idx))
# Handle the children nodes
num_children = len(node.children)
for i, child in enumerate(node.children):
is_last_child = (i == num_children - 1)
branch = "└─" if is_last_child else "├─"
prefix = prefix.replace("└─", " ").replace("├─", "│ ")
prefix = prefix.replace(root_str, "")
print_node(child, prefix + branch, is_last_child)
print('parent | idx')
print('-'*15)
print_node(self.root)
def shift_coordinates(points, origin, target):
"""
Shift the coordinates of a list of points from an origin to a target.
"""
origin_vector = (origin.x, origin.y, origin.z)
target_vector = (target.x, target.y, target.z)
for pt in points:
pt.x = pt.x - origin_vector[0] + target_vector[0]
pt.y = pt.y - origin_vector[1] + target_vector[1]
pt.z = pt.z - origin_vector[2] + target_vector[2]