Skip to content

Commit

Permalink
Squashed 'wrap/' changes from bae34fac8..b80bc63cf
Browse files Browse the repository at this point in the history
b80bc63cf Merge pull request #90 from borglab/fix/tpl_dependency
015b12da5 Merge pull request #86 from borglab/feature/optionalargs
362851980 address review comments
e461ca50e Merge pull request #89 from borglab/fix/template_iostream
2d413db57 add pybind cpp generation dependency on tpl file
79881c25e include pybind11 iostream for ostream_redirect in example tpl
5e8323c25 fix test fixture
95495726a Merge branch 'master' into feature/optionalargs
5af826840 clean up the _py_args_names method to reduce copy-pasta
844ff9229 add identifier parsing to _type
c3adca7a4 remove extra spaces from Type repr
350b531d7 slight test improvement
fd4f37578 cleaner default argument parsing
6013deacb overpowered default argument parsing rule
dbcda0ea2 fix unit tests for __repr__ ref  vs ptr
1c23c42e4 fix pointer vs const ref in __repr__
9b40350f1 update matlab tests
df7e9023c handle __repr__ with default arguments
092ef489b update pybind_wrapper for default arguments
3a2d7aa8a unit test default argument pybind
61a2b114e implement default argument parser
c2b92ffec unit test for parsing default arguments

git-subtree-dir: wrap
git-subtree-split: b80bc63cf466f9751e8059c0abb4a4d73b23efbe
  • Loading branch information
varunagrawal committed Apr 17, 2021
1 parent 00a04a1 commit db373df
Show file tree
Hide file tree
Showing 16 changed files with 172 additions and 41 deletions.
1 change: 1 addition & 0 deletions cmake/PybindWrap.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ function(pybind_wrap
--template
${module_template}
${_WRAP_BOOST_ARG}
DEPENDS ${interface_header} ${module_template}
VERBATIM)
add_custom_target(pybind_wrap_${module_name} ALL DEPENDS ${generated_cpp})

Expand Down
24 changes: 19 additions & 5 deletions gtwrap/interface_parser/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
from pyparsing import Optional, ParseResults, delimitedList

from .template import Template
from .tokens import (COMMA, IDENT, LOPBRACK, LPAREN, PAIR, ROPBRACK, RPAREN,
SEMI_COLON)
from .tokens import (COMMA, DEFAULT_ARG, EQUAL, IDENT, LOPBRACK, LPAREN, PAIR,
ROPBRACK, RPAREN, SEMI_COLON)
from .type import TemplatedType, Type


