From 05943c77a41231b8a1759858cdbed8a08ed3224c Mon Sep 17 00:00:00 2001 From: Jeremy Nimmer Date: Tue, 15 Oct 2024 06:56:05 -0700 Subject: [PATCH] [pydrake] Add dynamic_attr support to cpp_template_pybind --- bindings/pydrake/common/cpp_template_pybind.h | 18 ++++++++++--- .../common/test/cpp_template_pybind_test.cc | 26 +++++++++++++++++++ 2 files changed, 40 insertions(+), 4 deletions(-) diff --git a/bindings/pydrake/common/cpp_template_pybind.h b/bindings/pydrake/common/cpp_template_pybind.h index d1b3dcc5fe39..8cae9dc7a0bf 100644 --- a/bindings/pydrake/common/cpp_template_pybind.h +++ b/bindings/pydrake/common/cpp_template_pybind.h @@ -1,5 +1,6 @@ #pragma once +#include #include #include @@ -76,18 +77,23 @@ inline py::object AddTemplateClass( // BR /// and a default instantiation (if not already defined). /// The default instantiation is named `default_name`, while the template is /// named `default_name + template_suffix`. +/// The `template_suffix` defaults to "_" when not provided. +/// The caller may opt-in to py::dynamic_attr() as the last argument. /// @return pybind11 class template py::class_ DefineTemplateClassWithDefault( // BR py::handle scope, const std::string& default_name, py::tuple param, - const char* doc_string = "", const std::string& template_suffix = "_") { + const char* doc_string = "", + const std::optional& template_suffix = {}, + std::optional dynamic_attr = {}) { // The default instantiation is immediately assigned its correct class name. // Other instantiations use a temporary name here that will be overwritten // by the AddTemplateClass function during registration. const bool is_default = !py::hasattr(scope, default_name.c_str()); const std::string class_name = is_default ? default_name : TemporaryClassName(); - const std::string template_name = default_name + template_suffix; + const std::string template_name = + default_name + template_suffix.value_or("_"); // Define the class. std::string doc; if (is_default) { @@ -99,8 +105,12 @@ py::class_ DefineTemplateClassWithDefault( // BR } else { doc = doc_string; } - py::class_ py_class( - scope, class_name.c_str(), doc.c_str()); + py::class_ py_class = + dynamic_attr.has_value() + ? py::class_( + scope, class_name.c_str(), doc.c_str(), *dynamic_attr) + : py::class_( + scope, class_name.c_str(), doc.c_str()); // Register it as a template instantiation. const bool skip_rename = is_default; AddTemplateClass(scope, template_name, py_class, param, skip_rename); diff --git a/bindings/pydrake/common/test/cpp_template_pybind_test.cc b/bindings/pydrake/common/test/cpp_template_pybind_test.cc index c8c48ae9ec2e..c39a4c671cb8 100644 --- a/bindings/pydrake/common/test/cpp_template_pybind_test.cc +++ b/bindings/pydrake/common/test/cpp_template_pybind_test.cc @@ -46,6 +46,22 @@ void CheckValue(const string& expr, const T& expected) { EXPECT_EQ(py::eval(expr).cast(), expected); } +template +struct TemplateWithDefault { + string GetName() { return NiceTypeName::Get(); } +}; + +template +void BindTemplateWithDefault(py::module m) { + using Class = TemplateWithDefault; + auto py_class = + DefineTemplateClassWithDefault(m, "TemplateWithDefault", + GetPyParam(), "Documentation", std::nullopt, py::dynamic_attr()); + py_class // BR + .def(py::init<>()) + .def("GetName", &Class::GetName); +} + GTEST_TEST(CppTemplateTest, TemplateClass) { py::module m = py::module::create_extension_module("__main__", "", new PyModuleDef()); @@ -54,6 +70,9 @@ GTEST_TEST(CppTemplateTest, TemplateClass) { m.attr("DefaultInst") = cls_1; auto cls_2 = BindSimpleTemplate(m); + BindTemplateWithDefault(m); + BindTemplateWithDefault(m); + const vector expected_1 = {"int"}; const vector expected_2 = {"int", "double"}; SynchronizeGlobalsForPython3(m); @@ -62,6 +81,13 @@ GTEST_TEST(CppTemplateTest, TemplateClass) { CheckValue("SimpleTemplate[int]().GetNames()", expected_1); CheckValue("SimpleTemplate[int, float]().GetNames()", expected_2); + CheckValue("TemplateWithDefault().GetName()", string{"double"}); + CheckValue("TemplateWithDefault_[float]().GetName()", string{"double"}); + CheckValue("TemplateWithDefault_[int]().GetName()", string{"int"}); + + // Sanity test of the py::dynamic_attr(). + CheckValue("TemplateWithDefault().__dict__.setdefault('_foo', 1)", 1); + m.def("simple_func", [](const SimpleTemplate&) {}); SynchronizeGlobalsForPython3(m);