Skip to content

Commit

Permalink
[bindings] Add custom ref_cycle annotation
Browse files Browse the repository at this point in the history
Introduce the new annotation internal::ref_cycle<M, N>(). It will eventually
replace existing cyclic py::keep_alive<>() annotations (which do their job, but
leak their participants forever). The participants of ref_cycle<>() will be
garbage collectible, so that applications can run loops that use various drake
components without exhausting memory.

This patch just adds the implementation and its unit test.
  • Loading branch information
rpoyner-tri committed Oct 24, 2024
1 parent 44bf559 commit 4aaa667
Show file tree
Hide file tree
Showing 5 changed files with 456 additions and 0 deletions.
31 changes: 31 additions & 0 deletions bindings/pydrake/common/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,20 @@ drake_cc_library(
],
)

drake_cc_library(
name = "ref_cycle_pybind",
srcs = ["ref_cycle_pybind.cc"],
hdrs = ["ref_cycle_pybind.h"],
declare_installed_headers = 0,
visibility = ["//visibility:public"],
deps = [
"//bindings/pydrake:pydrake_pybind",
"//common:essential",
"@fmt",
"@pybind11",
],
)

# N.B. Any C++ libraries that include this must include `cpp_template_py` when
# being used in Python.
drake_cc_library(
Expand Down Expand Up @@ -492,6 +506,23 @@ drake_py_unittest(
],
)

drake_pybind_library(
name = "ref_cycle_test_util_py",
testonly = True,
add_install = False,
cc_deps = [":ref_cycle_pybind"],
cc_srcs = ["test/ref_cycle_test_util_py.cc"],
package_info = PACKAGE_INFO,
)

drake_py_unittest(
name = "ref_cycle_test",
deps = [
":common",
":ref_cycle_test_util_py",
],
)

