Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Arith] Add internal NarrowPredicateExpression utility #13041

Merged
merged 1 commit into from
Oct 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
219 changes: 219 additions & 0 deletions src/arith/narrow_predicate_expression.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file narrow_predicate_expression.cc
* \brief Utility to deduce bound of expression
*/
#include <tvm/arith/int_solver.h>
#include <tvm/runtime/registry.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>

namespace tvm {
namespace arith {

using namespace tir;

/* \brief Given a true expression that includes free parameter,
* generate a true expression without the free parameters.
*
* This function provides two guarantees:
*
* 1. If the resulting expression evaluates to True, then the original
* expression also evaluates to True.
*
* 2. The resulting expression does not contain any of the free
* parameters.
*
*/
// Utility for generating a known true expression from an expression
// with free parameters, and the range of those parameters.
class ExpressionNarrower : public tir::ExprMutator {
public:
static PrimExpr Apply(PrimExpr expr, Map<Var, Range> free_parameters) {
ICHECK(expr.dtype().is_bool()) << "Expected boolean expression, but received " << expr;
ExpressionNarrower mutator(free_parameters);
return mutator(expr);
}

private:
explicit ExpressionNarrower(Map<Var, Range> free_parameters)
: free_parameters_(free_parameters) {}

using Parent = tir::ExprMutator;
using Parent::VisitExpr_;

enum class Context {
Maximize,
Minimize,
};

template <typename T>
PrimExpr VisitInequality(T t, Context a_ctx, Context b_ctx) {
PrimExpr a = [&]() {
WithContext context(this, a_ctx);
return this->VisitExpr(t->a);
}();

PrimExpr b = [&]() {
WithContext context(this, b_ctx);
return this->VisitExpr(t->b);
}();

if (contains_unknown_expr_ && t.dtype().is_bool()) {
contains_unknown_expr_ = false;
return Bool(CurrentContext() == Context::Minimize);
} else if (a.same_as(t->a) && b.same_as(t->b)) {
return std::move(t);
} else {
return T(a, b);
}
}

PrimExpr VisitExpr_(const FloorModNode* op) override {
// FloorMod is non-monotonic, so inserting min/max won't remove
// the free parameters.
contains_unknown_expr_ = true;
return Parent::VisitExpr_(op);
}

PrimExpr VisitExpr_(const FloorDivNode* op) override {
auto res_a = this->VisitExpr(op->a);
auto res_b = this->VisitExpr(op->b);
if (is_zero(res_b)) {
contains_unknown_expr_ = true;
return IntImm(op->dtype, 0);
} else {
return floordiv(res_a, res_b);
}
}

PrimExpr VisitExpr_(const GTNode* op) override {
auto current = CurrentContext();
return VisitInequality(GetRef<GT>(op), OppositeContext(current), current);
}

PrimExpr VisitExpr_(const GENode* op) override {
auto current = CurrentContext();
return VisitInequality(GetRef<GE>(op), OppositeContext(current), current);
}

PrimExpr VisitExpr_(const LTNode* op) override {
auto current = CurrentContext();
return VisitInequality(GetRef<LT>(op), current, OppositeContext(current));
}

PrimExpr VisitExpr_(const LENode* op) override {
auto current = CurrentContext();
return VisitInequality(GetRef<LE>(op), current, OppositeContext(current));
}

PrimExpr VisitExpr_(const EQNode* op) override {
auto res_a = this->VisitExpr(op->a <= op->b);
auto res_b = this->VisitExpr(op->b <= op->a);
return res_a && res_b;
}

PrimExpr VisitExpr_(const NENode* op) override {
auto res_a = this->VisitExpr(op->a < op->b);
auto res_b = this->VisitExpr(op->b < op->a);
return res_a || res_b;
}

PrimExpr VisitExpr_(const SubNode* op) override {
auto current = CurrentContext();
return VisitInequality(GetRef<Sub>(op), current, OppositeContext(current));
}

PrimExpr VisitExpr_(const NotNode* op) override {
auto current = CurrentContext();
WithContext context(this, OppositeContext(current));
return !VisitExpr(op->a);
}

PrimExpr VisitExpr_(const BufferLoadNode* op) override {
contains_unknown_expr_ = true;
return GetRef<PrimExpr>(op);
}

PrimExpr VisitExpr_(const VarNode* op) override {
auto it = free_parameters_.find(GetRef<Var>(op));
if (it == free_parameters_.end()) {
return Parent::VisitExpr_(op);
}

Range range = (*it).second;

switch (CurrentContext()) {
case Context::Minimize:
return range->min;

case Context::Maximize:
return range->min + range->extent - 1;
}

return Parent::VisitExpr_(op);
}

Context CurrentContext() const {
if (context_stack_.size()) {
return context_stack_.back();
} else {
return Context::Maximize;
}
}

Context OppositeContext(Context context) const {
switch (context) {
case Context::Minimize:
return Context::Maximize;

case Context::Maximize:
return Context::Minimize;

default:
LOG(FATAL) << "Unhandled Context, all legal values should be handled";
return Context::Maximize;
}
}

struct WithContext {
WithContext(ExpressionNarrower* self, Context context) : self(self) {
self->context_stack_.push_back(context);
}
~WithContext() { self->context_stack_.pop_back(); }
ExpressionNarrower* self;
};

std::vector<Context> context_stack_;
Map<Var, Range> free_parameters_;
bool contains_unknown_expr_{false};
};

PrimExpr NarrowPredicateExpression(PrimExpr expr, Map<Var, Range> free_parameters) {
return ExpressionNarrower::Apply(std::move(expr), std::move(free_parameters));
}

TVM_REGISTER_GLOBAL("arith.NarrowPredicateExpression").set_body_typed(NarrowPredicateExpression);

} // namespace arith
} // namespace tvm
57 changes: 57 additions & 0 deletions src/arith/narrow_predicate_expression.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file narrow_predicate_expression.h
* \brief Utility for extracting and interacting with buffer touch points
*/

