From d13bb5104667e6e9e6e7b5bdd2a03e374794b74d Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Tue, 3 Oct 2023 15:43:40 -0700 Subject: [PATCH] [Arith] Simplify the result of non-divisible floordiv --- src/arith/iter_affine_map.cc | 9 +++++ .../unittest/test_arith_iter_affine_map.py | 34 +++++++++++++++++++ 2 files changed, 43 insertions(+) diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc index 1c782f454653..366784c04fc0 100644 --- a/src/arith/iter_affine_map.cc +++ b/src/arith/iter_affine_map.cc @@ -1904,6 +1904,15 @@ PrimExpr IterMapRewriter::SplitFloorDivConst(IterSplitExpr lhs, PrimExpr base, P /* lower_factor = */ padded->lower_factor * rhs, /* extent = */ analyzer_->Simplify(floordiv(padded->extent, rhs)), /* scale = */ padded->scale); + } else if (is_one(padded->lower_factor) && + analyzer_->CanProveEqual(padded->extent, padded->source->extent)) { + // floordiv(floormod(floordiv(iter, lower_factor), ext), c) + // = floordiv(iter, c) + // when lower_factor = 1 and ext = iter.extent + new_split = IterSplitExpr(padded->source, + /* lower_factor = */ rhs, + /* extent = */ analyzer_->Simplify(ceildiv(padded->extent, rhs)), + /* scale = */ padded->scale); } else { new_split = IterSplitExpr(IterMark(padded, padded->extent), /* lower_factor = */ rhs, diff --git a/tests/python/unittest/test_arith_iter_affine_map.py b/tests/python/unittest/test_arith_iter_affine_map.py index 912edcbcedb6..3a10ec05efeb 100644 --- a/tests/python/unittest/test_arith_iter_affine_map.py +++ b/tests/python/unittest/test_arith_iter_affine_map.py @@ -1227,11 +1227,29 @@ def test_iter_map_simplify_unit_loop_order(): def assert_normalize_to_iter_sum(index, input_iters, args, base): + """Assert the result of arith.normalize_to_iter_sum is correct + + Parameters + ---------- + index : tvm.tir.PrimExpr + The index to be normalized + input_iters : Mapping[Var, Range] + The input iterators + args : List[Union[tvm.arith.IterSplitExpr, Tuple[PrimExpr, PrimExpr]]] + The expected result. Ordered list of args of the expected IterSumExpr. Each arg can be + either IterSplitExpr or a tuple of (PrimExpr, PrimExpr) where the first element is the + iterator normalized to PrimExpr and the second element is the scale. + base : tvm.tir.PrimExpr + The expected base + """ res = tvm.arith.normalize_to_iter_sum(index, input_iters) assert isinstance(res, tvm.arith.IterSumExpr) assert len(res.args) == len(args) for split, item in zip(res.args, args): + if isinstance(item, tvm.arith.IterSplitExpr): + tvm.ir.assert_structural_equal(split, item) + continue tvm.testing.assert_prim_expr_equal(split.scale, item[1]) tvm.testing.assert_prim_expr_equal( tvm.arith.normalize_iter_map_to_expr(split), item[0] * item[1] @@ -1245,6 +1263,7 @@ def test_normalize_to_iter_sum(): z = tvm.tir.Var("z", "int64") a = tvm.tir.Var("a", "int64") n = tvm.tir.Var("n", "int64") + flm = tvm.tir.floormod assert_normalize_to_iter_sum( z + ((y + x * 4 + 2) * n) + 3, @@ -1285,6 +1304,21 @@ def test_normalize_to_iter_sum(): 0, ) + # non-divisible + assert_normalize_to_iter_sum( + x // 5, + var_dom([(x, 4096)]), + [ + tvm.arith.IterSplitExpr( + tvm.arith.IterMark(x, 4096), + lower_factor=tvm.tir.const(5, "int64"), + extent=tvm.tir.const(820, "int64"), + scale=tvm.tir.const(1, "int64"), + ) + ], + 0, + ) + # iter simplify assert_normalize_to_iter_sum( z * 2 + 2 * y * 3 + 4 * (x // 4) + (x % 4),