Skip to content

Commit

Permalink
[ARITH] Revamp IntSet
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen committed Jun 2, 2019
1 parent 4767554 commit 0e5701f
Show file tree
Hide file tree
Showing 17 changed files with 1,174 additions and 840 deletions.
194 changes: 109 additions & 85 deletions include/tvm/arithmetic.h
Original file line number Diff line number Diff line change
Expand Up @@ -328,71 +328,14 @@ class ConstraintContext {
std::function<void()> exit_;
};

/*!
* \brief Analyzer that contains bunch of sub-analyzers.
*
* Each sub-analyzer can make use of another sub-analyzer
* by weak reference of this.
*
* NOTE for sub-analyzer developers:
* If the analyzer uses memoization, we need to clear the internal
* cache when information about a Var has been overrideen.
*/
class Analyzer {
public:
/*! \brief sub-analyzer: const integer bound */
ConstIntBoundAnalyzer const_int_bound;
/*! \brief sub-analyzer: modular set */
ModularSetAnalyzer modular_set;
/*! \brief sub-analyzer rewrite simplify */
RewriteSimplifier rewrite_simplify;
/*! \brief sub-analyzer canonical simplify */
CanonicalSimplifier canonical_simplify;
/*! \brief constructor */
Analyzer();
/*!
* \brief Notify all the sub-analyzers that var
* is created and binded to expr.
*
* Each var can only be binded once.
*
* \param var The variable.
* \param expr The expression we bind to.
*/
void Bind(const VarExpr& var, const Expr& expr);
/*!
* \brief Notify all the sub-analyzers that var
* is created and binded to a range.
*
* Each var can only be binded once.
*
* \param var The variable.
* \param range The range we bind to.
*/
void Bind(const VarExpr& var, const Range& range);
/*!
* \brief Whether can we proof expr >= val.
* Non-negative proof is very useful in integer analysis
* to lower divisions and mods given difference in trunc and ceil mode.
*
* \param expr The expression.
* \param lower_bound The lower bound.
* \return Whether we can proof it.
*
* \note Analyzer will call into sub-analyzers to get the result.
*/
bool CanProveGreaterEqual(const Expr& expr, int64_t lower_bound);
};

