From e6b5da8a3b691ebea1f01197ae958a45b0d4356b Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 12 Jan 2024 21:37:05 +0000 Subject: [PATCH 1/2] [Relax] Ignore non-relax functions in relax.transform.RunCodegen The `relax.transform.RunCodegen` pass replaces calls to relax functions with the `"Codegen"` attribute with calls into a compiled module. Prior to this commit, while calls to relax functions without the `"Codegen"` attribute were ignored, calls to non-relax functions would raise an error. This commit updates `relax.transform.RunCodegen` to also ignore calls to non-relax functions. --- src/relax/transform/run_codegen.cc | 4 +-- .../relax/test_transform_codegen_pass.py | 32 ++++++++++++++++++- 2 files changed, 33 insertions(+), 3 deletions(-) diff --git a/src/relax/transform/run_codegen.cc b/src/relax/transform/run_codegen.cc index c385ae46ef0f..fe0e73d99e99 100644 --- a/src/relax/transform/run_codegen.cc +++ b/src/relax/transform/run_codegen.cc @@ -116,9 +116,9 @@ class CodeGenRunner : ExprMutator { auto ret_sinfo = GetStructInfo(call); if (auto it = extern_funcs_.find(gvar_node); it != extern_funcs_.end()) { return create_call_dps_packed(it->second, ret_sinfo); - } else { + } else if (auto opt_func = builder_->GetContextIRModule()->Lookup(gvar).as()) { // TODO(@sunggg): Is there any better way to get this func? - Function func = Downcast(builder_->GetContextIRModule()->Lookup(gvar)); + Function func = opt_func.value(); Expr new_func = VisitExpr(func); if (new_func->IsInstance()) { diff --git a/tests/python/relax/test_transform_codegen_pass.py b/tests/python/relax/test_transform_codegen_pass.py index cc8f390b96a1..83ce4dbd6efe 100644 --- a/tests/python/relax/test_transform_codegen_pass.py +++ b/tests/python/relax/test_transform_codegen_pass.py @@ -40,7 +40,7 @@ ) # Global variable in pytest that applies markers to all tests. -pytestmark = [has_tensorrt_codegen, has_tensorrt_runtime] +# pytestmark = [has_tensorrt_codegen, has_tensorrt_runtime] # Target gpu target_str = "nvidia/nvidia-t4" @@ -350,6 +350,36 @@ def main( after = relax.transform.RunCodegen()(Before) tvm.ir.assert_structural_equal(after["main"], Expected["main"]) + after.show() + + +def test_no_op_for_call_to_tir(): + """Calls to PrimFunc are ignored + + RunCodegen should only update calls to Relax functions annotated + with the `"Codegen"` attribute. Calls to any other function type + should be ignored. + + This is a regression test. Previous implementations performed an + unconditional cast from `tvm::BaseFunc` to `tvm::relax::Function`, + which produced an error. + """ + + @tvm.script.ir_module + class Before: + @R.function + def main(x: R.Tensor): + R.func_attr({"relax.force_pure": True}) + _ = Before.shape_func(x) + return x + + @T.prim_func(private=True) + def shape_func(H: T.Buffer(T.int64(4), "int64")): + H[T.int64(0)] = H[T.int64(0)] + T.int64(1) + + Expected = Before + After = relax.transform.RunCodegen()(Before) + tvm.ir.assert_structural_equal(Expected, After) # TODO(@sunggg): test with more complex patterns (e.g., multiple annots, mixed codegens, different ops, const binding) From 60bdbebe59a262caf2960a3270a5a36916d99e2f Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 16 Feb 2024 17:23:42 +0000 Subject: [PATCH 2/2] Remove debug changes --- tests/python/relax/test_transform_codegen_pass.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/python/relax/test_transform_codegen_pass.py b/tests/python/relax/test_transform_codegen_pass.py index 83ce4dbd6efe..560bd3bc0b53 100644 --- a/tests/python/relax/test_transform_codegen_pass.py +++ b/tests/python/relax/test_transform_codegen_pass.py @@ -40,7 +40,7 @@ ) # Global variable in pytest that applies markers to all tests. -# pytestmark = [has_tensorrt_codegen, has_tensorrt_runtime] +pytestmark = [has_tensorrt_codegen, has_tensorrt_runtime] # Target gpu target_str = "nvidia/nvidia-t4" @@ -350,7 +350,6 @@ def main( after = relax.transform.RunCodegen()(Before) tvm.ir.assert_structural_equal(after["main"], Expected["main"]) - after.show() def test_no_op_for_call_to_tir():