Skip to content

Commit

Permalink
Ensure interpreted functions can take values that are not TensorValues (
Browse files Browse the repository at this point in the history
  • Loading branch information
slyubomirsky authored and tqchen committed Apr 16, 2019
1 parent 561e422 commit fcc5b42
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 1 deletion.
8 changes: 7 additions & 1 deletion python/tvm/relay/backend/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from .. import _make, ir_pass
from ... import register_func, nd
from ..base import NodeBase, register_relay_node
from ..expr import Call, Constant, GlobalVar, Function, const
from ..expr import Tuple, RefCreate, Call, Constant, GlobalVar, Function, const
from ..scope_builder import ScopeBuilder

class Value(NodeBase):
Expand Down Expand Up @@ -112,6 +112,12 @@ def __init__(self, value):
def _arg_to_ast(arg):
if isinstance(arg, TensorValue):
return Constant(arg.data.copyto(nd.cpu(0)))
elif isinstance(arg, TupleValue):
return Tuple([_arg_to_ast(field) for field in arg.fields])
elif isinstance(arg, RefValue):
return RefCreate(_arg_to_ast(arg.value))
elif isinstance(arg, ConstructorValue):
return Call(arg.constructor, [_arg_to_ast(field) for field in arg.fields])
elif isinstance(arg, np.ndarray):
return Constant(nd.array(arg))
elif isinstance(arg, Constant):
Expand Down
43 changes: 43 additions & 0 deletions tests/python/relay/test_backend_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import tvm.testing
from tvm import relay
from tvm.relay.backend.interpreter import Value, TupleValue, TensorValue
from tvm.relay.backend.interpreter import RefValue, ConstructorValue
from tvm.relay.scope_builder import ScopeBuilder
from tvm.relay import testing, create_executor

Expand Down Expand Up @@ -156,6 +157,7 @@ def test_tensor_value():
xx = np.ones((1, 10)).astype("float32")
check_eval(relay.Function([x], x), [TensorValue(xx)], xx)


def test_kwargs_params():
x = relay.var("x", shape=(1, 10))
y = relay.var("y", shape=(1, 10))
Expand All @@ -170,6 +172,46 @@ def test_kwargs_params():
tvm.testing.assert_allclose(res.asnumpy(), x_data + y_data + z_data)


def test_function_taking_adt_ref_tuple():
mod = relay.Module()
prelude = relay.prelude.Prelude(mod)
intrp = create_executor("debug", mod)

nil_value = ConstructorValue(prelude.nil, [], [])
cons_value = ConstructorValue(prelude.cons, [
TensorValue(np.random.rand(1, 10).astype('float32')),
nil_value
], [relay.TensorType((1, 10), 'float32')])

ref_value = RefValue(TensorValue(np.random.rand(1, 10).astype('float32')))
tuple_value = TupleValue(*[
TensorValue(np.random.rand(1, 10).astype('float32')) for _ in range(10)
])

id_func = intrp.evaluate(prelude.id)

res_nil = id_func(nil_value)
assert res_nil.constructor == nil_value.constructor
assert len(res_nil.fields) == 0

res_cons = id_func(cons_value)
assert res_cons.constructor == cons_value.constructor
assert len(res_cons.fields) == len(cons_value.fields)
tvm.testing.assert_allclose(res_cons.fields[0].asnumpy(),
cons_value.fields[0].asnumpy())
assert isinstance(res_cons.fields[1], ConstructorValue)
assert res_cons.fields[1].constructor == prelude.nil
assert len(res_cons.fields[1].fields) == 0

res_ref = id_func(ref_value)
tvm.testing.assert_allclose(res_ref.value.asnumpy(), ref_value.value.asnumpy())

res_tuple = id_func(tuple_value)
for i in range(10):
tvm.testing.assert_allclose(res_tuple.fields[i].asnumpy(),
tuple_value.fields[i].asnumpy())


if __name__ == "__main__":
test_id()
test_add_const()
Expand All @@ -181,3 +223,4 @@ def test_kwargs_params():
test_kwargs_params()
test_ref()
test_tensor_value()
test_function_taking_adt_ref_tuple()

0 comments on commit fcc5b42

Please sign in to comment.