diff --git a/src/codegen/codegen_c_visitor.cpp b/src/codegen/codegen_c_visitor.cpp index 98bc738c0a..7063e4f7f3 100644 --- a/src/codegen/codegen_c_visitor.cpp +++ b/src/codegen/codegen_c_visitor.cpp @@ -1459,6 +1459,20 @@ static const TableStatement* get_table_statement(const ast::Block& node) { } +std::tuple CodegenCVisitor::check_if_var_is_array(const std::string& name) { + auto symbol = program_symtab->lookup_in_scope(name); + if (!symbol) { + throw std::runtime_error( + fmt::format("CodegenCVisitor:: {} not found in symbol table!", name)); + } + if (symbol->is_array()) { + return {true, symbol->get_length()}; + } else { + return {false, 0}; + } +} + + void CodegenCVisitor::print_table_check_function(const Block& node) { auto statement = get_table_statement(node); auto table_variables = statement->get_table_vars(); @@ -1525,7 +1539,15 @@ void CodegenCVisitor::print_table_check_function(const Block& node) { auto name = variable->get_node_name(); auto instance_name = get_variable_name(name); auto table_name = get_variable_name("t_" + name); - printer->fmt_line("{}[i] = {};", table_name, instance_name); + auto [is_array, array_length] = check_if_var_is_array(name); + if (is_array) { + for (int j = 0; j < array_length; j++) { + printer->fmt_line( + "{}[{}][i] = {}[{}];", table_name, j, instance_name, j); + } + } else { + printer->fmt_line("{}[i] = {};", table_name, instance_name); + } } } else { auto table_name = get_variable_name("t_" + name); @@ -1587,7 +1609,14 @@ void CodegenCVisitor::print_table_replacement_function(const ast::Block& node) { if (node.is_procedure_block()) { for (const auto& var: table_variables) { auto name = get_variable_name(var->get_node_name()); - printer->fmt_line("{} = xi;", name); + auto [is_array, array_length] = check_if_var_is_array(var->get_node_name()); + if (is_array) { + for (int j = 0; j < array_length; j++) { + printer->fmt_line("{}[{}] = xi;", name, j); + } + } else { + printer->fmt_line("{} = xi;", name); + } } printer->add_line("return 0;"); } else { @@ -1602,7 +1631,15 @@ void CodegenCVisitor::print_table_replacement_function(const ast::Block& node) { auto name = variable->get_node_name(); auto instance_name = get_variable_name(name); auto table_name = get_variable_name("t_" + name); - printer->fmt_line("{} = {}[index];", instance_name, table_name); + auto [is_array, array_length] = check_if_var_is_array(name); + if (is_array) { + for (int j = 0; j < array_length; j++) { + printer->fmt_line( + "{}[{}] = {}[{}][index];", instance_name, j, table_name, j); + } + } else { + printer->fmt_line("{} = {}[index];", instance_name, table_name); + } } printer->add_line("return 0;"); } else { @@ -1615,11 +1652,23 @@ void CodegenCVisitor::print_table_replacement_function(const ast::Block& node) { printer->add_line("double theta = xi - double(i);"); if (node.is_procedure_block()) { for (const auto& var: table_variables) { - auto instance_name = get_variable_name(var->get_node_name()); - auto table_name = get_variable_name("t_" + var->get_node_name()); - printer->fmt_line("{0} = {1}[i] + theta*({1}[i+1]-{1}[i]);", - instance_name, - table_name); + auto name = var->get_node_name(); + auto instance_name = get_variable_name(name); + auto table_name = get_variable_name("t_" + name); + auto [is_array, array_length] = check_if_var_is_array(var->get_node_name()); + if (is_array) { + for (size_t j = 0; j < array_length; j++) { + printer->fmt_line( + "{0}[{1}] = {2}[{1}][i] + theta*({2}[{1}][i+1]-{2}[{1}][i]);", + instance_name, + j, + table_name); + } + } else { + printer->fmt_line("{0} = {1}[i] + theta*({1}[i+1]-{1}[i]);", + instance_name, + table_name); + } } printer->add_line("return 0;"); } else { @@ -2630,8 +2679,19 @@ void CodegenCVisitor::print_mechanism_global_var_structure(bool print_initialise for (const auto& variable: info.table_statement_variables) { auto const name = "t_" + variable->get_name(); auto const num_values = variable->get_num_values(); - printer->fmt_line( - "{}{} {}[{}]{};", qualifier, float_type, name, num_values, value_initialise); + if (variable->is_array()) { + int array_len = variable->get_length(); + printer->fmt_line("{}{} {}[{}][{}]{};", + qualifier, + float_type, + name, + array_len, + num_values, + value_initialise); + } else { + printer->fmt_line( + "{}{} {}[{}]{};", qualifier, float_type, name, num_values, value_initialise); + } codegen_global_variables.push_back(make_symbol(name)); } } diff --git a/src/codegen/codegen_c_visitor.hpp b/src/codegen/codegen_c_visitor.hpp index 7bb324e505..3a4fc39c13 100644 --- a/src/codegen/codegen_c_visitor.hpp +++ b/src/codegen/codegen_c_visitor.hpp @@ -1025,6 +1025,12 @@ class CodegenCVisitor: public visitor::ConstAstVisitor { */ virtual bool is_constant_variable(const std::string& name) const; + /** + * Check if the given name exist in the symbol + * \return \c return a tuple if variable + * is an array otherwise + */ + std::tuple check_if_var_is_array(const std::string& name); /** * Print declaration of macro NRN_PRCELLSTATE for debugging diff --git a/test/unit/codegen/codegen_c_visitor.cpp b/test/unit/codegen/codegen_c_visitor.cpp index 66c1c14c80..b38f79f424 100644 --- a/test/unit/codegen/codegen_c_visitor.cpp +++ b/test/unit/codegen/codegen_c_visitor.cpp @@ -13,7 +13,11 @@ #include "parser/nmodl_driver.hpp" #include "test/unit/utils/test_utils.hpp" #include "visitors/implicit_argument_visitor.hpp" +#include "visitors/inline_visitor.hpp" +#include "visitors/neuron_solve_visitor.hpp" #include "visitors/perf_visitor.hpp" +#include "visitors/solve_block_visitor.hpp" +#include "visitors/sympy_solver_visitor.hpp" #include "visitors/symtab_visitor.hpp" using Catch::Matchers::Contains; // ContainsSubstring in newer Catch2 @@ -32,6 +36,11 @@ std::shared_ptr create_c_visitor(const std::shared_ptr("temp.mod", ss, "double", false); cv->setup(*ast); @@ -47,6 +56,15 @@ std::string get_instance_var_setup_function(std::string& nmodl_text) { return reindent_text(ss.str()); } +/// print entire code +std::string get_cpp_code(const std::string& nmodl_text) { + const auto& ast = NmodlDriver().parse_string(nmodl_text); + std::stringstream ss; + auto cvisitor = create_c_visitor(ast, nmodl_text, ss); + cvisitor->visit_program(*ast); + return reindent_text(ss.str()); +} + SCENARIO("Check instance variable definition order", "[codegen][var_order]") { GIVEN("cal_mig.mod: USEION variables declared as RANGE") { // In the below mod file, the ion variables cai and cao are also @@ -345,3 +363,57 @@ SCENARIO("Check NEURON globals are added to the instance struct on demand", } } } + +SCENARIO("Check code generation for TABLE statements", "[codegen][array_variables]") { + GIVEN("A MOD file that uses global and array variables in TABLE") { + std::string const nmodl_text = R"( + NEURON { + SUFFIX glia_Cav2_3 + RANGE inf + GLOBAL tau + } + + STATE { m } + + PARAMETER { + tau = 1 + } + + ASSIGNED { + inf[2] + } + + BREAKPOINT { + SOLVE states METHOD cnexp + } + + DERIVATIVE states { + mhn(v) + m' = (inf[0] - m)/tau + } + + PROCEDURE mhn(v (mV)) { + TABLE inf, tau DEPEND celsius FROM -100 TO 100 WITH 200 + FROM i=0 TO 1 { + inf[i] = v + tau + } + } + )"; + THEN("Array and global variables should be correctly generated") { + auto const generated = get_cpp_code(nmodl_text); + REQUIRE_THAT(generated, Contains("double t_inf[2][201]{};")); + REQUIRE_THAT(generated, Contains("double t_tau[201]{};")); + + REQUIRE_THAT(generated, Contains("inst->global->t_inf[0][i] = (inst->inf+id*2)[0];")); + REQUIRE_THAT(generated, Contains("inst->global->t_inf[1][i] = (inst->inf+id*2)[1];")); + REQUIRE_THAT(generated, Contains("inst->global->t_tau[i] = inst->global->tau;")); + + REQUIRE_THAT(generated, + Contains("(inst->inf+id*2)[0] = inst->global->t_inf[0][index];")); + + REQUIRE_THAT(generated, Contains("(inst->inf+id*2)[0] = inst->global->t_inf[0][i]")); + REQUIRE_THAT(generated, Contains("(inst->inf+id*2)[1] = inst->global->t_inf[1][i]")); + REQUIRE_THAT(generated, Contains("inst->global->tau = inst->global->t_tau[i]")); + } + } +}