diff --git a/cmake/PybindWrap.cmake b/cmake/PybindWrap.cmake index dc581be495..331dfff8c4 100644 --- a/cmake/PybindWrap.cmake +++ b/cmake/PybindWrap.cmake @@ -72,6 +72,7 @@ function(pybind_wrap --template ${module_template} ${_WRAP_BOOST_ARG} + DEPENDS ${interface_header} ${module_template} VERBATIM) add_custom_target(pybind_wrap_${module_name} ALL DEPENDS ${generated_cpp}) diff --git a/gtwrap/interface_parser/function.py b/gtwrap/interface_parser/function.py index 526d5e0555..64c7b176bb 100644 --- a/gtwrap/interface_parser/function.py +++ b/gtwrap/interface_parser/function.py @@ -15,8 +15,8 @@ from pyparsing import Optional, ParseResults, delimitedList from .template import Template -from .tokens import (COMMA, IDENT, LOPBRACK, LPAREN, PAIR, ROPBRACK, RPAREN, - SEMI_COLON) +from .tokens import (COMMA, DEFAULT_ARG, EQUAL, IDENT, LOPBRACK, LPAREN, PAIR, + ROPBRACK, RPAREN, SEMI_COLON) from .type import TemplatedType, Type @@ -29,15 +29,29 @@ class Argument: void sayHello(/*`s` is the method argument with type `const string&`*/ const string& s); ``` """ - rule = ((Type.rule ^ TemplatedType.rule)("ctype") + - IDENT("name")).setParseAction(lambda t: Argument(t.ctype, t.name)) + rule = ((Type.rule ^ TemplatedType.rule)("ctype") + IDENT("name") + \ + Optional(EQUAL + (DEFAULT_ARG ^ Type.rule ^ TemplatedType.rule) + \ + Optional(LPAREN + RPAREN) # Needed to parse the parens for default constructors + )("default") + ).setParseAction(lambda t: Argument(t.ctype, t.name, t.default)) - def __init__(self, ctype: Union[Type, TemplatedType], name: str): + def __init__(self, + ctype: Union[Type, TemplatedType], + name: str, + default: ParseResults = None): if isinstance(ctype, Iterable): self.ctype = ctype[0] else: self.ctype = ctype self.name = name + # If the length is 1, it's a regular type, + if len(default) == 1: + default = default[0] + # This means a tuple has been passed so we convert accordingly + elif len(default) > 1: + default = tuple(default.asList()) + self.default = default + self.parent: Union[ArgumentList, None] = None def __repr__(self) -> str: diff --git a/gtwrap/interface_parser/tokens.py b/gtwrap/interface_parser/tokens.py index 432c5407aa..5d2bdeaf3c 100644 --- a/gtwrap/interface_parser/tokens.py +++ b/gtwrap/interface_parser/tokens.py @@ -10,7 +10,9 @@ Author: Duy Nguyen Ta, Fan Jiang, Matthew Sklar, Varun Agrawal, and Frank Dellaert """ -from pyparsing import Keyword, Literal, Suppress, Word, alphanums, alphas, nums, Or +from pyparsing import (Keyword, Literal, Or, QuotedString, Suppress, Word, + alphanums, alphas, delimitedList, nums, + pyparsing_common) # rule for identifiers (e.g. variable names) IDENT = Word(alphas + '_', alphanums + '_') ^ Word(nums) @@ -19,6 +21,18 @@ LPAREN, RPAREN, LBRACE, RBRACE, COLON, SEMI_COLON = map(Suppress, "(){}:;") LOPBRACK, ROPBRACK, COMMA, EQUAL = map(Suppress, "<>,=") + +# Encapsulating type for numbers, and single and double quoted strings. +# The pyparsing_common utilities ensure correct coversion to the corresponding type. +# E.g. pyparsing_common.number will convert 3.1415 to a float type. +NUMBER_OR_STRING = (pyparsing_common.number ^ QuotedString('"') ^ QuotedString("'")) + +# A python tuple, e.g. (1, 9, "random", 3.1415) +TUPLE = (LPAREN + delimitedList(NUMBER_OR_STRING) + RPAREN) + +# Default argument passed to functions/methods. +DEFAULT_ARG = (NUMBER_OR_STRING ^ pyparsing_common.identifier ^ TUPLE) + CONST, VIRTUAL, CLASS, STATIC, PAIR, TEMPLATE, TYPEDEF, INCLUDE = map( Keyword, [ diff --git a/gtwrap/interface_parser/type.py b/gtwrap/interface_parser/type.py index e4b80a2247..b9f2bd8f74 100644 --- a/gtwrap/interface_parser/type.py +++ b/gtwrap/interface_parser/type.py @@ -203,9 +203,12 @@ def from_parse_result(t: ParseResults): raise ValueError("Parse result is not a Type") def __repr__(self) -> str: - return "{self.is_const} {self.typename} " \ - "{self.is_shared_ptr}{self.is_ptr}{self.is_ref}".format( - self=self) + is_ptr_or_ref = "{0}{1}{2}".format(self.is_shared_ptr, self.is_ptr, + self.is_ref) + return "{is_const}{self.typename}{is_ptr_or_ref}".format( + self=self, + is_const="const " if self.is_const else "", + is_ptr_or_ref=" " + is_ptr_or_ref if is_ptr_or_ref else "") def to_cpp(self, use_boost: bool) -> str: """ diff --git a/gtwrap/pybind_wrapper.py b/gtwrap/pybind_wrapper.py index 1e0d412a0b..801e691c6a 100755 --- a/gtwrap/pybind_wrapper.py +++ b/gtwrap/pybind_wrapper.py @@ -45,7 +45,14 @@ def _py_args_names(self, args_list): """Set the argument names in Pybind11 format.""" names = args_list.args_names() if names: - py_args = ['py::arg("{}")'.format(name) for name in names] + py_args = [] + for arg in args_list.args_list: + if arg.default and isinstance(arg.default, str): + arg.default = "\"{arg.default}\"".format(arg=arg) + argument = 'py::arg("{name}"){default}'.format( + name=arg.name, + default=' = {0}'.format(arg.default) if arg.default else '') + py_args.append(argument) return ", " + ", ".join(py_args) else: return '' @@ -124,35 +131,29 @@ def _wrap_method(self, method, cpp_class, prefix, suffix, method_suffix=""): suffix=suffix, )) + # Create __repr__ override + # We allow all arguments to .print() and let the compiler handle type mismatches. if method.name == 'print': # Redirect stdout - see pybind docs for why this is a good idea: # https://pybind11.readthedocs.io/en/stable/advanced/pycpp/utilities.html#capturing-standard-output-from-ostream - ret = ret.replace('self->', 'py::scoped_ostream_redirect output; self->') + ret = ret.replace('self->print', 'py::scoped_ostream_redirect output; self->print') - # __repr__() uses print's implementation: - type_list = method.args.to_cpp(self.use_boost) - if len(type_list) > 0 and type_list[0].strip() == 'string': - ret += '''{prefix}.def("__repr__", - [](const {cpp_class} &a) {{ + # Make __repr__() call print() internally + ret += '''{prefix}.def("__repr__", + [](const {cpp_class}& self{opt_comma}{args_signature_with_names}){{ gtsam::RedirectCout redirect; - a.print(""); + self.{method_name}({method_args}); return redirect.str(); - }}){suffix}'''.format( - prefix=prefix, - cpp_class=cpp_class, - suffix=suffix, - ) - else: - ret += '''{prefix}.def("__repr__", - [](const {cpp_class} &a) {{ - gtsam::RedirectCout redirect; - a.print(); - return redirect.str(); - }}){suffix}'''.format( - prefix=prefix, - cpp_class=cpp_class, - suffix=suffix, - ) + }}{py_args_names}){suffix}'''.format( + prefix=prefix, + cpp_class=cpp_class, + opt_comma=', ' if args_names else '', + args_signature_with_names=args_signature_with_names, + method_name=method.name, + method_args=", ".join(args_names) if args_names else '', + py_args_names=py_args_names, + suffix=suffix) + return ret def wrap_methods(self, methods, cpp_class, prefix='\n' + ' ' * 8, suffix=''): diff --git a/gtwrap/template_instantiator.py b/gtwrap/template_instantiator.py index f6c396a99e..bddaa07a8f 100644 --- a/gtwrap/template_instantiator.py +++ b/gtwrap/template_instantiator.py @@ -95,8 +95,10 @@ def instantiate_args_list(args_list, template_typenames, instantiations, for arg in args_list: new_type = instantiate_type(arg.ctype, template_typenames, instantiations, cpp_typename) + default = [arg.default] if isinstance(arg, parser.Argument) else '' instantiated_args.append(parser.Argument(name=arg.name, - ctype=new_type)) + ctype=new_type, + default=default)) return instantiated_args diff --git a/templates/pybind_wrapper.tpl.example b/templates/pybind_wrapper.tpl.example index 399c690aca..8c38ad21c4 100644 --- a/templates/pybind_wrapper.tpl.example +++ b/templates/pybind_wrapper.tpl.example @@ -4,6 +4,7 @@ #include #include #include +#include #include "gtsam/base/serialization.h" #include "gtsam/nonlinear/utilities.h" // for RedirectCout. diff --git a/tests/expected/matlab/MyFactorPosePoint2.m b/tests/expected/matlab/MyFactorPosePoint2.m index 290e41d4e7..ea2e335c71 100644 --- a/tests/expected/matlab/MyFactorPosePoint2.m +++ b/tests/expected/matlab/MyFactorPosePoint2.m @@ -4,6 +4,9 @@ %-------Constructors------- %MyFactorPosePoint2(size_t key1, size_t key2, double measured, Base noiseModel) % +%-------Methods------- +%print(string s, KeyFormatter keyFormatter) : returns void +% classdef MyFactorPosePoint2 < handle properties ptr_MyFactorPosePoint2 = 0 @@ -29,6 +32,16 @@ function delete(obj) %DISPLAY Calls print on the object function disp(obj), obj.display; end %DISP Calls print on the object + function varargout = print(this, varargin) + % PRINT usage: print(string s, KeyFormatter keyFormatter) : returns void + % Doxygen can be found at https://gtsam.org/doxygen/ + if length(varargin) == 2 && isa(varargin{1},'char') && isa(varargin{2},'gtsam.KeyFormatter') + class_wrapper(55, this, varargin{:}); + return + end + error('Arguments do not match any overload of function MyFactorPosePoint2.print'); + end + end methods(Static = true) diff --git a/tests/expected/matlab/TemplatedFunctionRot3.m b/tests/expected/matlab/TemplatedFunctionRot3.m index 132db92da9..5b90c24733 100644 --- a/tests/expected/matlab/TemplatedFunctionRot3.m +++ b/tests/expected/matlab/TemplatedFunctionRot3.m @@ -1,6 +1,6 @@ function varargout = TemplatedFunctionRot3(varargin) if length(varargin) == 1 && isa(varargin{1},'gtsam.Rot3') - functions_wrapper(8, varargin{:}); + functions_wrapper(11, varargin{:}); else error('Arguments do not match any overload of function TemplatedFunctionRot3'); end diff --git a/tests/expected/matlab/class_wrapper.cpp b/tests/expected/matlab/class_wrapper.cpp index 11eda2d96e..3fc2e5dafd 100644 --- a/tests/expected/matlab/class_wrapper.cpp +++ b/tests/expected/matlab/class_wrapper.cpp @@ -661,6 +661,15 @@ void MyFactorPosePoint2_deconstructor_54(int nargout, mxArray *out[], int nargin } } +void MyFactorPosePoint2_print_55(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +{ + checkArguments("print",nargout,nargin-1,2); + auto obj = unwrap_shared_ptr>(in[0], "ptr_MyFactorPosePoint2"); + string& s = *unwrap_shared_ptr< string >(in[1], "ptr_string"); + gtsam::KeyFormatter& keyFormatter = *unwrap_shared_ptr< gtsam::KeyFormatter >(in[2], "ptr_gtsamKeyFormatter"); + obj->print(s,keyFormatter); +} + void mexFunction(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { @@ -838,6 +847,9 @@ void mexFunction(int nargout, mxArray *out[], int nargin, const mxArray *in[]) case 54: MyFactorPosePoint2_deconstructor_54(nargout, out, nargin-1, in+1); break; + case 55: + MyFactorPosePoint2_print_55(nargout, out, nargin-1, in+1); + break; } } catch(const std::exception& e) { mexErrMsgTxt(("Exception from gtsam:\n" + std::string(e.what()) + "\n").c_str()); diff --git a/tests/expected/matlab/functions_wrapper.cpp b/tests/expected/matlab/functions_wrapper.cpp index c17c98eadb..b8341b4bae 100644 --- a/tests/expected/matlab/functions_wrapper.cpp +++ b/tests/expected/matlab/functions_wrapper.cpp @@ -196,7 +196,25 @@ void MultiTemplatedFunctionDoubleSize_tDouble_7(int nargout, mxArray *out[], int size_t y = unwrap< size_t >(in[1]); out[0] = wrap< double >(MultiTemplatedFunctionDoubleSize_tDouble(x,y)); } -void TemplatedFunctionRot3_8(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +void DefaultFuncInt_8(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +{ + checkArguments("DefaultFuncInt",nargout,nargin,1); + int a = unwrap< int >(in[0]); + DefaultFuncInt(a); +} +void DefaultFuncString_9(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +{ + checkArguments("DefaultFuncString",nargout,nargin,1); + string& s = *unwrap_shared_ptr< string >(in[0], "ptr_string"); + DefaultFuncString(s); +} +void DefaultFuncObj_10(int nargout, mxArray *out[], int nargin, const mxArray *in[]) +{ + checkArguments("DefaultFuncObj",nargout,nargin,1); + gtsam::KeyFormatter& keyFormatter = *unwrap_shared_ptr< gtsam::KeyFormatter >(in[0], "ptr_gtsamKeyFormatter"); + DefaultFuncObj(keyFormatter); +} +void TemplatedFunctionRot3_11(int nargout, mxArray *out[], int nargin, const mxArray *in[]) { checkArguments("TemplatedFunctionRot3",nargout,nargin,1); gtsam::Rot3& t = *unwrap_shared_ptr< gtsam::Rot3 >(in[0], "ptr_gtsamRot3"); @@ -239,7 +257,16 @@ void mexFunction(int nargout, mxArray *out[], int nargin, const mxArray *in[]) MultiTemplatedFunctionDoubleSize_tDouble_7(nargout, out, nargin-1, in+1); break; case 8: - TemplatedFunctionRot3_8(nargout, out, nargin-1, in+1); + DefaultFuncInt_8(nargout, out, nargin-1, in+1); + break; + case 9: + DefaultFuncString_9(nargout, out, nargin-1, in+1); + break; + case 10: + DefaultFuncObj_10(nargout, out, nargin-1, in+1); + break; + case 11: + TemplatedFunctionRot3_11(nargout, out, nargin-1, in+1); break; } } catch(const std::exception& e) { diff --git a/tests/expected/python/class_pybind.cpp b/tests/expected/python/class_pybind.cpp index 43705778bb..961daeffe4 100644 --- a/tests/expected/python/class_pybind.cpp +++ b/tests/expected/python/class_pybind.cpp @@ -57,9 +57,9 @@ PYBIND11_MODULE(class_py, m_) { .def("return_ptrs",[](Test* self, std::shared_ptr p1, std::shared_ptr p2){return self->return_ptrs(p1, p2);}, py::arg("p1"), py::arg("p2")) .def("print_",[](Test* self){ py::scoped_ostream_redirect output; self->print();}) .def("__repr__", - [](const Test &a) { + [](const Test& self){ gtsam::RedirectCout redirect; - a.print(); + self.print(); return redirect.str(); }) .def("set_container",[](Test* self, std::vector container){ self->set_container(container);}, py::arg("container")) @@ -83,7 +83,14 @@ PYBIND11_MODULE(class_py, m_) { py::class_, std::shared_ptr>>(m_, "MultipleTemplatesIntFloat"); py::class_, std::shared_ptr>>(m_, "MyFactorPosePoint2") - .def(py::init>(), py::arg("key1"), py::arg("key2"), py::arg("measured"), py::arg("noiseModel")); + .def(py::init>(), py::arg("key1"), py::arg("key2"), py::arg("measured"), py::arg("noiseModel")) + .def("print_",[](MyFactor* self, const string& s, const gtsam::KeyFormatter& keyFormatter){ py::scoped_ostream_redirect output; self->print(s, keyFormatter);}, py::arg("s") = "factor: ", py::arg("keyFormatter") = gtsam::DefaultKeyFormatter) + .def("__repr__", + [](const MyFactor& self, const string& s, const gtsam::KeyFormatter& keyFormatter){ + gtsam::RedirectCout redirect; + self.print(s, keyFormatter); + return redirect.str(); + }, py::arg("s") = "factor: ", py::arg("keyFormatter") = gtsam::DefaultKeyFormatter); #include "python/specializations.h" diff --git a/tests/expected/python/functions_pybind.cpp b/tests/expected/python/functions_pybind.cpp index a657bee67c..2513bcf564 100644 --- a/tests/expected/python/functions_pybind.cpp +++ b/tests/expected/python/functions_pybind.cpp @@ -30,6 +30,9 @@ PYBIND11_MODULE(functions_py, m_) { m_.def("overloadedGlobalFunction",[](int a, double b){return ::overloadedGlobalFunction(a, b);}, py::arg("a"), py::arg("b")); m_.def("MultiTemplatedFunctionStringSize_tDouble",[](const T& x, size_t y){return ::MultiTemplatedFunction(x, y);}, py::arg("x"), py::arg("y")); m_.def("MultiTemplatedFunctionDoubleSize_tDouble",[](const T& x, size_t y){return ::MultiTemplatedFunction(x, y);}, py::arg("x"), py::arg("y")); + m_.def("DefaultFuncInt",[](int a){ ::DefaultFuncInt(a);}, py::arg("a") = 123); + m_.def("DefaultFuncString",[](const string& s){ ::DefaultFuncString(s);}, py::arg("s") = "hello"); + m_.def("DefaultFuncObj",[](const gtsam::KeyFormatter& keyFormatter){ ::DefaultFuncObj(keyFormatter);}, py::arg("keyFormatter") = gtsam::DefaultKeyFormatter); m_.def("TemplatedFunctionRot3",[](const gtsam::Rot3& t){ ::TemplatedFunction(t);}, py::arg("t")); #include "python/specializations.h" diff --git a/tests/fixtures/class.i b/tests/fixtures/class.i index f4683a0321..f49725ffa9 100644 --- a/tests/fixtures/class.i +++ b/tests/fixtures/class.i @@ -79,6 +79,8 @@ virtual class ns::OtherClass; template class MyFactor { MyFactor(size_t key1, size_t key2, double measured, const gtsam::noiseModel::Base* noiseModel); + void print(const string &s = "factor: ", + const gtsam::KeyFormatter &keyFormatter = gtsam::DefaultKeyFormatter); }; // and a typedef specializing it diff --git a/tests/fixtures/functions.i b/tests/fixtures/functions.i index d983ac97a1..5e774a05a9 100644 --- a/tests/fixtures/functions.i +++ b/tests/fixtures/functions.i @@ -26,3 +26,8 @@ template void TemplatedFunction(const T& t); typedef TemplatedFunction TemplatedFunctionRot3; + +// Check default arguments +void DefaultFuncInt(int a = 123); +void DefaultFuncString(const string& s = "hello"); +void DefaultFuncObj(const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter); diff --git a/tests/test_interface_parser.py b/tests/test_interface_parser.py index d3c2ad52aa..1b9d5e711a 100644 --- a/tests/test_interface_parser.py +++ b/tests/test_interface_parser.py @@ -179,6 +179,33 @@ def test_argument_list_templated(self): self.assertEqual("vector>", args_list[1].ctype.to_cpp(True)) + def test_default_arguments(self): + """Tests any expression that is a valid default argument""" + args = ArgumentList.rule.parseString( + "string s=\"hello\", int a=3, " + "int b, double pi = 3.1415, " + "gtsam::KeyFormatter kf = gtsam::DefaultKeyFormatter, " + "std::vector p = std::vector(), " + "std::vector l = (1, 2, 'name', \"random\", 3.1415)" + )[0].args_list + + # Test for basic types + self.assertEqual(args[0].default, "hello") + self.assertEqual(args[1].default, 3) + # '' is falsy so we can check against it + self.assertEqual(args[2].default, '') + self.assertFalse(args[2].default) + + self.assertEqual(args[3].default, 3.1415) + + # Test non-basic type + self.assertEqual(repr(args[4].default.typename), 'gtsam::DefaultKeyFormatter') + # Test templated type + self.assertEqual(repr(args[5].default.typename), 'std::vector') + # Test for allowing list as default argument + print(args) + self.assertEqual(args[6].default, (1, 2, 'name', "random", 3.1415)) + def test_return_type(self): """Test ReturnType""" # Test void @@ -490,6 +517,5 @@ class Global{ self.assertEqual(["two", "two_dummy", "two"], [x.name for x in module.content[0].content]) - if __name__ == '__main__': unittest.main()