From 9345e09b8e94befa839861c2552bcef58259e4ba Mon Sep 17 00:00:00 2001 From: "C.A.P. Linssen" Date: Mon, 9 Dec 2024 16:07:13 +0100 Subject: [PATCH 1/3] clean up Python code generator --- .../codegeneration/nest_code_generator.py | 6 +- .../nest_desktop_code_generator.py | 4 +- .../printers/python_variable_printer.py | 14 +- .../python_standalone_code_generator.py | 10 +- .../point_neuron/@NEURON_NAME@.py.jinja2 | 49 +-- .../point_neuron/@SYNAPSE_NAME@.py.jinja2 | 396 ++++++++++++++++++ .../point_neuron/directives_py/Block.jinja2 | 6 +- pynestml/frontend/pynestml_frontend.py | 15 +- .../transformers/synapse_remove_post_port.py | 8 +- 9 files changed, 453 insertions(+), 55 deletions(-) create mode 100644 pynestml/codegeneration/resources_python_standalone/point_neuron/@SYNAPSE_NAME@.py.jinja2 diff --git a/pynestml/codegeneration/nest_code_generator.py b/pynestml/codegeneration/nest_code_generator.py index c50d75b21..aad4f60cd 100644 --- a/pynestml/codegeneration/nest_code_generator.py +++ b/pynestml/codegeneration/nest_code_generator.py @@ -679,9 +679,11 @@ def _get_neuron_model_namespace(self, neuron: ASTModel) -> Dict: namespace["state_vars_that_need_continuous_buffering_transformed_iv"][var_name] = self._nest_printer.print(neuron.get_initial_value(var_name_transformed)) else: namespace["state_vars_that_need_continuous_buffering"] = [] - namespace["extra_on_emit_spike_stmts_from_synapse"] = neuron.extra_on_emit_spike_stmts_from_synapse + if "extra_on_emit_spike_stmts_from_synapse" in dir(neuron): + namespace["extra_on_emit_spike_stmts_from_synapse"] = neuron.extra_on_emit_spike_stmts_from_synapse namespace["paired_synapse"] = neuron.paired_synapse - namespace["paired_synapse_original_model"] = neuron.paired_synapse_original_model + if "paired_synapse_original_model" in dir(neuron): + namespace["paired_synapse_original_model"] = neuron.paired_synapse_original_model namespace["paired_synapse_name"] = neuron.paired_synapse.get_name() namespace["post_spike_updates"] = neuron.post_spike_updates namespace["transferred_variables"] = neuron._transferred_variables diff --git a/pynestml/codegeneration/nest_desktop_code_generator.py b/pynestml/codegeneration/nest_desktop_code_generator.py index 930c11fd2..dddf431b4 100644 --- a/pynestml/codegeneration/nest_desktop_code_generator.py +++ b/pynestml/codegeneration/nest_desktop_code_generator.py @@ -18,7 +18,7 @@ # # You should have received a copy of the GNU General Public License # along with NEST. If not, see . -import os + from typing import Sequence, Optional, Mapping, Any, Dict from pynestml.codegeneration.code_generator import CodeGenerator @@ -62,8 +62,10 @@ def _get_neuron_model_namespace(self, neuron: ASTModel) -> Dict: :return: a map from name to functionality. """ from pynestml.codegeneration.nest_tools import NESTTools + namespace = dict() namespace["neuronName"] = neuron.get_name() namespace["neuron"] = neuron namespace["parameters"] = NESTTools.get_neuron_parameters(neuron.get_name()) + return namespace diff --git a/pynestml/codegeneration/printers/python_variable_printer.py b/pynestml/codegeneration/printers/python_variable_printer.py index ab137efb5..d03bdadd0 100644 --- a/pynestml/codegeneration/printers/python_variable_printer.py +++ b/pynestml/codegeneration/printers/python_variable_printer.py @@ -68,8 +68,20 @@ def print_variable(self, variable: ASTVariable) -> str: """ assert isinstance(variable, ASTVariable) + # print external variables (such as a variable in the synapse that needs to call the getter method on the postsynaptic partner) if isinstance(variable, ASTExternalVariable): - raise Exception("Python-standalone target does not support synapses") + _name = str(variable) + if variable.get_alternate_name(): + if not variable._altscope: + # get the value from the postsynaptic partner continuous-time buffer (for post_connected_continuous_input_ports); this has been buffered in a local temp variable starting with "__" + return variable.get_alternate_name() + + # get the value from the postsynaptic partner (without time specified) + # the disadvantage of this approach is that the time the value is to be obtained is not explicitly specified, so we will actually get the value at the end of the min_delay timestep + return "__target.get_" + variable.get_alternate_name() + "()" + + # grab the value from the postsynaptic spiking history buffer + return "start.get_" + _name + "()" if variable.get_name() == PredefinedVariables.E_CONSTANT: return "math.e" diff --git a/pynestml/codegeneration/python_standalone_code_generator.py b/pynestml/codegeneration/python_standalone_code_generator.py index d6afaa095..56f32f771 100644 --- a/pynestml/codegeneration/python_standalone_code_generator.py +++ b/pynestml/codegeneration/python_standalone_code_generator.py @@ -59,12 +59,16 @@ class PythonStandaloneCodeGenerator(NESTCodeGenerator): "templates": { "path": "resources_python_standalone/point_neuron", "model_templates": { - "neuron": ["@NEURON_NAME@.py.jinja2"] + "neuron": ["@NEURON_NAME@.py.jinja2"], + "synapse": ["@SYNAPSE_NAME@.py.jinja2"] }, - "module_templates": ["simulator.py.jinja2", "test_python_standalone_module.py.jinja2", "neuron.py.jinja2", "spike_generator.py.jinja2", "utils.py.jinja2"] + "module_templates": ["simulator.py.jinja2", "test_python_standalone_module.py.jinja2", "neuron.py.jinja2", "synapse.py.jinja2", "spike_generator.py.jinja2", "utils.py.jinja2"] }, "solver": "analytic", - "numeric_solver": "rk45" + "numeric_solver": "rk45", + "neuron_synapse_pairs": [], + "delay_variable": {}, + "weight_variable": {} } def __init__(self, options: Optional[Mapping[str, Any]] = None): diff --git a/pynestml/codegeneration/resources_python_standalone/point_neuron/@NEURON_NAME@.py.jinja2 b/pynestml/codegeneration/resources_python_standalone/point_neuron/@NEURON_NAME@.py.jinja2 index 9e3d1e404..3b7ee7bbf 100644 --- a/pynestml/codegeneration/resources_python_standalone/point_neuron/@NEURON_NAME@.py.jinja2 +++ b/pynestml/codegeneration/resources_python_standalone/point_neuron/@NEURON_NAME@.py.jinja2 @@ -153,29 +153,6 @@ class Neuron_{{neuronName}}(Neuron): self._timestep = timestep self.recompute_internal_variables(self._timestep) -{%- if paired_synapse is defined %} - # ----------------------------- - # code for paired synapse - # ----------------------------- - - # state variables for archiving state for paired synapse - self.n_incoming_ = 0. - self.max_delay_ = 0. - self.last_spike_ = -1. - - # cache initial values -{%- for var in transferred_variables %} -{%- set variable_symbol = transferred_variables_syms[var] %} -{%- set variable = utils.get_variable_by_name(astnode, variable_name) %} -{%- if not var == variable_symbol.get_symbol_name() %} - {{ raise('Error in resolving variable to symbol') }} -{%- endif %} - {{var}}__iv = get_{{ printer.print(variable) }}() -{%- endfor %} - - self.clear_history() -{%- endif %} - def get_model(self) -> str: return "{{neuronName}}" @@ -297,13 +274,13 @@ class Neuron_{{neuronName}}(Neuron): # NESTML generated code for the update block # ------------------------------------------------------------------------- -{% if neuron.get_update_blocks()|length > 0 %} -{%- filter indent(4) %} -{%- for dynamics in neuron.get_update_blocks() %} -{%- set ast = dynamics.get_block() %} -{%- include "directives_py/Block.jinja2" %} -{%- endfor %} -{%- endfilter %} +{% if neuron.get_update_blocks() | length > 0 %} +{%- filter indent(4) %} +{%- for dynamics in neuron.get_update_blocks() %} +{%- set ast = dynamics.get_block() %} +{%- include "directives_py/Block.jinja2" %} +{%- endfor %} +{%- endfilter %} {%- endif %} # ------------------------------------------------------------------------- @@ -356,27 +333,27 @@ class Neuron_{{neuronName}}(Neuron): {%- endfor %} -{% if has_spike_input %} +{% if has_spike_input %} # ------------------------------------------------------------------------- # Spiking input handlers # ------------------------------------------------------------------------- def handle(self, t_spike: float, w: float, port_name: str) -> None: -{%- for port in neuron.get_spike_input_ports() %} +{%- for port in neuron.get_spike_input_ports() %} if port_name == "{{port.name}}": self.B_.{{port.get_symbol_name()}} += abs(w) self.B_.spike_received_{{port.get_symbol_name()}} = True return -{%- endfor %} -{%- endif %} +{%- endfor %} raise Exception("Received a spike on unknown input port \"" + port_name + "\" at t = " + "{0:E}".format(t_spike)) def get_spiking_input_ports(self) -> List[str]: return [ -{%- for port in neuron.get_spike_input_ports() %} +{%- for port in neuron.get_spike_input_ports() %} "{{port.name}}", -{%- endfor %} +{%- endfor %} ] +{%- endif %} # ------------------------------------------------------------------------- # Methods corresponding to event handlers diff --git a/pynestml/codegeneration/resources_python_standalone/point_neuron/@SYNAPSE_NAME@.py.jinja2 b/pynestml/codegeneration/resources_python_standalone/point_neuron/@SYNAPSE_NAME@.py.jinja2 new file mode 100644 index 000000000..2ac10b4a7 --- /dev/null +++ b/pynestml/codegeneration/resources_python_standalone/point_neuron/@SYNAPSE_NAME@.py.jinja2 @@ -0,0 +1,396 @@ +{#- +@SYNAPSE_NAME@.py.jinja2 + +This file is part of NEST. + +Copyright (C) 2004 The NEST Initiative + +NEST is free software: you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation, either version 2 of the License, or +(at your option) any later version. + +NEST is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License +along with NEST. If not, see . +#} +{%- import 'directives_py/FunctionDeclaration.jinja2' as function_declaration with context %} +""" +{{ astnode.name }}.py + +This file is part of NEST. + +Copyright (C) 2004 The NEST Initiative + +NEST is free software: you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation, either version 2 of the License, or +(at your option) any later version. + +NEST is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License +along with NEST. If not, see . + +Generated from NESTML at time: {{now}} +""" + +{% if tracing %}# generated by {{self._TemplateReference__context.name}} +{% endif -%} + +from typing import Any, List, Mapping, Tuple + +import math +from math import * +import numpy as np +import scipy +import scipy.integrate + +from .synapse import Synapse +from .utils import steps + +DEBUG = 1 + + +{%- set stateSize = astnode.get_non_inline_state_symbols()|length %} + +class Synapse_{{ astnode.name }}(Synapse): + + class Parameters_: +{%- filter indent(4, False) %} +{%- for variable_symbol in astnode.get_parameter_symbols() %} +{%- set variable = utils.get_variable_by_name(astnode, variable_symbol.get_symbol_name()) %} +{%- include 'directives_py/MemberDeclaration.jinja2' %} +{%- endfor %} +{%- endfilter %} + + class State_: +{%- if numeric_state_variables|length > 0 %} + ode_state = np.nan * np.ones({{ numeric_state_variables|length }}) + ode_state_variable_name_to_index = { +{%- for var_name in numeric_state_variables %} + "{{ var_name }}" : {{ loop.index - 1 }}, +{%- endfor %} + } + +{% endif %} +{%- filter indent(4, False) %} +{%- for variable_symbol in astnode.get_state_symbols() %} +{%- set variable = utils.get_variable_by_name(astnode, variable_symbol.get_symbol_name()) %} +{%- include 'directives_py/MemberDeclaration.jinja2' %} +{%- endfor %} +{%- endfilter %} + + class Variables_: +{%- filter indent(4, False) %} +{%- for variable_symbol in astnode.get_internal_symbols() %} +{%- set variable = utils.get_variable_by_name(astnode, variable_symbol.get_symbol_name()) %} +{%- include "directives_py/MemberDeclaration.jinja2" %} +{%- endfor %} +{%- endfilter %} + + class Buffers_: +{%- if astnode.get_spike_input_ports() | length > 0 %} + # spiking input ports +{%- endif %} +{%- for port in astnode.get_spike_input_ports() %} +{%- if port.has_vector_parameter() %} + {{ port.get_symbol_name() }}: List[float] = [] + spike_received_{{ port.get_symbol_name() }}: List[bool] = [] +{%- else %} + {{ port.get_symbol_name() }}: float = 0. + spike_received_{{ port.get_symbol_name() }}: bool = False +{%- endif %} +{%- endfor %} + +{%- if astnode.get_continuous_input_ports() | length > 0 %} + # continuous input ports +{%- endif %} +{%- for port in astnode.get_continuous_input_ports() %} +{%- if port.has_vector_parameter() %} + {{ port.get_symbol_name() }}: List[float] = [] +{%- else %} + {{ port.get_symbol_name() }}: float = 0. +{%- endif %} +{%- endfor %} + + + def __init__(self, timestep: float): + super().__init__() + + self.P_ = self.Parameters_() + self.S_ = self.State_() + self.V_ = self.Variables_() + self.B_ = self.Buffers_() + +{%- filter indent(4, True) %} +{%- for variable_symbol in synapse.get_parameter_symbols() %} +{%- set variable = utils.get_parameter_variable_by_name(astnode, variable_symbol.get_symbol_name()) %} +{%- set isHomogeneous = PyNestMLLexer["DECORATOR_HOMOGENEOUS"] in variable_symbol.get_decorators() %} +{%- if not isHomogeneous %} +{%- if variable.get_name() != nest_codegen_opt_delay_variable %} +{%- include "directives_py/MemberInitialization.jinja2" %} +{%- endif %} +{%- endif %} +{%- endfor %} +{%- endfilter %} + +{%- if astnode.get_state_symbols()|length > 0 %} + # initial values for state variables +{%- filter indent(4) %} +{%- for variable_symbol in astnode.get_state_symbols() %} +{%- set variable = utils.get_variable_by_name(astnode, variable_symbol.get_symbol_name()) %} +{%- include "directives_py/MemberInitialization.jinja2" %} +{%- endfor %} +{%- endfilter %} +{%- endif %} + + self._timestep = timestep + self.recompute_internal_variables(self._timestep) + +{%- if paired_synapse is defined %} + # ----------------------------- + # code for paired synapse + # ----------------------------- + + # state variables for archiving state for paired synapse + self.n_incoming_ = 0. + self.max_delay_ = 0. + self.last_spike_ = -1. + + # cache initial values +{%- for var in transferred_variables %} +{%- set variable_symbol = transferred_variables_syms[var] %} +{%- set variable = utils.get_variable_by_name(astnode, var) %} +{%- if not var == variable_symbol.get_symbol_name() %} + {{ raise('Error in resolving variable to symbol') }} +{%- endif %} + {{ var }}__iv = get_{{ printer.print(variable) }}() +{%- endfor %} + + self.clear_history() +{%- endif %} + + def get_model(self) -> str: + return "{{ astnode.name }}" + + def recompute_internal_variables(self, timestep: float, exclude_timestep: bool = False): + __timestep: float = timestep # do not remove, this is necessary for the timestep() function + + if exclude_timestep: + {%- filter indent(6,True) %} + {%- for variable_symbol in astnode.get_internal_symbols() %} + {%- set variable = utils.get_variable_by_name(astnode, variable_symbol.get_symbol_name()) %} + {%- if variable.name != "__h" %} + {%- include "directives_py/MemberInitialization.jinja2" %} + {%- endif %} + {%- endfor %} + {%- endfilter %} + else: + # internals V_ + {%- filter indent(6) %} + {%- for variable_symbol in astnode.get_internal_symbols() %} + {%- set variable = utils.get_variable_by_name(astnode, variable_symbol.get_symbol_name()) %} + {%- include "directives_py/MemberInitialization.jinja2" %} + {%- endfor %} + {%- endfilter %} + +{%- if astnode.get_functions()|length > 0 %} + + # --------------------------------------------------------------------------- + # Functions defined in the NESTML model + # --------------------------------------------------------------------------- +{% for function in astnode.get_functions() -%} + {{ function_declaration.FunctionDeclaration(function, astnode.name) }}: +{%- filter indent(4,True) %} +{%- with ast = function.get_block() %} +{%- include "directives_py/Block.jinja2" %} +{%- endwith %} +{%- endfilter %} +{%- endfor %} +{%- endif %} + + # ------------------------------------------------------------------------- + # Getters/setters for state block + # ------------------------------------------------------------------------- +{% filter indent(2, True) -%} +{%- for variable_symbol in astnode.get_state_symbols() %} +{%- if not is_delta_kernel(astnode.get_kernel_by_name(variable_symbol.get_symbol_name())) %} +{%- set variable = utils.get_variable_by_name(astnode, variable_symbol.get_symbol_name()) %} +{%- include "directives_py/MemberVariableGetterSetter.jinja2" %} +{%- endif %} +{%- endfor %} +{%- endfilter %} + + # ------------------------------------------------------------------------- + # Getters/setters for parameters block + # ------------------------------------------------------------------------- +{% filter indent(2, True) -%} +{%- for variable_symbol in astnode.get_parameter_symbols() %} +{%- set variable = utils.get_variable_by_name(astnode, variable_symbol.get_symbol_name()) %} +{%- include "directives_py/MemberVariableGetterSetter.jinja2" %} +{%- endfor %} +{%- endfilter %} + +{% if astnode.get_equations_blocks()|length > 0 %} + # ------------------------------------------------------------------------- + # Numeric + analytic solver stepping function + # ------------------------------------------------------------------------- + +{%- for ast in utils.get_all_integrate_odes_calls_unique(astnode) %} + +{%- if uses_numeric_solver %} +{% filter indent(2) %} +{%- include "directives_py/GSLDifferentiationFunction.jinja2" %} +{%- endfilter %} +{%- endif %} + + def _integrate_odes{% if ast.get_args() | length > 0 %}_{{ utils.integrate_odes_args_str_from_function_call(ast) }}{% endif %}(self, origin: float, timestep: float): + r"""Integrate {% if ast.get_args() | length > 0 %}a subset of {% endif %}ODE(s) defined in the model equation block by one timestep. +{%- if ast.get_args() | length > 0 %} + + The variables that will be integrated are: {{ ", ".join(utils.integrate_odes_args_strs_from_function_call(ast)) }} +{%- endif %} + """ +{%- filter indent(4) %} + +{%- set analytic_state_variables_ = analytic_state_variables %} +{%- if ast.get_args() | length > 0 %} +{%- set analytic_state_variables_ = utils.filter_variables_list(analytic_state_variables_, ast.get_args()) %} +{%- endif %} + +{#- always integrate convolutions in time #} +{%- for var in analytic_state_variables %} +{%- if "__X__" in var %} +{%- set tmp = analytic_state_variables_.append(var) %} +{%- endif %} +{%- endfor %} + +{%- include "directives_py/AnalyticIntegrationStep_begin.jinja2" %} + +{%- if uses_numeric_solver %} +{%- include "directives_py/GSLIntegrationStep.jinja2" %} +{%- endif %} + +{%- include "directives_py/AnalyticIntegrationStep_end.jinja2" %} +{%- endfilter %} +{%- endfor %} +{%- endif %} + + def step(self, origin: float, timestep: float) -> None: + __timestep: float = timestep # do not remove, this is necessary for the timestep() function + + assert False, "Synapsese are not yet supported for the Python-standalone code generation target." + + # ------------------------------------------------------------------------- + # integrate variables related to convolutions + # ------------------------------------------------------------------------- + +{%- with analytic_state_variables_ = analytic_state_variables_from_convolutions %} +{%- include "directives_py/AnalyticIntegrationStep_begin.jinja2" %} +{%- endwith %} + + # ------------------------------------------------------------------------- + # NESTML generated code for the update block + # ------------------------------------------------------------------------- + +{% if astnode.get_update_blocks() | length > 0 %} +{%- filter indent(4) %} +{%- for dynamics in astnode.get_update_blocks() %} +{%- set ast = dynamics.get_block() %} +{%- include "directives_py/Block.jinja2" %} +{%- endfor %} +{%- endfilter %} + +{%- endif %} + + # ------------------------------------------------------------------------- + # integrate variables related to convolutions + # ------------------------------------------------------------------------- + +{%- with analytic_state_variables_ = analytic_state_variables_from_convolutions %} +{%- include "directives_py/AnalyticIntegrationStep_end.jinja2" %} +{%- endwith %} + + # ------------------------------------------------------------------------- + # process spikes from buffers + # ------------------------------------------------------------------------- +{%- filter indent(4, True) -%} +{%- include "directives_py/ApplySpikesFromBuffers.jinja2" %} +{%- endfilter %} + + # ------------------------------------------------------------------------- + # begin NESTML generated code for the onReceive block(s) + # ------------------------------------------------------------------------- + +{% for blk in astnode.get_on_receive_blocks() %} +{%- set inport = blk.get_port_name() %} + if self.B_.spike_received_{{ inport }}: + self.on_receive_block_{{ blk.get_port_name() }}() +{%- endfor %} + + # ------------------------------------------------------------------------- + # Clear spike buffers at end of timestep + # ------------------------------------------------------------------------- + +{%- for port in astnode.get_spike_input_ports() %} + self.B_.{{port.get_symbol_name()}} = 0. + self.B_.spike_received_{{port.get_symbol_name()}} = False +{%- endfor %} + + + # ------------------------------------------------------------------------- + # Begin NESTML generated code for the onCondition block(s) + # ------------------------------------------------------------------------- +{%- for block in astnode.get_on_condition_blocks() %} + if {{ printer.print(block.get_cond_expr()) }}: +{%- set ast = block.get_block() %} +{%- if ast.print_comment('#') | length > 1 %} +# {{ast.print_comment('#')}} +{%- endif %} +{%- filter indent(6) %} +{%- include "directives_py/Block.jinja2" %} +{%- endfilter %} +{%- endfor %} + + + # ------------------------------------------------------------------------- + # Spiking input handlers + # ------------------------------------------------------------------------- + + def handle(self, t_spike: float, port_name: str) -> None: + assert False, "Synapsese are not yet supported for the Python-standalone code generation target." + +{%- for port in astnode.get_spike_input_ports() %} + if port_name == "{{port.name}}": + self.B_.{{port.get_symbol_name()}} += abs(w) + self.B_.spike_received_{{port.get_symbol_name()}} = True + return +{%- endfor %} + raise Exception("Received a spike on unknown input port \"" + port_name + "\" at t = " + "{0:E}".format(t_spike)) + + def get_spiking_input_ports(self) -> List[str]: + return [ +{%- for port in astnode.get_spike_input_ports() %} + "{{port.name}}", +{%- endfor %} + ] + + # ------------------------------------------------------------------------- + # Methods corresponding to event handlers + # ------------------------------------------------------------------------- + +{%- for blk in astnode.get_on_receive_blocks() %} +{%- set ast = blk.get_block() %} + def on_receive_block_{{ blk.get_port_name() }}(self): +{%- filter indent(4, True) -%} +{%- include "directives_py/Block.jinja2" %} +{%- endfilter %} +{% endfor %} diff --git a/pynestml/codegeneration/resources_python_standalone/point_neuron/directives_py/Block.jinja2 b/pynestml/codegeneration/resources_python_standalone/point_neuron/directives_py/Block.jinja2 index 985179281..c16005c64 100644 --- a/pynestml/codegeneration/resources_python_standalone/point_neuron/directives_py/Block.jinja2 +++ b/pynestml/codegeneration/resources_python_standalone/point_neuron/directives_py/Block.jinja2 @@ -5,7 +5,7 @@ #} {%- if tracing %}# generated by {{self._TemplateReference__context.name}}{% endif %} {%- for statement in ast.get_stmts() %} -{%- with stmt = statement %} -{%- include "directives_py/Statement.jinja2" %} -{%- endwith %} +{%- with stmt = statement %} +{%- include "directives_py/Statement.jinja2" %} +{%- endwith %} {%- endfor %} diff --git a/pynestml/frontend/pynestml_frontend.py b/pynestml/frontend/pynestml_frontend.py index 058fe8ca8..365f44a66 100644 --- a/pynestml/frontend/pynestml_frontend.py +++ b/pynestml/frontend/pynestml_frontend.py @@ -70,7 +70,7 @@ def transformers_from_target_name(target_name: str, options: Optional[Mapping[st "goto", "if", "inline", "int", "long", "mutable", "namespace", "new", "noexcept", "not", "not_eq", "nullptr", "operator", "or", "or_eq", "private", "protected", "public", "register", "reinterpret_cast", "requires", "return", "short", "signed", "sizeof", "static", "static_assert", "static_cast", "struct", "switch", "template", "this", "thread_local", "throw", "true", "try", "typedef", "typeid", "typename", "union", "unsigned", "using", "virtual", "void", "volatile", "wchar_t", "while", "xor", "xor_eq"]}) transformers.append(variable_name_rewriter) - if target_name.upper() in ["SPINNAKER"]: + elif target_name.upper() in ["SPINNAKER"]: from pynestml.transformers.synapse_remove_post_port import SynapseRemovePostPortTransformer # co-generate neuron and synapse @@ -78,7 +78,7 @@ def transformers_from_target_name(target_name: str, options: Optional[Mapping[st options = synapse_post_neuron_co_generation.set_options(options) transformers.append(synapse_post_neuron_co_generation) - if target_name.upper() == "NEST": + elif target_name.upper() == "NEST": from pynestml.transformers.synapse_post_neuron_transformer import SynapsePostNeuronTransformer # co-generate neuron and synapse @@ -86,14 +86,21 @@ def transformers_from_target_name(target_name: str, options: Optional[Mapping[st options = synapse_post_neuron_co_generation.set_options(options) transformers.append(synapse_post_neuron_co_generation) - if target_name.upper() in ["PYTHON_STANDALONE"]: + elif target_name.upper() in ["PYTHON_STANDALONE"]: from pynestml.transformers.illegal_variable_name_transformer import IllegalVariableNameTransformer # rewrite all Python keywords # from: ``import keyword; print(keyword.kwlist)`` - variable_name_rewriter = IllegalVariableNameTransformer({"forbidden_names": ['False', 'None', 'True', 'and', 'as', 'assert', 'async', 'await', 'break', 'class', 'continue', 'def', 'del', 'elif', 'else', 'except', 'finally', 'for', 'from', 'global', 'if', 'import', 'in', 'is', 'lambda', 'nonlocal', 'not', 'or', 'pass', 'raise', 'return', 'try', 'while', 'with', 'yield']}) + variable_name_rewriter = IllegalVariableNameTransformer({"forbidden_names": ["False", "None", "True", "and", "as", "assert", "async", "await", "break", "class", "continue", "def", "del", "elif", "else", "except", "finally", "for", "from", "global", "if", "import", "in", "is", "lambda", "nonlocal", "not", "or", "pass", "raise", "return", "try", "while", "with", "yield"]}) transformers.append(variable_name_rewriter) + # co-generate neuron and synapse + from pynestml.transformers.synapse_remove_post_port import SynapseRemovePostPortTransformer + + synapse_post_neuron_co_generation = SynapseRemovePostPortTransformer() + options = synapse_post_neuron_co_generation.set_options(options) + transformers.append(synapse_post_neuron_co_generation) + return transformers, options diff --git a/pynestml/transformers/synapse_remove_post_port.py b/pynestml/transformers/synapse_remove_post_port.py index 8284e4402..cf92166a7 100644 --- a/pynestml/transformers/synapse_remove_post_port.py +++ b/pynestml/transformers/synapse_remove_post_port.py @@ -23,15 +23,13 @@ from typing import Optional, Mapping, Any, Union, Sequence +from pynestml.frontend.frontend_configuration import FrontendConfiguration from pynestml.meta_model.ast_node import ASTNode - -from pynestml.utils.logger import Logger, LoggingLevel from pynestml.transformers.transformer import Transformer -from pynestml.utils.ast_utils import ASTUtils from pynestml.visitors.ast_parent_visitor import ASTParentVisitor from pynestml.visitors.ast_symbol_table_visitor import ASTSymbolTableVisitor -from pynestml.frontend.frontend_configuration import FrontendConfiguration - +from pynestml.utils.ast_utils import ASTUtils +from pynestml.utils.logger import Logger, LoggingLevel from pynestml.utils.string_utils import removesuffix From 8774472796eac4e56cddbbbce54f3c9c8b55a2e2 Mon Sep 17 00:00:00 2001 From: "C.A.P. Linssen" Date: Tue, 10 Dec 2024 14:35:35 +0100 Subject: [PATCH 2/3] clean up Python code generator --- pynestml/frontend/pynestml_frontend.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pynestml/frontend/pynestml_frontend.py b/pynestml/frontend/pynestml_frontend.py index 365f44a66..334199bba 100644 --- a/pynestml/frontend/pynestml_frontend.py +++ b/pynestml/frontend/pynestml_frontend.py @@ -70,7 +70,7 @@ def transformers_from_target_name(target_name: str, options: Optional[Mapping[st "goto", "if", "inline", "int", "long", "mutable", "namespace", "new", "noexcept", "not", "not_eq", "nullptr", "operator", "or", "or_eq", "private", "protected", "public", "register", "reinterpret_cast", "requires", "return", "short", "signed", "sizeof", "static", "static_assert", "static_cast", "struct", "switch", "template", "this", "thread_local", "throw", "true", "try", "typedef", "typeid", "typename", "union", "unsigned", "using", "virtual", "void", "volatile", "wchar_t", "while", "xor", "xor_eq"]}) transformers.append(variable_name_rewriter) - elif target_name.upper() in ["SPINNAKER"]: + if target_name.upper() in ["SPINNAKER"]: from pynestml.transformers.synapse_remove_post_port import SynapseRemovePostPortTransformer # co-generate neuron and synapse @@ -78,7 +78,7 @@ def transformers_from_target_name(target_name: str, options: Optional[Mapping[st options = synapse_post_neuron_co_generation.set_options(options) transformers.append(synapse_post_neuron_co_generation) - elif target_name.upper() == "NEST": + if target_name.upper() == "NEST": from pynestml.transformers.synapse_post_neuron_transformer import SynapsePostNeuronTransformer # co-generate neuron and synapse @@ -86,7 +86,7 @@ def transformers_from_target_name(target_name: str, options: Optional[Mapping[st options = synapse_post_neuron_co_generation.set_options(options) transformers.append(synapse_post_neuron_co_generation) - elif target_name.upper() in ["PYTHON_STANDALONE"]: + if target_name.upper() in ["PYTHON_STANDALONE"]: from pynestml.transformers.illegal_variable_name_transformer import IllegalVariableNameTransformer # rewrite all Python keywords From b46c2ae6318f202a64e3f576db8a5b0064b84659 Mon Sep 17 00:00:00 2001 From: "C.A.P. Linssen" Date: Wed, 11 Dec 2024 13:49:28 +0100 Subject: [PATCH 3/3] clean up Python code generator --- .../point_neuron/@NEURON_NAME@.py.jinja2 | 2 -- .../point_neuron/@SYNAPSE_NAME@.py.jinja2 | 6 ++---- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/pynestml/codegeneration/resources_python_standalone/point_neuron/@NEURON_NAME@.py.jinja2 b/pynestml/codegeneration/resources_python_standalone/point_neuron/@NEURON_NAME@.py.jinja2 index ee8e48b64..33574910a 100644 --- a/pynestml/codegeneration/resources_python_standalone/point_neuron/@NEURON_NAME@.py.jinja2 +++ b/pynestml/codegeneration/resources_python_standalone/point_neuron/@NEURON_NAME@.py.jinja2 @@ -56,8 +56,6 @@ import scipy.integrate from .neuron import Neuron from .utils import steps -DEBUG = 1 - {%- set stateSize = neuron.get_non_inline_state_symbols()|length %} diff --git a/pynestml/codegeneration/resources_python_standalone/point_neuron/@SYNAPSE_NAME@.py.jinja2 b/pynestml/codegeneration/resources_python_standalone/point_neuron/@SYNAPSE_NAME@.py.jinja2 index 2ac10b4a7..fee9771d0 100644 --- a/pynestml/codegeneration/resources_python_standalone/point_neuron/@SYNAPSE_NAME@.py.jinja2 +++ b/pynestml/codegeneration/resources_python_standalone/point_neuron/@SYNAPSE_NAME@.py.jinja2 @@ -56,8 +56,6 @@ import scipy.integrate from .synapse import Synapse from .utils import steps -DEBUG = 1 - {%- set stateSize = astnode.get_non_inline_state_symbols()|length %} @@ -287,7 +285,7 @@ class Synapse_{{ astnode.name }}(Synapse): def step(self, origin: float, timestep: float) -> None: __timestep: float = timestep # do not remove, this is necessary for the timestep() function - assert False, "Synapsese are not yet supported for the Python-standalone code generation target." + assert False, "Synapses are not yet supported for the Python-standalone code generation target." # ------------------------------------------------------------------------- # integrate variables related to convolutions @@ -366,7 +364,7 @@ class Synapse_{{ astnode.name }}(Synapse): # ------------------------------------------------------------------------- def handle(self, t_spike: float, port_name: str) -> None: - assert False, "Synapsese are not yet supported for the Python-standalone code generation target." + assert False, "Synapses are not yet supported for the Python-standalone code generation target." {%- for port in astnode.get_spike_input_ports() %} if port_name == "{{port.name}}":