Skip to content

Commit

Permalink
[TE] Support varargs in te.compute (apache#9796)
Browse files Browse the repository at this point in the history
* [TE] Support varargs in te.compute

Support varargs (`lambda x, *args: ...`) in te.compute. The varargs take
indices into the remaining dimensions of the outputs shape. This
requires using inspect.getfullargspec instead of `fcompute.__code__`.

Also add checks that there are no keyword arguments.

* implicitly broadcast to remaining dimensions
  • Loading branch information
Tristan Konolige authored and ylc committed Jan 13, 2022
1 parent a7247cb commit 2e445b7
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 10 deletions.
29 changes: 20 additions & 9 deletions python/tvm/te/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
# pylint: disable=invalid-name
from numbers import Integral as _Integral
from typing import List, Union
import inspect

import tvm._ffi
from tvm._ffi.base import string_types
Expand Down Expand Up @@ -89,18 +90,28 @@ def compute(shape, fcompute, name="compute", tag="", attrs=None):
shape = (shape,) if isinstance(shape, tvm.tir.PrimExpr) else shape
# for python3
shape = tuple([int(s) if isinstance(s, float) else s for s in shape])
ndim = len(shape)
code = fcompute.__code__

out_ndim = ndim
if code.co_argcount == 0:
arg_names = ["i%d" % i for i in range(ndim)]
out_ndim = len(shape)

argspec = inspect.getfullargspec(fcompute)
if len(argspec.args) == 0 and argspec.varargs is None:
arg_names = ["i%d" % i for i in range(out_ndim)]
elif argspec.varargs is not None:
# if there is a varargs, it takes the remaining dimensions of out_ndim
arg_names = argspec.args + [f"i{i}" for i in range(out_ndim - len(argspec.args))]
else:
arg_names = code.co_varnames[: code.co_argcount]
out_ndim = code.co_argcount
arg_names = argspec.args
# if there are fewer args than out dimensions, the remaining dimensions
# are implicitly broadcast
out_ndim = len(arg_names)
assert argspec.varkw is None, "Variable keyword arguments not supported in fcompute"
assert argspec.defaults is None, "Default arguments not supported in fcompute"
assert len(argspec.kwonlyargs) == 0, "Keyword arguments are not supported in fcompute"

if out_ndim != len(arg_names):
raise ValueError("fcompute do not match dimension, ndim=%d" % ndim)
raise ValueError(
"Number of args to fcompute does not match dimension, "
"args=%d, dimension=%d" % (len(arg_names), out_ndim)
)

dim_var = [tvm.tir.IterVar((0, s), x, 0) for x, s in zip(arg_names, shape[:out_ndim])]
body = fcompute(*[v.var for v in dim_var])
Expand Down
4 changes: 3 additions & 1 deletion python/tvm/te/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,9 @@ class Tensor(DataProducer, _expr.ExprOp):
def __call__(self, *indices):
ndim = self.ndim
if len(indices) != ndim:
raise ValueError("Need to provide %d index in tensor slice" % ndim)
raise ValueError(
"Need to provide %d index in tensor but %d was provided" % (ndim, len(indices))
)
indices = convert_to_object(indices)
args = []
for x in indices:
Expand Down

0 comments on commit 2e445b7

Please sign in to comment.