Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add buffer interface for RosValue arrays #33

Merged
merged 14 commits into from
Dec 9, 2021
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,17 @@ To test, run:
# This will run both the C++ and Python tests against a small bag file
bazel test test:* --test_output=all

NOTE: If you're testing the python2 or python3 interface, you'll need to ensure that your system has numpy installed for each respective python version.

### Usage
To use the C++ API:
```c++
Embag::View view{filename};
view.addBag("another.bag"); # Views support reading from multiple bags

for (const auto &message : view.getMessages({"/fun/topic", "/another/topic"})) {
std::cout << message->timestamp.to_sec() << " : " << message->topic << std::endl;
std::cout << message->data()["fun_array"][0]["fun_field"].as<std::string>() << std::endl;
std::cout << message->timestamp->to_sec() << " : " << message->topic << std::endl;
std::cout << message->data()["fun_array"][0]["fun_field"]->as<std::string>() << std::endl;
}
```
See the [tests](https://github.com/embarktrucks/embag/tree/master/test) for more usage examples.
Expand Down
9 changes: 0 additions & 9 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -54,15 +54,6 @@ git_repository(
shallow_since = "1593046824 -0700",
)

load("@rules_python//python:repositories.bzl", "py_repositories")

py_repositories()

# Only needed if using the packaging rules.
load("@rules_python//python:pip.bzl", "pip_repositories")

pip_repositories()

# GTest
git_repository(
name = "gtest",
Expand Down
5 changes: 3 additions & 2 deletions lib/BUILD
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
load("@rules_cc//cc:defs.bzl", "cc_binary", "cc_library")
load("@rules_cc//cc:defs.bzl", "cc_binary")
load("@pybind11_bazel//:build_defs.bzl", "pybind_library")
load("//lib:version.bzl", "EMBAG_VERSION")

cc_library(
pybind_library(
name = "embag",
srcs = [
"embag.cc",
Expand Down
32 changes: 16 additions & 16 deletions lib/embag.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,24 +7,24 @@
namespace Embag {

const RosMsgTypes::primitive_type_map_t RosMsgTypes::FieldDef::primitive_type_map = {
{"bool", {RosValue::Type::ros_bool, sizeof(bool)}},
{"int8", {RosValue::Type::int8, sizeof(int8_t)}},
{"uint8", {RosValue::Type::uint8, sizeof(uint8_t)}},
{"int16", {RosValue::Type::int16, sizeof(int16_t)}},
{"uint16", {RosValue::Type::uint16, sizeof(uint16_t)}},
{"int32", {RosValue::Type::int32, sizeof(int32_t)}},
{"uint32", {RosValue::Type::uint32, sizeof(uint32_t)}},
{"int64", {RosValue::Type::int64, sizeof(int64_t)}},
{"uint64", {RosValue::Type::uint64, sizeof(uint64_t)}},
{"float32", {RosValue::Type::float32, sizeof(float)}},
{"float64", {RosValue::Type::float64, sizeof(double)}},
{"string", {RosValue::Type::string, 0}}, // The size of string is unknown!
{"time", {RosValue::Type::ros_time, sizeof(RosValue::ros_time_t)}},
{"duration", {RosValue::Type::ros_duration, sizeof(RosValue::ros_duration_t)}},
{"bool", RosValue::Type::ros_bool},
{"int8", RosValue::Type::int8},
{"uint8", RosValue::Type::uint8},
{"int16", RosValue::Type::int16},
{"uint16", RosValue::Type::uint16},
{"int32", RosValue::Type::int32},
{"uint32", RosValue::Type::uint32},
{"int64", RosValue::Type::int64},
{"uint64", RosValue::Type::uint64},
{"float32", RosValue::Type::float32},
{"float64", RosValue::Type::float64},
{"string", RosValue::Type::string},
{"time", RosValue::Type::ros_time},
{"duration", RosValue::Type::ros_duration},

// Deprecated types
{"byte", {RosValue::Type::int8, sizeof(int8_t)}},
{"char", {RosValue::Type::uint8, sizeof(uint8_t)}},
{"byte", RosValue::Type::int8},
{"char", RosValue::Type::uint8},
};

void Bag::BagFromFile::open(const std::string &path) {
Expand Down
2 changes: 1 addition & 1 deletion lib/message_parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ void MessageParser::initPrimitive(size_t primitive_offset, const RosMsgTypes::Fi
primitive.primitive_info_.offset = message_buffer_offset_;

if (field.type() == RosValue::Type::string) {
message_buffer_offset_ += *reinterpret_cast<const uint32_t* const>(primitive.getPrimitivePointer()) + sizeof(uint32_t);
message_buffer_offset_ += primitive.getPrimitive<uint32_t>() + sizeof(uint32_t);
} else {
message_buffer_offset_ += field.typeSize();
}
Expand Down
20 changes: 13 additions & 7 deletions lib/ros_msg_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@ class RosMsgTypes{

// Schema stuff
// TODO: move this stuff elsewhere?
typedef std::unordered_map<std::string, const std::pair<const RosValue::Type, const size_t>> primitive_type_map_t;
typedef std::unordered_map<std::string, const RosValue::Type> primitive_type_map_t;
class FieldDef {
public:
const static primitive_type_map_t primitive_type_map;
static size_t typeToSize(const RosValue::Type type);

struct parseable_info_t {
std::string type_name;
Expand All @@ -28,12 +29,13 @@ class RosMsgTypes{
, type_definition_(nullptr)
{
if (primitive_type_map.count(parsed_info_.type_name)) {
const std::pair<RosValue::Type, size_t>& type_info = primitive_type_map.at(parsed_info_.type_name);
type_ = type_info.first;
type_size_ = type_info.second;
type_ = primitive_type_map.at(parsed_info_.type_name);
if (type_ != RosValue::Type::string) {
// Cache the size of the type for quicker access.
type_size_ = RosValue::primitiveTypeToSize(type_);
}
} else {
type_ = RosValue::Type::object;
type_size_ = 0;
}
}

Expand All @@ -54,6 +56,10 @@ class RosMsgTypes{
}

size_t typeSize() const {
if (type_ == RosValue::Type::object || type_ == RosValue::Type::string) {
throw std::runtime_error("The size of an object or string is not statically known!");
}

return type_size_;
}

Expand Down Expand Up @@ -91,8 +97,8 @@ class RosMsgTypes{
// To maintain performance, we cache this information in each instance of the class.
// If this field is an array, holds the type of the items within the array.
RosValue::Type type_;
// If this field is an object, the size will be 0.
size_t type_size_;
// If this field is an object or a string, the size will be 0.
size_t type_size_ = 0;

// TODO: This can be stored in union with the size_t to reduce space
// If this is an object, cache the associated ros_embedded_msg_def
Expand Down
77 changes: 75 additions & 2 deletions lib/ros_value.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ const std::string RosValue::as<std::string>() const {
throw std::runtime_error("Cannot call as<std::string> for a non string");
}

const uint32_t string_length = *reinterpret_cast<const uint32_t* const>(getPrimitivePointer());
const char* const string_loc = reinterpret_cast<const char* const>(getPrimitivePointer() + sizeof(uint32_t));
const uint32_t string_length = getPrimitive<uint32_t>();
const char* const string_loc = &getPrimitive<char>() + sizeof(uint32_t);
return std::string(string_loc, string_loc + string_length);
}

Expand Down Expand Up @@ -130,6 +130,79 @@ void RosValue::print(const std::string &path) const {
std::cout << toString(path);
}

size_t RosValue::primitiveTypeToSize(const Type type) {
switch (type) {
case (Type::ros_bool):
return sizeof(bool);
case (Type::int8):
return sizeof(int8_t);
case (Type::uint8):
return sizeof(uint8_t);
case (Type::int16):
return sizeof(int16_t);
case (Type::uint16):
return sizeof(uint16_t);
case (Type::int32):
return sizeof(int32_t);
case (Type::uint32):
return sizeof(uint32_t);
case (Type::int64):
return sizeof(int64_t);
case (Type::uint64):
return sizeof(uint64_t);
case (Type::float32):
return sizeof(float);
case (Type::float64):
return sizeof(double);
case (Type::ros_time):
return sizeof(ros_time_t);
case (Type::ros_duration):
return sizeof(ros_duration_t);
case (Type::string):
case (Type::array):
case (Type::object):
default:
throw std::runtime_error("Provided type is a string or a non-primitive!");
}
}

std::string RosValue::primitiveTypeToFormat(const Type type) {
switch (type) {
case (Type::ros_bool):
return "?";
case (Type::int8):
return "b";
case (Type::uint8):
return "B";
case (Type::int16):
return "h";
case (Type::uint16):
return "H";
case (Type::int32):
return "i";
case (Type::uint32):
return "I";
case (Type::int64):
return "q";
case (Type::uint64):
return "Q";
case (Type::float32):
return "f";
case (Type::float64):
return "d";
case (Type::ros_time):
return "II";
case (Type::ros_duration):
return "II";
case (Type::string):
throw std::runtime_error("Strings do not have a struct format!");
case (Type::array):
case (Type::object):
default:
throw std::runtime_error("Provided type is not a primitive!");
}
}

/*
--------------
ITERATOR SETUP
Expand Down
36 changes: 33 additions & 3 deletions lib/ros_value.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <memory>
#include <cstdint>
#include <cstring>
#include <pybind11/buffer_info.h>
#include <string>
#include <unordered_map>
#include <vector>
Expand Down Expand Up @@ -73,6 +74,8 @@ class RosValue {
object,
array,
};
static size_t primitiveTypeToSize(const Type type);
static std::string primitiveTypeToFormat(const Type type);

struct ros_time_t {
uint32_t secs = 0;
Expand Down Expand Up @@ -274,7 +277,7 @@ class RosValue {
}

// TODO: Add check that the underlying type aligns with T
return *reinterpret_cast<const T*>(getPrimitivePointer());
return getPrimitive<T>();
}

bool has(const std::string &key) const {
Expand Down Expand Up @@ -320,6 +323,32 @@ class RosValue {
return values;
}

// This interface is used to provide a buffer_info interface to python bindings.
// The buffer_info object essentially provides the python runtime with a way
// to directly access the underlying memory that an object contains, and thus
// operate on it in a much more optimized fashion.
pybind11::buffer_info getPrimitiveArrayBufferInfo() {
if (getType() != Embag::RosValue::Type::array) {
throw std::runtime_error("Only arrays can be represented as buffers!");
}

const Embag::RosValue::Type type_of_elements = at(0)->getType();
if (type_of_elements == Embag::RosValue::Type::object || type_of_elements == Embag::RosValue::Type::string) {
throw std::runtime_error("In order to be represented as a buffer, an array's elements must not be objects or strings!");
}

const size_t size_of_elements = Embag::RosValue::primitiveTypeToSize(type_of_elements);
return pybind11::buffer_info(
(void*) &at(0)->getPrimitive<uint8_t>(),
size_of_elements,
Embag::RosValue::primitiveTypeToFormat(type_of_elements),
1,
{ size() },
{ size_of_elements },
true
);
}

std::string toString(const std::string &path = "") const;
void print(const std::string &path = "") const;

Expand Down Expand Up @@ -349,8 +378,9 @@ class RosValue {
object_info_t object_info_;
};

const char* const getPrimitivePointer() const {
return &primitive_info_.message_buffer->at(primitive_info_.offset);
template<typename T>
const T& getPrimitive() const {
return reinterpret_cast<const T&>(primitive_info_.message_buffer->at(primitive_info_.offset));
}

const ros_value_list_t& getChildren() const {
Expand Down
3 changes: 3 additions & 0 deletions pip_package/build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ function build() {
PYTHON_PATH=$1
PYTHON_VERSION=$2

# Install necessary dependencies
"$PYTHON_PATH/pip" install -r /tmp/pip_build/requirements.txt

# Build embag libs and echo test binary
(cd /tmp/embag &&
PYTHON_BIN_PATH="$PYTHON_PATH/python" bazel build -c opt //python:libembag.so //embag_echo:embag_echo &&
Expand Down
3 changes: 2 additions & 1 deletion pip_package/macos_build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@ mkdir -p /tmp/embag /tmp/pip_build /tmp/out
cp -r lib /tmp/embag
cp -r pip_package/* README.md LICENSE /tmp/pip_build

python -m pip install -r /tmp/pip_build/requirements.txt

# Build embag libs and echo test binary
bazel build -c opt //python:libembag.so //embag_echo:embag_echo && \
bazel test //test:embag_test //test:embag_test_python3 --test_output=all

# Build wheel
cp bazel-bin/python/libembag.so /tmp/pip_build/embag
python -m pip install wheel
(cd /tmp/pip_build && python setup.py bdist_wheel && \
# FIXME
#python -m pip install dist/embag*.whl && \
Expand Down
3 changes: 3 additions & 0 deletions pip_package/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
cython
numpy
wheel
3 changes: 2 additions & 1 deletion python/embag.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ PYBIND11_MODULE(libembag, m) {
.def_readonly("md5", &Embag::RosMessage::md5)
.def_readonly("raw_data_len", &Embag::RosMessage::raw_data_len);

auto ros_value = py::class_<Embag::RosValue, Embag::VectorItemPointer<Embag::RosValue>>(m, "RosValue", py::dynamic_attr())
auto ros_value = py::class_<Embag::RosValue, Embag::VectorItemPointer<Embag::RosValue>>(m, "RosValue", py::dynamic_attr(), py::buffer_protocol())
.def_buffer(&Embag::RosValue::getPrimitiveArrayBufferInfo)
.def("get", &Embag::RosValue::get)
.def("getType", &Embag::RosValue::getType)
.def("__len__", &Embag::RosValue::size)
Expand Down
8 changes: 8 additions & 0 deletions test/embag_test_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from collections import OrderedDict
import struct
import unittest
import numpy as np


class EmbagTest(unittest.TestCase):
Expand Down Expand Up @@ -170,5 +171,12 @@ def testBagFromBytes(self):
self.testConnectionsInView()
bag_stream.close()

def testBufferInfo(self):
for msg in self.view.getMessages('/base_pose_ground_truth'):
covariance_array = msg.data()['pose']['covariance']
self.assertTrue(memoryview(covariance_array).readonly)
for covariance in np.array(covariance_array, copy=False):
self.assertEqual(covariance, 0)

if __name__ == "__main__":
unittest.main()