Skip to content

Commit

Permalink
[pydrake] Add dynamic_attr support to cpp_template_pybind
Browse files Browse the repository at this point in the history
  • Loading branch information
jwnimmer-tri committed Oct 15, 2024
1 parent 1771dbf commit 05943c7
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 4 deletions.
18 changes: 14 additions & 4 deletions bindings/pydrake/common/cpp_template_pybind.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#pragma once

#include <optional>
#include <string>
#include <utility>

Expand Down Expand Up @@ -76,18 +77,23 @@ inline py::object AddTemplateClass( // BR
/// and a default instantiation (if not already defined).
/// The default instantiation is named `default_name`, while the template is
/// named `default_name + template_suffix`.
/// The `template_suffix` defaults to "_" when not provided.
/// The caller may opt-in to py::dynamic_attr() as the last argument.
/// @return pybind11 class
template <typename Class, typename... Options>
py::class_<Class, Options...> DefineTemplateClassWithDefault( // BR
py::handle scope, const std::string& default_name, py::tuple param,
const char* doc_string = "", const std::string& template_suffix = "_") {
const char* doc_string = "",
const std::optional<std::string>& template_suffix = {},
std::optional<py::dynamic_attr> dynamic_attr = {}) {
// The default instantiation is immediately assigned its correct class name.
// Other instantiations use a temporary name here that will be overwritten
// by the AddTemplateClass function during registration.
const bool is_default = !py::hasattr(scope, default_name.c_str());
const std::string class_name =
is_default ? default_name : TemporaryClassName<Class>();
const std::string template_name = default_name + template_suffix;
const std::string template_name =
default_name + template_suffix.value_or("_");
// Define the class.
std::string doc;
if (is_default) {
Expand All @@ -99,8 +105,12 @@ py::class_<Class, Options...> DefineTemplateClassWithDefault( // BR
} else {
doc = doc_string;
}
py::class_<Class, Options...> py_class(
scope, class_name.c_str(), doc.c_str());
py::class_<Class, Options...> py_class =
dynamic_attr.has_value()
? py::class_<Class, Options...>(
scope, class_name.c_str(), doc.c_str(), *dynamic_attr)
: py::class_<Class, Options...>(
scope, class_name.c_str(), doc.c_str());
// Register it as a template instantiation.
const bool skip_rename = is_default;
AddTemplateClass(scope, template_name, py_class, param, skip_rename);
Expand Down
26 changes: 26 additions & 0 deletions bindings/pydrake/common/test/cpp_template_pybind_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,22 @@ void CheckValue(const string& expr, const T& expected) {
EXPECT_EQ(py::eval(expr).cast<T>(), expected);
}

template <typename T>
struct TemplateWithDefault {
string GetName() { return NiceTypeName::Get<T>(); }
};

template <typename T>
void BindTemplateWithDefault(py::module m) {
using Class = TemplateWithDefault<T>;
auto py_class =
DefineTemplateClassWithDefault<Class>(m, "TemplateWithDefault",
GetPyParam<T>(), "Documentation", std::nullopt, py::dynamic_attr());
py_class // BR
.def(py::init<>())
.def("GetName", &Class::GetName);
}

GTEST_TEST(CppTemplateTest, TemplateClass) {
py::module m =
py::module::create_extension_module("__main__", "", new PyModuleDef());
Expand All @@ -54,6 +70,9 @@ GTEST_TEST(CppTemplateTest, TemplateClass) {
m.attr("DefaultInst") = cls_1;
auto cls_2 = BindSimpleTemplate<int, double>(m);

BindTemplateWithDefault<double>(m);
BindTemplateWithDefault<int>(m);

const vector<string> expected_1 = {"int"};
const vector<string> expected_2 = {"int", "double"};
SynchronizeGlobalsForPython3(m);
Expand All @@ -62,6 +81,13 @@ GTEST_TEST(CppTemplateTest, TemplateClass) {
CheckValue("SimpleTemplate[int]().GetNames()", expected_1);
CheckValue("SimpleTemplate[int, float]().GetNames()", expected_2);

CheckValue("TemplateWithDefault().GetName()", string{"double"});
CheckValue("TemplateWithDefault_[float]().GetName()", string{"double"});
CheckValue("TemplateWithDefault_[int]().GetName()", string{"int"});

// Sanity test of the py::dynamic_attr().
CheckValue("TemplateWithDefault().__dict__.setdefault('_foo', 1)", 1);

m.def("simple_func", [](const SimpleTemplate<int>&) {});
SynchronizeGlobalsForPython3(m);

Expand Down

0 comments on commit 05943c7

Please sign in to comment.