Skip to content

Commit

Permalink
fix typo in backend interpreter (apache#2752)
Browse files Browse the repository at this point in the history
  • Loading branch information
MarisaKirisame authored and wweic committed Mar 12, 2019
1 parent f197307 commit 0c343c2
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 2 deletions.
2 changes: 1 addition & 1 deletion python/tvm/relay/backend/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def __init__(self, value):

def _arg_to_ast(arg):
if isinstance(arg, TensorValue):
return Constant(arg.data.copyto(_nd.cpu(0)))
return Constant(arg.data.copyto(nd.cpu(0)))
elif isinstance(arg, np.ndarray):
return Constant(nd.array(arg))
elif isinstance(arg, Constant):
Expand Down
8 changes: 7 additions & 1 deletion tests/python/relay/test_backend_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import tvm
import tvm.testing
from tvm import relay
from tvm.relay.backend.interpreter import Value, TupleValue
from tvm.relay.backend.interpreter import Value, TupleValue, TensorValue
from tvm.relay.scope_builder import ScopeBuilder
from tvm.relay import testing, create_executor

Expand Down Expand Up @@ -135,6 +135,11 @@ def test_binds():
tvm.testing.assert_allclose(xx + xx, res)


def test_tensor_value():
x = relay.var("x", shape=(1, 10))
xx = np.ones((1, 10)).astype("float32")
check_eval(relay.Function([x], x), [TensorValue(xx)], xx)

def test_kwargs_params():
x = relay.var("x", shape=(1, 10))
y = relay.var("y", shape=(1, 10))
Expand All @@ -159,3 +164,4 @@ def test_kwargs_params():
test_binds()
test_kwargs_params()
test_ref()
test_tensor_value()

0 comments on commit 0c343c2

Please sign in to comment.