From 187d073dfff140d70a44217c5331f30b442bfdac Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Sat, 6 Apr 2019 10:01:22 -0700 Subject: [PATCH] Support wrapping PyTorch builtin functions --- funsor/six.py | 30 +++++++++++++++++++++++++----- test/test_torch.py | 9 +++++++++ 2 files changed, 34 insertions(+), 5 deletions(-) diff --git a/funsor/six.py b/funsor/six.py index eef789240..9428b0d4e 100644 --- a/funsor/six.py +++ b/funsor/six.py @@ -1,6 +1,7 @@ from __future__ import absolute_import, division, print_function import inspect +import re import six @@ -44,11 +45,30 @@ def decorator(fn): def getargspec(fn): - """wrapper to remove annoying DeprecationWarning for inspect.getargspec in Py3""" - if six.PY3: - args, vargs, kwargs, defaults, _, _, _ = inspect.getfullargspec(fn) - else: - args, vargs, kwargs, defaults = inspect.getargspec(fn) + """ + Similar to Python 2's :py:func:`inspect.getargspec` but: + - In Python 3 uses ``getfullargspec`` to avoid ``DeprecationWarning``. + - For builtin functions like ``torch.matmul``, falls back to attmpting + to parse the function docstring, assuming torch-style. + """ + assert callable(fn) + try: + if six.PY3: + args, vargs, kwargs, defaults, _, _, _ = inspect.getfullargspec(fn) + else: + args, vargs, kwargs, defaults = inspect.getargspec(fn) + except TypeError: + # Fall back to attmpting to parse a PyTorch-style docstring. + match = re.match(r"\s{}\(([^)]*)\)".format(fn.__name__), fn.__doc__) + if match is None: + raise + parts = match.group(1).split(", ") + args = [a.split("=")[0] for a in parts] + if not all(re.match(r"^[^\d\W]\w*\Z", arg) for arg in args): + raise + vargs = None + kwargs = None + defaults = () # Ignore defaults. return args, vargs, kwargs, defaults diff --git a/test/test_torch.py b/test/test_torch.py index 26004aff2..3bcbbbfcd 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -587,6 +587,15 @@ def max_and_argmax(x): assert_close(actual_argmax, expected_argmax) +def test_function_of_torch_tensor(): + x = torch.randn(4, 3) + y = torch.randn(3, 2) + f = funsor.torch.function(reals(4, 3), reals(3, 2), reals(4, 2))(torch.matmul) + actual = f(x, y) + expected = f(Tensor(x), Tensor(y)) + assert_close(actual, expected) + + def test_align(): x = Tensor(torch.randn(2, 3, 4), OrderedDict([ ('i', bint(2)),