Skip to content

Commit

Permalink
Updated T.assume(expr) to T.evaluate(T.assume(expr))
Browse files Browse the repository at this point in the history
  • Loading branch information
Lunderberg committed Nov 15, 2022
1 parent 794c0d9 commit 1bbf96f
Showing 1 changed file with 23 additions and 23 deletions.
46 changes: 23 additions & 23 deletions tests/python/unittest/test_tir_transform_simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -1127,12 +1127,12 @@ class TestSimplifyInputAssumption(BaseBeforeAfter):
propagate_knowns_to_prove_conditional = True

def before(A: T.Buffer[1, "int32"], n: T.int32):
T.assume(n == 0)
T.evaluate(T.assume(n == 0))
if n == 0:
A[0] = 42

def expected(A: T.Buffer[1, "int32"], n: T.int32):
T.assume(n == 0)
T.evaluate(T.assume(n == 0))
A[0] = 42


Expand All @@ -1142,12 +1142,12 @@ class TestSimplifyInputAssumption(BaseBeforeAfter):
propagate_knowns_to_prove_conditional = True

def before(A: T.Buffer[1, "int32"], n: T.int32):
T.assume(n == 0)
T.evaluate(T.assume(n == 0))
if n == 0:
A[0] = 42

def expected(A: T.Buffer[1, "int32"], n: T.int32):
T.assume(n == 0)
T.evaluate(T.assume(n == 0))
A[0] = 42


Expand All @@ -1158,7 +1158,7 @@ class TestNoSimplifyFromScopedInputAssumption(BaseBeforeAfter):

def before(A: T.Buffer[1, "int32"], n: T.int32, m: T.int32):
if m == 0:
T.assume(n == 0)
T.evaluate(T.assume(n == 0))

if n == 0:
A[0] = 42
Expand Down Expand Up @@ -1232,13 +1232,13 @@ class TestSimplifyUsingBufferAssumption(BaseBeforeAfter):
propagate_knowns_to_prove_conditional = True

def before(A: T.Buffer[1, "int32"]):
T.assume(A[0] == 0)
T.evaluate(T.assume(A[0] == 0))

if A[0] == 0:
A[0] = 42

def expected(A: T.Buffer[1, "int32"]):
T.assume(A[0] == 0)
T.evaluate(T.assume(A[0] == 0))
A[0] = 42


Expand All @@ -1249,15 +1249,15 @@ class TestSimplifyUsingBufferAssumptionInLoop(BaseBeforeAfter):

def before(A: T.Buffer[16, "int32"]):
for i in T.serial(16):
T.assume(A[i] == i)
T.evaluate(T.assume(A[i] == i))

for i in T.serial(16):
if A[i] < 100:
A[i] = 0

def expected(A: T.Buffer[16, "int32"]):
for i in T.serial(16):
T.assume(A[i] == i)
T.evaluate(T.assume(A[i] == i))

for i in T.serial(16):
A[i] = 0
Expand All @@ -1271,7 +1271,7 @@ class TestSimplifyUsingPartiallyKnownBufferConditional(BaseBeforeAfter):
def before(A: T.Buffer[16, "int32"]):
for i in T.serial(16):
if 14 <= i:
T.assume(A[i] == 0)
T.evaluate(T.assume(A[i] == 0))

for i in T.serial(16):
if 14 <= i:
Expand All @@ -1285,7 +1285,7 @@ def before(A: T.Buffer[16, "int32"]):
def expected(A: T.Buffer[16, "int32"]):
for i in T.serial(16):
if 14 <= i:
T.assume(A[i] == 0)
T.evaluate(T.assume(A[i] == 0))

for i in T.serial(16):
if 14 <= i:
Expand All @@ -1308,7 +1308,7 @@ class TestSimplifyUsingPartiallyKnownBufferExpression(BaseBeforeAfter):

def before(A: T.Buffer[16, "int32"]):
for i in T.serial(16):
T.assume(i < 14 or A[i] == 0)
T.evaluate(T.assume(i < 14 or A[i] == 0))

for i in T.serial(16):
if 14 <= i:
Expand All @@ -1317,7 +1317,7 @@ def before(A: T.Buffer[16, "int32"]):

def expected(A: T.Buffer[16, "int32"]):
for i in T.serial(16):
T.assume(i < 14 or A[i] == 0)
T.evaluate(T.assume(i < 14 or A[i] == 0))

