Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ARITH] Revamp IntSet #3272

Merged
merged 10 commits into from
Jun 13, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
194 changes: 109 additions & 85 deletions include/tvm/arithmetic.h
Original file line number Diff line number Diff line change
Expand Up @@ -328,71 +328,14 @@ class ConstraintContext {
std::function<void()> exit_;
};

/*!
* \brief Analyzer that contains bunch of sub-analyzers.
*
* Each sub-analyzer can make use of another sub-analyzer
* by weak reference of this.
*
* NOTE for sub-analyzer developers:
* If the analyzer uses memoization, we need to clear the internal
* cache when information about a Var has been overrideen.
*/
class Analyzer {
public:
/*! \brief sub-analyzer: const integer bound */
ConstIntBoundAnalyzer const_int_bound;
/*! \brief sub-analyzer: modular set */
ModularSetAnalyzer modular_set;
/*! \brief sub-analyzer rewrite simplify */
RewriteSimplifier rewrite_simplify;
/*! \brief sub-analyzer canonical simplify */
CanonicalSimplifier canonical_simplify;
/*! \brief constructor */
Analyzer();
/*!
* \brief Notify all the sub-analyzers that var
* is created and binded to expr.
*
* Each var can only be binded once.
*
* \param var The variable.
* \param expr The expression we bind to.
*/
void Bind(const VarExpr& var, const Expr& expr);
/*!
* \brief Notify all the sub-analyzers that var
* is created and binded to a range.
*
* Each var can only be binded once.
*
* \param var The variable.
* \param range The range we bind to.
*/
void Bind(const VarExpr& var, const Range& range);
/*!
* \brief Whether can we proof expr >= val.

* Non-negative proof is very useful in integer analysis
* to lower divisions and mods given difference in trunc and ceil mode.
*
* \param expr The expression.
* \param lower_bound The lower bound.
* \return Whether we can proof it.
*
* \note Analyzer will call into sub-analyzers to get the result.
*/
bool CanProveGreaterEqual(const Expr& expr, int64_t lower_bound);
};

