Skip to content

Commit

Permalink
[UnitTest] Parametrized test_arith_iter_affine_map::test_padding (apa…
Browse files Browse the repository at this point in the history
…che#13774)

Parametrization helped in the debugging of
apache#13530, but is not otherwise related
to that PR.
  • Loading branch information
Lunderberg authored and fzi-peccia committed Mar 27, 2023
1 parent b8169d6 commit 96a1089
Showing 1 changed file with 70 additions and 90 deletions.
160 changes: 70 additions & 90 deletions tests/python/unittest/test_arith_iter_affine_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -946,103 +946,83 @@ def test_free_variables():
)


def test_padding():
class TestPadding:
x = tvm.tir.Var("x", "int32")
y = tvm.tir.Var("y", "int32")
fld = tvm.tir.floordiv
flm = tvm.tir.floormod

# left padding only, offset divisible
sum = 64 + y
dom_map = var_dom([(y, 192)])
assert_iter_sum_pattern(
{fld(sum, 32): (6, 2, 1), flm(sum, 32): (32, 0, 1)},
dom_map,
check_level="bijective",
)

# left padding only, offset non-divisible
sum = 80 + y
dom_map = var_dom([(y, 176)])
assert_iter_sum_pattern(
{fld(sum, 32): (6, 2, 1)},
dom_map,
)
assert_iter_sum_pattern(
{flm(fld(sum, 2), 16): (16, 0, 1), flm(sum, 2): (2, 0, 1)},
dom_map,
)
assert_iter_sum_failure({fld(sum, 32), flm(sum, 32)}, dom_map)
assert_iter_sum_failure({fld(sum, 32), fld(sum, 4)}, dom_map)

# right padding only, offset divisible
sum = x * 32 + y * 8
dom_map = var_dom([(x, 5), (y, 4)])
assert_iter_sum_pattern(
{fld(sum, 16): (10, 0, 1), flm(sum, 16): (2, 0, 8)},
dom_map,
)
assert_iter_sum_failure({fld(sum, 5)}, dom_map)

# right padding only, offset non-divisible
dom_map = var_dom([(x, 26)])
assert_iter_sum_pattern(
{fld(x, 15): (2, 0, 1)},
dom_map,
)
assert_iter_sum_pattern(
{flm(fld(x, 3), 5): (5, 0, 1), flm(x, 3): (3, 0, 1)},
dom_map,
)

# padding constants on both side
sum = x + 71
dom_map = var_dom([(x, 45)])
assert_iter_sum_pattern({fld(sum, 32): (2, 2, 1)}, dom_map)
assert_iter_sum_pattern(
{flm(fld(x, 4), 8): (8, 0, 1), flm(x, 4): (4, 0, 1)},
dom_map,
)

# padding for free iteration part
sum = x * 360 + y
dom_map = var_dom([(y, 360)])
assert_iter_sum_pattern({fld(sum, 16): (23, fld(x * 360 - flm(x, 2) * 8, 16), 1)}, dom_map)
assert_iter_sum_pattern({flm(x * 360 + y, 16): (16, 0, 1)}, dom_map)

# multiple split with same mark offset, could
# be surjective on missing (padded // LCM)
assert_iter_sum_pattern(
{
flm(x + 10, 3): (3, 0),
flm(fld(x + 10, 3), 4): (4, 0),
flm(fld(fld(x + 10, 3), 4), 5): (5, 0),
},
var_dom([(x, 240)]),
)
assert_iter_sum_failure(
{
flm(x + 10, 3),
flm(fld(x + 10, 3), 4),
flm(fld(fld(x + 10, 3), 4), 5),
fld(fld(fld(x + 10, 3), 4), 5),
},
var_dom([(x, 240)]),
)

# different offsets on splits
assert_iter_sum_pattern(
{
flm(x + 1, 3): (3, 0),
flm(fld(x + 10, 3) + 2, 4): (4, 0),
flm(fld(fld(x + 10, 3), 4) + 3, 5): (5, 0),
},
var_dom([(x, 240)]),
positive_test_case = tvm.testing.parameter(
# left padding only, offset divisible
({y: 192}, {fld(64 + y, 32): (6, 2, 1), flm(64 + y, 32): (32, 0, 1)}, "bijective"),
# left padding only, offset non-divisible
({y: 176}, {fld(80 + y, 32): (6, 2, 1)}),
({y: 176}, {flm(fld(80 + y, 2), 16): (16, 0, 1), flm(80 + y, 2): (2, 0, 1)}),
# right padding only, offset divisible
({x: 5, y: 4}, {fld(x * 32 + y * 8, 16): (10, 0, 1), flm(x * 32 + y * 8, 16): (2, 0, 8)}),
# right padding only, offset non-divisible
({x: 26}, {fld(x, 15): (2, 0, 1)}),
({x: 26}, {flm(fld(x, 3), 5): (5, 0, 1), flm(x, 3): (3, 0, 1)}),
# padding constants on both side
({x: 45}, {fld(x + 71, 32): (2, 2, 1)}),
({x: 45}, {flm(fld(x, 4), 8): (8, 0, 1), flm(x, 4): (4, 0, 1)}),
# padding for free iteration part
({y: 360}, {fld(x * 360 + y, 16): (23, fld(x * 360 - flm(x, 2) * 8, 16), 1)}),
({y: 360}, {flm(x * 360 + y, 16): (16, 0, 1)}),
# multiple split with same mark offset, could
# be surjective on missing (padded // LCM)
(
{x: 240},
{
flm(x + 10, 3): (3, 0),
flm(fld(x + 10, 3), 4): (4, 0),
flm(fld(fld(x + 10, 3), 4), 5): (5, 0),
},
),
# different offsets on splits
(
{x: 240},
{
flm(x + 1, 3): (3, 0),
flm(fld(x + 10, 3) + 2, 4): (4, 0),
flm(fld(fld(x + 10, 3), 4) + 3, 5): (5, 0),
},
),
)

# original extent is smaller than the divident
# it is not surjective wrt to the region [0, 16)
assert_iter_sum_failure({flm(x, 16)}, var_dom([(x, 3)]))
negative_test_case = tvm.testing.parameter(
# left padding only, offset non-divisible
({y: 176}, {fld(80 + y, 32), flm(80 + y, 32)}),
({y: 176}, {fld(80 + y, 32), fld(80 + y, 4)}),
# right padding only, offset divisible
({x: 5, y: 4}, {fld(x * 32 + y * 8, 5)}),
# multiple split with same mark offset, could
# be surjective on missing (padded // LCM)
(
{x: 240},
{
flm(x + 10, 3),
flm(fld(x + 10, 3), 4),
flm(fld(fld(x + 10, 3), 4), 5),
fld(fld(fld(x + 10, 3), 4), 5),
},
),
# original extent is smaller than the divident
# it is not surjective wrt to the region [0, 16)
({x: 3}, {flm(x, 16)}),
)

def test_padding(self, positive_test_case):
iter_extent, mapped_iterators, *args = positive_test_case
check_level = args[0] if args else "surjective"
dom_map = {var: tvm.ir.Range(0, ext) for var, ext in iter_extent.items()}
assert_iter_sum_pattern(mapped_iterators, dom_map, check_level=check_level)

def test_padding_error(self, negative_test_case):
iter_extent, mapped_iterators, *args = negative_test_case
check_level = args[0] if args else "surjective"
dom_map = {var: tvm.ir.Range(0, ext) for var, ext in iter_extent.items()}
assert_iter_sum_failure(mapped_iterators, dom_map, check_level=check_level)


def test_overlapped_fuse():
Expand Down

0 comments on commit 96a1089

Please sign in to comment.