From 00fc96af6a5adfaef9b0f934ad93b6255db19b19 Mon Sep 17 00:00:00 2001 From: spectrometerHBH Date: Mon, 10 Jan 2022 00:19:32 -0500 Subject: [PATCH] fixing codegen --- include/tvm/tir/transform.h | 5 ++ python/tvm/tir/transform/transform.py | 11 +++ src/arith/modular_set.cc | 12 +++ src/arith/rewrite_simplify.cc | 2 + src/driver/driver_api.cc | 12 +++ .../transforms/renormalize_split_pattern.cc | 87 +++++++++++++++++++ .../python/unittest/test_arith_modular_set.py | 10 +++ .../unittest/test_arith_rewrite_simplify.py | 5 ++ ...tir_transform_renormalize_split_pattern.py | 84 ++++++++++++++++++ 9 files changed, 228 insertions(+) create mode 100644 src/tir/transforms/renormalize_split_pattern.cc create mode 100644 tests/python/unittest/test_tir_transform_renormalize_split_pattern.py diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index f746996e31b2..80f2ab858dc8 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -496,6 +496,11 @@ TVM_DLL Pass InjectSoftwarePipeline(); */ TVM_DLL Pass LowerAutoCopy(); +/*! + * \brief Renormalize the split pattern from floordiv(floormod()) to floormod(floordiv()) + * \return The pass. + */ +TVM_DLL Pass RenormalizeSplitPattern(); } // namespace transform } // namespace tir diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index d82123b97525..1009ba85b824 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -761,6 +761,7 @@ def InjectSoftwarePipeline(): """ return _ffi_api.InjectSoftwarePipeline() # type: ignore + def LowerAutoCopy(): """Automatically do memory optimizations for auto copy blocks @@ -771,3 +772,13 @@ def LowerAutoCopy(): """ return _ffi_api.LowerAutoCopy() + +def RenomalizeSplitPattern(): + """Renormalize the split pattern from floordiv(floormod()) to floormod(floordiv()) + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.RenormalizeSplitPattern() diff --git a/src/arith/modular_set.cc b/src/arith/modular_set.cc index ac176b2623a3..fb101c5ff236 100644 --- a/src/arith/modular_set.cc +++ b/src/arith/modular_set.cc @@ -196,6 +196,18 @@ class ModularSetAnalyzer::Impl : public ExprFunctorb); + if (b.is_const()) { + int64_t c2 = b.base; + ICHECK(c2 != 0) << "MathError: the divisor is 0"; + Entry a = VisitExpr(op->a); + int64_t coeff = ZeroAwareGCD(a.coeff, c2); + return Entry(coeff, a.base % c2); + } + return Everything(); + } + Entry VisitExpr_(const MinNode* op) final { Entry a = VisitExpr(op->a); Entry b = VisitExpr(op->b); diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index 4a99e10211b7..6ac69e64e183 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -192,6 +192,8 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AddNode* op) { TVM_TRY_REWRITE(truncdiv(x, c1) * c1 + truncmod(x, c1), x); // floor div TVM_TRY_REWRITE(floordiv(x, c1) * c1 + floormod(x, c1), x); + TVM_TRY_REWRITE_IF(floordiv(floormod(x, c2) + c1, c2) + floordiv(x, c2), floordiv(x + c1, c2), + c2.Eval()->value > 0); // canonicalization rule // will try rewrite again after canonicalization. diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index e0e74c32fb20..32398aaa6c85 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -36,6 +36,9 @@ #include #include +#include "../printer/text_printer.h" +#include + namespace tvm { // Register build pipeline related options @@ -187,6 +190,14 @@ transform::Pass Filter(FCond fcond) { return tir::transform::CreatePrimFuncPass(fpass, 0, "Filter", {}); } +Pass Print() { + auto pass_func = [](tir::PrimFunc f, IRModule m, transform::PassContext ctx) { + LOG(INFO) << tir::AsTVMScript(f); + return f; + }; + return tir::transform::CreatePrimFuncPass(pass_func, 0, "tir.Print", {}); +} + Array CreatePassList(bool disable_loop_partition) { transform::PassContext pass_ctx = transform::PassContext::Current(); @@ -271,6 +282,7 @@ Array CreatePassList(bool disable_loop_partition) { // PHASE 3 pass_list.push_back(tir::transform::Simplify()); + pass_list.push_back(tir::transform::RenormalizeSplitPattern()); pass_list.push_back(tir::transform::RemoveNoOp()); pass_list.push_back(tir::transform::RewriteUnsafeSelect()); pass_list.push_back(tir::transform::HoistIfThenElse()); diff --git a/src/tir/transforms/renormalize_split_pattern.cc b/src/tir/transforms/renormalize_split_pattern.cc new file mode 100644 index 000000000000..78362d86bcfe --- /dev/null +++ b/src/tir/transforms/renormalize_split_pattern.cc @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file renormalize_split_pattern.cc + * \brief Renormalize the split pattern from floordiv(floormod()) to floormod(floordiv()) + */ +#include +#include +#include +#include +#include +#include + +#include "../../arith/ir_mutator_with_analyzer.h" + +namespace tvm { +namespace tir { + +using namespace arith; + +class SplitPatternReNormalizer : public IRMutatorWithAnalyzer { + public: + explicit SplitPatternReNormalizer(Analyzer* analyzer) : IRMutatorWithAnalyzer(analyzer) {} + + PrimExpr VisitExpr_(const FloorDivNode* op) final { + PrimExpr a = VisitExpr(op->a); + PrimExpr b = VisitExpr(op->b); + // floordiv(floormod(x, c1 * c2), c2) = floormod(floordiv(x, c2), c1) + if (const auto* inner = op->a.as()) { + if (const auto* c2 = op->b.as()) { + if (const auto* c1c2 = inner->b.as()) { + if (c1c2->value % c2->value == 0) { + return analyzer_->Simplify(FloorMod(FloorDiv(inner->a, op->b), IntImm(op->b.dtype(), c1c2->value / c2->value))); + } + } + } + } + if (a.same_as(op->a) && b.same_as(op->b)) { + return GetRef(op); + } else { + return FloorDiv(a, b); + } + } + + Stmt VisitStmt_(const ForNode* op) final { + analyzer_->Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent)); + With ctx1(analyzer_, op->loop_var >= op->min); + With ctx2(analyzer_, op->loop_var < op->min + op->extent); + return IRMutatorWithAnalyzer::VisitStmt_(op); + } +}; + +namespace transform { + +Pass RenormalizeSplitPattern() { + auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { + auto* n = f.CopyOnWrite(); + arith::Analyzer analyzer; + n->body = SplitPatternReNormalizer(&analyzer)(std::move(n->body)); + return f; + }; + return CreatePrimFuncPass(pass_func, 0, "tir.RenormalizeSplitPattern", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.RenormalizeSplitPattern").set_body_typed(RenormalizeSplitPattern); + +} // namespace transform + +} // namespace tir +} // namespace tvm diff --git a/tests/python/unittest/test_arith_modular_set.py b/tests/python/unittest/test_arith_modular_set.py index 4a4cd6a31ef1..7914195effe1 100644 --- a/tests/python/unittest/test_arith_modular_set.py +++ b/tests/python/unittest/test_arith_modular_set.py @@ -16,6 +16,7 @@ # under the License. import tvm from tvm import te +from tvm.arith import analyzer def test_cast(): @@ -50,6 +51,14 @@ def test_mul(): assert m.base == 2 +def test_floormod(): + analyzer = tvm.arith.Analyzer() + x, y = te.var("x"), te.var("y") + m = analyzer.modular_set(tvm.tir.floormod(x * 128 + y * 4, 256)) + assert m.coeff == 4 + assert m.base == 0 + + def test_div_shift(): analyzer = tvm.arith.Analyzer() x, y = te.var("x"), te.var("y") @@ -175,6 +184,7 @@ def test_let(): test_add_sub() test_mul() test_div_shift() + test_floormod() test_min_max_select() test_mix_index() test_constraint_scope() diff --git a/tests/python/unittest/test_arith_rewrite_simplify.py b/tests/python/unittest/test_arith_rewrite_simplify.py index 6ca2a2a5fcb0..549882126d50 100644 --- a/tests/python/unittest/test_arith_rewrite_simplify.py +++ b/tests/python/unittest/test_arith_rewrite_simplify.py @@ -80,6 +80,10 @@ def test_vector_simplify(): ck.verify(fld(tvm.tir.Ramp(x * 8 + 1, 1, 4), 8), (x).astype("int32x4")) ck.verify(fld(tvm.tir.Ramp(x * 8 + 15, 1, 4), 8), fld(tvm.tir.Ramp(x * 8 + 15, 1, 4), 8)) ck.verify(fld(tvm.tir.Ramp(x, 8, 5), tvm.tir.Broadcast(4, 5)), tvm.tir.Ramp(fld(x, 4), 2, 5)) + ck.verify( + fld(tvm.tir.Ramp(flm(x * 4, 256), 1, 4), tvm.tir.Broadcast(8, 4)), + tvm.tir.Broadcast(fld(flm(x * 4, 256), 8), 4) + ) ck.verify( fld(tvm.tir.Ramp(x, 7, 4), tvm.tir.Broadcast(4, 4)), fld(tvm.tir.Ramp(x, 7, 4), tvm.tir.Broadcast(4, 4)), @@ -277,6 +281,7 @@ def test_add_index_simplify(): flm = tvm.te.floormod ck.verify(y * flm(x, 8) + 10 * flm(x, 8), flm(x, 8) * (y + 10)) ck.verify(fld(x, 8) * 8 + flm(x, 8), x) + ck.verify(fld(flm(x, 2) + 7, 2) + fld(x, 2), fld(x + 7, 2)) def test_sub_index_simplify(): diff --git a/tests/python/unittest/test_tir_transform_renormalize_split_pattern.py b/tests/python/unittest/test_tir_transform_renormalize_split_pattern.py new file mode 100644 index 000000000000..6217d2f0989a --- /dev/null +++ b/tests/python/unittest/test_tir_transform_renormalize_split_pattern.py @@ -0,0 +1,84 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +from tvm.script import tir as T + + +@tvm.script.ir_module +class Before: + @T.prim_func + def main(inputs: T.Buffer[(1, 4, 4, 512), "float32"], weight: T.Buffer[(4, 4, 512, 256), "float32"], conv2d_transpose_nhwc: T.Buffer[(1, 8, 8, 256), "float32"]) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + # var definition + threadIdx_x = T.env_thread("threadIdx.x") + blockIdx_x = T.env_thread("blockIdx.x") + # body + T.launch_thread(blockIdx_x, 64) + conv2d_transpose_nhwc_local = T.allocate([8], "float32", "local") + PadInput_shared = T.allocate([768], "float32", "shared") + weight_shared = T.allocate([4096], "float32", "shared") + T.launch_thread(threadIdx_x, 32) + for i2_3_init, i1_4_init, i2_4_init in T.grid(2, 2, 2): + T.store(conv2d_transpose_nhwc_local, i1_4_init * 4 + i2_3_init * 2 + i2_4_init, T.float32(0), True) + for i6_0 in T.serial(16): + for ax0_ax1_ax2_ax3_fused_0 in T.serial(24): + T.store(PadInput_shared, ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x, T.if_then_else(128 <= ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x and ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x < 640 and 1 <= blockIdx_x // 32 * 2 + (ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x) % 128 // 32 and blockIdx_x // 32 * 2 + (ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x) % 128 // 32 < 5, T.load("float32", inputs.data, blockIdx_x // 32 * 1024 + ax0_ax1_ax2_ax3_fused_0 * 512 + i6_0 * 32 + threadIdx_x - 2560), T.float32(0), dtype="float32"), True) + for ax0_ax1_ax2_ax3_fused_0 in T.serial(32): + T.store(weight_shared, T.ramp(ax0_ax1_ax2_ax3_fused_0 * 128 + threadIdx_x * 4, 1, 4), T.load("float32x4", weight.data, T.ramp((ax0_ax1_ax2_ax3_fused_0 * 128 + threadIdx_x * 4) // 256 * 131072 + i6_0 * 8192 + (ax0_ax1_ax2_ax3_fused_0 * 128 + threadIdx_x * 4) % 256 // 8 * 256 + blockIdx_x % 32 * 8 + threadIdx_x % 2 * 4, 1, 4), T.broadcast(True, 4)), T.broadcast(True, 4)) + for i6_1, i2_3, i4_2, i5_2, i6_2, i1_4, i2_4 in T.grid(4, 2, 4, 4, 8, 2, 2): + T.store(conv2d_transpose_nhwc_local, i1_4 * 4 + i2_3 * 2 + i2_4, T.load("float32", conv2d_transpose_nhwc_local, i1_4 * 4 + i2_3 * 2 + i2_4) + T.if_then_else((i1_4 + i4_2) % 2 == 0 and (i2_4 + i5_2) % 2 == 0, T.load("float32", PadInput_shared, threadIdx_x // 8 * 128 + (i1_4 + i4_2) // 2 * 128 + (i2_4 + i5_2) // 2 * 32 + i2_3 * 32 + i6_1 * 8 + i6_2), T.float32(0), dtype="float32") * T.load("float32", weight_shared, i6_1 * 64 + i6_2 * 8 + threadIdx_x % 8 + 3840 - i5_2 * 256 - i4_2 * 1024), True) + for ax1, ax2 in T.grid(2, 4): + T.store(conv2d_transpose_nhwc.data, threadIdx_x // 8 * 4096 + ax1 * 2048 + blockIdx_x // 32 * 1024 + ax2 * 256 + blockIdx_x % 32 * 8 + threadIdx_x % 8, T.load("float32", conv2d_transpose_nhwc_local, ax1 * 4 + ax2), True) + + +@tvm.script.ir_module +class After: + @T.prim_func + def main(inputs: T.Buffer[(1, 4, 4, 512), "float32"], weight: T.Buffer[(4, 4, 512, 256), "float32"], conv2d_transpose_nhwc: T.Buffer[(1, 8, 8, 256), "float32"]) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + # var definition + threadIdx_x = T.env_thread("threadIdx.x") + blockIdx_x = T.env_thread("blockIdx.x") + # body + T.launch_thread(blockIdx_x, 64) + conv2d_transpose_nhwc_local = T.allocate([8], "float32", "local") + PadInput_shared = T.allocate([768], "float32", "shared") + weight_shared = T.allocate([4096], "float32", "shared") + T.launch_thread(threadIdx_x, 32) + for i2_3_init, i1_4_init, i2_4_init in T.grid(2, 2, 2): + T.store(conv2d_transpose_nhwc_local, i1_4_init * 4 + i2_3_init * 2 + i2_4_init, T.float32(0), True) + for i6_0 in T.serial(16): + for ax0_ax1_ax2_ax3_fused_0 in T.serial(24): + T.store(PadInput_shared, ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x, T.if_then_else(128 <= ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x and ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x < 640 and 1 <= blockIdx_x // 32 * 2 + ax0_ax1_ax2_ax3_fused_0 % 4 and blockIdx_x // 32 * 2 + ax0_ax1_ax2_ax3_fused_0 % 4 < 5, T.load("float32", inputs.data, blockIdx_x // 32 * 1024 + ax0_ax1_ax2_ax3_fused_0 * 512 + i6_0 * 32 + threadIdx_x - 2560), T.float32(0), dtype="float32"), True) + for ax0_ax1_ax2_ax3_fused_0 in T.serial(32): + T.store(weight_shared, T.ramp(ax0_ax1_ax2_ax3_fused_0 * 128 + threadIdx_x * 4, 1, 4), T.load("float32x4", weight.data, T.ramp((ax0_ax1_ax2_ax3_fused_0 * 128 + threadIdx_x * 4) // 256 * 131072 + i6_0 * 8192 + (ax0_ax1_ax2_ax3_fused_0 * 16 + threadIdx_x // 2) % 32 * 256 + blockIdx_x % 32 * 8 + threadIdx_x % 2 * 4, 1, 4), T.broadcast(True, 4)), T.broadcast(True, 4)) + for i6_1, i2_3, i4_2, i5_2, i6_2, i1_4, i2_4 in T.grid(4, 2, 4, 4, 8, 2, 2): + T.store(conv2d_transpose_nhwc_local, i1_4 * 4 + i2_3 * 2 + i2_4, T.load("float32", conv2d_transpose_nhwc_local, i1_4 * 4 + i2_3 * 2 + i2_4) + T.if_then_else((i1_4 + i4_2) % 2 == 0 and (i2_4 + i5_2) % 2 == 0, T.load("float32", PadInput_shared, threadIdx_x // 8 * 128 + (i1_4 + i4_2) // 2 * 128 + (i2_4 + i5_2) // 2 * 32 + i2_3 * 32 + i6_1 * 8 + i6_2), T.float32(0), dtype="float32") * T.load("float32", weight_shared, i6_1 * 64 + i6_2 * 8 + threadIdx_x % 8 + 3840 - i5_2 * 256 - i4_2 * 1024), True) + for ax1, ax2 in T.grid(2, 4): + T.store(conv2d_transpose_nhwc.data, threadIdx_x // 8 * 4096 + ax1 * 2048 + blockIdx_x // 32 * 1024 + ax2 * 256 + blockIdx_x % 32 * 8 + threadIdx_x % 8, T.load("float32", conv2d_transpose_nhwc_local, ax1 * 4 + ax2), True) + + +def tesd_renormalize_split_pattern(): + after = tvm.tir.transform.RenomalizeSplitPattern()(Before) + tvm.ir.assert_structural_equal(after, After) + + +if __name__ == "__main__": + tesd_renormalize_split_pattern()