From fd7bf6f671795c68e14096fb7517d06f7240360c Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Sat, 22 Jan 2022 15:28:00 -0800 Subject: [PATCH] [Hotfix] A unittest --- .../test_meta_schedule_postproc_rewrite_unbound_block.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python/unittest/test_meta_schedule_postproc_rewrite_unbound_block.py b/tests/python/unittest/test_meta_schedule_postproc_rewrite_unbound_block.py index 4ab2741da181..9b39ad1bff3e 100644 --- a/tests/python/unittest/test_meta_schedule_postproc_rewrite_unbound_block.py +++ b/tests/python/unittest/test_meta_schedule_postproc_rewrite_unbound_block.py @@ -99,7 +99,7 @@ def main(A: T.Buffer[(1, 256, 256), "float32"], D: T.Buffer[(1,), "float32"]) -> with T.block("C"): b = T.axis.S(1, 0) i, j = T.axis.remap("RR", [i1, i2]) - T.where(i0_fused_0 * 32 + i0_fused_1 < 1) + T.where(i0_fused_1 < 1) with T.init(): C[b] = T.float32(0) C[b] = C[b] + A[b, i, j] * A[b, i, j] @@ -107,7 +107,7 @@ def main(A: T.Buffer[(1, 256, 256), "float32"], D: T.Buffer[(1,), "float32"]) -> for i0_fused_1 in T.thread_binding(32, thread="threadIdx.x"): with T.block("D"): b = T.axis.S(1, 0) - T.where(i0_fused_0 * 32 + i0_fused_1 < 1) + T.where(i0_fused_1 < 1) D[b] = T.sqrt(C[b], dtype="float32")