From 7292ae6acb6f78d54e87f9986b568936a5141fd2 Mon Sep 17 00:00:00 2001 From: Yutetsu TAKATSUKASA Date: Sun, 31 Mar 2019 03:13:54 +0900 Subject: [PATCH] Consistent result of DetectLinearEquation() when an empy vars is passed (#2860) --- src/arithmetic/detect_linear_equation.cc | 30 ++++++++----------- src/pass/inject_copy_intrin.cc | 11 +++---- .../test_arith_detect_linear_equation.py | 8 +++++ 3 files changed, 27 insertions(+), 22 deletions(-) diff --git a/src/arithmetic/detect_linear_equation.cc b/src/arithmetic/detect_linear_equation.cc index 6f4d3cfb53bb..e7bc7e74b675 100644 --- a/src/arithmetic/detect_linear_equation.cc +++ b/src/arithmetic/detect_linear_equation.cc @@ -127,25 +127,21 @@ Array DetectLinearEquation(const Expr& e, const Array& vars) { Expr base = e; Array coeff; - if (0 == vars.size()) { - coeff.push_back(make_const(Int(32), 1)); - } else { - for (Var v : vars) { - LinearEqEntry ret; - if (!LinearEqDetector(v).Detect(base, &ret)) { - return Array(); - } - coeff.push_back(ret.coeff); - base = std::move(ret.base); + for (Var v : vars) { + LinearEqEntry ret; + if (!LinearEqDetector(v).Detect(base, &ret)) { + return Array(); } + coeff.push_back(ret.coeff); + base = std::move(ret.base); + } - std::unordered_set vset; - for (size_t i = vars.size(); i != 1; --i) { - vset.insert(vars[i - 1].get()); - // The previous coeff contains the variable - if (ExprUseVar(coeff[i - 2], vset)) { - return Array(); - } + std::unordered_set vset; + for (size_t i = vars.size(); i > 1; --i) { + vset.insert(vars[i - 1].get()); + // The previous coeff contains the variable + if (ExprUseVar(coeff[i - 2], vset)) { + return Array(); } } coeff.push_back(base); diff --git a/src/pass/inject_copy_intrin.cc b/src/pass/inject_copy_intrin.cc index 7ca1d133bd2d..7dcfcfdae239 100644 --- a/src/pass/inject_copy_intrin.cc +++ b/src/pass/inject_copy_intrin.cc @@ -39,7 +39,6 @@ class CopyIntrinInjector : public IRMutator { bool MatchCopyPattern(Stmt stmt, Stmt *out) { using namespace arith; Stmt body = stmt; - bool is_single_point_copy = false; // strip the loops std::vector loops; @@ -60,7 +59,6 @@ class CopyIntrinInjector : public IRMutator { const Cast* cast = store->value.as(); const Load* load = store->value.as(); if (0 == loops.size()) { - is_single_point_copy = true; CHECK(!has_cond); } // for now only support true condition matching @@ -83,9 +81,8 @@ class CopyIntrinInjector : public IRMutator { arith::DetectLinearEquation(load->index, loop_vars); if (load_strides.size() == 0 || store_strides.size() == 0) return false; Array dst_shape; - auto loop_var_size = loop_vars.size(); - if (is_single_point_copy) { - loop_var_size = 1; + const size_t loop_var_size = loop_vars.size(); + if (loop_var_size == 0) { dst_shape.push_back(make_const(Int(32), 1)); } else { for (const For* op : loops) { @@ -132,6 +129,10 @@ class CopyIntrinInjector : public IRMutator { CHECK_EQ(load_strides.size(), loop_var_size + 1); Array src_strides(load_strides.begin(), load_strides.begin() + loop_var_size); Array dst_strides(store_strides.begin(), store_strides.begin() + loop_var_size); + if (loop_var_size == 0) { + src_strides.push_back(make_const(Int(32), 1)); + dst_strides.push_back(make_const(Int(32), 1)); + } Buffer dst = BufferNode::make( Var(store->buffer_var.node_), store->value.type(), diff --git a/tests/python/unittest/test_arith_detect_linear_equation.py b/tests/python/unittest/test_arith_detect_linear_equation.py index 2b0f327b65b2..33e266684f09 100644 --- a/tests/python/unittest/test_arith_detect_linear_equation.py +++ b/tests/python/unittest/test_arith_detect_linear_equation.py @@ -20,6 +20,10 @@ def test_basic(): m = tvm.arith.DetectLinearEquation(b * 7, [a]) assert m[0].value == 0 + m = tvm.arith.DetectLinearEquation(b * 7, []) + assert len(m) == 1 + assert tvm.ir_pass.Simplify(m[0] - b * 7).value == 0 + def test_multivariate(): v = [tvm.var("v%d" % i) for i in range(4)] b = tvm.var("b") @@ -42,6 +46,10 @@ def test_multivariate(): assert(m[0].value == 0) assert(tvm.ir_pass.Simplify(m[1] - (v[0] - v[1])).value == 0) + m = tvm.arith.DetectLinearEquation((v[0] - v[1]), []) + assert(len(m) == 1) + assert(tvm.ir_pass.Simplify(m[0] - (v[0] - v[1])).value == 0) + if __name__ == "__main__": test_basic() test_multivariate()