Skip to content

Commit

Permalink
std::function<double()> as alternative to double* (#37)
Browse files Browse the repository at this point in the history
* Also added unit test for std::function<double()> elements.

Co-authored-by: Jorge Blanco Alonso <[email protected]>
  • Loading branch information
olupton and jorblancoa authored Apr 28, 2023
1 parent 4910364 commit b95d1c8
Show file tree
Hide file tree
Showing 7 changed files with 131 additions and 14 deletions.
17 changes: 17 additions & 0 deletions include/bbp/sonata/reports.h
Original file line number Diff line number Diff line change
Expand Up @@ -192,4 +192,21 @@ char* sonata_restore(uint64_t node_id, const int* piece_count, const int* length

#if defined(__cplusplus)
}

#include <functional>

/**
* \brief Add an element value to an existing node on a report
*
* C++-only analogue of sonata_add_element that takes a generic handle to element_value instead of a
* raw pointer.
*
* \return 0 if operator succeeded, -2 if the report doesn't exist, -3 if the specified node
* doesn't exist, -1 for other errors.
*/
int sonata_add_element_handle(const char* report_name,
const char* population_name,
uint64_t node_id,
uint32_t element_id,
std::function<double()> element_value);
#endif
46 changes: 42 additions & 4 deletions src/data/node.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
#include <algorithm>
#include <cassert>
#include <stdexcept>

#include "node.h"

Expand All @@ -9,18 +11,54 @@ Node::Node(uint64_t node_id)
: node_id_(node_id) {}

void Node::add_element(double* element_value, uint32_t element_id) {
if (!element_handles_.empty()) {
throw std::runtime_error(
"bbp::sonata::Node::add_element: mixing raw pointers and generic handles is not "
"supported");
}
elements_.push_back(element_value);
element_ids_.push_back(element_id);
}

void Node::add_element(std::function<double()> element_value, uint32_t element_id) {
if (!elements_.empty()) {
throw std::runtime_error(
"bbp::sonata::Node::add_element: mixing raw pointers and generic handles is not "
"supported");
}
element_handles_.push_back(std::move(element_value));
element_ids_.push_back(element_id);
}

void Node::fill_data(std::vector<float>::iterator it) {
std::transform(elements_.begin(), elements_.end(), it, [](auto elem) {
return static_cast<float>(*elem);
});
assert(elements_.empty() || element_handles_.empty());
if (!elements_.empty()) {
std::transform(elements_.begin(), elements_.end(), it, [](auto elem) -> float {
return *elem;
});
} else if (!element_handles_.empty()) {
std::transform(element_handles_.begin(),
element_handles_.end(),
it,
[](auto const& elem) -> float { return elem(); });
}
}

void Node::refresh_pointers(std::function<double*(double*)> refresh_function) {
std::transform(elements_.begin(), elements_.end(), elements_.begin(), refresh_function);
if (!elements_.empty()) {
std::transform(elements_.begin(), elements_.end(), elements_.begin(), refresh_function);
} else if (!element_handles_.empty()) {
std::transform(element_handles_.begin(),
element_handles_.end(),
element_handles_.begin(),
[&refresh_function](auto const& elem) -> std::function<double()> {
return [elem, refresh_function]() -> double {
double value = elem();
double* refreshed_value = refresh_function(&value);
return *refreshed_value;
};
});
}
}

} // namespace sonata
Expand Down
4 changes: 3 additions & 1 deletion src/data/node.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@ class Node
void fill_data(std::vector<float>::iterator it);
void refresh_pointers(std::function<double*(double*)> refresh_function);
virtual void add_element(double* element_value, uint32_t element_id);
virtual void add_element(std::function<double()> element_value, uint32_t element_id);

uint64_t get_node_id() const noexcept {
return node_id_;
}
virtual size_t get_num_elements() const noexcept {
return elements_.size();
return element_ids_.size();
}
const std::vector<uint32_t>& get_element_ids() const noexcept {
return element_ids_;
Expand All @@ -33,6 +34,7 @@ class Node
protected:
std::vector<uint32_t> element_ids_;
std::vector<double*> elements_;
std::vector<std::function<double()>> element_handles_;
};

using nodes_t = std::map<uint64_t, std::shared_ptr<Node>>;
Expand Down
10 changes: 9 additions & 1 deletion src/data/soma_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,20 @@ SomaNode::SomaNode(uint64_t node_id)
: Node(node_id) {}

