Skip to content

Commit

Permalink
Add more iterator support to python interfaces (#38)
Browse files Browse the repository at this point in the history
  • Loading branch information
liambenson authored Feb 23, 2022
1 parent 9942998 commit d07d536
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 20 deletions.
6 changes: 3 additions & 3 deletions lib/ros_msg_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ class RosMsgTypes{
}

members_.reserve(parsed_info.members.size());
field_indexes_ = std::make_shared<std::unordered_map<std::string, const size_t>>();
field_indexes_ = std::make_shared<std::unordered_map<std::string, size_t>>();
field_indexes_->reserve(num_fields);
size_t field_num = 0;
for (const auto& member : parsed_info.members) {
Expand Down Expand Up @@ -160,7 +160,7 @@ class RosMsgTypes{
}
}

const std::shared_ptr<std::unordered_map<std::string, const size_t>>& fieldIndexes() const {
const std::shared_ptr<std::unordered_map<std::string, size_t>>& fieldIndexes() const {
return field_indexes_;
}

Expand Down Expand Up @@ -188,7 +188,7 @@ class RosMsgTypes{
}

private:
std::shared_ptr<std::unordered_map<std::string, const size_t>> field_indexes_;
std::shared_ptr<std::unordered_map<std::string, size_t>> field_indexes_;
std::vector<MemberDef> members_;
const std::string name_;
std::string scope_;
Expand Down
16 changes: 9 additions & 7 deletions lib/ros_value.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ class RosValue {
return {value_, index_++};
}
protected:
const_iterator_base(const RosValue& value, size_t index)
const_iterator_base(const RosValue& value, IndexType index)
: value_(value)
, index_(index)
{
Expand Down Expand Up @@ -152,7 +152,7 @@ class RosValue {
class const_iterator<ReturnType, std::unordered_map<std::string, size_t>::const_iterator> : public const_iterator_base<ReturnType, std::unordered_map<std::string, size_t>::const_iterator, const_iterator<ReturnType, std::unordered_map<std::string, size_t>::const_iterator>> {
public:
const_iterator(const RosValue& value, std::unordered_map<std::string, size_t>::const_iterator index)
: const_iterator_base<ReturnType, size_t, std::unordered_map<std::string, size_t>::const_iterator>(value, index)
: const_iterator_base<ReturnType, std::unordered_map<std::string, size_t>::const_iterator, const_iterator<ReturnType, std::unordered_map<std::string, size_t>::const_iterator>>(value, index)
{
if (value.type_ != Type::object) {
throw std::runtime_error("Cannot iterate the keys or key/value pairs of an non-object RosValue");
Expand All @@ -176,15 +176,15 @@ class RosValue {
throw std::runtime_error("Cannot iterate over the items of a RosValue that is not an object");
}

return RosValue::const_iterator<IteratorReturnType, std::unordered_map<std::string, size_t>::const_iterator>(*this, object_info_.field_indexes->begin());
return RosValue::const_iterator<IteratorReturnType, std::unordered_map<std::string, size_t>::const_iterator>(*this, object_info_.field_indexes->cbegin());
}
template<class IteratorReturnType>
const_iterator<IteratorReturnType, std::unordered_map<std::string, size_t>::const_iterator> endItems() const {
if (type_ != Type::object) {
throw std::runtime_error("Cannot iterate over the items of a RosValue that is not an object");
}

return RosValue::const_iterator<IteratorReturnType, std::unordered_map<std::string, size_t>::const_iterator>(*this, object_info_.field_indexes->end());
return RosValue::const_iterator<IteratorReturnType, std::unordered_map<std::string, size_t>::const_iterator>(*this, object_info_.field_indexes->cend());
}

private:
Expand All @@ -206,7 +206,7 @@ class RosValue {
throw std::runtime_error("Cannot create an object or array with this constructor");
}
}
RosValue(const std::shared_ptr<std::unordered_map<std::string, const size_t>>& field_indexes)
RosValue(const std::shared_ptr<std::unordered_map<std::string, size_t>>& field_indexes)
: type_(Type::object)
, object_info_()
{
Expand Down Expand Up @@ -240,7 +240,7 @@ class RosValue {
destroy_object_info();
}

RosValue operator=(const RosValue& other) {
RosValue& operator=(const RosValue& other) {
if (type_ != other.type_) {
destroy_object_info();
}
Expand All @@ -255,6 +255,8 @@ class RosValue {
} else {
primitive_info_ = other.primitive_info_;
}

return *this;
}

void destroy_object_info() {
Expand Down Expand Up @@ -348,7 +350,7 @@ class RosValue {

struct object_info_t {
ros_value_list_t children;
std::shared_ptr<std::unordered_map<std::string, const size_t>> field_indexes;
std::shared_ptr<std::unordered_map<std::string, size_t>> field_indexes;
};

Type type_;
Expand Down
9 changes: 9 additions & 0 deletions python/adapters.h
Original file line number Diff line number Diff line change
Expand Up @@ -249,4 +249,13 @@ const py::object RosValue::const_iterator<py::object, size_t>::operator*() const
return castValue(value_.at(index_));
}

template<>
const py::str RosValue::const_iterator<py::str, std::unordered_map<std::string, size_t>::const_iterator>::operator*() const {
return index_->first;
}

template<>
const py::tuple RosValue::const_iterator<py::tuple, std::unordered_map<std::string, size_t>::const_iterator>::operator*() const {
return py::make_tuple(index_->first, castValue(value_.object_info_.children.at(index_->second)));
}
}
26 changes: 23 additions & 3 deletions python/embag.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,16 +100,36 @@ PYBIND11_MODULE(libembag, m) {
}, py::arg("path") = "")
.def("__iter__", [](Embag::RosValue::Pointer &v) {
switch (v->getType()) {
// TODO: Allow object iteration
case Embag::RosValue::Type::array:
case Embag::RosValue::Type::primitive_array:
{
return py::make_iterator(v->beginValues<py::object>(), v->endValues<py::object>());
}
case Embag::RosValue::Type::object:
return py::make_iterator(v->beginItems<py::str>(), v->endItems<py::str>());
default:
throw std::runtime_error("Can only iterate array RosValues");
}
}, py::keep_alive<0, 1>() /* Essential: keep object alive while iterator exists */)
.def("items", [](Embag::RosValue::Pointer &v) {
if (v->getType() != Embag::RosValue::Type::object) {
throw std::runtime_error("Cannot get items of a non-object RosValue");
}

return py::make_iterator(v->beginItems<py::tuple>(), v->endItems<py::tuple>());
}, py::keep_alive<0, 1>() /* Essential: keep object alive while iterator exists */)
.def("keys", [](Embag::RosValue::Pointer &v) {
if (v->getType() != Embag::RosValue::Type::object) {
throw std::runtime_error("Cannot get keys of a non-object RosValue");
}

return py::make_iterator(v->beginItems<py::str>(), v->endItems<py::str>());
}, py::keep_alive<0, 1>() /* Essential: keep object alive while iterator exists */)
.def("values", [](Embag::RosValue::Pointer &v) {
if (v->getType() != Embag::RosValue::Type::object) {
throw std::runtime_error("Cannot get values of a non-object RosValue");
}

return py::make_iterator(v->beginValues<py::object>(), v->endValues<py::object>());
}, py::keep_alive<0, 1>() /* Essential: keep object alive while iterator exists */)
.def("get", getField)
.def("__getattr__", getField)
.def("__getitem__", [](Embag::RosValue::Pointer &v, const std::string &key) {
Expand Down
25 changes: 18 additions & 7 deletions test/embag_test_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,7 @@ def setUp(self):
'scope': 'sensor_msgs', 'md5sum': '1158d486dd51d683ce2f1be655c3c181',
'callerid': '/play_1604515189845695821', 'latching': False, 'message_count': 5},
}

def tearDown(self):
self.bag.close()

def testSchema(self):
known_schema = OrderedDict([
self.known_pointcloud_schema = OrderedDict([
('header',
{'type': 'object',
'children': OrderedDict([('seq', {'type': 'uint32'}),
Expand All @@ -52,8 +47,12 @@ def testSchema(self):
('is_dense', {'type': 'bool'})
])

def tearDown(self):
self.bag.close()

def testSchema(self):
schema = self.bag.getSchema('/luminar_pointcloud')
self.assertDictEqual(schema, known_schema)
self.assertDictEqual(schema, self.known_pointcloud_schema)

def testTopicsInBag(self):
topics = set(self.bag.topics())
Expand Down Expand Up @@ -153,6 +152,18 @@ def testViewMessages(self):
for v in msg_data['pose']['covariance']:
self.assertEqual(v, 0)

def testObjectIterators(self):
for topic, msg, t in self.bag.read_messages(topics=['/luminar_pointcloud']):
assert {field_name for field_name in msg} == {field_name for field_name in self.known_pointcloud_schema}
for field_name, value in msg.items():
if isinstance(value, embag.RosValue):
assert str(msg[field_name]) == str(value)
else:
assert msg[field_name] == value
for field_name in msg.keys():
assert field_name in self.known_pointcloud_schema
assert set(str(v) for v in msg.values()) == {str(msg[field_name]) for field_name in msg}

def testTopicsInView(self):
topics = set(self.view.topics())
self.assertSetEqual(topics, self.known_topics)
Expand Down

0 comments on commit d07d536

Please sign in to comment.