diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc index 5d3f80bb02b7b..a2a3e96dd567e 100644 --- a/src/relax/transform/fuse_ops.cc +++ b/src/relax/transform/fuse_ops.cc @@ -1203,10 +1203,11 @@ class CompositeFunctionAnnotator : public ExprMutator { func->GetAttr(attr::kCodegen).defined()) { continue; } - auto new_body = VisitExpr(func->body); + + auto new_body = VisitWithNewScope(func->body, func->params); if (!new_body.same_as(func->body)) { - auto new_func = Function(func->params, VisitExpr(func->body), func->ret_struct_info, - func->is_pure, func->attrs, func->span); + auto new_func = Function(func->params, new_body, func->ret_struct_info, func->is_pure, + func->attrs, func->span); builder_->UpdateFunction(entry.first, new_func); } } diff --git a/tests/python/relax/test_transform_fuse_ops_by_pattern.py b/tests/python/relax/test_transform_fuse_ops_by_pattern.py index b6bcf01862b88..5e700b277f329 100644 --- a/tests/python/relax/test_transform_fuse_ops_by_pattern.py +++ b/tests/python/relax/test_transform_fuse_ops_by_pattern.py @@ -1130,5 +1130,92 @@ def test_error_on_repeated_variable_definitions(): relax.transform.FuseOpsByPattern(patterns)(mod) +def test_matmul_symbolic_var(): + @I.ir_module + class Before: + @R.function + def main( + x: R.Tensor(["batch_size", 1024], "float16"), + w1: R.Tensor([1024, 1024], "float16"), + w2: R.Tensor([1024, "M"], "float16"), + ): + with R.dataflow(): + matmul1 = R.matmul(x, w1) + matmul2 = R.matmul(x, w2) + out = (matmul1, matmul2) + R.output(out) + return out + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor(["batch_size", 1024], "float16"), + w1: R.Tensor([1024, 1024], "float16"), + w2: R.Tensor([1024, "M"], "float16"), + ) -> R.Tuple( + R.Tensor(["batch_size", 1024], "float16"), + R.Tensor(["batch_size", "M"], "float16"), + ): + cls = Expected + with R.dataflow(): + matmul1 = cls.fused_relax_matmul_cublas(x, w1) + matmul2 = cls.fused_relax_matmul1_cublas(x, w2) + out = (matmul1, matmul2) + R.output(out) + return out + + @R.function + def fused_relax_matmul_cublas( + x: R.Tensor(["batch_size", 1024], "float16"), + w1: R.Tensor([1024, 1024], "float16"), + ) -> R.Tensor(["batch_size", 1024], "float16"): + batch_size = T.int64() + R.func_attr({"Codegen": "cublas"}) + + @R.function + def inner_func( + x: R.Tensor([batch_size, 1024], "float16"), + w1: R.Tensor([1024, 1024], "float16"), + ) -> R.Tensor([batch_size, 1024], "float16"): + R.func_attr({"Composite": "cublas.matmul"}) + with R.dataflow(): + out = R.matmul(x, w1) + R.output(out) + return out + + out = inner_func(x, w1) + return out + + @R.function + def fused_relax_matmul1_cublas( + x: R.Tensor(["batch_size", 1024], "float16"), + w2: R.Tensor([1024, "M"], "float16"), + ) -> R.Tensor(["batch_size", "M"], "float16"): + batch_size = T.int64() + M = T.int64() + R.func_attr({"Codegen": "cublas"}) + + @R.function + def inner_func( + x: R.Tensor([batch_size, 1024], "float16"), + w2: R.Tensor((1024, M), "float16"), + ) -> R.Tensor([batch_size, M], "float16"): + R.func_attr({"Composite": "cublas.matmul"}) + with R.dataflow(): + out = R.matmul(x, w2) + R.output(out) + return out + + out = inner_func(x, w2) + return out + + patterns = relax.backend.pattern_registry.get_patterns_with_prefix("cublas.matmul") + After = relax.transform.FuseOpsByPattern(patterns, bind_constants=False, annotate_codegen=True)( + Before + ) + tvm.ir.assert_structural_equal(Expected, After) + + if __name__ == "__main__": pytest.main([__file__])