void SomaNode::add_element(double* element_value, uint32_t element_id) {
if (!elements_.empty()) {
if (!elements_.empty() || !element_handles_.empty()) {
throw std::runtime_error("ERROR: Soma report nodes can only have 1 element");
}
elements_.push_back(element_value);
element_ids_.push_back(element_id);
}

void SomaNode::add_element(std::function<double()> element_value, uint32_t element_id) {
if (!elements_.empty() || !element_handles_.empty()) {
throw std::runtime_error("ERROR: Soma report nodes can only have 1 element");
}
element_handles_.push_back(std::move(element_value));
element_ids_.push_back(element_id);
}

} // namespace sonata
} // namespace bbp
1 change: 1 addition & 0 deletions src/data/soma_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ class SomaNode: public Node
SomaNode(uint64_t node_id);

void add_element(double* element_value, uint32_t element_id) override;
void add_element(std::function<double()> element_value, uint32_t element_id) override;
size_t get_num_elements() const noexcept override {
return elements_.empty() ? 0 : 1;
};
Expand Down
32 changes: 26 additions & 6 deletions src/library/reports.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,18 +44,20 @@ int sonata_add_node(const char* report_name,
return 0;
}

int sonata_add_element(const char* report_name,
const char* population_name,
uint64_t node_id,
uint32_t element_id,
double* voltage) {
namespace {
template <typename T>
int sonata_add_element_impl(const char* report_name,
const char* population_name,
uint64_t node_id,
uint32_t element_id,
T&& voltage) {
if (!sonata_report.report_exists(report_name)) {
return -2;
}
try {
auto report = sonata_report.get_report(report_name);
auto node = report->get_node(population_name, node_id);
node->add_element(voltage, element_id);
node->add_element(std::forward<T>(voltage), element_id);
} catch (const std::out_of_range& err) {
logger->error(err.what());
return -3;
Expand All @@ -65,6 +67,24 @@ int sonata_add_element(const char* report_name,
}
return 0;
}
} // namespace

int sonata_add_element(const char* report_name,
const char* population_name,
uint64_t node_id,
uint32_t element_id,
double* voltage) {
return sonata_add_element_impl(report_name, population_name, node_id, element_id, voltage);
}

int sonata_add_element_handle(const char* report_name,
const char* population_name,
uint64_t node_id,
uint32_t element_id,
std::function<double()> element_value) {
return sonata_add_element_impl(
report_name, population_name, node_id, element_id, std::move(element_value));
}

void sonata_setup_communicators() {
sonata_report.create_communicators();
Expand Down
35 changes: 33 additions & 2 deletions tests/unit/test_node.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#include <catch2/catch.hpp>
#include <memory>
#include <data/node.h>
#include <data/soma_node.h>
#include <memory>
#include <spdlog/spdlog.h>

using namespace bbp::sonata;
Expand All @@ -26,7 +26,7 @@ SCENARIO("Test Node class", "[Node]") {
REQUIRE_NOTHROW(node.refresh_pointers(&square));
}

WHEN("We add a element") {
WHEN("We add a raw pointer element") {
std::vector<double> elements = {10, 11, 12, 13, 14};
size_t i = 0;
for (auto& element : elements) {
Expand All @@ -53,6 +53,37 @@ SCENARIO("Test Node class", "[Node]") {
REQUIRE(result == compare);
}
}

WHEN("We add a std::function<double()> element") {
std::vector<std::function<double()>> elements = {[]() { return 10.0; },
[]() { return 11.0; },
[]() { return 12.0; }};
size_t i = 0;
for (auto& element_handle : elements) {
node.add_element(element_handle, i);
++i;
}
THEN("Number of elements is 3") {
REQUIRE(node.get_num_elements() == 3);
}
THEN("The element_ids are") {
std::vector<uint32_t> compare = {0, 1, 2};
REQUIRE(node.get_element_ids() == compare);
}
THEN("fill_data will return something correct") {
std::vector<float> result(3, -1.0);
node.fill_data(result.begin());
std::vector<float> compare = {10.0, 11.0, 12.0};
REQUIRE(result == compare);
}
THEN("refresh_pointers will call the function on all elements") {
node.refresh_pointers(&square);
std::vector<float> compare{100, 121, 144};
std::vector<float> result(3, -1);
node.fill_data(result.begin());
REQUIRE(result == compare);
}
}
}

GIVEN("An instance of a soma node") {
Expand Down

0 comments on commit b95d1c8

Please sign in to comment.