# SPDX-FileCopyrightText: 2025 Poirazi Lab <dendrotweaks@dendrites.gr>
# SPDX-License-Identifier: MPL-2.0
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.spatial.transform import Rotation
from dendrotweaks.utils import timeit
from dendrotweaks.morphology.trees import Node, Tree
from dendrotweaks.utils import timeit
from contextlib import contextmanager
import random
[docs]
class Point(Node):
"""
A class representing a single point in a morphological reconstruction.
Parameters
----------
idx : str
The unique identifier of the node.
type_idx : int
The type of the node according to the SWC specification (e.g. soma-1, axon-2, dendrite-3).
x : float
The x-coordinate of the node.
y : float
The y-coordinate of the node.
z : float
The z-coordinate of the node.
r : float
The radius of the node.
parent_idx : str
The identifier of the parent node.
Attributes
----------
idx : str
The unique identifier of the node.
type_idx : int
The type of the node according to the SWC specification (e.g. soma-1, axon-2, dendrite-3).
x : float
The x-coordinate of the node.
y : float
The y-coordinate of the node.
z : float
The z-coordinate of the node.
r : float
The radius of the node.
parent_idx : str
The identifier of the parent node.
"""
def __init__(self, idx: str, type_idx: int,
x: float, y: float, z: float, r: float, parent_idx: str,
domain_name: str, domain_color: str) -> None:
super().__init__(idx, parent_idx)
self.type_idx = int(type_idx)
self.x = x
self.y = y
self.z = z
self.r = r
self.domain_name = domain_name
self.domain_color = domain_color
self._section = None
@property
def distance_to_parent(self):
# TODO: could this be cached?
"""
The Euclidean distance from this node to its parent.
"""
if self.parent:
return np.sqrt((self.x - self.parent.x)**2 +
(self.y - self.parent.y)**2 +
(self.z - self.parent.z)**2)
return 0
[docs]
def path_distance(self, other=None):
"""
Calculate the path distance from this node to another node or to the root.
Parameters
----------
other : Point, optional
The other node to calculate the path distance to.
If None, calculates the distance to the root, by default None.
"""
if self is other:
return 0
if other is None:
path = self.path_to_ancestor(include_self=True,
include_ancestor=False)
path_length = sum(pt.distance_to_parent for pt in path)
return path_length
path = self.path(other,
include_self=True,
include_other=True,
include_ancestor=False)
path_length = sum(pt.distance_to_parent for pt in path)
return path_length
[docs]
def copy(self):
"""
Create a copy of the node.
Returns:
Point: A copy of the node with the same attributes.
"""
new_node = Point(self.idx, self.type_idx,
self.x, self.y, self.z, self.r, self.parent_idx,
self.domain_name, self.domain_color)
return new_node
[docs]
def overlaps_with(self, other, **kwargs) -> bool:
"""
Check if the coordinates of this node overlap with another node.
Args:
other (Point): The other node to compare with.
kwargs: Additional keyword arguments passed to np.allclose.
Returns:
bool: True if the coordinates overlap, False otherwise.
"""
return np.allclose(
[self.x, self.y, self.z],
[other.x, other.y, other.z],
**kwargs
)
[docs]
class PointTree(Tree):
"""
A class representing a tree graph of points in a morphological reconstruction.
Parameters
----------
nodes : list[Point]
A list of points in the tree.
"""
def __init__(self, nodes: list[Point]) -> None:
super().__init__(nodes)
self._sections = []
self._is_extended = False
def __repr__(self):
return f"PointTree(root={self.root!r}, num_nodes={len(self._nodes)})"
# PROPERTIES
@property
def points(self):
"""
The list of points in the tree. An alias for self._nodes.
"""
return self._nodes
# @property
# def is_sectioned(self):
# return len(self._sections) > 0
@property
def soma_points(self):
"""
The list of points representing the soma (type 1).
"""
return [pt for pt in self.points if pt.type_idx == 1]
@property
def soma_center(self):
"""
The center of the soma as the average of the coordinates of the soma points.
"""
return np.mean([[pt.x, pt.y, pt.z]
for pt in self.soma_points], axis=0)
@property
def apical_center(self):
"""
The center of the apical dendrite as the average of the coordinates of the apical points.
"""
apical_points = [pt for pt in self.points
if pt.type_idx == 4]
if len(apical_points) == 0:
return None
return np.mean([[pt.x, pt.y, pt.z]
for pt in apical_points], axis=0)
@property
def soma_notation(self):
"""
The type of soma notation used in the tree.
- '1PS': One-point soma
- '2PS': Two-point soma
- '3PS': Three-point soma
- 'contour': Soma represented as a contour
"""
if len(self.soma_points) == 1:
return '1PS'
elif len(self.soma_points) == 2:
return '2PS'
elif len(self.soma_points) == 3:
return '3PS'
else:
return 'contour'
@property
def df(self):
"""
A DataFrame representation of the tree.
"""
data = {
'idx': [node.idx for node in self._nodes],
'type_idx': [node.type_idx for node in self._nodes],
'domain_name': [node.domain_name for node in self._nodes],
'domain_color': [node.domain_color for node in self._nodes],
'x': [node.x for node in self._nodes],
'y': [node.y for node in self._nodes],
'z': [node.z for node in self._nodes],
'r': [node.r for node in self._nodes],
'parent_idx': [node.parent_idx for node in self._nodes]
}
return pd.DataFrame(data)
# STANDARDIZATION METHODS
[docs]
def change_soma_notation(self, notation):
"""
Convert the soma to 3PS notation.
"""
if self.soma_notation == notation:
return
if self.soma_notation == '1PS':
pt = self.soma_points[0]
pt_left = Point(
idx=2,
type_idx=1,
x=pt.x - pt.r,
y=pt.y,
z=pt.z,
r=pt.r,
parent_idx=pt.idx,
domain_name=pt.domain_name,
domain_color=pt.domain_color)
pt_right = Point(
idx=3,
type_idx=1,
x=pt.x + pt.r,
y=pt.y,
z=pt.z,
r=pt.r,
parent_idx=pt.idx,
domain_name=pt.domain_name,
domain_color=pt.domain_color)
self.add_subtree(pt_right, pt)
self.add_subtree(pt_left, pt)
elif self.soma_notation == '3PS':
raise NotImplementedError('Conversion from 1PS to 3PS notation is not implemented yet.')
elif self.soma_notation =='contour':
# if soma has contour notation, take the average
# distance of the nodes from the center of the soma
# and use it as radius, create 3 new nodes
raise NotImplementedError('Conversion from contour is not implemented yet.')
print('Converted soma to 3PS notation.')
# GEOMETRICAL METHODS
[docs]
def round_coordinates(self, decimals=8):
"""
Round the coordinates of all nodes to the specified number of decimals.
Parameters
----------
decimals : int, optional
The number of decimals to round to, by default
"""
for pt in self.points:
pt.x = round(pt.x, decimals)
pt.y = round(pt.y, decimals)
pt.z = round(pt.z, decimals)
pt.r = round(pt.r, decimals)
[docs]
def shift_coordinates_to_soma_center(self):
"""
Shift all coordinates so that the soma center is at the origin (0, 0, 0).
"""
soma_x, soma_y, soma_z = self.soma_center
for pt in self.points:
pt.x = round(pt.x - soma_x, 8)
pt.y = round(pt.y - soma_y, 8)
pt.z = round(pt.z - soma_z, 8)
@timeit
def rotate(self, angle_deg, axis='Y'):
"""Rotate the point cloud around the specified axis at the soma center using numpy.
Parameters
----------
angle_deg : float
The rotation angle in degrees.
axis : str, optional
The rotation axis ('X', 'Y', or 'Z'), by default 'Y'.
"""
# Get the rotation center point
rotation_point = self.soma_center
# Define rotation matrix based on the specified axis
angle = np.radians(angle_deg)
if axis == 'X':
rotation_matrix = np.array([
[1, 0, 0],
[0, np.cos(angle), -np.sin(angle)],
[0, np.sin(angle), np.cos(angle)]
])
elif axis == 'Y':
rotation_matrix = np.array([
[np.cos(angle), 0, np.sin(angle)],
[0, 1, 0],
[-np.sin(angle), 0, np.cos(angle)]
])
elif axis == 'Z':
rotation_matrix = np.array([
[np.cos(angle), -np.sin(angle), 0],
[np.sin(angle), np.cos(angle), 0],
[0, 0, 1]
])
else:
raise ValueError("Axis must be 'X', 'Y', or 'Z'")
# Subtract rotation point to translate the cloud to the origin
coords = np.array([[pt.x, pt.y, pt.z] for pt in self.points])
coords -= rotation_point
# Apply rotation
rotated_coords = np.dot(coords, rotation_matrix.T)
# Translate back to the original position
rotated_coords += rotation_point
# Update the coordinates of the points
for pt, (x, y, z) in zip(self._nodes, rotated_coords):
pt.x, pt.y, pt.z = x, y, z
[docs]
def align_apical_dendrite(self, axis='Y', facing='up'):
"""
Align the apical dendrite with the specified axis.
Parameters
----------
axis : str, optional
The axis to align the apical dendrite with ('X', 'Y', or 'Z'), by default 'Y'.
facing : str, optional
The direction the apical dendrite should face ('up' or 'down'), by default 'up'.
"""
soma_center = self.soma_center
apical_center = self.apical_center
if apical_center is None:
return
# Define the target vector based on the axis and facing
target_vector = {
'X': np.array([1, 0, 0]),
'Y': np.array([0, 1, 0]),
'Z': np.array([0, 0, 1])
}.get(axis.upper(), None)
if target_vector is None:
raise ValueError("Axis must be 'X', 'Y', or 'Z'")
if facing == 'down':
target_vector = -target_vector
# Calculate the current vector
current_vector = apical_center - soma_center
# Check if the apical dendrite is already aligned
if np.allclose(current_vector / np.linalg.norm(current_vector), target_vector):
print('Apical dendrite is already aligned.')
return
# Calculate the rotation vector and angle
rotation_vector = np.cross(current_vector, target_vector)
rotation_angle = np.arccos(np.dot(current_vector, target_vector) / np.linalg.norm(current_vector))
# Create the rotation matrix
rotation_matrix = Rotation.from_rotvec(rotation_angle * rotation_vector / np.linalg.norm(rotation_vector)).as_matrix()
# Apply the rotation to each point
for pt in self.points:
coords = np.array([pt.x, pt.y, pt.z]) - soma_center
rotated_coords = np.dot(rotation_matrix, coords) + soma_center
pt.x, pt.y, pt.z = rotated_coords
# I/O METHODS
[docs]
def remove_overlaps(self):
"""
Remove overlapping nodes from the tree.
"""
n_nodes_before = len(self.points)
overlapping_nodes = [
pt for pt in self.traverse()
if pt.parent is not None and pt.overlaps_with(pt.parent)
]
for pt in overlapping_nodes:
self.remove_node(pt)
self._is_extended = False
n_nodes_after = len(self.points)
if n_nodes_before != n_nodes_after:
print(f'Removed {n_nodes_before - n_nodes_after} overlapping nodes.')
[docs]
def extend_sections(self):
"""
Extend each section by adding a node in the beginning
overlapping with the parent node for geometrical continuity.
"""
n_nodes_before = len(self.points)
if self._is_extended:
print('Tree is already extended.')
return
bifurcations_excluding_root = [
b for b in self.bifurcations if b != self.root
]
for pt in bifurcations_excluding_root:
children = pt.children[:]
for child in children:
if child.overlaps_with(pt):
raise ValueError(f'Child {child} already overlaps with parent {pt}.')
new_node = pt.copy()
new_node.type_idx = child.type_idx
new_node.domain_name = child.domain_name
new_node.domain_color = child.domain_color
if child._section is not None:
new_node._section = child._section
if not new_node in new_node._section.points:
new_node._section.points[0] = new_node
self.insert_node_before(new_node, child)
self._is_extended = True
n_nodes_after = len(self.points)
print(f'Extended {n_nodes_after - n_nodes_before} nodes.')
[docs]
def to_swc(self, path_to_file):
"""
Save the tree to an SWC file.
"""
with remove_overlaps(self):
df = self.df.drop(
columns=['domain_name', 'domain_color']
).astype({
'idx': int,
'type_idx': int,
'x': float,
'y': float,
'z': float,
'r': float,
'parent_idx': int
})
# Shift to 1-based indexing (SWC standard)
df['idx'] += 1
df.loc[df['parent_idx'] >= 0, 'parent_idx'] += 1
# Collect mapping: type_idx → domain / color
domain_map = {}
color_map = {}
for pt in self.points:
domain_map[pt.type_idx] = pt.domain_name
color_map[pt.type_idx] = pt.domain_color
# Sort keys for stable output
sorted_types = sorted(domain_map.keys())
# Create strings
domain_info = " ".join(f"{t}:{domain_map[t]}" for t in sorted_types)
color_info = " ".join(f"{t}:{color_map[t]}" for t in sorted_types)
# Write header
from dendrotweaks import __version__
with open(path_to_file, "w") as f:
f.write(f"# Generated by DendroTweaks {__version__}\n")
f.write(f"# DOMAIN_NAMES {domain_info}\n")
f.write(f"# DOMAIN_COLORS {color_info}\n")
f.write("# ID TYPE_ID X Y Z R PARENT_ID\n")
# Append data for SWC table
df.to_csv(
path_to_file,
sep=" ",
index=False,
header=False,
mode="a"
)
# PLOTTING METHODS
[docs]
def plot(self, ax=None,
show_nodes=True, show_edges=True, show_domains=True,
annotate=False, projection='XY',
highlight_nodes=None, focus_nodes=None):
"""
Plot a 2D projection of the tree.
Parameters
----------
ax : matplotlib.axes.Axes, optional
The axes to plot on, by default None
show_nodes : bool, optional
Whether to plot the nodes, by default True
show_edges : bool, optional
Whether to plot the edges, by default True
show_domains : bool, optional
Whether to color the nodes based on their domains, by default True
annotate : bool, optional
Whether to annotate the nodes with their indices, by default False
projection : str, optional
The projection plane ('XY', 'XZ', or 'YZ'), by default 'XY'
highlight_nodes : list, optional
A list of nodes to highlight, by default None
focus_nodes : list, optional
A list of nodes to focus on, by default None
"""
if ax is None:
fig, ax = plt.subplots(figsize=(10, 10))
# Convert focus/highlight to sets for faster lookup
focus_nodes = set(focus_nodes) if focus_nodes else None
highlight_nodes = set(highlight_nodes) if highlight_nodes else None
# Determine which points to consider
points_to_plot = self.points if focus_nodes is None else [pt for pt in self.points if pt in focus_nodes]
# Extract coordinates for projection
coords = {axis: [getattr(pt, axis.lower()) for pt in points_to_plot] for axis in "XYZ"}
# Draw edges efficiently
if show_edges:
point_set = set(points_to_plot) # Convert list to set for fast lookup
for pt1, pt2 in self.edges:
if pt1 in point_set and pt2 in point_set:
ax.plot(
[getattr(pt1, projection[0].lower()), getattr(pt2, projection[0].lower())],
[getattr(pt1, projection[1].lower()), getattr(pt2, projection[1].lower())],
color='C1'
)
# Assign colors based on domains
if show_domains:
for pt in points_to_plot:
colors = [pt.domain_color for pt in points_to_plot]
else:
colors = 'C0'
# Plot nodes
if show_nodes:
ax.scatter(coords[projection[0]], coords[projection[1]], s=10, c=colors, marker='.', zorder=2)
# Annotate nodes if few enough
if annotate and len(points_to_plot) < 50:
for pt, x, y in zip(points_to_plot, coords[projection[0]], coords[projection[1]]):
ax.annotate(f'{pt.idx}', (x, y), fontsize=8)
# Highlight nodes correctly
if highlight_nodes:
for i, pt in enumerate(points_to_plot):
if pt in highlight_nodes:
ax.plot(coords[projection[0]][i], coords[projection[1]][i], 'o', color='C3', markersize=5)
# Set labels and aspect ratio
ax.set_xlabel(projection[0])
ax.set_ylabel(projection[1])
if projection in {"XY", "XZ", "YZ"}:
ax.set_aspect('equal')
[docs]
def plot_radii_distribution(self, ax=None, highlight=None,
domains=True, show_soma=False):
if ax is None:
fig, ax = plt.subplots(figsize=(8, 3))
for pt in self.points:
if not show_soma and pt.domain_name == 'soma':
continue
color = pt.domain_color
if highlight and pt.idx in highlight:
ax.plot(
pt.path_distance(),
pt.r,
marker='.',
color='red',
zorder=2
)
else:
ax.plot(
pt.path_distance(),
pt.r,
marker='.',
color=color,
zorder=1
)
ax.set_xlabel('Distance from root')
ax.set_ylabel('Radius')
@contextmanager
def remove_overlaps(point_tree):
"""
Context manager for temporarily removing overlaps in the given point_tree.
Is primarily used for saving the tree to an SWC file without overlaps.
Restores the original state of the tree after the context block to ensure
the geometrical continuity of the tree.
"""
# Store whether the point_tree was already extended
was_extended = point_tree._is_extended
# Remove overlaps
point_tree.remove_overlaps()
point_tree.sort()
try:
# Yield control to the context block
yield
finally:
# Restore the overlapping state if the point_tree was extended
if was_extended:
point_tree.extend_sections()
point_tree.sort()