diff --git a/python/tvm/relax/transform/legalize_ops/statistical.py b/python/tvm/relax/transform/legalize_ops/statistical.py index 1181b3b2a769..bdb79126f012 100644 --- a/python/tvm/relax/transform/legalize_ops/statistical.py +++ b/python/tvm/relax/transform/legalize_ops/statistical.py @@ -85,17 +85,3 @@ def _variance(bb: BlockBuilder, call: Call) -> Expr: register_legalize("relax.min", _statistical(topi.min)) register_legalize("relax.prod", _statistical(topi.prod)) register_legalize("relax.sum", _statistical(topi.sum)) - - -@register_legalize("relax.cumsum") -def _cumsum(bb: BlockBuilder, call: Call) -> Expr: - return bb.call_te( - topi.cumsum, call.args[0], call.attrs.axis, call.attrs.dtype, call.attrs.exclusive - ) - - -@register_legalize("relax.cumprod") -def _cumprod(bb: BlockBuilder, call: Call) -> Expr: - return bb.call_te( - topi.cumprod, call.args[0], call.attrs.axis, call.attrs.dtype, call.attrs.exclusive - ) diff --git a/tests/python/relax/test_frontend_nn_op.py b/tests/python/relax/test_frontend_nn_op.py index 0d579163cdd0..eb1df67a8f81 100644 --- a/tests/python/relax/test_frontend_nn_op.py +++ b/tests/python/relax/test_frontend_nn_op.py @@ -1161,7 +1161,6 @@ def foo(prob: R.Tensor((2, 3), dtype="float32"), sorted_prob: R.Tensor((2, 3), d target = tvm.target.Target("cuda -libs=thrust", host="llvm") with target: - mod = relax.backend.DispatchSortScan()(mod) mod = relax.transform.LegalizeOps()(mod) mod = tir.transform.DefaultGPUSchedule()(mod) diff --git a/tests/python/relax/test_transform_legalize_ops_search_statistical.py b/tests/python/relax/test_transform_legalize_ops_search_statistical.py index c6c53ff0b9af..2a28151dbe7e 100644 --- a/tests/python/relax/test_transform_legalize_ops_search_statistical.py +++ b/tests/python/relax/test_transform_legalize_ops_search_statistical.py @@ -1066,74 +1066,5 @@ def main(x: R.Tensor((2, 3, 4, 5), dtype="float32")) -> R.Tensor((3, 4), dtype=" tvm.ir.assert_structural_equal(mod, Expected) -def test_cumsum(): - # fmt: off - @I.ir_module - class Cumsum: - @R.function - def main(x: R.Tensor((3, 2, 3), "float32")): - gv = R.cumsum(x, axis=1, dtype="int32") - return gv - - @I.ir_module - class Expected: - @T.prim_func(private=True) - def cumsum(var_rxplaceholder: T.handle, out_buf: T.Buffer((T.int64(3), T.int64(2), T.int64(3)), "int32")): - T.func_attr({"tir.noalias": True}) - rxplaceholder = T.match_buffer(var_rxplaceholder, (T.int64(3), T.int64(2), T.int64(3)), offset_factor=1) - with T.block("cumsum_generic"): - for fused in T.parallel(T.int64(9)): - out_buf[(fused // T.int64(3) * T.int64(2) * T.int64(3) + fused % T.int64(3)) // T.int64(3) // T.int64(2), (fused // T.int64(3) * T.int64(2) * T.int64(3) + fused % T.int64(3)) // T.int64(3) % T.int64(2), (fused // T.int64(3) * T.int64(2) * T.int64(3) + fused % T.int64(3)) % T.int64(3)] = T.Cast("int32", rxplaceholder[(fused // T.int64(3) * T.int64(2) * T.int64(3) + fused % T.int64(3)) // T.int64(3) // T.int64(2), (fused // T.int64(3) * T.int64(2) * T.int64(3) + fused % T.int64(3)) // T.int64(3) % T.int64(2), (fused // T.int64(3) * T.int64(2) * T.int64(3) + fused % T.int64(3)) % T.int64(3)]) - for _k in range(T.int64(1)): - out_buf[(fused // T.int64(3) * T.int64(2) * T.int64(3) + fused % T.int64(3) + (_k + T.int64(1)) * T.int64(3)) // T.int64(3) // T.int64(2), (fused // T.int64(3) * T.int64(2) * T.int64(3) + fused % T.int64(3) + (_k + T.int64(1)) * T.int64(3)) // T.int64(3) % T.int64(2), (fused // T.int64(3) * T.int64(2) * T.int64(3) + fused % T.int64(3) + (_k + T.int64(1)) * T.int64(3)) % T.int64(3)] = out_buf[(fused // T.int64(3) * T.int64(2) * T.int64(3) + fused % T.int64(3) + (_k + T.int64(1) - T.int64(1)) * T.int64(3)) // T.int64(3) // T.int64(2), (fused // T.int64(3) * T.int64(2) * T.int64(3) + fused % T.int64(3) + (_k + T.int64(1) - T.int64(1)) * T.int64(3)) // T.int64(3) % T.int64(2), (fused // T.int64(3) * T.int64(2) * T.int64(3) + fused % T.int64(3) + (_k + T.int64(1) - T.int64(1)) * T.int64(3)) % T.int64(3)] + T.Cast("int32", rxplaceholder[(fused // T.int64(3) * T.int64(2) * T.int64(3) + fused % T.int64(3) + (_k + T.int64(1)) * T.int64(3)) // T.int64(3) // T.int64(2), (fused // T.int64(3) * T.int64(2) * T.int64(3) + fused % T.int64(3) + (_k + T.int64(1)) * T.int64(3)) // T.int64(3) % T.int64(2), (fused // T.int64(3) * T.int64(2) * T.int64(3) + fused % T.int64(3) + (_k + T.int64(1)) * T.int64(3)) % T.int64(3)]) - - @R.function - def main(x: R.Tensor((3, 2, 3), dtype="float32")) -> R.Tensor((3, 2, 3), dtype="int32"): - cls = Expected - gv = R.call_tir(cls.cumsum, (x,), out_sinfo=R.Tensor((3, 2, 3), dtype="int32")) - return gv - # fmt: on - - mod = LegalizeOps()(Cumsum) - tvm.ir.assert_structural_equal(mod, Expected) - - -def test_cumsum_symbolic(): - # fmt: off - @I.ir_module - class Cumsum: - @R.function - def main(x: R.Tensor(("a", "b", "c"), "float32")): - gv = R.cumsum(x, axis=1, dtype="int32") - return gv - - @I.ir_module - class Expected: - @T.prim_func(private=True) - def cumsum(var_rxplaceholder: T.handle, var_cumsum_generic: T.handle): - T.func_attr({"tir.noalias": True}) - a, b, c = T.int64(), T.int64(), T.int64() - rxplaceholder = T.match_buffer(var_rxplaceholder, (a, b, c), offset_factor=1) - out_buf = T.match_buffer(var_cumsum_generic, (a, b, c), "int32") - with T.block("cumsum_generic"): - for fused in T.parallel(a * c): - out_buf[(fused // c * b * c + fused % c) // c // b, (fused // c * b * c + fused % c) // c % b, (fused // c * b * c + fused % c) % c] = T.Cast("int32", rxplaceholder[(fused // c * b * c + fused % c) // c // b, (fused // c * b * c + fused % c) // c % b, (fused // c * b * c + fused % c) % c]) - for _k in range(b - T.int64(1)): - out_buf[(fused // c * b * c + fused % c + (_k + T.int64(1)) * c) // c // b, (fused // c * b * c + fused % c + (_k + T.int64(1)) * c) // c % b, (fused // c * b * c + fused % c + (_k + T.int64(1)) * c) % c] = out_buf[(fused // c * b * c + fused % c + (_k + T.int64(1) - T.int64(1)) * c) // c // b, (fused // c * b * c + fused % c + (_k + T.int64(1) - T.int64(1)) * c) // c % b, (fused // c * b * c + fused % c + (_k + T.int64(1) - T.int64(1)) * c) % c] + T.Cast("int32", rxplaceholder[(fused // c * b * c + fused % c + (_k + T.int64(1)) * c) // c // b, (fused // c * b * c + fused % c + (_k + T.int64(1)) * c) // c % b, (fused // c * b * c + fused % c + (_k + T.int64(1)) * c) % c]) - - @R.function - def main(x: R.Tensor(("a", "b", "c"), dtype="float32")) -> R.Tensor(("a", "b", "c"), dtype="int32"): - a = T.int64() - b = T.int64() - c = T.int64() - cls = Expected - gv = R.call_tir(cls.cumsum, (x,), out_sinfo=R.Tensor((a, b, c), dtype="int32")) - return gv - # fmt: on - - mod = LegalizeOps()(Cumsum) - tvm.ir.assert_structural_equal(mod, Expected) - - if __name__ == "__main__": tvm.testing.main()