Skip to content

Commit

Permalink
[Arith] Add internal NarrowPredicateExpression utility
Browse files Browse the repository at this point in the history
Implements `tvm::arith::NarrowPredicateExpression`, a utility that
removes free parameters from a boolean expression, such that the
resulting expression being true implies that the original expression
is true.  For example, the predicate `(0 <= i+f) && (i+f < 16)`, where
`f` is a free parameter on the range `0 <= f < 2)`, can be narrowed to
the expression `(0 <= i+0) && (i+2 < 16)`.

In effect, `NarrowPredicateExpression` functions as a context-sentive
`tvm::tir::Substitute`, where the value substituted is selected such
that the resulting expression errs on the side of being false.  This
is an internal utility used as part of the simplifications for layout
transformations ([tracking issue
link](#12261)).
  • Loading branch information
Lunderberg committed Oct 14, 2022
1 parent 605876e commit c6d78c1
Show file tree
Hide file tree
Showing 3 changed files with 363 additions and 0 deletions.
219 changes: 219 additions & 0 deletions src/arith/narrow_predicate_expression.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file narrow_predicate_expression.cc
* \brief Utility to deduce bound of expression
*/
#include <tvm/arith/int_solver.h>
#include <tvm/runtime/registry.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>

namespace tvm {
namespace arith {

using namespace tir;

/* \brief Given a true expression that includes free parameter,
* generate a true expression without the free parameters.
*
* This function provides two guarantees:
*
* 1. If the resulting expression evaluates to True, then the original
* expression also evaluates to True.
*
* 2. The resulting expression does not contain any of the free
* parameters.
*
*/
// Utility for generating a known true expression from an expression
// with free parameters, and the range of those parameters.
class ExpressionNarrower : public tir::ExprMutator {
public:
static PrimExpr Apply(PrimExpr expr, Map<Var, Range> free_parameters) {
ICHECK(expr.dtype().is_bool()) << "Expected boolean expression, but received " << expr;
ExpressionNarrower mutator(free_parameters);
return mutator(expr);
}

private:
explicit ExpressionNarrower(Map<Var, Range> free_parameters)
: free_parameters_(free_parameters) {}

using Parent = tir::ExprMutator;
using Parent::VisitExpr_;

enum class Context {
Maximize,
Minimize,
};

template <typename T>
PrimExpr VisitInequality(T t, Context a_ctx, Context b_ctx) {
PrimExpr a = [&]() {
WithContext context(this, a_ctx);
return this->VisitExpr(t->a);
}();

PrimExpr b = [&]() {
WithContext context(this, b_ctx);
return this->VisitExpr(t->b);
}();

if (contains_unknown_expr_ && t.dtype().is_bool()) {
contains_unknown_expr_ = false;
return Bool(CurrentContext() == Context::Minimize);
} else if (a.same_as(t->a) && b.same_as(t->b)) {
return std::move(t);
} else {
return T(a, b);
}
}

PrimExpr VisitExpr_(const FloorModNode* op) override {
// FloorMod is non-monotonic, so inserting min/max won't remove
// the free parameters.
contains_unknown_expr_ = true;
return Parent::VisitExpr_(op);
}

PrimExpr VisitExpr_(const FloorDivNode* op) override {
auto res_a = this->VisitExpr(op->a);
auto res_b = this->VisitExpr(op->b);
if (is_zero(res_b)) {
contains_unknown_expr_ = true;
return IntImm(op->dtype, 0);
} else {
return floordiv(res_a, res_b);
}
}

PrimExpr VisitExpr_(const GTNode* op) override {
auto current = CurrentContext();
return VisitInequality(GetRef<GT>(op), OppositeContext(current), current);
}

PrimExpr VisitExpr_(const GENode* op) override {
auto current = CurrentContext();
return VisitInequality(GetRef<GE>(op), OppositeContext(current), current);
}

PrimExpr VisitExpr_(const LTNode* op) override {
auto current = CurrentContext();
return VisitInequality(GetRef<LT>(op), current, OppositeContext(current));
}

PrimExpr VisitExpr_(const LENode* op) override {
auto current = CurrentContext();
return VisitInequality(GetRef<LE>(op), current, OppositeContext(current));
}

PrimExpr VisitExpr_(const EQNode* op) override {
auto res_a = this->VisitExpr(op->a <= op->b);
auto res_b = this->VisitExpr(op->b <= op->a);
return res_a && res_b;
}

PrimExpr VisitExpr_(const NENode* op) override {
auto res_a = this->VisitExpr(op->a < op->b);
auto res_b = this->VisitExpr(op->b < op->a);
return res_a || res_b;
}

PrimExpr VisitExpr_(const SubNode* op) override {
auto current = CurrentContext();
return VisitInequality(GetRef<Sub>(op), current, OppositeContext(current));
}

PrimExpr VisitExpr_(const NotNode* op) override {
auto current = CurrentContext();
WithContext context(this, OppositeContext(current));
return !VisitExpr(op->a);
}

PrimExpr VisitExpr_(const BufferLoadNode* op) override {
contains_unknown_expr_ = true;
return GetRef<PrimExpr>(op);
}

PrimExpr VisitExpr_(const VarNode* op) override {
auto it = free_parameters_.find(GetRef<Var>(op));
if (it == free_parameters_.end()) {
return Parent::VisitExpr_(op);
}

Range range = (*it).second;

switch (CurrentContext()) {
case Context::Minimize:
return range->min;

case Context::Maximize:
return range->min + range->extent - 1;
}

return Parent::VisitExpr_(op);
}

Context CurrentContext() const {
if (context_stack_.size()) {
return context_stack_.back();
} else {
return Context::Maximize;
}
}

Context OppositeContext(Context context) const {
switch (context) {
case Context::Minimize:
return Context::Maximize;

case Context::Maximize:
return Context::Minimize;

default:
LOG(FATAL) << "Unhandled Context, all legal values should be handled";
return Context::Maximize;
}
}

struct WithContext {
WithContext(ExpressionNarrower* self, Context context) : self(self) {
self->context_stack_.push_back(context);
}
~WithContext() { self->context_stack_.pop_back(); }
ExpressionNarrower* self;
};

std::vector<Context> context_stack_;
Map<Var, Range> free_parameters_;
bool contains_unknown_expr_{false};
};

PrimExpr NarrowPredicateExpression(PrimExpr expr, Map<Var, Range> free_parameters) {
return ExpressionNarrower::Apply(std::move(expr), std::move(free_parameters));
}

TVM_REGISTER_GLOBAL("arith.NarrowPredicateExpression").set_body_typed(NarrowPredicateExpression);

} // namespace arith
} // namespace tvm
57 changes: 57 additions & 0 deletions src/arith/narrow_predicate_expression.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file narrow_predicate_expression.h
* \brief Utility for extracting and interacting with buffer touch points
*/

#include <tvm/ir/expr.h>
#include <tvm/tir/var.h>

#ifndef TVM_ARITH_NARROW_PREDICATE_EXPRESSION_H_
#define TVM_ARITH_NARROW_PREDICATE_EXPRESSION_H_

namespace tvm {
namespace arith {

/* \brief Narrow a true expression to remove free parameters
*
* This function provides two guarantees:
*
* 1. If the resulting expression evaluates to True, then the original
* expression also evaluates to True.
*
* 2. The resulting expression does not contain any of the free
* parameters.
*
* 3. The resulting expression does not contain any BufferLoad
*
* \param expr The expression to be examined.
*
* \param ranges The variables to be removed from the expression
*
* \returns An expression that, if true, implies that the original
* expression is also true.
*/
PrimExpr NarrowPredicateExpression(PrimExpr expr, Map<tir::Var, Range> free_parameters);

} // namespace arith
} // namespace tvm
#endif // TVM_ARITH_NARROW_PREDICATE_EXPRESSION_H_
87 changes: 87 additions & 0 deletions tests/python/unittest/test_arith_narrow_predicate_expression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import tvm
import tvm.testing

from tvm import tir
from tvm.runtime import convert


i = tir.Var("i", "int32")
j = tir.Var("j", "int32")
n = tir.Var("n", "int32")
m = tir.Var("m", "int32")
b = tir.Var("b", "bool")
buf = tir.decl_buffer(16, "int32", "buf")

tir_false = tir.IntImm("bool", False)
tir_true = tir.IntImm("bool", True)

before, expected = tvm.testing.parameters(
# General arithmatic
[tir_true, tir_true],
[tir_false, tir_false],
[b, b],
[i > 5, i > 5],
[i > n, i > 7],
[i < n, i < 0],
[i <= n, i <= 0],
[i >= n, i >= 7],
[n > i, convert(0) > i],
[n < i, convert(7) < i],
[n <= i, convert(7) <= i],
[n >= i, convert(0) >= i],
[i == n, tir.all(i <= 0, convert(7) <= i)],
[n == i, tir.all(convert(7) <= i, i <= 0)],
[i != n, tir.any(i < 0, convert(7) < i)],
[n != i, tir.any(convert(7) < i, i < 0)],
[i // 4 > n, i // 4 > 7],
[n < i // 4, convert(7) < i // 4],
[(i + n) // 4 > 0, tir.Add(i, 0) // 4 > 0],
[(i + n) // 4 == 0, tir.all(tir.Add(i, 7) // 4 <= 0, convert(0) <= tir.Add(i, 0) // 4)],
[i + n < 10, i + 7 < 10],
[i - n < 10, tir.Sub(i, 0) < 10],
[tir.Not(i < n), tir.Not(i < 7)],
# Use of FloorMod should make the narrowing strategy bail out, as
# it is non-monotonic.
[i % 8 == n, tir_false],
# Ensure that dividing by a free parameter doesn't generate a
# divide-by-zero to be triggered later.
[i // n == 0, tir_false],
### Buffer handling
[buf.vload(0) > 0, tir_false],
[buf.vload(0) > i, tir_false],
[buf.vload(i) > 0, tir_false],
[tir.And(buf.vload(i) > 0, i <= 0), tir.And(tir_false, i <= 0)],
[tir.Or(buf.vload(i) > 0, i <= n), tir.Or(tir_false, i <= 0)],
[tir.Or(tir.Not(buf.vload(i) > 0), i <= n), tir.Or(tir_false, i <= 0)],
)


def test_narrow_expression(before, expected):
ranges = {n: tvm.ir.Range(0, 8)}
after = tvm.arith._ffi_api.NarrowPredicateExpression(before, ranges)

if expected is None:
assert after is None
else:
tvm.ir.assert_structural_equal(after, expected)


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit c6d78c1

Please sign in to comment.