Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[UnitTest] Parametrized test_arith_iter_affine_map::test_padding #13774

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 70 additions & 90 deletions tests/python/unittest/test_arith_iter_affine_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -946,103 +946,83 @@ def test_free_variables():
)


def test_padding():
class TestPadding:
x = tvm.tir.Var("x", "int32")
y = tvm.tir.Var("y", "int32")
fld = tvm.tir.floordiv
flm = tvm.tir.floormod

# left padding only, offset divisible
sum = 64 + y
dom_map = var_dom([(y, 192)])
assert_iter_sum_pattern(
{fld(sum, 32): (6, 2, 1), flm(sum, 32): (32, 0, 1)},
dom_map,
check_level="bijective",
)

# left padding only, offset non-divisible
sum = 80 + y
dom_map = var_dom([(y, 176)])
assert_iter_sum_pattern(
{fld(sum, 32): (6, 2, 1)},
dom_map,
)
assert_iter_sum_pattern(
{flm(fld(sum, 2), 16): (16, 0, 1), flm(sum, 2): (2, 0, 1)},
dom_map,
)
assert_iter_sum_failure({fld(sum, 32), flm(sum, 32)}, dom_map)
assert_iter_sum_failure({fld(sum, 32), fld(sum, 4)}, dom_map)

# right padding only, offset divisible
sum = x * 32 + y * 8
dom_map = var_dom([(x, 5), (y, 4)])
assert_iter_sum_pattern(
{fld(sum, 16): (10, 0, 1), flm(sum, 16): (2, 0, 8)},
dom_map,
)
assert_iter_sum_failure({fld(sum, 5)}, dom_map)

# right padding only, offset non-divisible
dom_map = var_dom([(x, 26)])
assert_iter_sum_pattern(
{fld(x, 15): (2, 0, 1)},
dom_map,
)
assert_iter_sum_pattern(
{flm(fld(x, 3), 5): (5, 0, 1), flm(x, 3): (3, 0, 1)},
dom_map,
)

# padding constants on both side
sum = x + 71
dom_map = var_dom([(x, 45)])
assert_iter_sum_pattern({fld(sum, 32): (2, 2, 1)}, dom_map)
assert_iter_sum_pattern(
{flm(fld(x, 4), 8): (8, 0, 1), flm(x, 4): (4, 0, 1)},
dom_map,
)

# padding for free iteration part
sum = x * 360 + y
dom_map = var_dom([(y, 360)])
assert_iter_sum_pattern({fld(sum, 16): (23, fld(x * 360 - flm(x, 2) * 8, 16), 1)}, dom_map)
assert_iter_sum_pattern({flm(x * 360 + y, 16): (16, 0, 1)}, dom_map)

# multiple split with same mark offset, could
# be surjective on missing (padded // LCM)
assert_iter_sum_pattern(
{
flm(x + 10, 3): (3, 0),
flm(fld(x + 10, 3), 4): (4, 0),
flm(fld(fld(x + 10, 3), 4), 5): (5, 0),
},
var_dom([(x, 240)]),
)
assert_iter_sum_failure(
{
flm(x + 10, 3),
flm(fld(x + 10, 3), 4),
flm(fld(fld(x + 10, 3), 4), 5),
fld(fld(fld(x + 10, 3), 4), 5),
},
var_dom([(x, 240)]),
)

# different offsets on splits
assert_iter_sum_pattern(
{
flm(x + 1, 3): (3, 0),
flm(fld(x + 10, 3) + 2, 4): (4, 0),
flm(fld(fld(x + 10, 3), 4) + 3, 5): (5, 0),
},
var_dom([(x, 240)]),
positive_test_case = tvm.testing.parameter(
# left padding only, offset divisible
({y: 192}, {fld(64 + y, 32): (6, 2, 1), flm(64 + y, 32): (32, 0, 1)}, "bijective"),
# left padding only, offset non-divisible
({y: 176}, {fld(80 + y, 32): (6, 2, 1)}),
({y: 176}, {flm(fld(80 + y, 2), 16): (16, 0, 1), flm(80 + y, 2): (2, 0, 1)}),
# right padding only, offset divisible
({x: 5, y: 4}, {fld(x * 32 + y * 8, 16): (10, 0, 1), flm(x * 32 + y * 8, 16): (2, 0, 8)}),
# right padding only, offset non-divisible
({x: 26}, {fld(x, 15): (2, 0, 1)}),
({x: 26}, {flm(fld(x, 3), 5): (5, 0, 1), flm(x, 3): (3, 0, 1)}),
# padding constants on both side
({x: 45}, {fld(x + 71, 32): (2, 2, 1)}),
({x: 45}, {flm(fld(x, 4), 8): (8, 0, 1), flm(x, 4): (4, 0, 1)}),
# padding for free iteration part
({y: 360}, {fld(x * 360 + y, 16): (23, fld(x * 360 - flm(x, 2) * 8, 16), 1)}),
({y: 360}, {flm(x * 360 + y, 16): (16, 0, 1)}),
# multiple split with same mark offset, could
# be surjective on missing (padded // LCM)
(
{x: 240},
{
flm(x + 10, 3): (3, 0),
flm(fld(x + 10, 3), 4): (4, 0),
flm(fld(fld(x + 10, 3), 4), 5): (5, 0),
},
),
# different offsets on splits
(
{x: 240},
{
flm(x + 1, 3): (3, 0),
flm(fld(x + 10, 3) + 2, 4): (4, 0),
flm(fld(fld(x + 10, 3), 4) + 3, 5): (5, 0),
},
),
)

# original extent is smaller than the divident
# it is not surjective wrt to the region [0, 16)
assert_iter_sum_failure({flm(x, 16)}, var_dom([(x, 3)]))
negative_test_case = tvm.testing.parameter(
# left padding only, offset non-divisible
({y: 176}, {fld(80 + y, 32), flm(80 + y, 32)}),
({y: 176}, {fld(80 + y, 32), fld(80 + y, 4)}),
# right padding only, offset divisible
({x: 5, y: 4}, {fld(x * 32 + y * 8, 5)}),
# multiple split with same mark offset, could
# be surjective on missing (padded // LCM)
(
{x: 240},
{
flm(x + 10, 3),
flm(fld(x + 10, 3), 4),
flm(fld(fld(x + 10, 3), 4), 5),
fld(fld(fld(x + 10, 3), 4), 5),
},
),
# original extent is smaller than the divident
# it is not surjective wrt to the region [0, 16)
({x: 3}, {flm(x, 16)}),
)

def test_padding(self, positive_test_case):
iter_extent, mapped_iterators, *args = positive_test_case
check_level = args[0] if args else "surjective"
dom_map = {var: tvm.ir.Range(0, ext) for var, ext in iter_extent.items()}
assert_iter_sum_pattern(mapped_iterators, dom_map, check_level=check_level)

def test_padding_error(self, negative_test_case):
iter_extent, mapped_iterators, *args = negative_test_case
check_level = args[0] if args else "surjective"
dom_map = {var: tvm.ir.Range(0, ext) for var, ext in iter_extent.items()}
assert_iter_sum_failure(mapped_iterators, dom_map, check_level=check_level)


def test_overlapped_fuse():
Expand Down