Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add buffer interface for RosValue arrays #33

Merged
merged 14 commits into from
Dec 9, 2021
14 changes: 8 additions & 6 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -54,15 +54,17 @@ git_repository(
shallow_since = "1593046824 -0700",
)

load("@rules_python//python:repositories.bzl", "py_repositories")

py_repositories()

# Only needed if using the packaging rules.
load("@rules_python//python:pip.bzl", "pip_repositories")
load("@rules_python//python:pip.bzl", "pip_repositories", "pip_import")

pip_repositories()

pip_import(
name = "test_python_requirements",
requirements = "//test:requirements.txt",
)
load("@test_python_requirements//:requirements.bzl", pip_install_test_requirements = "pip_install")
pip_install_test_requirements()

# GTest
git_repository(
name = "gtest",
Expand Down
5 changes: 3 additions & 2 deletions lib/BUILD
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
load("@rules_cc//cc:defs.bzl", "cc_binary", "cc_library")
load("@rules_cc//cc:defs.bzl", "cc_binary")
load("@pybind11_bazel//:build_defs.bzl", "pybind_library")
load("//lib:version.bzl", "EMBAG_VERSION")

cc_library(
pybind_library(
name = "embag",
srcs = [
"embag.cc",
Expand Down
2 changes: 1 addition & 1 deletion lib/message_parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ void MessageParser::initPrimitive(size_t primitive_offset, const RosMsgTypes::Fi
primitive.primitive_info_.offset = message_buffer_offset_;

if (field.type() == RosValue::Type::string) {
message_buffer_offset_ += *reinterpret_cast<const uint32_t* const>(primitive.getPrimitivePointer()) + sizeof(uint32_t);
message_buffer_offset_ += *primitive.getPrimitivePointer<uint32_t>() + sizeof(uint32_t);
} else {
message_buffer_offset_ += field.typeSize();
}
Expand Down
1 change: 1 addition & 0 deletions lib/ros_msg_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class RosMsgTypes{
class FieldDef {
public:
const static primitive_type_map_t primitive_type_map;
static size_t typeToSize(const RosValue::Type type);

struct parseable_info_t {
std::string type_name;
Expand Down
78 changes: 76 additions & 2 deletions lib/ros_value.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ const std::string RosValue::as<std::string>() const {
throw std::runtime_error("Cannot call as<std::string> for a non string");
}

const uint32_t string_length = *reinterpret_cast<const uint32_t* const>(getPrimitivePointer());
const char* const string_loc = reinterpret_cast<const char* const>(getPrimitivePointer() + sizeof(uint32_t));
const uint32_t string_length = *getPrimitivePointer<uint32_t>();
const char* const string_loc = getPrimitivePointer<char>() + sizeof(uint32_t);
return std::string(string_loc, string_loc + string_length);
}

Expand Down Expand Up @@ -130,6 +130,80 @@ void RosValue::print(const std::string &path) const {
std::cout << toString(path);
}

size_t RosValue::primitiveTypeToSize(const Type type) {
switch (type) {
case (Type::ros_bool):
return sizeof(bool);
case (Type::int8):
return sizeof(int8_t);
case (Type::uint8):
return sizeof(uint8_t);
case (Type::int16):
return sizeof(int16_t);
case (Type::uint16):
return sizeof(uint16_t);
case (Type::int32):
return sizeof(int32_t);
case (Type::uint32):
return sizeof(uint32_t);
case (Type::int64):
return sizeof(int64_t);
case (Type::uint64):
return sizeof(uint64_t);
case (Type::float32):
return sizeof(float);
case (Type::float64):
return sizeof(double);
case (Type::string):
return 0; // The size of string is unknown!
case (Type::ros_time):
return sizeof(ros_time_t);
case (Type::ros_duration):
return sizeof(ros_duration_t);
case (Type::array):
case (Type::object):
default:
throw std::runtime_error("Provided type is not a primitive!");
}
}

std::string RosValue::primitiveTypeToFormat(const Type type) {
switch (type) {
case (Type::ros_bool):
return "?";
case (Type::int8):
return "b";
case (Type::uint8):
return "B";
case (Type::int16):
return "h";
case (Type::uint16):
return "H";
case (Type::int32):
return "i";
case (Type::uint32):
return "I";
case (Type::int64):
return "q";
case (Type::uint64):
return "Q";
case (Type::float32):
return "f";
case (Type::float64):
return "d";
case (Type::ros_time):
return "II";
case (Type::ros_duration):
return "II";
case (Type::string):
throw std::runtime_error("Strings do not have a struct format!");
case (Type::array):
case (Type::object):
default:
throw std::runtime_error("Provided type is not a primitive!");
}
}

/*
--------------
ITERATOR SETUP
Expand Down
33 changes: 30 additions & 3 deletions lib/ros_value.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <memory>
#include <cstdint>
#include <cstring>
#include <pybind11/buffer_info.h>
#include <string>
#include <unordered_map>
#include <vector>
Expand Down Expand Up @@ -73,6 +74,8 @@ class RosValue {
object,
array,
};
static size_t primitiveTypeToSize(const Type type);
static std::string primitiveTypeToFormat(const Type type);

struct ros_time_t {
uint32_t secs = 0;
Expand Down Expand Up @@ -274,7 +277,7 @@ class RosValue {
}

// TODO: Add check that the underlying type aligns with T
return *reinterpret_cast<const T*>(getPrimitivePointer());
return *getPrimitivePointer<T>();
}

bool has(const std::string &key) const {
Expand Down Expand Up @@ -320,6 +323,29 @@ class RosValue {
return values;
}

// Used for python bindings

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think very few people are going to be familiar with what pybind11::buffer_info does -
Could you write a better comment explaining what exactly this does and how one might properly use it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure thing!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

pybind11::buffer_info getPrimitiveArrayBufferInfo() {
if (getType() != Embag::RosValue::Type::array) {
throw std::runtime_error("Only arrays can be represented as buffers!");
}

const Embag::RosValue::Type type_of_elements = at(0)->getType();
if (type_of_elements == Embag::RosValue::Type::object || type_of_elements == Embag::RosValue::Type::string) {
throw std::runtime_error("In order to be represented as a buffer, an array's elements must not be objects or strings!");
}

const size_t size_of_elements = Embag::RosValue::primitiveTypeToSize(type_of_elements);
return pybind11::buffer_info(
(void*) at(0)->getPrimitivePointer<void>(),
size_of_elements,
Embag::RosValue::primitiveTypeToFormat(type_of_elements),
1,
{ size() },
{ size_of_elements },
true
);
}

std::string toString(const std::string &path = "") const;
void print(const std::string &path = "") const;

Expand Down Expand Up @@ -349,8 +375,9 @@ class RosValue {
object_info_t object_info_;
};

const char* const getPrimitivePointer() const {
return &primitive_info_.message_buffer->at(primitive_info_.offset);
template<typename T>
const T* const getPrimitivePointer() const {
return reinterpret_cast<const T* const>(&primitive_info_.message_buffer->at(primitive_info_.offset));
}

const ros_value_list_t& getChildren() const {
Expand Down
3 changes: 2 additions & 1 deletion pip_package/macos_build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@ mkdir -p /tmp/embag /tmp/pip_build /tmp/out
cp -r lib /tmp/embag
cp -r pip_package/* README.md LICENSE /tmp/pip_build

python -m pip install cython wheel

# Build embag libs and echo test binary
bazel build -c opt //python:libembag.so //embag_echo:embag_echo && \
bazel test //test:embag_test //test:embag_test_python3 --test_output=all

# Build wheel
cp bazel-bin/python/libembag.so /tmp/pip_build/embag
python -m pip install wheel
(cd /tmp/pip_build && python setup.py bdist_wheel && \
# FIXME
#python -m pip install dist/embag*.whl && \
Expand Down
3 changes: 2 additions & 1 deletion python/embag.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ PYBIND11_MODULE(libembag, m) {
.def_readonly("md5", &Embag::RosMessage::md5)
.def_readonly("raw_data_len", &Embag::RosMessage::raw_data_len);

auto ros_value = py::class_<Embag::RosValue, Embag::VectorItemPointer<Embag::RosValue>>(m, "RosValue", py::dynamic_attr())
auto ros_value = py::class_<Embag::RosValue, Embag::VectorItemPointer<Embag::RosValue>>(m, "RosValue", py::dynamic_attr(), py::buffer_protocol())
.def_buffer(&Embag::RosValue::getPrimitiveArrayBufferInfo)
.def("get", &Embag::RosValue::get)
.def("getType", &Embag::RosValue::getType)
.def("__len__", &Embag::RosValue::size)
Expand Down
8 changes: 8 additions & 0 deletions test/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ cc_test(
],
)

load("@test_python_requirements//:requirements.bzl", "requirement")

py_test(
name = "embag_test_python3",
main = "embag_test_python.py",
Expand All @@ -18,6 +20,9 @@ py_test(
"test.bag",
"//python:libembag.so",
],
deps = [
requirement("numpy"),
],
python_version = "PY3",
)

Expand All @@ -29,5 +34,8 @@ py_test(
"test.bag",
"//python:libembag.so",
],
deps = [
requirement("numpy"),
],
python_version = "PY2",
)
8 changes: 8 additions & 0 deletions test/embag_test_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from collections import OrderedDict
import struct
import unittest
import numpy as np


class EmbagTest(unittest.TestCase):
Expand Down Expand Up @@ -170,5 +171,12 @@ def testBagFromBytes(self):
self.testConnectionsInView()
bag_stream.close()

def testBufferInfo(self):
for msg in self.view.getMessages('/base_pose_ground_truth'):
covariance_array = msg.data()['pose']['covariance']
memoryview(covariance_array)
for covariance in np.array(covariance_array, copy=False):
self.assertEqual(covariance, 0)

if __name__ == "__main__":
unittest.main()
2 changes: 2 additions & 0 deletions test/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
numpy==1.16.6;python_version<"3"
numpy==1.18.1;python_version>="3"