//-----------------------------------------------
// Integer set abstraction API.
// Integer set data structure.
//
// This is a API build on top of the base
// integer analysis API to provide set analysis.
//------------------------------------------------
/*!
* \brief Sign of an expression or set.
* \brief Sign type of an integer expression.
*/
enum SignType {
kPositive,
Expand All @@ -401,8 +344,13 @@ enum SignType {
kUnknown
};

// internal node container of int set.
struct IntSetNode;
/*!
* \brief Base class of all IntSet containers.
*/
struct IntSetNode : public Node {
static constexpr const char* _type_key = "IntSet";
TVM_DECLARE_BASE_NODE_INFO(IntSetNode, Node);
};

/*!
* \brief Integer set class, represent a set of integers in one dimension.
Expand All @@ -424,11 +372,6 @@ class IntSet : public NodeRef {
* \return The covering range.
*/
Range cover_range(Range max_range) const;
/*!
* \brief find an interval that covers the set.
* \return The covering interval set.
*/
IntSet cover_interval() const;
/*! \return Lower bound of the set */
Expr min() const;
/*! \return upper bound of the set */
Expand Down Expand Up @@ -493,33 +436,91 @@ class IntSet : public NodeRef {
};

/*!
* \brief Base class of all IntSet containers.
* \brief Integer set analyzer.
*/
struct IntSetNode : public Node {
static constexpr const char* _type_key = "IntSet";
TVM_DECLARE_BASE_NODE_INFO(IntSetNode, Node);
class IntSetAnalyzer {
public:
/*!
* \brief Find a symbolic integer set that contains all possible values of
* expr given the domain of each variables.
*
* \param expr The expression of interest.
* \param dom_map The domain map to indicate which variable to relax.
* \return the result of the analysis.
*/
IntSet operator()(const Expr& expr, const Map<Var, IntSet>& dom_map);

private:
friend class Analyzer;
explicit IntSetAnalyzer(Analyzer* parent);
~IntSetAnalyzer();
class Impl;
/*! \brief Internal impl */
Impl* impl_;
};

/*!
* \brief Detect if e can be rewritten as e = sum_{i=0}^{n-1} var[i] * coeff[i] + coeff[n]
* Where coeff[i] and base are invariant of var[j] for all i and j.
* \brief Analyzer that contains bunch of sub-analyzers.
*
* \param e The expression to be detected.
* \param vars List of variables to be used in detection.
* \return [coeff[i]] if it is possible, empty array if it is not.
*/
Array<Expr> DetectLinearEquation(const Expr& e, const Array<Var>& vars);

/*!
* \brief Detect if expression corresponds to clip bound of the vars
* Each sub-analyzer can make use of another sub-analyzer
* by weak reference of this.
*
* \param e The expression to be detected.
* \param vars List of variables to be used in detection.
* \return concat([min_value[i], max_value[i]]), None is returned if there is no min or max value
* return empty if the e does not match the pattern.
* NOTE for sub-analyzer developers:
* If the analyzer uses memoization, we need to clear the internal
* cache when information about a Var has been overridden.
*/
Array<Expr> DetectClipBound(const Expr& e, const Array<Var>& vars);
class Analyzer {
public:
/*! \brief sub-analyzer: const integer bound */
ConstIntBoundAnalyzer const_int_bound;
/*! \brief sub-analyzer: modular set */
ModularSetAnalyzer modular_set;
/*! \brief sub-analyzer rewrite simplify */
RewriteSimplifier rewrite_simplify;
/*! \brief sub-analyzer canonical simplify */
CanonicalSimplifier canonical_simplify;
/*! \brief sub-analyzer: int set */
IntSetAnalyzer int_set;
/*! \brief constructor */
Analyzer();
/*!
* \brief Notify all the sub-analyzers that var
* is created and binded to expr.
*
* Each var can only be binded once.
*
* \param var The variable.
* \param expr The expression we bind to.
*/
void Bind(const VarExpr& var, const Expr& expr);
/*!
* \brief Notify all the sub-analyzers that var
* is created and binded to a range.
*
* Each var can only be binded once.
*
* \param var The variable.
* \param range The range we bind to.
*/
void Bind(const VarExpr& var, const Range& range);
/*!
* \brief Whether can we prove expr >= val.

* Non-negative proof is very useful in integer analysis
* to lower divisions and mods given difference in trunc and ceil mode.
*
* \param expr The expression.
* \param lower_bound The lower bound.
* \return Whether we can prove it.
*
* \note Analyzer will call into sub-analyzers to get the result.
*/
bool CanProveGreaterEqual(const Expr& expr, int64_t lower_bound);
};

//-----------------------------------------------
// Integer set legacy API.
//------------------------------------------------
/*!
* \brief Find an symbolic integer set that contains all possible values of
* e given the domain of each iteration variables.
Expand Down Expand Up @@ -638,6 +639,29 @@ IntSet DeduceBound(Expr v, Expr cond,
*/
Domain DomainTouched(Stmt body, const Tensor &tensor, bool consider_calls, bool consider_provides);

// Expression pattern detector.
/*!
* \brief Detect if e can be rewritten as e = sum_{i=0}^{n-1} var[i] * coeff[i] + coeff[n]
* Where coeff[i] and base are invariant of var[j] for all i and j.
*
* \param e The expression to be detected.
* \param vars List of variables to be used in detection.
* \return [coeff[i]] if it is possible, empty array if it is not.
*/
Array<Expr> DetectLinearEquation(const Expr& e,
const Array<Var>& vars);

/*!
* \brief Detect if expression corresponds to clip bound of the vars
*
* \param e The expression to be detected.
* \param vars List of variables to be used in detection.
* \return concat([min_value[i], max_value[i]]), None is returned if there is no min or max value
* return empty if the e does not match the pattern.
*/
Array<Expr> DetectClipBound(const Expr& e,
const Array<Var>& vars);

// implementation
inline const IntSetNode* IntSet::operator->() const {
return static_cast<const IntSetNode*>(node_.get());
Expand Down
43 changes: 31 additions & 12 deletions python/tvm/arith.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,21 +32,21 @@ def is_everything(self):
return _api_internal._IntSetIsEverything(self)


@register_node
@register_node("arith.IntervalSet")
class IntervalSet(IntSet):
"""Represent set of continuous interval"""
def min(self):
"""get the minimum value"""
return _api_internal._IntervalSetGetMin(self)

def max(self):
"""get the maximum value"""
return _api_internal._IntervalSetGetMax(self)
"""Represent set of continuous interval [min_value, max_value]

Parameters
----------
min_value : Expr
The minimum value in the interval.

@register_node
class StrideSet(IntSet):
"""Represent set of strided integers"""
max_value : Expr
The maximum value in the interval.
"""
def __init__(self, min_value, max_value):
self.__init_handle_by_constructor__(
_make_IntervalSet, min_value, max_value)


@register_node("arith.ModularSet")
Expand Down Expand Up @@ -114,6 +114,7 @@ def __init__(self):
self._modular_set = _mod("modular_set")
self._rewrite_simplify = _mod("rewrite_simplify")
self._canonical_simplify = _mod("canonical_simplify")
self._int_set = _mod("int_set")
self._enter_constraint_context = _mod("enter_constraint_context")

def const_int_bound(self, expr):
Expand Down Expand Up @@ -176,6 +177,24 @@ def canonical_simplify(self, expr):
"""
return self._canonical_simplify(expr)

def int_set(self, expr, dom_map):
"""Compute a symbolic IntSet that covers expr for all values in dom_map.

Parameters
----------
expr : tvm.Expr
The expression.

dom_map : Dict[Var, tvm.arith.IntSet]
The domain for variables to be relaxed.

Returns
-------
result : IntSet
The result.
"""
return self._int_set(expr, dom_map)

def bind(self, var, expr):
"""Bind a variable to the expression.

Expand Down
5 changes: 5 additions & 0 deletions src/api/api_arith.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ TVM_REGISTER_API("arith.intset_vector")
TVM_REGISTER_API("arith.intset_interval")
.set_body_typed(IntSet::interval);


TVM_REGISTER_API("arith.DetectLinearEquation")
.set_body_typed(DetectLinearEquation);

Expand Down Expand Up @@ -110,6 +111,10 @@ TVM_REGISTER_API("arith._CreateAnalyzer")
return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
*ret = self->canonical_simplify(args[0]);
});
} else if (name == "int_set") {
return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
*ret = self->int_set(args[0], args[1]);
});
} else if (name == "bind") {
return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
auto& sptr = args[1].node_sptr();
Expand Down
5 changes: 3 additions & 2 deletions src/arithmetic/analyzer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ Analyzer::Analyzer()
: const_int_bound(this),
modular_set(this),
rewrite_simplify(this),
canonical_simplify(this) {
canonical_simplify(this),
int_set(this) {
}

void Analyzer::Bind(const VarExpr& v, const Expr& expr) {
Expand Down Expand Up @@ -74,7 +75,7 @@ void ConstraintContext::ExitWithScope() {

bool Analyzer::CanProveGreaterEqual(const Expr& expr, int64_t lower_bound) {
if (const auto* ptr = expr.as<ir::IntImm>()) {
return ptr->value > lower_bound;
return ptr->value >= lower_bound;
}
auto bd = this->const_int_bound(this->rewrite_simplify(expr));
if (bd->min_value >= lower_bound) return true;
Expand Down
8 changes: 4 additions & 4 deletions src/arithmetic/bound_deducer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
Expand All @@ -30,12 +30,12 @@

#include <unordered_set>
#include <unordered_map>
#include "int_set.h"

namespace tvm {
namespace arith {

using namespace ir;
using HalideIR::Internal::Interval;

// a visitor to find the path to the target variable
// from a expression.
Expand Down Expand Up @@ -293,7 +293,7 @@ IntSet DeduceBound(Expr v, Expr e,
BoundDeducer d(v, e, hint_map, relax_map);
d.Deduce();
if (!d.success) return IntSet::nothing();
Expr min = Interval::neg_inf, max = Interval::pos_inf;
Expr min = neg_inf(), max = pos_inf();
if (d.is_greater) {
min = d.result;
} else {
Expand Down
Loading