Skip to content

Commit

Permalink
[py systems] Adjust scalar conversion to avoid unique_ptr
Browse files Browse the repository at this point in the history
  • Loading branch information
jwnimmer-tri committed Jan 23, 2025
1 parent 62b40b0 commit 2b97a5e
Show file tree
Hide file tree
Showing 13 changed files with 425 additions and 21 deletions.
72 changes: 58 additions & 14 deletions bindings/pydrake/systems/framework_py_systems.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "drake/systems/framework/system_visitor.h"
#include "drake/systems/framework/vector_system.h"
#include "drake/systems/framework/witness_function.h"
#include "drake/systems/framework/wrapped_system.h"

using std::make_unique;
using std::string;
Expand Down Expand Up @@ -45,6 +46,7 @@ using systems::SystemVisitor;
using systems::UnrestrictedUpdateEvent;
using systems::VectorSystem;
using systems::WitnessFunction;
using systems::internal::WrappedSystem;

// NOLINTNEXTLINE(build/namespaces): Emulate placement in namespace.
using namespace drake::systems;
Expand Down Expand Up @@ -91,8 +93,7 @@ struct Impl {
// (LeafSystemPublic::*)(...), since this typeid is not exposed in pybind.
// If needed, solution is to expose it as an intermediate type if needed.

// Expose protected methods for binding, no need for virtual overrides
// (ordered by how they are bound).
// Expose protected methods for binding, no need for virtual overrides.
using Base::DeclareAbstractInputPort;
using Base::DeclareAbstractOutputPort;
using Base::DeclareAbstractParameter;
Expand All @@ -110,6 +111,7 @@ struct Impl {
using Base::get_mutable_forced_discrete_update_events;
using Base::get_mutable_forced_publish_events;
using Base::get_mutable_forced_unrestricted_update_events;
using Base::HandlePostConstructionScalarConversion;
using Base::MakeWitnessFunction;

// Because `LeafSystem<T>::DoCalcTimeDerivatives` is protected, and we had
Expand Down Expand Up @@ -557,13 +559,24 @@ Note: The above is for the C++ documentation. For Python, use
doc.System.ToSymbolicMaybe.doc)
.def("FixInputPortsFrom", &System<T>::FixInputPortsFrom,
py::arg("other_system"), py::arg("other_context"),
py::arg("target_context"), doc.System.FixInputPortsFrom.doc);
py::arg("target_context"), doc.System.FixInputPortsFrom.doc)
.def("get_system_scalar_converter",
&System<T>::get_system_scalar_converter, py_rvp::reference_internal,
doc.System.get_system_scalar_converter.doc);
auto def_to_scalar_type = [&cls](auto dummy) {
using U = decltype(dummy);
AddTemplateMethod(
cls, "ToScalarType",
[](const System<T>& self) { return self.template ToScalarType<U>(); },
GetPyParam<U>(), doc.System.ToScalarType.doc_0args);
AddTemplateMethod(
cls, "_HandlePostConstructionScalarConversion",
[](System<T>& self, const System<U>& from) {
LeafSystemPublic::HandlePostConstructionScalarConversion(
from, &self);
},
GetPyParam<U>(),
doc.System.HandlePostConstructionScalarConversion.doc);
};
type_visit(def_to_scalar_type, CommonScalarPack{});

Expand Down Expand Up @@ -1067,6 +1080,16 @@ Note: The above is for the C++ documentation. For Python, use
}
}

static void DefineWrappedSystem(py::module m) {
using Class = WrappedSystem<T>;
auto cls = DefineTemplateClassWithDefault<Class, Diagram<T>>(m,
"_WrappedSystem", GetPyParam<T>(),
"Wrapper that enables scalar-conversion of Python leaf systems.");
cls // BR
.def("unwrap", &Class::unwrap, py_rvp::reference_internal,
"Returns the underlying system.");
}

