From 9ca2139d0f3c52912b37a897c34d0d4175ce088a Mon Sep 17 00:00:00 2001 From: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com> Date: Mon, 28 Feb 2022 04:01:04 -0500 Subject: [PATCH] [TensorIR] Renormalize split pattern (#10401) --- include/tvm/tir/transform.h | 6 + python/tvm/tir/transform/transform.py | 11 + src/driver/driver_api.cc | 1 + .../transforms/renormalize_split_pattern.cc | 212 ++++++++++++++++++ ...tir_transform_renormalize_split_pattern.py | 119 ++++++++++ 5 files changed, 349 insertions(+) create mode 100644 src/tir/transforms/renormalize_split_pattern.cc create mode 100644 tests/python/unittest/test_tir_transform_renormalize_split_pattern.py diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index 3bb5491affdf..4330c4f7c64a 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -611,6 +611,12 @@ TVM_DLL Pass BindParams(const Array& constants); */ TVM_DLL Pass ExtractPrimFuncConstants(); +/*! + * \brief Renormalize the split pattern from floordiv(floormod()) to floormod(floordiv()) + * \return The pass. + */ +TVM_DLL Pass RenormalizeSplitPattern(); + } // namespace transform } // namespace tir } // namespace tvm diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index d5f1a9ae979b..1bb1a3e47a54 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -782,3 +782,14 @@ def ExtractPrimFuncConstants(): The result pass """ return _ffi_api.ExtractPrimFuncConstants() # type: ignore + + +def RenomalizeSplitPattern(): + """Renormalize the split pattern from floordiv(floormod()) to floormod(floordiv()) + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.RenormalizeSplitPattern() # type: ignore diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 2a0c2f73f2ba..54126aaa5119 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -275,6 +275,7 @@ Array CreatePassList(bool disable_loop_partition) { pass_list.insert(pass_list.end(), user_lower_phase2.begin(), user_lower_phase2.end()); // PHASE 3 + pass_list.push_back(tir::transform::RenormalizeSplitPattern()); pass_list.push_back(tir::transform::Simplify()); pass_list.push_back(tir::transform::RemoveNoOp()); pass_list.push_back(tir::transform::RewriteUnsafeSelect()); diff --git a/src/tir/transforms/renormalize_split_pattern.cc b/src/tir/transforms/renormalize_split_pattern.cc new file mode 100644 index 000000000000..d55df5ea92eb --- /dev/null +++ b/src/tir/transforms/renormalize_split_pattern.cc @@ -0,0 +1,212 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file renormalize_split_pattern.cc + * \brief Renormalize the split pattern from floordiv(floormod()) to floormod(floordiv()) + */ +#include +#include +#include +#include +#include +#include + +#include "../../arith/ir_mutator_with_analyzer.h" +#include "../../arith/pattern_match.h" + +namespace tvm { +namespace tir { + +using namespace arith; + +// macro for doing simple rewrite +#define TRY_REWRITE(SrcExpr, ResExpr) \ + if ((SrcExpr).Match(ret)) { \ + return (ResExpr).Eval(); \ + } + +// macro rewrite + recursive_rewrite only if CondExpr is true after match. +#define TRY_RECURSIVE_REWRITE_IF(SrcExpr, ResExpr, CondExpr) \ + if ((SrcExpr).Match(ret) && (CondExpr)) { \ + return RecursiveRewrite((ResExpr).Eval()); \ + } + +class SplitPatternReNormalizer : public IRMutatorWithAnalyzer { + public: + explicit SplitPatternReNormalizer(Analyzer* analyzer) : IRMutatorWithAnalyzer(analyzer) {} + + PrimExpr VisitExpr_(const FloorDivNode* op) final { + PrimExpr a = VisitExpr(op->a); + PrimExpr b = VisitExpr(op->b); + PrimExpr ret = floordiv(a, b); + // Pattern var to match any expression + PVar x, y, z; + // Pattern var match IntImm + PVar c1, c2, c3; + // Pattern var for lanes in broadcast and ramp + PVar lanes; + + // floordiv(floormod(x, c1 * c2), c2) = floormod(floordiv(x, c2), c1) + TRY_RECURSIVE_REWRITE_IF(floordiv(floormod(x, c3), c2), + floormod(floordiv(x, c2), floordiv(c3, c2)), + c3.Eval()->value % c2.Eval()->value == 0); + TRY_RECURSIVE_REWRITE_IF( + floordiv(floormod(x, broadcast(c3, lanes)), broadcast(c2, lanes)), + floormod(floordiv(x, broadcast(c2, lanes)), broadcast(floordiv(c3, c2), lanes)), + c3.Eval()->value % c2.Eval()->value == 0); + + // floordiv(x*c1*c3 + y, c2*c3) = floordiv(x*c1 + floordiv(y, c3), c2) + if ((floordiv(x * c1 + y, c2)).Match(ret)) { + int64_t c1_val = c1.Eval()->value; + int64_t c2_val = c2.Eval()->value; + if (c1_val > 0 && c2_val > 0) { + int64_t c3 = ZeroAwareGCD(c1_val, c2_val); + if (c3 > 1) { + IntImm c1_div = IntImm(c1.Eval().dtype(), c1_val / c3); + IntImm c2_div = IntImm(c2.Eval().dtype(), c2_val / c3); + return RecursiveRewrite(floordiv(x.Eval() * c1_div + floordiv(y.Eval(), c3), c2_div)); + } + } + } + if ((floordiv(x * broadcast(c1, lanes) + y, broadcast(c2, lanes))).Match(ret)) { + int64_t c1_val = c1.Eval()->value; + int64_t c2_val = c2.Eval()->value; + if (c1_val > 0 && c2_val > 0) { + int64_t c3 = ZeroAwareGCD(c1_val, c2_val); + if (c3 > 1) { + IntImm c1_div = IntImm(c1.Eval().dtype(), c1_val / c3); + IntImm c2_div = IntImm(c2.Eval().dtype(), c2_val / c3); + return RecursiveRewrite(floordiv( + x.Eval() * Broadcast(c1_div, lanes.Eval()) + + floordiv(y.Eval(), Broadcast(IntImm(c1.Eval().dtype(), c3), lanes.Eval())), + Broadcast(c2_div, lanes.Eval()))); + } + } + } + + // floordiv(x*c1*c3 + y + z, c2*c3) = floordiv(x*c1 + floordiv(y + z, c3), c2) + if ((floordiv(x * c1 + y + z, c2)).Match(ret)) { + int64_t c1_val = c1.Eval()->value; + int64_t c2_val = c2.Eval()->value; + if (c1_val > 0 && c2_val > 0) { + int64_t c3 = ZeroAwareGCD(c1_val, c2_val); + if (c3 > 1) { + IntImm c1_div = IntImm(c1.Eval().dtype(), c1_val / c3); + IntImm c2_div = IntImm(c2.Eval().dtype(), c2_val / c3); + return RecursiveRewrite( + floordiv(x.Eval() * c1_div + floordiv(y.Eval() + z.Eval(), c3), c2_div)); + } + } + } + if ((floordiv(x * broadcast(c1, lanes) + y + z, broadcast(c2, lanes))).Match(ret)) { + int64_t c1_val = c1.Eval()->value; + int64_t c2_val = c2.Eval()->value; + if (c1_val > 0 && c2_val > 0) { + int64_t c3 = ZeroAwareGCD(c1_val, c2_val); + if (c3 > 1) { + IntImm c1_div = IntImm(c1.Eval().dtype(), c1_val / c3); + IntImm c2_div = IntImm(c2.Eval().dtype(), c2_val / c3); + return RecursiveRewrite( + floordiv(x.Eval() * Broadcast(c1_div, lanes.Eval()) + + floordiv(y.Eval() + z.Eval(), + Broadcast(IntImm(c1.Eval().dtype(), c3), lanes.Eval())), + Broadcast(c2_div, lanes.Eval()))); + } + } + } + + return ret; + } + + PrimExpr VisitExpr_(const LENode* op) { return this->VisitExpr(Not(op->b < op->a)); } + + PrimExpr VisitExpr_(const GTNode* op) { return this->VisitExpr(op->b < op->a); } + + PrimExpr VisitExpr_(const GENode* op) { return this->VisitExpr(Not(op->a < op->b)); } + + PrimExpr VisitExpr_(const LTNode* op) { + PrimExpr a = VisitExpr(op->a); + PrimExpr b = VisitExpr(op->b); + PrimExpr ret = tir::LT(a, b); + // Pattern var to match any expression + PVar x; + // Pattern var match IntImm + PVar c1, c2; + // x < c2 <=> x/c2 < 1 <=> floor(x / c2) < 1 + TRY_RECURSIVE_REWRITE_IF(xvalue> 0); // NOLINT + return ret; + } + + PrimExpr VisitExpr_(const NotNode* op) { + PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); + // Pattern var to match any expression + PVar x, y; + TRY_REWRITE(!(!x), x); + TRY_REWRITE(!(x <= y), y < x); + TRY_REWRITE(!(x >= y), x < y); + TRY_REWRITE(!(x < y), y <= x); + TRY_REWRITE(!(x > y), x <= y); + return ret; + } + + Stmt VisitStmt_(const ForNode* op) final { + analyzer_->Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent)); + With ctx1(analyzer_, op->loop_var >= op->min); + With ctx2(analyzer_, op->loop_var < op->min + op->extent); + return IRMutatorWithAnalyzer::VisitStmt_(op); + } + + // Recursive rewrite x + // we limit maximum depth of recursive rewrite allowed to + // avoid infinite loop + PrimExpr RecursiveRewrite(const PrimExpr& x) { + if (recur_depth_ >= kMaxRecurDepth) return x; + ++recur_depth_; + PrimExpr res = this->VisitExpr(x); + --recur_depth_; + return res; + } + + private: + // counter to record recursive rewrite depth. + int recur_depth_{0}; + // maximum number of recursion allowed during a single pass. + static const constexpr int kMaxRecurDepth = 5; +}; + +namespace transform { + +Pass RenormalizeSplitPattern() { + auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { + auto* n = f.CopyOnWrite(); + arith::Analyzer analyzer; + n->body = SplitPatternReNormalizer(&analyzer)(std::move(n->body)); + return f; + }; + return CreatePrimFuncPass(pass_func, 0, "tir.RenormalizeSplitPattern", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.RenormalizeSplitPattern") + .set_body_typed(RenormalizeSplitPattern); + +} // namespace transform + +} // namespace tir +} // namespace tvm diff --git a/tests/python/unittest/test_tir_transform_renormalize_split_pattern.py b/tests/python/unittest/test_tir_transform_renormalize_split_pattern.py new file mode 100644 index 000000000000..eb3efd317e9c --- /dev/null +++ b/tests/python/unittest/test_tir_transform_renormalize_split_pattern.py @@ -0,0 +1,119 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +from tvm.script import tir as T + +# fmt: off +# pylint: disable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,redundant-keyword-arg + +@tvm.script.ir_module +class Before: + @T.prim_func + def main(inputs: T.Buffer[(1, 4, 4, 512), "float32"], weight: T.Buffer[(4, 4, 512, 256), "float32"], conv2d_transpose_nhwc: T.Buffer[(1, 8, 8, 256), "float32"]) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + # var definition + threadIdx_x = T.env_thread("threadIdx.x") + blockIdx_x = T.env_thread("blockIdx.x") + # body + T.launch_thread(blockIdx_x, 64) + conv2d_transpose_nhwc_local = T.allocate([8], "float32", "local") + PadInput_shared = T.allocate([768], "float32", "shared") + weight_shared = T.allocate([4096], "float32", "shared") + T.launch_thread(threadIdx_x, 32) + for i2_3_init, i1_4_init, i2_4_init in T.grid(2, 2, 2): + T.store(conv2d_transpose_nhwc_local, i1_4_init * 4 + i2_3_init * 2 + i2_4_init, T.float32(0), True) + for i6_0 in T.serial(16): + for ax0_ax1_ax2_ax3_fused_0 in T.serial(24): + T.store(PadInput_shared, ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x, T.if_then_else(128 <= ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x and ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x < 640 and 1 <= blockIdx_x // 32 * 2 + (ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x) % 128 // 32 and blockIdx_x // 32 * 2 + (ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x) % 128 // 32 < 5, T.load("float32", inputs.data, blockIdx_x // 32 * 1024 + ax0_ax1_ax2_ax3_fused_0 * 512 + i6_0 * 32 + threadIdx_x - 2560), T.float32(0), dtype="float32"), True) + for ax0_ax1_ax2_ax3_fused_0 in T.serial(32): + T.store(weight_shared, T.ramp(ax0_ax1_ax2_ax3_fused_0 * 128 + threadIdx_x * 4, 1, 4), T.load("float32x4", weight.data, T.ramp((ax0_ax1_ax2_ax3_fused_0 * 128 + threadIdx_x * 4) // 256 * 131072 + i6_0 * 8192 + (ax0_ax1_ax2_ax3_fused_0 * 128 + threadIdx_x * 4) % 256 // 8 * 256 + blockIdx_x % 32 * 8 + threadIdx_x % 2 * 4, 1, 4), T.broadcast(True, 4)), T.broadcast(True, 4)) + for i6_1, i2_3, i4_2, i5_2, i6_2, i1_4, i2_4 in T.grid(4, 2, 4, 4, 8, 2, 2): + T.store(conv2d_transpose_nhwc_local, i1_4 * 4 + i2_3 * 2 + i2_4, T.load("float32", conv2d_transpose_nhwc_local, i1_4 * 4 + i2_3 * 2 + i2_4) + T.if_then_else((i1_4 + i4_2) % 2 == 0 and (i2_4 + i5_2) % 2 == 0, T.load("float32", PadInput_shared, threadIdx_x // 8 * 128 + (i1_4 + i4_2) // 2 * 128 + (i2_4 + i5_2) // 2 * 32 + i2_3 * 32 + i6_1 * 8 + i6_2), T.float32(0), dtype="float32") * T.load("float32", weight_shared, i6_1 * 64 + i6_2 * 8 + threadIdx_x % 8 + 3840 - i5_2 * 256 - i4_2 * 1024), True) + for ax1, ax2 in T.grid(2, 4): + T.store(conv2d_transpose_nhwc.data, threadIdx_x // 8 * 4096 + ax1 * 2048 + blockIdx_x // 32 * 1024 + ax2 * 256 + blockIdx_x % 32 * 8 + threadIdx_x % 8, T.load("float32", conv2d_transpose_nhwc_local, ax1 * 4 + ax2), True) + + +@tvm.script.ir_module +class After: + @T.prim_func + def main(inputs: T.Buffer[(1, 4, 4, 512), "float32"], weight: T.Buffer[(4, 4, 512, 256), "float32"], conv2d_transpose_nhwc: T.Buffer[(1, 8, 8, 256), "float32"]) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + # var definition + threadIdx_x = T.env_thread("threadIdx.x") + blockIdx_x = T.env_thread("blockIdx.x") + # body + T.launch_thread(blockIdx_x, 64) + conv2d_transpose_nhwc_local = T.allocate([8], "float32", "local") + PadInput_shared = T.allocate([768], "float32", "shared") + weight_shared = T.allocate([4096], "float32", "shared") + T.launch_thread(threadIdx_x, 32) + for i2_3_init, i1_4_init, i2_4_init in T.grid(2, 2, 2): + T.store(conv2d_transpose_nhwc_local, i1_4_init * 4 + i2_3_init * 2 + i2_4_init, T.float32(0), True) + for i6_0 in T.serial(16): + for ax0_ax1_ax2_ax3_fused_0 in T.serial(24): + T.store(PadInput_shared, ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x, T.if_then_else(1 <= (ax0_ax1_ax2_ax3_fused_0 + threadIdx_x // 32) // 4 and (ax0_ax1_ax2_ax3_fused_0 + threadIdx_x // 32) // 20 < 1 and 1 <= blockIdx_x // 32 * 2 + (ax0_ax1_ax2_ax3_fused_0 + threadIdx_x // 32) % 4 and (blockIdx_x // 32 * 2 + (ax0_ax1_ax2_ax3_fused_0 + threadIdx_x // 32) % 4) // 5 < 1, T.load("float32", inputs.data, blockIdx_x // 32 * 1024 + ax0_ax1_ax2_ax3_fused_0 * 512 + i6_0 * 32 + threadIdx_x - 2560), T.float32(0), dtype="float32"), True) + for ax0_ax1_ax2_ax3_fused_0 in T.serial(32): + T.store(weight_shared, T.ramp(ax0_ax1_ax2_ax3_fused_0 * 128 + threadIdx_x * 4, 1, 4), T.load("float32x4", weight.data, T.ramp((ax0_ax1_ax2_ax3_fused_0 + threadIdx_x * 4 // 128) // 2 * 131072 + i6_0 * 8192 + (ax0_ax1_ax2_ax3_fused_0 * 16 + threadIdx_x * 4 // 8) % 32 * 256 + blockIdx_x % 32 * 8 + threadIdx_x % 2 * 4, 1, 4), T.broadcast(True, 4)), T.broadcast(True, 4)) + for i6_1, i2_3, i4_2, i5_2, i6_2, i1_4, i2_4 in T.grid(4, 2, 4, 4, 8, 2, 2): + T.store(conv2d_transpose_nhwc_local, i1_4 * 4 + i2_3 * 2 + i2_4, T.load("float32", conv2d_transpose_nhwc_local, i1_4 * 4 + i2_3 * 2 + i2_4) + T.if_then_else((i1_4 + i4_2) % 2 == 0 and (i2_4 + i5_2) % 2 == 0, T.load("float32", PadInput_shared, threadIdx_x // 8 * 128 + (i1_4 + i4_2) // 2 * 128 + (i2_4 + i5_2) // 2 * 32 + i2_3 * 32 + i6_1 * 8 + i6_2), T.float32(0), dtype="float32") * T.load("float32", weight_shared, i6_1 * 64 + i6_2 * 8 + threadIdx_x % 8 + 3840 - i5_2 * 256 - i4_2 * 1024), True) + for ax1, ax2 in T.grid(2, 4): + T.store(conv2d_transpose_nhwc.data, threadIdx_x // 8 * 4096 + ax1 * 2048 + blockIdx_x // 32 * 1024 + ax2 * 256 + blockIdx_x % 32 * 8 + threadIdx_x % 8, T.load("float32", conv2d_transpose_nhwc_local, ax1 * 4 + ax2), True) + + +@tvm.script.ir_module +class After_simplified: + @T.prim_func + def main(inputs: T.Buffer[(1, 4, 4, 512), "float32"], weight: T.Buffer[(4, 4, 512, 256), "float32"], conv2d_transpose_nhwc: T.Buffer[(1, 8, 8, 256), "float32"]) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + # var definition + threadIdx_x = T.env_thread("threadIdx.x") + blockIdx_x = T.env_thread("blockIdx.x") + # body + T.launch_thread(blockIdx_x, 64) + conv2d_transpose_nhwc_local = T.allocate([8], "float32", "local") + PadInput_shared = T.allocate([768], "float32", "shared") + weight_shared = T.allocate([4096], "float32", "shared") + T.launch_thread(threadIdx_x, 32) + for i2_3_init, i1_4_init, i2_4_init in T.grid(2, 2, 2): + T.store(conv2d_transpose_nhwc_local, i1_4_init * 4 + i2_3_init * 2 + i2_4_init, T.float32(0), True) + for i6_0 in T.serial(16): + for ax0_ax1_ax2_ax3_fused_0 in T.serial(24): + T.store(PadInput_shared, ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x, T.if_then_else(4 <= ax0_ax1_ax2_ax3_fused_0 and ax0_ax1_ax2_ax3_fused_0 < 20 and 1 <= blockIdx_x // 32 * 2 + ax0_ax1_ax2_ax3_fused_0 % 4 and blockIdx_x // 32 * 2 + ax0_ax1_ax2_ax3_fused_0 % 4 < 5, T.load("float32", inputs.data, blockIdx_x // 32 * 1024 + ax0_ax1_ax2_ax3_fused_0 * 512 + i6_0 * 32 + threadIdx_x - 2560), T.float32(0), dtype="float32"), True) + for ax0_ax1_ax2_ax3_fused_0 in T.serial(32): + T.store(weight_shared, T.ramp(ax0_ax1_ax2_ax3_fused_0 * 128 + threadIdx_x * 4, 1, 4), T.load("float32x4", weight.data, T.ramp(ax0_ax1_ax2_ax3_fused_0 // 2 * 131072 + i6_0 * 8192 + (ax0_ax1_ax2_ax3_fused_0 * 16 + threadIdx_x // 2) % 32 * 256 + blockIdx_x % 32 * 8 + threadIdx_x % 2 * 4, 1, 4), T.broadcast(True, 4)), T.broadcast(True, 4)) + for i6_1, i2_3, i4_2, i5_2, i6_2, i1_4, i2_4 in T.grid(4, 2, 4, 4, 8, 2, 2): + T.store(conv2d_transpose_nhwc_local, i1_4 * 4 + i2_3 * 2 + i2_4, T.load("float32", conv2d_transpose_nhwc_local, i1_4 * 4 + i2_3 * 2 + i2_4) + T.if_then_else((i1_4 + i4_2) % 2 == 0 and (i2_4 + i5_2) % 2 == 0, T.load("float32", PadInput_shared, threadIdx_x // 8 * 128 + (i1_4 + i4_2) // 2 * 128 + (i2_4 + i5_2) // 2 * 32 + i2_3 * 32 + i6_1 * 8 + i6_2), T.float32(0), dtype="float32") * T.load("float32", weight_shared, i6_1 * 64 + i6_2 * 8 + threadIdx_x % 8 + 3840 - i5_2 * 256 - i4_2 * 1024), True) + for ax1, ax2 in T.grid(2, 4): + T.store(conv2d_transpose_nhwc.data, threadIdx_x // 8 * 4096 + ax1 * 2048 + blockIdx_x // 32 * 1024 + ax2 * 256 + blockIdx_x % 32 * 8 + threadIdx_x % 8, T.load("float32", conv2d_transpose_nhwc_local, ax1 * 4 + ax2), True) + +# pylint: enable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,redundant-keyword-arg +# fmt: on + + +def tesd_renormalize_split_pattern(): + after = tvm.tir.transform.RenomalizeSplitPattern()(Before) + tvm.ir.assert_structural_equal(after, After) + after = tvm.tir.transform.Simplify()(after) + tvm.ir.assert_structural_equal(after, After_simplified) + + +if __name__ == "__main__": + tesd_renormalize_split_pattern()