Skip to content

Commit

Permalink
lint fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Lunderberg committed Jul 10, 2024
1 parent 7488ea1 commit 6d77ec9
Showing 1 changed file with 9 additions and 27 deletions.
36 changes: 9 additions & 27 deletions tests/python/relax/test_dataflow_rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,12 +257,8 @@ def before(x: R.Tensor([16], "float32")):

@R.function(private=True)
def expected(x: R.Tensor([16], "float32")):
y = R.call_pure_packed(
"my_optimized_add_impl", x, x, sinfo_args=R.Tensor([16], "float32")
)
z = R.call_pure_packed(
"my_optimized_add_impl", y, y, sinfo_args=R.Tensor([16], "float32")
)
y = R.call_pure_packed("my_optimized_add_impl", x, x, sinfo_args=R.Tensor([16], "float32"))
z = R.call_pure_packed("my_optimized_add_impl", y, y, sinfo_args=R.Tensor([16], "float32"))
return z

after = Rewriter(before)
Expand Down Expand Up @@ -316,12 +312,8 @@ def expected(
B: R.Tensor([16], "float32"),
C: R.Tensor([16], "float32"),
):
D = R.call_pure_packed(
"my_optimized_add_impl", A, B, sinfo_args=R.Tensor([16], "float32")
)
E = R.call_pure_packed(
"my_optimized_mul_impl", C, D, sinfo_args=R.Tensor([16], "float32")
)
D = R.call_pure_packed("my_optimized_add_impl", A, B, sinfo_args=R.Tensor([16], "float32"))
E = R.call_pure_packed("my_optimized_mul_impl", C, D, sinfo_args=R.Tensor([16], "float32"))
return E

rewriter = RewriteAdd | RewriteMultiply
Expand Down Expand Up @@ -457,9 +449,7 @@ def pattern(A: R.Tensor([16], "float32")):

@R.function
def replacement(A: R.Tensor([16], "float32")):
return R.call_tir(
RewriteMul.subroutine_mul, [A], out_sinfo=R.Tensor([16], "float32")
)
return R.call_tir(RewriteMul.subroutine_mul, [A], out_sinfo=R.Tensor([16], "float32"))

@T.prim_func(private=True)
def subroutine_mul(A: T.Buffer(16, "float32"), B: T.Buffer(16, "float32")):
Expand Down Expand Up @@ -537,9 +527,7 @@ def pattern(A: R.Tensor([16], "float32")):

@R.function
def replacement(A: R.Tensor([16], "float32")):
return R.call_tir(
RewriteMul.subroutine, [A], out_sinfo=R.Tensor([16], "float32")
)
return R.call_tir(RewriteMul.subroutine, [A], out_sinfo=R.Tensor([16], "float32"))

@T.prim_func(private=True)
def subroutine(A: T.Buffer(16, "float32"), B: T.Buffer(16, "float32")):
Expand All @@ -559,9 +547,7 @@ class Expected:
@R.function
def main(A: R.Tensor([16], "float32")):
B = Expected.subroutine(A)
C = R.call_tir(
Expected.subroutine_1, [B], out_sinfo=R.Tensor([16], "float32")
)
C = R.call_tir(Expected.subroutine_1, [B], out_sinfo=R.Tensor([16], "float32"))
return C

@R.function(private=True)
Expand Down Expand Up @@ -1212,9 +1198,7 @@ def replacement(A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32")):
)

@R.function(private=True)
def before(
A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32"), cond: R.Prim("bool")
):
def before(A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32"), cond: R.Prim("bool")):
if cond:
out = A + B
else:
Expand All @@ -1223,9 +1207,7 @@ def before(
return out

@R.function(private=True)
def expected(
A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32"), cond: R.Prim("bool")
):
def expected(A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32"), cond: R.Prim("bool")):
if cond:
out = R.call_pure_packed(
"my_optimized_add_impl", A, B, sinfo_args=R.Tensor([16], "float32")
Expand Down

0 comments on commit 6d77ec9

Please sign in to comment.