From 0b4d970521510f724071d2c024c27466df4285c4 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Mon, 30 Jan 2023 15:03:18 -0500 Subject: [PATCH] [Fix][Arith] Analyzer simplification starts with canonical This PR updates the order of arithmetic analyzer simplification, by adding a stage of canonical simplification at the very beginning so that every simplification always starts with a canonical round. This is because the rewrite simplification may destroy some PrimExpr property that the canonical simplification can make use of. Therefore, adding the canonical one in the front can maximize the use of canonical simplification. --- src/arith/analyzer.cc | 4 ++ src/tir/op/op.cc | 13 ++++++- tests/python/unittest/test_arith_intset.py | 11 +----- tests/python/unittest/test_arith_simplify.py | 38 +++++++++++++++++++ tests/python/unittest/test_tir_buffer.py | 2 +- .../unittest/test_tir_schedule_analysis.py | 6 +-- .../unittest/test_tir_schedule_rfactor.py | 4 +- 7 files changed, 61 insertions(+), 17 deletions(-) create mode 100644 tests/python/unittest/test_arith_simplify.py diff --git a/src/arith/analyzer.cc b/src/arith/analyzer.cc index 921f8ac7094b7..4714cf1df59fe 100644 --- a/src/arith/analyzer.cc +++ b/src/arith/analyzer.cc @@ -129,6 +129,10 @@ bool Analyzer::CanProve(const PrimExpr& expr) { PrimExpr Analyzer::Simplify(const PrimExpr& expr, int steps) { PrimExpr res = expr; + // Always starts with a canonical simplification, as some structural property + // of an expression might be destroyed by rewrite simplification. + res = this->canonical_simplify(res); + for (int i = 0; i < steps; ++i) { if (tir::is_const_int(res)) { return res; diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc index 078e32ca57c70..4694144438f6c 100644 --- a/src/tir/op/op.cc +++ b/src/tir/op/op.cc @@ -218,11 +218,20 @@ PrimExpr min_value(const DataType& dtype, Span span) { // floatimm min bug) return (*f)(dtype.bits()); } else if (dtype.is_int()) { + // Here we use the actual min value + 1. + // This is because the in integer system, the actual min value and + // max value are not symmetric. In arithmetic analyzer and integer + // expression simplification methods, it is very common to take + // the negative value of integers. If the actual min value of an + // integer dtype is taken the negative value, the result will be + // out of the integer dtype range and lead to many other issues. + // So here we use the actual min value + 1 to avoid the issues of + // "taking the negative of the min value". if (dtype.bits() == 64) { - return IntImm(dtype, std::numeric_limits::lowest(), span); + return IntImm(dtype, std::numeric_limits::lowest() + 1, span); } else if (dtype.bits() < 64) { int64_t val = 1; - val = -(val << (dtype.bits() - 1)); + val = -(val << (dtype.bits() - 1)) + 1; return IntImm(dtype, val, span); } } else if (dtype.is_uint()) { diff --git a/tests/python/unittest/test_arith_intset.py b/tests/python/unittest/test_arith_intset.py index 24228fb527032..da3fd94f8192b 100644 --- a/tests/python/unittest/test_arith_intset.py +++ b/tests/python/unittest/test_arith_intset.py @@ -182,7 +182,6 @@ def check_region_bound(expect_region, var_dom, mode, predicate=None): expect_begin, expect_end = expect_desc[binding] result_begin = analyzer.simplify(intset.min_value, 3) result_end = analyzer.simplify(intset.max_value + 1, 3) - print(result_end) assert analyzer.can_prove_equal( result_begin - expect_begin, 0 ), f"{result_begin} vs {expect_begin}" @@ -306,10 +305,7 @@ def test_region_lower_bound_for_non_perfect_tile(): + h2: { (): ( tvm.tir.max(h3 * 8, 1), - tvm.tir.max(h3 * 8, 1) - - tvm.tir.max(h3 * 8, 214) - - tvm.tir.max(1 - h3 * 8, 0) - + 224, + tvm.tir.min(0, h3 * 8 - 214) + 224, ), ((h3, 0),): (1, 10), # h3 == 0: region is [1, 10) ((h3, 10),): (h3 * 8, h3 * 8 + 10), # 0 < h3 <= 26: region is [h3 * 8, h3 * 8 + 10) @@ -333,10 +329,7 @@ def test_region_lower_bound_for_non_perfect_tile(): + h1: { (): ( tvm.tir.max(h3 * 8, 1), - tvm.tir.max(h3 * 8, 1) - - tvm.tir.max(h3 * 8, 214) - - tvm.tir.max(1 - h3 * 8, 0) - + 224, + tvm.tir.min(0, h3 * 8 - 214) + 224, ), ((h3, 0),): (1, 10), ((h3, 10),): (h3 * 8, h3 * 8 + 10), diff --git a/tests/python/unittest/test_arith_simplify.py b/tests/python/unittest/test_arith_simplify.py new file mode 100644 index 0000000000000..aa9d5179aa3f8 --- /dev/null +++ b/tests/python/unittest/test_arith_simplify.py @@ -0,0 +1,38 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +import tvm.testing +from tvm import tir + + +def test_simplify_reshape_flattened_index(): + ana = tvm.arith.Analyzer() + + i0 = tir.Var("i0", "int64") + i1 = tir.Var("i1", "int64") + ana.bind(i0, tvm.ir.Range(0, 8)) + ana.bind(i1, tvm.ir.Range(0, 3)) + + i_flattened = i0 * 3 + i1 + assert tvm.ir.structural_equal( + ana.simplify((i_flattened) // 12 * 12 + (i_flattened) % 12 // 4 * 4 + (i_flattened) % 4), + i_flattened, + ) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/unittest/test_tir_buffer.py b/tests/python/unittest/test_tir_buffer.py index 55c83167392f7..95ad81db889bd 100644 --- a/tests/python/unittest/test_tir_buffer.py +++ b/tests/python/unittest/test_tir_buffer.py @@ -150,7 +150,7 @@ def assert_simplified_equal(index_simplified, index_direct): index_simplified = A.offset_of( (idxd(idxm(k0, idxd(k1, s)), n), idxm(idxm(k0, idxd(k1, s)), n) + idxm(k0, k1)) ) - index_direct = A.offset_of((0, idxm(k0, k1) + idxm(k0, idxd(k1, s)))) + index_direct = A.offset_of((0, idxm(k0, idxd(k1, s)) + idxm(k0, k1))) assert_simplified_equal(index_simplified, index_direct) # Test Case3 index_simplified = A.offset_of( diff --git a/tests/python/unittest/test_tir_schedule_analysis.py b/tests/python/unittest/test_tir_schedule_analysis.py index 38bd4bba1418d..349c4734c9ee8 100644 --- a/tests/python/unittest/test_tir_schedule_analysis.py +++ b/tests/python/unittest/test_tir_schedule_analysis.py @@ -126,7 +126,7 @@ def test_suggest_index_map_winograd(): floordiv(i0, 2), floordiv(i1, 2), floormod(i0, 2), - floormod(((i1 * 4) + floordiv(i2, 32)), 8), + floormod(i1, 2) * 4 + floordiv(i2, 32), floormod(i2, 32), floordiv(i3, 32), floormod(i3, 32), @@ -137,8 +137,8 @@ def test_suggest_index_map_winograd(): expected_inverse_index_map = IndexMap.from_func( lambda i0, i1, i2, i3, i4, i5, i6: ( ((i0 * 2) + i2), - ((i1 * 2) + floordiv(((i3 * 32) + i4), 128)), - floormod(((i3 * 32) + i4), 128), + i1 * 2 + floordiv(i3, 4), + floormod(i3, 4) * 32 + i4, ((i5 * 32) + i6), ) ) diff --git a/tests/python/unittest/test_tir_schedule_rfactor.py b/tests/python/unittest/test_tir_schedule_rfactor.py index 964fe772d8af5..83bc649933a5e 100644 --- a/tests/python/unittest/test_tir_schedule_rfactor.py +++ b/tests/python/unittest/test_tir_schedule_rfactor.py @@ -1147,7 +1147,7 @@ def argmax_topi_rfactor( T.writes(placeholder_red_temp_v0_rf[ax0, vi1_1], placeholder_red_temp_v1_rf[ax0, vi1_1]) with T.init(): placeholder_red_temp_v0_rf[ax0, vi1_1] = -1 - placeholder_red_temp_v1_rf[ax0, vi1_1] = -2147483648 + placeholder_red_temp_v1_rf[ax0, vi1_1] = T.min_value("int32") v_placeholder_red_temp_v0_rf: T.int32 = T.Select( placeholder_red_temp_v1_rf[ax0, vi1_1] > placeholder[ax0, vi1_0 * 8 + vi1_1] or placeholder_red_temp_v1_rf[ax0, vi1_1] == placeholder[ax0, vi1_0 * 8 + vi1_1] @@ -1169,7 +1169,7 @@ def argmax_topi_rfactor( T.writes(placeholder_red_temp_v0[ax0], placeholder_red_temp_v1[ax0]) with T.init(): placeholder_red_temp_v0[ax0] = -1 - placeholder_red_temp_v1[ax0] = -2147483648 + placeholder_red_temp_v1[ax0] = T.min_value("int32") v_placeholder_red_temp_v0: T.int32 = T.Select( placeholder_red_temp_v1[ax0] > placeholder_red_temp_v1_rf[ax0, vi1_1] or placeholder_red_temp_v1[ax0] == placeholder_red_temp_v1_rf[ax0, vi1_1]