Source code for dendrotweaks.biophys.io.code_generators

# SPDX-FileCopyrightText: 2025 Poirazi Lab <dendrotweaks@dendrites.gr>
# SPDX-License-Identifier: MPL-2.0

import re
import os
from jinja2 import Template

from abc import ABC, abstractmethod
from jinja2 import Environment, FileSystemLoader
from dendrotweaks.utils import write_file

# Configure the Jinja2 environment
# env = Environment(
#     loader=FileSystemLoader('static/data/templates'),  # Load templates from the 'templates' directory
#     trim_blocks=False,                      # Trim newlines after Jinja blocks
#     lstrip_blocks=False,                     # Strip leading whitespace from Jinja blocks
# )

EQUILIBRIUM_POTENTIALS = {
    'na': 60,
    'k': -80,
    'ca': 140
}

class CodeGenerator(ABC):

    @abstractmethod
    def generate(self, ast, path_to_template):
        pass

    def write_file(self, path_to_file):
        write_file(self.content, path_to_file)


[docs] class PythonCodeGenerator(CodeGenerator): """ A class to generate Python code from an AST using a Jinja2 template. """ def __init__(self): self.content = None # MAIN METHOD
[docs] def generate(self, ast, path_to_template): """ Generate a Python class from the AST using a Jinja2 template. Parameters ---------- ast : dict The AST representation of the channel path_to_template : str The path to the Jinja2 template file Returns ------- str The Python code generated from the AST """ # Read the template file with open(path_to_template, 'r') as file: template_string = file.read() # # Create a Jinja2 template from the string template = Template(template_string) # template = env.get_template(self.path_to_template) # Define the variables for the template variables = { 'title': ast.title, # 'comment': ast.comment, 'class_name': ast.suffix, 'suffix': ast.suffix, 'ion': ast.ion, 'independent_var_name': ast.independent_var_name, 'channel_params': ast.params, 'range_params': ast.range_params, 'current_available': ast.current_available, 'state_vars': ast.state_vars, 'functions': self._generate_functions(ast), 'procedures': self._generate_procedures(ast), 'procedure_calls': self._generate_procedure_calls(ast), 'E_ion': EQUILIBRIUM_POTENTIALS.get(ast.ion, None) } # Render the template with the variables content = template.render(variables) if re.search(r'\bjnp\b', template_string): content = content.replace('np', 'jnp') self.content = content
# HELPER METHODS def _generate_functions(self, ast, indent=8): functions = [] for function in ast.functions: # Generate the signature signature_str = self._generate_signature(function.signature) # Generate the body body_str = self._generate_body(function.statements) for name in [function.name for function in ast.functions if function != function]: body_str = re.sub(r'\b' + re.escape(name) + r'\b', f"self.{name}", body_str) body_str = re.sub(r'\b' + re.escape('tadj') + r'\b', f"self.tadj", body_str) body_str = re.sub(r'\b' + re.escape('celsius') + r'\b', f"self.temperature", body_str) body_str += f"return {function.name}" body_str = '\n'.join(' ' * indent + line for line in body_str.splitlines()) # Find the parameters that are used in the body params = [param for param in ast.params if param not in function.local_vars and param not in function.params and re.search(r'\b' + re.escape(param) + r'\b', body_str)] functions.append({ 'signature': signature_str, 'params': params, 'body': body_str.strip() }) return functions def _generate_procedures(self, ast, indent=8): if len(ast.procedures) != 1: raise ValueError("Only one procedure is supported") ast.procedures[0].name = 'compute_kinetic_variables' procedures = [] for procedure in ast.procedures: # Generate the signature signature_str = self._generate_signature(procedure.signature, is_method=True, extra_params=['celsius'], default_params=[ast.independent_var_name]) # Generate the body body_str = self._generate_body(procedure.statements) for name in [function.name for function in ast.functions]: body_str = re.sub(r'\b' + re.escape(name) + r'\b', f"self.{name}", body_str) body_str = re.sub(r'\b' + re.escape('tadj') + r'\b', f"self.tadj", body_str) body_str = re.sub(r'\b' + re.escape('celsius') + r'\b', f"self.temperature", body_str) body_str += 'return ' + ', '.join([f"{state_var}Inf, {state_var}Tau" for state_var in ast.state_vars]) body_str = '\n'.join(' ' * indent + line for line in body_str.splitlines()) # Find the parameters that are used in the body params = [param for param in ast.params if param not in procedure.local_vars and re.search(r'\b' + re.escape(param) + r'\b', body_str)] procedures.append({ 'signature': signature_str, 'params': params, 'body': body_str.strip() }) return procedures def _generate_signature(self, signature, is_method=True, extra_params=None, default_params=None): """ Generate the signature string for a function using a Jinja2 template. The function AST representation is used to retrieve the function name and parameters: >>> def f_name(self, arg1, arg2, ...): Parameters ---------- signature : dict The function signature as an AST dictionary is_method : bool Whether the function is a class method or not """ signature_template = ( "def {{ name }}({% if params %}{{ params | join(', ') }}{% endif %}):" ) template = Template(signature_template) name = signature['name'] params = [param['name'] for param in signature.get('params', [])] default_params = default_params or [] if params == [] and default_params: print(f"Warning: Procedure {name} has no parameters! Expected 'v' or 'cai'. Defaulting to '{default_params[0]}'.") params = default_params if is_method: params = ['self'] + params return template.render(name=name, params=params) def _generate_body(self, statements, indent=12, skip_vars=['tadj']): python_code = "" # Add statements for statement in statements: # If the statement is an if-else statement if statement.get('condition', False): python_code += self._generate_conditionals(statement) else: if statement['assigned_var'] in skip_vars: continue python_code += (f"{statement['assigned_var']} = {statement['expression']}\n") return python_code def _generate_conditionals(self, statement): """ Generate the conditional statement for an if-else block using a Jinja2 template. """ condition = statement['condition'] if_statements = statement['if_statements'] else_statements = statement.get('else_statements', []) else_statements = {statement['assigned_var']: statement['expression'] for statement in else_statements} conditional_code = "" for if_statement in statement['if_statements']: # Default to variable name if not in else_expressions else_statement = else_statements.get( if_statement['assigned_var'], if_statement["assigned_var"] ) # Use a Jinja2 template to generate the conditional code conditional_template = ( "conditions = [{{ condition }}, ~({{ condition }})]" "\nchoices = [{{ if_statement }}, {{ else_statement }}]" "\n{{ assigned_var }} = np.select(conditions, choices)" ) template = Template(conditional_template) conditional_code += template.render( condition=condition, if_statement=if_statement['expression'], else_statement=else_statement, assigned_var=if_statement['assigned_var'] ) conditional_code += "\n" # Add a newline between blocks return conditional_code def _generate_procedure_calls(self, ast): """ Generate procedure call statements from AST procedures. Used only for Jaxley-compatible code generation. """ for procedure in ast.procedures: name = procedure.signature['name'] params = [param['name'] for param in procedure.signature.get('params', [])] if params == []: params = [ast.independent_var_name] state_vars = list(ast.state_vars.keys()) procedure_call_template = """{%- for state_var in state_vars -%} {{ state_var }}Inf, {{ state_var }}Tau{% if not loop.last %}, {% endif -%} {% endfor %} = self.{{ name }}({% if params %}{{ params | join(', ') }}{% endif %}) """ template = Template(procedure_call_template.strip()) return template.render( name=name, params=params, state_vars=state_vars )
[docs] class NMODLCodeGenerator(CodeGenerator): """ A class to generate NMODL code from a StandardIonChannel""" def __init__(self): self.content = None
[docs] def generate(self, channel, path_to_template: str) -> None: """ Generate NMODL code for a standardized ion channel using a Jinja2 template. Parameters ---------- channel : StandardIonChannel The standardized ion channel. path_to_template : str The path to the Jinja2 template file. """ # Read the template file with open(path_to_template, 'r') as file: template_string = file.read() # Create a Jinja2 template from the string template = Template(template_string) def get_unit(param): if param.startswith('vhalf_'): return 'mV' elif param.startswith('sigma_'): return 'mV' elif param.startswith('k_'): return '1/ms' elif param.startswith('delta_'): return '1' elif param.startswith('tau0_'): return 'ms' elif param.startswith('temp'): return 'degC' elif param.startswith('q10'): return '1' elif param.startswith('gbar'): return 'S/cm2' else: return '1' # Define the variables for the template variables = { 'suffix': channel.name, 'ion': channel.ion, 'params': [ (param, channel.params[param], get_unit(param)) for param in channel.params ], 'range_params': [ (param, channel.range_params[param], get_unit(param)) for param in channel.range_params ], 'has_tadj': ('q10' in channel.params and 'temp' in channel.params), 'state_vars': { var: params['power'] for var, params in channel._state_powers.items() }, } # Render the template with the variables content = template.render(variables) self.content = content return content