for i in T.serial(16):
if 14 <= i:
Expand All @@ -1338,7 +1338,7 @@ class TestNoSimplificationIfPredicateNotMet(BaseBeforeAfter):
def before(A: T.Buffer[16, "int32"]):
for i in T.serial(16):
if 14 <= i:
T.assume(A[i] == 0)
T.evaluate(T.assume(A[i] == 0))

for i in T.serial(16):
if i < 14:
Expand Down Expand Up @@ -1375,7 +1375,7 @@ class TestNoSimplifyUsingOverwrittenValue(BaseBeforeAfter):

def before(A: T.Buffer[16, "int32"]):
for i in T.serial(16):
T.assume(A[i] == 0)
T.evaluate(T.assume(A[i] == 0))

for i in T.serial(16):
if i == 0:
Expand Down Expand Up @@ -1422,7 +1422,7 @@ class TestSimplifyPriorToOverwrittenValue(BaseBeforeAfter):

def before(A: T.Buffer[16, "int32"]):
for i in T.serial(16):
T.assume(A[i] == 0)
T.evaluate(T.assume(A[i] == 0))

for i in T.serial(16):
if A[i] == 0:
Expand All @@ -1436,7 +1436,7 @@ def before(A: T.Buffer[16, "int32"]):

def expected(A: T.Buffer[16, "int32"]):
for i in T.serial(16):
T.assume(A[i] == 0)
T.evaluate(T.assume(A[i] == 0))

for i in T.serial(16):
A[i] = 17
Expand Down Expand Up @@ -1515,7 +1515,7 @@ class TestSimplifyUsingTransitiveKnownBufferValue(BaseBeforeAfter):
propagate_knowns_to_prove_conditional = True

def before(A: T.Buffer[1, "int32"]):
T.assume(A[0] == 0)
T.evaluate(T.assume(A[0] == 0))

A[0] = A[0] + 1
A[0] = A[0] + 1
Expand All @@ -1525,7 +1525,7 @@ def before(A: T.Buffer[1, "int32"]):
A[0] = 42

def expected(A: T.Buffer[1, "int32"]):
T.assume(A[0] == 0)
T.evaluate(T.assume(A[0] == 0))

A[0] = A[0] + 1
A[0] = A[0] + 1
Expand Down Expand Up @@ -1591,7 +1591,7 @@ class TestSimplifyUsingPartiallyProvenBufferValueGather(BaseBeforeAfter):
def before(A: T.Buffer[24, "int32"], B: T.Buffer[24, "int32"], F: T.Buffer[3, "int32"]):
# A has non-zero values only in the range 3 <= i < 17
for i in T.serial(24):
T.assume(((3 <= i) and (i < 17)) or A[i] == 0)
T.evaluate(T.assume(((3 <= i) and (i < 17)) or A[i] == 0))

# After convoluting with F, B has non-zero values only in the
# range 3 <= i < 19.
Expand All @@ -1611,7 +1611,7 @@ def before(A: T.Buffer[24, "int32"], B: T.Buffer[24, "int32"], F: T.Buffer[3, "i

def expected(A: T.Buffer[24, "int32"], B: T.Buffer[24, "int32"], F: T.Buffer[3, "int32"]):
for i in T.serial(24):
T.assume(((3 <= i) and (i < 17)) or A[i] == 0)
T.evaluate(T.assume(((3 <= i) and (i < 17)) or A[i] == 0))

for i in T.serial(24):
B[i] = 0
Expand All @@ -1637,7 +1637,7 @@ class TestSimplifyUsingPartiallyProvenBufferValueScatter(BaseBeforeAfter):
def before(A: T.Buffer[24, "int32"], B: T.Buffer[24, "int32"], F: T.Buffer[3, "int32"]):
# A has non-zero values only in the range 3 <= i < 17
for i in T.serial(24):
T.assume(((3 <= i) and (i < 17)) or A[i] == 0)
T.evaluate(T.assume(((3 <= i) and (i < 17)) or A[i] == 0))

for i in T.serial(24):
B[i] = 0
Expand All @@ -1659,7 +1659,7 @@ def before(A: T.Buffer[24, "int32"], B: T.Buffer[24, "int32"], F: T.Buffer[3, "i

def expected(A: T.Buffer[24, "int32"], B: T.Buffer[24, "int32"], F: T.Buffer[3, "int32"]):
for i in T.serial(24):
T.assume(((3 <= i) and (i < 17)) or A[i] == 0)
T.evaluate(T.assume(((3 <= i) and (i < 17)) or A[i] == 0))

for i in T.serial(24):
B[i] = 0
Expand Down

0 comments on commit 1bbf96f

Please sign in to comment.