diff --git a/src/language/templates/pybind/pysymtab.cpp b/src/language/templates/pybind/pysymtab.cpp index 104e132203..a3d4e9b9dc 100644 --- a/src/language/templates/pybind/pysymtab.cpp +++ b/src/language/templates/pybind/pysymtab.cpp @@ -200,7 +200,8 @@ void init_symtab_module(py::module& m) { .def("get_id", &Symbol::get_id) .def("get_status", &Symbol::get_status) .def("get_properties", &Symbol::get_properties) - .def("get_node", &Symbol::get_node) + .def("get_node", [](const std::shared_ptr& s){ auto n = s->get_nodes(); return n.empty() ? nullptr : n.front(); }) + .def("get_nodes", &Symbol::get_nodes) .def("get_original_name", &Symbol::get_original_name) .def("get_name", &Symbol::get_name) .def("has_any_property", &Symbol::has_any_property) diff --git a/src/symtab/symbol.cpp b/src/symtab/symbol.cpp index edd9688c5a..852862fd89 100644 --- a/src/symtab/symbol.cpp +++ b/src/symtab/symbol.cpp @@ -7,6 +7,7 @@ #include "symtab/symbol.hpp" #include "utils/logger.hpp" +#include namespace nmodl { namespace symtab { @@ -49,5 +50,19 @@ std::string Symbol::to_string() const { return s; } +std::vector Symbol::get_nodes_by_type( + std::initializer_list l) const noexcept { + std::vector _nodes; + for (const auto& n: nodes) { + for (const auto& m: l) { + if (n->get_node_type() == m) { + _nodes.push_back(n); + break; + } + } + } + return _nodes; +} + } // namespace symtab } // namespace nmodl diff --git a/src/symtab/symbol.hpp b/src/symtab/symbol.hpp index e6c8dddfb1..3e7bf67afb 100644 --- a/src/symtab/symbol.hpp +++ b/src/symtab/symbol.hpp @@ -15,6 +15,7 @@ #include #include +#include "ast/ast_decl.hpp" #include "lexer/modtoken.hpp" #include "symtab/symbol_properties.hpp" @@ -61,14 +62,12 @@ class Symbol { /// unique id or index position when symbol is inserted into specific table int id = 0; - /// first AST node for which symbol is inserted - /// Variable can appear multiple times in the mod file. This node - /// represent the first occurance of the variable in the input. Currently - /// we don't track all AST nodes. - ast::Ast* node = nullptr; + /// All given AST nodes for this symbol. + /// Variable can appear multiple times in the mod file. + std::vector nodes{}; /// token associated with symbol (from node) - ModToken token; + ModToken token{}; /// properties of symbol as a result of usage across whole mod file syminfo::NmodlType properties{syminfo::NmodlType::empty}; @@ -112,9 +111,13 @@ class Symbol { Symbol() = delete; + Symbol(std::string name) + : name(std::move(name)) {} + Symbol(std::string name, ast::Ast* node) - : name(std::move(name)) - , node(node) {} + : name(std::move(name)) { + nodes.push_back(node); + } Symbol(std::string name, ModToken token) : name(std::move(name)) @@ -122,8 +125,9 @@ class Symbol { Symbol(std::string name, ast::Ast* node, ModToken token) : name(std::move(name)) - , node(node) - , token(std::move(token)) {} + , token(std::move(token)) { + nodes.push_back(node); + } /// \} @@ -236,10 +240,17 @@ class Symbol { return status; } - ast::Ast* get_node() const noexcept { - return node; + void add_node(ast::Ast* node) noexcept { + nodes.push_back(node); + } + + std::vector get_nodes() const noexcept { + return nodes; } + std::vector get_nodes_by_type( + std::initializer_list l) const noexcept; + ModToken get_token() const noexcept { return token; } diff --git a/src/symtab/symbol_table.cpp b/src/symtab/symbol_table.cpp index 2a3f7a8a2b..6d3e02f49c 100644 --- a/src/symtab/symbol_table.cpp +++ b/src/symtab/symbol_table.cpp @@ -8,6 +8,7 @@ #include #include "ast/ast.hpp" +#include "ast/ast_decl.hpp" #include "symtab/symbol_table.hpp" #include "utils/logger.hpp" #include "utils/table_data.hpp" @@ -54,15 +55,12 @@ SymbolTable::SymbolTable(const SymbolTable& table) bool SymbolTable::is_method_defined(const std::string& name) const { auto symbol = lookup_in_scope(name); - if (symbol != nullptr) { - auto node = symbol->get_node(); - if (node != nullptr) { - if (node->is_procedure_block() || node->is_function_block()) { - return true; - } - } + if (symbol == nullptr) { + return false; } - return false; + auto nodes = symbol->get_nodes_by_type( + {AstNodeType::FUNCTION_BLOCK, AstNodeType::PROCEDURE_BLOCK}); + return !nodes.empty(); } @@ -202,12 +200,13 @@ std::shared_ptr ModelSymbolTable::lookup(const std::string& name) { void ModelSymbolTable::emit_message(const std::shared_ptr& first, const std::shared_ptr& second, bool redefinition) { - auto node = first->get_node(); + auto nodes = first->get_nodes(); std::string name = first->get_name(); auto properties = to_string(second->get_properties()); std::string type = "UNKNOWN"; - if (node != nullptr) { - type = node->get_node_type_name(); + if (!nodes.empty()) { + // Here we take the first one, because this is a redefinition + type = nodes.front()->get_node_type_name(); } if (redefinition) { @@ -318,6 +317,9 @@ std::shared_ptr ModelSymbolTable::insert(const std::shared_ptr& emit_message(symbol, search_symbol, true); } else { search_symbol->add_properties(symbol->get_properties()); + for (const auto& n: symbol->get_nodes()) { + search_symbol->add_node(n); + } } return search_symbol; } @@ -457,8 +459,9 @@ void SymbolTable::Table::print(std::ostream& stream, std::string title, int inde TableData table; table.title = std::move(title); table.headers = { - "NAME", "PROPERTIES", "STATUS", "LOCATION", "VALUE", "# READS", "# WRITES"}; + "NAME", "# NODES", "PROPERTIES", "STATUS", "LOCATION", "VALUE", "# READS", "# WRITES"}; table.alignments = {text_alignment::left, + text_alignment::left, text_alignment::left, text_alignment::right, text_alignment::right, @@ -482,13 +485,14 @@ void SymbolTable::Table::print(std::ostream& stream, std::string title, int inde auto properties = syminfo::to_string(symbol->get_properties()); auto status = syminfo::to_string(symbol->get_status()); auto reads = std::to_string(symbol->get_read_count()); + auto nodes = std::to_string(symbol->get_nodes().size()); std::string value; auto sym_value = symbol->get_value(); if (sym_value) { value = std::to_string(*sym_value); } auto writes = std::to_string(symbol->get_write_count()); - table.rows.push_back({name, properties, status, position, value, reads, writes}); + table.rows.push_back({name, nodes, properties, status, position, value, reads, writes}); } table.print(stream, indent); } diff --git a/src/visitors/inline_visitor.cpp b/src/visitors/inline_visitor.cpp index b72c9cbdb9..fb43e86271 100644 --- a/src/visitors/inline_visitor.cpp +++ b/src/visitors/inline_visitor.cpp @@ -8,6 +8,7 @@ #include "visitors/inline_visitor.hpp" #include "ast/all.hpp" +#include "ast/ast_decl.hpp" #include "parser/c11_driver.hpp" #include "utils/logger.hpp" #include "visitors/local_var_rename_visitor.hpp" @@ -210,25 +211,21 @@ void InlineVisitor::visit_function_call(FunctionCall& node) { return; } - auto function_definition = symbol->get_node(); - if (function_definition == nullptr) { + auto nodes = symbol->get_nodes_by_type( + {AstNodeType::FUNCTION_BLOCK, AstNodeType::PROCEDURE_BLOCK}); + if (nodes.empty()) { throw std::runtime_error("symbol table doesn't have ast node for " + function_name); } + auto f_block = nodes.front(); /// first inline called function - function_definition->visit_children(*this); + f_block->visit_children(*this); bool inlined = false; - if (function_definition->is_procedure_block()) { - auto proc = dynamic_cast(function_definition); - assert(proc); - inlined = inline_function_call(*proc, node, *caller_block); - } else if (function_definition->is_function_block()) { - auto func = dynamic_cast(function_definition); - assert(func); - inlined = inline_function_call(*func, node, *caller_block); - } + auto block = dynamic_cast(f_block); + assert(block); + inlined = inline_function_call(*block, node, *caller_block); if (inlined) { symbol->mark_inlined(); diff --git a/src/visitors/solve_block_visitor.cpp b/src/visitors/solve_block_visitor.cpp index d7ce65d9a2..59e6b92d70 100644 --- a/src/visitors/solve_block_visitor.cpp +++ b/src/visitors/solve_block_visitor.cpp @@ -47,7 +47,7 @@ ast::SolutionExpression* SolveBlockVisitor::create_solution_expression( throw std::runtime_error( fmt::format("SolveBlockVisitor :: cannot find the block '{}' to solve it", block_name)); } - auto node_to_solve = solve_node_symbol->get_node(); + auto node_to_solve = solve_node_symbol->get_nodes().front(); /// in case of derivimplicit method if neuron solver is used (i.e. not sympy) then /// the solution is not in place but we have to create a callback to newton solver diff --git a/src/visitors/symtab_visitor_helper.hpp b/src/visitors/symtab_visitor_helper.hpp index 477c8e453b..63c9399a98 100644 --- a/src/visitors/symtab_visitor_helper.hpp +++ b/src/visitors/symtab_visitor_helper.hpp @@ -135,7 +135,7 @@ void SymtabVisitor::setup_symbol(ast::Node* node, NmodlType property) { auto name = use_ion->get_name()->get_node_name(); for (const auto& variable: codegen::Ion::get_possible_variables(name)) { std::string ion_variable(codegen::naming::ION_VARNAME_PREFIX + variable); - auto symbol = std::make_shared(ion_variable, nullptr, ModToken()); + auto symbol = std::make_shared(ion_variable); symbol->add_property(NmodlType::codegen_var); modsymtab->insert(symbol); } @@ -169,13 +169,13 @@ static void add_external_symbols(symtab::ModelSymbolTable* symtab) { ModToken tok(true); auto variables = nmodl::get_external_variables(); for (auto variable: variables) { - auto symbol = std::make_shared(variable, nullptr, tok); + auto symbol = std::make_shared(variable, tok); symbol->add_property(NmodlType::extern_neuron_variable); symtab->insert(symbol); } auto methods = nmodl::get_external_functions(); for (auto method: methods) { - auto symbol = std::make_shared(method, nullptr, tok); + auto symbol = std::make_shared(method, tok); symbol->add_property(NmodlType::extern_method); symtab->insert(symbol); } diff --git a/test/unit/symtab/symbol_table.cpp b/test/unit/symtab/symbol_table.cpp index 1aa8d76164..0524259f79 100644 --- a/test/unit/symtab/symbol_table.cpp +++ b/test/unit/symtab/symbol_table.cpp @@ -7,11 +7,14 @@ #define CATCH_CONFIG_MAIN +#include #include #include +#include "ast/float.hpp" #include "ast/program.hpp" +#include "ast/string.hpp" #include "symtab/symbol.hpp" #include "symtab/symbol_table.hpp" @@ -165,7 +168,7 @@ SCENARIO("Symbol table allows operations like insert, lookup") { GIVEN("A global SymbolTable") { auto program = std::make_shared(); auto table = std::make_shared("Na", program.get(), true); - auto symbol = std::make_shared("alpha", ModToken()); + auto symbol = std::make_shared("alpha"); WHEN("checked methods and member variables") { THEN("all members are initialized") { @@ -190,7 +193,7 @@ SCENARIO("Symbol table allows operations like insert, lookup") { } } WHEN("inserting another symbol") { - auto next_symbol = std::make_shared("beta", ModToken()); + auto next_symbol = std::make_shared("beta"); table->insert(next_symbol); THEN("symbol gets added and table size increases") { REQUIRE(table->symbol_count() == 2); @@ -204,7 +207,7 @@ SCENARIO("Symbol table allows operations like insert, lookup") { THEN("table doesn't have any global variables") { REQUIRE(variables.empty()); WHEN("added global symbol") { - auto next_symbol = std::make_shared("gamma", ModToken()); + auto next_symbol = std::make_shared("gamma"); next_symbol->add_property(NmodlType::assigned_definition); table->insert(next_symbol); auto variables = table->get_variables_with_properties( @@ -226,10 +229,10 @@ SCENARIO("Symbol table allows operations like insert, lookup") { } } WHEN("query for symbol with and without properties") { - auto symbol1 = std::make_shared("alpha", ModToken()); - auto symbol2 = std::make_shared("beta", ModToken()); - auto symbol3 = std::make_shared("gamma", ModToken()); - auto symbol4 = std::make_shared("delta", ModToken()); + auto symbol1 = std::make_shared("alpha"); + auto symbol2 = std::make_shared("beta"); + auto symbol3 = std::make_shared("gamma"); + auto symbol4 = std::make_shared("delta"); symbol1->add_property(NmodlType::range_var | NmodlType::param_assign); symbol2->add_property(NmodlType::range_var | NmodlType::param_assign | @@ -280,9 +283,9 @@ SCENARIO("Global symbol table (ModelSymbol) allows scope based operations") { ModelSymbolTable mod_symtab; auto program = std::make_shared(); - auto symbol1 = std::make_shared("alpha", ModToken()); - auto symbol2 = std::make_shared("alpha", ModToken()); - auto symbol3 = std::make_shared("alpha", ModToken()); + auto symbol1 = std::make_shared("alpha"); + auto symbol2 = std::make_shared("alpha"); + auto symbol3 = std::make_shared("alpha"); symbol1->add_property(NmodlType::param_assign); symbol2->add_property(NmodlType::range_var); @@ -303,7 +306,7 @@ SCENARIO("Global symbol table (ModelSymbol) allows scope based operations") { } WHEN("trying to insert without entering scope") { THEN("throws an exception") { - auto symbol = std::make_shared("alpha", ModToken()); + auto symbol = std::make_shared("alpha"); REQUIRE_THROWS_WITH(mod_symtab.insert(symbol), Catch::Contains("Can not insert")); } } @@ -346,3 +349,78 @@ SCENARIO("Global symbol table (ModelSymbol) allows scope based operations") { } } } + +//============================================================================= +// Symbol class tests +//============================================================================= + +SCENARIO("Symbol class allows manipulation") { + GIVEN("A symbol can have several nodes") { + auto st = std::make_shared("node1"); + auto fl = std::make_shared("1.1"); + Symbol symbol1("alpha"); + symbol1.add_node(st.get()); + symbol1.add_node(fl.get()); + + Symbol symbol2("beta"); + + WHEN("trying to get name") { + THEN("it works") { + REQUIRE(symbol1.get_name() == "alpha"); + REQUIRE(symbol2.get_name() == "beta"); + } + } + + WHEN("trying to get all nodes") { + THEN("it works") { + REQUIRE(symbol1.get_nodes().size() == 2); + REQUIRE(symbol2.get_nodes().empty()); + } + } + + WHEN("trying to get specific node") { + auto nodes = symbol1.get_nodes_by_type({ast::AstNodeType::STRING}); + + THEN("it works") { + REQUIRE(nodes.size() == 1); + REQUIRE(nodes.front()->is_string()); + REQUIRE(symbol2.get_nodes_by_type({ast::AstNodeType::STRING}).empty()); + } + } + WHEN("read and write counters works") { + symbol1.read(); + symbol1.read(); + symbol1.write(); + + THEN("it works") { + REQUIRE(symbol1.get_read_count() == 2); + REQUIRE(symbol1.get_write_count() == 1); + REQUIRE(symbol2.get_read_count() == 0); + REQUIRE(symbol2.get_write_count() == 0); + } + } + + WHEN("renaming a symbol") { + symbol2.set_name("gamma"); + THEN("get_name return the new name") { + REQUIRE(symbol2.get_name() == "gamma"); + REQUIRE(symbol2.get_original_name() == "beta"); + } + symbol2.set_original_name("gamma"); + THEN("get_original_name return the new name") { + REQUIRE(symbol2.get_original_name() == "gamma"); + } + } + + WHEN("set as array") { + symbol1.set_as_array(15); + THEN("recognized as an array") { + REQUIRE(symbol1.get_length() == 15); + REQUIRE(symbol1.is_array()); + + REQUIRE(symbol2.get_length() == 1); + REQUIRE(!symbol2.is_array()); + } + } + } +}