From 5b7d9e143e31e65df015c033af2f4519b1cbafaf Mon Sep 17 00:00:00 2001 From: Chenfan Date: Fri, 27 Aug 2021 03:33:37 +0800 Subject: [PATCH] [FIX] Bug fix for a floormod rewrite simplify rule (#8852) * Update rewrite_simplify.cc * Update test_arith_rewrite_simplify.py * Update test_arith_rewrite_simplify.py * Update test_arith_rewrite_simplify.py --- src/arith/rewrite_simplify.cc | 16 +++++++++----- .../unittest/test_arith_rewrite_simplify.py | 22 ++++++++++++------- 2 files changed, 24 insertions(+), 14 deletions(-) diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index ff6536ab066b..1d3475b13dad 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -858,14 +858,18 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorModNode* op) { ModularSet bmod = analyzer_->modular_set(b1.Eval()); int64_t ramp_min = floordiv(bmod->base, c2val); int64_t ramp_max = floordiv(bmod->base + (lanes.Eval() - 1) * c1val, c2val); - if (bmod->coeff % c2val == 0) { - if (ramp_min == ramp_max) { + if (ramp_min == ramp_max) { + // If b1 can devide c2 + if (bmod->coeff % c2val == 0) { return ramp(floormod(bmod->base, c2), c1, lanes).Eval(); - } else { - return floormod(ramp(floormod(bmod->base, c2), c1, lanes), broadcast(c2, lanes)).Eval(); } - } else if (c2val % bmod->coeff == 0 && ramp_min == ramp_max) { - return ramp(floormod(b1, c2), c1, lanes).Eval(); + // If all indices can be guaranteed to settle inside a coeff range + if (c2val % bmod->coeff == 0 && bmod->base + (lanes.Eval() - 1) * c1val < bmod->coeff) { + return ramp(floormod(b1, c2), c1, lanes).Eval(); + } + } + if (bmod->coeff % c2val == 0) { + return floormod(ramp(floormod(bmod->base, c2), c1, lanes), broadcast(c2, lanes)).Eval(); } } } diff --git a/tests/python/unittest/test_arith_rewrite_simplify.py b/tests/python/unittest/test_arith_rewrite_simplify.py index 231c376c50ca..641eed51d5cf 100644 --- a/tests/python/unittest/test_arith_rewrite_simplify.py +++ b/tests/python/unittest/test_arith_rewrite_simplify.py @@ -101,15 +101,16 @@ def test_vector_simplify(): ck.verify( fld(tvm.tir.Ramp(x * 4, 1, 5), tvm.tir.Broadcast(64, 5)), fld(tvm.tir.Ramp(x * 4, 1, 5), tvm.tir.Broadcast(64, 5)), - ) + ) # Example negative case: x = 15; [60, 61, 62, 63, 64] / 64 = [0, 0, 0, 0, 1] ck.verify( fld(tvm.tir.Ramp(x * 4 + 3, 1, 4), tvm.tir.Broadcast(64, 4)), fld(tvm.tir.Ramp(x * 4 + 3, 1, 4), tvm.tir.Broadcast(64, 4)), - ) + ) # Example negative case: x = 15; [63, 64, 65, 66] % 64 = [0, 1, 1, 1] ck.verify( fld(tvm.tir.Ramp(x * 7, 1, 4), tvm.tir.Broadcast(64, 4)), fld(tvm.tir.Ramp(x * 7, 1, 4), tvm.tir.Broadcast(64, 4)), - ) + ) # Example negative case: x = 9; [63, 70, 77, 84] % 64 = [0, 1, 1, 1] + # floor mod ck.verify(flm(y.astype("int32x2"), x.astype("int32x2")), flm(y, x).astype("int32x2")) ck.verify(flm(tvm.tir.Ramp(x, 4, 4), 2), tvm.tir.Broadcast(flm(x, 2), 4)) @@ -136,16 +137,21 @@ def test_vector_simplify(): flm(tvm.tir.Ramp(x * 8, 2, 4), tvm.tir.Broadcast(64, 4)), tvm.tir.Ramp(flm(x * 8, 64), 2, 4) ) ck.verify( - flm(tvm.tir.Ramp(x * 4, 1, 5), tvm.tir.Broadcast(64, 5)), tvm.tir.Ramp(flm(x * 4, 64), 1, 5) - ) + flm(tvm.tir.Ramp(x * 4, 1, 5), tvm.tir.Broadcast(64, 5)), + flm(tvm.tir.Ramp(x * 4, 1, 5), tvm.tir.Broadcast(64, 5)), + ) # Example negative case: x = 15; [60, 61, 62, 63, 64] % 64 = [60, 61, 62, 63, 0] ck.verify( flm(tvm.tir.Ramp(x * 4 + 3, 1, 4), tvm.tir.Broadcast(64, 4)), - tvm.tir.Ramp(flm(x * 4 + 3, 64), 1, 4), - ) + flm(tvm.tir.Ramp(x * 4 + 3, 1, 4), tvm.tir.Broadcast(64, 4)), + ) # Example negative case: x = 15; [63, 64, 65, 66] % 64 = [63, 0, 1, 2] + ck.verify( + flm(tvm.tir.Ramp(x * 2, 1, 8), tvm.tir.Broadcast(20, 8)), + flm(tvm.tir.Ramp(x * 2, 1, 8), tvm.tir.Broadcast(20, 8)), + ) # Example negative case: x = 9; [18, 19, 20, ..., 25] % 20 = [18, 19, 0, 1, ..., 5] ck.verify( flm(tvm.tir.Ramp(x * 7, 1, 4), tvm.tir.Broadcast(64, 4)), flm(tvm.tir.Ramp(x * 7, 1, 4), tvm.tir.Broadcast(64, 4)), - ) + ) # Example negative case: x = 9; [63, 70, 77, 84] % 64 = [63, 6, 13, 20] # Min/Max rules vx = te.var("vx", dtype="int32x2")