Skip to content

Commit

Permalink
[Relax] Remove the legalization of cumsum/cumprob (#16676)
Browse files Browse the repository at this point in the history
* [Relax] Remove the legalization of cumsum/cumprob

* remove related tests
  • Loading branch information
yongwww authored Mar 7, 2024
1 parent d284cf4 commit 6ca2341
Show file tree
Hide file tree
Showing 3 changed files with 0 additions and 84 deletions.
14 changes: 0 additions & 14 deletions python/tvm/relax/transform/legalize_ops/statistical.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,17 +85,3 @@ def _variance(bb: BlockBuilder, call: Call) -> Expr:
register_legalize("relax.min", _statistical(topi.min))
register_legalize("relax.prod", _statistical(topi.prod))
register_legalize("relax.sum", _statistical(topi.sum))


@register_legalize("relax.cumsum")
def _cumsum(bb: BlockBuilder, call: Call) -> Expr:
return bb.call_te(
topi.cumsum, call.args[0], call.attrs.axis, call.attrs.dtype, call.attrs.exclusive
)


@register_legalize("relax.cumprod")
def _cumprod(bb: BlockBuilder, call: Call) -> Expr:
return bb.call_te(
topi.cumprod, call.args[0], call.attrs.axis, call.attrs.dtype, call.attrs.exclusive
)
1 change: 0 additions & 1 deletion tests/python/relax/test_frontend_nn_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1161,7 +1161,6 @@ def foo(prob: R.Tensor((2, 3), dtype="float32"), sorted_prob: R.Tensor((2, 3), d

target = tvm.target.Target("cuda -libs=thrust", host="llvm")
with target:
mod = relax.backend.DispatchSortScan()(mod)
mod = relax.transform.LegalizeOps()(mod)
mod = tir.transform.DefaultGPUSchedule()(mod)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1066,74 +1066,5 @@ def main(x: R.Tensor((2, 3, 4, 5), dtype="float32")) -> R.Tensor((3, 4), dtype="
tvm.ir.assert_structural_equal(mod, Expected)


def test_cumsum():
# fmt: off
@I.ir_module
class Cumsum:
@R.function
def main(x: R.Tensor((3, 2, 3), "float32")):
gv = R.cumsum(x, axis=1, dtype="int32")
return gv

@I.ir_module
class Expected:
@T.prim_func(private=True)
def cumsum(var_rxplaceholder: T.handle, out_buf: T.Buffer((T.int64(3), T.int64(2), T.int64(3)), "int32")):
T.func_attr({"tir.noalias": True})
rxplaceholder = T.match_buffer(var_rxplaceholder, (T.int64(3), T.int64(2), T.int64(3)), offset_factor=1)
with T.block("cumsum_generic"):
for fused in T.parallel(T.int64(9)):
out_buf[(fused // T.int64(3) * T.int64(2) * T.int64(3) + fused % T.int64(3)) // T.int64(3) // T.int64(2), (fused // T.int64(3) * T.int64(2) * T.int64(3) + fused % T.int64(3)) // T.int64(3) % T.int64(2), (fused // T.int64(3) * T.int64(2) * T.int64(3) + fused % T.int64(3)) % T.int64(3)] = T.Cast("int32", rxplaceholder[(fused // T.int64(3) * T.int64(2) * T.int64(3) + fused % T.int64(3)) // T.int64(3) // T.int64(2), (fused // T.int64(3) * T.int64(2) * T.int64(3) + fused % T.int64(3)) // T.int64(3) % T.int64(2), (fused // T.int64(3) * T.int64(2) * T.int64(3) + fused % T.int64(3)) % T.int64(3)])
for _k in range(T.int64(1)):
out_buf[(fused // T.int64(3) * T.int64(2) * T.int64(3) + fused % T.int64(3) + (_k + T.int64(1)) * T.int64(3)) // T.int64(3) // T.int64(2), (fused // T.int64(3) * T.int64(2) * T.int64(3) + fused % T.int64(3) + (_k + T.int64(1)) * T.int64(3)) // T.int64(3) % T.int64(2), (fused // T.int64(3) * T.int64(2) * T.int64(3) + fused % T.int64(3) + (_k + T.int64(1)) * T.int64(3)) % T.int64(3)] = out_buf[(fused // T.int64(3) * T.int64(2) * T.int64(3) + fused % T.int64(3) + (_k + T.int64(1) - T.int64(1)) * T.int64(3)) // T.int64(3) // T.int64(2), (fused // T.int64(3) * T.int64(2) * T.int64(3) + fused % T.int64(3) + (_k + T.int64(1) - T.int64(1)) * T.int64(3)) // T.int64(3) % T.int64(2), (fused // T.int64(3) * T.int64(2) * T.int64(3) + fused % T.int64(3) + (_k + T.int64(1) - T.int64(1)) * T.int64(3)) % T.int64(3)] + T.Cast("int32", rxplaceholder[(fused // T.int64(3) * T.int64(2) * T.int64(3) + fused % T.int64(3) + (_k + T.int64(1)) * T.int64(3)) // T.int64(3) // T.int64(2), (fused // T.int64(3) * T.int64(2) * T.int64(3) + fused % T.int64(3) + (_k + T.int64(1)) * T.int64(3)) // T.int64(3) % T.int64(2), (fused // T.int64(3) * T.int64(2) * T.int64(3) + fused % T.int64(3) + (_k + T.int64(1)) * T.int64(3)) % T.int64(3)])

@R.function
def main(x: R.Tensor((3, 2, 3), dtype="float32")) -> R.Tensor((3, 2, 3), dtype="int32"):
cls = Expected
gv = R.call_tir(cls.cumsum, (x,), out_sinfo=R.Tensor((3, 2, 3), dtype="int32"))
return gv
# fmt: on

mod = LegalizeOps()(Cumsum)
tvm.ir.assert_structural_equal(mod, Expected)


def test_cumsum_symbolic():
# fmt: off
@I.ir_module
class Cumsum:
@R.function
def main(x: R.Tensor(("a", "b", "c"), "float32")):
gv = R.cumsum(x, axis=1, dtype="int32")
return gv

@I.ir_module
class Expected:
@T.prim_func(private=True)
def cumsum(var_rxplaceholder: T.handle, var_cumsum_generic: T.handle):
T.func_attr({"tir.noalias": True})
a, b, c = T.int64(), T.int64(), T.int64()
rxplaceholder = T.match_buffer(var_rxplaceholder, (a, b, c), offset_factor=1)
out_buf = T.match_buffer(var_cumsum_generic, (a, b, c), "int32")
with T.block("cumsum_generic"):
for fused in T.parallel(a * c):
out_buf[(fused // c * b * c + fused % c) // c // b, (fused // c * b * c + fused % c) // c % b, (fused // c * b * c + fused % c) % c] = T.Cast("int32", rxplaceholder[(fused // c * b * c + fused % c) // c // b, (fused // c * b * c + fused % c) // c % b, (fused // c * b * c + fused % c) % c])
for _k in range(b - T.int64(1)):
out_buf[(fused // c * b * c + fused % c + (_k + T.int64(1)) * c) // c // b, (fused // c * b * c + fused % c + (_k + T.int64(1)) * c) // c % b, (fused // c * b * c + fused % c + (_k + T.int64(1)) * c) % c] = out_buf[(fused // c * b * c + fused % c + (_k + T.int64(1) - T.int64(1)) * c) // c // b, (fused // c * b * c + fused % c + (_k + T.int64(1) - T.int64(1)) * c) // c % b, (fused // c * b * c + fused % c + (_k + T.int64(1) - T.int64(1)) * c) % c] + T.Cast("int32", rxplaceholder[(fused // c * b * c + fused % c + (_k + T.int64(1)) * c) // c // b, (fused // c * b * c + fused % c + (_k + T.int64(1)) * c) // c % b, (fused // c * b * c + fused % c + (_k + T.int64(1)) * c) % c])

@R.function
def main(x: R.Tensor(("a", "b", "c"), dtype="float32")) -> R.Tensor(("a", "b", "c"), dtype="int32"):
a = T.int64()
b = T.int64()
c = T.int64()
cls = Expected
gv = R.call_tir(cls.cumsum, (x,), out_sinfo=R.Tensor((a, b, c), dtype="int32"))
return gv
# fmt: on

mod = LegalizeOps()(Cumsum)
tvm.ir.assert_structural_equal(mod, Expected)


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit 6ca2341

Please sign in to comment.