From f4b5bc71e164161c508d0cecefc65a7ca225db7a Mon Sep 17 00:00:00 2001 From: JCGoran Date: Tue, 3 Dec 2024 11:44:04 +0100 Subject: [PATCH 1/2] Fix custom function handling in sympy solve visitor (#1563) --- python/nmodl/ode.py | 4 +++- test/unit/ode/test_ode.py | 2 ++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/python/nmodl/ode.py b/python/nmodl/ode.py index f34c9d026..dce3fdeb7 100644 --- a/python/nmodl/ode.py +++ b/python/nmodl/ode.py @@ -543,10 +543,12 @@ def integrate2c(diff_string, dt_var, vars, use_pade_approx=False): if _a1 == 0 and _a2 == 0: solution = _a0 + custom_fcts = {str(f.func): str(f.func) for f in solution.atoms(sp.Function)} + # return result as C code in NEURON format: # - in the lhs x_0 refers to the state var at time (t+dt) # - in the rhs x_0 refers to the state var at time t - return f"{sp.ccode(x)} = {sp.ccode(solution.evalf())}" + return f"{sp.ccode(x)} = {sp.ccode(solution.evalf(), user_functions=custom_fcts)}" def forwards_euler2c(diff_string, dt_var, vars, function_calls): diff --git a/test/unit/ode/test_ode.py b/test/unit/ode/test_ode.py index 8be63e02e..021a90eb5 100644 --- a/test/unit/ode/test_ode.py +++ b/test/unit/ode/test_ode.py @@ -170,6 +170,8 @@ def test_integrate2c(): ("a", "x + a*dt"), ("a*x", "x*exp(a*dt)"), ("a*x+b", "(-b + (a*x + b)*exp(a*dt))/a"), + # assume custom_function is defined in mod file + ("custom_function(a)*x", "x*exp(custom_function(a)*dt)"), ] for eq, sol in test_cases: assert _equivalent( From 2c3c7ddc0ddf6e27d6a659d4003e0abbff3dc634 Mon Sep 17 00:00:00 2001 From: JCGoran Date: Tue, 3 Dec 2024 13:57:03 +0100 Subject: [PATCH 2/2] Remove units when parsing CVODE statements (#1561) --- src/visitors/cvode_visitor.cpp | 23 +++++++++++++++++++++++ test/usecases/cvode/derivative.mod | 8 +++++++- 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/src/visitors/cvode_visitor.cpp b/src/visitors/cvode_visitor.cpp index 617e1e6c6..13ef1e819 100644 --- a/src/visitors/cvode_visitor.cpp +++ b/src/visitors/cvode_visitor.cpp @@ -13,6 +13,7 @@ #include "utils/logger.hpp" #include "visitors/visitor_utils.hpp" #include +#include #include namespace pywrap = nmodl::pybind_wrappers; @@ -35,6 +36,25 @@ static void remove_conserve_statements(ast::StatementBlock& node) { } } +// remove units from CVODE block so sympy can parse it properly +static void remove_units(ast::BinaryExpression& node) { + // matches either an int or a float, followed by any (including zero) + // number of spaces, followed by an expression in parentheses, that only + // has letters of the alphabet + std::regex unit_pattern(R"((\d+\.?\d*|\.\d+)\s*\([a-zA-Z]+\))"); + auto rhs_string = to_nmodl(node.get_rhs()); + auto rhs_string_no_units = fmt::format("{} = {}", + to_nmodl(node.get_lhs()), + std::regex_replace(rhs_string, unit_pattern, "$1")); + logger->debug("CvodeVisitor :: removing units from statement {}", to_nmodl(node)); + logger->debug("CvodeVisitor :: result: {}", rhs_string_no_units); + auto expr_statement = std::dynamic_pointer_cast( + create_statement(rhs_string_no_units)); + const auto bin_expr = std::dynamic_pointer_cast( + expr_statement->get_expression()); + node.set_rhs(std::shared_ptr(bin_expr->get_rhs()->clone())); +} + static std::pair> parse_independent_var( std::shared_ptr node) { auto variable = std::make_pair(node->get_node_name(), std::optional()); @@ -152,7 +172,10 @@ class StiffVisitor: public CvodeHelperVisitor { program_symtab->insert(symbol); } + remove_units(node); + auto rhs = node.get_rhs(); + // all indexed variables (need special treatment in SymPy) auto indexed_variables = get_indexed_variables(*rhs, name->get_node_name()); auto diff2c = pywrap::EmbeddedPythonLoader::get_instance().api().diff2c; diff --git a/test/usecases/cvode/derivative.mod b/test/usecases/cvode/derivative.mod index d3715352f..2a8ba6ca6 100644 --- a/test/usecases/cvode/derivative.mod +++ b/test/usecases/cvode/derivative.mod @@ -2,6 +2,10 @@ NEURON { SUFFIX scalar } +UNITS { + (um) = (micron) +} + PARAMETER { freq = 10 a = 5 @@ -14,7 +18,7 @@ PARAMETER { k = 0.2 } -STATE {var1 var2 var3} +STATE {var1 var2 var3 var4} INITIAL { var1 = v1 @@ -34,4 +38,6 @@ DERIVATIVE equation { var2' = -var2 * a : logistic ODE var3' = r * var3 * (1 - var3 / k) + : ODE with some units + var4' = 1(um) * var4 + a * .1(um) + r * 1.(um) + 1.0 (um) }