diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc index efbf648b4807..7eb499f1023a 100644 --- a/src/relax/op/op.cc +++ b/src/relax/op/op.cc @@ -921,8 +921,8 @@ RELAY_REGISTER_OP("relax.memory.kill_storage") .set_num_inputs(1) .add_argument("storage", "Expr", "The storage to be killed.") .set_attr("FInferStructInfo", ReturnVoidStructInfo) - // deallocation also isn't considered a "visible effect" as far as purity is concerned - .set_attr("FPurity", Bool(true)); + // We mark this as impure so it wouldn't be removed by "remove_all_unused" + .set_attr("FPurity", Bool(false)); Expr MakeMemKillStorage(Expr storage) { static const Op& op = Op::Get("relax.memory.kill_storage"); @@ -937,8 +937,8 @@ RELAY_REGISTER_OP("relax.memory.kill_tensor") .set_num_inputs(1) .add_argument("tensor", "Expr", "The tensor to be killed.") .set_attr("FInferStructInfo", ReturnVoidStructInfo) - // memory deallocation also isn't considered a "visible effect" as far as purity is concerned - .set_attr("FPurity", Bool(true)); + // We mark this as impure so it wouldn't be removed by "remove_all_unused" + .set_attr("FPurity", Bool(false)); Expr MakeMemKillTensor(Expr tensor) { static const Op& op = Op::Get("relax.memory.kill_tensor"); @@ -1013,8 +1013,8 @@ TVM_REGISTER_OP("relax.vm.kill_object") .set_num_inputs(1) .add_argument("obj", "Expr", "The object to be killed.") .set_attr("FInferStructInfo", ReturnVoidStructInfo) - // deallocation also isn't considered a "visible effect" as far as purity is concerned - .set_attr("FPurity", Bool(true)); + // We mark this as impure so it wouldn't be removed by "remove_all_unused" + .set_attr("FPurity", Bool(false)); Expr MakeVMKillObject(Expr obj) { static const Op& op = Op::Get("relax.vm.kill_object"); @@ -1031,7 +1031,8 @@ RELAY_REGISTER_OP("relax.vm.call_tir_dyn") .add_argument("args", "Tuple", "The input arguments (list of tensors and last argument is ShapeExpr)") .set_attr("FInferStructInfo", ReturnVoidStructInfo) - .set_attr("FPurity", Bool(true)); + // "relax.vm.call_tir_dyn" works in an in-place way, which is impure. + .set_attr("FPurity", Bool(false)); Expr MakeCallTIRDyn(Expr func, Tuple args) { static const Op& op = Op::Get("relax.vm.call_tir_dyn"); diff --git a/tests/python/relax/test_analysis.py b/tests/python/relax/test_analysis.py index 28ca13ad8991..c790b1bc5142 100644 --- a/tests/python/relax/test_analysis.py +++ b/tests/python/relax/test_analysis.py @@ -19,19 +19,21 @@ import tvm import tvm.testing -from tvm import tir from tvm import relax as rx +from tvm import tir from tvm.relax.analysis import ( - has_reshape_pattern, - udchain, - remove_all_unused, - name_to_binding, - all_vars, all_global_vars, - free_vars, + all_vars, bound_vars, + free_vars, + has_reshape_pattern, + name_to_binding, + remove_all_unused, + udchain, ) -from tvm.script import relax as R, tir as T +from tvm.script import ir as I +from tvm.script import relax as R +from tvm.script import tir as T def var_name_set(vars: List[Union[rx.Var, rx.GlobalVar]]) -> Set[str]: @@ -352,6 +354,30 @@ def expected(x: R.Tensor((32, 32), "float32")) -> R.Tensor: tvm.ir.assert_structural_equal(expected.body, after, map_free_vars=True) +def test_retain_calls_to_impure_builtin_ops(): + @I.ir_module + class Module: + @T.prim_func(private=True) + def my_tir(A: T.handle, B: T.handle, n: T.int64): + T.evaluate(0) + + @R.function(pure=False) + def main(x: R.Tensor(("n",), "float32")): + cls = Module + n = T.int64() + storage = R.memory.alloc_storage((n * 4,), 0, "global", "float32") + alloc = R.memory.alloc_tensor(storage, R.prim_value(0), R.shape([n]), "float32") + # "call_tir_dyn" is impure which shouldn't be removed. + R.vm.call_tir_dyn(cls.my_tir, (x, alloc, R.shape([n]))) + # "kill_tensor"/"kill_storage" are impure which shouldn't be removed. + R.memory.kill_tensor(alloc) + R.memory.kill_storage(storage) + return x + + after = remove_all_unused(Module["main"]) + tvm.ir.assert_structural_equal(after, Module["main"], map_free_vars=True) + + def test_name_to_binding_var_shadowing(): @R.function def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor: diff --git a/tests/python/relax/test_transform_cse.py b/tests/python/relax/test_transform_cse.py index b491577314ec..0998fb67c044 100644 --- a/tests/python/relax/test_transform_cse.py +++ b/tests/python/relax/test_transform_cse.py @@ -435,7 +435,7 @@ def sum( def test_do_not_eliminate_dtype(): @I.ir_module class Before: - @R.function + @R.function(pure=False) def foo() -> R.Tensor((32, 64), "int32"): obj: R.Object = R.vm.alloc_storage( R.shape([24576]), runtime_device_index=0, dtype="uint8" diff --git a/tests/python/relax/test_tvmscript_parser.py b/tests/python/relax/test_tvmscript_parser.py index 3f806de28dbd..109971ce37a4 100644 --- a/tests/python/relax/test_tvmscript_parser.py +++ b/tests/python/relax/test_tvmscript_parser.py @@ -1552,7 +1552,7 @@ def foo(x: R.Tensor(("m", "n"), dtype="float32")): def test_vm_ops(): - @R.function + @R.function(pure=False) def foo(x: R.Tensor(("m", "n"), dtype="float32")): m = T.int64() n = T.int64()