-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[TensorIR] Renormalize split pattern (#10401)
- Loading branch information
1 parent
b7caa12
commit 9ca2139
Showing
5 changed files
with
349 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,212 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
|
||
/*! | ||
* \file renormalize_split_pattern.cc | ||
* \brief Renormalize the split pattern from floordiv(floormod()) to floormod(floordiv()) | ||
*/ | ||
#include <tvm/runtime/registry.h> | ||
#include <tvm/tir/analysis.h> | ||
#include <tvm/tir/op.h> | ||
#include <tvm/tir/stmt.h> | ||
#include <tvm/tir/stmt_functor.h> | ||
#include <tvm/tir/transform.h> | ||
|
||
#include "../../arith/ir_mutator_with_analyzer.h" | ||
#include "../../arith/pattern_match.h" | ||
|
||
namespace tvm { | ||
namespace tir { | ||
|
||
using namespace arith; | ||
|
||
// macro for doing simple rewrite | ||
#define TRY_REWRITE(SrcExpr, ResExpr) \ | ||
if ((SrcExpr).Match(ret)) { \ | ||
return (ResExpr).Eval(); \ | ||
} | ||
|
||
// macro rewrite + recursive_rewrite only if CondExpr is true after match. | ||
#define TRY_RECURSIVE_REWRITE_IF(SrcExpr, ResExpr, CondExpr) \ | ||
if ((SrcExpr).Match(ret) && (CondExpr)) { \ | ||
return RecursiveRewrite((ResExpr).Eval()); \ | ||
} | ||
|
||
class SplitPatternReNormalizer : public IRMutatorWithAnalyzer { | ||
public: | ||
explicit SplitPatternReNormalizer(Analyzer* analyzer) : IRMutatorWithAnalyzer(analyzer) {} | ||
|
||
PrimExpr VisitExpr_(const FloorDivNode* op) final { | ||
PrimExpr a = VisitExpr(op->a); | ||
PrimExpr b = VisitExpr(op->b); | ||
PrimExpr ret = floordiv(a, b); | ||
// Pattern var to match any expression | ||
PVar<PrimExpr> x, y, z; | ||
// Pattern var match IntImm | ||
PVar<IntImm> c1, c2, c3; | ||
// Pattern var for lanes in broadcast and ramp | ||
PVar<int> lanes; | ||
|
||
// floordiv(floormod(x, c1 * c2), c2) = floormod(floordiv(x, c2), c1) | ||
TRY_RECURSIVE_REWRITE_IF(floordiv(floormod(x, c3), c2), | ||
floormod(floordiv(x, c2), floordiv(c3, c2)), | ||
c3.Eval()->value % c2.Eval()->value == 0); | ||
TRY_RECURSIVE_REWRITE_IF( | ||
floordiv(floormod(x, broadcast(c3, lanes)), broadcast(c2, lanes)), | ||
floormod(floordiv(x, broadcast(c2, lanes)), broadcast(floordiv(c3, c2), lanes)), | ||
c3.Eval()->value % c2.Eval()->value == 0); | ||
|
||
// floordiv(x*c1*c3 + y, c2*c3) = floordiv(x*c1 + floordiv(y, c3), c2) | ||
if ((floordiv(x * c1 + y, c2)).Match(ret)) { | ||
int64_t c1_val = c1.Eval()->value; | ||
int64_t c2_val = c2.Eval()->value; | ||
if (c1_val > 0 && c2_val > 0) { | ||
int64_t c3 = ZeroAwareGCD(c1_val, c2_val); | ||
if (c3 > 1) { | ||
IntImm c1_div = IntImm(c1.Eval().dtype(), c1_val / c3); | ||
IntImm c2_div = IntImm(c2.Eval().dtype(), c2_val / c3); | ||
return RecursiveRewrite(floordiv(x.Eval() * c1_div + floordiv(y.Eval(), c3), c2_div)); | ||
} | ||
} | ||
} | ||
if ((floordiv(x * broadcast(c1, lanes) + y, broadcast(c2, lanes))).Match(ret)) { | ||
int64_t c1_val = c1.Eval()->value; | ||
int64_t c2_val = c2.Eval()->value; | ||
if (c1_val > 0 && c2_val > 0) { | ||
int64_t c3 = ZeroAwareGCD(c1_val, c2_val); | ||
if (c3 > 1) { | ||
IntImm c1_div = IntImm(c1.Eval().dtype(), c1_val / c3); | ||
IntImm c2_div = IntImm(c2.Eval().dtype(), c2_val / c3); | ||
return RecursiveRewrite(floordiv( | ||
x.Eval() * Broadcast(c1_div, lanes.Eval()) + | ||
floordiv(y.Eval(), Broadcast(IntImm(c1.Eval().dtype(), c3), lanes.Eval())), | ||
Broadcast(c2_div, lanes.Eval()))); | ||
} | ||
} | ||
} | ||
|
||
// floordiv(x*c1*c3 + y + z, c2*c3) = floordiv(x*c1 + floordiv(y + z, c3), c2) | ||
if ((floordiv(x * c1 + y + z, c2)).Match(ret)) { | ||
int64_t c1_val = c1.Eval()->value; | ||
int64_t c2_val = c2.Eval()->value; | ||
if (c1_val > 0 && c2_val > 0) { | ||
int64_t c3 = ZeroAwareGCD(c1_val, c2_val); | ||
if (c3 > 1) { | ||
IntImm c1_div = IntImm(c1.Eval().dtype(), c1_val / c3); | ||
IntImm c2_div = IntImm(c2.Eval().dtype(), c2_val / c3); | ||
return RecursiveRewrite( | ||
floordiv(x.Eval() * c1_div + floordiv(y.Eval() + z.Eval(), c3), c2_div)); | ||
} | ||
} | ||
} | ||
if ((floordiv(x * broadcast(c1, lanes) + y + z, broadcast(c2, lanes))).Match(ret)) { | ||
int64_t c1_val = c1.Eval()->value; | ||
int64_t c2_val = c2.Eval()->value; | ||
if (c1_val > 0 && c2_val > 0) { | ||
int64_t c3 = ZeroAwareGCD(c1_val, c2_val); | ||
if (c3 > 1) { | ||
IntImm c1_div = IntImm(c1.Eval().dtype(), c1_val / c3); | ||
IntImm c2_div = IntImm(c2.Eval().dtype(), c2_val / c3); | ||
return RecursiveRewrite( | ||
floordiv(x.Eval() * Broadcast(c1_div, lanes.Eval()) + | ||
floordiv(y.Eval() + z.Eval(), | ||
Broadcast(IntImm(c1.Eval().dtype(), c3), lanes.Eval())), | ||
Broadcast(c2_div, lanes.Eval()))); | ||
} | ||
} | ||
} | ||
|
||
return ret; | ||
} | ||
|
||
PrimExpr VisitExpr_(const LENode* op) { return this->VisitExpr(Not(op->b < op->a)); } | ||
|
||
PrimExpr VisitExpr_(const GTNode* op) { return this->VisitExpr(op->b < op->a); } | ||
|
||
PrimExpr VisitExpr_(const GENode* op) { return this->VisitExpr(Not(op->a < op->b)); } | ||
|
||
PrimExpr VisitExpr_(const LTNode* op) { | ||
PrimExpr a = VisitExpr(op->a); | ||
PrimExpr b = VisitExpr(op->b); | ||
PrimExpr ret = tir::LT(a, b); | ||
// Pattern var to match any expression | ||
PVar<PrimExpr> x; | ||
// Pattern var match IntImm | ||
PVar<IntImm> c1, c2; | ||
// x < c2 <=> x/c2 < 1 <=> floor(x / c2) < 1 | ||
TRY_RECURSIVE_REWRITE_IF(x<c2, floordiv(x, c2) < 1, c2.Eval()->value> 0); // NOLINT | ||
return ret; | ||
} | ||
|
||
PrimExpr VisitExpr_(const NotNode* op) { | ||
PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op); | ||
// Pattern var to match any expression | ||
PVar<PrimExpr> x, y; | ||
TRY_REWRITE(!(!x), x); | ||
TRY_REWRITE(!(x <= y), y < x); | ||
TRY_REWRITE(!(x >= y), x < y); | ||
TRY_REWRITE(!(x < y), y <= x); | ||
TRY_REWRITE(!(x > y), x <= y); | ||
return ret; | ||
} | ||
|
||
Stmt VisitStmt_(const ForNode* op) final { | ||
analyzer_->Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent)); | ||
With<ConstraintContext> ctx1(analyzer_, op->loop_var >= op->min); | ||
With<ConstraintContext> ctx2(analyzer_, op->loop_var < op->min + op->extent); | ||
return IRMutatorWithAnalyzer::VisitStmt_(op); | ||
} | ||
|
||
// Recursive rewrite x | ||
// we limit maximum depth of recursive rewrite allowed to | ||
// avoid infinite loop | ||
PrimExpr RecursiveRewrite(const PrimExpr& x) { | ||
if (recur_depth_ >= kMaxRecurDepth) return x; | ||
++recur_depth_; | ||
PrimExpr res = this->VisitExpr(x); | ||
--recur_depth_; | ||
return res; | ||
} | ||
|
||
private: | ||
// counter to record recursive rewrite depth. | ||
int recur_depth_{0}; | ||
// maximum number of recursion allowed during a single pass. | ||
static const constexpr int kMaxRecurDepth = 5; | ||
}; | ||
|
||
namespace transform { | ||
|
||
Pass RenormalizeSplitPattern() { | ||
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { | ||
auto* n = f.CopyOnWrite(); | ||
arith::Analyzer analyzer; | ||
n->body = SplitPatternReNormalizer(&analyzer)(std::move(n->body)); | ||
return f; | ||
}; | ||
return CreatePrimFuncPass(pass_func, 0, "tir.RenormalizeSplitPattern", {}); | ||
} | ||
|
||
TVM_REGISTER_GLOBAL("tir.transform.RenormalizeSplitPattern") | ||
.set_body_typed(RenormalizeSplitPattern); | ||
|
||
} // namespace transform | ||
|
||
} // namespace tir | ||
} // namespace tvm |
119 changes: 119 additions & 0 deletions
119
tests/python/unittest/test_tir_transform_renormalize_split_pattern.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
|
||
import tvm | ||
from tvm.script import tir as T | ||
|
||
# fmt: off | ||
# pylint: disable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,redundant-keyword-arg | ||
|
||
@tvm.script.ir_module | ||
class Before: | ||
@T.prim_func | ||
def main(inputs: T.Buffer[(1, 4, 4, 512), "float32"], weight: T.Buffer[(4, 4, 512, 256), "float32"], conv2d_transpose_nhwc: T.Buffer[(1, 8, 8, 256), "float32"]) -> None: | ||
# function attr dict | ||
T.func_attr({"global_symbol": "main", "tir.noalias": True}) | ||
# var definition | ||
threadIdx_x = T.env_thread("threadIdx.x") | ||
blockIdx_x = T.env_thread("blockIdx.x") | ||
# body | ||
T.launch_thread(blockIdx_x, 64) | ||
conv2d_transpose_nhwc_local = T.allocate([8], "float32", "local") | ||
PadInput_shared = T.allocate([768], "float32", "shared") | ||
weight_shared = T.allocate([4096], "float32", "shared") | ||
T.launch_thread(threadIdx_x, 32) | ||
for i2_3_init, i1_4_init, i2_4_init in T.grid(2, 2, 2): | ||
T.store(conv2d_transpose_nhwc_local, i1_4_init * 4 + i2_3_init * 2 + i2_4_init, T.float32(0), True) | ||
for i6_0 in T.serial(16): | ||
for ax0_ax1_ax2_ax3_fused_0 in T.serial(24): | ||
T.store(PadInput_shared, ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x, T.if_then_else(128 <= ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x and ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x < 640 and 1 <= blockIdx_x // 32 * 2 + (ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x) % 128 // 32 and blockIdx_x // 32 * 2 + (ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x) % 128 // 32 < 5, T.load("float32", inputs.data, blockIdx_x // 32 * 1024 + ax0_ax1_ax2_ax3_fused_0 * 512 + i6_0 * 32 + threadIdx_x - 2560), T.float32(0), dtype="float32"), True) | ||
for ax0_ax1_ax2_ax3_fused_0 in T.serial(32): | ||
T.store(weight_shared, T.ramp(ax0_ax1_ax2_ax3_fused_0 * 128 + threadIdx_x * 4, 1, 4), T.load("float32x4", weight.data, T.ramp((ax0_ax1_ax2_ax3_fused_0 * 128 + threadIdx_x * 4) // 256 * 131072 + i6_0 * 8192 + (ax0_ax1_ax2_ax3_fused_0 * 128 + threadIdx_x * 4) % 256 // 8 * 256 + blockIdx_x % 32 * 8 + threadIdx_x % 2 * 4, 1, 4), T.broadcast(True, 4)), T.broadcast(True, 4)) | ||
for i6_1, i2_3, i4_2, i5_2, i6_2, i1_4, i2_4 in T.grid(4, 2, 4, 4, 8, 2, 2): | ||
T.store(conv2d_transpose_nhwc_local, i1_4 * 4 + i2_3 * 2 + i2_4, T.load("float32", conv2d_transpose_nhwc_local, i1_4 * 4 + i2_3 * 2 + i2_4) + T.if_then_else((i1_4 + i4_2) % 2 == 0 and (i2_4 + i5_2) % 2 == 0, T.load("float32", PadInput_shared, threadIdx_x // 8 * 128 + (i1_4 + i4_2) // 2 * 128 + (i2_4 + i5_2) // 2 * 32 + i2_3 * 32 + i6_1 * 8 + i6_2), T.float32(0), dtype="float32") * T.load("float32", weight_shared, i6_1 * 64 + i6_2 * 8 + threadIdx_x % 8 + 3840 - i5_2 * 256 - i4_2 * 1024), True) | ||
for ax1, ax2 in T.grid(2, 4): | ||
T.store(conv2d_transpose_nhwc.data, threadIdx_x // 8 * 4096 + ax1 * 2048 + blockIdx_x // 32 * 1024 + ax2 * 256 + blockIdx_x % 32 * 8 + threadIdx_x % 8, T.load("float32", conv2d_transpose_nhwc_local, ax1 * 4 + ax2), True) | ||
|
||
|
||
@tvm.script.ir_module | ||
class After: | ||
@T.prim_func | ||
def main(inputs: T.Buffer[(1, 4, 4, 512), "float32"], weight: T.Buffer[(4, 4, 512, 256), "float32"], conv2d_transpose_nhwc: T.Buffer[(1, 8, 8, 256), "float32"]) -> None: | ||
# function attr dict | ||
T.func_attr({"global_symbol": "main", "tir.noalias": True}) | ||
# var definition | ||
threadIdx_x = T.env_thread("threadIdx.x") | ||
blockIdx_x = T.env_thread("blockIdx.x") | ||
# body | ||
T.launch_thread(blockIdx_x, 64) | ||
conv2d_transpose_nhwc_local = T.allocate([8], "float32", "local") | ||
PadInput_shared = T.allocate([768], "float32", "shared") | ||
weight_shared = T.allocate([4096], "float32", "shared") | ||
T.launch_thread(threadIdx_x, 32) | ||
for i2_3_init, i1_4_init, i2_4_init in T.grid(2, 2, 2): | ||
T.store(conv2d_transpose_nhwc_local, i1_4_init * 4 + i2_3_init * 2 + i2_4_init, T.float32(0), True) | ||
for i6_0 in T.serial(16): | ||
for ax0_ax1_ax2_ax3_fused_0 in T.serial(24): | ||
T.store(PadInput_shared, ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x, T.if_then_else(1 <= (ax0_ax1_ax2_ax3_fused_0 + threadIdx_x // 32) // 4 and (ax0_ax1_ax2_ax3_fused_0 + threadIdx_x // 32) // 20 < 1 and 1 <= blockIdx_x // 32 * 2 + (ax0_ax1_ax2_ax3_fused_0 + threadIdx_x // 32) % 4 and (blockIdx_x // 32 * 2 + (ax0_ax1_ax2_ax3_fused_0 + threadIdx_x // 32) % 4) // 5 < 1, T.load("float32", inputs.data, blockIdx_x // 32 * 1024 + ax0_ax1_ax2_ax3_fused_0 * 512 + i6_0 * 32 + threadIdx_x - 2560), T.float32(0), dtype="float32"), True) | ||
for ax0_ax1_ax2_ax3_fused_0 in T.serial(32): | ||
T.store(weight_shared, T.ramp(ax0_ax1_ax2_ax3_fused_0 * 128 + threadIdx_x * 4, 1, 4), T.load("float32x4", weight.data, T.ramp((ax0_ax1_ax2_ax3_fused_0 + threadIdx_x * 4 // 128) // 2 * 131072 + i6_0 * 8192 + (ax0_ax1_ax2_ax3_fused_0 * 16 + threadIdx_x * 4 // 8) % 32 * 256 + blockIdx_x % 32 * 8 + threadIdx_x % 2 * 4, 1, 4), T.broadcast(True, 4)), T.broadcast(True, 4)) | ||
for i6_1, i2_3, i4_2, i5_2, i6_2, i1_4, i2_4 in T.grid(4, 2, 4, 4, 8, 2, 2): | ||
T.store(conv2d_transpose_nhwc_local, i1_4 * 4 + i2_3 * 2 + i2_4, T.load("float32", conv2d_transpose_nhwc_local, i1_4 * 4 + i2_3 * 2 + i2_4) + T.if_then_else((i1_4 + i4_2) % 2 == 0 and (i2_4 + i5_2) % 2 == 0, T.load("float32", PadInput_shared, threadIdx_x // 8 * 128 + (i1_4 + i4_2) // 2 * 128 + (i2_4 + i5_2) // 2 * 32 + i2_3 * 32 + i6_1 * 8 + i6_2), T.float32(0), dtype="float32") * T.load("float32", weight_shared, i6_1 * 64 + i6_2 * 8 + threadIdx_x % 8 + 3840 - i5_2 * 256 - i4_2 * 1024), True) | ||
for ax1, ax2 in T.grid(2, 4): | ||
T.store(conv2d_transpose_nhwc.data, threadIdx_x // 8 * 4096 + ax1 * 2048 + blockIdx_x // 32 * 1024 + ax2 * 256 + blockIdx_x % 32 * 8 + threadIdx_x % 8, T.load("float32", conv2d_transpose_nhwc_local, ax1 * 4 + ax2), True) | ||
|
||
|
||
@tvm.script.ir_module | ||
class After_simplified: | ||
@T.prim_func | ||
def main(inputs: T.Buffer[(1, 4, 4, 512), "float32"], weight: T.Buffer[(4, 4, 512, 256), "float32"], conv2d_transpose_nhwc: T.Buffer[(1, 8, 8, 256), "float32"]) -> None: | ||
# function attr dict | ||
T.func_attr({"global_symbol": "main", "tir.noalias": True}) | ||
# var definition | ||
threadIdx_x = T.env_thread("threadIdx.x") | ||
blockIdx_x = T.env_thread("blockIdx.x") | ||
# body | ||
T.launch_thread(blockIdx_x, 64) | ||
conv2d_transpose_nhwc_local = T.allocate([8], "float32", "local") | ||
PadInput_shared = T.allocate([768], "float32", "shared") | ||
weight_shared = T.allocate([4096], "float32", "shared") | ||
T.launch_thread(threadIdx_x, 32) | ||
for i2_3_init, i1_4_init, i2_4_init in T.grid(2, 2, 2): | ||
T.store(conv2d_transpose_nhwc_local, i1_4_init * 4 + i2_3_init * 2 + i2_4_init, T.float32(0), True) | ||
for i6_0 in T.serial(16): | ||
for ax0_ax1_ax2_ax3_fused_0 in T.serial(24): | ||
T.store(PadInput_shared, ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x, T.if_then_else(4 <= ax0_ax1_ax2_ax3_fused_0 and ax0_ax1_ax2_ax3_fused_0 < 20 and 1 <= blockIdx_x // 32 * 2 + ax0_ax1_ax2_ax3_fused_0 % 4 and blockIdx_x // 32 * 2 + ax0_ax1_ax2_ax3_fused_0 % 4 < 5, T.load("float32", inputs.data, blockIdx_x // 32 * 1024 + ax0_ax1_ax2_ax3_fused_0 * 512 + i6_0 * 32 + threadIdx_x - 2560), T.float32(0), dtype="float32"), True) | ||
for ax0_ax1_ax2_ax3_fused_0 in T.serial(32): | ||
T.store(weight_shared, T.ramp(ax0_ax1_ax2_ax3_fused_0 * 128 + threadIdx_x * 4, 1, 4), T.load("float32x4", weight.data, T.ramp(ax0_ax1_ax2_ax3_fused_0 // 2 * 131072 + i6_0 * 8192 + (ax0_ax1_ax2_ax3_fused_0 * 16 + threadIdx_x // 2) % 32 * 256 + blockIdx_x % 32 * 8 + threadIdx_x % 2 * 4, 1, 4), T.broadcast(True, 4)), T.broadcast(True, 4)) | ||
for i6_1, i2_3, i4_2, i5_2, i6_2, i1_4, i2_4 in T.grid(4, 2, 4, 4, 8, 2, 2): | ||
T.store(conv2d_transpose_nhwc_local, i1_4 * 4 + i2_3 * 2 + i2_4, T.load("float32", conv2d_transpose_nhwc_local, i1_4 * 4 + i2_3 * 2 + i2_4) + T.if_then_else((i1_4 + i4_2) % 2 == 0 and (i2_4 + i5_2) % 2 == 0, T.load("float32", PadInput_shared, threadIdx_x // 8 * 128 + (i1_4 + i4_2) // 2 * 128 + (i2_4 + i5_2) // 2 * 32 + i2_3 * 32 + i6_1 * 8 + i6_2), T.float32(0), dtype="float32") * T.load("float32", weight_shared, i6_1 * 64 + i6_2 * 8 + threadIdx_x % 8 + 3840 - i5_2 * 256 - i4_2 * 1024), True) | ||
for ax1, ax2 in T.grid(2, 4): | ||
T.store(conv2d_transpose_nhwc.data, threadIdx_x // 8 * 4096 + ax1 * 2048 + blockIdx_x // 32 * 1024 + ax2 * 256 + blockIdx_x % 32 * 8 + threadIdx_x % 8, T.load("float32", conv2d_transpose_nhwc_local, ax1 * 4 + ax2), True) | ||
|
||
# pylint: enable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,redundant-keyword-arg | ||
# fmt: on | ||
|
||
|
||
def tesd_renormalize_split_pattern(): | ||
after = tvm.tir.transform.RenomalizeSplitPattern()(Before) | ||
tvm.ir.assert_structural_equal(after, After) | ||
after = tvm.tir.transform.Simplify()(after) | ||
tvm.ir.assert_structural_equal(after, After_simplified) | ||
|
||
|
||
if __name__ == "__main__": | ||
tesd_renormalize_split_pattern() |