# SPDX-FileCopyrightText: 2025 Poirazi Lab <dendrotweaks@dendrites.gr>
# SPDX-License-Identifier: MPL-2.0
from typing import Callable, Dict, List
from numpy import ndarray, full_like
from numpy import exp, sin, polyval
import warnings
# Define simple functions and store them alongside their defaults in FUNCTIONS
def constant(position, value=0):
"""
Constant function that returns a constant value for any position.
Parameters
----------
position : float or numpy.ndarray
The position at which to evaluate the function.
value : float
The constant value to return.
Returns
-------
float or numpy.ndarray
The value of the constant function at the given position.
"""
if isinstance(position, ndarray):
return full_like(position, value, dtype=float)
else:
return value
def uniform(position, value=0):
"""
Constant function that returns a constant value for any position.
Parameters
----------
position : float or numpy.ndarray
The position at which to evaluate the function.
value : float
The constant value to return.
Returns
-------
float or numpy.ndarray
The value of the constant function at the given position.
"""
if isinstance(position, ndarray):
return full_like(position, value, dtype=float)
else:
return value
def linear(position, slope=1, intercept=0):
"""
Linear function that returns a linearly changing value for any position.
Parameters
----------
position : float or numpy.ndarray
The position at which to evaluate the function.
slope : float
The slope of the linear function.
intercept : float
The intercept of the linear function.
Returns
-------
float or numpy.ndarray
The value of the linear function at the given position.
"""
return slope * position + intercept
def power(position, vertical_shift=0, scale_factor=1, exponent=-1, horizontal_shift=0):
"""
Power function that returns a value raised to a given exponent for any position.
Parameters
----------
position : float or numpy.ndarray
The position at which to evaluate the function.
vertical_shift : float
The vertical shift to be applied to the result.
scale_factor : float
The scale factor to be applied to the result.
exponent : float
The exponent to which the position is raised.
horizontal_shift : float
The horizontal shift to be applied to the position.
Returns
-------
float or numpy.ndarray
The value of the power function at the given position.
"""
return vertical_shift + scale_factor * ((position + horizontal_shift) ** exponent)
def exponential(distance: float, vertical_shift:float = 0, scale_factor: float =1, growth_rate: float=1, horizontal_shift: float = 0) -> float:
"""
Exponential distribution function.
Args:
distance (float): The distance parameter.
vertical_shift (float): The vertical shift parameter.
scale_factor (float): The scale factor parameter.
growth_rate (float): The growth rate parameter.
horizontal_shift (float): The horizontal shift parameter.
Returns:
The result of the exponential equation: vertical_shift + scale_factor * exp(growth_rate * (distance - horizontal_shift)).
"""
return vertical_shift + scale_factor * exp(growth_rate * (distance - horizontal_shift))
def sigmoid(distance: float, vertical_shift=0, scale_factor=1, growth_rate=1, horizontal_shift=0) -> float:
"""
Sigmoid distribution function.
Args:
distance (float): The distance parameter.
vertical_shift (float): The vertical shift parameter.
scale_factor (float): The scale factor parameter.
growth_rate (float): The growth rate parameter.
horizontal_shift (float): The horizontal shift parameter.
Returns:
The result of the sigmoid equation: vertical_shift + scale_factor / (1 + exp(-growth_rate * (distance - horizontal_shift))).
"""
return vertical_shift + scale_factor / (1 + exp(-growth_rate*(distance - horizontal_shift)))
def sinusoidal(distance: float, amplitude: float, frequency: float, phase: float) -> float:
"""
Sinusoidal distribution function.
Args:
distance (float): The distance parameter.
amplitude (float): The amplitude parameter.
frequency (float): The frequency parameter.
phase (float): The phase parameter.
Returns:
The result of the sinusoidal equation: amplitude * sin(frequency * distance + phase).
"""
return amplitude * sin(frequency * distance + phase)
def gaussian(distance: float, amplitude: float, mean: float, std: float) -> float:
"""
Gaussian distribution function.
Args:
distance (float): The distance parameter.
amplitude (float): The amplitude parameter.
mean (float): The mean parameter.
std (float): The standard deviation parameter.
Returns:
The result of the gaussian equation: amplitude * exp(-((distance - mean) ** 2) / (2 * std ** 2)).
"""
return amplitude * exp(-((distance - mean) ** 2) / (2 * std ** 2))
def step(distance: float, start: float, end: float, min_value: float, max_value: float) -> float:
"""
Step distribution function.
Args:
distance (float): The distance parameter.
start (float): The start parameter.
end (float): The end parameter.
min_value (float): The minimum value parameter.
max_value (float): The maximum value parameter.
Returns:
The result of the step equation: min_value if distance < start, max_value if distance > end, and a linear interpolation between min_value and max_value if start <= distance <= end.
"""
if start < distance < end:
return max_value
else:
return min_value
def polynomial(distance: float, coeffs: List[float]) -> float:
"""
Polynomial distribution function.
Args:
distance (float): The distance parameter.
coefficients (List[float]): The coefficients of the polynomial.
Returns:
The result of the polynomial equation: sum(coefficients[i] * distance ** i for i in range(len(coefficients))).
"""
return polyval(coeffs, distance)
# aka ParametrizedFunction
[docs]
class Distribution:
"""
A callable class for creating and managing distribution functions.
Parameters
----------
function_name : str
The name of the function to use.
\**parameters
The parameters to use for the function.
Attributes
----------
function : Callable
The function to use for evaluation.
parameters : dict
The parameters to use for the function.
Examples
--------
>>> func = Distribution('uniform', value=0)
>>> func(5)
0
"""
FUNCTIONS = {
'constant': {'func': constant, 'defaults': {'value': 0}},
'uniform': {'func': uniform, 'defaults': {'value': 0}},
'linear': {'func': linear, 'defaults': {'slope': 1, 'intercept': 0}},
'exponential': {'func': exponential, 'defaults': {'vertical_shift': 0, 'scale_factor': 1, 'growth_rate': 1, 'horizontal_shift': 0}},
'sigmoid': {'func': sigmoid, 'defaults': {'vertical_shift': 0, 'scale_factor': 1, 'growth_rate': 1, 'horizontal_shift': 0}},
'sinusoidal': {'func': sinusoidal, 'defaults': {'amplitude': 1, 'frequency': 1, 'phase': 0}},
'gaussian': {'func': gaussian, 'defaults': {'amplitude': 1, 'mean': 0, 'std': 1}},
'step': {'func': step, 'defaults': {'max_value': 1, 'min_value': 0, 'start': 0, 'end': 1}},
'polynomial': {'func': polynomial, 'defaults': {'coeffs': [1, 0]}},
}
[docs]
@staticmethod
def from_dict(data: Dict[str, any]) -> 'Distribution':
"""
Creates a new Distribution from a dictionary.
Parameters
----------
data : dict
The dictionary containing the function data.
Returns
-------
Distribution
The new Distribution instance.
"""
return Distribution(data['function'], **data['parameters'])
def __init__(self, function_name: str, **parameters: Dict[str, float]) -> None:
"""
Creates a new parameterized function.
Parameters
----------
function_name : str
The name of the function to use.
\**parameters
The parameters to use for the function.
"""
func_data = self.FUNCTIONS[function_name]
self.function = func_data['func']
# Merge defaults with user parameters
valid_params = {k: v for k, v in parameters.items()
if k in func_data['defaults']}
self.parameters = {**func_data['defaults'], **valid_params}
def __repr__(self):
"""
Returns a string representation of the function.
Returns
-------
str
The string representation of the function.
"""
return f'{self.function.__name__}({self.parameters})'
def __call__(self, position):
"""
Calls the function with a given position.
Parameters
----------
position : float or numpy.ndarray
The position at which to evaluate the function.
Returns
-------
float or numpy.ndarray
The value of the function at the given position.
"""
return self.function(position, **self.parameters)
@property
def function_name(self):
"""
Returns the name of the function.
Returns
-------
str
The name of the function.
"""
return self.function.__name__
@property
def degree(self):
"""
Returns the degree of the polynomial function (if applicable).
Returns
-------
int
The degree of the function.
"""
if self.function_name == 'polynomial':
return len(self.parameters['coeffs']) - 1
else:
return None
[docs]
def update_parameters(self, **new_params):
"""
Updates the parameters of the function.
Parameters
----------
\**new_params
The new parameters to update the function with.
"""
valid_params = {k: v for k, v in new_params.items()
if k in self.parameters}
# if any of the new parameters are invalid, raise an error
if len(valid_params) != len(new_params):
invalid_params = set(new_params) - set(valid_params)
warnings.warn(f'\nIgnoring invalid parameters: {invalid_params}.\nSupported parameters: {list(self.parameters.keys())}')
self.parameters.update(valid_params)
[docs]
def to_dict(self):
"""
Exports the function to a dictionary format.
Returns
-------
dict
A dictionary representation of the function.
"""
return {
'function': self.function.__name__,
'parameters': self.parameters
}