//-----------------------------------------------
// Integer set abstraction API.
// Integer set data structure.
//
// This is a API build on top of the base
// integer analysis API to provide set analysis.
//------------------------------------------------
/*!
* \brief Sign of an expression or set.
* \brief Sign type of an integer expression.
*/
enum SignType {
kPositive,
Expand All @@ -401,8 +344,13 @@ enum SignType {
kUnknown
};

// internal node container of int set.
struct IntSetNode;
/*!
* \brief Base class of all IntSet containers.
*/
struct IntSetNode : public Node {
static constexpr const char* _type_key = "IntSet";
TVM_DECLARE_BASE_NODE_INFO(IntSetNode, Node);
};

/*!
* \brief Integer set class, represent a set of integers in one dimension.
Expand All @@ -424,11 +372,6 @@ class IntSet : public NodeRef {
* \return The covering range.
*/
Range cover_range(Range max_range) const;
/*!
* \brief find an interval that covers the set.
* \return The covering interval set.
*/
IntSet cover_interval() const;
/*! \return Lower bound of the set */
Expr min() const;
/*! \return upper bound of the set */
Expand Down Expand Up @@ -493,33 +436,91 @@ class IntSet : public NodeRef {
};

/*!
* \brief Base class of all IntSet containers.
* \brief Integer set analyzer.
*/
struct IntSetNode : public Node {
static constexpr const char* _type_key = "IntSet";
TVM_DECLARE_BASE_NODE_INFO(IntSetNode, Node);
class IntSetAnalyzer {
public:
/*!
* \brief Find an symbolic integer set that contains all possible values of
* expr given the domain of each variables.
*
* \param expr The expression of interest.
* \param dom_map The domain map to indicate which variable to relax.
* \return the result of the analysis.
*/
IntSet operator()(const Expr& expr, const Map<Var, IntSet>& dom_map);

private:
friend class Analyzer;
explicit IntSetAnalyzer(Analyzer* parent);
~IntSetAnalyzer();
class Impl;
/*! \brief Internal impl */
Impl* impl_;
};

/*!
* \brief Detect if e can be rewritten as e = sum_{i=0}^{n-1} var[i] * coeff[i] + coeff[n]
* Where coeff[i] and base are invariant of var[j] for all i and j.
* \brief Analyzer that contains bunch of sub-analyzers.
*
* \param e The expression to be detected.
* \param vars List of variables to be used in detection.
* \return [coeff[i]] if it is possible, empty array if it is not.
*/
Array<Expr> DetectLinearEquation(const Expr& e, const Array<Var>& vars);

/*!
* \brief Detect if expression corresponds to clip bound of the vars
* Each sub-analyzer can make use of another sub-analyzer
* by weak reference of this.
*
* \param e The expression to be detected.
* \param vars List of variables to be used in detection.
* \return concat([min_value[i], max_value[i]]), None is returned if there is no min or max value
* return empty if the e does not match the pattern.
* NOTE for sub-analyzer developers:
* If the analyzer uses memoization, we need to clear the internal
* cache when information about a Var has been overrideen.
*/
Array<Expr> DetectClipBound(const Expr& e, const Array<Var>& vars);
class Analyzer {
public:
/*! \brief sub-analyzer: const integer bound */
ConstIntBoundAnalyzer const_int_bound;
/*! \brief sub-analyzer: modular set */
ModularSetAnalyzer modular_set;
/*! \brief sub-analyzer rewrite simplify */
RewriteSimplifier rewrite_simplify;
/*! \brief sub-analyzer canonical simplify */
CanonicalSimplifier canonical_simplify;
/*! \brief sub-analyzer: int set */
IntSetAnalyzer int_set;
/*! \brief constructor */
Analyzer();
/*!
* \brief Notify all the sub-analyzers that var
* is created and binded to expr.
*
* Each var can only be binded once.
*
* \param var The variable.
* \param expr The expression we bind to.
*/
void Bind(const VarExpr& var, const Expr& expr);
/*!
* \brief Notify all the sub-analyzers that var
* is created and binded to a range.
*
* Each var can only be binded once.
*
* \param var The variable.
* \param range The range we bind to.
*/
void Bind(const VarExpr& var, const Range& range);
/*!
* \brief Whether can we proof expr >= val.
* Non-negative proof is very useful in integer analysis
* to lower divisions and mods given difference in trunc and ceil mode.
*
* \param expr The expression.
* \param lower_bound The lower bound.
* \return Whether we can proof it.
*
* \note Analyzer will call into sub-analyzers to get the result.
*/
bool CanProveGreaterEqual(const Expr& expr, int64_t lower_bound);
};

//-----------------------------------------------
// Integer set legacy API.
//------------------------------------------------
/*!
* \brief Find an symbolic integer set that contains all possible values of
* e given the domain of each iteration variables.
Expand Down Expand Up @@ -638,6 +639,29 @@ IntSet DeduceBound(Expr v, Expr cond,
*/
Domain DomainTouched(Stmt body, const Tensor &tensor, bool consider_calls, bool consider_provides);

// Expression pattern detector.
/*!
* \brief Detect if e can be rewritten as e = sum_{i=0}^{n-1} var[i] * coeff[i] + coeff[n]
* Where coeff[i] and base are invariant of var[j] for all i and j.
*
* \param e The expression to be detected.
* \param vars List of variables to be used in detection.
* \return [coeff[i]] if it is possible, empty array if it is not.
*/
Array<Expr> DetectLinearEquation(const Expr& e,
const Array<Var>& vars);

/*!
* \brief Detect if expression corresponds to clip bound of the vars
*
* \param e The expression to be detected.
* \param vars List of variables to be used in detection.
* \return concat([min_value[i], max_value[i]]), None is returned if there is no min or max value
* return empty if the e does not match the pattern.
*/
Array<Expr> DetectClipBound(const Expr& e,
const Array<Var>& vars);

// implementation
inline const IntSetNode* IntSet::operator->() const {
return static_cast<const IntSetNode*>(node_.get());
Expand Down
43 changes: 31 additions & 12 deletions python/tvm/arith.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,21 +32,21 @@ def is_everything(self):
return _api_internal._IntSetIsEverything(self)


@register_node
@register_node("arith.IntervalSet")
class IntervalSet(IntSet):
"""Represent set of continuous interval"""
def min(self):
"""get the minimum value"""
return _api_internal._IntervalSetGetMin(self)

def max(self):
"""get the maximum value"""
return _api_internal._IntervalSetGetMax(self)
"""Represent set of continuous interval [min_value, max_value]
Parameters
----------
min_value : Expr
The minimum value in the interval.
@register_node
class StrideSet(IntSet):
"""Represent set of strided integers"""
max_value : Expr
The maximum value in the interval.
"""
def __init__(self, min_value, max_value):
self.__init_handle_by_constructor__(
_make_IntervalSet, min_value, max_value)


@register_node("arith.ModularSet")
Expand Down Expand Up @@ -114,6 +114,7 @@ def __init__(self):
self._modular_set = _mod("modular_set")
self._rewrite_simplify = _mod("rewrite_simplify")
self._canonical_simplify = _mod("canonical_simplify")
self._int_set = _mod("int_set")
self._enter_constraint_context = _mod("enter_constraint_context")

def const_int_bound(self, expr):
Expand Down Expand Up @@ -176,6 +177,24 @@ def canonical_simplify(self, expr):
"""
return self._canonical_simplify(expr)

def int_set(self, expr, dom_map):
"""Compute a symbolic IntSet that covers expr for all values in dom_map.
Parameters
----------
expr : tvm.Expr
The expression.
dom_map : Dict[Var, tvm.arith.IntSet]
The domain for variables to be relaxed.
Returns
-------
result : IntSet
The result.
"""
return self._int_set(expr, dom_map)

def bind(self, var, expr):
"""Bind a variable to the expression.
Expand Down
5 changes: 5 additions & 0 deletions src/api/api_arith.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ TVM_REGISTER_API("arith.intset_vector")
TVM_REGISTER_API("arith.intset_interval")
.set_body_typed(IntSet::interval);


TVM_REGISTER_API("arith.DetectLinearEquation")
.set_body_typed(DetectLinearEquation);

Expand Down Expand Up @@ -110,6 +111,10 @@ TVM_REGISTER_API("arith._CreateAnalyzer")
return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
*ret = self->canonical_simplify(args[0]);
});
} else if (name == "int_set") {
return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
*ret = self->int_set(args[0], args[1]);
});
} else if (name == "bind") {
return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
auto& sptr = args[1].node_sptr();
Expand Down
4 changes: 3 additions & 1 deletion src/arithmetic/analyzer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ Analyzer::Analyzer()
: const_int_bound(this),
modular_set(this),
rewrite_simplify(this),
canonical_simplify(this) {
canonical_simplify(this),
int_set(this) {
}

void Analyzer::Bind(const VarExpr& v, const Expr& expr) {
Expand Down Expand Up @@ -77,6 +78,7 @@ bool Analyzer::CanProveGreaterEqual(const Expr& expr, int64_t lower_bound) {
return ptr->value > lower_bound;
}
auto bd = this->const_int_bound(this->rewrite_simplify(expr));
LOG(INFO) << bd;
if (bd->min_value >= lower_bound) return true;
return false;
}
Expand Down
8 changes: 4 additions & 4 deletions src/arithmetic/bound_deducer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
Expand All @@ -30,12 +30,12 @@

#include <unordered_set>
#include <unordered_map>
#include "int_set.h"

namespace tvm {
namespace arith {

using namespace ir;
using HalideIR::Internal::Interval;

// a visitor to find the path to the target variable
// from a expression.
Expand Down Expand Up @@ -293,7 +293,7 @@ IntSet DeduceBound(Expr v, Expr e,
BoundDeducer d(v, e, hint_map, relax_map);
d.Deduce();
if (!d.success) return IntSet::nothing();
Expr min = Interval::neg_inf, max = Interval::pos_inf;
Expr min = neg_inf(), max = pos_inf();
if (d.is_greater) {
min = d.result;
} else {
Expand Down
1 change: 0 additions & 1 deletion src/arithmetic/canonical_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
*/

/*!
* Copyright (c) 2019 by Contributors
* \file canonical_simplify.cc
* \brief Canonical form based simplification.
*/
Expand Down
Loading

0 comments on commit 0e5701f

Please sign in to comment.