drake_py_unittest(
name = "schema_test",
deps = [
Expand Down
101 changes: 101 additions & 0 deletions bindings/pydrake/common/ref_cycle_pybind.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
#include "drake/bindings/pydrake/common/ref_cycle_pybind.h"

#include <fmt/format.h>

#include "drake/common/drake_assert.h"
#include "drake/common/drake_throw.h"

using pybind11::handle;
using pybind11::detail::function_call;

namespace drake {
namespace pydrake {
namespace internal {

namespace {

void make_ref_cycle(handle p0, handle p1) {
DRAKE_DEMAND(static_cast<bool>(p0));
DRAKE_DEMAND(static_cast<bool>(p1));
DRAKE_DEMAND(!p0.is_none());
DRAKE_DEMAND(!p1.is_none());
DRAKE_DEMAND(PyType_IS_GC(Py_TYPE(p0.ptr())));
DRAKE_DEMAND(PyType_IS_GC(Py_TYPE(p1.ptr())));

// Each peer will have a new/updated attribute, containing a set of
// handles. Insert each into the other's handle set. Create the set first
// if it is not yet existing.
auto make_link = [](handle a, handle b) {
static const char refcycle_peers[] = "_pydrake_internal_ref_cycle_peers";
handle peers;
if (hasattr(a, refcycle_peers)) {
peers = a.attr(refcycle_peers);
} else {
peers = PySet_New(nullptr);
DRAKE_DEMAND(PyType_IS_GC(Py_TYPE(peers.ptr())));
a.attr(refcycle_peers) = peers;
Py_DECREF(peers.ptr());
}
// Ensure the proper ref count on the `peers` set. If it is > 1, the
// objects will live forever. If it is < 1, the cycle will just be deleted
// immediately.
DRAKE_DEMAND(Py_REFCNT(peers.ptr()) == 1);
PySet_Add(peers.ptr(), b.ptr());
};
make_link(p0, p1);
make_link(p1, p0);
}

} // namespace

void ref_cycle_impl(
size_t peer0, size_t peer1, const function_call& call, handle ret) {
// Returns the handle selected by the given index. Throws if the index is
// invalid.
auto get_arg = [&](size_t n) -> handle {
if (n == 0) {
return ret;
}
if (n == 1 && call.init_self) {
return call.init_self;
}
if (n <= call.args.size()) {
return call.args[n - 1];
}
pybind11::pybind11_fail(fmt::format(
"Could not activate ref_cycle: index {} is invalid for function '{}'",
n, call.func.name));
};
handle p0 = get_arg(peer0);
handle p1 = get_arg(peer1);

// Returns false if the handle's value is None. Throws if the handle's value
// is not of a garbage-collectable type.
auto check_handle = [&](size_t n, handle p) -> bool {
if (p.is_none()) {
return false;
}
// Among the reasons the following check may fail is that one of the
// participating pybind11::class_ types does not declare
// pybind11::dynamic_attr().
if (!PyType_IS_GC(Py_TYPE(p.ptr()))) {
pybind11::pybind11_fail(fmt::format(
"Could not activate ref_cycle: object type at index {} for "
"function '{}' is not tracked by garbage collection.",
n, call.func.name));
}
return true;
};
if (!check_handle(peer0, p0) || !check_handle(peer1, p1)) {
// At least one of the handles is None. We can't construct a ref-cycle, but
// neither should we complain. A None variable value could happen for any
// number of legitimate reasons, and does not mean that the ref_cycle call
// policy is defective.
return;
}
make_ref_cycle(p0, p1);
}

} // namespace internal
} // namespace pydrake
} // namespace drake
72 changes: 72 additions & 0 deletions bindings/pydrake/common/ref_cycle_pybind.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
#pragma once

#include "drake/bindings/pydrake/pydrake_pybind.h"

namespace drake {
namespace pydrake {
namespace internal {

/* pydrake::internal::ref_cycle is a custom call policy for pybind11.
For an overview of other call policies, See
https://pybind11.readthedocs.io/en/stable/advanced/functions.html#additional-call-policies
`ref_cycle` creates a reference count cycle that Python's cyclic garbage
collection can find and collect, once the cycle's objects are no longer
reachable. Both peer objects must by pybind11 classes that were defined with
the dynamic_attr() annotation.
`ref_cycle` cause each object to refer to the other in a cycle. It is
bidirectional and symmetric. The order of the template arguments does not
matter.
Note the consequences for object lifetimes:
* M keeps N alive, and N keeps M alive.
* Neither object is finalized until:
* both are unreachable, and
* garbage collection runs.
*/
template <size_t Peer0, size_t Peer1>
struct ref_cycle {};

/* This function is used in the template below to select peers by call/return
index. */
void ref_cycle_impl(size_t peer0, size_t peer1,
const pybind11::detail::function_call& call, pybind11::handle ret);

} // namespace internal
} // namespace pydrake
} // namespace drake

namespace pybind11 {
namespace detail {

// Provide a specialization of the pybind11 internal process_attribute
// template; this allows writing an annotation that works seamlessly in
// bindings definitions.
template <size_t Peer0, size_t Peer1>
class process_attribute<drake::pydrake::internal::ref_cycle<Peer0, Peer1>>
: public process_attribute_default<
drake::pydrake::internal::ref_cycle<Peer0, Peer1>> {
public:
// NOLINTNEXTLINE(runtime/references)
static void precall(function_call& call) {
if constexpr (!needs_result()) {
drake::pydrake::internal::ref_cycle_impl(Peer0, Peer1, call, handle());
}
}

// NOLINTNEXTLINE(runtime/references)
static void postcall(function_call& call, handle ret) {
if constexpr (needs_result()) {
drake::pydrake::internal::ref_cycle_impl(Peer0, Peer1, call, ret);
}
}

private:
static constexpr bool needs_result() { return Peer0 == 0 || Peer1 == 0; }
};

} // namespace detail
} // namespace pybind11
175 changes: 175 additions & 0 deletions bindings/pydrake/common/test/ref_cycle_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
"""Unit test for ref_cycle<>() annotation.
See also ref_cycle_test_util_py.cc for the bindings used in the tests.
"""
import functools
import gc
import sys
import unittest
import weakref

from pydrake.common.ref_cycle_test_util import (
NotDynamic, IsDynamic, invalid_arg_index, free_function, ouroboros)


def actual_ref_count(o):
"""Returns the actual ref count of `o`, in the caller's scope."""
# sys.getrefcount() always artificially adds 1 to its result, owing to the
# machinery of python calling the native-code implementation. Since we wrap
# it here, we have to adjust the result to account for the
# python-implemented function call. For extra fun, the ref-count cost of a
# python function call varies with python interpreter versions, so
# wrapped_refcount_cost() actually measures the total (native call plus
# python call) cost.
@functools.cache
def wrapped_refcount_cost():
def wrapped(o):
return sys.getrefcount(o)
return wrapped(object())

return sys.getrefcount(o) - wrapped_refcount_cost()


class TestRefCycle(unittest.TestCase):
def check_is_collectable_cycle(self, p0, p1):
# The edges of the cycle are:
# p0 -> p0.__dict__ -> p0._pydrake_internal_ref_cycle_peers \
# -> p1 -> p1.__dict__ -> p1._pydrake_internal_ref_cycle_peers -> p0
# where the object at each _pydrake_internal_ref_cycle_peers is a set.
#
# It is impractical to check the counts of p0 and p1 here because
# callers may hold an arbitrary number of references.

for x in [p0, p1]:
self.assertEqual(actual_ref_count(x.__dict__), 1)
self.assertEqual(
actual_ref_count(x._pydrake_internal_ref_cycle_peers), 1)

# Check that all parts are tracked by gc.
self.assertTrue(gc.is_tracked(x))
self.assertTrue(gc.is_tracked(x.__dict__))
self.assertTrue(gc.is_tracked(x._pydrake_internal_ref_cycle_peers))

# Check that the peers refer to each other.
self.assertTrue(p1 in p0._pydrake_internal_ref_cycle_peers)
self.assertTrue(p0 in p1._pydrake_internal_ref_cycle_peers)

def check_no_cycle(self, p0, p1):
for x in [p0, p1]:
self.assertFalse(hasattr(x, '_pydrake_internal_ref_cycle_peers'))

def test_invalid_index(self):
with self.assertRaisesRegex(RuntimeError,
"Could not activate ref_cycle.*"):
invalid_arg_index()

def test_ouroboros(self):
# The self-cycle edges are:
# dut -> dut.__dict__ -> dut._pydrake_internal_ref_cycle_peers -> dut
#
# This still passes check_is_collectable_cycle() -- the function just
# does redundant work.
dut = IsDynamic()
returned = ouroboros(dut)
self.assertEqual(returned, dut)
self.assertEqual(len(dut._pydrake_internal_ref_cycle_peers), 1)
self.check_is_collectable_cycle(returned, dut)

def test_free_function(self):
p0 = IsDynamic()
p1 = IsDynamic()
free_function(p0, p1)
self.check_is_collectable_cycle(p0, p1)

def test_not_dynamic_add(self):
dut = NotDynamic()
peer = IsDynamic()
# Un-annotated call is fine.
dut.AddIs(peer)
self.check_no_cycle(dut, peer)
# Annotated call dies because dut is not py::dynamic_attr().
with self.assertRaisesRegex(
RuntimeError, ".type.*index 1.*AddIsCycle.*not tracked.*"):
dut.AddIsCycle(peer)

def test_not_dynamic_return(self):
dut = NotDynamic()
# Un-annotated call is fine.
returned = dut.ReturnIs()
self.check_no_cycle(dut, returned)
# Annotated call dies because dut is not py::dynamic_attr().
with self.assertRaisesRegex(
RuntimeError, ".type.*index 1.*ReturnIsCycle.*not tracked.*"):
dut.ReturnIsCycle()

def test_not_dynamic_null(self):
dut = NotDynamic()
# Un-annotated call is fine.
self.assertIsNone(dut.ReturnNullIs())
# Annotated call does not die because one peer is missing.
self.assertIsNone(dut.ReturnNullIsCycle())

def test_is_dynamic_add_not(self):
dut = IsDynamic()
notpeer = NotDynamic()
dut.AddNot(notpeer)
self.check_no_cycle(dut, notpeer)
# Annotated call dies because notpeer is not py::dynamic_attr().
with self.assertRaisesRegex(
RuntimeError, ".type.*index 2.*AddNotCycle.*not tracked.*"):
dut.AddNotCycle(notpeer)

def test_is_dynamic_return_not(self):
dut = IsDynamic()
# Un-annotated call is fine.
returned = dut.ReturnNot()
self.check_no_cycle(dut, returned)
# Annotated call dies because return is not py::dynamic_attr().
with self.assertRaisesRegex(
RuntimeError, ".type.*index 0.*ReturnNotCycle.*not tracked.*"):
dut.ReturnNotCycle()

def test_is_dynamic_return_null(self):
dut = IsDynamic()
# Un-annotated call is fine.
self.assertIsNone(dut.ReturnNullNot())
self.assertIsNone(dut.ReturnNullIs())
# Annotated call does not die because one peer is missing.
self.assertIsNone(dut.ReturnNullNotCycle())
self.assertIsNone(dut.ReturnNullIsCycle())

def test_is_dynamic_add_is(self):
dut = IsDynamic()
peer = IsDynamic()
# Un-annotated call does not implement a cycle.
dut.AddIs(peer)
self.check_no_cycle(dut, peer)
# Annotated call produces a collectable cycle.
dut.AddIsCycle(peer)
self.check_is_collectable_cycle(dut, peer)

def test_is_dynamic_return_is(self):
dut = IsDynamic()
# Un-annotated call does not implement a cycle.
returned = dut.ReturnIs()
self.check_no_cycle(dut, returned)
# Annotated call produces a collectable cycle.
returned = dut.ReturnIsCycle()
self.check_is_collectable_cycle(dut, returned)

def test_actual_collection(self):

def make_a_cycle():
dut = IsDynamic()
return dut.ReturnIsCycle()

cycle = make_a_cycle()
finalizer = weakref.finalize(cycle, lambda: None)
# Cycle is alive while we refer to it.
self.assertTrue(finalizer.alive)
del cycle
# Cycle is alive because of the ref_cycle.
self.assertTrue(finalizer.alive)
gc.collect()
# Cycle does not survive garbage collection.
self.assertFalse(finalizer.alive)
Loading

0 comments on commit 4aaa667

Please sign in to comment.