From 417acac77a6c63f706e8e942a410cad7d7b58362 Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Tue, 24 Sep 2024 18:03:56 +0200 Subject: [PATCH 01/47] Add `DerivativeOriginalFunctionBlock` and `DerivativeVisitor` --- src/language/code_generator.cmake | 1 + src/language/codegen.yaml | 25 +++- src/main.cpp | 8 ++ src/visitors/CMakeLists.txt | 1 + src/visitors/derivative_original_visitor.cpp | 129 +++++++++++++++++++ src/visitors/derivative_original_visitor.hpp | 64 +++++++++ src/visitors/sympy_solver_visitor.cpp | 4 + src/visitors/sympy_solver_visitor.hpp | 2 + 8 files changed, 233 insertions(+), 1 deletion(-) create mode 100644 src/visitors/derivative_original_visitor.cpp create mode 100644 src/visitors/derivative_original_visitor.hpp diff --git a/src/language/code_generator.cmake b/src/language/code_generator.cmake index a3dea0767f..992d5b0cb1 100644 --- a/src/language/code_generator.cmake +++ b/src/language/code_generator.cmake @@ -74,6 +74,7 @@ set(AST_GENERATED_SOURCES ${PROJECT_BINARY_DIR}/src/ast/constructor_block.hpp ${PROJECT_BINARY_DIR}/src/ast/define.hpp ${PROJECT_BINARY_DIR}/src/ast/derivative_block.hpp + ${PROJECT_BINARY_DIR}/src/ast/derivative_original_function_block.hpp ${PROJECT_BINARY_DIR}/src/ast/derivimplicit_callback.hpp ${PROJECT_BINARY_DIR}/src/ast/destructor_block.hpp ${PROJECT_BINARY_DIR}/src/ast/diff_eq_expression.hpp diff --git a/src/language/codegen.yaml b/src/language/codegen.yaml index 477df7fa65..ac92afb517 100644 --- a/src/language/codegen.yaml +++ b/src/language/codegen.yaml @@ -87,7 +87,30 @@ type: StatementBlock - finalize_block: brief: "Statement block to be executed after calling linear solver" - type: StatementBlock + type: StatementBlock + - DerivativeOriginalFunctionBlock: + nmodl: "DERIVATIVE_ORIGINAL_FUNCTION " + members: + - name: + brief: "Name of the derivative block" + type: Name + node_name: true + suffix: {value: " "} + - statement_block: + brief: "Block with statements vector" + type: StatementBlock + getter: {override: true} + brief: "Represents the original, unmodified `DERIVATIVE` block in the NMODL" + description: | + The original `DERIVATIVE` block in NMODL is + replaced in-place if the system of ODEs is + solvable analytically. Therefore, this + block's sole purpose is to keep the + original, unsolved block in the AST. This is + primarily useful when we need to solve the + ODE system using implicit methods, for + instance, CVode. + - WrappedExpression: brief: "Wrap any other expression type" members: diff --git a/src/main.cpp b/src/main.cpp index f12bfe35dd..f150753479 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -25,6 +25,7 @@ #include "visitors/after_cvode_to_cnexp_visitor.hpp" #include "visitors/ast_visitor.hpp" #include "visitors/constant_folder_visitor.hpp" +#include "visitors/derivative_original_visitor.hpp" #include "visitors/function_callpath_visitor.hpp" #include "visitors/global_var_visitor.hpp" #include "visitors/implicit_argument_visitor.hpp" @@ -497,6 +498,13 @@ int run_nmodl(int argc, const char* argv[]) { const bool sympy_linear = node_exists(*ast, ast::AstNodeType::LINEAR_BLOCK); const bool sympy_sparse = solver_exists(*ast, "sparse"); + if (neuron_code) { + logger->info("Running derivative visitor"); + DerivativeOriginalVisitor().visit_program(*ast); + SymtabVisitor(update_symtab).visit_program(*ast); + ast_to_nmodl(*ast, filepath("derivative_original")); + } + if (sympy_conductance || sympy_analytic || sympy_sparse || sympy_derivimplicit || sympy_linear) { nmodl::pybind_wrappers::EmbeddedPythonLoader::get_instance() diff --git a/src/visitors/CMakeLists.txt b/src/visitors/CMakeLists.txt index 262b6a623a..ede77671eb 100644 --- a/src/visitors/CMakeLists.txt +++ b/src/visitors/CMakeLists.txt @@ -11,6 +11,7 @@ add_library( visitor STATIC after_cvode_to_cnexp_visitor.cpp constant_folder_visitor.cpp + derivative_original_visitor.cpp defuse_analyze_visitor.cpp function_callpath_visitor.cpp global_var_visitor.cpp diff --git a/src/visitors/derivative_original_visitor.cpp b/src/visitors/derivative_original_visitor.cpp new file mode 100644 index 0000000000..bd6d135350 --- /dev/null +++ b/src/visitors/derivative_original_visitor.cpp @@ -0,0 +1,129 @@ +/* + * Copyright 2023 Blue Brain Project, EPFL. + * See the top-level LICENSE file for details. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +#include "visitors/derivative_original_visitor.hpp" + +#include "ast/all.hpp" +#include "lexer/token_mapping.hpp" +#include "pybind/pyembed.hpp" +#include "utils/logger.hpp" +#include "visitors/visitor_utils.hpp" +#include +#include + +namespace pywrap = nmodl::pybind_wrappers; + +namespace nmodl { +namespace visitor { + +static int get_index(const ast::IndexedName& node) { + return std::stoi(to_nmodl(node.get_length())); +} + +static auto get_name_map(const ast::Expression& node, const std::string& name) { + std::unordered_map name_map; + // all of the "reserved" symbols + auto reserved_symbols = get_external_functions(); + // all indexed vars + auto indexed_vars = collect_nodes(node, {ast::AstNodeType::INDEXED_NAME}); + for (const auto& var: indexed_vars) { + if (!name_map.count(var->get_node_name()) && var->get_node_name() != name && + std::none_of(reserved_symbols.begin(), reserved_symbols.end(), [&var](const auto item) { + return var->get_node_name() == item; + })) { + logger->debug( + "DerivativeOriginalVisitor :: adding INDEXED_VARIABLE {} to " + "node_map", + var->get_node_name()); + name_map[var->get_node_name()] = get_index( + *std::dynamic_pointer_cast(var)); + } + } + return name_map; +} + + +void DerivativeOriginalVisitor::visit_derivative_block(ast::DerivativeBlock& node) { + node.visit_children(*this); + der_block_function = node.clone(); +} + + +void DerivativeOriginalVisitor::visit_derivative_original_function_block( + ast::DerivativeOriginalFunctionBlock& node) { + derivative_block = true; + node_type = node.get_node_type(); + node.visit_children(*this); + node_type = ast::AstNodeType::NODE; + derivative_block = false; +} + +void DerivativeOriginalVisitor::visit_diff_eq_expression(ast::DiffEqExpression& node) { + differential_equation = true; + node.visit_children(*this); + differential_equation = false; +} + + +void DerivativeOriginalVisitor::visit_binary_expression(ast::BinaryExpression& node) { + const auto& lhs = node.get_lhs(); + + /// we have to only solve ODEs under original derivative block where lhs is variable + if (!derivative_block || !differential_equation || !lhs->is_var_name()) { + return; + } + + auto name = std::dynamic_pointer_cast(lhs)->get_name(); + + if (name->is_prime_name() || name->is_indexed_name()) { + std::string varname; + if (name->is_prime_name()) { + varname = "D" + name->get_node_name(); + logger->debug("DerivativeOriginalVisitor :: replacing {} with {} on LHS of {}", + name->get_node_name(), + varname, + to_nmodl(node)); + node.set_lhs(std::make_shared(new ast::String(varname))); + if (program_symtab->lookup(varname) == nullptr) { + auto symbol = std::make_shared(varname, ModToken()); + symbol->set_original_name(name->get_node_name()); + program_symtab->insert(symbol); + } + } else { + varname = "D" + stringutils::remove_character(to_nmodl(node.get_lhs()), '\''); + // we discard the RHS here so it can be anything (as long as NMODL considers it valid) + auto statement = fmt::format("{} = {}", varname, varname); + logger->debug("DerivativeOriginalVisitor :: replacing {} with {} on LHS of {}", + to_nmodl(node.get_lhs()), + varname, + to_nmodl(node)); + auto expr_statement = std::dynamic_pointer_cast( + create_statement(statement)); + const auto bin_expr = std::dynamic_pointer_cast( + expr_statement->get_expression()); + node.set_lhs(std::shared_ptr(bin_expr->get_lhs()->clone())); + // TODO add symbol? + } + } +} + +void DerivativeOriginalVisitor::visit_program(ast::Program& node) { + program_symtab = node.get_symbol_table(); + node.visit_children(*this); + if (der_block_function) { + auto der_node = + new ast::DerivativeOriginalFunctionBlock(der_block_function->get_name(), + der_block_function->get_statement_block()); + node.emplace_back_node(der_node); + } + + // re-visit the AST since we now inserted the DERIVATIVE_ORIGINAL block + node.visit_children(*this); +} + +} // namespace visitor +} // namespace nmodl diff --git a/src/visitors/derivative_original_visitor.hpp b/src/visitors/derivative_original_visitor.hpp new file mode 100644 index 0000000000..d483ab845b --- /dev/null +++ b/src/visitors/derivative_original_visitor.hpp @@ -0,0 +1,64 @@ +/* + * Copyright 2023 Blue Brain Project, EPFL. + * See the top-level LICENSE file for details. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +/** + * \file + * \brief \copybrief nmodl::visitor::DerivativeOriginalVisitor + */ + +#include "symtab/decl.hpp" +#include "visitors/ast_visitor.hpp" +#include + +namespace nmodl { +namespace visitor { + +/** + * \addtogroup visitor_classes + * \{ + */ + +/** + * \class DerivativeOriginalVisitor + * \brief Make a copy of the `DERIVATIVE` block (if it exists), and insert back as + * `DERIVATIVE_ORIGINAL_FUNCTION` block. + * + * If \ref SympySolverVisitor runs successfully, it replaces the original + * solution. This block is inserted before that to prevent losing access to + * information about the block. + */ +class DerivativeOriginalVisitor: public AstVisitor { + private: + /// The copy of the derivative block we are solving + ast::DerivativeBlock* der_block_function = nullptr; + + /// true while visiting differential equation + bool differential_equation = false; + + /// global symbol table + symtab::SymbolTable* program_symtab = nullptr; + + /// visiting derivative block + bool derivative_block = false; + + ast::AstNodeType node_type = ast::AstNodeType::NODE; + + public: + void visit_derivative_block(ast::DerivativeBlock& node) override; + void visit_program(ast::Program& node) override; + void visit_derivative_original_function_block( + ast::DerivativeOriginalFunctionBlock& node) override; + void visit_diff_eq_expression(ast::DiffEqExpression& node) override; + void visit_binary_expression(ast::BinaryExpression& node) override; +}; + +/** \} */ // end of visitor_classes + +} // namespace visitor +} // namespace nmodl diff --git a/src/visitors/sympy_solver_visitor.cpp b/src/visitors/sympy_solver_visitor.cpp index f2d6260c21..e7b955a5c0 100644 --- a/src/visitors/sympy_solver_visitor.cpp +++ b/src/visitors/sympy_solver_visitor.cpp @@ -399,6 +399,10 @@ void SympySolverVisitor::visit_var_name(ast::VarName& node) { } } +// Skip visiting DERIVATIVE_ORIGINAL block +void SympySolverVisitor::visit_derivative_original_function_block( + ast::DerivativeOriginalFunctionBlock& node) {} + void SympySolverVisitor::visit_diff_eq_expression(ast::DiffEqExpression& node) { const auto& lhs = node.get_expression()->get_lhs(); diff --git a/src/visitors/sympy_solver_visitor.hpp b/src/visitors/sympy_solver_visitor.hpp index ecb326ab63..627451d4b7 100644 --- a/src/visitors/sympy_solver_visitor.hpp +++ b/src/visitors/sympy_solver_visitor.hpp @@ -185,6 +185,8 @@ class SympySolverVisitor: public AstVisitor { void visit_expression_statement(ast::ExpressionStatement& node) override; void visit_statement_block(ast::StatementBlock& node) override; void visit_program(ast::Program& node) override; + void visit_derivative_original_function_block( + ast::DerivativeOriginalFunctionBlock& node) override; }; /** @} */ // end of visitor_classes From d33a594575291e4b7cc1fcb10f76501c17ec6143 Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Tue, 24 Sep 2024 18:18:33 +0200 Subject: [PATCH 02/47] Remove unused functions --- src/visitors/derivative_original_visitor.cpp | 26 -------------------- 1 file changed, 26 deletions(-) diff --git a/src/visitors/derivative_original_visitor.cpp b/src/visitors/derivative_original_visitor.cpp index bd6d135350..2e7b6942a2 100644 --- a/src/visitors/derivative_original_visitor.cpp +++ b/src/visitors/derivative_original_visitor.cpp @@ -20,32 +20,6 @@ namespace pywrap = nmodl::pybind_wrappers; namespace nmodl { namespace visitor { -static int get_index(const ast::IndexedName& node) { - return std::stoi(to_nmodl(node.get_length())); -} - -static auto get_name_map(const ast::Expression& node, const std::string& name) { - std::unordered_map name_map; - // all of the "reserved" symbols - auto reserved_symbols = get_external_functions(); - // all indexed vars - auto indexed_vars = collect_nodes(node, {ast::AstNodeType::INDEXED_NAME}); - for (const auto& var: indexed_vars) { - if (!name_map.count(var->get_node_name()) && var->get_node_name() != name && - std::none_of(reserved_symbols.begin(), reserved_symbols.end(), [&var](const auto item) { - return var->get_node_name() == item; - })) { - logger->debug( - "DerivativeOriginalVisitor :: adding INDEXED_VARIABLE {} to " - "node_map", - var->get_node_name()); - name_map[var->get_node_name()] = get_index( - *std::dynamic_pointer_cast(var)); - } - } - return name_map; -} - void DerivativeOriginalVisitor::visit_derivative_block(ast::DerivativeBlock& node) { node.visit_children(*this); From b9f08d05225b37c9cd288c3471112cda51861bb4 Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Wed, 25 Sep 2024 09:27:13 +0200 Subject: [PATCH 03/47] Add test for DerivativeOriginalVisitor --- test/unit/CMakeLists.txt | 1 + test/unit/visitor/derivative_original.cpp | 55 +++++++++++++++++++++++ 2 files changed, 56 insertions(+) create mode 100644 test/unit/visitor/derivative_original.cpp diff --git a/test/unit/CMakeLists.txt b/test/unit/CMakeLists.txt index 44d57fe91f..9ed95d8aff 100644 --- a/test/unit/CMakeLists.txt +++ b/test/unit/CMakeLists.txt @@ -45,6 +45,7 @@ add_executable( visitor/kinetic_block.cpp visitor/localize.cpp visitor/localrename.cpp + visitor/derivative_original.cpp visitor/local_to_assigned.cpp visitor/lookup.cpp visitor/loop_unroll.cpp diff --git a/test/unit/visitor/derivative_original.cpp b/test/unit/visitor/derivative_original.cpp new file mode 100644 index 0000000000..d2f5e17cf2 --- /dev/null +++ b/test/unit/visitor/derivative_original.cpp @@ -0,0 +1,55 @@ +#include + +#include "ast/program.hpp" +#include "parser/nmodl_driver.hpp" +#include "test/unit/utils/test_utils.hpp" +#include "visitors/checkparent_visitor.hpp" +#include "visitors/nmodl_visitor.hpp" +#include "visitors/symtab_visitor.hpp" +#include "visitors/derivative_original_visitor.hpp" +#include "visitors/visitor_utils.hpp" + +using namespace nmodl; +using namespace visitor; +using namespace test; +using namespace test_utils; + +using nmodl::parser::NmodlDriver; + + +auto run_derivative_original_visitor(const std::string& text) { + NmodlDriver driver; + const auto& ast = driver.parse_string(text); + SymtabVisitor().visit_program(*ast); + DerivativeOriginalVisitor().visit_program(*ast); + + return ast; +} + + +TEST_CASE("Make sure DERIVATIVE block is copied properly", "[visitor][derivative_original]") { + GIVEN("DERIVATIVE block") { + std::string nmodl_text = R"( + NEURON { + SUFFIX example + } + + STATE {x z[2]} + + DERIVATIVE equation { + x' = -x + z'[0] = x + z'[1] = x + z[0] + } +)"; + auto ast = run_derivative_original_visitor(nmodl_text); + THEN("DERIVATIVE_ORIGINAL_FUNCTION block is added") { + auto block = collect_nodes(*ast, {ast::AstNodeType::DERIVATIVE_ORIGINAL_FUNCTION_BLOCK}); + REQUIRE(!block.empty()); + THEN("No primed variables exist in the DERIVATIVE_ORIGINAL_FUNCTION block") { + auto primed_vars = collect_nodes(*block[0], {ast::AstNodeType::PRIME_NAME}); + REQUIRE(primed_vars.empty()); + } + } + } +} From c5dc45e9a3a63e2f4e63996d4254426f6e2f6fa6 Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Wed, 25 Sep 2024 09:28:40 +0200 Subject: [PATCH 04/47] Fmt --- test/unit/visitor/derivative_original.cpp | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/test/unit/visitor/derivative_original.cpp b/test/unit/visitor/derivative_original.cpp index d2f5e17cf2..58c0bd9c4f 100644 --- a/test/unit/visitor/derivative_original.cpp +++ b/test/unit/visitor/derivative_original.cpp @@ -4,9 +4,9 @@ #include "parser/nmodl_driver.hpp" #include "test/unit/utils/test_utils.hpp" #include "visitors/checkparent_visitor.hpp" +#include "visitors/derivative_original_visitor.hpp" #include "visitors/nmodl_visitor.hpp" #include "visitors/symtab_visitor.hpp" -#include "visitors/derivative_original_visitor.hpp" #include "visitors/visitor_utils.hpp" using namespace nmodl; @@ -28,8 +28,8 @@ auto run_derivative_original_visitor(const std::string& text) { TEST_CASE("Make sure DERIVATIVE block is copied properly", "[visitor][derivative_original]") { - GIVEN("DERIVATIVE block") { - std::string nmodl_text = R"( + GIVEN("DERIVATIVE block") { + std::string nmodl_text = R"( NEURON { SUFFIX example } @@ -42,14 +42,15 @@ TEST_CASE("Make sure DERIVATIVE block is copied properly", "[visitor][derivative z'[1] = x + z[0] } )"; - auto ast = run_derivative_original_visitor(nmodl_text); - THEN("DERIVATIVE_ORIGINAL_FUNCTION block is added") { - auto block = collect_nodes(*ast, {ast::AstNodeType::DERIVATIVE_ORIGINAL_FUNCTION_BLOCK}); - REQUIRE(!block.empty()); - THEN("No primed variables exist in the DERIVATIVE_ORIGINAL_FUNCTION block") { - auto primed_vars = collect_nodes(*block[0], {ast::AstNodeType::PRIME_NAME}); - REQUIRE(primed_vars.empty()); - } + auto ast = run_derivative_original_visitor(nmodl_text); + THEN("DERIVATIVE_ORIGINAL_FUNCTION block is added") { + auto block = collect_nodes(*ast, + {ast::AstNodeType::DERIVATIVE_ORIGINAL_FUNCTION_BLOCK}); + REQUIRE(!block.empty()); + THEN("No primed variables exist in the DERIVATIVE_ORIGINAL_FUNCTION block") { + auto primed_vars = collect_nodes(*block[0], {ast::AstNodeType::PRIME_NAME}); + REQUIRE(primed_vars.empty()); } } + } } From 1dadd7a21b43b97b9caec984e1fb48ea2cf001db Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Wed, 25 Sep 2024 10:44:01 +0200 Subject: [PATCH 05/47] Fix leak --- src/visitors/derivative_original_visitor.cpp | 2 +- src/visitors/derivative_original_visitor.hpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/visitors/derivative_original_visitor.cpp b/src/visitors/derivative_original_visitor.cpp index 2e7b6942a2..9377641851 100644 --- a/src/visitors/derivative_original_visitor.cpp +++ b/src/visitors/derivative_original_visitor.cpp @@ -23,7 +23,7 @@ namespace visitor { void DerivativeOriginalVisitor::visit_derivative_block(ast::DerivativeBlock& node) { node.visit_children(*this); - der_block_function = node.clone(); + der_block_function = std::shared_ptr(node.clone()); } diff --git a/src/visitors/derivative_original_visitor.hpp b/src/visitors/derivative_original_visitor.hpp index d483ab845b..7178390ca8 100644 --- a/src/visitors/derivative_original_visitor.hpp +++ b/src/visitors/derivative_original_visitor.hpp @@ -36,7 +36,7 @@ namespace visitor { class DerivativeOriginalVisitor: public AstVisitor { private: /// The copy of the derivative block we are solving - ast::DerivativeBlock* der_block_function = nullptr; + std::shared_ptr der_block_function = nullptr; /// true while visiting differential equation bool differential_equation = false; From 1125fdf43c886b9250faabd7dbb55241f610ecac Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Wed, 25 Sep 2024 11:32:07 +0200 Subject: [PATCH 06/47] Remove unused stuff `DERIVATIVE` blocks can't have array variables in NOCMODL by default, so let's go with that. --- src/visitors/derivative_original_visitor.cpp | 41 ++++++-------------- src/visitors/derivative_original_visitor.hpp | 2 - test/unit/visitor/derivative_original.cpp | 7 ++-- 3 files changed, 14 insertions(+), 36 deletions(-) diff --git a/src/visitors/derivative_original_visitor.cpp b/src/visitors/derivative_original_visitor.cpp index 9377641851..43b3eb6df8 100644 --- a/src/visitors/derivative_original_visitor.cpp +++ b/src/visitors/derivative_original_visitor.cpp @@ -30,9 +30,7 @@ void DerivativeOriginalVisitor::visit_derivative_block(ast::DerivativeBlock& nod void DerivativeOriginalVisitor::visit_derivative_original_function_block( ast::DerivativeOriginalFunctionBlock& node) { derivative_block = true; - node_type = node.get_node_type(); node.visit_children(*this); - node_type = ast::AstNodeType::NODE; derivative_block = false; } @@ -53,34 +51,17 @@ void DerivativeOriginalVisitor::visit_binary_expression(ast::BinaryExpression& n auto name = std::dynamic_pointer_cast(lhs)->get_name(); - if (name->is_prime_name() || name->is_indexed_name()) { - std::string varname; - if (name->is_prime_name()) { - varname = "D" + name->get_node_name(); - logger->debug("DerivativeOriginalVisitor :: replacing {} with {} on LHS of {}", - name->get_node_name(), - varname, - to_nmodl(node)); - node.set_lhs(std::make_shared(new ast::String(varname))); - if (program_symtab->lookup(varname) == nullptr) { - auto symbol = std::make_shared(varname, ModToken()); - symbol->set_original_name(name->get_node_name()); - program_symtab->insert(symbol); - } - } else { - varname = "D" + stringutils::remove_character(to_nmodl(node.get_lhs()), '\''); - // we discard the RHS here so it can be anything (as long as NMODL considers it valid) - auto statement = fmt::format("{} = {}", varname, varname); - logger->debug("DerivativeOriginalVisitor :: replacing {} with {} on LHS of {}", - to_nmodl(node.get_lhs()), - varname, - to_nmodl(node)); - auto expr_statement = std::dynamic_pointer_cast( - create_statement(statement)); - const auto bin_expr = std::dynamic_pointer_cast( - expr_statement->get_expression()); - node.set_lhs(std::shared_ptr(bin_expr->get_lhs()->clone())); - // TODO add symbol? + if (name->is_prime_name()) { + auto varname = "D" + name->get_node_name(); + logger->debug("DerivativeOriginalVisitor :: replacing {} with {} on LHS of {}", + name->get_node_name(), + varname, + to_nmodl(node)); + node.set_lhs(std::make_shared(new ast::String(varname))); + if (program_symtab->lookup(varname) == nullptr) { + auto symbol = std::make_shared(varname, ModToken()); + symbol->set_original_name(name->get_node_name()); + program_symtab->insert(symbol); } } } diff --git a/src/visitors/derivative_original_visitor.hpp b/src/visitors/derivative_original_visitor.hpp index 7178390ca8..2fb3b26297 100644 --- a/src/visitors/derivative_original_visitor.hpp +++ b/src/visitors/derivative_original_visitor.hpp @@ -47,8 +47,6 @@ class DerivativeOriginalVisitor: public AstVisitor { /// visiting derivative block bool derivative_block = false; - ast::AstNodeType node_type = ast::AstNodeType::NODE; - public: void visit_derivative_block(ast::DerivativeBlock& node) override; void visit_program(ast::Program& node) override; diff --git a/test/unit/visitor/derivative_original.cpp b/test/unit/visitor/derivative_original.cpp index 58c0bd9c4f..4533de36d5 100644 --- a/test/unit/visitor/derivative_original.cpp +++ b/test/unit/visitor/derivative_original.cpp @@ -34,12 +34,11 @@ TEST_CASE("Make sure DERIVATIVE block is copied properly", "[visitor][derivative SUFFIX example } - STATE {x z[2]} + STATE {x z} DERIVATIVE equation { - x' = -x - z'[0] = x - z'[1] = x + z[0] + x' = -x + z * z + z' = z * x } )"; auto ast = run_derivative_original_visitor(nmodl_text); From 0267fbdf9c1f5a2fca08da574bd40a5d7949305c Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Wed, 25 Sep 2024 11:45:09 +0200 Subject: [PATCH 07/47] Update block description --- src/language/codegen.yaml | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/src/language/codegen.yaml b/src/language/codegen.yaml index ac92afb517..e31cb4aca4 100644 --- a/src/language/codegen.yaml +++ b/src/language/codegen.yaml @@ -100,16 +100,15 @@ brief: "Block with statements vector" type: StatementBlock getter: {override: true} - brief: "Represents the original, unmodified `DERIVATIVE` block in the NMODL" + brief: "Represents a copy of the `DERIVATIVE` block in NMODL with prime vars replaced by D vars" description: | The original `DERIVATIVE` block in NMODL is replaced in-place if the system of ODEs is solvable analytically. Therefore, this - block's sole purpose is to keep the - original, unsolved block in the AST. This is - primarily useful when we need to solve the - ODE system using implicit methods, for - instance, CVode. + block's sole purpose is to keep the unsolved + block in the AST. This is primarily useful + when we need to solve the ODE system using + implicit methods, for instance, CVode. - WrappedExpression: brief: "Wrap any other expression type" From e58070f1b398139120336e4c8af74ba81b41d0e8 Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Fri, 27 Sep 2024 15:29:28 +0200 Subject: [PATCH 08/47] Rename DERIVATIVE_ORIGINAL to CVODE --- src/language/code_generator.cmake | 2 +- src/language/codegen.yaml | 24 ++++++++------------ src/main.cpp | 6 ++--- src/visitors/derivative_original_visitor.cpp | 21 ++++++++--------- src/visitors/derivative_original_visitor.hpp | 7 +++--- src/visitors/sympy_solver_visitor.cpp | 5 ++-- src/visitors/sympy_solver_visitor.hpp | 3 +-- test/unit/visitor/derivative_original.cpp | 15 ++++++------ 8 files changed, 36 insertions(+), 47 deletions(-) diff --git a/src/language/code_generator.cmake b/src/language/code_generator.cmake index 992d5b0cb1..83ecff2eac 100644 --- a/src/language/code_generator.cmake +++ b/src/language/code_generator.cmake @@ -72,9 +72,9 @@ set(AST_GENERATED_SOURCES ${PROJECT_BINARY_DIR}/src/ast/constant_statement.hpp ${PROJECT_BINARY_DIR}/src/ast/constant_var.hpp ${PROJECT_BINARY_DIR}/src/ast/constructor_block.hpp + ${PROJECT_BINARY_DIR}/src/ast/cvode_block.hpp ${PROJECT_BINARY_DIR}/src/ast/define.hpp ${PROJECT_BINARY_DIR}/src/ast/derivative_block.hpp - ${PROJECT_BINARY_DIR}/src/ast/derivative_original_function_block.hpp ${PROJECT_BINARY_DIR}/src/ast/derivimplicit_callback.hpp ${PROJECT_BINARY_DIR}/src/ast/destructor_block.hpp ${PROJECT_BINARY_DIR}/src/ast/diff_eq_expression.hpp diff --git a/src/language/codegen.yaml b/src/language/codegen.yaml index e31cb4aca4..292cb567c8 100644 --- a/src/language/codegen.yaml +++ b/src/language/codegen.yaml @@ -88,27 +88,21 @@ - finalize_block: brief: "Statement block to be executed after calling linear solver" type: StatementBlock - - DerivativeOriginalFunctionBlock: - nmodl: "DERIVATIVE_ORIGINAL_FUNCTION " + - CvodeBlock: + nmodl: "CVODE_BLOCK " members: - name: - brief: "Name of the derivative block" + brief: "Name of the block" type: Name node_name: true suffix: {value: " "} - - statement_block: - brief: "Block with statements vector" + - function_block: + brief: "Block with statements of the form Dvar = f(var)" type: StatementBlock - getter: {override: true} - brief: "Represents a copy of the `DERIVATIVE` block in NMODL with prime vars replaced by D vars" - description: | - The original `DERIVATIVE` block in NMODL is - replaced in-place if the system of ODEs is - solvable analytically. Therefore, this - block's sole purpose is to keep the unsolved - block in the AST. This is primarily useful - when we need to solve the ODE system using - implicit methods, for instance, CVode. + - diagonal_jacobian_block: + brief: "Block with statements of the form Dvar = Dvar / (1 - dt * J(f))" + type: StatementBlock + brief: "Represents a block used for variable timestep integration (CVODE) of DERIVATIVE blocks" - WrappedExpression: brief: "Wrap any other expression type" diff --git a/src/main.cpp b/src/main.cpp index f150753479..b3620dc46b 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -499,10 +499,10 @@ int run_nmodl(int argc, const char* argv[]) { const bool sympy_sparse = solver_exists(*ast, "sparse"); if (neuron_code) { - logger->info("Running derivative visitor"); - DerivativeOriginalVisitor().visit_program(*ast); + logger->info("Running cvode visitor"); + CvodeVisitor().visit_program(*ast); SymtabVisitor(update_symtab).visit_program(*ast); - ast_to_nmodl(*ast, filepath("derivative_original")); + ast_to_nmodl(*ast, filepath("cvode")); } if (sympy_conductance || sympy_analytic || sympy_sparse || sympy_derivimplicit || diff --git a/src/visitors/derivative_original_visitor.cpp b/src/visitors/derivative_original_visitor.cpp index 43b3eb6df8..97af9c32f7 100644 --- a/src/visitors/derivative_original_visitor.cpp +++ b/src/visitors/derivative_original_visitor.cpp @@ -21,27 +21,26 @@ namespace nmodl { namespace visitor { -void DerivativeOriginalVisitor::visit_derivative_block(ast::DerivativeBlock& node) { +void CvodeVisitor::visit_derivative_block(ast::DerivativeBlock& node) { node.visit_children(*this); - der_block_function = std::shared_ptr(node.clone()); + der_block = std::shared_ptr(node.clone()); } -void DerivativeOriginalVisitor::visit_derivative_original_function_block( - ast::DerivativeOriginalFunctionBlock& node) { +void CvodeVisitor::visit_cvode_block(ast::CvodeBlock& node) { derivative_block = true; node.visit_children(*this); derivative_block = false; } -void DerivativeOriginalVisitor::visit_diff_eq_expression(ast::DiffEqExpression& node) { +void CvodeVisitor::visit_diff_eq_expression(ast::DiffEqExpression& node) { differential_equation = true; node.visit_children(*this); differential_equation = false; } -void DerivativeOriginalVisitor::visit_binary_expression(ast::BinaryExpression& node) { +void CvodeVisitor::visit_binary_expression(ast::BinaryExpression& node) { const auto& lhs = node.get_lhs(); /// we have to only solve ODEs under original derivative block where lhs is variable @@ -66,13 +65,13 @@ void DerivativeOriginalVisitor::visit_binary_expression(ast::BinaryExpression& n } } -void DerivativeOriginalVisitor::visit_program(ast::Program& node) { +void CvodeVisitor::visit_program(ast::Program& node) { program_symtab = node.get_symbol_table(); node.visit_children(*this); - if (der_block_function) { - auto der_node = - new ast::DerivativeOriginalFunctionBlock(der_block_function->get_name(), - der_block_function->get_statement_block()); + if (der_block) { + auto der_node = new ast::CvodeBlock(der_block->get_name(), + der_block->get_statement_block(), + der_block->get_statement_block()); node.emplace_back_node(der_node); } diff --git a/src/visitors/derivative_original_visitor.hpp b/src/visitors/derivative_original_visitor.hpp index 2fb3b26297..9edb792186 100644 --- a/src/visitors/derivative_original_visitor.hpp +++ b/src/visitors/derivative_original_visitor.hpp @@ -33,10 +33,10 @@ namespace visitor { * solution. This block is inserted before that to prevent losing access to * information about the block. */ -class DerivativeOriginalVisitor: public AstVisitor { +class CvodeVisitor: public AstVisitor { private: /// The copy of the derivative block we are solving - std::shared_ptr der_block_function = nullptr; + std::shared_ptr der_block = nullptr; /// true while visiting differential equation bool differential_equation = false; @@ -50,8 +50,7 @@ class DerivativeOriginalVisitor: public AstVisitor { public: void visit_derivative_block(ast::DerivativeBlock& node) override; void visit_program(ast::Program& node) override; - void visit_derivative_original_function_block( - ast::DerivativeOriginalFunctionBlock& node) override; + void visit_cvode_block(ast::CvodeBlock& node) override; void visit_diff_eq_expression(ast::DiffEqExpression& node) override; void visit_binary_expression(ast::BinaryExpression& node) override; }; diff --git a/src/visitors/sympy_solver_visitor.cpp b/src/visitors/sympy_solver_visitor.cpp index e7b955a5c0..42936ae5e6 100644 --- a/src/visitors/sympy_solver_visitor.cpp +++ b/src/visitors/sympy_solver_visitor.cpp @@ -399,9 +399,8 @@ void SympySolverVisitor::visit_var_name(ast::VarName& node) { } } -// Skip visiting DERIVATIVE_ORIGINAL block -void SympySolverVisitor::visit_derivative_original_function_block( - ast::DerivativeOriginalFunctionBlock& node) {} +// Skip visiting CVODE block +void SympySolverVisitor::visit_cvode_block(ast::CvodeBlock& node) {} void SympySolverVisitor::visit_diff_eq_expression(ast::DiffEqExpression& node) { const auto& lhs = node.get_expression()->get_lhs(); diff --git a/src/visitors/sympy_solver_visitor.hpp b/src/visitors/sympy_solver_visitor.hpp index 627451d4b7..7642b79411 100644 --- a/src/visitors/sympy_solver_visitor.hpp +++ b/src/visitors/sympy_solver_visitor.hpp @@ -185,8 +185,7 @@ class SympySolverVisitor: public AstVisitor { void visit_expression_statement(ast::ExpressionStatement& node) override; void visit_statement_block(ast::StatementBlock& node) override; void visit_program(ast::Program& node) override; - void visit_derivative_original_function_block( - ast::DerivativeOriginalFunctionBlock& node) override; + void visit_cvode_block(ast::CvodeBlock& node) override; }; /** @} */ // end of visitor_classes diff --git a/test/unit/visitor/derivative_original.cpp b/test/unit/visitor/derivative_original.cpp index 4533de36d5..5fbc60fcaa 100644 --- a/test/unit/visitor/derivative_original.cpp +++ b/test/unit/visitor/derivative_original.cpp @@ -17,17 +17,17 @@ using namespace test_utils; using nmodl::parser::NmodlDriver; -auto run_derivative_original_visitor(const std::string& text) { +auto run_cvode_visitor(const std::string& text) { NmodlDriver driver; const auto& ast = driver.parse_string(text); SymtabVisitor().visit_program(*ast); - DerivativeOriginalVisitor().visit_program(*ast); + CvodeVisitor().visit_program(*ast); return ast; } -TEST_CASE("Make sure DERIVATIVE block is copied properly", "[visitor][derivative_original]") { +TEST_CASE("Make sure DERIVATIVE block is copied properly", "[visitor][cvode]") { GIVEN("DERIVATIVE block") { std::string nmodl_text = R"( NEURON { @@ -41,12 +41,11 @@ TEST_CASE("Make sure DERIVATIVE block is copied properly", "[visitor][derivative z' = z * x } )"; - auto ast = run_derivative_original_visitor(nmodl_text); - THEN("DERIVATIVE_ORIGINAL_FUNCTION block is added") { - auto block = collect_nodes(*ast, - {ast::AstNodeType::DERIVATIVE_ORIGINAL_FUNCTION_BLOCK}); + auto ast = run_cvode_visitor(nmodl_text); + THEN("CVODE block is added") { + auto block = collect_nodes(*ast, {ast::AstNodeType::CVODE_BLOCK}); REQUIRE(!block.empty()); - THEN("No primed variables exist in the DERIVATIVE_ORIGINAL_FUNCTION block") { + THEN("No primed variables exist in the CVODE block") { auto primed_vars = collect_nodes(*block[0], {ast::AstNodeType::PRIME_NAME}); REQUIRE(primed_vars.empty()); } From 044dfd93250209ae794fd37fa8ff8702574e5c48 Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Mon, 30 Sep 2024 09:14:34 +0200 Subject: [PATCH 09/47] Finish renaming --- src/main.cpp | 2 +- src/visitors/CMakeLists.txt | 2 +- ...ative_original_visitor.cpp => cvode_visitor.cpp} | 6 +++--- ...ative_original_visitor.hpp => cvode_visitor.hpp} | 13 ++++--------- test/unit/CMakeLists.txt | 2 +- .../visitor/{derivative_original.cpp => cvode.cpp} | 2 +- 6 files changed, 11 insertions(+), 16 deletions(-) rename src/visitors/{derivative_original_visitor.cpp => cvode_visitor.cpp} (91%) rename src/visitors/{derivative_original_visitor.hpp => cvode_visitor.hpp} (70%) rename test/unit/visitor/{derivative_original.cpp => cvode.cpp} (96%) diff --git a/src/main.cpp b/src/main.cpp index b3620dc46b..1ac71b752c 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -25,7 +25,7 @@ #include "visitors/after_cvode_to_cnexp_visitor.hpp" #include "visitors/ast_visitor.hpp" #include "visitors/constant_folder_visitor.hpp" -#include "visitors/derivative_original_visitor.hpp" +#include "visitors/cvode_visitor.hpp" #include "visitors/function_callpath_visitor.hpp" #include "visitors/global_var_visitor.hpp" #include "visitors/implicit_argument_visitor.hpp" diff --git a/src/visitors/CMakeLists.txt b/src/visitors/CMakeLists.txt index ede77671eb..f51a65b732 100644 --- a/src/visitors/CMakeLists.txt +++ b/src/visitors/CMakeLists.txt @@ -11,7 +11,7 @@ add_library( visitor STATIC after_cvode_to_cnexp_visitor.cpp constant_folder_visitor.cpp - derivative_original_visitor.cpp + cvode_visitor.cpp defuse_analyze_visitor.cpp function_callpath_visitor.cpp global_var_visitor.cpp diff --git a/src/visitors/derivative_original_visitor.cpp b/src/visitors/cvode_visitor.cpp similarity index 91% rename from src/visitors/derivative_original_visitor.cpp rename to src/visitors/cvode_visitor.cpp index 97af9c32f7..ee60c9451b 100644 --- a/src/visitors/derivative_original_visitor.cpp +++ b/src/visitors/cvode_visitor.cpp @@ -5,7 +5,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -#include "visitors/derivative_original_visitor.hpp" +#include "visitors/cvode_visitor.hpp" #include "ast/all.hpp" #include "lexer/token_mapping.hpp" @@ -52,7 +52,7 @@ void CvodeVisitor::visit_binary_expression(ast::BinaryExpression& node) { if (name->is_prime_name()) { auto varname = "D" + name->get_node_name(); - logger->debug("DerivativeOriginalVisitor :: replacing {} with {} on LHS of {}", + logger->debug("CvodeVisitor :: replacing {} with {} on LHS of {}", name->get_node_name(), varname, to_nmodl(node)); @@ -75,7 +75,7 @@ void CvodeVisitor::visit_program(ast::Program& node) { node.emplace_back_node(der_node); } - // re-visit the AST since we now inserted the DERIVATIVE_ORIGINAL block + // re-visit the AST since we now inserted the CVODE block node.visit_children(*this); } diff --git a/src/visitors/derivative_original_visitor.hpp b/src/visitors/cvode_visitor.hpp similarity index 70% rename from src/visitors/derivative_original_visitor.hpp rename to src/visitors/cvode_visitor.hpp index 9edb792186..baeed0f84f 100644 --- a/src/visitors/derivative_original_visitor.hpp +++ b/src/visitors/cvode_visitor.hpp @@ -9,7 +9,7 @@ /** * \file - * \brief \copybrief nmodl::visitor::DerivativeOriginalVisitor + * \brief \copybrief nmodl::visitor::CvodeVisitor */ #include "symtab/decl.hpp" @@ -25,17 +25,12 @@ namespace visitor { */ /** - * \class DerivativeOriginalVisitor - * \brief Make a copy of the `DERIVATIVE` block (if it exists), and insert back as - * `DERIVATIVE_ORIGINAL_FUNCTION` block. - * - * If \ref SympySolverVisitor runs successfully, it replaces the original - * solution. This block is inserted before that to prevent losing access to - * information about the block. + * \class CvodeVisitor + * \brief Visitor used for generating the necessary AST nodes for CVODE */ class CvodeVisitor: public AstVisitor { private: - /// The copy of the derivative block we are solving + /// The copy of the derivative block of a given mod file std::shared_ptr der_block = nullptr; /// true while visiting differential equation diff --git a/test/unit/CMakeLists.txt b/test/unit/CMakeLists.txt index 9ed95d8aff..f12d5167bb 100644 --- a/test/unit/CMakeLists.txt +++ b/test/unit/CMakeLists.txt @@ -45,7 +45,7 @@ add_executable( visitor/kinetic_block.cpp visitor/localize.cpp visitor/localrename.cpp - visitor/derivative_original.cpp + visitor/cvode.cpp visitor/local_to_assigned.cpp visitor/lookup.cpp visitor/loop_unroll.cpp diff --git a/test/unit/visitor/derivative_original.cpp b/test/unit/visitor/cvode.cpp similarity index 96% rename from test/unit/visitor/derivative_original.cpp rename to test/unit/visitor/cvode.cpp index 5fbc60fcaa..bdc4777665 100644 --- a/test/unit/visitor/derivative_original.cpp +++ b/test/unit/visitor/cvode.cpp @@ -4,7 +4,7 @@ #include "parser/nmodl_driver.hpp" #include "test/unit/utils/test_utils.hpp" #include "visitors/checkparent_visitor.hpp" -#include "visitors/derivative_original_visitor.hpp" +#include "visitors/cvode_visitor.hpp" #include "visitors/nmodl_visitor.hpp" #include "visitors/symtab_visitor.hpp" #include "visitors/visitor_utils.hpp" From 50f38cec0843fa5db2db4768783c198e21295222 Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Mon, 30 Sep 2024 12:06:12 +0200 Subject: [PATCH 10/47] Add item with Jacobian --- python/nmodl/ode.py | 25 +++++++++-- src/main.cpp | 17 ++++---- src/pybind/wrapper.cpp | 40 +++++++++++++++++- src/pybind/wrapper.hpp | 7 ++++ src/visitors/cvode_visitor.cpp | 77 +++++++++++++++++++++++++++++----- src/visitors/cvode_visitor.hpp | 12 ++++-- test/unit/visitor/cvode.cpp | 2 +- 7 files changed, 151 insertions(+), 29 deletions(-) diff --git a/python/nmodl/ode.py b/python/nmodl/ode.py index 3fe769e596..2eab38e873 100644 --- a/python/nmodl/ode.py +++ b/python/nmodl/ode.py @@ -608,7 +608,12 @@ def differentiate2c(expression, dependent_var, vars, prev_expressions=None): vars = set(vars) vars.discard(dependent_var) # declare all other supplied variables - sympy_vars = {var: sp.symbols(var, real=True) for var in vars} + sympy_vars = { + var if isinstance(var, str) else str(var): ( + sp.symbols(var, real=True) if isinstance(var, str) else var + ) + for var in vars + } sympy_vars[dependent_var] = x # parse string into SymPy equation @@ -643,15 +648,27 @@ def differentiate2c(expression, dependent_var, vars, prev_expressions=None): # differentiate w.r.t. x diff = expr.diff(x).simplify() + # could be something generic like f'(x), in which case we use finite differences + if needs_finite_differences(diff): + diff = ( + transform_expression(diff, discretize_derivative) + .subs({finite_difference_step_variable(x): 1e-3}) + .evalf() + ) + + # the codegen method does not like undefined function calls, so we extract + # them here + custom_fcts = {str(f.func): str(f.func) for f in diff.atoms(sp.Function)} + # try to simplify expression in terms of existing variables # ignore any exceptions here, since we already have a valid solution # so if this further simplification step fails the error is not fatal try: # if expression is equal to one of the supplied vars, replace with this var # can do a simple string comparison here since a var cannot be further simplified - diff_as_string = sp.ccode(diff) + diff_as_string = sp.ccode(diff, user_functions=custom_fcts) for v in sympy_vars: - if diff_as_string == sp.ccode(sympy_vars[v]): + if diff_as_string == sp.ccode(sympy_vars[v], user_functions=custom_fcts): diff = sympy_vars[v] # or if equal to rhs of one of the supplied equations, replace with lhs @@ -672,4 +689,4 @@ def differentiate2c(expression, dependent_var, vars, prev_expressions=None): pass # return result as C code in NEURON format - return sp.ccode(diff.evalf()) + return sp.ccode(diff.evalf(), user_functions=custom_fcts) diff --git a/src/main.cpp b/src/main.cpp index 1ac71b752c..a4ea9e266b 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -498,18 +498,19 @@ int run_nmodl(int argc, const char* argv[]) { const bool sympy_linear = node_exists(*ast, ast::AstNodeType::LINEAR_BLOCK); const bool sympy_sparse = solver_exists(*ast, "sparse"); - if (neuron_code) { - logger->info("Running cvode visitor"); - CvodeVisitor().visit_program(*ast); - SymtabVisitor(update_symtab).visit_program(*ast); - ast_to_nmodl(*ast, filepath("cvode")); - } - if (sympy_conductance || sympy_analytic || sympy_sparse || sympy_derivimplicit || - sympy_linear) { + sympy_linear || neuron_code) { nmodl::pybind_wrappers::EmbeddedPythonLoader::get_instance() .api() .initialize_interpreter(); + + if (neuron_code) { + logger->info("Running CVODE visitor"); + CvodeVisitor().visit_program(*ast); + SymtabVisitor(update_symtab).visit_program(*ast); + ast_to_nmodl(*ast, filepath("cvode")); + } + if (sympy_conductance) { logger->info("Running sympy conductance visitor"); SympyConductanceVisitor().visit_program(*ast); diff --git a/src/pybind/wrapper.cpp b/src/pybind/wrapper.cpp index 32c390736c..ae9d414976 100644 --- a/src/pybind/wrapper.cpp +++ b/src/pybind/wrapper.cpp @@ -9,7 +9,7 @@ #include "codegen/codegen_naming.hpp" #include "pybind/pyembed.hpp" - +#include #include #include @@ -186,6 +186,41 @@ except Exception as e: return {std::move(solution), std::move(exception_message)}; } +/// \brief A blunt instrument that differentiates expression w.r.t. variable +/// \return The tuple (solution, exception) +std::tuple call_diff2c( + const std::string& expression, + const std::string& variable, + const std::unordered_map& indexed_vars) { + std::string statements; + // only indexed variables require special treatment + for (const auto& [var, prop]: indexed_vars) { + statements += fmt::format("_allvars.append(sp.IndexedBase('{}', shape=[1]))\n", var); + } + auto locals = py::dict("expression"_a = expression, "variable"_a = variable); + std::string script = fmt::format(R"( +_allvars = [] +{} +exception_message = "" +try: + solution = differentiate2c(expression, + variable, + _allvars, + ) +except Exception as e: + # if we fail, fail silently and return empty string + solution = "" + exception_message = str(e) +)", + statements); + + py::exec(nmodl::pybind_wrappers::ode_py + script, locals); + + auto solution = locals["solution"].cast(); + auto exception_message = locals["exception_message"].cast(); + + return {std::move(solution), std::move(exception_message)}; +} void initialize_interpreter_func() { pybind11::initialize_interpreter(true); @@ -203,7 +238,8 @@ NMODL_EXPORT pybind_wrap_api nmodl_init_pybind_wrapper_api() noexcept { &call_solve_nonlinear_system, &call_solve_linear_system, &call_diffeq_solver, - &call_analytic_diff}; + &call_analytic_diff, + &call_diff2c}; } } diff --git a/src/pybind/wrapper.hpp b/src/pybind/wrapper.hpp index 725f9f8113..b4ec0a2dff 100644 --- a/src/pybind/wrapper.hpp +++ b/src/pybind/wrapper.hpp @@ -9,6 +9,7 @@ #include #include +#include #include namespace nmodl { @@ -44,6 +45,11 @@ std::tuple call_analytic_diff( const std::vector& expressions, const std::set& used_names_in_block); +std::tuple call_diff2c( + const std::string& expression, + const std::string& variable, + const std::unordered_map& indexed_vars = {}); + struct pybind_wrap_api { decltype(&initialize_interpreter_func) initialize_interpreter; decltype(&finalize_interpreter_func) finalize_interpreter; @@ -51,6 +57,7 @@ struct pybind_wrap_api { decltype(&call_solve_linear_system) solve_linear_system; decltype(&call_diffeq_solver) diffeq_solver; decltype(&call_analytic_diff) analytic_diff; + decltype(&call_diff2c) diff2c; }; #ifdef _WIN32 diff --git a/src/visitors/cvode_visitor.cpp b/src/visitors/cvode_visitor.cpp index ee60c9451b..c326fd36ef 100644 --- a/src/visitors/cvode_visitor.cpp +++ b/src/visitors/cvode_visitor.cpp @@ -20,23 +20,56 @@ namespace pywrap = nmodl::pybind_wrappers; namespace nmodl { namespace visitor { +static int get_index(const ast::IndexedName& node) { + return std::stoi(to_nmodl(node.get_length())); +} + +static auto get_name_map(const ast::Expression& node, const std::string& name) { + std::unordered_map name_map; + // all of the "reserved" symbols + auto reserved_symbols = get_external_functions(); + // all indexed vars + auto indexed_vars = collect_nodes(node, {ast::AstNodeType::INDEXED_NAME}); + for (const auto& var: indexed_vars) { + if (!name_map.count(var->get_node_name()) && var->get_node_name() != name && + std::none_of(reserved_symbols.begin(), reserved_symbols.end(), [&var](const auto item) { + return var->get_node_name() == item; + })) { + logger->debug( + "DerivativeOriginalVisitor :: adding INDEXED_VARIABLE {} to " + "node_map", + var->get_node_name()); + name_map[var->get_node_name()] = get_index( + *std::dynamic_pointer_cast(var)); + } + } + return name_map; +} void CvodeVisitor::visit_derivative_block(ast::DerivativeBlock& node) { node.visit_children(*this); - der_block = std::shared_ptr(node.clone()); + derivative_block = std::shared_ptr(node.clone()); } void CvodeVisitor::visit_cvode_block(ast::CvodeBlock& node) { - derivative_block = true; + in_cvode_block = true; node.visit_children(*this); - derivative_block = false; + in_cvode_block = false; } void CvodeVisitor::visit_diff_eq_expression(ast::DiffEqExpression& node) { - differential_equation = true; + in_differential_equation = true; node.visit_children(*this); - differential_equation = false; + in_differential_equation = false; +} + + +void CvodeVisitor::visit_statement_block(ast::StatementBlock& node) { + node.visit_children(*this); + if (in_cvode_block) { + ++block_index; + } } @@ -44,7 +77,7 @@ void CvodeVisitor::visit_binary_expression(ast::BinaryExpression& node) { const auto& lhs = node.get_lhs(); /// we have to only solve ODEs under original derivative block where lhs is variable - if (!derivative_block || !differential_equation || !lhs->is_var_name()) { + if (!in_cvode_block || !in_differential_equation || !lhs->is_var_name()) { return; } @@ -62,16 +95,40 @@ void CvodeVisitor::visit_binary_expression(ast::BinaryExpression& node) { symbol->set_original_name(name->get_node_name()); program_symtab->insert(symbol); } + if (block_index == 1) { + auto rhs = node.get_rhs(); + // map of all indexed symbols (need special treatment in SymPy) + auto name_map = get_name_map(*rhs, name->get_node_name()); + auto diff2c = pywrap::EmbeddedPythonLoader::get_instance().api().diff2c; + auto [jacobian, + exception_message] = diff2c(to_nmodl(*rhs), name->get_node_name(), name_map); + if (!exception_message.empty()) { + logger->warn("DerivativeOriginalVisitor :: python exception: {}", + exception_message); + } + // NOTE: LHS can be anything here, the equality is to keep `create_statement` from + // complaining, we discard the LHS later + auto statement = fmt::format("{} = {} / (1 - dt * ({}))", varname, varname, jacobian); + logger->debug("DerivativeOriginalVisitor :: replacing statement {} with {}", + to_nmodl(node), + statement); + auto expr_statement = std::dynamic_pointer_cast( + create_statement(statement)); + const auto bin_expr = std::dynamic_pointer_cast( + expr_statement->get_expression()); + node.set_rhs(std::shared_ptr(bin_expr->get_rhs()->clone())); + } } } void CvodeVisitor::visit_program(ast::Program& node) { program_symtab = node.get_symbol_table(); node.visit_children(*this); - if (der_block) { - auto der_node = new ast::CvodeBlock(der_block->get_name(), - der_block->get_statement_block(), - der_block->get_statement_block()); + if (derivative_block) { + auto der_node = new ast::CvodeBlock(derivative_block->get_name(), + derivative_block->get_statement_block(), + std::shared_ptr( + derivative_block->get_statement_block()->clone())); node.emplace_back_node(der_node); } diff --git a/src/visitors/cvode_visitor.hpp b/src/visitors/cvode_visitor.hpp index baeed0f84f..0bf336cfd4 100644 --- a/src/visitors/cvode_visitor.hpp +++ b/src/visitors/cvode_visitor.hpp @@ -31,16 +31,19 @@ namespace visitor { class CvodeVisitor: public AstVisitor { private: /// The copy of the derivative block of a given mod file - std::shared_ptr der_block = nullptr; + std::shared_ptr derivative_block = nullptr; /// true while visiting differential equation - bool differential_equation = false; + bool in_differential_equation = false; /// global symbol table symtab::SymbolTable* program_symtab = nullptr; - /// visiting derivative block - bool derivative_block = false; + /// true while we are visiting a CVODE block + bool in_cvode_block = false; + + /// index of the block to modify (0 = function block, 1 = Jacobian block) + int block_index = 0; public: void visit_derivative_block(ast::DerivativeBlock& node) override; @@ -48,6 +51,7 @@ class CvodeVisitor: public AstVisitor { void visit_cvode_block(ast::CvodeBlock& node) override; void visit_diff_eq_expression(ast::DiffEqExpression& node) override; void visit_binary_expression(ast::BinaryExpression& node) override; + void visit_statement_block(ast::StatementBlock& node) override; }; /** \} */ // end of visitor_classes diff --git a/test/unit/visitor/cvode.cpp b/test/unit/visitor/cvode.cpp index bdc4777665..3a57d242d4 100644 --- a/test/unit/visitor/cvode.cpp +++ b/test/unit/visitor/cvode.cpp @@ -27,7 +27,7 @@ auto run_cvode_visitor(const std::string& text) { } -TEST_CASE("Make sure DERIVATIVE block is copied properly", "[visitor][cvode]") { +TEST_CASE("Make sure CVODE block is generated properly", "[visitor][cvode]") { GIVEN("DERIVATIVE block") { std::string nmodl_text = R"( NEURON { From f82fe1f531b6c9b500040a959412eddf292892fe Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Mon, 30 Sep 2024 12:39:30 +0200 Subject: [PATCH 11/47] Do not use an int but an enum-wrapped int --- src/visitors/cvode_visitor.cpp | 2 +- src/visitors/cvode_visitor.hpp | 14 ++++++++++++-- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/src/visitors/cvode_visitor.cpp b/src/visitors/cvode_visitor.cpp index c326fd36ef..0ed1a97573 100644 --- a/src/visitors/cvode_visitor.cpp +++ b/src/visitors/cvode_visitor.cpp @@ -95,7 +95,7 @@ void CvodeVisitor::visit_binary_expression(ast::BinaryExpression& node) { symbol->set_original_name(name->get_node_name()); program_symtab->insert(symbol); } - if (block_index == 1) { + if (block_index == BlockIndex::JACOBIAN) { auto rhs = node.get_rhs(); // map of all indexed symbols (need special treatment in SymPy) auto name_map = get_name_map(*rhs, name->get_node_name()); diff --git a/src/visitors/cvode_visitor.hpp b/src/visitors/cvode_visitor.hpp index 0bf336cfd4..7dcef42839 100644 --- a/src/visitors/cvode_visitor.hpp +++ b/src/visitors/cvode_visitor.hpp @@ -19,6 +19,16 @@ namespace nmodl { namespace visitor { +enum class BlockIndex { FUNCTION = 0, JACOBIAN = 1 }; + +inline BlockIndex& operator++(BlockIndex& index) { + if (index == BlockIndex::FUNCTION) { + index = BlockIndex::JACOBIAN; + } else { + index = BlockIndex::FUNCTION; + } + return index; +} /** * \addtogroup visitor_classes * \{ @@ -42,8 +52,8 @@ class CvodeVisitor: public AstVisitor { /// true while we are visiting a CVODE block bool in_cvode_block = false; - /// index of the block to modify (0 = function block, 1 = Jacobian block) - int block_index = 0; + /// index of the block to modify + BlockIndex block_index = BlockIndex::FUNCTION; public: void visit_derivative_block(ast::DerivativeBlock& node) override; From bd2fd36a5c2f961ff4e9bae724daa161647dbb93 Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Mon, 30 Sep 2024 14:01:31 +0200 Subject: [PATCH 12/47] Add support for diffing expressions with indexed vars --- python/nmodl/ode.py | 7 ++++++- test/unit/ode/test_ode.py | 17 ++++++++++++++++- 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/python/nmodl/ode.py b/python/nmodl/ode.py index 3fe769e596..2e110b3842 100644 --- a/python/nmodl/ode.py +++ b/python/nmodl/ode.py @@ -608,7 +608,12 @@ def differentiate2c(expression, dependent_var, vars, prev_expressions=None): vars = set(vars) vars.discard(dependent_var) # declare all other supplied variables - sympy_vars = {var: sp.symbols(var, real=True) for var in vars} + sympy_vars = { + var if isinstance(var, str) else str(var): ( + sp.symbols(var, real=True) if isinstance(var, str) else var + ) + for var in vars + } sympy_vars[dependent_var] = x # parse string into SymPy equation diff --git a/test/unit/ode/test_ode.py b/test/unit/ode/test_ode.py index 387cfb801f..0c195bc02b 100644 --- a/test/unit/ode/test_ode.py +++ b/test/unit/ode/test_ode.py @@ -28,7 +28,12 @@ def _equivalent( """ lhs = lhs.replace("pow(", "Pow(") rhs = rhs.replace("pow(", "Pow(") - sympy_vars = {var: sp.symbols(var, real=True) for var in vars} + sympy_vars = { + var if isinstance(var, str) else str(var): ( + sp.symbols(var, real=True) if isinstance(var, str) else var + ) + for var in vars + } for l, r in zip(lhs.split("=", 1), rhs.split("=", 1)): eq_l = sp.sympify(l, locals=sympy_vars) eq_r = sp.sympify(r, locals=sympy_vars) @@ -100,6 +105,16 @@ def test_differentiate2c(): "g", ) + assert _equivalent( + differentiate2c( + "(s[0] + s[1])*(z[0]*z[1]*z[2])*x", + "x", + {sp.IndexedBase("s", shape=[1]), sp.IndexedBase("z", shape=[1])}, + ), + "(s[0] + s[1])*(z[0]*z[1]*z[2])", + {sp.IndexedBase("s", shape=[1]), sp.IndexedBase("z", shape=[1])}, + ) + def test_integrate2c(): From b082f0d29b3a1f1d8d30e9c1d43375ac2f15219b Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Mon, 30 Sep 2024 14:29:28 +0200 Subject: [PATCH 13/47] Allow diffing implicit functions in `differentiate2c` Uses finite differences --- python/nmodl/ode.py | 18 +++++++++++++++--- test/unit/ode/test_ode.py | 10 ++++++++++ 2 files changed, 25 insertions(+), 3 deletions(-) diff --git a/python/nmodl/ode.py b/python/nmodl/ode.py index 3fe769e596..66b3a752e2 100644 --- a/python/nmodl/ode.py +++ b/python/nmodl/ode.py @@ -643,15 +643,27 @@ def differentiate2c(expression, dependent_var, vars, prev_expressions=None): # differentiate w.r.t. x diff = expr.diff(x).simplify() + # could be something generic like f'(x), in which case we use finite differences + if needs_finite_differences(diff): + diff = ( + transform_expression(diff, discretize_derivative) + .subs({finite_difference_step_variable(x): 1e-3}) + .evalf() + ) + + # the codegen method does not like undefined function calls, so we extract + # them here + custom_fcts = {str(f.func): str(f.func) for f in diff.atoms(sp.Function)} + # try to simplify expression in terms of existing variables # ignore any exceptions here, since we already have a valid solution # so if this further simplification step fails the error is not fatal try: # if expression is equal to one of the supplied vars, replace with this var # can do a simple string comparison here since a var cannot be further simplified - diff_as_string = sp.ccode(diff) + diff_as_string = sp.ccode(diff, user_functions=custom_fcts) for v in sympy_vars: - if diff_as_string == sp.ccode(sympy_vars[v]): + if diff_as_string == sp.ccode(sympy_vars[v], user_functions=custom_fcts): diff = sympy_vars[v] # or if equal to rhs of one of the supplied equations, replace with lhs @@ -672,4 +684,4 @@ def differentiate2c(expression, dependent_var, vars, prev_expressions=None): pass # return result as C code in NEURON format - return sp.ccode(diff.evalf()) + return sp.ccode(diff.evalf(), user_functions=custom_fcts) diff --git a/test/unit/ode/test_ode.py b/test/unit/ode/test_ode.py index 387cfb801f..0d5e7f628a 100644 --- a/test/unit/ode/test_ode.py +++ b/test/unit/ode/test_ode.py @@ -100,6 +100,16 @@ def test_differentiate2c(): "g", ) + assert _equivalent( + differentiate2c( + "-f(x)", + "x", + {}, + ), + "1000.0*f(x - 0.00050000000000000001) - 1000.0*f(x + 0.00050000000000000001)", + {"x"}, + ) + def test_integrate2c(): From edf33a70000bff2908267a5a45251a6be6381675 Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Mon, 30 Sep 2024 16:34:44 +0200 Subject: [PATCH 14/47] Simplify condition --- python/nmodl/ode.py | 4 +--- test/unit/ode/test_ode.py | 4 +--- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/python/nmodl/ode.py b/python/nmodl/ode.py index 2e110b3842..4e5b9be253 100644 --- a/python/nmodl/ode.py +++ b/python/nmodl/ode.py @@ -609,9 +609,7 @@ def differentiate2c(expression, dependent_var, vars, prev_expressions=None): vars.discard(dependent_var) # declare all other supplied variables sympy_vars = { - var if isinstance(var, str) else str(var): ( - sp.symbols(var, real=True) if isinstance(var, str) else var - ) + str(var): (sp.symbols(var, real=True) if isinstance(var, str) else var) for var in vars } sympy_vars[dependent_var] = x diff --git a/test/unit/ode/test_ode.py b/test/unit/ode/test_ode.py index 0c195bc02b..33810c16da 100644 --- a/test/unit/ode/test_ode.py +++ b/test/unit/ode/test_ode.py @@ -29,9 +29,7 @@ def _equivalent( lhs = lhs.replace("pow(", "Pow(") rhs = rhs.replace("pow(", "Pow(") sympy_vars = { - var if isinstance(var, str) else str(var): ( - sp.symbols(var, real=True) if isinstance(var, str) else var - ) + str(var): (sp.symbols(var, real=True) if isinstance(var, str) else var) for var in vars } for l, r in zip(lhs.split("=", 1), rhs.split("=", 1)): From 565fa03c1d9647badc146e82b24117ab7f8b3272 Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Tue, 1 Oct 2024 11:02:57 +0200 Subject: [PATCH 15/47] Better testing --- test/unit/ode/test_ode.py | 31 +++++++++++++++++++++++-------- 1 file changed, 23 insertions(+), 8 deletions(-) diff --git a/test/unit/ode/test_ode.py b/test/unit/ode/test_ode.py index 0d5e7f628a..21ee8a88f9 100644 --- a/test/unit/ode/test_ode.py +++ b/test/unit/ode/test_ode.py @@ -4,6 +4,7 @@ # SPDX-License-Identifier: Apache-2.0 from nmodl.ode import differentiate2c, integrate2c +import numpy as np import sympy as sp @@ -100,15 +101,29 @@ def test_differentiate2c(): "g", ) - assert _equivalent( - differentiate2c( - "-f(x)", - "x", - {}, - ), - "1000.0*f(x - 0.00050000000000000001) - 1000.0*f(x + 0.00050000000000000001)", - {"x"}, + result = differentiate2c( + "-f(x)", + "x", + {}, ) + # instead of comparing the expression as a string, we convert the string + # back to an expression and insert various functions + for function in [sp.sin, sp.exp, sp.tanh]: + for value in np.linspace(-5, 5, 100): + np.testing.assert_allclose( + float( + sp.sympify(result) + .subs(sp.Function("f"), function) + .subs({"x": value}) + .evalf() + ), + float( + -sp.Derivative(function("x")) + .as_finite_difference(1e-3) + .subs({"x": value}) + .evalf() + ), + ) def test_integrate2c(): From 6bd6aed914cfc87bdf1a2b54ec39c6bc8fc3c654 Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Tue, 1 Oct 2024 15:45:22 +0200 Subject: [PATCH 16/47] Add suggestions from code review --- test/unit/ode/test_ode.py | 33 ++++++++++++++++----------------- 1 file changed, 16 insertions(+), 17 deletions(-) diff --git a/test/unit/ode/test_ode.py b/test/unit/ode/test_ode.py index 21ee8a88f9..e3a25b06e5 100644 --- a/test/unit/ode/test_ode.py +++ b/test/unit/ode/test_ode.py @@ -107,23 +107,22 @@ def test_differentiate2c(): {}, ) # instead of comparing the expression as a string, we convert the string - # back to an expression and insert various functions - for function in [sp.sin, sp.exp, sp.tanh]: - for value in np.linspace(-5, 5, 100): - np.testing.assert_allclose( - float( - sp.sympify(result) - .subs(sp.Function("f"), function) - .subs({"x": value}) - .evalf() - ), - float( - -sp.Derivative(function("x")) - .as_finite_difference(1e-3) - .subs({"x": value}) - .evalf() - ), - ) + # back to an expression and compare with an explicit function + for value in np.linspace(-5, 5, 100): + np.testing.assert_allclose( + float( + sp.sympify(result) + .subs(sp.Function("f"), sp.sin) + .subs({"x": value}) + .evalf() + ), + float( + -sp.Derivative(sp.sin("x")) + .as_finite_difference(1e-3) + .subs({"x": value}) + .evalf() + ), + ) def test_integrate2c(): From 0eba407672a868fdcc24a5c66e84fa7d175ca3ba Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Wed, 2 Oct 2024 11:33:14 +0200 Subject: [PATCH 17/47] Add `stepsize` param to `differentiate2c` --- python/nmodl/ode.py | 14 ++++++++++++-- test/unit/ode/test_ode.py | 8 ++++++++ 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/python/nmodl/ode.py b/python/nmodl/ode.py index 66b3a752e2..e40cb47c62 100644 --- a/python/nmodl/ode.py +++ b/python/nmodl/ode.py @@ -568,7 +568,13 @@ def forwards_euler2c(diff_string, dt_var, vars, function_calls): return f"{sp.ccode(x)} = {sp.ccode(solution, user_functions=custom_fcts)}" -def differentiate2c(expression, dependent_var, vars, prev_expressions=None): +def differentiate2c( + expression, + dependent_var, + vars, + prev_expressions=None, + stepsize=1e-3, +): """Analytically differentiate supplied expression, return solution as C code. Expression should be of the form "f(x)", where "x" is @@ -595,11 +601,15 @@ def differentiate2c(expression, dependent_var, vars, prev_expressions=None): vars: set of all other variables used in expression, e.g. {"a", "b", "c"} prev_expressions: time-ordered list of preceeding expressions to evaluate & substitute, e.g. ["b = x + c", "a = 12*b"] + stepsize: in case an analytic expression is not possible, finite differences are used; + this argument sets the step size Returns: string containing analytic derivative of expression (including any substitutions of variables from supplied prev_expressions) w.r.t. dependent_var as C code. """ + if stepsize <= 0: + raise ValueError("arg `stepsize` must be > 0") prev_expressions = prev_expressions or [] # every symbol (a.k.a variable) that SymPy # is going to manipulate needs to be declared @@ -647,7 +657,7 @@ def differentiate2c(expression, dependent_var, vars, prev_expressions=None): if needs_finite_differences(diff): diff = ( transform_expression(diff, discretize_derivative) - .subs({finite_difference_step_variable(x): 1e-3}) + .subs({finite_difference_step_variable(x): stepsize}) .evalf() ) diff --git a/test/unit/ode/test_ode.py b/test/unit/ode/test_ode.py index e3a25b06e5..390c938f9a 100644 --- a/test/unit/ode/test_ode.py +++ b/test/unit/ode/test_ode.py @@ -5,6 +5,7 @@ from nmodl.ode import differentiate2c, integrate2c import numpy as np +import pytest import sympy as sp @@ -123,6 +124,13 @@ def test_differentiate2c(): .evalf() ), ) + with pytest.raises(ValueError): + differentiate2c( + "-f(x)", + "x", + {}, + stepsize=-1, + ) def test_integrate2c(): From c1e7fd3bb9657d07f5d567e435d1893e4a8d06c2 Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Wed, 2 Oct 2024 12:45:44 +0200 Subject: [PATCH 18/47] Try Python 3.9 maybe? --- .github/workflows/nmodl-ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/nmodl-ci.yml b/.github/workflows/nmodl-ci.yml index 09a2a4d20d..3d981241cc 100644 --- a/.github/workflows/nmodl-ci.yml +++ b/.github/workflows/nmodl-ci.yml @@ -16,7 +16,7 @@ on: env: CTEST_PARALLEL_LEVEL: 1 - PYTHON_VERSION: 3.8 + PYTHON_VERSION: 3.9 DESIRED_CMAKE_VERSION: 3.15.0 jobs: From 4fde9295ef4c00e79322d3fe14c379715efc59d4 Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Mon, 7 Oct 2024 09:40:27 +0200 Subject: [PATCH 19/47] Put back Python 3.8 for now --- .github/workflows/nmodl-ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/nmodl-ci.yml b/.github/workflows/nmodl-ci.yml index 3d981241cc..09a2a4d20d 100644 --- a/.github/workflows/nmodl-ci.yml +++ b/.github/workflows/nmodl-ci.yml @@ -16,7 +16,7 @@ on: env: CTEST_PARALLEL_LEVEL: 1 - PYTHON_VERSION: 3.9 + PYTHON_VERSION: 3.8 DESIRED_CMAKE_VERSION: 3.15.0 jobs: From a08df258e71a8105f26f294bc0b23978c70bd8bd Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Mon, 7 Oct 2024 13:37:27 +0200 Subject: [PATCH 20/47] Remove remaining occurrences of `DerivativeOriginalVisitor` --- src/visitors/cvode_visitor.cpp | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/visitors/cvode_visitor.cpp b/src/visitors/cvode_visitor.cpp index 0ed1a97573..da053061d2 100644 --- a/src/visitors/cvode_visitor.cpp +++ b/src/visitors/cvode_visitor.cpp @@ -36,7 +36,7 @@ static auto get_name_map(const ast::Expression& node, const std::string& name) { return var->get_node_name() == item; })) { logger->debug( - "DerivativeOriginalVisitor :: adding INDEXED_VARIABLE {} to " + "CvodeVisitor :: adding INDEXED_VARIABLE {} to " "node_map", var->get_node_name()); name_map[var->get_node_name()] = get_index( @@ -103,13 +103,12 @@ void CvodeVisitor::visit_binary_expression(ast::BinaryExpression& node) { auto [jacobian, exception_message] = diff2c(to_nmodl(*rhs), name->get_node_name(), name_map); if (!exception_message.empty()) { - logger->warn("DerivativeOriginalVisitor :: python exception: {}", - exception_message); + logger->warn("CvodeVisitor :: python exception: {}", exception_message); } // NOTE: LHS can be anything here, the equality is to keep `create_statement` from // complaining, we discard the LHS later auto statement = fmt::format("{} = {} / (1 - dt * ({}))", varname, varname, jacobian); - logger->debug("DerivativeOriginalVisitor :: replacing statement {} with {}", + logger->debug("CvodeVisitor :: replacing statement {} with {}", to_nmodl(node), statement); auto expr_statement = std::dynamic_pointer_cast( From 9fee9a8d42092724bf69c125fd85688a75e88686 Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Tue, 8 Oct 2024 11:50:04 +0200 Subject: [PATCH 21/47] WIP on CONSERVE --- src/visitors/cvode_visitor.cpp | 40 ++++++++++++++++++++++++++++++++++ src/visitors/cvode_visitor.hpp | 4 ++++ 2 files changed, 44 insertions(+) diff --git a/src/visitors/cvode_visitor.cpp b/src/visitors/cvode_visitor.cpp index da053061d2..7cb7439e6a 100644 --- a/src/visitors/cvode_visitor.cpp +++ b/src/visitors/cvode_visitor.cpp @@ -24,6 +24,20 @@ static int get_index(const ast::IndexedName& node) { return std::stoi(to_nmodl(node.get_length())); } +void CvodeVisitor::visit_conserve(ast::Conserve& node) { + logger->debug("CvodeVisitor :: CONSERVE statement: {}", to_nmodl(node)); + std::string conserve_equation_statevar; + if (node.get_react()->is_react_var_name()) { + conserve_equation_statevar = node.get_react()->get_node_name(); + } + auto conserve_equation_str = to_nmodl(*node.get_expr()); + logger->debug("CvodeVisitor :: --> replace ODE for state var {} with equation {}", + conserve_equation_statevar, + conserve_equation_str); + conserve_equations[conserve_equation_statevar] = conserve_equation_str; +} + + static auto get_name_map(const ast::Expression& node, const std::string& name) { std::unordered_map name_map; // all of the "reserved" symbols @@ -95,6 +109,28 @@ void CvodeVisitor::visit_binary_expression(ast::BinaryExpression& node) { symbol->set_original_name(name->get_node_name()); program_symtab->insert(symbol); } + // case: there is a variable being CONSERVEd, but it's not the current one + if (!conserve_equations.empty() && !conserve_equations.count(name->get_node_name())) { + auto rhs = node.get_rhs(); + auto nodes = collect_nodes(*node.get_rhs(), {ast::AstNodeType::VAR_NAME}); + for (auto& n: nodes) { + if (conserve_equations.count(n->get_node_name())) { + auto statement = fmt::format("{} = {}", n->get_node_name(), conserve_equations[n->get_node_name()]); + logger->debug("CvodeVisitor :: replacing CONSERVEd variable {} with {} in {}", + n->get_node_name(), + conserve_equations[n->get_node_name()], + to_nmodl(*node.get_rhs())); + auto expr_statement = std::dynamic_pointer_cast( + create_statement(statement)); + const auto bin_expr = std::dynamic_pointer_cast( + expr_statement->get_expression()); + auto thing = std::shared_ptr(bin_expr->get_rhs()->clone()); + n = std::move(std::dynamic_pointer_cast(thing)); + std::cout << to_nmodl(*n) << std::endl; + } + } + } + std::cout << to_nmodl(node) << std::endl; if (block_index == BlockIndex::JACOBIAN) { auto rhs = node.get_rhs(); // map of all indexed symbols (need special treatment in SymPy) @@ -131,6 +167,10 @@ void CvodeVisitor::visit_program(ast::Program& node) { node.emplace_back_node(der_node); } + for (const auto& [key, value]: conserve_equations) { + std::cout << key << ", " << value << std::endl; + } + // re-visit the AST since we now inserted the CVODE block node.visit_children(*this); } diff --git a/src/visitors/cvode_visitor.hpp b/src/visitors/cvode_visitor.hpp index 7dcef42839..bcd561e973 100644 --- a/src/visitors/cvode_visitor.hpp +++ b/src/visitors/cvode_visitor.hpp @@ -55,6 +55,9 @@ class CvodeVisitor: public AstVisitor { /// index of the block to modify BlockIndex block_index = BlockIndex::FUNCTION; + /// map of state vars to conserve equations + std::unordered_map conserve_equations; + public: void visit_derivative_block(ast::DerivativeBlock& node) override; void visit_program(ast::Program& node) override; @@ -62,6 +65,7 @@ class CvodeVisitor: public AstVisitor { void visit_diff_eq_expression(ast::DiffEqExpression& node) override; void visit_binary_expression(ast::BinaryExpression& node) override; void visit_statement_block(ast::StatementBlock& node) override; + void visit_conserve(ast::Conserve& node) override; }; /** \} */ // end of visitor_classes From d98fcc00d5e21947ebd1aa98a68afc9a57a7d8a5 Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Wed, 9 Oct 2024 16:54:24 +0200 Subject: [PATCH 22/47] Ignore CONSERVE equations They are just hints to the NMODL compiler, but they are not at all necessary to use when solving the ODEs. --- src/visitors/cvode_visitor.cpp | 42 ++++++---------------------------- src/visitors/cvode_visitor.hpp | 5 ++-- 2 files changed, 10 insertions(+), 37 deletions(-) diff --git a/src/visitors/cvode_visitor.cpp b/src/visitors/cvode_visitor.cpp index 7cb7439e6a..0f87c2242a 100644 --- a/src/visitors/cvode_visitor.cpp +++ b/src/visitors/cvode_visitor.cpp @@ -25,16 +25,11 @@ static int get_index(const ast::IndexedName& node) { } void CvodeVisitor::visit_conserve(ast::Conserve& node) { - logger->debug("CvodeVisitor :: CONSERVE statement: {}", to_nmodl(node)); - std::string conserve_equation_statevar; - if (node.get_react()->is_react_var_name()) { - conserve_equation_statevar = node.get_react()->get_node_name(); + if (in_cvode_block) { + logger->warn("CvodeVisitor :: CONSERVE statement {} will be ignored in CVODE codegen", + to_nmodl(node)); + conserve_equations.emplace(&node); } - auto conserve_equation_str = to_nmodl(*node.get_expr()); - logger->debug("CvodeVisitor :: --> replace ODE for state var {} with equation {}", - conserve_equation_statevar, - conserve_equation_str); - conserve_equations[conserve_equation_statevar] = conserve_equation_str; } @@ -109,28 +104,6 @@ void CvodeVisitor::visit_binary_expression(ast::BinaryExpression& node) { symbol->set_original_name(name->get_node_name()); program_symtab->insert(symbol); } - // case: there is a variable being CONSERVEd, but it's not the current one - if (!conserve_equations.empty() && !conserve_equations.count(name->get_node_name())) { - auto rhs = node.get_rhs(); - auto nodes = collect_nodes(*node.get_rhs(), {ast::AstNodeType::VAR_NAME}); - for (auto& n: nodes) { - if (conserve_equations.count(n->get_node_name())) { - auto statement = fmt::format("{} = {}", n->get_node_name(), conserve_equations[n->get_node_name()]); - logger->debug("CvodeVisitor :: replacing CONSERVEd variable {} with {} in {}", - n->get_node_name(), - conserve_equations[n->get_node_name()], - to_nmodl(*node.get_rhs())); - auto expr_statement = std::dynamic_pointer_cast( - create_statement(statement)); - const auto bin_expr = std::dynamic_pointer_cast( - expr_statement->get_expression()); - auto thing = std::shared_ptr(bin_expr->get_rhs()->clone()); - n = std::move(std::dynamic_pointer_cast(thing)); - std::cout << to_nmodl(*n) << std::endl; - } - } - } - std::cout << to_nmodl(node) << std::endl; if (block_index == BlockIndex::JACOBIAN) { auto rhs = node.get_rhs(); // map of all indexed symbols (need special treatment in SymPy) @@ -167,12 +140,11 @@ void CvodeVisitor::visit_program(ast::Program& node) { node.emplace_back_node(der_node); } - for (const auto& [key, value]: conserve_equations) { - std::cout << key << ", " << value << std::endl; - } - // re-visit the AST since we now inserted the CVODE block node.visit_children(*this); + if (!conserve_equations.empty()) { + node.erase_node(conserve_equations); + } } } // namespace visitor diff --git a/src/visitors/cvode_visitor.hpp b/src/visitors/cvode_visitor.hpp index bcd561e973..f324b2b075 100644 --- a/src/visitors/cvode_visitor.hpp +++ b/src/visitors/cvode_visitor.hpp @@ -15,6 +15,7 @@ #include "symtab/decl.hpp" #include "visitors/ast_visitor.hpp" #include +#include namespace nmodl { namespace visitor { @@ -55,8 +56,8 @@ class CvodeVisitor: public AstVisitor { /// index of the block to modify BlockIndex block_index = BlockIndex::FUNCTION; - /// map of state vars to conserve equations - std::unordered_map conserve_equations; + /// list of conserve equations encountered + std::unordered_set conserve_equations; public: void visit_derivative_block(ast::DerivativeBlock& node) override; From 321cdb395e874bca49df4475b5a7dd4610cbb847 Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Thu, 10 Oct 2024 10:35:22 +0200 Subject: [PATCH 23/47] Add documentation --- docs/contents/cvode.rst | 74 +++++++++++++++++++++++++++++++++++++++++ docs/index.rst | 1 + 2 files changed, 75 insertions(+) create mode 100644 docs/contents/cvode.rst diff --git a/docs/contents/cvode.rst b/docs/contents/cvode.rst new file mode 100644 index 0000000000..87ecdb309c --- /dev/null +++ b/docs/contents/cvode.rst @@ -0,0 +1,74 @@ +Variable timestep integration (CVODE) +===================================== + +As opposed to fixed timestep integration, variable timestep integration (CVODE +in NEURON parlance) uses the SUNDIALS package to solve a ``DERIVATIVE`` or +``KINETIC`` block using a variable timestep. This allows for faster computation +times if the function in question does not vary too wildly. + +Implementation in NMODL +----------------------- + +The code generation for CVODE is activated only if exactly one of the following +is satisfied: + +1. there is one ``KINETIC`` block in the mod file +2. there is one ``DERIVATIVE`` block in the mod file +3. a ``PROCEDURE`` block is solved with the ``after_cvode``, ``cvode_t``, or + ``cvode_t_v`` methods + +In NMODL, all ``KINETIC`` blocks are internally first converted to +``DERIVATIVE`` blocks. The ``DERIVATIVE`` block is then converted to a +``CVODE`` block, which contains two parts; the first part contains the update +step for linear systems, while the second part contains the update step for +non-linear systems (see `CVODES documentation`_, eqs. (4.8) and (4.9)). Given +a ``DERIVATIVE`` block of the form: + +.. _CVODES documentation: https://sundials.readthedocs.io/en/latest/cvodes/Mathematics_link.html + +.. code-block:: + + DERIVATIVE state { + x_i' = f(x_1, ..., x_n) + } + +the structure of the ``CVODE`` block is then roughly: + +.. code-block:: + + CVODE state { + Dx_i = f_i(x_1, ..., x_n) + }{ + Dx_i = Dx_i / (1 - dt * J_ii(f)) + } + +where ``J_ii(f)`` is the diagonal part of the Jacobian, i.e. + +.. math:: + + J_{ii}(f) = \frac{ \partial f_i(x_1, \ldots, x_n) }{\partial x_i} + +As an example, consider the following ``DERIVATIVE`` +block: + +.. code-block:: + + DERIVATIVE state { + X' = - X + } + +Where ``X`` is a ``STATE`` variable with some initial value, specified in the +``INITIAL`` block. The corresponding ``CVODE`` block is then: + +.. code-block:: + + CVODE state { + DX = - X + }{ + DX = DX / (1 - dt * (-1)) + } + + +**NOTE**: in case there are ``CONSERVE`` statements in ``KINETIC`` blocks, as +they are merely hints to NMODL, and have no impact on the results, they are +removed from ``CVODE`` blocks before the codegen stage. diff --git a/docs/index.rst b/docs/index.rst index 9c4b0105ee..15125ef4a6 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -24,6 +24,7 @@ About NMODL contents/pointers contents/cable_equations contents/globals + contents/cvode .. toctree:: :maxdepth: 3 From 2984e46f25ff44a5af09c5f77262982cbb85803f Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Thu, 10 Oct 2024 10:47:13 +0200 Subject: [PATCH 24/47] Really delete CONSERVE statements this time --- src/visitors/cvode_visitor.cpp | 7 +++++-- src/visitors/cvode_visitor.hpp | 2 +- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/visitors/cvode_visitor.cpp b/src/visitors/cvode_visitor.cpp index 0f87c2242a..302b541d5d 100644 --- a/src/visitors/cvode_visitor.cpp +++ b/src/visitors/cvode_visitor.cpp @@ -26,7 +26,7 @@ static int get_index(const ast::IndexedName& node) { void CvodeVisitor::visit_conserve(ast::Conserve& node) { if (in_cvode_block) { - logger->warn("CvodeVisitor :: CONSERVE statement {} will be ignored in CVODE codegen", + logger->warn("CvodeVisitor :: statement {} will be ignored in CVODE codegen", to_nmodl(node)); conserve_equations.emplace(&node); } @@ -143,7 +143,10 @@ void CvodeVisitor::visit_program(ast::Program& node) { // re-visit the AST since we now inserted the CVODE block node.visit_children(*this); if (!conserve_equations.empty()) { - node.erase_node(conserve_equations); + auto blocks = collect_nodes(node, {ast::AstNodeType::CVODE_BLOCK}); + auto block = std::dynamic_pointer_cast(blocks[0]); + block->get_function_block()->erase_statement(conserve_equations); + block->get_diagonal_jacobian_block()->erase_statement(conserve_equations); } } diff --git a/src/visitors/cvode_visitor.hpp b/src/visitors/cvode_visitor.hpp index f324b2b075..716600be20 100644 --- a/src/visitors/cvode_visitor.hpp +++ b/src/visitors/cvode_visitor.hpp @@ -57,7 +57,7 @@ class CvodeVisitor: public AstVisitor { BlockIndex block_index = BlockIndex::FUNCTION; /// list of conserve equations encountered - std::unordered_set conserve_equations; + std::unordered_set conserve_equations; public: void visit_derivative_block(ast::DerivativeBlock& node) override; From 313330b29c398605147b8b3b94f61fec29830a95 Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Thu, 10 Oct 2024 10:50:15 +0200 Subject: [PATCH 25/47] Add test for CONSERVE statement --- test/unit/visitor/cvode.cpp | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/test/unit/visitor/cvode.cpp b/test/unit/visitor/cvode.cpp index 3a57d242d4..a9f11a4219 100644 --- a/test/unit/visitor/cvode.cpp +++ b/test/unit/visitor/cvode.cpp @@ -37,6 +37,7 @@ TEST_CASE("Make sure CVODE block is generated properly", "[visitor][cvode]") { STATE {x z} DERIVATIVE equation { + CONSERVE x + z = 5 x' = -x + z * z z' = z * x } @@ -49,6 +50,10 @@ TEST_CASE("Make sure CVODE block is generated properly", "[visitor][cvode]") { auto primed_vars = collect_nodes(*block[0], {ast::AstNodeType::PRIME_NAME}); REQUIRE(primed_vars.empty()); } + THEN("No CONSERVE statements are present in the CVODE block") { + auto conserved_stmts = collect_nodes(*block[0], {ast::AstNodeType::CONSERVE}); + REQUIRE(conserved_stmts.empty()); + } } } } From 03b40e842a190249034e148051a1c41f07486264 Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Thu, 10 Oct 2024 11:45:56 +0200 Subject: [PATCH 26/47] Fix variable naming --- test/unit/visitor/cvode.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/unit/visitor/cvode.cpp b/test/unit/visitor/cvode.cpp index a9f11a4219..8ed1f595f6 100644 --- a/test/unit/visitor/cvode.cpp +++ b/test/unit/visitor/cvode.cpp @@ -44,14 +44,14 @@ TEST_CASE("Make sure CVODE block is generated properly", "[visitor][cvode]") { )"; auto ast = run_cvode_visitor(nmodl_text); THEN("CVODE block is added") { - auto block = collect_nodes(*ast, {ast::AstNodeType::CVODE_BLOCK}); - REQUIRE(!block.empty()); + auto blocks = collect_nodes(*ast, {ast::AstNodeType::CVODE_BLOCK}); + REQUIRE(blocks.size() == 1); THEN("No primed variables exist in the CVODE block") { - auto primed_vars = collect_nodes(*block[0], {ast::AstNodeType::PRIME_NAME}); + auto primed_vars = collect_nodes(*blocks[0], {ast::AstNodeType::PRIME_NAME}); REQUIRE(primed_vars.empty()); } THEN("No CONSERVE statements are present in the CVODE block") { - auto conserved_stmts = collect_nodes(*block[0], {ast::AstNodeType::CONSERVE}); + auto conserved_stmts = collect_nodes(*blocks[0], {ast::AstNodeType::CONSERVE}); REQUIRE(conserved_stmts.empty()); } } From 836ec7498d91675785ff8f29d48c1e3f8356451d Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Thu, 10 Oct 2024 13:32:39 +0200 Subject: [PATCH 27/47] Update docstring --- docs/contents/cvode.rst | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/docs/contents/cvode.rst b/docs/contents/cvode.rst index 87ecdb309c..d70378aaf6 100644 --- a/docs/contents/cvode.rst +++ b/docs/contents/cvode.rst @@ -20,9 +20,10 @@ is satisfied: In NMODL, all ``KINETIC`` blocks are internally first converted to ``DERIVATIVE`` blocks. The ``DERIVATIVE`` block is then converted to a ``CVODE`` block, which contains two parts; the first part contains the update -step for linear systems, while the second part contains the update step for -non-linear systems (see `CVODES documentation`_, eqs. (4.8) and (4.9)). Given -a ``DERIVATIVE`` block of the form: +step for non-stiff systems (functional iteration), while the second part +contains the update step for stiff systems (additional step using the +Jacobian). For more information, see `CVODES documentation`_, eqs. (4.8) and +(4.9)). Given a ``DERIVATIVE`` block of the form: .. _CVODES documentation: https://sundials.readthedocs.io/en/latest/cvodes/Mathematics_link.html From 9f6b751b5c70d90742a406284dceb2d2ab1738fd Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Thu, 10 Oct 2024 13:46:34 +0200 Subject: [PATCH 28/47] Fix typo --- docs/contents/cvode.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/contents/cvode.rst b/docs/contents/cvode.rst index d70378aaf6..367e160d48 100644 --- a/docs/contents/cvode.rst +++ b/docs/contents/cvode.rst @@ -23,7 +23,7 @@ In NMODL, all ``KINETIC`` blocks are internally first converted to step for non-stiff systems (functional iteration), while the second part contains the update step for stiff systems (additional step using the Jacobian). For more information, see `CVODES documentation`_, eqs. (4.8) and -(4.9)). Given a ``DERIVATIVE`` block of the form: +(4.9). Given a ``DERIVATIVE`` block of the form: .. _CVODES documentation: https://sundials.readthedocs.io/en/latest/cvodes/Mathematics_link.html From 1348ab946456a3341ca9a8253a71717170214045 Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Thu, 10 Oct 2024 13:58:54 +0200 Subject: [PATCH 29/47] Update docstring --- src/pybind/wrapper.cpp | 2 -- src/pybind/wrapper.hpp | 6 ++++++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/pybind/wrapper.cpp b/src/pybind/wrapper.cpp index ae9d414976..65fe2b6d6b 100644 --- a/src/pybind/wrapper.cpp +++ b/src/pybind/wrapper.cpp @@ -186,8 +186,6 @@ except Exception as e: return {std::move(solution), std::move(exception_message)}; } -/// \brief A blunt instrument that differentiates expression w.r.t. variable -/// \return The tuple (solution, exception) std::tuple call_diff2c( const std::string& expression, const std::string& variable, diff --git a/src/pybind/wrapper.hpp b/src/pybind/wrapper.hpp index b4ec0a2dff..694b0143d7 100644 --- a/src/pybind/wrapper.hpp +++ b/src/pybind/wrapper.hpp @@ -45,6 +45,12 @@ std::tuple call_analytic_diff( const std::vector& expressions, const std::set& used_names_in_block); + +/// \brief Differentiates an expression with respect to a variable +/// \param expression The expression we want to differentiate +/// \param variable The name of the independent variable we are differentiating against +/// \param index_vars A map of array (indexable) variables (and their associated indices) that +/// appear in \ref expression \return The tuple (solution, exception) std::tuple call_diff2c( const std::string& expression, const std::string& variable, From ee9c187afd155d3d0193ecfbdf7c1349b5cf7785 Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Mon, 14 Oct 2024 13:02:34 +0200 Subject: [PATCH 30/47] Add option for diffing IndexedName --- python/nmodl/ode.py | 2 +- src/pybind/wrapper.cpp | 19 +++++++++++---- src/pybind/wrapper.hpp | 3 ++- src/visitors/cvode_visitor.cpp | 43 ++++++++++++++++++++++++---------- 4 files changed, 48 insertions(+), 19 deletions(-) diff --git a/python/nmodl/ode.py b/python/nmodl/ode.py index cd6b2b27ae..8219169db8 100644 --- a/python/nmodl/ode.py +++ b/python/nmodl/ode.py @@ -619,7 +619,7 @@ def differentiate2c( # every symbol (a.k.a variable) that SymPy # is going to manipulate needs to be declared # explicitly - x = sp.symbols(dependent_var, real=True) + x = make_symbol(dependent_var) vars = set(vars) vars.discard(dependent_var) # declare all other supplied variables diff --git a/src/pybind/wrapper.cpp b/src/pybind/wrapper.cpp index 65fe2b6d6b..8385954ed7 100644 --- a/src/pybind/wrapper.cpp +++ b/src/pybind/wrapper.cpp @@ -10,6 +10,7 @@ #include "codegen/codegen_naming.hpp" #include "pybind/pyembed.hpp" #include +#include #include #include @@ -188,17 +189,26 @@ except Exception as e: std::tuple call_diff2c( const std::string& expression, - const std::string& variable, + const std::pair>& variable, const std::unordered_map& indexed_vars) { std::string statements; // only indexed variables require special treatment for (const auto& [var, prop]: indexed_vars) { statements += fmt::format("_allvars.append(sp.IndexedBase('{}', shape=[1]))\n", var); } - auto locals = py::dict("expression"_a = expression, "variable"_a = variable); - std::string script = fmt::format(R"( + auto [name, property] = variable; + if (property.has_value()) { + name = fmt::format("sp.IndexedBase('{}', shape=[1])", name); + statements += fmt::format("_allvars.append({})", name); + } else { + name = fmt::format("'{}'", name); + } + auto locals = py::dict("expression"_a = expression); + std::string script = + fmt::format(R"( _allvars = [] {} +variable = {} exception_message = "" try: solution = differentiate2c(expression, @@ -210,7 +220,8 @@ except Exception as e: solution = "" exception_message = str(e) )", - statements); + statements, + property.has_value() ? fmt::format("{}[{}]", name, property.value()) : name); py::exec(nmodl::pybind_wrappers::ode_py + script, locals); diff --git a/src/pybind/wrapper.hpp b/src/pybind/wrapper.hpp index 694b0143d7..aad85aef25 100644 --- a/src/pybind/wrapper.hpp +++ b/src/pybind/wrapper.hpp @@ -7,6 +7,7 @@ #pragma once +#include #include #include #include @@ -53,7 +54,7 @@ std::tuple call_analytic_diff( /// appear in \ref expression \return The tuple (solution, exception) std::tuple call_diff2c( const std::string& expression, - const std::string& variable, + const std::pair>& variable, const std::unordered_map& indexed_vars = {}); struct pybind_wrap_api { diff --git a/src/visitors/cvode_visitor.cpp b/src/visitors/cvode_visitor.cpp index 302b541d5d..f3979474fb 100644 --- a/src/visitors/cvode_visitor.cpp +++ b/src/visitors/cvode_visitor.cpp @@ -24,6 +24,16 @@ static int get_index(const ast::IndexedName& node) { return std::stoi(to_nmodl(node.get_length())); } +static std::pair> parse_independent_var( + std::shared_ptr node) { + auto variable = std::make_pair(node->get_node_name(), std::optional()); + if (node->is_indexed_name()) { + variable.second = std::optional( + get_index(*std::dynamic_pointer_cast(node))); + } + return variable; +} + void CvodeVisitor::visit_conserve(ast::Conserve& node) { if (in_cvode_block) { logger->warn("CvodeVisitor :: statement {} will be ignored in CVODE codegen", @@ -92,25 +102,32 @@ void CvodeVisitor::visit_binary_expression(ast::BinaryExpression& node) { auto name = std::dynamic_pointer_cast(lhs)->get_name(); - if (name->is_prime_name()) { - auto varname = "D" + name->get_node_name(); - logger->debug("CvodeVisitor :: replacing {} with {} on LHS of {}", - name->get_node_name(), - varname, - to_nmodl(node)); - node.set_lhs(std::make_shared(new ast::String(varname))); - if (program_symtab->lookup(varname) == nullptr) { - auto symbol = std::make_shared(varname, ModToken()); - symbol->set_original_name(name->get_node_name()); - program_symtab->insert(symbol); + if (name->is_prime_name() || name->is_indexed_name()) { + std::string varname; + if (name->is_prime_name()) { + varname = "D" + name->get_node_name(); + node.set_lhs(std::make_shared(new ast::String(varname))); + if (program_symtab->lookup(varname) == nullptr) { + auto symbol = std::make_shared(varname, ModToken()); + symbol->set_original_name(name->get_node_name()); + program_symtab->insert(symbol); + } + } else { + varname = "D" + stringutils::remove_character(to_nmodl(node.get_lhs()), '\''); + auto statement = fmt::format("{} = {}", varname, varname); + auto expr_statement = std::dynamic_pointer_cast( + create_statement(statement)); + const auto bin_expr = std::dynamic_pointer_cast( + expr_statement->get_expression()); + node.set_lhs(std::shared_ptr(bin_expr->get_lhs()->clone())); } if (block_index == BlockIndex::JACOBIAN) { auto rhs = node.get_rhs(); // map of all indexed symbols (need special treatment in SymPy) auto name_map = get_name_map(*rhs, name->get_node_name()); auto diff2c = pywrap::EmbeddedPythonLoader::get_instance().api().diff2c; - auto [jacobian, - exception_message] = diff2c(to_nmodl(*rhs), name->get_node_name(), name_map); + auto [jacobian, exception_message] = + diff2c(to_nmodl(*rhs), parse_independent_var(name), name_map); if (!exception_message.empty()) { logger->warn("CvodeVisitor :: python exception: {}", exception_message); } From 54a480e0da3be17961d7cae4172f775f58e10793 Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Mon, 14 Oct 2024 16:58:50 +0200 Subject: [PATCH 31/47] Refactor --- src/visitors/cvode_visitor.cpp | 216 +++++++++++++++++++-------------- src/visitors/cvode_visitor.hpp | 35 ------ 2 files changed, 126 insertions(+), 125 deletions(-) diff --git a/src/visitors/cvode_visitor.cpp b/src/visitors/cvode_visitor.cpp index f3979474fb..8d40764ea2 100644 --- a/src/visitors/cvode_visitor.cpp +++ b/src/visitors/cvode_visitor.cpp @@ -34,16 +34,8 @@ static std::pair> parse_independent_var( return variable; } -void CvodeVisitor::visit_conserve(ast::Conserve& node) { - if (in_cvode_block) { - logger->warn("CvodeVisitor :: statement {} will be ignored in CVODE codegen", - to_nmodl(node)); - conserve_equations.emplace(&node); - } -} - - -static auto get_name_map(const ast::Expression& node, const std::string& name) { +static std::unordered_map get_name_map(const ast::Expression& node, + const std::string& name) { std::unordered_map name_map; // all of the "reserved" symbols auto reserved_symbols = get_external_functions(); @@ -65,105 +57,149 @@ static auto get_name_map(const ast::Expression& node, const std::string& name) { return name_map; } -void CvodeVisitor::visit_derivative_block(ast::DerivativeBlock& node) { - node.visit_children(*this); - derivative_block = std::shared_ptr(node.clone()); -} +static std::string cvode_set_lhs(ast::BinaryExpression& node) { + const auto& lhs = node.get_lhs(); + auto name = std::dynamic_pointer_cast(lhs)->get_name(); -void CvodeVisitor::visit_cvode_block(ast::CvodeBlock& node) { - in_cvode_block = true; - node.visit_children(*this); - in_cvode_block = false; + std::string varname; + if (name->is_prime_name()) { + varname = "D" + name->get_node_name(); + node.set_lhs(std::make_shared(new ast::String(varname))); + } else if (name->is_indexed_name()) { + auto nodes = collect_nodes(*name, {ast::AstNodeType::PRIME_NAME}); + // make sure the LHS isn't just a plain indexed var + if (!nodes.empty()) { + varname = "D" + stringutils::remove_character(to_nmodl(node.get_lhs()), '\''); + auto statement = fmt::format("{} = {}", varname, varname); + auto expr_statement = std::dynamic_pointer_cast( + create_statement(statement)); + const auto bin_expr = std::dynamic_pointer_cast( + expr_statement->get_expression()); + node.set_lhs(std::shared_ptr(bin_expr->get_lhs()->clone())); + } + } + return varname; } -void CvodeVisitor::visit_diff_eq_expression(ast::DiffEqExpression& node) { - in_differential_equation = true; - node.visit_children(*this); - in_differential_equation = false; -} +class CvodeHelperVisitor: public AstVisitor { + protected: + symtab::SymbolTable* program_symtab = nullptr; + bool in_differential_equation = false; + std::unordered_set conserve_equations; + public: + inline void visit_diff_eq_expression(ast::DiffEqExpression& node) { + in_differential_equation = true; + node.visit_children(*this); + in_differential_equation = false; + } +}; -void CvodeVisitor::visit_statement_block(ast::StatementBlock& node) { - node.visit_children(*this); - if (in_cvode_block) { - ++block_index; +class NonStiffVisitor: public CvodeHelperVisitor { + public: + NonStiffVisitor(symtab::SymbolTable* symtab) { + program_symtab = symtab; } -} + inline void visit_binary_expression(ast::BinaryExpression& node) { + const auto& lhs = node.get_lhs(); -void CvodeVisitor::visit_binary_expression(ast::BinaryExpression& node) { - const auto& lhs = node.get_lhs(); + if (!in_differential_equation || !lhs->is_var_name()) { + return; + } + + auto name = std::dynamic_pointer_cast(lhs)->get_name(); + auto varname = cvode_set_lhs(node); - /// we have to only solve ODEs under original derivative block where lhs is variable - if (!in_cvode_block || !in_differential_equation || !lhs->is_var_name()) { - return; + if (program_symtab->lookup(varname) == nullptr) { + auto symbol = std::make_shared(varname, ModToken()); + symbol->set_original_name(name->get_node_name()); + program_symtab->insert(symbol); + } } +}; - auto name = std::dynamic_pointer_cast(lhs)->get_name(); +class StiffVisitor: public CvodeHelperVisitor { + public: + StiffVisitor(symtab::SymbolTable* symtab) { + program_symtab = symtab; + } - if (name->is_prime_name() || name->is_indexed_name()) { - std::string varname; - if (name->is_prime_name()) { - varname = "D" + name->get_node_name(); - node.set_lhs(std::make_shared(new ast::String(varname))); - if (program_symtab->lookup(varname) == nullptr) { - auto symbol = std::make_shared(varname, ModToken()); - symbol->set_original_name(name->get_node_name()); - program_symtab->insert(symbol); - } - } else { - varname = "D" + stringutils::remove_character(to_nmodl(node.get_lhs()), '\''); - auto statement = fmt::format("{} = {}", varname, varname); - auto expr_statement = std::dynamic_pointer_cast( - create_statement(statement)); - const auto bin_expr = std::dynamic_pointer_cast( - expr_statement->get_expression()); - node.set_lhs(std::shared_ptr(bin_expr->get_lhs()->clone())); + inline void visit_binary_expression(ast::BinaryExpression& node) { + const auto& lhs = node.get_lhs(); + + if (!in_differential_equation || !lhs->is_var_name()) { + return; } - if (block_index == BlockIndex::JACOBIAN) { - auto rhs = node.get_rhs(); - // map of all indexed symbols (need special treatment in SymPy) - auto name_map = get_name_map(*rhs, name->get_node_name()); - auto diff2c = pywrap::EmbeddedPythonLoader::get_instance().api().diff2c; - auto [jacobian, exception_message] = - diff2c(to_nmodl(*rhs), parse_independent_var(name), name_map); - if (!exception_message.empty()) { - logger->warn("CvodeVisitor :: python exception: {}", exception_message); - } - // NOTE: LHS can be anything here, the equality is to keep `create_statement` from - // complaining, we discard the LHS later - auto statement = fmt::format("{} = {} / (1 - dt * ({}))", varname, varname, jacobian); - logger->debug("CvodeVisitor :: replacing statement {} with {}", - to_nmodl(node), - statement); - auto expr_statement = std::dynamic_pointer_cast( - create_statement(statement)); - const auto bin_expr = std::dynamic_pointer_cast( - expr_statement->get_expression()); - node.set_rhs(std::shared_ptr(bin_expr->get_rhs()->clone())); + + auto name = std::dynamic_pointer_cast(lhs)->get_name(); + auto varname = cvode_set_lhs(node); + + if (program_symtab->lookup(varname) == nullptr) { + auto symbol = std::make_shared(varname, ModToken()); + symbol->set_original_name(name->get_node_name()); + program_symtab->insert(symbol); + } + + auto rhs = node.get_rhs(); + // map of all indexed symbols (need special treatment in SymPy) + auto name_map = get_name_map(*rhs, name->get_node_name()); + auto diff2c = pywrap::EmbeddedPythonLoader::get_instance().api().diff2c; + auto [jacobian, + exception_message] = diff2c(to_nmodl(*rhs), parse_independent_var(name), name_map); + if (!exception_message.empty()) { + logger->warn("CvodeVisitor :: python exception: {}", exception_message); } + // NOTE: LHS can be anything here, the equality is to keep `create_statement` from + // complaining, we discard the LHS later + auto statement = fmt::format("{} = {} / (1 - dt * ({}))", varname, varname, jacobian); + logger->debug("CvodeVisitor :: replacing statement {} with {}", to_nmodl(node), statement); + auto expr_statement = std::dynamic_pointer_cast( + create_statement(statement)); + const auto bin_expr = std::dynamic_pointer_cast( + expr_statement->get_expression()); + node.set_rhs(std::shared_ptr(bin_expr->get_rhs()->clone())); } -} +}; + void CvodeVisitor::visit_program(ast::Program& node) { - program_symtab = node.get_symbol_table(); - node.visit_children(*this); - if (derivative_block) { - auto der_node = new ast::CvodeBlock(derivative_block->get_name(), - derivative_block->get_statement_block(), - std::shared_ptr( - derivative_block->get_statement_block()->clone())); - node.emplace_back_node(der_node); - } + auto der_blocks = collect_nodes(node, {ast::AstNodeType::DERIVATIVE_BLOCK}); + if (!der_blocks.empty()) { + auto der_block = std::dynamic_pointer_cast(der_blocks[0]); + + auto non_stiff_block = der_block->get_statement_block()->clone(); + { + auto conserve_equations = collect_nodes(*non_stiff_block, {ast::AstNodeType::CONSERVE}); + if (!conserve_equations.empty()) { + std::unordered_set eqs; + for (const auto& item: conserve_equations) { + eqs.insert(std::dynamic_pointer_cast(item).get()); + } + non_stiff_block->erase_statement(eqs); + } + } + + auto stiff_block = der_block->get_statement_block()->clone(); + { + auto conserve_equations = collect_nodes(*stiff_block, {ast::AstNodeType::CONSERVE}); + if (!conserve_equations.empty()) { + std::unordered_set eqs; + for (const auto& item: conserve_equations) { + eqs.insert(std::dynamic_pointer_cast(item).get()); + } + stiff_block->erase_statement(eqs); + } + } + - // re-visit the AST since we now inserted the CVODE block - node.visit_children(*this); - if (!conserve_equations.empty()) { - auto blocks = collect_nodes(node, {ast::AstNodeType::CVODE_BLOCK}); - auto block = std::dynamic_pointer_cast(blocks[0]); - block->get_function_block()->erase_statement(conserve_equations); - block->get_diagonal_jacobian_block()->erase_statement(conserve_equations); + NonStiffVisitor(node.get_symbol_table()).visit_statement_block(*non_stiff_block); + StiffVisitor(node.get_symbol_table()).visit_statement_block(*stiff_block); + node.emplace_back_node( + new ast::CvodeBlock(der_block->get_name(), + std::shared_ptr(non_stiff_block), + std::shared_ptr(stiff_block))); } } diff --git a/src/visitors/cvode_visitor.hpp b/src/visitors/cvode_visitor.hpp index 716600be20..c8d37f6a61 100644 --- a/src/visitors/cvode_visitor.hpp +++ b/src/visitors/cvode_visitor.hpp @@ -20,16 +20,6 @@ namespace nmodl { namespace visitor { -enum class BlockIndex { FUNCTION = 0, JACOBIAN = 1 }; - -inline BlockIndex& operator++(BlockIndex& index) { - if (index == BlockIndex::FUNCTION) { - index = BlockIndex::JACOBIAN; - } else { - index = BlockIndex::FUNCTION; - } - return index; -} /** * \addtogroup visitor_classes * \{ @@ -40,33 +30,8 @@ inline BlockIndex& operator++(BlockIndex& index) { * \brief Visitor used for generating the necessary AST nodes for CVODE */ class CvodeVisitor: public AstVisitor { - private: - /// The copy of the derivative block of a given mod file - std::shared_ptr derivative_block = nullptr; - - /// true while visiting differential equation - bool in_differential_equation = false; - - /// global symbol table - symtab::SymbolTable* program_symtab = nullptr; - - /// true while we are visiting a CVODE block - bool in_cvode_block = false; - - /// index of the block to modify - BlockIndex block_index = BlockIndex::FUNCTION; - - /// list of conserve equations encountered - std::unordered_set conserve_equations; - public: - void visit_derivative_block(ast::DerivativeBlock& node) override; void visit_program(ast::Program& node) override; - void visit_cvode_block(ast::CvodeBlock& node) override; - void visit_diff_eq_expression(ast::DiffEqExpression& node) override; - void visit_binary_expression(ast::BinaryExpression& node) override; - void visit_statement_block(ast::StatementBlock& node) override; - void visit_conserve(ast::Conserve& node) override; }; /** \} */ // end of visitor_classes From cefc1596beff987cb7a7bd9313773963983c63c5 Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Mon, 14 Oct 2024 21:29:26 +0200 Subject: [PATCH 32/47] Enable sympy if NEURON codegen --- src/main.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main.cpp b/src/main.cpp index a32640276d..89c2813f23 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -523,7 +523,7 @@ int run_nmodl(int argc, const char* argv[]) { } - if (sympy_conductance || sympy_analytic) { + if (sympy_conductance || sympy_analytic || neuron_code) { nmodl::pybind_wrappers::EmbeddedPythonLoader::get_instance() .api() .initialize_interpreter(); From a3e1c6c23a177abd92d34ac731d8abb03d42adca Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Mon, 14 Oct 2024 21:49:57 +0200 Subject: [PATCH 33/47] Mark constructors as explicit --- src/visitors/cvode_visitor.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/visitors/cvode_visitor.cpp b/src/visitors/cvode_visitor.cpp index 8d40764ea2..d64ba7d214 100644 --- a/src/visitors/cvode_visitor.cpp +++ b/src/visitors/cvode_visitor.cpp @@ -98,7 +98,7 @@ class CvodeHelperVisitor: public AstVisitor { class NonStiffVisitor: public CvodeHelperVisitor { public: - NonStiffVisitor(symtab::SymbolTable* symtab) { + explicit NonStiffVisitor(symtab::SymbolTable* symtab) { program_symtab = symtab; } @@ -122,7 +122,7 @@ class NonStiffVisitor: public CvodeHelperVisitor { class StiffVisitor: public CvodeHelperVisitor { public: - StiffVisitor(symtab::SymbolTable* symtab) { + explicit StiffVisitor(symtab::SymbolTable* symtab) { program_symtab = symtab; } From 659b018f7f39384851ca934d4c2b9c7840316517 Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Tue, 15 Oct 2024 14:16:53 +0200 Subject: [PATCH 34/47] Update tests --- test/unit/visitor/cvode.cpp | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/test/unit/visitor/cvode.cpp b/test/unit/visitor/cvode.cpp index 8ed1f595f6..e42df2ec45 100644 --- a/test/unit/visitor/cvode.cpp +++ b/test/unit/visitor/cvode.cpp @@ -34,12 +34,14 @@ TEST_CASE("Make sure CVODE block is generated properly", "[visitor][cvode]") { SUFFIX example } - STATE {x z} + STATE {X Y[2] Z} DERIVATIVE equation { - CONSERVE x + z = 5 - x' = -x + z * z - z' = z * x + CONSERVE X + Z = 5 + X' = -X + Z * Z + Z' = Z * X + Y'[1] = -Y[0] + Y'[0] = -Y[1] } )"; auto ast = run_cvode_visitor(nmodl_text); From bd376c3ab5bcc431d909a8a9f22562658c87db0a Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Tue, 15 Oct 2024 15:58:50 +0200 Subject: [PATCH 35/47] Remove code duplication --- src/visitors/cvode_visitor.cpp | 34 +++++++++++++--------------------- 1 file changed, 13 insertions(+), 21 deletions(-) diff --git a/src/visitors/cvode_visitor.cpp b/src/visitors/cvode_visitor.cpp index d64ba7d214..e93b2b3fd0 100644 --- a/src/visitors/cvode_visitor.cpp +++ b/src/visitors/cvode_visitor.cpp @@ -24,6 +24,17 @@ static int get_index(const ast::IndexedName& node) { return std::stoi(to_nmodl(node.get_length())); } +static void remove_conserve_statements(ast::StatementBlock& node) { + auto conserve_equations = collect_nodes(node, {ast::AstNodeType::CONSERVE}); + if (!conserve_equations.empty()) { + std::unordered_set eqs; + for (const auto& item: conserve_equations) { + eqs.insert(std::dynamic_pointer_cast(item).get()); + } + node.erase_statement(eqs); + } +} + static std::pair> parse_independent_var( std::shared_ptr node) { auto variable = std::make_pair(node->get_node_name(), std::optional()); @@ -170,29 +181,10 @@ void CvodeVisitor::visit_program(ast::Program& node) { auto der_block = std::dynamic_pointer_cast(der_blocks[0]); auto non_stiff_block = der_block->get_statement_block()->clone(); - { - auto conserve_equations = collect_nodes(*non_stiff_block, {ast::AstNodeType::CONSERVE}); - if (!conserve_equations.empty()) { - std::unordered_set eqs; - for (const auto& item: conserve_equations) { - eqs.insert(std::dynamic_pointer_cast(item).get()); - } - non_stiff_block->erase_statement(eqs); - } - } + remove_conserve_statements(*non_stiff_block); auto stiff_block = der_block->get_statement_block()->clone(); - { - auto conserve_equations = collect_nodes(*stiff_block, {ast::AstNodeType::CONSERVE}); - if (!conserve_equations.empty()) { - std::unordered_set eqs; - for (const auto& item: conserve_equations) { - eqs.insert(std::dynamic_pointer_cast(item).get()); - } - stiff_block->erase_statement(eqs); - } - } - + remove_conserve_statements(*stiff_block); NonStiffVisitor(node.get_symbol_table()).visit_statement_block(*non_stiff_block); StiffVisitor(node.get_symbol_table()).visit_statement_block(*stiff_block); From a60577b9957944567240b91b70f11cd8af5387fd Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Tue, 15 Oct 2024 16:27:35 +0200 Subject: [PATCH 36/47] Remove unused class field --- src/visitors/cvode_visitor.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/src/visitors/cvode_visitor.cpp b/src/visitors/cvode_visitor.cpp index e93b2b3fd0..c4c7e08248 100644 --- a/src/visitors/cvode_visitor.cpp +++ b/src/visitors/cvode_visitor.cpp @@ -98,7 +98,6 @@ class CvodeHelperVisitor: public AstVisitor { protected: symtab::SymbolTable* program_symtab = nullptr; bool in_differential_equation = false; - std::unordered_set conserve_equations; public: inline void visit_diff_eq_expression(ast::DiffEqExpression& node) { in_differential_equation = true; From 8bc7d18e00344cdaecbae0936fca3b6edd7129d8 Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Wed, 16 Oct 2024 15:29:37 +0200 Subject: [PATCH 37/47] Only enable sympy if DERIVATIVE block exists --- src/main.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/main.cpp b/src/main.cpp index 89c2813f23..b63b7b0c70 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -517,13 +517,15 @@ int run_nmodl(int argc, const char* argv[]) { enable_sympy(solver_exists(*ast, "derivimplicit"), "'SOLVE ... METHOD derivimplicit'"); enable_sympy(node_exists(*ast, ast::AstNodeType::LINEAR_BLOCK), "'LINEAR' block"); + enable_sympy(node_exists(*ast, ast::AstNodeType::DERIVATIVE_BLOCK), + "'DERIVATIVE' block"); enable_sympy(node_exists(*ast, ast::AstNodeType::NON_LINEAR_BLOCK), "'NONLINEAR' block"); enable_sympy(solver_exists(*ast, "sparse"), "'SOLVE ... METHOD sparse'"); } - if (sympy_conductance || sympy_analytic || neuron_code) { + if (sympy_conductance || sympy_analytic) { nmodl::pybind_wrappers::EmbeddedPythonLoader::get_instance() .api() .initialize_interpreter(); From 40cb10bd7893c2f01da4e74393f4ffbfdd27189c Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Wed, 16 Oct 2024 15:46:20 +0200 Subject: [PATCH 38/47] Rename CVODE subblocks with more apt names --- src/language/codegen.yaml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/language/codegen.yaml b/src/language/codegen.yaml index 67c9efe214..73f113c166 100644 --- a/src/language/codegen.yaml +++ b/src/language/codegen.yaml @@ -96,11 +96,11 @@ type: Name node_name: true suffix: {value: " "} - - function_block: - brief: "Block with statements of the form Dvar = f(var)" + - nonstiff_block: + brief: "Block with statements of the form Dvar = f(var), used for updating non-stiff systems" type: StatementBlock - - diagonal_jacobian_block: - brief: "Block with statements of the form Dvar = Dvar / (1 - dt * J(f))" + - stiff_block: + brief: "Block with statements of the form Dvar = Dvar / (1 - dt * J(f)), used for updating stiff systems" type: StatementBlock brief: "Represents a block used for variable timestep integration (CVODE) of DERIVATIVE blocks" - LongitudinalDiffusionBlock: From 9b58feb2ff9b47ea532f21e7c204e6f9f1d771c7 Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Wed, 16 Oct 2024 15:49:31 +0200 Subject: [PATCH 39/47] `nonstiff` -> `non_stiff` --- src/language/codegen.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/language/codegen.yaml b/src/language/codegen.yaml index 73f113c166..4659965be9 100644 --- a/src/language/codegen.yaml +++ b/src/language/codegen.yaml @@ -96,7 +96,7 @@ type: Name node_name: true suffix: {value: " "} - - nonstiff_block: + - non_stiff_block: brief: "Block with statements of the form Dvar = f(var), used for updating non-stiff systems" type: StatementBlock - stiff_block: From a6ea5abaa979d70b0f3b850e1829d1754e027347 Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Mon, 21 Oct 2024 14:28:44 +0200 Subject: [PATCH 40/47] Get # of ODEs to solve --- src/language/codegen.yaml | 5 +++++ src/visitors/cvode_visitor.cpp | 10 ++++++---- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/src/language/codegen.yaml b/src/language/codegen.yaml index 4659965be9..c7ce97c5b8 100644 --- a/src/language/codegen.yaml +++ b/src/language/codegen.yaml @@ -96,6 +96,11 @@ type: Name node_name: true suffix: {value: " "} + - n_odes: + brief: "number of ODEs to solve" + type: Integer + prefix: {value: "["} + suffix: {value: "]"} - non_stiff_block: brief: "Block with statements of the form Dvar = f(var), used for updating non-stiff systems" type: StatementBlock diff --git a/src/visitors/cvode_visitor.cpp b/src/visitors/cvode_visitor.cpp index c4c7e08248..54dc3ca721 100644 --- a/src/visitors/cvode_visitor.cpp +++ b/src/visitors/cvode_visitor.cpp @@ -187,10 +187,12 @@ void CvodeVisitor::visit_program(ast::Program& node) { NonStiffVisitor(node.get_symbol_table()).visit_statement_block(*non_stiff_block); StiffVisitor(node.get_symbol_table()).visit_statement_block(*stiff_block); - node.emplace_back_node( - new ast::CvodeBlock(der_block->get_name(), - std::shared_ptr(non_stiff_block), - std::shared_ptr(stiff_block))); + auto prime_vars = collect_nodes(*der_block, {ast::AstNodeType::PRIME_NAME}); + node.emplace_back_node(new ast::CvodeBlock( + der_block->get_name(), + std::shared_ptr(new ast::Integer(prime_vars.size(), nullptr)), + std::shared_ptr(non_stiff_block), + std::shared_ptr(stiff_block))); } } From 5d94f7a88499d3f6c0fa35cfba407ec98314f6ee Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Mon, 21 Oct 2024 16:02:11 +0200 Subject: [PATCH 41/47] der_block(s) -> derivative_block(s) --- src/visitors/cvode_visitor.cpp | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/src/visitors/cvode_visitor.cpp b/src/visitors/cvode_visitor.cpp index 54dc3ca721..41845ead33 100644 --- a/src/visitors/cvode_visitor.cpp +++ b/src/visitors/cvode_visitor.cpp @@ -175,21 +175,22 @@ class StiffVisitor: public CvodeHelperVisitor { void CvodeVisitor::visit_program(ast::Program& node) { - auto der_blocks = collect_nodes(node, {ast::AstNodeType::DERIVATIVE_BLOCK}); - if (!der_blocks.empty()) { - auto der_block = std::dynamic_pointer_cast(der_blocks[0]); + auto derivative_blocks = collect_nodes(node, {ast::AstNodeType::DERIVATIVE_BLOCK}); + if (!derivative_blocks.empty()) { + auto derivative_block = std::dynamic_pointer_cast( + derivative_blocks[0]); - auto non_stiff_block = der_block->get_statement_block()->clone(); + auto non_stiff_block = derivative_block->get_statement_block()->clone(); remove_conserve_statements(*non_stiff_block); - auto stiff_block = der_block->get_statement_block()->clone(); + auto stiff_block = derivative_block->get_statement_block()->clone(); remove_conserve_statements(*stiff_block); NonStiffVisitor(node.get_symbol_table()).visit_statement_block(*non_stiff_block); StiffVisitor(node.get_symbol_table()).visit_statement_block(*stiff_block); - auto prime_vars = collect_nodes(*der_block, {ast::AstNodeType::PRIME_NAME}); + auto prime_vars = collect_nodes(*derivative_block, {ast::AstNodeType::PRIME_NAME}); node.emplace_back_node(new ast::CvodeBlock( - der_block->get_name(), + derivative_block->get_name(), std::shared_ptr(new ast::Integer(prime_vars.size(), nullptr)), std::shared_ptr(non_stiff_block), std::shared_ptr(stiff_block))); From bf9db7acef87ab100583dcd2de83ab3c666469ca Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Mon, 21 Oct 2024 16:30:25 +0200 Subject: [PATCH 42/47] get_name_map -> get_indexed_variables Also use a set since we don't care about the actual index for the RHS --- src/pybind/wrapper.cpp | 4 ++-- src/pybind/wrapper.hpp | 8 ++++---- src/visitors/cvode_visitor.cpp | 33 ++++++++++++++++----------------- 3 files changed, 22 insertions(+), 23 deletions(-) diff --git a/src/pybind/wrapper.cpp b/src/pybind/wrapper.cpp index 8385954ed7..d59b579d97 100644 --- a/src/pybind/wrapper.cpp +++ b/src/pybind/wrapper.cpp @@ -190,10 +190,10 @@ except Exception as e: std::tuple call_diff2c( const std::string& expression, const std::pair>& variable, - const std::unordered_map& indexed_vars) { + const std::unordered_set& indexed_vars) { std::string statements; // only indexed variables require special treatment - for (const auto& [var, prop]: indexed_vars) { + for (const auto& var: indexed_vars) { statements += fmt::format("_allvars.append(sp.IndexedBase('{}', shape=[1]))\n", var); } auto [name, property] = variable; diff --git a/src/pybind/wrapper.hpp b/src/pybind/wrapper.hpp index aad85aef25..e93cca51f5 100644 --- a/src/pybind/wrapper.hpp +++ b/src/pybind/wrapper.hpp @@ -10,7 +10,7 @@ #include #include #include -#include +#include #include namespace nmodl { @@ -50,12 +50,12 @@ std::tuple call_analytic_diff( /// \brief Differentiates an expression with respect to a variable /// \param expression The expression we want to differentiate /// \param variable The name of the independent variable we are differentiating against -/// \param index_vars A map of array (indexable) variables (and their associated indices) that -/// appear in \ref expression \return The tuple (solution, exception) +/// \param index_vars A set of array (indexable) variables that appear in \ref expression +/// \return The tuple (solution, exception) std::tuple call_diff2c( const std::string& expression, const std::pair>& variable, - const std::unordered_map& indexed_vars = {}); + const std::unordered_set& indexed_vars = {}); struct pybind_wrap_api { decltype(&initialize_interpreter_func) initialize_interpreter; diff --git a/src/visitors/cvode_visitor.cpp b/src/visitors/cvode_visitor.cpp index 41845ead33..04eea17093 100644 --- a/src/visitors/cvode_visitor.cpp +++ b/src/visitors/cvode_visitor.cpp @@ -45,27 +45,26 @@ static std::pair> parse_independent_var( return variable; } -static std::unordered_map get_name_map(const ast::Expression& node, - const std::string& name) { - std::unordered_map name_map; - // all of the "reserved" symbols +/// set of all indexed variables not equal to ``name`` +static std::unordered_set get_indexed_variables(const ast::Expression& node, + const std::string& name) { + std::unordered_set indexed_variables; + // all of the "reserved" vars auto reserved_symbols = get_external_functions(); // all indexed vars auto indexed_vars = collect_nodes(node, {ast::AstNodeType::INDEXED_NAME}); for (const auto& var: indexed_vars) { - if (!name_map.count(var->get_node_name()) && var->get_node_name() != name && - std::none_of(reserved_symbols.begin(), reserved_symbols.end(), [&var](const auto item) { - return var->get_node_name() == item; - })) { - logger->debug( - "CvodeVisitor :: adding INDEXED_VARIABLE {} to " - "node_map", - var->get_node_name()); - name_map[var->get_node_name()] = get_index( - *std::dynamic_pointer_cast(var)); + const auto& varname = var->get_node_name(); + // skip if it's a reserved var + auto varname_not_reserved = + std::none_of(reserved_symbols.begin(), + reserved_symbols.end(), + [&varname](const auto item) { return varname == item; }); + if (indexed_variables.count(varname) == 0 && varname != name && varname_not_reserved) { + indexed_variables.insert(varname); } } - return name_map; + return indexed_variables; } static std::string cvode_set_lhs(ast::BinaryExpression& node) { @@ -153,8 +152,8 @@ class StiffVisitor: public CvodeHelperVisitor { } auto rhs = node.get_rhs(); - // map of all indexed symbols (need special treatment in SymPy) - auto name_map = get_name_map(*rhs, name->get_node_name()); + // all indexed variables (need special treatment in SymPy) + auto name_map = get_indexed_variables(*rhs, name->get_node_name()); auto diff2c = pywrap::EmbeddedPythonLoader::get_instance().api().diff2c; auto [jacobian, exception_message] = diff2c(to_nmodl(*rhs), parse_independent_var(name), name_map); From fd0c619c580e1fcf2e38d3ec836b64bf756eeedb Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Mon, 21 Oct 2024 23:03:20 +0200 Subject: [PATCH 43/47] Add check for multiple DERIVATIVE blocks --- src/visitors/cvode_visitor.cpp | 54 ++++++++++++++++++++++------------ 1 file changed, 36 insertions(+), 18 deletions(-) diff --git a/src/visitors/cvode_visitor.cpp b/src/visitors/cvode_visitor.cpp index 04eea17093..75e20c43dc 100644 --- a/src/visitors/cvode_visitor.cpp +++ b/src/visitors/cvode_visitor.cpp @@ -175,25 +175,43 @@ class StiffVisitor: public CvodeHelperVisitor { void CvodeVisitor::visit_program(ast::Program& node) { auto derivative_blocks = collect_nodes(node, {ast::AstNodeType::DERIVATIVE_BLOCK}); - if (!derivative_blocks.empty()) { - auto derivative_block = std::dynamic_pointer_cast( - derivative_blocks[0]); - - auto non_stiff_block = derivative_block->get_statement_block()->clone(); - remove_conserve_statements(*non_stiff_block); - - auto stiff_block = derivative_block->get_statement_block()->clone(); - remove_conserve_statements(*stiff_block); - - NonStiffVisitor(node.get_symbol_table()).visit_statement_block(*non_stiff_block); - StiffVisitor(node.get_symbol_table()).visit_statement_block(*stiff_block); - auto prime_vars = collect_nodes(*derivative_block, {ast::AstNodeType::PRIME_NAME}); - node.emplace_back_node(new ast::CvodeBlock( - derivative_block->get_name(), - std::shared_ptr(new ast::Integer(prime_vars.size(), nullptr)), - std::shared_ptr(non_stiff_block), - std::shared_ptr(stiff_block))); + if (derivative_blocks.empty()) { + return; } + + // steady state adds a DERIVATIVE block with a `_steadystate` suffix + auto not_steadystate = [](const auto& item) { + auto name = std::dynamic_pointer_cast(item)->get_node_name(); + return !stringutils::ends_with(name, "_steadystate"); + }; + decltype(derivative_blocks) derivative_blocks_copy; + std::copy_if(derivative_blocks.begin(), + derivative_blocks.end(), + std::back_inserter(derivative_blocks_copy), + not_steadystate); + if (derivative_blocks_copy.size() > 1) { + auto message = "CvodeVisitor :: cannot have multiple DERIVATIVE blocks"; + logger->error(message); + throw std::runtime_error(message); + } + + auto derivative_block = std::dynamic_pointer_cast( + derivative_blocks_copy[0]); + + auto non_stiff_block = derivative_block->get_statement_block()->clone(); + remove_conserve_statements(*non_stiff_block); + + auto stiff_block = derivative_block->get_statement_block()->clone(); + remove_conserve_statements(*stiff_block); + + NonStiffVisitor(node.get_symbol_table()).visit_statement_block(*non_stiff_block); + StiffVisitor(node.get_symbol_table()).visit_statement_block(*stiff_block); + auto prime_vars = collect_nodes(*derivative_block, {ast::AstNodeType::PRIME_NAME}); + node.emplace_back_node(new ast::CvodeBlock( + derivative_block->get_name(), + std::shared_ptr(new ast::Integer(prime_vars.size(), nullptr)), + std::shared_ptr(non_stiff_block), + std::shared_ptr(stiff_block))); } } // namespace visitor From 5f4a00aba712faeb6ec2535a887bc328ec799fb1 Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Mon, 21 Oct 2024 23:03:35 +0200 Subject: [PATCH 44/47] Update tests for CVODE --- test/unit/visitor/cvode.cpp | 30 +++++++++++++++++++++++++++++- 1 file changed, 29 insertions(+), 1 deletion(-) diff --git a/test/unit/visitor/cvode.cpp b/test/unit/visitor/cvode.cpp index e42df2ec45..96d09549ab 100644 --- a/test/unit/visitor/cvode.cpp +++ b/test/unit/visitor/cvode.cpp @@ -28,8 +28,16 @@ auto run_cvode_visitor(const std::string& text) { TEST_CASE("Make sure CVODE block is generated properly", "[visitor][cvode]") { + GIVEN("No DERIVATIVE block") { + auto nmodl_text = "NEURON { SUFFIX example }"; + auto ast = run_cvode_visitor(nmodl_text); + THEN("No CVODE block is added") { + auto blocks = collect_nodes(*ast, {ast::AstNodeType::CVODE_BLOCK}); + REQUIRE(blocks.empty()); + } + } GIVEN("DERIVATIVE block") { - std::string nmodl_text = R"( + auto nmodl_text = R"( NEURON { SUFFIX example } @@ -58,4 +66,24 @@ TEST_CASE("Make sure CVODE block is generated properly", "[visitor][cvode]") { } } } + GIVEN("Multiple DERIVATIVE blocks") { + auto nmodl_text = R"( + NEURON { + SUFFIX example + } + + STATE {X} + + DERIVATIVE equation { + X' = -X + } + + DERIVATIVE equation2 { + X' = -X * X + } +)"; + THEN("An error is raised") { + REQUIRE_THROWS_AS(run_cvode_visitor(nmodl_text), std::runtime_error); + } + } } From e30c27cbfbf68d41253fad467d002f23bad97e59 Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Mon, 21 Oct 2024 23:28:36 +0200 Subject: [PATCH 45/47] Update docs --- docs/contents/cvode.rst | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/contents/cvode.rst b/docs/contents/cvode.rst index 367e160d48..6179d824c2 100644 --- a/docs/contents/cvode.rst +++ b/docs/contents/cvode.rst @@ -37,20 +37,20 @@ the structure of the ``CVODE`` block is then roughly: .. code-block:: - CVODE state { + CVODE state[n] { Dx_i = f_i(x_1, ..., x_n) }{ Dx_i = Dx_i / (1 - dt * J_ii(f)) } -where ``J_ii(f)`` is the diagonal part of the Jacobian, i.e. +where ``N`` is the total number of ODEs to solve, and ``J_ii(f)`` is the +diagonal part of the Jacobian, i.e. .. math:: J_{ii}(f) = \frac{ \partial f_i(x_1, \ldots, x_n) }{\partial x_i} -As an example, consider the following ``DERIVATIVE`` -block: +As an example, consider the following ``DERIVATIVE`` block: .. code-block:: @@ -63,7 +63,7 @@ Where ``X`` is a ``STATE`` variable with some initial value, specified in the .. code-block:: - CVODE state { + CVODE state[1] { DX = - X }{ DX = DX / (1 - dt * (-1)) From 167d38f03aa0ea90cd81d7da97ec950a5ca8aa95 Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Thu, 24 Oct 2024 13:43:24 +0200 Subject: [PATCH 46/47] Address comments from review --- src/visitors/cvode_visitor.cpp | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/src/visitors/cvode_visitor.cpp b/src/visitors/cvode_visitor.cpp index 75e20c43dc..ea86b14bfa 100644 --- a/src/visitors/cvode_visitor.cpp +++ b/src/visitors/cvode_visitor.cpp @@ -153,10 +153,10 @@ class StiffVisitor: public CvodeHelperVisitor { auto rhs = node.get_rhs(); // all indexed variables (need special treatment in SymPy) - auto name_map = get_indexed_variables(*rhs, name->get_node_name()); + auto indexed_variables = get_indexed_variables(*rhs, name->get_node_name()); auto diff2c = pywrap::EmbeddedPythonLoader::get_instance().api().diff2c; - auto [jacobian, - exception_message] = diff2c(to_nmodl(*rhs), parse_independent_var(name), name_map); + auto [jacobian, exception_message] = + diff2c(to_nmodl(*rhs), parse_independent_var(name), indexed_variables); if (!exception_message.empty()) { logger->warn("CvodeVisitor :: python exception: {}", exception_message); } @@ -172,11 +172,10 @@ class StiffVisitor: public CvodeHelperVisitor { } }; - -void CvodeVisitor::visit_program(ast::Program& node) { +static std::shared_ptr get_derivative_block(ast::Program& node) { auto derivative_blocks = collect_nodes(node, {ast::AstNodeType::DERIVATIVE_BLOCK}); if (derivative_blocks.empty()) { - return; + return nullptr; } // steady state adds a DERIVATIVE block with a `_steadystate` suffix @@ -195,8 +194,15 @@ void CvodeVisitor::visit_program(ast::Program& node) { throw std::runtime_error(message); } - auto derivative_block = std::dynamic_pointer_cast( - derivative_blocks_copy[0]); + return std::dynamic_pointer_cast(derivative_blocks_copy[0]); +} + + +void CvodeVisitor::visit_program(ast::Program& node) { + auto derivative_block = get_derivative_block(node); + if (derivative_block == nullptr) { + return; + } auto non_stiff_block = derivative_block->get_statement_block()->clone(); remove_conserve_statements(*non_stiff_block); From 4ce3d854da35c17c16772213a65ff205207a9a4f Mon Sep 17 00:00:00 2001 From: Goran Jelic-Cizmek Date: Thu, 24 Oct 2024 14:51:05 +0200 Subject: [PATCH 47/47] Fix variable name --- src/visitors/cvode_visitor.cpp | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/visitors/cvode_visitor.cpp b/src/visitors/cvode_visitor.cpp index ea86b14bfa..617e1e6c65 100644 --- a/src/visitors/cvode_visitor.cpp +++ b/src/visitors/cvode_visitor.cpp @@ -45,9 +45,9 @@ static std::pair> parse_independent_var( return variable; } -/// set of all indexed variables not equal to ``name`` +/// set of all indexed variables not equal to ``ignored_name`` static std::unordered_set get_indexed_variables(const ast::Expression& node, - const std::string& name) { + const std::string& ignored_name) { std::unordered_set indexed_variables; // all of the "reserved" vars auto reserved_symbols = get_external_functions(); @@ -60,7 +60,8 @@ static std::unordered_set get_indexed_variables(const ast::Expressi std::none_of(reserved_symbols.begin(), reserved_symbols.end(), [&varname](const auto item) { return varname == item; }); - if (indexed_variables.count(varname) == 0 && varname != name && varname_not_reserved) { + if (indexed_variables.count(varname) == 0 && varname != ignored_name && + varname_not_reserved) { indexed_variables.insert(varname); } }