template <typename PyClass>
static void DefineSystemVisitor(py::module m, PyClass* system_cls) {
// TODO(eric.cousineau): Bind virtual methods once we provide a function
Expand Down Expand Up @@ -1254,17 +1277,37 @@ void DefineSystemScalarConverter(PyClass* cls) {
&SystemScalarConverter::IsConvertible<T, U>, GetPyParam<T, U>(),
cls_doc.IsConvertible.doc);
using system_scalar_converter_internal::AddPydrakeConverterFunction;
using ConverterFunction =
std::function<std::unique_ptr<System<T>>(const System<U>&)>;
AddTemplateMethod(converter, "_Add",
WrapCallbacks(
[](SystemScalarConverter* self, const ConverterFunction& func) {
const std::function<System<T>*(const System<U>&)> bare_func =
[func](const System<U>& other) {
return func(other).release();
};
AddPydrakeConverterFunction(self, bare_func);
}),
// N.B. The "_AddConstructor" method is called by scalar_conversion.py
// to register a constructor, similar to MaybeAddConstructor in C++.
using ConverterFunction = std::function<System<T>*(const System<U>&)>;
AddTemplateMethod(
converter, "_AddConstructor",
[](SystemScalarConverter* self,
py::function python_converter_function) {
AddPydrakeConverterFunction(self,
ConverterFunction{
[python_converter_function](const System<U>& system_u_cpp) {
py::gil_scoped_acquire guard;
// Call the Python converter function.
py::object system_u_py =
py::cast(system_u_cpp, py_rvp::reference_internal);
py::object system_t_py =
python_converter_function(system_u_py);
DRAKE_THROW_UNLESS(!system_t_py.is_none());
// Cast its result to a shared_ptr.
std::shared_ptr<System<T>> system_t_cpp =
make_shared_ptr_from_py_object<System<T>>(
std::move(system_t_py));
// Wrap the result in a Diagram so we have a unique_ptr
// instead of a shared_ptr.
std::unique_ptr<System<T>> result =
std::make_unique<WrappedSystem<T>>(
std::move(system_t_cpp));
// Our contract is to return an owned raw pointer. Our
// caller will wrap the unique_ptr back around it.
return result.release();
}});
},
GetPyParam<T, U>());
};
// N.B. When changing the pairs of supported types below, ensure that these
Expand Down Expand Up @@ -1314,6 +1357,7 @@ void DefineFrameworkPySystems(py::module m) {
Impl<T>::DefineLeafSystem(m);
Impl<T>::DefineDiagram(m);
Impl<T>::DefineVectorSystem(m);
Impl<T>::DefineWrappedSystem(m);
Impl<T>::DefineSystemVisitor(m, cls_system);
};
type_visit(bind_common_scalar_types, CommonScalarPack{});
Expand Down
51 changes: 50 additions & 1 deletion bindings/pydrake/systems/scalar_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,17 @@
import copy
from functools import partial

from pydrake.autodiffutils import AutoDiffXd
from pydrake.common import pretty_class_name
from pydrake.symbolic import Expression
from pydrake.systems.framework import (
LeafSystem_,
SystemScalarConverter,
)
from pydrake.common.cpp_template import (
_get_module_from_stack,
TemplateClass,
TemplateMethod,
)


Expand Down Expand Up @@ -194,6 +197,30 @@ def system_init(self, *args, **kwargs):
self, *args, converter=converter, **kwargs)

cls.__init__ = system_init

# Patch the scalar-conversion functions (only when called from Python,
# not when called from C++) to return the Python type instead of the
# WrappedSystem shim.
cls.ToScalarType = TemplateMethod(
cls=cls, name="ToScalarType")
cls.ToScalarTypeMaybe = TemplateMethod(
cls=cls, name="ToScalarTypeMaybe")
for U in SystemScalarConverter.SupportedScalars:
new_method = template._make_new_to_scalar_type_method(
T=U, U=T, maybe=False)
new_method_maybe = template._make_new_to_scalar_type_method(
T=U, U=T, maybe=True)
if U == AutoDiffXd:
cls.ToAutoDiffXd = new_method
cls.ToAutoDiffXdMaybe = new_method_maybe
elif U == Expression:
cls.ToSymbolic = new_method
cls.ToSymbolicMaybe = new_method_maybe
cls.ToScalarType.add_instantiation(
param=[U,], instantiation=new_method)
cls.ToScalarTypeMaybe.add_instantiation(
param=[U,], instantiation=new_method_maybe)

return cls

def _check_if_copying(self, obj, *args, **kwargs):
Expand Down Expand Up @@ -222,5 +249,27 @@ def _make_converter(self):
# to when the conversion is called.
for (T, U) in self._T_pairs:
conversion = partial(self._make, T, U)
converter._Add[T, U](conversion)
converter._AddConstructor[T, U](conversion)
return converter

def _make_new_to_scalar_type_method(self, *, T, U, maybe):
# Helper for _on_add for replacing the ToScalarType{Maybe}_ methods
# with pure-python implementations that avoid C++ lifetime challenges.
# U is the "from" type (the type of the System being called).
# T is the "to" type (the type of System to return).
def _to_scalar_type(system_U):
converter = system_U.get_system_scalar_converter()
if not converter.IsConvertible[U, T]():
if maybe:
return None
raise RuntimeError(
f"System {system_U.GetSystemPathname()} "
f"of type {type(system_U)} "
f"does not support scalar conversion to type {T}")
result = self._make(T=T, U=U, system_U=system_U)
result._HandlePostConstructionScalarConversion[U](system_U)
return result
_to_scalar_type.T = T
_to_scalar_type.U = U
_to_scalar_type.maybe = maybe
return _to_scalar_type
3 changes: 3 additions & 0 deletions bindings/pydrake/systems/test/general_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
SystemVisitor,
SystemBase,
SystemOutput, SystemOutput_,
SystemScalarConverter,
VectorBase, VectorBase_,
TriggerType,
VectorSystem, VectorSystem_,
Expand Down Expand Up @@ -439,6 +440,8 @@ def test_scalar_type_conversion(self):
float_system.get_input_port(0).FixValue(float_context, 1.)
for T in [float, AutoDiffXd, Expression]:
system = Adder_[T](1, 1)
self.assertIsInstance(system.get_system_scalar_converter(),
SystemScalarConverter)
# N.B. Current scalar conversion does not permit conversion to and
# from the same type.
if T != float:
Expand Down
56 changes: 56 additions & 0 deletions bindings/pydrake/systems/test/scalar_conversion_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,25 @@ def _construct_copy(self, other, converter=None):
Example = Example_[None]


