diff --git a/src/relax/transform/allocate_workspace.cc b/src/relax/transform/allocate_workspace.cc index 95bbfbee7ca88..4b26b590ef9a6 100644 --- a/src/relax/transform/allocate_workspace.cc +++ b/src/relax/transform/allocate_workspace.cc @@ -125,11 +125,17 @@ class WorkspaceProvider : ExprMutator { builder_->GetContextIRModule()->Remove(GetRef(gvar)); } - auto gvar = mod_->GetGlobalVar("main"); - auto func = Downcast(mod_->Lookup(gvar)); - auto new_func = Function(func->params, VisitExpr(func->body), func->ret_struct_info, - func->is_pure, func->attrs); - builder_->UpdateFunction(gvar, new_func); + for (const auto& [gvar, f] : mod_->functions) { + workspace_var_main_ = Var(); + if (!f->IsInstance() || f->GetAttr(attr::kCodegen) || + f->GetAttr(attr::kComposite)) { + continue; + } + auto func = Downcast(mod_->Lookup(gvar)); + auto new_func = Function(func->params, VisitExpr(func->body), func->ret_struct_info, + func->is_pure, func->attrs); + builder_->UpdateFunction(gvar, new_func); + } return builder_->GetContextIRModule(); } diff --git a/tests/python/relax/test_transform_allocate_workspace.py b/tests/python/relax/test_transform_allocate_workspace.py index 7ffbd01b05b25..aca6ea2fe83a0 100644 --- a/tests/python/relax/test_transform_allocate_workspace.py +++ b/tests/python/relax/test_transform_allocate_workspace.py @@ -55,7 +55,7 @@ def gv( return gv1 @R.function - def main( + def entry_a( q: R.Tensor((32, 8, 16, 8), dtype="float16"), k: R.Tensor((32, 8, 16, 8), dtype="float16"), v: R.Tensor((32, 8, 16, 8), dtype="float16"), @@ -68,6 +68,20 @@ def main( R.output(gv) return gv + @R.function + def entry_b( + q: R.Tensor((32, 8, 16, 8), dtype="float16"), + k: R.Tensor((32, 8, 16, 8), dtype="float16"), + v: R.Tensor((32, 8, 16, 8), dtype="float16"), + ) -> R.Tensor((32, 8, 16, 8), dtype="float16"): + cls = Module + with R.dataflow(): + gv: R.Tensor((32, 8, 16, 8), dtype="float16") = cls.fused_relax_nn_attention_cutlass( + q, k, v + ) + R.const(1, dtype="float16") + R.output(gv) + return gv + @I.ir_module class Expected: @@ -105,7 +119,7 @@ def gv( return gv1 @R.function - def main( + def entry_a( q: R.Tensor((32, 8, 16, 8), dtype="float16"), k: R.Tensor((32, 8, 16, 8), dtype="float16"), v: R.Tensor((32, 8, 16, 8), dtype="float16"), @@ -122,6 +136,24 @@ def main( R.output(gv) return gv + @R.function + def entry_b( + q: R.Tensor((32, 8, 16, 8), dtype="float16"), + k: R.Tensor((32, 8, 16, 8), dtype="float16"), + v: R.Tensor((32, 8, 16, 8), dtype="float16"), + ) -> R.Tensor((32, 8, 16, 8), dtype="float16"): + cls = Expected + with R.dataflow(): + lv: R.Object = R.vm.alloc_storage(R.shape([65536]), R.prim_value(0), R.dtype("uint8")) + workspace_main: R.Tensor((65536,), dtype="uint8") = R.vm.alloc_tensor( + lv, R.prim_value(0), R.shape([65536]), R.dtype("uint8") + ) + gv: R.Tensor((32, 8, 16, 8), dtype="float16") = cls.fused_relax_nn_attention_cutlass1( + q, k, v, workspace_main + ) + R.const(1, dtype="float16") + R.output(gv) + return gv + def test_single_attention(): rewritten = relax.transform.AllocateWorkspace()(Module)