Skip to content

Commit

Permalink
add attributes to spiking input ports
Browse files Browse the repository at this point in the history
  • Loading branch information
C.A.P. Linssen committed Jan 7, 2025
1 parent 03013f9 commit 597d2db
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 13 deletions.
4 changes: 2 additions & 2 deletions models/neurons/terub_gpe_neuron.nestml
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@ model terub_gpe_neuron:
inline g_k_Ca real = 15.0 #Report:15, Terman Rubin 2002: 20.0
inline g_k1 real = 30.0

inline I_exc_mod real = -convolve(g_exc, exc_spikes) * V_m
inline I_inh_mod real = convolve(g_inh, inh_spikes) * (V_m - E_gg)
inline I_exc_mod real = -convolve(g_exc, exc_spikes.weight) * V_m
inline I_inh_mod real = convolve(g_inh, inh_spikes.weight) * (V_m - E_gg)

inline tau_n real = g_tau_n_0 + g_tau_n_1 / (1. + exp(-(V_m-g_theta_n_tau)/g_sigma_n_tau))
inline tau_h real = g_tau_h_0 + g_tau_h_1 / (1. + exp(-(V_m-g_theta_h_tau)/g_sigma_h_tau))
Expand Down
4 changes: 2 additions & 2 deletions pynestml/codegeneration/printers/python_variable_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,9 @@ def _print_python_name(cls, variable_name: str) -> str:
"""
differential_order = variable_name.count("\"")
if differential_order > 0:
return variable_name.replace("\"", "").replace("$", "__DOLLAR") + "__" + "d" * differential_order
return variable_name.replace(".", "__DOT__").replace("\"", "").replace("$", "__DOLLAR") + "__" + "d" * differential_order

return variable_name.replace("$", "__DOLLAR")
return variable_name.replace(".", "__DOT__").replace("$", "__DOLLAR")

def print_variable(self, variable: ASTVariable) -> str:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,11 @@ class Neuron_{{neuronName}}(Neuron):
{{ port.get_symbol_name() }}: List[float] = []
spike_received_{{ port.get_symbol_name() }}: List[bool] = []
{%- else %}
{{ port.get_symbol_name() }}: float = 0.
{% set ast_input_port = utils.get_input_port_by_name(astnode.get_input_blocks(), port.name) %}
{%- for attribute in ast_input_port.get_parameters() %}
{{ port.get_symbol_name() }}__DOT__{{ attribute.name }}: float = 0.
{%- endfor %}
{{ port.get_symbol_name() }}: float = 0. # buffer for the port name by itself (train of unweighted delta pulses)
spike_received_{{ port.get_symbol_name() }}: bool = False
{%- endif %}
{%- endfor %}
Expand Down Expand Up @@ -337,9 +341,17 @@ class Neuron_{{neuronName}}(Neuron):

def handle(self, t_spike: float, w: float, port_name: str) -> None:
{%- for port in neuron.get_spike_input_ports() %}
if port_name == "{{port.name}}":
self.B_.{{port.get_symbol_name()}} += abs(w)
self.B_.spike_received_{{port.get_symbol_name()}} = True
if port_name == "{{ port.name }}":
self.B_.{{ port.get_symbol_name() }} += 1. # unweighted spike port
{% set ast_input_port = utils.get_input_port_by_name(astnode.get_input_blocks(), port.name) %}
{%- for attribute in ast_input_port.get_parameters() %}
{% if attribute.name == "weight" %}
self.B_.{{ port.get_symbol_name() }}__DOT__{{ attribute.name }} += abs(w) # unweighted spike port
{% else %}
{{ raise('The Python-standalone code generator only supports \'weight\' spike input port attribute for now') }}
{% endif %}
{% endfor %}
self.B_.spike_received_{{ port.get_symbol_name() }} = True
return
{%- endfor %}
raise Exception("Received a spike on unknown input port \"" + port_name + "\" at t = " + "{0:E}".format(t_spike))
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{%- if tracing %}# generated by {{self._TemplateReference__context.name}}{% endif %}
{%- for spike_updates_for_port in spike_updates.values() %}
{%- for ast in spike_updates_for_port -%}
{%- include "directives_py/Assignment.jinja2" %}
{%- endfor %}
{%- for ast in spike_updates_for_port -%}
{%- include "directives_py/Assignment.jinja2" %}
{%- endfor %}
{%- endfor %}
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ class TestSimulator:
sg_inh = simulator.add_neuron(SpikeGenerator(interval=50.))
{% for neuron in neurons %}
neuron = simulator.add_neuron(Neuron_{{neuron.get_name()}}(timestep=simulator.timestep))
simulator.connect(sg_exc, neuron, "exc_spikes",w=1000.)
simulator.connect(sg_inh, neuron, "inh_spikes",w=4000.)
simulator.connect(sg_exc, neuron, "spike_in_port",w=1000.)
simulator.connect(sg_inh, neuron, "spike_in_port",w=-4000.)
{% endfor %}

simulator.run(t_stop)
Expand Down

0 comments on commit 597d2db

Please sign in to comment.