Skip to content

Commit

Permalink
[Bugfix][Transform] Preserve symbolic variables in FuseOps (apache#16637
Browse files Browse the repository at this point in the history
)

[Unity][Transform] Preserve symbolic variables in FuseOps

Prior to this commit, the `CompositeFunctionAnnotator` visited the
body of functions without the parameters being considered in-scope.
As a result, `EraseToWellDefined` would remove known shapes from the
function body's `StructInfo`.
  • Loading branch information
Lunderberg committed Mar 12, 2024
1 parent 308599a commit 90d0712
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 3 deletions.
7 changes: 4 additions & 3 deletions src/relax/transform/fuse_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1203,10 +1203,11 @@ class CompositeFunctionAnnotator : public ExprMutator {
func->GetAttr<String>(attr::kCodegen).defined()) {
continue;
}
auto new_body = VisitExpr(func->body);

auto new_body = VisitWithNewScope(func->body, func->params);
if (!new_body.same_as(func->body)) {
auto new_func = Function(func->params, VisitExpr(func->body), func->ret_struct_info,
func->is_pure, func->attrs, func->span);
auto new_func = Function(func->params, new_body, func->ret_struct_info, func->is_pure,
func->attrs, func->span);
builder_->UpdateFunction(entry.first, new_func);
}
}
Expand Down
87 changes: 87 additions & 0 deletions tests/python/relax/test_transform_fuse_ops_by_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -1130,5 +1130,92 @@ def test_error_on_repeated_variable_definitions():
relax.transform.FuseOpsByPattern(patterns)(mod)


def test_matmul_symbolic_var():
@I.ir_module
class Before:
@R.function
def main(
x: R.Tensor(["batch_size", 1024], "float16"),
w1: R.Tensor([1024, 1024], "float16"),
w2: R.Tensor([1024, "M"], "float16"),
):
with R.dataflow():
matmul1 = R.matmul(x, w1)
matmul2 = R.matmul(x, w2)
out = (matmul1, matmul2)
R.output(out)
return out

@I.ir_module
class Expected:
@R.function
def main(
x: R.Tensor(["batch_size", 1024], "float16"),
w1: R.Tensor([1024, 1024], "float16"),
w2: R.Tensor([1024, "M"], "float16"),
) -> R.Tuple(
R.Tensor(["batch_size", 1024], "float16"),
R.Tensor(["batch_size", "M"], "float16"),
):
cls = Expected
with R.dataflow():
matmul1 = cls.fused_relax_matmul_cublas(x, w1)
matmul2 = cls.fused_relax_matmul1_cublas(x, w2)
out = (matmul1, matmul2)
R.output(out)
return out

@R.function
def fused_relax_matmul_cublas(
x: R.Tensor(["batch_size", 1024], "float16"),
w1: R.Tensor([1024, 1024], "float16"),
) -> R.Tensor(["batch_size", 1024], "float16"):
batch_size = T.int64()
R.func_attr({"Codegen": "cublas"})

@R.function
def inner_func(
x: R.Tensor([batch_size, 1024], "float16"),
w1: R.Tensor([1024, 1024], "float16"),
) -> R.Tensor([batch_size, 1024], "float16"):
R.func_attr({"Composite": "cublas.matmul"})
with R.dataflow():
out = R.matmul(x, w1)
R.output(out)
return out

out = inner_func(x, w1)
return out

@R.function
def fused_relax_matmul1_cublas(
x: R.Tensor(["batch_size", 1024], "float16"),
w2: R.Tensor([1024, "M"], "float16"),
) -> R.Tensor(["batch_size", "M"], "float16"):
batch_size = T.int64()
M = T.int64()
R.func_attr({"Codegen": "cublas"})

@R.function
def inner_func(
x: R.Tensor([batch_size, 1024], "float16"),
w2: R.Tensor((1024, M), "float16"),
) -> R.Tensor([batch_size, M], "float16"):
R.func_attr({"Composite": "cublas.matmul"})
with R.dataflow():
out = R.matmul(x, w2)
R.output(out)
return out

out = inner_func(x, w2)
return out

patterns = relax.backend.pattern_registry.get_patterns_with_prefix("cublas.matmul")
After = relax.transform.FuseOpsByPattern(patterns, bind_constants=False, annotate_codegen=True)(
Before
)
tvm.ir.assert_structural_equal(Expected, After)


if __name__ == "__main__":
pytest.main([__file__])

0 comments on commit 90d0712

Please sign in to comment.