From 0e5701f3ec994f4e98e00226a9a6e923682d248e Mon Sep 17 00:00:00 2001 From: tqchen Date: Sat, 27 Apr 2019 16:52:12 -0700 Subject: [PATCH] [ARITH] Revamp IntSet --- include/tvm/arithmetic.h | 194 ++-- python/tvm/arith.py | 43 +- src/api/api_arith.cc | 5 + src/arithmetic/analyzer.cc | 4 +- src/arithmetic/bound_deducer.cc | 8 +- src/arithmetic/canonical_simplify.cc | 1 - src/arithmetic/compute_expr.h | 10 +- src/arithmetic/const_fold.h | 58 +- src/arithmetic/detect_linear_equation.cc | 8 +- src/arithmetic/int_op_overflow.h | 4 +- src/arithmetic/int_set.cc | 1017 +++++++++-------- src/arithmetic/int_set.h | 143 +++ src/arithmetic/int_set_internal.h | 79 -- src/lang/expr_operator.cc | 19 +- src/pass/loop_partition.cc | 29 +- .../unittest/test_arith_deduce_bound.py | 168 +++ tests/python/unittest/test_arith_intset.py | 224 ++-- 17 files changed, 1174 insertions(+), 840 deletions(-) create mode 100644 src/arithmetic/int_set.h delete mode 100644 src/arithmetic/int_set_internal.h create mode 100644 tests/python/unittest/test_arith_deduce_bound.py diff --git a/include/tvm/arithmetic.h b/include/tvm/arithmetic.h index 600e3c565358..f4988409b6b8 100644 --- a/include/tvm/arithmetic.h +++ b/include/tvm/arithmetic.h @@ -328,71 +328,14 @@ class ConstraintContext { std::function exit_; }; -/*! - * \brief Analyzer that contains bunch of sub-analyzers. - * - * Each sub-analyzer can make use of another sub-analyzer - * by weak reference of this. - * - * NOTE for sub-analyzer developers: - * If the analyzer uses memoization, we need to clear the internal - * cache when information about a Var has been overrideen. - */ -class Analyzer { - public: - /*! \brief sub-analyzer: const integer bound */ - ConstIntBoundAnalyzer const_int_bound; - /*! \brief sub-analyzer: modular set */ - ModularSetAnalyzer modular_set; - /*! \brief sub-analyzer rewrite simplify */ - RewriteSimplifier rewrite_simplify; - /*! \brief sub-analyzer canonical simplify */ - CanonicalSimplifier canonical_simplify; - /*! \brief constructor */ - Analyzer(); - /*! - * \brief Notify all the sub-analyzers that var - * is created and binded to expr. - * - * Each var can only be binded once. - * - * \param var The variable. - * \param expr The expression we bind to. - */ - void Bind(const VarExpr& var, const Expr& expr); - /*! - * \brief Notify all the sub-analyzers that var - * is created and binded to a range. - * - * Each var can only be binded once. - * - * \param var The variable. - * \param range The range we bind to. - */ - void Bind(const VarExpr& var, const Range& range); - /*! - * \brief Whether can we proof expr >= val. - - * Non-negative proof is very useful in integer analysis - * to lower divisions and mods given difference in trunc and ceil mode. - * - * \param expr The expression. - * \param lower_bound The lower bound. - * \return Whether we can proof it. - * - * \note Analyzer will call into sub-analyzers to get the result. - */ - bool CanProveGreaterEqual(const Expr& expr, int64_t lower_bound); -}; - //----------------------------------------------- -// Integer set abstraction API. +// Integer set data structure. // // This is a API build on top of the base // integer analysis API to provide set analysis. //------------------------------------------------ /*! - * \brief Sign of an expression or set. + * \brief Sign type of an integer expression. */ enum SignType { kPositive, @@ -401,8 +344,13 @@ enum SignType { kUnknown }; -// internal node container of int set. -struct IntSetNode; +/*! + * \brief Base class of all IntSet containers. + */ +struct IntSetNode : public Node { + static constexpr const char* _type_key = "IntSet"; + TVM_DECLARE_BASE_NODE_INFO(IntSetNode, Node); +}; /*! * \brief Integer set class, represent a set of integers in one dimension. @@ -424,11 +372,6 @@ class IntSet : public NodeRef { * \return The covering range. */ Range cover_range(Range max_range) const; - /*! - * \brief find an interval that covers the set. - * \return The covering interval set. - */ - IntSet cover_interval() const; /*! \return Lower bound of the set */ Expr min() const; /*! \return upper bound of the set */ @@ -493,33 +436,91 @@ class IntSet : public NodeRef { }; /*! - * \brief Base class of all IntSet containers. + * \brief Integer set analyzer. */ -struct IntSetNode : public Node { - static constexpr const char* _type_key = "IntSet"; - TVM_DECLARE_BASE_NODE_INFO(IntSetNode, Node); +class IntSetAnalyzer { + public: + /*! + * \brief Find an symbolic integer set that contains all possible values of + * expr given the domain of each variables. + * + * \param expr The expression of interest. + * \param dom_map The domain map to indicate which variable to relax. + * \return the result of the analysis. + */ + IntSet operator()(const Expr& expr, const Map& dom_map); + + private: + friend class Analyzer; + explicit IntSetAnalyzer(Analyzer* parent); + ~IntSetAnalyzer(); + class Impl; + /*! \brief Internal impl */ + Impl* impl_; }; /*! - * \brief Detect if e can be rewritten as e = sum_{i=0}^{n-1} var[i] * coeff[i] + coeff[n] - * Where coeff[i] and base are invariant of var[j] for all i and j. + * \brief Analyzer that contains bunch of sub-analyzers. * - * \param e The expression to be detected. - * \param vars List of variables to be used in detection. - * \return [coeff[i]] if it is possible, empty array if it is not. - */ -Array DetectLinearEquation(const Expr& e, const Array& vars); - -/*! - * \brief Detect if expression corresponds to clip bound of the vars + * Each sub-analyzer can make use of another sub-analyzer + * by weak reference of this. * - * \param e The expression to be detected. - * \param vars List of variables to be used in detection. - * \return concat([min_value[i], max_value[i]]), None is returned if there is no min or max value - * return empty if the e does not match the pattern. + * NOTE for sub-analyzer developers: + * If the analyzer uses memoization, we need to clear the internal + * cache when information about a Var has been overrideen. */ -Array DetectClipBound(const Expr& e, const Array& vars); +class Analyzer { + public: + /*! \brief sub-analyzer: const integer bound */ + ConstIntBoundAnalyzer const_int_bound; + /*! \brief sub-analyzer: modular set */ + ModularSetAnalyzer modular_set; + /*! \brief sub-analyzer rewrite simplify */ + RewriteSimplifier rewrite_simplify; + /*! \brief sub-analyzer canonical simplify */ + CanonicalSimplifier canonical_simplify; + /*! \brief sub-analyzer: int set */ + IntSetAnalyzer int_set; + /*! \brief constructor */ + Analyzer(); + /*! + * \brief Notify all the sub-analyzers that var + * is created and binded to expr. + * + * Each var can only be binded once. + * + * \param var The variable. + * \param expr The expression we bind to. + */ + void Bind(const VarExpr& var, const Expr& expr); + /*! + * \brief Notify all the sub-analyzers that var + * is created and binded to a range. + * + * Each var can only be binded once. + * + * \param var The variable. + * \param range The range we bind to. + */ + void Bind(const VarExpr& var, const Range& range); + /*! + * \brief Whether can we proof expr >= val. + + * Non-negative proof is very useful in integer analysis + * to lower divisions and mods given difference in trunc and ceil mode. + * + * \param expr The expression. + * \param lower_bound The lower bound. + * \return Whether we can proof it. + * + * \note Analyzer will call into sub-analyzers to get the result. + */ + bool CanProveGreaterEqual(const Expr& expr, int64_t lower_bound); +}; +//----------------------------------------------- +// Integer set legacy API. +//------------------------------------------------ /*! * \brief Find an symbolic integer set that contains all possible values of * e given the domain of each iteration variables. @@ -638,6 +639,29 @@ IntSet DeduceBound(Expr v, Expr cond, */ Domain DomainTouched(Stmt body, const Tensor &tensor, bool consider_calls, bool consider_provides); +// Expression pattern detector. +/*! + * \brief Detect if e can be rewritten as e = sum_{i=0}^{n-1} var[i] * coeff[i] + coeff[n] + * Where coeff[i] and base are invariant of var[j] for all i and j. + * + * \param e The expression to be detected. + * \param vars List of variables to be used in detection. + * \return [coeff[i]] if it is possible, empty array if it is not. + */ +Array DetectLinearEquation(const Expr& e, + const Array& vars); + +/*! + * \brief Detect if expression corresponds to clip bound of the vars + * + * \param e The expression to be detected. + * \param vars List of variables to be used in detection. + * \return concat([min_value[i], max_value[i]]), None is returned if there is no min or max value + * return empty if the e does not match the pattern. + */ +Array DetectClipBound(const Expr& e, + const Array& vars); + // implementation inline const IntSetNode* IntSet::operator->() const { return static_cast(node_.get()); diff --git a/python/tvm/arith.py b/python/tvm/arith.py index eda5cb825326..4c3c05f75796 100644 --- a/python/tvm/arith.py +++ b/python/tvm/arith.py @@ -32,21 +32,21 @@ def is_everything(self): return _api_internal._IntSetIsEverything(self) -@register_node +@register_node("arith.IntervalSet") class IntervalSet(IntSet): - """Represent set of continuous interval""" - def min(self): - """get the minimum value""" - return _api_internal._IntervalSetGetMin(self) - - def max(self): - """get the maximum value""" - return _api_internal._IntervalSetGetMax(self) + """Represent set of continuous interval [min_value, max_value] + Parameters + ---------- + min_value : Expr + The minimum value in the interval. -@register_node -class StrideSet(IntSet): - """Represent set of strided integers""" + max_value : Expr + The maximum value in the interval. + """ + def __init__(self, min_value, max_value): + self.__init_handle_by_constructor__( + _make_IntervalSet, min_value, max_value) @register_node("arith.ModularSet") @@ -114,6 +114,7 @@ def __init__(self): self._modular_set = _mod("modular_set") self._rewrite_simplify = _mod("rewrite_simplify") self._canonical_simplify = _mod("canonical_simplify") + self._int_set = _mod("int_set") self._enter_constraint_context = _mod("enter_constraint_context") def const_int_bound(self, expr): @@ -176,6 +177,24 @@ def canonical_simplify(self, expr): """ return self._canonical_simplify(expr) + def int_set(self, expr, dom_map): + """Compute a symbolic IntSet that covers expr for all values in dom_map. + + Parameters + ---------- + expr : tvm.Expr + The expression. + + dom_map : Dict[Var, tvm.arith.IntSet] + The domain for variables to be relaxed. + + Returns + ------- + result : IntSet + The result. + """ + return self._int_set(expr, dom_map) + def bind(self, var, expr): """Bind a variable to the expression. diff --git a/src/api/api_arith.cc b/src/api/api_arith.cc index 4d5d8bdf58d3..f31f02b1eaf4 100644 --- a/src/api/api_arith.cc +++ b/src/api/api_arith.cc @@ -39,6 +39,7 @@ TVM_REGISTER_API("arith.intset_vector") TVM_REGISTER_API("arith.intset_interval") .set_body_typed(IntSet::interval); + TVM_REGISTER_API("arith.DetectLinearEquation") .set_body_typed(DetectLinearEquation); @@ -110,6 +111,10 @@ TVM_REGISTER_API("arith._CreateAnalyzer") return PackedFunc([self](TVMArgs args, TVMRetValue *ret) { *ret = self->canonical_simplify(args[0]); }); + } else if (name == "int_set") { + return PackedFunc([self](TVMArgs args, TVMRetValue *ret) { + *ret = self->int_set(args[0], args[1]); + }); } else if (name == "bind") { return PackedFunc([self](TVMArgs args, TVMRetValue *ret) { auto& sptr = args[1].node_sptr(); diff --git a/src/arithmetic/analyzer.cc b/src/arithmetic/analyzer.cc index bd8c7005f458..99803ba40877 100644 --- a/src/arithmetic/analyzer.cc +++ b/src/arithmetic/analyzer.cc @@ -31,7 +31,8 @@ Analyzer::Analyzer() : const_int_bound(this), modular_set(this), rewrite_simplify(this), - canonical_simplify(this) { + canonical_simplify(this), + int_set(this) { } void Analyzer::Bind(const VarExpr& v, const Expr& expr) { @@ -77,6 +78,7 @@ bool Analyzer::CanProveGreaterEqual(const Expr& expr, int64_t lower_bound) { return ptr->value > lower_bound; } auto bd = this->const_int_bound(this->rewrite_simplify(expr)); + LOG(INFO) << bd; if (bd->min_value >= lower_bound) return true; return false; } diff --git a/src/arithmetic/bound_deducer.cc b/src/arithmetic/bound_deducer.cc index 89e556c6f75f..395a371f43af 100644 --- a/src/arithmetic/bound_deducer.cc +++ b/src/arithmetic/bound_deducer.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -30,12 +30,12 @@ #include #include +#include "int_set.h" namespace tvm { namespace arith { using namespace ir; -using HalideIR::Internal::Interval; // a visitor to find the path to the target variable // from a expression. @@ -293,7 +293,7 @@ IntSet DeduceBound(Expr v, Expr e, BoundDeducer d(v, e, hint_map, relax_map); d.Deduce(); if (!d.success) return IntSet::nothing(); - Expr min = Interval::neg_inf, max = Interval::pos_inf; + Expr min = neg_inf(), max = pos_inf(); if (d.is_greater) { min = d.result; } else { diff --git a/src/arithmetic/canonical_simplify.cc b/src/arithmetic/canonical_simplify.cc index 1bf1f84fb635..70779e6c186b 100644 --- a/src/arithmetic/canonical_simplify.cc +++ b/src/arithmetic/canonical_simplify.cc @@ -18,7 +18,6 @@ */ /*! - * Copyright (c) 2019 by Contributors * \file canonical_simplify.cc * \brief Canonical form based simplification. */ diff --git a/src/arithmetic/compute_expr.h b/src/arithmetic/compute_expr.h index ff2fb8dbd4ac..cc54bff596be 100644 --- a/src/arithmetic/compute_expr.h +++ b/src/arithmetic/compute_expr.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -27,8 +27,8 @@ #define TVM_ARITHMETIC_COMPUTE_EXPR_H_ #include -#include #include +#include namespace tvm { namespace arith { @@ -105,12 +105,12 @@ inline Expr ComputeExpr(Expr a, Expr b) { template<> inline Expr ComputeExpr(Expr a, Expr b) { - return HalideIR::Internal::Interval::make_max(a, b); + return max(a, b); } template<> inline Expr ComputeExpr(Expr a, Expr b) { - return HalideIR::Internal::Interval::make_min(a, b); + return min(a, b); } template diff --git a/src/arithmetic/const_fold.h b/src/arithmetic/const_fold.h index fbf8fe7e6f89..ec50aef5c51e 100644 --- a/src/arithmetic/const_fold.h +++ b/src/arithmetic/const_fold.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -206,6 +206,7 @@ inline Expr TryConstFold(Expr a, Expr b) { if (pa && pb) return IntImm::make(rtype, std::min(pa->value, pb->value)); if (fa && fb) return FloatImm::make(rtype, std::min(fa->value, fb->value)); }); + if (a.same_as(b)) return a; return Expr(); } @@ -216,6 +217,7 @@ inline Expr TryConstFold(Expr a, Expr b) { if (pa && pb) return IntImm::make(rtype, std::max(pa->value, pb->value)); if (fa && fb) return FloatImm::make(rtype, std::max(fa->value, fb->value)); }); + if (a.same_as(b)) return a; return Expr(); } @@ -307,6 +309,58 @@ inline Expr TryConstFold(Expr a) { return Expr(); } +/*! \brief Helper namespace for symbolic value limits */ +struct SymbolicLimits { + /*! \brief positive infinity */ + static Expr pos_inf_; + /*! \brief negative infinity */ + static Expr neg_inf_; +}; + +/*! + * \brief Opaque expression representing positive infinity. + * + * It can can only be used as parameter of by min/max + * for integer analysis and cannot be used in normal expressions. + * + * \return positive infinity. + */ +inline Expr pos_inf() { + return SymbolicLimits::pos_inf_; +} + +/*! + * \brief Check if value is positive infinity. + * \param value The value to be checked. + * + * \return The check result. + */ +inline bool is_pos_inf(const Expr& value) { + return value.same_as(SymbolicLimits::pos_inf_); +} + +/*! + * \brief Opaque expression representing negative infinity. + * + * It can can only be used as parameter of by min/max + * for integer analysis and cannot be used in normal expressions. + * + * \return negative infinity. + */ +inline Expr neg_inf() { + return SymbolicLimits::neg_inf_; +} + +/*! + * \brief Check if value is negative infinity. + * \param value The value to be checked. + * + * \return The check result. + */ +inline bool is_neg_inf(const Expr& value) { + return value.same_as(SymbolicLimits::neg_inf_); +} + } // namespace arith } // namespace tvm #endif // TVM_ARITHMETIC_CONST_FOLD_H_ diff --git a/src/arithmetic/detect_linear_equation.cc b/src/arithmetic/detect_linear_equation.cc index 2fe21fef7e21..e584c8b1ce33 100644 --- a/src/arithmetic/detect_linear_equation.cc +++ b/src/arithmetic/detect_linear_equation.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -19,8 +19,8 @@ /*! * Copyright (c) 2017 by Contributors - * \file bound_deducer.cc - * \brief Utility to deduce bound of expression + * \file detect_linear_equation.cc + * \brief Utility to detect patterns in the expression. */ #include #include diff --git a/src/arithmetic/int_op_overflow.h b/src/arithmetic/int_op_overflow.h index 87f4f059e858..b78f21cb1dba 100644 --- a/src/arithmetic/int_op_overflow.h +++ b/src/arithmetic/int_op_overflow.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY diff --git a/src/arithmetic/int_set.cc b/src/arithmetic/int_set.cc index abbb7cd9744e..7493cfbb11ea 100644 --- a/src/arithmetic/int_set.cc +++ b/src/arithmetic/int_set.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -18,201 +18,55 @@ */ /*! - * Copyright (c) 2017 by Contributors * \file int_set.cc * \brief The integer set functions */ #include -#include -#include #include -#include +#include +#include +#include #include -#include "compute_expr.h" -#include "int_set_internal.h" +#include "int_set.h" +#include "pattern_match.h" namespace tvm { namespace arith { -using HalideIR::Internal::Interval; -using namespace ir; - -inline IntSet IntSet::cover_interval() const { - if ((*this).as()) return *this; - const StrideSet* s = (*this).as(); - if (s) { - CHECK_NE(s->extents.size(), 0U); - Expr max = s->base.max; - for (size_t i = 0; i < s->extents.size(); ++i) { - max = max + s->extents[i] * s->strides[i] - s->strides[i]; - } - return IntervalSet::make(s->base.min, Simplify(max)); - } - LOG(FATAL) << "cannot convert set " << (*this)->type_key() << " to interval"; - return IntSet::everything(); -} - -Range IntSet::cover_range(Range max_range) const { - IntSet temp; - const IntervalSet* s_int = (*this).as(); - if (s_int == nullptr) { - temp = this->cover_interval(); - s_int = temp.as(); - } - if (s_int->i.is_bounded()) { - return Range::make_by_min_extent( - s_int->i.min, Simplify(s_int->i.max + 1 - s_int->i.min)); - } - return max_range; -} - -Expr IntSet::min() const { - const IntervalSet* s_int = (*this).as(); - CHECK(s_int); - return s_int->i.min; -} - -Expr IntSet::max() const { - const IntervalSet* s_int = (*this).as(); - CHECK(s_int); - return s_int->i.max; -} - -bool IntSet::is_nothing() const { - const IntervalSet* s_int = (*this).as(); - return (s_int && s_int->i.is_empty()); -} - -bool IntSet::is_everything() const { - const IntervalSet* s_int = (*this).as(); - return (s_int && s_int->i.is_everything()); -} +Expr SymbolicLimits::pos_inf_ = Var("pos_inf", Handle()); +Expr SymbolicLimits::neg_inf_ = Var("neg_inf", Handle()); -bool IntSet::is_single_point() const { - const IntervalSet* s_int = (*this).as(); - return (s_int && s_int->i.is_single_point()); +IntervalSet::IntervalSet(Expr min_value, Expr max_value) { + auto node = make_node(); + node->min_value = std::move(min_value); + node->max_value = std::move(max_value); + node_ = std::move(node); } -bool IntSet::can_prove_positive() const { - const IntervalSet* s_int = (*this).as(); - return (s_int && is_positive_const(ir::Simplify(s_int->i.min))); +IntervalSet MakeIntervalSet(Expr min_value, Expr max_value) { + return IntervalSet(min_value, max_value); } -bool IntSet::can_prove_negative() const { - const IntervalSet* s_int = (*this).as(); - return (s_int && is_negative_const(ir::Simplify(s_int->i.max))); -} +TVM_REGISTER_API("arith._make_IntervalSet") +.set_body_typed(MakeIntervalSet); -bool IntSet::can_prove_non_positive() const { - if (const IntervalSet* s_int = (*this).as()) { - auto max = ir::Simplify(s_int->i.max); - return is_zero(max) || is_negative_const(max); - } - return false; -} -bool IntSet::can_prove_non_negative() const { - if (const IntervalSet* s_int = (*this).as()) { - // Any reason why we should or should not use can_prove() to implement - // these functions? - auto min = ir::Simplify(s_int->i.min); - return is_zero(min) || is_positive_const(min); - } - return false; -} - - -SignType IntSet::sign_type() const { - if (can_prove_positive()) { - return kPositive; - } else if (can_prove_negative()) { - return kNegative; - } else if (is_single_point() && is_zero(point_value())) { - return kZero; +IntervalSet Intersect(Analyzer* analyzer, IntervalSet a, IntervalSet b) { + Expr max_value = min(a->max_value, b->max_value); + Expr min_value = max(a->min_value, b->min_value); + if ((max_value.type().is_int() || max_value.type().is_uint()) && + (min_value.type().is_int() || max_value.type().is_uint()) && + analyzer->CanProveGreaterEqual(min_value - max_value, 1)) { + return IntervalSet::Empty(); } else { - return kUnknown; - } -} -Expr IntSet::point_value() const { - const IntervalSet* s_int = (*this).as(); - CHECK(s_int && s_int->i.is_single_point()); - return s_int->i.min; -} - -IntSet IntSet::nothing() { - return IntervalSet::make(Interval::nothing()); -} - -IntSet IntSet::everything() { - return IntervalSet::make(Interval::everything()); -} - -IntSet IntSet::single_point(Expr x) { - return IntervalSet::make(Interval::single_point(x)); -} - -IntSet IntSet::range(Range r) { - // must make sure it can be matched back by MatchRange. - if (is_one(r->extent)) { - return IntSet::single_point(r->min); - } - if (is_positive_const(r->extent) && is_const(r->min)) { - return IntervalSet::make( - r->min, ComputeExpr(ComputeExpr(r->extent, r->min), 1)); - } - return IntervalSet::make(r->min, (r->extent + r->min) - 1); -} - -IntSet IntSet::interval(Expr min, Expr max) { - if (min.same_as(max)) { - return IntSet::single_point(min); - } - return IntervalSet::make(min, max); -} - -inline bool prove_equal(Expr lhs, Expr rhs) { - return is_zero(ir::Simplify(lhs - rhs)); -} - -// Check if a is created from b. -bool IntSet::match_range(const Range& b) const { - const IntSet& a = *this; - const IntervalSet* a_int = a.as(); - if (!a_int) return false; - const Interval& i = a_int->i; - return prove_equal(i.min, b->min) && - prove_equal(i.max, ComputeExpr(ComputeExpr(b->extent, b->min), 1)); -} - -inline bool MatchPoint(const IntSet& a, - const Expr& b) { - const IntervalSet* a_int = a.as(); - if (!a_int) return false; - const Interval& i = a_int->i; - return i.is_single_point() && i.min.same_as(b); -} - -IntSet Union(const Array& sets) { - if (sets.size() == 0) return IntSet::nothing(); - if (sets.size() == 1) return sets[0]; - Interval x = sets[0].cover_interval().as()->i; - for (size_t i = 1; i < sets.size(); ++i) { - IntSet s = sets[i].cover_interval(); - const Interval& y = s.as()->i; - x.include(y); + return IntervalSet(min_value, max_value); } - x.max = ir::Simplify(x.max); - x.min = ir::Simplify(x.min); - return IntervalSet::make(x); } -IntSet Intersect(const Array& sets) { - Interval x = sets[0].cover_interval().as()->i; - for (size_t i = 1; i < sets.size(); ++i) { - Interval y = sets[i].cover_interval().as()->i; - x = Interval::make_intersection(x, y); - } - return IntervalSet::make(x); +IntervalSet Union(Analyzer* analyzer, IntervalSet a, IntervalSet b) { + Expr max_value = max(a->max_value, b->max_value); + Expr min_value = min(a->min_value, b->min_value); + return IntervalSet(min_value, max_value); } // type traits @@ -227,407 +81,620 @@ struct is_logical_op { static const bool value = true; \ }; -// interval related. -template -inline IntSet CombineInterval(Interval a, Interval b) { - if (a.is_single_point() && b.is_single_point()) { - return IntSet::single_point(ComputeExpr(a.min, b.min)); - } - LOG(WARNING) << "Return Everything in CombineInterval " << OP::_type_key; - return IntSet::everything(); +TVM_DECLARE_LOGICAL_OP(And); +TVM_DECLARE_LOGICAL_OP(Or); +TVM_DECLARE_LOGICAL_OP(EQ); +TVM_DECLARE_LOGICAL_OP(NE); +TVM_DECLARE_LOGICAL_OP(GE); +TVM_DECLARE_LOGICAL_OP(GT); +TVM_DECLARE_LOGICAL_OP(LE); +TVM_DECLARE_LOGICAL_OP(LT); +TVM_DECLARE_LOGICAL_OP(Not); + +/*! + * \brief Combine two interval set under arithmetic operations. + * \note this can possibly relax the set. + */ +template +inline IntervalSet Combine(Analyzer* analyzer, + IntervalSet a, + IntervalSet b) { + if (a->IsSinglePoint() && b->IsSinglePoint()) { + Expr res = TryConstFold(a->min_value, b->min_value); + if (!res.defined()) res = Op::make(a->min_value, b->min_value); + return IntervalSet::SinglePoint(res); + } + if (is_logical_op::value) { + return IntervalSet(make_const(a->min_value.type(), 0), + make_const(a->min_value.type(), 1)); + } + if (a->IsEmpty()) return a; + if (b->IsEmpty()) return b; + if (a->IsEverything()) return a; + if (b->IsEverything()) return b; + return IntervalSet::Everything(); } template<> -inline IntSet CombineInterval(Interval a, Interval b) { - if (a.is_single_point() && b.is_single_point()) { - return IntSet::single_point(ComputeExpr(a.min, b.min)); - } - Interval r = Interval::everything(); - if (a.has_lower_bound() && b.has_lower_bound()) { - r.min = ComputeExpr(a.min, b.min); - } - if (a.has_upper_bound() && b.has_upper_bound()) { - r.max = ComputeExpr(a.max, b.max); - } - return IntervalSet::make(r); +inline IntervalSet Combine(Analyzer* analyer, + IntervalSet a, + IntervalSet b) { + if (a->IsSinglePoint() && b->IsSinglePoint()) { + return IntervalSet::SinglePoint(a->min_value + b->min_value); + } + if (a->IsEmpty()) return a; + if (b->IsEmpty()) return b; + Expr min_value = + a->HasLowerBound() && b->HasLowerBound() ? + a->min_value + b->min_value : neg_inf(); + Expr max_value = + a->HasUpperBound() && b->HasUpperBound() ? + a->max_value + b->max_value : pos_inf(); + return IntervalSet(min_value, max_value); } template<> -inline IntSet CombineInterval(Interval a, Interval b) { - if (a.is_single_point() && b.is_single_point()) { - return IntSet::single_point(ComputeExpr(a.min, b.min)); +inline IntervalSet Combine(Analyzer* analyer, + IntervalSet a, + IntervalSet b) { + if (a->IsSinglePoint() && b->IsSinglePoint()) { + return IntervalSet::SinglePoint(a->min_value - b->min_value); } - Interval r = Interval::everything(); - if (a.has_lower_bound() && b.has_upper_bound()) { - r.min = ComputeExpr(a.min, b.max); - } - if (a.has_upper_bound() && b.has_lower_bound()) { - r.max = ComputeExpr(a.max, b.min); - } - return IntervalSet::make(r); + if (a->IsEmpty()) return a; + if (b->IsEmpty()) return b; + Expr min_value = + a->HasLowerBound() && b->HasUpperBound() ? + a->min_value - b->max_value : neg_inf(); + Expr max_value = + a->HasUpperBound() && b->HasLowerBound() ? + a->max_value - b->min_value : pos_inf(); + return IntervalSet(min_value, max_value); } + template<> -inline IntSet CombineInterval(Interval a, Interval b) { - if (a.is_single_point() && b.is_single_point()) { - return IntSet::single_point(ComputeExpr(a.min, b.min)); - } - if (a.is_single_point() && !b.is_single_point()) { +inline IntervalSet Combine(Analyzer* analyzer, + IntervalSet a, + IntervalSet b) { + if (a->IsSinglePoint() && b->IsSinglePoint()) { + return IntervalSet::SinglePoint(a->min_value * b->min_value); + } + if (a->IsEmpty()) return a; + if (b->IsEmpty()) return b; + if (a->IsSinglePoint()) { + // assert !b->IsSinglePoint(); std::swap(a, b); } - if (b.is_single_point()) { - if (is_zero(b.min)) return IntSet::single_point(0); - if (is_one(b.min)) return IntervalSet::make(a); - Expr e1 = a.has_lower_bound() ? ComputeExpr(a.min, b.min) : a.min; - Expr e2 = a.has_upper_bound() ? ComputeExpr(a.max, b.min) : a.max; - // no relaxation is needed in here due to set is inclusive - // TODO(tqchen): consider convert to StrideSet. - if (is_positive_const(b.min)) { - return IntervalSet::make(e1, e2); - } else if (is_negative_const(b.min)) { - return IntervalSet::make(e2, e1); - } else if (a.is_bounded()) { + if (b->IsSinglePoint()) { + if (is_zero(b->min_value)) return IntervalSet::SinglePoint(b->min_value); + if (is_one(b->min_value)) return a; + if (analyzer->CanProveGreaterEqual(b->min_value, 0)) { + Expr min_value = a->HasLowerBound() ? a->min_value * b->min_value : neg_inf(); + Expr max_value = a->HasUpperBound() ? a->max_value * b->min_value : pos_inf(); + return IntervalSet(min_value, max_value); + } else if (analyzer->CanProveGreaterEqual(-b->min_value, 1)) { + Expr min_value = a->HasUpperBound() ? a->max_value * b->min_value : neg_inf(); + Expr max_value = a->HasLowerBound() ? a->min_value * b->min_value : pos_inf(); + return IntervalSet(min_value, max_value); + } else if (a->HasUpperBound() && a->HasLowerBound()) { using ir::Select; - Expr cmp = b.min >= make_zero(b.min.type().element_of()); - return IntervalSet::make(Select::make(cmp, e1, e2), Select::make(cmp, e2, e1)); + Expr sign = b->min_value >= make_zero(b->min_value.type().element_of()); + Expr e1 = a->min_value * b->min_value; + Expr e2 = a->max_value * b->min_value; + return IntervalSet(Select::make(sign, e1, e2), Select::make(sign, e2, e1)); } } - LOG(WARNING) << "Return Everything in CombineInterval Mul"; - return IntSet::everything(); + DLOG(WARNING) << "Return Everything in CombineInterval Mul"; + return IntervalSet::Everything(); } template<> -inline IntSet CombineInterval
(Interval a, Interval b) { - if (a.is_single_point() && b.is_single_point()) { - return IntSet::single_point(ComputeExpr
(a.min, b.min)); - } - if (b.is_single_point()) { - if (is_zero(b.min)) { +inline IntervalSet Combine(Analyzer* analyzer, + IntervalSet a, + IntervalSet b) { + if (a->IsSinglePoint() && b->IsSinglePoint()) { + return IntervalSet::SinglePoint(a->min_value / b->min_value); + } + if (a->IsEmpty()) return a; + if (b->IsEmpty()) return b; + if (b->IsSinglePoint()) { + if (is_zero(b->min_value)) { LOG(FATAL) << "Divide by zero in CombineInterval Div"; } - if (is_one(b.min)) return IntervalSet::make(a); - Expr e1 = a.has_lower_bound() ? ComputeExpr
(a.min, b.min) : a.min; - Expr e2 = a.has_upper_bound() ? ComputeExpr
(a.max, b.min) : a.max; + if (is_one(b->min_value)) return a; // no relaxation is needed in here due to set is inclusive - if (is_positive_const(b.min)) { - return IntervalSet::make(e1, e2); - } else if (is_negative_const(b.min)) { - return IntervalSet::make(e2, e1); - } else if (a.is_bounded()) { + if (analyzer->CanProveGreaterEqual(b->min_value, 0)) { + Expr min_value = a->HasLowerBound() ? a->min_value / b->min_value : neg_inf(); + Expr max_value = a->HasUpperBound() ? a->max_value / b->min_value : pos_inf(); + return IntervalSet(min_value, max_value); + } else if (analyzer->CanProveGreaterEqual(-b->min_value, 1)) { + Expr min_value = a->HasUpperBound() ? a->max_value / b->min_value : neg_inf(); + Expr max_value = a->HasLowerBound() ? a->min_value / b->min_value : pos_inf(); + return IntervalSet(min_value, max_value); + } else if (a->HasUpperBound() && a->HasLowerBound()) { using ir::Select; - Expr cmp = b.min >= make_zero(b.min.type().element_of()); - return IntervalSet::make(Select::make(cmp, e1, e2), Select::make(cmp, e2, e1)); + Expr sign = b->min_value >= make_zero(b->min_value.type().element_of()); + Expr e1 = a->min_value / b->min_value; + Expr e2 = a->max_value / b->min_value; + return IntervalSet(Select::make(sign, e1, e2), Select::make(sign, e2, e1)); } } - LOG(WARNING) << "Return Everything in CombineInterval Div"; - return IntSet::everything(); + DLOG(WARNING) << "Return Everything in CombineInterval Div"; + return IntervalSet::Everything(); } template<> -inline IntSet CombineInterval(Interval a, Interval b) { - if (a.is_single_point() && b.is_single_point()) { - return IntSet::single_point(ComputeExpr(a.min, b.min)); +inline IntervalSet Combine(Analyzer* analyzer, + IntervalSet a, + IntervalSet b) { + if (a->IsSinglePoint() && b->IsSinglePoint()) { + return IntervalSet::SinglePoint(a->min_value % b->min_value); } - if (b.is_single_point()) { - Expr divisor = b.min; + if (a->IsEmpty()) return a; + if (b->IsEmpty()) return b; + + if (b->IsSinglePoint()) { + const Expr& divisor = b->min_value; if (is_zero(divisor)) { LOG(FATAL) << "Modular by zero in CombineInterval Mod"; } - return IntervalSet::make(make_zero(divisor.type()), divisor - 1); + if (analyzer->CanProveGreaterEqual(divisor, 0)) { + return IntervalSet(make_zero(divisor.type()), divisor - 1); + } else { + Expr bound = abs(divisor) - 1; + return IntervalSet(-bound, bound); + } } - - LOG(WARNING) << "Return Everything in CombineInterval Mod"; - return IntSet::everything(); + DLOG(WARNING) << "Return Everything in CombineInterval Mod"; + return IntervalSet::Everything(); } template<> -inline IntSet CombineInterval(Interval a, Interval b) { - if (a.is_single_point() && b.is_single_point()) { - return IntSet::single_point(ComputeExpr(a.min, b.min)); +inline IntervalSet Combine(Analyzer* analzyer, + IntervalSet a, + IntervalSet b) { + if (a->IsSinglePoint() && b->IsSinglePoint()) { + return IntervalSet::SinglePoint(a->min_value % b->min_value); } - return IntervalSet::make(Interval::make_max(a.min, b.min), - Interval::make_max(a.max, b.max)); + if (a->IsEmpty()) return a; + if (b->IsEmpty()) return b; + return IntervalSet(max(a->min_value, b->min_value), + max(a->max_value, b->max_value)); } template<> -inline IntSet CombineInterval(Interval a, Interval b) { - if (a.is_single_point() && b.is_single_point()) { - return IntSet::single_point(ComputeExpr(a.min, b.min)); +inline IntervalSet Combine(Analyzer* analzyer, + IntervalSet a, + IntervalSet b) { + if (a->IsSinglePoint() && b->IsSinglePoint()) { + return IntervalSet::SinglePoint(a->min_value % b->min_value); } - return IntervalSet::make(Interval::make_min(a.min, b.min), - Interval::make_min(a.max, b.max)); -} - -template -inline IntSet CombineInterval_(IntSet a, IntSet b) { - return CombineInterval( - a.as()->i, b.as()->i); -} - -// stride related -inline IntSet AsStrideSet(IntSet a) { - if (a.as()) return a; - const IntervalSet* s = a.as(); - CHECK(s->i.is_bounded()); - NodePtr n = make_node(); - n->base = s->i; - return IntSet(n); -} -template -inline IntSet CombineSets(IntSet a, IntSet b) { - return CombineInterval_(a.cover_interval(), b.cover_interval()); + if (a->IsEmpty()) return a; + if (b->IsEmpty()) return b; + return IntervalSet(min(a->min_value, b->min_value), + min(a->max_value, b->max_value)); } -template<> -inline IntSet CombineSets(IntSet a, IntSet b) { - const IntervalSet* a_int = a.as(); - const IntervalSet* b_int = b.as(); - if (a_int && is_zero(a_int->i.min)) return b; - if (b_int && is_zero(b_int->i.min)) return a; - a = AsStrideSet(a); - b = AsStrideSet(b); - const StrideSet* a_stride = a.as(); - const StrideSet* b_stride = b.as(); - auto n = make_node(*a_stride); - for (size_t i = 0; i < b_stride->extents.size(); ++i) { - n->extents.push_back(b_stride->extents[i]); - n->strides.push_back(b_stride->strides[i]); - } - n->base = CombineInterval( - a_stride->base, b_stride->base).as()->i; - return IntSet(n); -} - -inline IntSet NegateSet(IntSet a) { - const IntervalSet* a_int = a.as(); - if (a_int) { - if (a_int->i.is_single_point()) { - return IntSet::single_point(-a_int->i.min); - } else { - Interval r = Interval::everything(); - if (a_int->i.has_upper_bound()) { - r.min = -(a_int->i.max); - } - if (a_int->i.has_lower_bound()) { - r.max = -(a_int->i.min); - } - return IntervalSet::make(r); - } - } else { - return NegateSet(a.cover_interval()); +// internal helper function to get an interval set +IntervalSet ToIntervalSet(IntSet set) { + if (auto* node = set.as()) { + return GetRef(node); } + DLOG(INFO) << "cannot resolve int set " << set; + return IntervalSet::Everything(); } -template<> -inline IntSet CombineSets(IntSet a, IntSet b) { - return CombineSets(a, NegateSet(b)); -} - -TVM_DECLARE_LOGICAL_OP(And); -TVM_DECLARE_LOGICAL_OP(Or); -TVM_DECLARE_LOGICAL_OP(EQ); -TVM_DECLARE_LOGICAL_OP(NE); -TVM_DECLARE_LOGICAL_OP(GE); -TVM_DECLARE_LOGICAL_OP(GT); -TVM_DECLARE_LOGICAL_OP(LE); -TVM_DECLARE_LOGICAL_OP(LT); -TVM_DECLARE_LOGICAL_OP(Not); +using namespace ir; -// generic combine operations of two sets -template -inline IntSet Combine(const IntSet& a, const IntSet &b) { - if (is_logical_op::value) { - return IntervalSet::make(0, 1); +// Simplified version of int set evaluator that operates on IntervalSet +// We might use better set analysis in the future to replace the intervalset. +class IntervalSetEvaluator : + public ExprFunctor { + public: + IntervalSetEvaluator(Analyzer* analyzer, + const Map& dom_map, + bool eval_vec = false) + : analyzer_(analyzer), + dom_map_(dom_map), + eval_vec_(eval_vec) { } - const IntervalSet* a_int = a.as(); - const IntervalSet* b_int = b.as(); - if (a_int && a_int->i.is_everything()) return a; - if (b_int && b_int->i.is_everything()) return b; - if (a_int && b_int) { - return CombineInterval(a_int->i, b_int->i); + + IntervalSet Eval(const Expr& val) { + return this->VisitExpr(val); } - if (a_int && !(a_int->i.is_bounded())) { - return CombineInterval_(a, b.cover_interval()); + + IntervalSet VisitExpr_(const IntImm* op) final { + return IntervalSet::SinglePoint(GetRef(op)); } - if (b_int && !(b_int->i.is_bounded())) { - return CombineInterval_(a.cover_interval(), b); + + IntervalSet VisitExpr_(const UIntImm* op) final { + return IntervalSet::SinglePoint(GetRef(op)); } - return CombineSets(a, b); -} -class IntSetEvaluator : - public ExprFunctor { - public: - explicit IntSetEvaluator( - const std::unordered_map& dom_map, - bool eval_vec = false) - : dom_map_(dom_map), eval_vec_(eval_vec) {} - // Evaluate. - IntSet Eval(const Expr& e) { - return this->VisitExpr(e, e); - } - IntSet VisitExpr_(const IntImm* op, const Expr& e) final { - return IntSet::single_point(e); - } - IntSet VisitExpr_(const UIntImm* op, const Expr& e) final { - return IntSet::single_point(e); - } - IntSet VisitExpr_(const Variable* op, const Expr& e) final { - auto it = dom_map_.find(op); + IntervalSet VisitExpr_(const Variable* op) final { + Var var = GetRef(op); + auto it = dom_map_.find(var); if (it != dom_map_.end()) { - return it->second; + return ToIntervalSet((*it).second); } else { - return IntSet::single_point(e); + return IntervalSet::SinglePoint(var); } } - IntSet VisitExpr_(const Add* op, const Expr& e) final { - return Binary(op, e); + + IntervalSet VisitExpr_(const Add* op) final { + return VisitBinaryExpr_(op); } - IntSet VisitExpr_(const Sub* op, const Expr& e) final { - return Binary(op, e); + + IntervalSet VisitExpr_(const Sub* op) final { + return VisitBinaryExpr_(op); } - IntSet VisitExpr_(const Mul* op, const Expr& e) final { - return Binary(op, e); + + IntervalSet VisitExpr_(const Mul* op) final { + return VisitBinaryExpr_(op); } - IntSet VisitExpr_(const Div* op, const Expr& e) final { - return Binary(op, e); + + IntervalSet VisitExpr_(const Div* op) final { + return VisitBinaryExpr_(op); } - IntSet VisitExpr_(const Mod* op, const Expr& e) final { - return Binary(op, e); + + IntervalSet VisitExpr_(const Mod* op) final { + return VisitBinaryExpr_(op); } - IntSet VisitExpr_(const Min* op, const Expr& e) final { - return Binary(op, e); + + IntervalSet VisitExpr_(const Min* op) final { + return VisitBinaryExpr_(op); } - IntSet VisitExpr_(const Max* op, const Expr& e) final { - return Binary(op, e); + + IntervalSet VisitExpr_(const Max* op) final { + return VisitBinaryExpr_(op); } - IntSet VisitExpr_(const EQ* op, const Expr& e) final { - return Binary(op, e); + + IntervalSet VisitExpr_(const EQ* op) final { + return VisitBinaryExpr_(op); } - IntSet VisitExpr_(const NE* op, const Expr& e) final { - return Binary(op, e); + + IntervalSet VisitExpr_(const NE* op) final { + return VisitBinaryExpr_(op); } - IntSet VisitExpr_(const LT* op, const Expr& e) final { - return Binary(op, e); + + IntervalSet VisitExpr_(const LT* op) final { + return VisitBinaryExpr_(op); } - IntSet VisitExpr_(const LE* op, const Expr& e) final { - return Binary(op, e); + + IntervalSet VisitExpr_(const LE* op) final { + return VisitBinaryExpr_(op); } - IntSet VisitExpr_(const GT* op, const Expr& e) final { - return Binary(op, e); + + IntervalSet VisitExpr_(const GT* op) final { + return VisitBinaryExpr_(op); } - IntSet VisitExpr_(const GE* op, const Expr& e) final { - return Binary(op, e); + + IntervalSet VisitExpr_(const GE* op) final { + return VisitBinaryExpr_(op); } - IntSet VisitExpr_(const And* op, const Expr& e) final { - return Binary(op, e); + + IntervalSet VisitExpr_(const And* op) final { + return VisitBinaryExpr_(op); } - IntSet VisitExpr_(const Or* op, const Expr& e) final { - return Binary(op, e); + + IntervalSet VisitExpr_(const Or* op) final { + return VisitBinaryExpr_(op); } - IntSet VisitExpr_(const Ramp* op, const Expr& e) final { + + IntervalSet VisitExpr_(const Ramp* op) final { CHECK(eval_vec_); - IntSet base = Eval(op->base); - int vstride; - if (GetConstInt(op->stride, &vstride)) { + IntervalSet base = Eval(op->base); + PVar stride; + if (stride.Match(op->stride)) { Type t = op->base.type(); - if (vstride > 0) { + int64_t vstride = stride.Eval()->value; + if (vstride> 0) { return Combine( + analyzer_, base, - IntSet::interval(make_zero(t), - make_const(t, vstride * op->lanes -1))); + IntervalSet(make_zero(t), make_const(t, vstride * op->lanes - 1))); } else { return Combine( + analyzer_, base, - IntSet::interval(make_const(t, vstride * op->lanes + 1), - make_zero(t))); + IntervalSet(make_const(t, vstride * op->lanes + 1), make_zero(t))); } } - LOG(WARNING) << "cannot evaluate set on expression " << e; - return IntSet::everything(); + DLOG(WARNING) << "cannot evaluate set on expression " << GetRef(op); + return IntervalSet::Everything(); } - IntSet VisitExpr_(const Broadcast* op, const Expr& e) final { + + IntervalSet VisitExpr_(const Broadcast* op) final { CHECK(eval_vec_); - return Eval(op->value); + return VisitExpr(op->value); } - IntSet VisitExpr_(const Select* op, const Expr& e) final { - IntSet true_set = this->Eval(op->true_value); - IntSet false_set = this->Eval(op->false_value); - return Union({false_set, true_set}); + + IntervalSet VisitExpr_(const Select* op) final { + IntervalSet true_set = this->Eval(op->true_value); + IntervalSet false_set = this->Eval(op->false_value); + return Union(analyzer_, false_set, true_set); } - IntSet VisitExprDefault_(const Node* op, const Expr& e) final { - LOG(WARNING) << "cannot evaluate set type " << e->type_key(); - return IntSet::everything(); + + IntervalSet VisitExprDefault_(const Node* op) final { + DLOG(WARNING) << "cannot evaluate set type " << op->type_key(); + return IntervalSet::Everything(); } private: + // whether set is exactly single point that equals value. + bool MatchPoint(const IntervalSet& set, + const Expr& value) const { + return set->min_value.same_as(value) && set->max_value.same_as(value); + } + template - inline IntSet Binary(const T* op, const Expr& e) { - IntSet a = this->Eval(op->a); - IntSet b = this->Eval(op->b); + inline IntervalSet VisitBinaryExpr_(const T* op) { + IntervalSet a = this->Eval(op->a); + IntervalSet b = this->Eval(op->b); if (MatchPoint(a, op->a) && MatchPoint(b, op->b)) { - return IntSet::single_point(e); + return IntervalSet::SinglePoint(GetRef(op)); } - return Combine(a, b); + return Combine(analyzer_, a, b); } - const std::unordered_map& dom_map_; + Analyzer* analyzer_; + const Map& dom_map_; bool eval_vec_{false}; }; +class IntSetAnalyzer::Impl { + public: + explicit Impl(Analyzer* analyzer) + : analyzer_(analyzer) { + } + + IntSet Eval(const Expr& expr, const Map& dom_map) const { + return IntervalSetEvaluator(analyzer_, dom_map).Eval(expr); + } + + private: + Analyzer* analyzer_; +}; + +IntSetAnalyzer::IntSetAnalyzer(Analyzer* parent) + : impl_(new Impl(parent)) { +} + +IntSetAnalyzer::~IntSetAnalyzer() { + delete impl_; +} + +IntSet IntSetAnalyzer::operator()(const Expr& expr, + const Map& dom_map) { + return impl_->Eval(expr, dom_map); +} + +// Quickly adapt to IntSet interface +// TODO(tqchen): revisit IntSet interface as well. +Range IntSet::cover_range(Range max_range) const { + IntSet temp; + const IntervalSetNode* s_int = (*this).as(); + CHECK(s_int != nullptr); + if (s_int->HasUpperBound() && s_int->HasLowerBound()) { + return Range::make_by_min_extent( + s_int->min_value, Simplify(s_int->max_value + 1 - s_int->min_value)); + } + return max_range; +} + +Expr IntSet::min() const { + const IntervalSetNode* s_int = (*this).as(); + CHECK(s_int); + return s_int->min_value; +} + +Expr IntSet::max() const { + const IntervalSetNode* s_int = (*this).as(); + CHECK(s_int); + return s_int->max_value; +} + +bool IntSet::is_nothing() const { + const IntervalSetNode* s_int = (*this).as(); + return (s_int && s_int->IsEmpty()); +} + +bool IntSet::is_everything() const { + const IntervalSetNode* s_int = (*this).as(); + return (s_int && s_int->IsEverything()); +} + +bool IntSet::is_single_point() const { + const IntervalSetNode* s_int = (*this).as(); + return (s_int && s_int->IsSinglePoint()); +} + +bool IntSet::can_prove_positive() const { + const IntervalSetNode* s_int = (*this).as(); + return (s_int && is_positive_const(ir::Simplify(s_int->min_value))); +} + +bool IntSet::can_prove_negative() const { + const IntervalSetNode* s_int = (*this).as(); + return (s_int && is_negative_const(ir::Simplify(s_int->max_value))); +} + +bool IntSet::can_prove_non_positive() const { + if (const auto* s_int = (*this).as()) { + auto max = ir::Simplify(s_int->max_value); + return is_zero(max) || is_negative_const(max); + } + return false; +} + +bool IntSet::can_prove_non_negative() const { + if (const IntervalSetNode* s_int = (*this).as()) { + auto min = ir::Simplify(s_int->min_value); + return is_zero(min) || is_positive_const(min); + } + return false; +} + +SignType IntSet::sign_type() const { + if (can_prove_positive()) { + return kPositive; + } else if (can_prove_negative()) { + return kNegative; + } else if (is_single_point() && is_zero(point_value())) { + return kZero; + } else { + return kUnknown; + } +} +Expr IntSet::point_value() const { + const IntervalSetNode* s_int = (*this).as(); + CHECK(s_int && s_int->IsSinglePoint()); + return s_int->min_value; +} + +IntSet IntSet::nothing() { + return IntervalSet::Empty(); +} + +IntSet IntSet::everything() { + return IntervalSet::Everything(); +} + +IntSet IntSet::single_point(Expr x) { + return IntervalSet::SinglePoint(x); +} + +IntSet IntSet::interval(Expr min, Expr max) { + if (min.same_as(max)) { + return IntSet::single_point(min); + } + return IntervalSet(min, max); +} + +// Range related code +inline bool ProveEqual(Expr lhs, Expr rhs) { + return is_zero(ir::Simplify(lhs - rhs)); +} + +IntSet IntSet::range(Range r) { + // must make sure it can be matched back by MatchRange. + if (is_one(r->extent)) { + return IntSet::single_point(r->min); + } + return IntervalSet(r->min, r->extent + r->min - 1); +} + +bool IntSet::match_range(const Range& b) const { + const IntSet& a = *this; + const IntervalSetNode* a_int = a.as(); + if (!a_int) return false; + return ProveEqual(a_int->min_value, b->min) && + ProveEqual(a_int->max_value, b->extent + b->min - 1); +} + +IntSet Union(const Array& sets) { + if (sets.size() == 0) return IntSet::nothing(); + if (sets.size() == 1) return sets[0]; + Analyzer ana; + IntervalSet x = ToIntervalSet(sets[0]); + for (size_t i = 1; i < sets.size(); ++i) { + x = Union(&ana, x, ToIntervalSet(sets[i])); + } + return IntervalSet(ir::Simplify(x->min_value), + ir::Simplify(x->max_value)); +} + +IntSet Intersect(const Array& sets) { + if (sets.size() == 0) return IntSet::nothing(); + if (sets.size() == 1) return sets[0]; + Analyzer ana; + IntervalSet x = ToIntervalSet(sets[0]); + for (size_t i = 1; i < sets.size(); ++i) { + x = Intersect(&ana, x, ToIntervalSet(sets[i])); + } + return IntervalSet(ir::Simplify(x->min_value), + ir::Simplify(x->max_value)); +} + +Map ConvertDomMap(const Map& dom_map) { + Map dmap; + for (auto kv : dom_map) { + dmap.Set(kv.first->var, kv.second); + } + return dmap; +} + +Map ConvertDomMap( + const std::unordered_map& dom_map) { + Map dmap; + for (auto kv : dom_map) { + dmap.Set(GetRef(kv.first), kv.second); + } + return dmap; +} + IntSet EvalSet(Expr e, - const std::unordered_map& dom_map) { - return IntSetEvaluator(dom_map, false).Eval(e); + const Map& dom_map) { + Analyzer ana; + return IntervalSetEvaluator(&ana, dom_map, false).Eval(e); } IntSet IntSet::vector(Expr x) { - std::unordered_map dmap; - return IntSetEvaluator(dmap, true).Eval(x); + Analyzer ana; + Map dmap; + return IntervalSetEvaluator(&ana, dmap, true).Eval(x); } IntSet EvalSet(Expr e, const Map& dom_map) { - std::unordered_map dmap; - for (auto kv : dom_map) { - dmap[kv.first->var.as()] = kv.second; - } - return EvalSet(e, dmap); + return EvalSet(e, ConvertDomMap(dom_map)); } -IntSet EvalSet(Range r, +IntSet EvalSet(Expr e, const std::unordered_map& dom_map) { - IntSetEvaluator m(dom_map); - IntSet min_set = m.Eval(r->min).cover_interval(); + return EvalSet(e, ConvertDomMap(dom_map)); +} + +IntSet EvalSet(Range r, + const Map& dom_map) { + Analyzer ana; + IntervalSetEvaluator m(&ana, dom_map); + IntervalSet min_set = m.Eval(r->min); // Simplifying first can give tighter bounds if r->min and r->extent share variables - Expr sum = ComputeExpr(ComputeExpr(r->min, r->extent), 1); - IntSet max_set = m.Eval(Simplify(sum)).cover_interval(); - const Interval& ni = min_set.as()->i; - const Interval& xi = max_set.as()->i; - if (!ni.has_lower_bound()) return IntSet::everything(); - if (!xi.has_upper_bound()) return IntSet::everything(); - return IntervalSet::make(ni.min, xi.max); + Expr sum = r->min + r->extent - 1; + IntervalSet max_set = m.Eval(Simplify(sum)); + if (!min_set->HasLowerBound()) return IntSet::everything(); + if (!max_set->HasUpperBound()) return IntSet::everything(); + return IntervalSet(min_set->min_value, max_set->max_value); } -IntSet EvalSet(IntSet s, +IntSet EvalSet(Range r, const std::unordered_map& dom_map) { - IntSetEvaluator m(dom_map); - s = s.cover_interval(); - const IntervalSet* s_int = s.as(); - Expr vmax = s_int->i.has_upper_bound() ? - m.Eval(s_int->i.max).cover_interval().max() : s_int->i.max; - Expr vmin = s_int->i.has_lower_bound() ? - m.Eval(s_int->i.min).cover_interval().min() : s_int->i.min; - return IntervalSet::make(vmin, vmax); + return EvalSet(r, ConvertDomMap(dom_map)); } -class SubExprIntSetEvaluator : public IntSetEvaluator { +IntSet EvalSet(IntSet s, + const std::unordered_map& dom_map) { + Analyzer ana; + auto dmap = ConvertDomMap(dom_map); + IntervalSetEvaluator m(&ana, dmap); + const IntervalSetNode* s_int = s.as(); + Expr vmax = s_int->HasUpperBound() ? + m.Eval(s_int->max_value).max() : s_int->max_value; + Expr vmin = s_int->HasLowerBound() ? + m.Eval(s_int->min_value).min() : s_int->min_value; + return IntervalSet(vmin, vmax); +} + +class SubExprIntervalSetEvaluator : public IntervalSetEvaluator { public: - explicit SubExprIntSetEvaluator( - const std::unordered_map& dom_map) - : IntSetEvaluator(dom_map) {} + explicit SubExprIntervalSetEvaluator( + Analyzer* analyzer, + const Map& dom_map) + : IntervalSetEvaluator(analyzer, dom_map) {} - IntSet VisitExpr(const Expr& n, const Expr& e) final { - IntSet ret = IntSetEvaluator::VisitExpr(n, e); + IntervalSet VisitExpr(const Expr& n) final { + IntervalSet ret = IntervalSetEvaluator::VisitExpr(n); expr_map[n] = ret; return ret; } @@ -635,28 +702,26 @@ class SubExprIntSetEvaluator : public IntSetEvaluator { ExprIntSetMap expr_map; }; -ExprIntSetMap EvalSetForEachSubExpr(Expr e, +ExprIntSetMap EvalSetForEachSubExpr( + Expr e, const std::unordered_map& dom_map) { - SubExprIntSetEvaluator m(dom_map); + Analyzer ana; + auto dmap = ConvertDomMap(dom_map); + SubExprIntervalSetEvaluator m(&ana, dmap); m.Eval(e); return m.expr_map; } IntSet EvalSet(Range r, const Map& dom_map) { - std::unordered_map dmap; - for (auto kv : dom_map) { - dmap[kv.first->var.as()] = kv.second; - } - return EvalSet(r, dmap); + return EvalSet(r, ConvertDomMap(dom_map)); } TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) -.set_dispatch([](const IntervalSet *op, IRPrinter *p) { - p->stream << "interval-set" - << "[" << op->i.min << ", " - << op->i.max << ']'; +.set_dispatch([](const IntervalSetNode *op, IRPrinter *p) { + p->stream << "IntervalSet" + << "[" << op->min_value << ", " + << op->max_value << ']'; }); - } // namespace arith } // namespace tvm diff --git a/src/arithmetic/int_set.h b/src/arithmetic/int_set.h new file mode 100644 index 000000000000..9bbd9e5e18d0 --- /dev/null +++ b/src/arithmetic/int_set.h @@ -0,0 +1,143 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file int_set.h + * \brief Internal data structure for integer set. + */ +#ifndef TVM_ARITHMETIC_INT_SET_H_ +#define TVM_ARITHMETIC_INT_SET_H_ + +#include +#include +#include +#include "const_fold.h" + +namespace tvm { +namespace arith { + +/*! + * \brief Symbolic interval set. + * + * \note We intentionally keep the internal of IntSet private, + as we anticipate that + */ +class IntervalSetNode : public IntSetNode { + public: + /*! \brief Minimum value in the interval. */ + Expr min_value; + /*! \brief Maximum value in the interval. */ + Expr max_value; + + // visitor overload. + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("min_value", &min_value); + v->Visit("max_value", &max_value); + } + + /*! \return Whether the interval has upper bound. */ + bool HasUpperBound() const { + return !is_pos_inf(max_value) && !IsEmpty(); + } + /*! \return Whether the interval has lower bound. */ + bool HasLowerBound() const { + return !is_neg_inf(min_value) && !IsEmpty(); + } + /*! \return Whether the interval is a single point. */ + bool IsSinglePoint() const { + return min_value.same_as(max_value); + } + /*! \return whether interval represent nothing */ + bool IsEmpty() const { + // during computations, either extreme could occur. + return is_pos_inf(min_value) || is_neg_inf(max_value); + } + /*! \return whether interval represent everything */ + bool IsEverything() const { + return is_neg_inf(min_value) && is_pos_inf(max_value); + } + + static constexpr const char* _type_key = "arith.IntervalSet"; + TVM_DECLARE_NODE_TYPE_INFO(IntervalSetNode, IntSetNode); +}; + +/*! + * \brief Interval set used for symbolic integer analysis. + * \sa IntervalSetNode + */ +class IntervalSet : public IntSet { + public: + /*! + * \brief Make a new instance of interval set. + * \param min_value The minimum value in the interval. + * \param max_value The maximum value in the interval. + * \return The created set. + */ + TVM_DLL IntervalSet(Expr min_value, Expr max_value); + + /*! + * \brief Create an IntervalSet that represents a single point. + * \param value The value to be represented. + * \return The result set. + */ + static IntervalSet SinglePoint(Expr value) { + return IntervalSet(value, value); + } + /*! + * \brief Create an IntervalSet that represents everything. + * \param value The value to be represented. + * \return The result set. + */ + static IntervalSet Everything() { + return IntervalSet(neg_inf(), pos_inf()); + } + /*! + * \brief Create an empty eet. + * \return The result set. + */ + static IntervalSet Empty() { + return IntervalSet(pos_inf(), neg_inf()); + } + + TVM_DEFINE_NODE_REF_COW(IntervalSetNode); + TVM_DEFINE_NODE_REF_METHODS(IntervalSet, IntSet, IntervalSetNode); +}; + +/*! + * \brief Create union of two IntervalSets. + * \param analyzer The analyzer for simplification analysis. + * \param a The first set. + * \param b The second set. + * \return The result set. + */ +TVM_DLL IntervalSet Union(Analyzer* analyzer, IntervalSet a, IntervalSet b); + +/*! + * \brief Create insersection of two IntervalSets. + * \param analzyer The analyzer for simplification analysis. + * \param a The first set. + * \param b The second set. + * \return The result set. + */ +TVM_DLL IntervalSet Intersect(Analyzer *analzyer, IntervalSet a, IntervalSet b); + +} // namespace arith +} // namespace tvm + +#endif // TVM_ARITHMETIC_INT_SET_H_ diff --git a/src/arithmetic/int_set_internal.h b/src/arithmetic/int_set_internal.h deleted file mode 100644 index 8b675cfbffda..000000000000 --- a/src/arithmetic/int_set_internal.h +++ /dev/null @@ -1,79 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * Copyright (c) 2017 by Contributors - * \file int_set_internal.h - * \brief Implementations of integer set - */ -#ifndef TVM_ARITHMETIC_INT_SET_INTERNAL_H_ -#define TVM_ARITHMETIC_INT_SET_INTERNAL_H_ - -#include -#include -#include - -namespace tvm { -namespace arith { - -using HalideIR::Internal::Interval; - -/*! \brief Set of continuous interval */ -struct IntervalSet : public IntSetNode { - /*! \brief the internal interval*/ - Interval i; - - static IntSet make(Interval i) { - NodePtr n = - make_node(); - n->i = i; - return IntSet(n); - } - static IntSet make(Expr min, Expr max) { - NodePtr n = - make_node(); - n->i.min = min; - n->i.max = max; - return IntSet(n); - } - - static constexpr const char* _type_key = "IntervalSet"; - TVM_DECLARE_NODE_TYPE_INFO(IntervalSet, IntSetNode); -}; - -/*! - * \brief set represented by strided integers - * Reserved for cases where strided access is supported. - */ -struct StrideSet : public IntSetNode { - /*! \brief the base inetrval */ - Interval base; - /*! \brief additional extents in positive number */ - Array extents; - /*! \brief additional strides in positive number */ - Array strides; - - static constexpr const char* _type_key = "StrideSet"; - TVM_DECLARE_NODE_TYPE_INFO(StrideSet, IntSetNode); -}; - -} // namespace arith -} // namespace tvm - -#endif // TVM_ARITHMETIC_INT_SET_INTERNAL_H_ diff --git a/src/lang/expr_operator.cc b/src/lang/expr_operator.cc index 8537f17b763c..3f5254069b8d 100644 --- a/src/lang/expr_operator.cc +++ b/src/lang/expr_operator.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -188,7 +188,15 @@ Expr operator%(Expr a, Expr b) { return ir::Mod::make(a, b); } + Expr min(Expr a, Expr b) { + // inf-aware simplificaiton + using arith::is_pos_inf; + using arith::is_neg_inf; + if (is_pos_inf(a)) return b; + if (is_neg_inf(a)) return a; + if (is_pos_inf(b)) return a; + if (is_neg_inf(b)) return b; BinaryOpMatchTypes(a, b); Expr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; @@ -196,6 +204,13 @@ Expr min(Expr a, Expr b) { } Expr max(Expr a, Expr b) { + // inf-aware simplificaiton + using arith::is_pos_inf; + using arith::is_neg_inf; + if (is_pos_inf(a)) return a; + if (is_neg_inf(a)) return b; + if (is_pos_inf(b)) return b; + if (is_neg_inf(b)) return a; BinaryOpMatchTypes(a, b); Expr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; diff --git a/src/pass/loop_partition.cc b/src/pass/loop_partition.cc index bcb2608682ee..0a5b7410f3cf 100644 --- a/src/pass/loop_partition.cc +++ b/src/pass/loop_partition.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -28,7 +28,7 @@ #include #include #include -#include "../arithmetic/int_set_internal.h" +#include "../arithmetic/int_set.h" #include "../runtime/thread_storage_scope.h" namespace tvm { @@ -366,7 +366,7 @@ class LoopPartitioner : public IRMutator { std::pair> GetIntervalAndCondset(const Partition &partitions, - const arith::Interval &for_interval, + const arith::IntervalSet &for_interval, bool cond_value); inline Stmt MakeFor(const Node* op, Expr extent, Stmt body); @@ -374,6 +374,7 @@ class LoopPartitioner : public IRMutator { /* Candidate IRs that may be partitioned potentially */ std::unordered_map hint_map_; std::unordered_map relax_map_; + arith::Analyzer analyzer_; CandidateSelector selector; }; @@ -381,16 +382,17 @@ class LoopPartitioner : public IRMutator { // given in the second component provably have value given by cond_value std::pair> LoopPartitioner::GetIntervalAndCondset(const Partition &partitions, - const arith::Interval &for_interval, + const arith::IntervalSet &for_interval, bool cond_value) { Array sets; std::unordered_set cond_set; for (const auto &kv : partitions) { if (kv.first.second == cond_value) { - arith::Interval interval = kv.second.as()->i; - arith::Interval intersection = arith::Interval::make_intersection(interval, for_interval); - if (!intersection.is_empty()) { + arith::IntervalSet interval = Downcast(kv.second); + arith::IntervalSet intersection = arith::Intersect( + &analyzer_, interval, for_interval); + if (!intersection->IsEmpty()) { sets.push_back(kv.second); cond_set.insert(kv.first.first); } @@ -463,11 +465,12 @@ Stmt LoopPartitioner::TryPartition(const Node* node, Expr max, Stmt body, bool partition_thread_scope) { + using namespace arith; PartitionFinder finder(var, hint_map_, relax_map_); finder.Visit(body); if (finder.partitions.empty()) return Stmt(); - arith::Interval for_interval(min, max); + arith::IntervalSet for_interval(min, max); bool cond_value; IntSet middle_interval; std::unordered_set cond_set; @@ -478,7 +481,7 @@ Stmt LoopPartitioner::TryPartition(const Node* node, // if such interval doesn't exist, find an interval in which all // conditions on var are false std::tie(middle_interval, cond_set) = - GetIntervalAndCondset(finder.partitions, for_interval, false); + GetIntervalAndCondset(finder.partitions, for_interval, false); if (middle_interval.is_nothing()) // we couldn't find an interval in which the condintions are provably true or false // Therefore, we can't partition the loop based on those conds @@ -488,7 +491,7 @@ Stmt LoopPartitioner::TryPartition(const Node* node, cond_value = true; } - arith::Interval middle_interval_i = middle_interval.as()->i; + IntervalSet middle_interval_i = Downcast(middle_interval); // middle_interval is the subrange of the loop variable range for which a // set of conditions are true (or false resp.) // The part of the loop variable range that is before (after resp.) that @@ -499,7 +502,7 @@ Stmt LoopPartitioner::TryPartition(const Node* node, Expr body_begin; Stmt pre_stmt; bool pre_stmt_recurse = true; - if (middle_interval_i.has_lower_bound()) { + if (middle_interval_i->HasLowerBound()) { body_begin = ir::Simplify(middle_interval.min()); if (!can_prove(body_begin == min)) { Expr cond = (body_begin - min >= 0); @@ -524,7 +527,7 @@ Stmt LoopPartitioner::TryPartition(const Node* node, Expr post_doubt_begin; Stmt post_stmt; bool post_stmt_recurse = true; - if (middle_interval_i.has_upper_bound()) { + if (middle_interval_i->HasUpperBound()) { post_doubt_begin = ir::Simplify(middle_interval.max() + 1); if (!can_prove(middle_interval.max() == max)) { // require the extent to be non-negative diff --git a/tests/python/unittest/test_arith_deduce_bound.py b/tests/python/unittest/test_arith_deduce_bound.py new file mode 100644 index 000000000000..7fe6f56edea7 --- /dev/null +++ b/tests/python/unittest/test_arith_deduce_bound.py @@ -0,0 +1,168 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm + +def test_deduce(): + a = tvm.var('a') + b = tvm.var('b') + c = tvm.var('c') + d = tvm.var('d') + + b_s = tvm.arith.IntervalSet(2, 3) + c_s = tvm.arith.IntervalSet(10, 15) + d_s = tvm.arith.IntervalSet(-3, -1) + zero = tvm.const(0, "int32") + + e0 = (-b)*a+c-d + res0 = tvm.arith.DeduceBound(a, e0>=0, {b: b_s, c: c_s, d: d_s}, {}) + ans0 = ((d - c) /(b*-1)) + assert str(tvm.ir_pass.Simplify(res0.max_value)) == str(ans0) + + # expression containing variable a is on rhs + res0 = tvm.arith.DeduceBound(a, zero <= e0, {b: b_s, c: c_s, d: d_s}, {}) + assert str(tvm.ir_pass.Simplify(res0.max_value)) == str(ans0) + + e0 = d*a+c-d + res0 = tvm.arith.DeduceBound(a, e0>=0, {b: b_s, c: c_s, d: d_s}, {}) + ans0 = ((0-c)/d + 1) + assert str(tvm.ir_pass.Simplify(res0.max_value)) == str(ans0) + + # expression containing variable a is on rhs + res0 = tvm.arith.DeduceBound(a, zero <= e0, {b: b_s, c: c_s, d: d_s}, {}) + assert str(tvm.ir_pass.Simplify(res0.max_value)) == str(ans0) + + e1 = (a*4+b < c) + res1 = tvm.arith.DeduceBound(a, e1, {b: b_s, c: c_s, d: d_s}, {}) + ans1 = (((c - b) + -1)/4) + assert str(tvm.ir_pass.Simplify(res1.max_value)) == str(ans1) + + # expression containing variable a is on rhs + e1 = (c > a*4+b) + res1 = tvm.arith.DeduceBound(a, e1, {b: b_s, c: c_s, d: d_s}, {}) + assert str(tvm.ir_pass.Simplify(res1.max_value)) == str(ans1) + + e2 = (tvm.max(5, a * 4) < 0) + res2 = tvm.arith.DeduceBound(a, e2, {b: b_s, c: c_s, d: d_s}, {}) + assert str(res2.max_value) == "neg_inf" + assert str(res2.min_value) == "pos_inf" + + # expression containing variable a is on rhs + e2 = (zero < tvm.max(5, a * 4)) + res2 = tvm.arith.DeduceBound(a, e2, {b: b_s, c: c_s, d: d_s}, {}) + assert str(res2.max_value) == "neg_inf" + assert str(res2.min_value) == "pos_inf" + + + e3 = (-b)+a*c-d + res3 = tvm.arith.DeduceBound(a, e3>=0, {b: b_s, c: c_s, d: d_s}, {b: b_s, d: d_s}) + ans3 = 2/c+1 + assert str(tvm.ir_pass.Simplify(res3.min_value)) == str(ans3) + + res3 = tvm.arith.DeduceBound(a, zero <= e3, {b: b_s, c: c_s, d: d_s}, {b: b_s, d: d_s}) + assert str(tvm.ir_pass.Simplify(res3.min_value)) == str(ans3) + +def test_check(): + a = tvm.var('a') + b = tvm.var('b') + c = tvm.var('c') + d = tvm.var('d') + + b_s = tvm.arith.IntervalSet(2, 3) + c_s = tvm.arith.IntervalSet(5, 7) + d_s = tvm.arith.IntervalSet(-3, -1) + + # no compare operator + res1 = tvm.arith.DeduceBound(a, a+b, {b: b_s}, {}) + assert res1.is_nothing() + + # multiple compare operators + res2 = tvm.arith.DeduceBound(a, (a+b>3).astype(c.dtype)>c , {b: b_s, c: c_s}, {}) + assert res2.is_nothing() + + # multiple target variable + res2 = tvm.arith.DeduceBound(a, a*2-a>b, {b: b_s}, {}) + assert res2.is_nothing() + +def test_deduce_basic(): + def test_basic(a1, a2, coff): + a = tvm.var('a') + b = tvm.var('b') + b_s = tvm.arith.IntervalSet(a1, a2) + e0 = b + a*coff + 3 + + res1 = tvm.arith.DeduceBound(a, e0<17, {b: b_s}, {b: b_s}) + [x, y] = [res1.max_value, b_s.max_value] if coff > 0 else [res1.min_value, b_s.min_value] + assert (tvm.ir_pass.Simplify((x * coff + 3 + y) < 17)).value == 1 + + # expression containing variable a is on rhs + res1 = tvm.arith.DeduceBound(a, tvm.const(17, "int32") < e0, {b: b_s}, {b: b_s}) + [x, y] = [res1.max_value, b_s.max_value] if coff < 0 else [res1.min_value, b_s.min_value] + assert (tvm.ir_pass.Simplify((x * coff + 3 + y) > 17)).value == 1 + + # expression containing variable a is on rhs + res1 = tvm.arith.DeduceBound(a, tvm.const(17, "int32")>= e0, {b: b_s}, {b: b_s}) + [x, y] = [res1.max_value, b_s.max_value] if coff > 0 else [res1.min_value, b_s.min_value] + assert (tvm.ir_pass.Simplify((x * coff + 3 + y) <= 17)).value == 1 + + res1 = tvm.arith.DeduceBound(a, e0>=17, {b: b_s}, {b: b_s}) + [x, y] = [res1.max_value, b_s.max_value] if coff < 0 else [res1.min_value, b_s.min_value] + assert (tvm.ir_pass.Simplify((x * coff + 3 + y) >= 17)).value == 1 + + test_basic(0, 4, 4) + test_basic(1, 5, 4) + test_basic(2, 6, 4) + test_basic(0, 4, -4) + test_basic(1, 5, -4) + test_basic(2, 6, -4) + +def test_deduce_complex(): + def test_complex(a1, a2, coff): + a = tvm.var('a') + b = tvm.var('b') + b_s = tvm.arith.IntervalSet(a1, a2) + e0 = (b*3 + a* coff) * 4 + + res1 = tvm.arith.DeduceBound(a, e0<63, {b: b_s}, {b: b_s}) + [t, x] = [res1.max_value, b_s.max_value] if coff > 0 else [res1.min_value, b_s.min_value] + assert (tvm.ir_pass.Simplify(((x*3 + t* coff) * 4) < 63)).value == 1 + + # expression containing variable a is on rhs + res1 = tvm.arith.DeduceBound(a, tvm.const(63, "int32")>= e0, {b: b_s}, {b: b_s}) + [t, x] = [res1.max_value, b_s.max_value] if coff > 0 else [res1.min_value, b_s.min_value] + assert (tvm.ir_pass.Simplify(((x*3 + t* coff) * 4) <= 63)).value == 1 + + res1 = tvm.arith.DeduceBound(a, e0>63, {b: b_s}, {b: b_s}) + [t, x] = [res1.max_value, b_s.max_value] if coff < 0 else [res1.min_value, b_s.min_value] + assert (tvm.ir_pass.Simplify(((x*3 + t* coff) * 4) > 63)).value == 1 + + # expression containing variable a is on rhs + res1 = tvm.arith.DeduceBound(a, tvm.const(63, "int32") <= e0, {b: b_s}, {b: b_s}) + [t, x] = [res1.max_value, b_s.max_value] if coff < 0 else [res1.min_value, b_s.min_value] + assert (tvm.ir_pass.Simplify(((x*3 + t* coff) * 4) >= 63)).value == 1 + + test_complex(0, 4, 4) + test_complex(0, 4, -4) + test_complex(2, 6, 4) + test_complex(0, 4, -4) + test_complex(1, 5, -4) + test_complex(2, 6, -4) + + +if __name__ == "__main__": + test_check() + test_deduce_basic() + test_deduce_complex() diff --git a/tests/python/unittest/test_arith_intset.py b/tests/python/unittest/test_arith_intset.py index a74162ec07f2..a6dfa449c370 100644 --- a/tests/python/unittest/test_arith_intset.py +++ b/tests/python/unittest/test_arith_intset.py @@ -16,168 +16,84 @@ # under the License. import tvm + +class IntSetChecker: + def __init__(self): + self.analyzer = tvm.arith.Analyzer() + + def verify(self, data, dmap, expected): + res = self.analyzer.int_set(data, dmap) + def err_msg(): + return "\ndata={}\ndmap={}\nres={}\nexpected={}".format(data, dmap, res, expected) + def equal(x, y): + res = self.analyzer.canonical_simplify(x - y) + return tvm.ir_pass.Equal(res, 0) + assert equal(res.min_value, expected[0]), err_msg() + assert equal(res.max_value, expected[1]), err_msg() + def test_basic(): - s = tvm.arith.intset_interval(2, 3) - assert s.min().value == 2 - assert s.max().value == 3 + s = tvm.arith.IntervalSet(2, 3) + assert s.min_value.value == 2 + assert s.max_value.value == 3 + def test_vector(): base = 10 stride = 3 lanes = 2 s = tvm.arith.intset_vector(tvm.make.Ramp(base, stride, lanes)) - assert s.min().value == base - assert s.max().value == base + stride * lanes - 1 - -def test_deduce(): - a = tvm.var('a') - b = tvm.var('b') - c = tvm.var('c') - d = tvm.var('d') - - b_s = tvm.arith.intset_interval(2, 3) - c_s = tvm.arith.intset_interval(10, 15) - d_s = tvm.arith.intset_interval(-3, -1) - zero = tvm.const(0, "int32") - - e0 = (-b)*a+c-d - res0 = tvm.arith.DeduceBound(a, e0>=0, {b: b_s, c: c_s, d: d_s}, {}) - ans0 = ((d - c) /(b*-1)) - assert str(tvm.ir_pass.Simplify(res0.max())) == str(ans0) - - # expression containing variable a is on rhs - res0 = tvm.arith.DeduceBound(a, zero <= e0, {b: b_s, c: c_s, d: d_s}, {}) - assert str(tvm.ir_pass.Simplify(res0.max())) == str(ans0) - - e0 = d*a+c-d - res0 = tvm.arith.DeduceBound(a, e0>=0, {b: b_s, c: c_s, d: d_s}, {}) - ans0 = ((0-c)/d + 1) - assert str(tvm.ir_pass.Simplify(res0.max())) == str(ans0) - - # expression containing variable a is on rhs - res0 = tvm.arith.DeduceBound(a, zero <= e0, {b: b_s, c: c_s, d: d_s}, {}) - assert str(tvm.ir_pass.Simplify(res0.max())) == str(ans0) - - e1 = (a*4+b < c) - res1 = tvm.arith.DeduceBound(a, e1, {b: b_s, c: c_s, d: d_s}, {}) - ans1 = (((c - b) + -1)/4) - assert str(tvm.ir_pass.Simplify(res1.max())) == str(ans1) - - # expression containing variable a is on rhs - e1 = (c > a*4+b) - res1 = tvm.arith.DeduceBound(a, e1, {b: b_s, c: c_s, d: d_s}, {}) - assert str(tvm.ir_pass.Simplify(res1.max())) == str(ans1) - - e2 = (tvm.max(5, a * 4) < 0) - res2 = tvm.arith.DeduceBound(a, e2, {b: b_s, c: c_s, d: d_s}, {}) - assert str(res2.max()) == "neg_inf" - assert str(res2.min()) == "pos_inf" - - # expression containing variable a is on rhs - e2 = (zero < tvm.max(5, a * 4)) - res2 = tvm.arith.DeduceBound(a, e2, {b: b_s, c: c_s, d: d_s}, {}) - assert str(res2.max()) == "neg_inf" - assert str(res2.min()) == "pos_inf" - - - e3 = (-b)+a*c-d - res3 = tvm.arith.DeduceBound(a, e3>=0, {b: b_s, c: c_s, d: d_s}, {b: b_s, d: d_s}) - ans3 = 2/c+1 - assert str(tvm.ir_pass.Simplify(res3.min())) == str(ans3) - - res3 = tvm.arith.DeduceBound(a, zero <= e3, {b: b_s, c: c_s, d: d_s}, {b: b_s, d: d_s}) - assert str(tvm.ir_pass.Simplify(res3.min())) == str(ans3) - -def test_check(): - a = tvm.var('a') - b = tvm.var('b') - c = tvm.var('c') - d = tvm.var('d') - - b_s = tvm.arith.intset_interval(2, 3) - c_s = tvm.arith.intset_interval(5, 7) - d_s = tvm.arith.intset_interval(-3, -1) - - # no compare operator - res1 = tvm.arith.DeduceBound(a, a+b, {b: b_s}, {}) - assert res1.is_nothing() - - # multiple compare operators - res2 = tvm.arith.DeduceBound(a, (a+b>3).astype(c.dtype)>c , {b: b_s, c: c_s}, {}) - assert res2.is_nothing() - - # multiple target variable - res2 = tvm.arith.DeduceBound(a, a*2-a>b, {b: b_s}, {}) - assert res2.is_nothing() - -def test_deduce_basic(): - def test_basic(a1, a2, coff): - a = tvm.var('a') - b = tvm.var('b') - b_s = tvm.arith.intset_interval(a1, a2) - e0 = b + a*coff + 3 - - res1 = tvm.arith.DeduceBound(a, e0<17, {b: b_s}, {b: b_s}) - [x, y] = [res1.max(), b_s.max()] if coff > 0 else [res1.min(), b_s.min()] - assert (tvm.ir_pass.Simplify((x * coff + 3 + y) < 17)).value == 1 - - # expression containing variable a is on rhs - res1 = tvm.arith.DeduceBound(a, tvm.const(17, "int32") < e0, {b: b_s}, {b: b_s}) - [x, y] = [res1.max(), b_s.max()] if coff < 0 else [res1.min(), b_s.min()] - assert (tvm.ir_pass.Simplify((x * coff + 3 + y) > 17)).value == 1 - - # expression containing variable a is on rhs - res1 = tvm.arith.DeduceBound(a, tvm.const(17, "int32")>= e0, {b: b_s}, {b: b_s}) - [x, y] = [res1.max(), b_s.max()] if coff > 0 else [res1.min(), b_s.min()] - assert (tvm.ir_pass.Simplify((x * coff + 3 + y) <= 17)).value == 1 - - res1 = tvm.arith.DeduceBound(a, e0>=17, {b: b_s}, {b: b_s}) - [x, y] = [res1.max(), b_s.max()] if coff < 0 else [res1.min(), b_s.min()] - assert (tvm.ir_pass.Simplify((x * coff + 3 + y) >= 17)).value == 1 - - test_basic(0, 4, 4) - test_basic(1, 5, 4) - test_basic(2, 6, 4) - test_basic(0, 4, -4) - test_basic(1, 5, -4) - test_basic(2, 6, -4) - -def test_deduce_complex(): - def test_complex(a1, a2, coff): - a = tvm.var('a') - b = tvm.var('b') - b_s = tvm.arith.intset_interval(a1, a2) - e0 = (b*3 + a* coff) * 4 - - res1 = tvm.arith.DeduceBound(a, e0<63, {b: b_s}, {b: b_s}) - [t, x] = [res1.max(), b_s.max()] if coff > 0 else [res1.min(), b_s.min()] - assert (tvm.ir_pass.Simplify(((x*3 + t* coff) * 4) < 63)).value == 1 - - # expression containing variable a is on rhs - res1 = tvm.arith.DeduceBound(a, tvm.const(63, "int32")>= e0, {b: b_s}, {b: b_s}) - [t, x] = [res1.max(), b_s.max()] if coff > 0 else [res1.min(), b_s.min()] - assert (tvm.ir_pass.Simplify(((x*3 + t* coff) * 4) <= 63)).value == 1 - - res1 = tvm.arith.DeduceBound(a, e0>63, {b: b_s}, {b: b_s}) - [t, x] = [res1.max(), b_s.max()] if coff < 0 else [res1.min(), b_s.min()] - assert (tvm.ir_pass.Simplify(((x*3 + t* coff) * 4) > 63)).value == 1 - - # expression containing variable a is on rhs - res1 = tvm.arith.DeduceBound(a, tvm.const(63, "int32") <= e0, {b: b_s}, {b: b_s}) - [t, x] = [res1.max(), b_s.max()] if coff < 0 else [res1.min(), b_s.min()] - assert (tvm.ir_pass.Simplify(((x*3 + t* coff) * 4) >= 63)).value == 1 - - test_complex(0, 4, 4) - test_complex(0, 4, -4) - test_complex(2, 6, 4) - test_complex(0, 4, -4) - test_complex(1, 5, -4) - test_complex(2, 6, -4) + assert s.min_value.value == base + assert s.max_value.value == base + stride * lanes - 1 + + +def test_add_sub(): + ck = IntSetChecker() + x, y = tvm.var("x"), tvm.var("y") + ck.verify(x + y, {x : tvm.arith.IntervalSet(0, 10)}, (y, 10 + y)) + ck.verify(x + y, + {x : tvm.arith.IntervalSet(0, 10), y : tvm.arith.IntervalSet(1, 11)}, + (1, 21)) + ck.verify(x - y, + {x : tvm.arith.IntervalSet(0, 10), y : tvm.arith.IntervalSet(1, 11)}, + (-11, 9)) + +def test_mul_div(): + ck = IntSetChecker() + x, y = tvm.var("x"), tvm.var("y") + ck.analyzer.update(y, tvm.arith.ConstIntBound(1, 100), override=True) + ck.verify(x * y, {x : tvm.arith.IntervalSet(0, 10)}, (0, 10 * y)) + ck.verify(x * 2, {x : tvm.arith.IntervalSet(1, 10)}, (2, 20)) + ck.verify(x * -2, {x : tvm.arith.IntervalSet(1, 10)}, (-20, -2)) + ck.verify(x / y, {x : tvm.arith.IntervalSet(0, 10)}, (0, 10 / y)) + ck.verify(x / 2, {x : tvm.arith.IntervalSet(1, 10)}, (0, 5)) + + +def test_mod(): + ck = IntSetChecker() + x, y = tvm.var("x"), tvm.var("y") + ck.analyzer.update(y, tvm.arith.ConstIntBound(1, 100), override=True) + ck.verify(x % y, {x : tvm.arith.IntervalSet(0, 10)}, (0, y - 1)) + ck.verify(x % 10, {x : tvm.arith.IntervalSet(1, 10)}, (0, 9)) + + +def test_max_min(): + ck = IntSetChecker() + x, y = tvm.var("x"), tvm.var("y") + ck.verify(tvm.max(x, x + 1), {x : tvm.arith.IntervalSet(0, 10)}, (1, 11)) + ck.verify(tvm.min(x - 1, x + 1), {x : tvm.arith.IntervalSet(0, 10)}, (-1, 9)) + + +def test_select(): + ck = IntSetChecker() + x, y = tvm.var("x"), tvm.var("y") + ck.verify(tvm.expr.Select(x > 0, x - 1, x + 1), + {x : tvm.arith.IntervalSet(0, 10)}, (-1, 11)) + if __name__ == "__main__": test_basic() test_vector() - test_deduce() - test_check() - test_deduce_basic() - test_deduce_complex() + test_add_sub() + test_mul_div() + test_max_min() + test_select()