From e18dc356842dacedc1d4ab6eff4609a52a0a1d2a Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Thu, 13 May 2021 00:43:36 -0600 Subject: [PATCH] support concat in recast (#8028) --- python/tvm/relay/transform/recast.py | 18 ++++++++++++++---- tests/python/relay/test_recast.py | 24 ++++++++++++++++++++++++ 2 files changed, 38 insertions(+), 4 deletions(-) diff --git a/python/tvm/relay/transform/recast.py b/python/tvm/relay/transform/recast.py index 05a72676a9075..2c88f10dc8f45 100644 --- a/python/tvm/relay/transform/recast.py +++ b/python/tvm/relay/transform/recast.py @@ -57,10 +57,20 @@ def visit_call(self, call): # Downcast this op if its the correct type and not skipped. if call.op in self.valid_ops and current_layer not in self.skip_layers: # Recast inputs to specified type. - args = [self.visit(arg) for arg in call.args] - new_args = list() - for arg in args: - new_args.append(relay.cast(arg, dtype=self.dtype)) + if call.op == relay.op.get("concatenate"): + if len(call.args) != 1 or not isinstance(call.args[0], relay.expr.Tuple): + return Call(new_fn, args, call.attrs) + + tuple_args = [self.visit(arg) for arg in call.args[0].fields] + new_args = list() + for arg in tuple_args: + new_args.append(relay.cast(arg, dtype=self.dtype)) + new_args = [relay.expr.Tuple(new_args)] + else: + args = [self.visit(arg) for arg in call.args] + new_args = list() + for arg in args: + new_args.append(relay.cast(arg, dtype=self.dtype)) # If out_dtype is in the attributes, we need to update it. orig_dtype = None diff --git a/tests/python/relay/test_recast.py b/tests/python/relay/test_recast.py index 8c5a562ddbba4..43def9df41ce2 100644 --- a/tests/python/relay/test_recast.py +++ b/tests/python/relay/test_recast.py @@ -102,6 +102,30 @@ def expected(): assert tvm.ir.structural_equal(expected, post) +def test_recast_concat(): + def before(): + x = relay.var("x", shape=[1, 4]) + y = relay.var("y", shape=[1, 4]) + t = relay.Tuple([x, y]) + c = relay.op.concatenate(t, axis=1) + return relay.Function([x, y], c) + + def expected(): + xv = relay.var("x", shape=[1, 4]) + yv = relay.var("y", shape=[1, 4]) + x = relay.cast(xv, "float16") + y = relay.cast(yv, "float16") + t = relay.Tuple([x, y]) + c = relay.op.concatenate(t, axis=1) + c = relay.cast(c, "float32") + return relay.Function([xv, yv], c) + + pre = before() + post = recast(pre, "float16", "float32", ops=["concatenate"]) + expected = expected() + assert tvm.ir.structural_equal(expected, post) + + if __name__ == "__main__": test_recast_simple() test_recast_medium()