From 8d50312f74addb1c6d4533c7741b626b228bddf4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=BE=E9=9B=A8=E9=AD=94=E7=90=86=E6=B2=99?= Date: Tue, 16 Apr 2019 22:33:31 -0700 Subject: [PATCH] [Relay] Fix Fuse (#3035) * save * fix * Update fuse_ops.cc --- src/relay/pass/fuse_ops.cc | 10 +++++++++- tests/python/relay/test_backend_interpreter.py | 10 +++++++++- 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/src/relay/pass/fuse_ops.cc b/src/relay/pass/fuse_ops.cc index 6de9c2d65f90..12e3174dcade 100644 --- a/src/relay/pass/fuse_ops.cc +++ b/src/relay/pass/fuse_ops.cc @@ -865,9 +865,17 @@ class FuseMutator : private ExprMutator { } Expr MakeNewFunction(GraphPartitioner::Group* group, Type ret_type, Expr body) { + // If the function has no call, it is not a primitive function. + struct HasCallVisitor : ExprVisitor { + bool has_call = false; + void VisitExpr_(const CallNode* op) final { + has_call = true; + } + } visitor; + visitor(body); const GroupInfo& ginfo = ginfo_[group]; auto func = FunctionNode::make(ginfo.params, body, ret_type, {}); - func = FunctionSetAttr(func, "Primitive", tvm::Integer(1)); + func = FunctionSetAttr(func, "Primitive", tvm::Integer(visitor.has_call)); return CallNode::make(func, ginfo.arguments, Attrs()); } diff --git a/tests/python/relay/test_backend_interpreter.py b/tests/python/relay/test_backend_interpreter.py index 5d8ceb4c7bdc..e8a99e14d741 100644 --- a/tests/python/relay/test_backend_interpreter.py +++ b/tests/python/relay/test_backend_interpreter.py @@ -51,6 +51,12 @@ def test_tuple_value(): np.testing.assert_allclose(tv[2].asnumpy(), 3) +def test_tuple_getitem(): + two = relay.add(relay.const(1), relay.const(1)) + func = relay.Function([], relay.TupleGetItem(relay.Tuple([relay.const(1), relay.const(2)]), 0)) + check_eval(func, [], 1) + + def test_id(): x = relay.var('x', 'float32') ident = relay.Function([x], x) @@ -223,4 +229,6 @@ def test_function_taking_adt_ref_tuple(): test_kwargs_params() test_ref() test_tensor_value() - test_function_taking_adt_ref_tuple() + test_tuple_value() + test_tuple_getitem() + test_function_taking_adt_ref_tuple() \ No newline at end of file