@mut.TemplateSystem.define("NonsymbolicExample_", T_list=(float, AutoDiffXd))
def NonsymbolicExample_(T):

class Impl(LeafSystem_[T]):
"""Testing non-symbolic example."""

def _construct(self, value, converter=None):
LeafSystem_[T].__init__(self, converter=converter)
self.value = value

def _construct_copy(self, other, converter=None):
Impl._construct(self, other.value, converter=converter)

return Impl


NonsymbolicExample = NonsymbolicExample_[None]


class TestScalarConversion(unittest.TestCase):
def test_converter_attributes(self):
conversion_scalars = (
Expand Down Expand Up @@ -207,6 +226,43 @@ def _construct_copy(self, other, converter=None):
with self.assertRaises(AssertionError):
mut.TemplateSystem.define("C", T_pairs=T_pairs_unsupported)

def test_nonsymbolic_example(self):
"""Tests the NonsymbolicExample_ system."""
# Test private properties (do NOT use these in your code!).
self.assertEqual(len(NonsymbolicExample_._T_list), 2)
self.assertEqual(len(NonsymbolicExample_._T_pairs), 2)

# Test calls that we have available for scalar conversion.
for (T, U), use_maybe_variation in itertools.product(
SystemScalarConverter.SupportedConversionPairs, [False, True]):
if U is Expression:
continue
expected_is_convertible = T is not Expression
system_U = NonsymbolicExample_[U](2)
system_U._AddExternalConstraint(_ExternalSystemConstraint())
if expected_is_convertible:
if use_maybe_variation:
system_T = system_U.ToScalarTypeMaybe[T]()
else:
system_T = system_U.ToScalarType[T]()
self.assertEqual(system_T.value, system_U.value)
if T is AutoDiffXd:
if use_maybe_variation:
system_ad = system_U.ToAutoDiffXdMaybe()
else:
system_ad = system_U.ToAutoDiffXd()
self.assertEqual(system_ad.value, system_U.value)
continue
# Carefully check when happens when NOT convertible.
if use_maybe_variation:
system_T = system_U.ToScalarTypeMaybe[T]()
self.assertIsNone(system_T)
continue
with self.assertRaisesRegex(
RuntimeError,
".*NonsymbolicExample.*not support.*Expression"):
system_U.ToScalarType[T]()

def test_inheritance(self):

@mut.TemplateSystem.define("Child_")
Expand Down
40 changes: 38 additions & 2 deletions systems/framework/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ drake_cc_package_library(
":vector",
":vector_system",
":witness_function",
":wrapped_system",
],
)

Expand Down Expand Up @@ -609,19 +610,28 @@ drake_cc_library(

drake_cc_library(
name = "diagram",
srcs = ["diagram.cc"],
srcs = [
"diagram.cc",
# The internal-use-only wrapped_system is a close friend of Diagram.
# It subclasses Diagram and Diagram's implementation uses it, so we
# must group that dependency cycle into a single library. However, note
# that one constructor is defined in wrapped_system_builder.cc as part
# of the actual :wrapped_system library.
"wrapped_system.cc",
"wrapped_system.h",
],
hdrs = ["diagram.h"],
deps = [
":diagram_context",
":diagram_output_port",
":system",
"//common:default_scalars",
"//common:essential",
"//common:string_container",
],
implementation_deps = [
":abstract_value_cloner",
"//common:pointer_cast",
"//common:string_container",
],
)

Expand Down Expand Up @@ -696,6 +706,22 @@ drake_cc_library(
],
)

drake_cc_library(
name = "wrapped_system",
srcs = [
# This file just defines one constructor. Most functions are defined in
# wrapped_system.cc which is necessarily part of the :diagram library.
"wrapped_system_builder.cc",
],
hdrs = ["wrapped_system.h"],
deps = [
":diagram",
],
implementation_deps = [
":diagram_builder",
],
)

# === test/ ===

drake_cc_googletest(
Expand Down Expand Up @@ -1096,4 +1122,14 @@ drake_cc_googletest(
],
)

drake_cc_googletest(
name = "wrapped_system_test",
deps = [
":wrapped_system",
"//common/test_utilities",
"//systems/framework/test_utilities:scalar_conversion",
"//systems/primitives:adder",
],
)

add_lint_tests(enable_clang_format_lint = False)
Loading

0 comments on commit 2b97a5e

Please sign in to comment.