Skip to content

Commit

Permalink
[Relay] Fix Fuse (#3035)
Browse files Browse the repository at this point in the history
* save

* fix

* Update fuse_ops.cc
  • Loading branch information
MarisaKirisame authored and masahi committed Apr 17, 2019
1 parent d5f7224 commit 8d50312
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 2 deletions.
10 changes: 9 additions & 1 deletion src/relay/pass/fuse_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -865,9 +865,17 @@ class FuseMutator : private ExprMutator {
}

Expr MakeNewFunction(GraphPartitioner::Group* group, Type ret_type, Expr body) {
// If the function has no call, it is not a primitive function.
struct HasCallVisitor : ExprVisitor {
bool has_call = false;
void VisitExpr_(const CallNode* op) final {
has_call = true;
}
} visitor;
visitor(body);
const GroupInfo& ginfo = ginfo_[group];
auto func = FunctionNode::make(ginfo.params, body, ret_type, {});
func = FunctionSetAttr(func, "Primitive", tvm::Integer(1));
func = FunctionSetAttr(func, "Primitive", tvm::Integer(visitor.has_call));
return CallNode::make(func, ginfo.arguments, Attrs());
}

Expand Down
10 changes: 9 additions & 1 deletion tests/python/relay/test_backend_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,12 @@ def test_tuple_value():
np.testing.assert_allclose(tv[2].asnumpy(), 3)


def test_tuple_getitem():
two = relay.add(relay.const(1), relay.const(1))
func = relay.Function([], relay.TupleGetItem(relay.Tuple([relay.const(1), relay.const(2)]), 0))
check_eval(func, [], 1)


def test_id():
x = relay.var('x', 'float32')
ident = relay.Function([x], x)
Expand Down Expand Up @@ -223,4 +229,6 @@ def test_function_taking_adt_ref_tuple():
test_kwargs_params()
test_ref()
test_tensor_value()
test_function_taking_adt_ref_tuple()
test_tuple_value()
test_tuple_getitem()
test_function_taking_adt_ref_tuple()

0 comments on commit 8d50312

Please sign in to comment.