Skip to content

Commit

Permalink
Code generation fixes for array variables in TABLE statement (#924)
Browse files Browse the repository at this point in the history
* Table statements can have array variables. Until now only
  scalar variables were supported in code generation.
* We can check symbol table to find out if the variable is
  an array and it's length.
* Similar to mod2c implementation, generate code for array
  variable assignments:
     https://github.com/BlueBrain/mod2c/blob/469c74dc7d96bbc5a06a42696422154b4cd2ce28/src/mod2c_core/parsact.c#L942
* with this, `glia__dbbs_mod_collection__Cav2_3__0.mod` from #888
  compiles

* Add test and fix the bug for array variables allocation and access
  • Loading branch information
pramodk authored Sep 7, 2022
1 parent 4984c31 commit 6a59caa
Show file tree
Hide file tree
Showing 3 changed files with 148 additions and 10 deletions.
80 changes: 70 additions & 10 deletions src/codegen/codegen_c_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1459,6 +1459,20 @@ static const TableStatement* get_table_statement(const ast::Block& node) {
}


std::tuple<bool, int> CodegenCVisitor::check_if_var_is_array(const std::string& name) {
auto symbol = program_symtab->lookup_in_scope(name);
if (!symbol) {
throw std::runtime_error(
fmt::format("CodegenCVisitor:: {} not found in symbol table!", name));
}
if (symbol->is_array()) {
return {true, symbol->get_length()};
} else {
return {false, 0};
}
}


void CodegenCVisitor::print_table_check_function(const Block& node) {
auto statement = get_table_statement(node);
auto table_variables = statement->get_table_vars();
Expand Down Expand Up @@ -1525,7 +1539,15 @@ void CodegenCVisitor::print_table_check_function(const Block& node) {
auto name = variable->get_node_name();
auto instance_name = get_variable_name(name);
auto table_name = get_variable_name("t_" + name);
printer->fmt_line("{}[i] = {};", table_name, instance_name);
auto [is_array, array_length] = check_if_var_is_array(name);
if (is_array) {
for (int j = 0; j < array_length; j++) {
printer->fmt_line(
"{}[{}][i] = {}[{}];", table_name, j, instance_name, j);
}
} else {
printer->fmt_line("{}[i] = {};", table_name, instance_name);
}
}
} else {
auto table_name = get_variable_name("t_" + name);
Expand Down Expand Up @@ -1587,7 +1609,14 @@ void CodegenCVisitor::print_table_replacement_function(const ast::Block& node) {
if (node.is_procedure_block()) {
for (const auto& var: table_variables) {
auto name = get_variable_name(var->get_node_name());
printer->fmt_line("{} = xi;", name);
auto [is_array, array_length] = check_if_var_is_array(var->get_node_name());
if (is_array) {
for (int j = 0; j < array_length; j++) {
printer->fmt_line("{}[{}] = xi;", name, j);
}
} else {
printer->fmt_line("{} = xi;", name);
}
}
printer->add_line("return 0;");
} else {
Expand All @@ -1602,7 +1631,15 @@ void CodegenCVisitor::print_table_replacement_function(const ast::Block& node) {
auto name = variable->get_node_name();
auto instance_name = get_variable_name(name);
auto table_name = get_variable_name("t_" + name);
printer->fmt_line("{} = {}[index];", instance_name, table_name);
auto [is_array, array_length] = check_if_var_is_array(name);
if (is_array) {
for (int j = 0; j < array_length; j++) {
printer->fmt_line(
"{}[{}] = {}[{}][index];", instance_name, j, table_name, j);
}
} else {
printer->fmt_line("{} = {}[index];", instance_name, table_name);
}
}
printer->add_line("return 0;");
} else {
Expand All @@ -1615,11 +1652,23 @@ void CodegenCVisitor::print_table_replacement_function(const ast::Block& node) {
printer->add_line("double theta = xi - double(i);");
if (node.is_procedure_block()) {
for (const auto& var: table_variables) {
auto instance_name = get_variable_name(var->get_node_name());
auto table_name = get_variable_name("t_" + var->get_node_name());
printer->fmt_line("{0} = {1}[i] + theta*({1}[i+1]-{1}[i]);",
instance_name,
table_name);
auto name = var->get_node_name();
auto instance_name = get_variable_name(name);
auto table_name = get_variable_name("t_" + name);
auto [is_array, array_length] = check_if_var_is_array(var->get_node_name());
if (is_array) {
for (size_t j = 0; j < array_length; j++) {
printer->fmt_line(
"{0}[{1}] = {2}[{1}][i] + theta*({2}[{1}][i+1]-{2}[{1}][i]);",
instance_name,
j,
table_name);
}
} else {
printer->fmt_line("{0} = {1}[i] + theta*({1}[i+1]-{1}[i]);",
instance_name,
table_name);
}
}
printer->add_line("return 0;");
} else {
Expand Down Expand Up @@ -2630,8 +2679,19 @@ void CodegenCVisitor::print_mechanism_global_var_structure(bool print_initialise
for (const auto& variable: info.table_statement_variables) {
auto const name = "t_" + variable->get_name();
auto const num_values = variable->get_num_values();
printer->fmt_line(
"{}{} {}[{}]{};", qualifier, float_type, name, num_values, value_initialise);
if (variable->is_array()) {
int array_len = variable->get_length();
printer->fmt_line("{}{} {}[{}][{}]{};",
qualifier,
float_type,
name,
array_len,
num_values,
value_initialise);
} else {
printer->fmt_line(
"{}{} {}[{}]{};", qualifier, float_type, name, num_values, value_initialise);
}
codegen_global_variables.push_back(make_symbol(name));
}
}
Expand Down
6 changes: 6 additions & 0 deletions src/codegen/codegen_c_visitor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1025,6 +1025,12 @@ class CodegenCVisitor: public visitor::ConstAstVisitor {
*/
virtual bool is_constant_variable(const std::string& name) const;

/**
* Check if the given name exist in the symbol
* \return \c return a tuple <true, array_length> if variable
* is an array otherwise <false, 0>
*/
std::tuple<bool, int> check_if_var_is_array(const std::string& name);

/**
* Print declaration of macro NRN_PRCELLSTATE for debugging
Expand Down
72 changes: 72 additions & 0 deletions test/unit/codegen/codegen_c_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@
#include "parser/nmodl_driver.hpp"
#include "test/unit/utils/test_utils.hpp"
#include "visitors/implicit_argument_visitor.hpp"
#include "visitors/inline_visitor.hpp"
#include "visitors/neuron_solve_visitor.hpp"
#include "visitors/perf_visitor.hpp"
#include "visitors/solve_block_visitor.hpp"
#include "visitors/sympy_solver_visitor.hpp"
#include "visitors/symtab_visitor.hpp"

using Catch::Matchers::Contains; // ContainsSubstring in newer Catch2
Expand All @@ -32,6 +36,11 @@ std::shared_ptr<CodegenCVisitor> create_c_visitor(const std::shared_ptr<ast::Pro
/// construct symbol table
SymtabVisitor().visit_program(*ast);

/// run all necessary pass
InlineVisitor().visit_program(*ast);
NeuronSolveVisitor().visit_program(*ast);
SolveBlockVisitor().visit_program(*ast);

/// create C code generation visitor
auto cv = std::make_shared<CodegenCVisitor>("temp.mod", ss, "double", false);
cv->setup(*ast);
Expand All @@ -47,6 +56,15 @@ std::string get_instance_var_setup_function(std::string& nmodl_text) {
return reindent_text(ss.str());
}

/// print entire code
std::string get_cpp_code(const std::string& nmodl_text) {
const auto& ast = NmodlDriver().parse_string(nmodl_text);
std::stringstream ss;
auto cvisitor = create_c_visitor(ast, nmodl_text, ss);
cvisitor->visit_program(*ast);
return reindent_text(ss.str());
}

SCENARIO("Check instance variable definition order", "[codegen][var_order]") {
GIVEN("cal_mig.mod: USEION variables declared as RANGE") {
// In the below mod file, the ion variables cai and cao are also
Expand Down Expand Up @@ -345,3 +363,57 @@ SCENARIO("Check NEURON globals are added to the instance struct on demand",
}
}
}

SCENARIO("Check code generation for TABLE statements", "[codegen][array_variables]") {
GIVEN("A MOD file that uses global and array variables in TABLE") {
std::string const nmodl_text = R"(
NEURON {
SUFFIX glia_Cav2_3
RANGE inf
GLOBAL tau
}
STATE { m }
PARAMETER {
tau = 1
}
ASSIGNED {
inf[2]
}
BREAKPOINT {
SOLVE states METHOD cnexp
}
DERIVATIVE states {
mhn(v)
m' = (inf[0] - m)/tau
}
PROCEDURE mhn(v (mV)) {
TABLE inf, tau DEPEND celsius FROM -100 TO 100 WITH 200
FROM i=0 TO 1 {
inf[i] = v + tau
}
}
)";
THEN("Array and global variables should be correctly generated") {
auto const generated = get_cpp_code(nmodl_text);
REQUIRE_THAT(generated, Contains("double t_inf[2][201]{};"));
REQUIRE_THAT(generated, Contains("double t_tau[201]{};"));

REQUIRE_THAT(generated, Contains("inst->global->t_inf[0][i] = (inst->inf+id*2)[0];"));
REQUIRE_THAT(generated, Contains("inst->global->t_inf[1][i] = (inst->inf+id*2)[1];"));
REQUIRE_THAT(generated, Contains("inst->global->t_tau[i] = inst->global->tau;"));

REQUIRE_THAT(generated,
Contains("(inst->inf+id*2)[0] = inst->global->t_inf[0][index];"));

REQUIRE_THAT(generated, Contains("(inst->inf+id*2)[0] = inst->global->t_inf[0][i]"));
REQUIRE_THAT(generated, Contains("(inst->inf+id*2)[1] = inst->global->t_inf[1][i]"));
REQUIRE_THAT(generated, Contains("inst->global->tau = inst->global->t_tau[i]"));
}
}
}

0 comments on commit 6a59caa

Please sign in to comment.