-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Arith] Add internal NarrowPredicateExpression utility
Implements `tvm::arith::NarrowPredicateExpression`, a utility that removes free parameters from a boolean expression, such that the resulting expression being true implies that the original expression is true. For example, the predicate `(0 <= i+f) && (i+f < 16)`, where `f` is a free parameter on the range `0 <= f < 2)`, can be narrowed to the expression `(0 <= i+0) && (i+2 < 16)`. In effect, `NarrowPredicateExpression` functions as a context-sentive `tvm::tir::Substitute`, where the value substituted is selected such that the resulting expression errs on the side of being false. This is an internal utility used as part of the simplifications for layout transformations ([tracking issue link](#12261)).
- Loading branch information
1 parent
605876e
commit c6d78c1
Showing
3 changed files
with
363 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,219 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
|
||
/*! | ||
* \file narrow_predicate_expression.cc | ||
* \brief Utility to deduce bound of expression | ||
*/ | ||
#include <tvm/arith/int_solver.h> | ||
#include <tvm/runtime/registry.h> | ||
#include <tvm/tir/analysis.h> | ||
#include <tvm/tir/expr.h> | ||
#include <tvm/tir/op.h> | ||
#include <tvm/tir/stmt_functor.h> | ||
|
||
namespace tvm { | ||
namespace arith { | ||
|
||
using namespace tir; | ||
|
||
/* \brief Given a true expression that includes free parameter, | ||
* generate a true expression without the free parameters. | ||
* | ||
* This function provides two guarantees: | ||
* | ||
* 1. If the resulting expression evaluates to True, then the original | ||
* expression also evaluates to True. | ||
* | ||
* 2. The resulting expression does not contain any of the free | ||
* parameters. | ||
* | ||
*/ | ||
// Utility for generating a known true expression from an expression | ||
// with free parameters, and the range of those parameters. | ||
class ExpressionNarrower : public tir::ExprMutator { | ||
public: | ||
static PrimExpr Apply(PrimExpr expr, Map<Var, Range> free_parameters) { | ||
ICHECK(expr.dtype().is_bool()) << "Expected boolean expression, but received " << expr; | ||
ExpressionNarrower mutator(free_parameters); | ||
return mutator(expr); | ||
} | ||
|
||
private: | ||
explicit ExpressionNarrower(Map<Var, Range> free_parameters) | ||
: free_parameters_(free_parameters) {} | ||
|
||
using Parent = tir::ExprMutator; | ||
using Parent::VisitExpr_; | ||
|
||
enum class Context { | ||
Maximize, | ||
Minimize, | ||
}; | ||
|
||
template <typename T> | ||
PrimExpr VisitInequality(T t, Context a_ctx, Context b_ctx) { | ||
PrimExpr a = [&]() { | ||
WithContext context(this, a_ctx); | ||
return this->VisitExpr(t->a); | ||
}(); | ||
|
||
PrimExpr b = [&]() { | ||
WithContext context(this, b_ctx); | ||
return this->VisitExpr(t->b); | ||
}(); | ||
|
||
if (contains_unknown_expr_ && t.dtype().is_bool()) { | ||
contains_unknown_expr_ = false; | ||
return Bool(CurrentContext() == Context::Minimize); | ||
} else if (a.same_as(t->a) && b.same_as(t->b)) { | ||
return std::move(t); | ||
} else { | ||
return T(a, b); | ||
} | ||
} | ||
|
||
PrimExpr VisitExpr_(const FloorModNode* op) override { | ||
// FloorMod is non-monotonic, so inserting min/max won't remove | ||
// the free parameters. | ||
contains_unknown_expr_ = true; | ||
return Parent::VisitExpr_(op); | ||
} | ||
|
||
PrimExpr VisitExpr_(const FloorDivNode* op) override { | ||
auto res_a = this->VisitExpr(op->a); | ||
auto res_b = this->VisitExpr(op->b); | ||
if (is_zero(res_b)) { | ||
contains_unknown_expr_ = true; | ||
return IntImm(op->dtype, 0); | ||
} else { | ||
return floordiv(res_a, res_b); | ||
} | ||
} | ||
|
||
PrimExpr VisitExpr_(const GTNode* op) override { | ||
auto current = CurrentContext(); | ||
return VisitInequality(GetRef<GT>(op), OppositeContext(current), current); | ||
} | ||
|
||
PrimExpr VisitExpr_(const GENode* op) override { | ||
auto current = CurrentContext(); | ||
return VisitInequality(GetRef<GE>(op), OppositeContext(current), current); | ||
} | ||
|
||
PrimExpr VisitExpr_(const LTNode* op) override { | ||
auto current = CurrentContext(); | ||
return VisitInequality(GetRef<LT>(op), current, OppositeContext(current)); | ||
} | ||
|
||
PrimExpr VisitExpr_(const LENode* op) override { | ||
auto current = CurrentContext(); | ||
return VisitInequality(GetRef<LE>(op), current, OppositeContext(current)); | ||
} | ||
|
||
PrimExpr VisitExpr_(const EQNode* op) override { | ||
auto res_a = this->VisitExpr(op->a <= op->b); | ||
auto res_b = this->VisitExpr(op->b <= op->a); | ||
return res_a && res_b; | ||
} | ||
|
||
PrimExpr VisitExpr_(const NENode* op) override { | ||
auto res_a = this->VisitExpr(op->a < op->b); | ||
auto res_b = this->VisitExpr(op->b < op->a); | ||
return res_a || res_b; | ||
} | ||
|
||
PrimExpr VisitExpr_(const SubNode* op) override { | ||
auto current = CurrentContext(); | ||
return VisitInequality(GetRef<Sub>(op), current, OppositeContext(current)); | ||
} | ||
|
||
PrimExpr VisitExpr_(const NotNode* op) override { | ||
auto current = CurrentContext(); | ||
WithContext context(this, OppositeContext(current)); | ||
return !VisitExpr(op->a); | ||
} | ||
|
||
PrimExpr VisitExpr_(const BufferLoadNode* op) override { | ||
contains_unknown_expr_ = true; | ||
return GetRef<PrimExpr>(op); | ||
} | ||
|
||
PrimExpr VisitExpr_(const VarNode* op) override { | ||
auto it = free_parameters_.find(GetRef<Var>(op)); | ||
if (it == free_parameters_.end()) { | ||
return Parent::VisitExpr_(op); | ||
} | ||
|
||
Range range = (*it).second; | ||
|
||
switch (CurrentContext()) { | ||
case Context::Minimize: | ||
return range->min; | ||
|
||
case Context::Maximize: | ||
return range->min + range->extent - 1; | ||
} | ||
|
||
return Parent::VisitExpr_(op); | ||
} | ||
|
||
Context CurrentContext() const { | ||
if (context_stack_.size()) { | ||
return context_stack_.back(); | ||
} else { | ||
return Context::Maximize; | ||
} | ||
} | ||
|
||
Context OppositeContext(Context context) const { | ||
switch (context) { | ||
case Context::Minimize: | ||
return Context::Maximize; | ||
|
||
case Context::Maximize: | ||
return Context::Minimize; | ||
|
||
default: | ||
LOG(FATAL) << "Unhandled Context, all legal values should be handled"; | ||
return Context::Maximize; | ||
} | ||
} | ||
|
||
struct WithContext { | ||
WithContext(ExpressionNarrower* self, Context context) : self(self) { | ||
self->context_stack_.push_back(context); | ||
} | ||
~WithContext() { self->context_stack_.pop_back(); } | ||
ExpressionNarrower* self; | ||
}; | ||
|
||
std::vector<Context> context_stack_; | ||
Map<Var, Range> free_parameters_; | ||
bool contains_unknown_expr_{false}; | ||
}; | ||
|
||
PrimExpr NarrowPredicateExpression(PrimExpr expr, Map<Var, Range> free_parameters) { | ||
return ExpressionNarrower::Apply(std::move(expr), std::move(free_parameters)); | ||
} | ||
|
||
TVM_REGISTER_GLOBAL("arith.NarrowPredicateExpression").set_body_typed(NarrowPredicateExpression); | ||
|
||
} // namespace arith | ||
} // namespace tvm |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
|
||
/*! | ||
* \file narrow_predicate_expression.h | ||
* \brief Utility for extracting and interacting with buffer touch points | ||
*/ | ||
|
||
#include <tvm/ir/expr.h> | ||
#include <tvm/tir/var.h> | ||
|
||
#ifndef TVM_ARITH_NARROW_PREDICATE_EXPRESSION_H_ | ||
#define TVM_ARITH_NARROW_PREDICATE_EXPRESSION_H_ | ||
|
||
namespace tvm { | ||
namespace arith { | ||
|
||
/* \brief Narrow a true expression to remove free parameters | ||
* | ||
* This function provides two guarantees: | ||
* | ||
* 1. If the resulting expression evaluates to True, then the original | ||
* expression also evaluates to True. | ||
* | ||
* 2. The resulting expression does not contain any of the free | ||
* parameters. | ||
* | ||
* 3. The resulting expression does not contain any BufferLoad | ||
* | ||
* \param expr The expression to be examined. | ||
* | ||
* \param ranges The variables to be removed from the expression | ||
* | ||
* \returns An expression that, if true, implies that the original | ||
* expression is also true. | ||
*/ | ||
PrimExpr NarrowPredicateExpression(PrimExpr expr, Map<tir::Var, Range> free_parameters); | ||
|
||
} // namespace arith | ||
} // namespace tvm | ||
#endif // TVM_ARITH_NARROW_PREDICATE_EXPRESSION_H_ |
87 changes: 87 additions & 0 deletions
87
tests/python/unittest/test_arith_narrow_predicate_expression.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
|
||
import tvm | ||
import tvm.testing | ||
|
||
from tvm import tir | ||
from tvm.runtime import convert | ||
|
||
|
||
i = tir.Var("i", "int32") | ||
j = tir.Var("j", "int32") | ||
n = tir.Var("n", "int32") | ||
m = tir.Var("m", "int32") | ||
b = tir.Var("b", "bool") | ||
buf = tir.decl_buffer(16, "int32", "buf") | ||
|
||
tir_false = tir.IntImm("bool", False) | ||
tir_true = tir.IntImm("bool", True) | ||
|
||
before, expected = tvm.testing.parameters( | ||
# General arithmatic | ||
[tir_true, tir_true], | ||
[tir_false, tir_false], | ||
[b, b], | ||
[i > 5, i > 5], | ||
[i > n, i > 7], | ||
[i < n, i < 0], | ||
[i <= n, i <= 0], | ||
[i >= n, i >= 7], | ||
[n > i, convert(0) > i], | ||
[n < i, convert(7) < i], | ||
[n <= i, convert(7) <= i], | ||
[n >= i, convert(0) >= i], | ||
[i == n, tir.all(i <= 0, convert(7) <= i)], | ||
[n == i, tir.all(convert(7) <= i, i <= 0)], | ||
[i != n, tir.any(i < 0, convert(7) < i)], | ||
[n != i, tir.any(convert(7) < i, i < 0)], | ||
[i // 4 > n, i // 4 > 7], | ||
[n < i // 4, convert(7) < i // 4], | ||
[(i + n) // 4 > 0, tir.Add(i, 0) // 4 > 0], | ||
[(i + n) // 4 == 0, tir.all(tir.Add(i, 7) // 4 <= 0, convert(0) <= tir.Add(i, 0) // 4)], | ||
[i + n < 10, i + 7 < 10], | ||
[i - n < 10, tir.Sub(i, 0) < 10], | ||
[tir.Not(i < n), tir.Not(i < 7)], | ||
# Use of FloorMod should make the narrowing strategy bail out, as | ||
# it is non-monotonic. | ||
[i % 8 == n, tir_false], | ||
# Ensure that dividing by a free parameter doesn't generate a | ||
# divide-by-zero to be triggered later. | ||
[i // n == 0, tir_false], | ||
### Buffer handling | ||
[buf.vload(0) > 0, tir_false], | ||
[buf.vload(0) > i, tir_false], | ||
[buf.vload(i) > 0, tir_false], | ||
[tir.And(buf.vload(i) > 0, i <= 0), tir.And(tir_false, i <= 0)], | ||
[tir.Or(buf.vload(i) > 0, i <= n), tir.Or(tir_false, i <= 0)], | ||
[tir.Or(tir.Not(buf.vload(i) > 0), i <= n), tir.Or(tir_false, i <= 0)], | ||
) | ||
|
||
|
||
def test_narrow_expression(before, expected): | ||
ranges = {n: tvm.ir.Range(0, 8)} | ||
after = tvm.arith._ffi_api.NarrowPredicateExpression(before, ranges) | ||
|
||
if expected is None: | ||
assert after is None | ||
else: | ||
tvm.ir.assert_structural_equal(after, expected) | ||
|
||
|
||
if __name__ == "__main__": | ||
tvm.testing.main() |