#include <tvm/ir/expr.h>
#include <tvm/tir/var.h>

#ifndef TVM_ARITH_NARROW_PREDICATE_EXPRESSION_H_
#define TVM_ARITH_NARROW_PREDICATE_EXPRESSION_H_

namespace tvm {
namespace arith {

/* \brief Narrow a true expression to remove free parameters
*
* This function provides two guarantees:
*
* 1. If the resulting expression evaluates to True, then the original
* expression also evaluates to True.
*
* 2. The resulting expression does not contain any of the free
* parameters.
*
* 3. The resulting expression does not contain any BufferLoad
*
* \param expr The expression to be examined.
*
* \param ranges The variables to be removed from the expression
*
* \returns An expression that, if true, implies that the original
* expression is also true.
*/
PrimExpr NarrowPredicateExpression(PrimExpr expr, Map<tir::Var, Range> free_parameters);

} // namespace arith
} // namespace tvm
#endif // TVM_ARITH_NARROW_PREDICATE_EXPRESSION_H_
87 changes: 87 additions & 0 deletions tests/python/unittest/test_arith_narrow_predicate_expression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import tvm
import tvm.testing

from tvm import tir
from tvm.runtime import convert


i = tir.Var("i", "int32")
j = tir.Var("j", "int32")
n = tir.Var("n", "int32")
m = tir.Var("m", "int32")
b = tir.Var("b", "bool")
buf = tir.decl_buffer(16, "int32", "buf")

tir_false = tir.IntImm("bool", False)
tir_true = tir.IntImm("bool", True)

before, expected = tvm.testing.parameters(
# General arithmatic
[tir_true, tir_true],
[tir_false, tir_false],
[b, b],
[i > 5, i > 5],
[i > n, i > 7],
[i < n, i < 0],
[i <= n, i <= 0],
[i >= n, i >= 7],
[n > i, convert(0) > i],
[n < i, convert(7) < i],
[n <= i, convert(7) <= i],
[n >= i, convert(0) >= i],
[i == n, tir.all(i <= 0, convert(7) <= i)],
[n == i, tir.all(convert(7) <= i, i <= 0)],
[i != n, tir.any(i < 0, convert(7) < i)],
[n != i, tir.any(convert(7) < i, i < 0)],
[i // 4 > n, i // 4 > 7],
[n < i // 4, convert(7) < i // 4],
[(i + n) // 4 > 0, tir.Add(i, 0) // 4 > 0],
[(i + n) // 4 == 0, tir.all(tir.Add(i, 7) // 4 <= 0, convert(0) <= tir.Add(i, 0) // 4)],
[i + n < 10, i + 7 < 10],
[i - n < 10, tir.Sub(i, 0) < 10],
[tir.Not(i < n), tir.Not(i < 7)],
# Use of FloorMod should make the narrowing strategy bail out, as
# it is non-monotonic.
[i % 8 == n, tir_false],
# Ensure that dividing by a free parameter doesn't generate a
# divide-by-zero to be triggered later.
[i // n == 0, tir_false],
### Buffer handling
[buf.vload(0) > 0, tir_false],
[buf.vload(0) > i, tir_false],
[buf.vload(i) > 0, tir_false],
[tir.And(buf.vload(i) > 0, i <= 0), tir.And(tir_false, i <= 0)],
[tir.Or(buf.vload(i) > 0, i <= n), tir.Or(tir_false, i <= 0)],
[tir.Or(tir.Not(buf.vload(i) > 0), i <= n), tir.Or(tir_false, i <= 0)],
)


def test_narrow_expression(before, expected):
ranges = {n: tvm.ir.Range(0, 8)}
after = tvm.arith._ffi_api.NarrowPredicateExpression(before, ranges)

if expected is None:
assert after is None
else:
tvm.ir.assert_structural_equal(after, expected)


if __name__ == "__main__":
tvm.testing.main()