Expand All @@ -29,15 +29,29 @@ class Argument:
void sayHello(/*`s` is the method argument with type `const string&`*/ const string& s);
```
"""
rule = ((Type.rule ^ TemplatedType.rule)("ctype") +
IDENT("name")).setParseAction(lambda t: Argument(t.ctype, t.name))
rule = ((Type.rule ^ TemplatedType.rule)("ctype") + IDENT("name") + \
Optional(EQUAL + (DEFAULT_ARG ^ Type.rule ^ TemplatedType.rule) + \
Optional(LPAREN + RPAREN) # Needed to parse the parens for default constructors
)("default")
).setParseAction(lambda t: Argument(t.ctype, t.name, t.default))

def __init__(self, ctype: Union[Type, TemplatedType], name: str):
def __init__(self,
ctype: Union[Type, TemplatedType],
name: str,
default: ParseResults = None):
if isinstance(ctype, Iterable):
self.ctype = ctype[0]
else:
self.ctype = ctype
self.name = name
# If the length is 1, it's a regular type,
if len(default) == 1:
default = default[0]
# This means a tuple has been passed so we convert accordingly
elif len(default) > 1:
default = tuple(default.asList())
self.default = default

self.parent: Union[ArgumentList, None] = None

def __repr__(self) -> str:
Expand Down
16 changes: 15 additions & 1 deletion gtwrap/interface_parser/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
Author: Duy Nguyen Ta, Fan Jiang, Matthew Sklar, Varun Agrawal, and Frank Dellaert
"""

from pyparsing import Keyword, Literal, Suppress, Word, alphanums, alphas, nums, Or
from pyparsing import (Keyword, Literal, Or, QuotedString, Suppress, Word,
alphanums, alphas, delimitedList, nums,
pyparsing_common)

# rule for identifiers (e.g. variable names)
IDENT = Word(alphas + '_', alphanums + '_') ^ Word(nums)
Expand All @@ -19,6 +21,18 @@

LPAREN, RPAREN, LBRACE, RBRACE, COLON, SEMI_COLON = map(Suppress, "(){}:;")
LOPBRACK, ROPBRACK, COMMA, EQUAL = map(Suppress, "<>,=")

# Encapsulating type for numbers, and single and double quoted strings.
# The pyparsing_common utilities ensure correct coversion to the corresponding type.
# E.g. pyparsing_common.number will convert 3.1415 to a float type.
NUMBER_OR_STRING = (pyparsing_common.number ^ QuotedString('"') ^ QuotedString("'"))

# A python tuple, e.g. (1, 9, "random", 3.1415)
TUPLE = (LPAREN + delimitedList(NUMBER_OR_STRING) + RPAREN)

# Default argument passed to functions/methods.
DEFAULT_ARG = (NUMBER_OR_STRING ^ pyparsing_common.identifier ^ TUPLE)

CONST, VIRTUAL, CLASS, STATIC, PAIR, TEMPLATE, TYPEDEF, INCLUDE = map(
Keyword,
[
Expand Down
9 changes: 6 additions & 3 deletions gtwrap/interface_parser/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,9 +203,12 @@ def from_parse_result(t: ParseResults):
raise ValueError("Parse result is not a Type")

def __repr__(self) -> str:
return "{self.is_const} {self.typename} " \
"{self.is_shared_ptr}{self.is_ptr}{self.is_ref}".format(
self=self)
is_ptr_or_ref = "{0}{1}{2}".format(self.is_shared_ptr, self.is_ptr,
self.is_ref)
return "{is_const}{self.typename}{is_ptr_or_ref}".format(
self=self,
is_const="const " if self.is_const else "",
is_ptr_or_ref=" " + is_ptr_or_ref if is_ptr_or_ref else "")

def to_cpp(self, use_boost: bool) -> str:
"""
Expand Down
49 changes: 25 additions & 24 deletions gtwrap/pybind_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,14 @@ def _py_args_names(self, args_list):
"""Set the argument names in Pybind11 format."""
names = args_list.args_names()
if names:
py_args = ['py::arg("{}")'.format(name) for name in names]
py_args = []
for arg in args_list.args_list:
if arg.default and isinstance(arg.default, str):
arg.default = "\"{arg.default}\"".format(arg=arg)
argument = 'py::arg("{name}"){default}'.format(
name=arg.name,
default=' = {0}'.format(arg.default) if arg.default else '')
py_args.append(argument)
return ", " + ", ".join(py_args)
else:
return ''
Expand Down Expand Up @@ -124,35 +131,29 @@ def _wrap_method(self, method, cpp_class, prefix, suffix, method_suffix=""):
suffix=suffix,
))

# Create __repr__ override
# We allow all arguments to .print() and let the compiler handle type mismatches.
if method.name == 'print':
# Redirect stdout - see pybind docs for why this is a good idea:
# https://pybind11.readthedocs.io/en/stable/advanced/pycpp/utilities.html#capturing-standard-output-from-ostream
ret = ret.replace('self->', 'py::scoped_ostream_redirect output; self->')
ret = ret.replace('self->print', 'py::scoped_ostream_redirect output; self->print')

# __repr__() uses print's implementation:
type_list = method.args.to_cpp(self.use_boost)
if len(type_list) > 0 and type_list[0].strip() == 'string':
ret += '''{prefix}.def("__repr__",
[](const {cpp_class} &a) {{
# Make __repr__() call print() internally
ret += '''{prefix}.def("__repr__",
[](const {cpp_class}& self{opt_comma}{args_signature_with_names}){{
gtsam::RedirectCout redirect;
a.print("");
self.{method_name}({method_args});
return redirect.str();
}}){suffix}'''.format(
prefix=prefix,
cpp_class=cpp_class,
suffix=suffix,
)
else:
ret += '''{prefix}.def("__repr__",
[](const {cpp_class} &a) {{
gtsam::RedirectCout redirect;
a.print();
return redirect.str();
}}){suffix}'''.format(
prefix=prefix,
cpp_class=cpp_class,
suffix=suffix,
)
}}{py_args_names}){suffix}'''.format(
prefix=prefix,
cpp_class=cpp_class,
opt_comma=', ' if args_names else '',
args_signature_with_names=args_signature_with_names,
method_name=method.name,
method_args=", ".join(args_names) if args_names else '',
py_args_names=py_args_names,
suffix=suffix)

return ret

def wrap_methods(self, methods, cpp_class, prefix='\n' + ' ' * 8, suffix=''):
Expand Down
4 changes: 3 additions & 1 deletion gtwrap/template_instantiator.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,10 @@ def instantiate_args_list(args_list, template_typenames, instantiations,
for arg in args_list:
new_type = instantiate_type(arg.ctype, template_typenames,
instantiations, cpp_typename)
default = [arg.default] if isinstance(arg, parser.Argument) else ''
instantiated_args.append(parser.Argument(name=arg.name,
ctype=new_type))
ctype=new_type,
default=default))
return instantiated_args


Expand Down
1 change: 1 addition & 0 deletions templates/pybind_wrapper.tpl.example
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <pybind11/stl_bind.h>
#include <pybind11/pybind11.h>
#include <pybind11/operators.h>
#include <pybind11/iostream.h>
#include "gtsam/base/serialization.h"
#include "gtsam/nonlinear/utilities.h" // for RedirectCout.

Expand Down
13 changes: 13 additions & 0 deletions tests/expected/matlab/MyFactorPosePoint2.m
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
%-------Constructors-------
%MyFactorPosePoint2(size_t key1, size_t key2, double measured, Base noiseModel)
%
%-------Methods-------
%print(string s, KeyFormatter keyFormatter) : returns void
%
classdef MyFactorPosePoint2 < handle
properties
ptr_MyFactorPosePoint2 = 0
Expand All @@ -29,6 +32,16 @@ function delete(obj)
%DISPLAY Calls print on the object
function disp(obj), obj.display; end
%DISP Calls print on the object
function varargout = print(this, varargin)
% PRINT usage: print(string s, KeyFormatter keyFormatter) : returns void
% Doxygen can be found at https://gtsam.org/doxygen/
if length(varargin) == 2 && isa(varargin{1},'char') && isa(varargin{2},'gtsam.KeyFormatter')
class_wrapper(55, this, varargin{:});
return
end
error('Arguments do not match any overload of function MyFactorPosePoint2.print');
end

end

methods(Static = true)
Expand Down
2 changes: 1 addition & 1 deletion tests/expected/matlab/TemplatedFunctionRot3.m
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
function varargout = TemplatedFunctionRot3(varargin)
if length(varargin) == 1 && isa(varargin{1},'gtsam.Rot3')
functions_wrapper(8, varargin{:});
functions_wrapper(11, varargin{:});
else
error('Arguments do not match any overload of function TemplatedFunctionRot3');
end
12 changes: 12 additions & 0 deletions tests/expected/matlab/class_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -661,6 +661,15 @@ void MyFactorPosePoint2_deconstructor_54(int nargout, mxArray *out[], int nargin
}
}

void MyFactorPosePoint2_print_55(int nargout, mxArray *out[], int nargin, const mxArray *in[])
{
checkArguments("print",nargout,nargin-1,2);
auto obj = unwrap_shared_ptr<MyFactor<gtsam::Pose2, gtsam::Matrix>>(in[0], "ptr_MyFactorPosePoint2");
string& s = *unwrap_shared_ptr< string >(in[1], "ptr_string");
gtsam::KeyFormatter& keyFormatter = *unwrap_shared_ptr< gtsam::KeyFormatter >(in[2], "ptr_gtsamKeyFormatter");
obj->print(s,keyFormatter);
}


void mexFunction(int nargout, mxArray *out[], int nargin, const mxArray *in[])
{
Expand Down Expand Up @@ -838,6 +847,9 @@ void mexFunction(int nargout, mxArray *out[], int nargin, const mxArray *in[])
case 54:
MyFactorPosePoint2_deconstructor_54(nargout, out, nargin-1, in+1);
break;
case 55:
MyFactorPosePoint2_print_55(nargout, out, nargin-1, in+1);
break;
}
} catch(const std::exception& e) {
mexErrMsgTxt(("Exception from gtsam:\n" + std::string(e.what()) + "\n").c_str());
Expand Down
31 changes: 29 additions & 2 deletions tests/expected/matlab/functions_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,25 @@ void MultiTemplatedFunctionDoubleSize_tDouble_7(int nargout, mxArray *out[], int
size_t y = unwrap< size_t >(in[1]);
out[0] = wrap< double >(MultiTemplatedFunctionDoubleSize_tDouble(x,y));
}
void TemplatedFunctionRot3_8(int nargout, mxArray *out[], int nargin, const mxArray *in[])
void DefaultFuncInt_8(int nargout, mxArray *out[], int nargin, const mxArray *in[])
{
checkArguments("DefaultFuncInt",nargout,nargin,1);
int a = unwrap< int >(in[0]);
DefaultFuncInt(a);
}
void DefaultFuncString_9(int nargout, mxArray *out[], int nargin, const mxArray *in[])
{
checkArguments("DefaultFuncString",nargout,nargin,1);
string& s = *unwrap_shared_ptr< string >(in[0], "ptr_string");
DefaultFuncString(s);
}
void DefaultFuncObj_10(int nargout, mxArray *out[], int nargin, const mxArray *in[])
{
checkArguments("DefaultFuncObj",nargout,nargin,1);
gtsam::KeyFormatter& keyFormatter = *unwrap_shared_ptr< gtsam::KeyFormatter >(in[0], "ptr_gtsamKeyFormatter");
DefaultFuncObj(keyFormatter);
}
void TemplatedFunctionRot3_11(int nargout, mxArray *out[], int nargin, const mxArray *in[])
{
checkArguments("TemplatedFunctionRot3",nargout,nargin,1);
gtsam::Rot3& t = *unwrap_shared_ptr< gtsam::Rot3 >(in[0], "ptr_gtsamRot3");
Expand Down Expand Up @@ -239,7 +257,16 @@ void mexFunction(int nargout, mxArray *out[], int nargin, const mxArray *in[])
MultiTemplatedFunctionDoubleSize_tDouble_7(nargout, out, nargin-1, in+1);
break;
case 8:
TemplatedFunctionRot3_8(nargout, out, nargin-1, in+1);
DefaultFuncInt_8(nargout, out, nargin-1, in+1);
break;
case 9:
DefaultFuncString_9(nargout, out, nargin-1, in+1);
break;
case 10:
DefaultFuncObj_10(nargout, out, nargin-1, in+1);
break;
case 11:
TemplatedFunctionRot3_11(nargout, out, nargin-1, in+1);
break;
}
} catch(const std::exception& e) {
Expand Down
13 changes: 10 additions & 3 deletions tests/expected/python/class_pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,9 @@ PYBIND11_MODULE(class_py, m_) {
.def("return_ptrs",[](Test* self, std::shared_ptr<Test> p1, std::shared_ptr<Test> p2){return self->return_ptrs(p1, p2);}, py::arg("p1"), py::arg("p2"))
.def("print_",[](Test* self){ py::scoped_ostream_redirect output; self->print();})
.def("__repr__",
[](const Test &a) {
[](const Test& self){
gtsam::RedirectCout redirect;
a.print();
self.print();
return redirect.str();
})
.def("set_container",[](Test* self, std::vector<testing::Test> container){ self->set_container(container);}, py::arg("container"))
Expand All @@ -83,7 +83,14 @@ PYBIND11_MODULE(class_py, m_) {
py::class_<MultipleTemplates<int, float>, std::shared_ptr<MultipleTemplates<int, float>>>(m_, "MultipleTemplatesIntFloat");

py::class_<MyFactor<gtsam::Pose2, gtsam::Matrix>, std::shared_ptr<MyFactor<gtsam::Pose2, gtsam::Matrix>>>(m_, "MyFactorPosePoint2")
.def(py::init<size_t, size_t, double, const std::shared_ptr<gtsam::noiseModel::Base>>(), py::arg("key1"), py::arg("key2"), py::arg("measured"), py::arg("noiseModel"));
.def(py::init<size_t, size_t, double, const std::shared_ptr<gtsam::noiseModel::Base>>(), py::arg("key1"), py::arg("key2"), py::arg("measured"), py::arg("noiseModel"))
.def("print_",[](MyFactor<gtsam::Pose2, gtsam::Matrix>* self, const string& s, const gtsam::KeyFormatter& keyFormatter){ py::scoped_ostream_redirect output; self->print(s, keyFormatter);}, py::arg("s") = "factor: ", py::arg("keyFormatter") = gtsam::DefaultKeyFormatter)
.def("__repr__",
[](const MyFactor<gtsam::Pose2, gtsam::Matrix>& self, const string& s, const gtsam::KeyFormatter& keyFormatter){
gtsam::RedirectCout redirect;
self.print(s, keyFormatter);
return redirect.str();
}, py::arg("s") = "factor: ", py::arg("keyFormatter") = gtsam::DefaultKeyFormatter);


#include "python/specializations.h"
Expand Down
3 changes: 3 additions & 0 deletions tests/expected/python/functions_pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ PYBIND11_MODULE(functions_py, m_) {
m_.def("overloadedGlobalFunction",[](int a, double b){return ::overloadedGlobalFunction(a, b);}, py::arg("a"), py::arg("b"));
m_.def("MultiTemplatedFunctionStringSize_tDouble",[](const T& x, size_t y){return ::MultiTemplatedFunction<string,size_t,double>(x, y);}, py::arg("x"), py::arg("y"));
m_.def("MultiTemplatedFunctionDoubleSize_tDouble",[](const T& x, size_t y){return ::MultiTemplatedFunction<double,size_t,double>(x, y);}, py::arg("x"), py::arg("y"));
m_.def("DefaultFuncInt",[](int a){ ::DefaultFuncInt(a);}, py::arg("a") = 123);
m_.def("DefaultFuncString",[](const string& s){ ::DefaultFuncString(s);}, py::arg("s") = "hello");
m_.def("DefaultFuncObj",[](const gtsam::KeyFormatter& keyFormatter){ ::DefaultFuncObj(keyFormatter);}, py::arg("keyFormatter") = gtsam::DefaultKeyFormatter);
m_.def("TemplatedFunctionRot3",[](const gtsam::Rot3& t){ ::TemplatedFunction<Rot3>(t);}, py::arg("t"));

#include "python/specializations.h"
Expand Down
2 changes: 2 additions & 0 deletions tests/fixtures/class.i
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ virtual class ns::OtherClass;
template<POSE, POINT>
class MyFactor {
MyFactor(size_t key1, size_t key2, double measured, const gtsam::noiseModel::Base* noiseModel);
void print(const string &s = "factor: ",
const gtsam::KeyFormatter &keyFormatter = gtsam::DefaultKeyFormatter);
};

// and a typedef specializing it
Expand Down
5 changes: 5 additions & 0 deletions tests/fixtures/functions.i
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,8 @@ template<T>
void TemplatedFunction(const T& t);

typedef TemplatedFunction<gtsam::Rot3> TemplatedFunctionRot3;

// Check default arguments
void DefaultFuncInt(int a = 123);
void DefaultFuncString(const string& s = "hello");
void DefaultFuncObj(const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter);
28 changes: 27 additions & 1 deletion tests/test_interface_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,33 @@ def test_argument_list_templated(self):
self.assertEqual("vector<boost::shared_ptr<T>>",
args_list[1].ctype.to_cpp(True))

def test_default_arguments(self):
"""Tests any expression that is a valid default argument"""
args = ArgumentList.rule.parseString(
"string s=\"hello\", int a=3, "
"int b, double pi = 3.1415, "
"gtsam::KeyFormatter kf = gtsam::DefaultKeyFormatter, "
"std::vector<size_t> p = std::vector<size_t>(), "
"std::vector<size_t> l = (1, 2, 'name', \"random\", 3.1415)"
)[0].args_list

# Test for basic types
self.assertEqual(args[0].default, "hello")
self.assertEqual(args[1].default, 3)
# '' is falsy so we can check against it
self.assertEqual(args[2].default, '')
self.assertFalse(args[2].default)

self.assertEqual(args[3].default, 3.1415)

# Test non-basic type
self.assertEqual(repr(args[4].default.typename), 'gtsam::DefaultKeyFormatter')
# Test templated type
self.assertEqual(repr(args[5].default.typename), 'std::vector<size_t>')
# Test for allowing list as default argument
print(args)
self.assertEqual(args[6].default, (1, 2, 'name', "random", 3.1415))

def test_return_type(self):
"""Test ReturnType"""
# Test void
Expand Down Expand Up @@ -490,6 +517,5 @@ class Global{
self.assertEqual(["two", "two_dummy", "two"],
[x.name for x in module.content[0].content])


if __name__ == '__main__':
unittest.main()

0 comments on commit db373df

Please sign in to comment.