From c6d78c1d918d0ce3497e2f43161f01c3035678be Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 11 Oct 2022 13:04:15 -0500 Subject: [PATCH] [Arith] Add internal NarrowPredicateExpression utility Implements `tvm::arith::NarrowPredicateExpression`, a utility that removes free parameters from a boolean expression, such that the resulting expression being true implies that the original expression is true. For example, the predicate `(0 <= i+f) && (i+f < 16)`, where `f` is a free parameter on the range `0 <= f < 2)`, can be narrowed to the expression `(0 <= i+0) && (i+2 < 16)`. In effect, `NarrowPredicateExpression` functions as a context-sentive `tvm::tir::Substitute`, where the value substituted is selected such that the resulting expression errs on the side of being false. This is an internal utility used as part of the simplifications for layout transformations ([tracking issue link](https://github.com/apache/tvm/issues/12261)). --- src/arith/narrow_predicate_expression.cc | 219 ++++++++++++++++++ src/arith/narrow_predicate_expression.h | 57 +++++ .../test_arith_narrow_predicate_expression.py | 87 +++++++ 3 files changed, 363 insertions(+) create mode 100644 src/arith/narrow_predicate_expression.cc create mode 100644 src/arith/narrow_predicate_expression.h create mode 100644 tests/python/unittest/test_arith_narrow_predicate_expression.py diff --git a/src/arith/narrow_predicate_expression.cc b/src/arith/narrow_predicate_expression.cc new file mode 100644 index 000000000000..1c8931d2dec4 --- /dev/null +++ b/src/arith/narrow_predicate_expression.cc @@ -0,0 +1,219 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file narrow_predicate_expression.cc + * \brief Utility to deduce bound of expression + */ +#include +#include +#include +#include +#include +#include + +namespace tvm { +namespace arith { + +using namespace tir; + +/* \brief Given a true expression that includes free parameter, + * generate a true expression without the free parameters. + * + * This function provides two guarantees: + * + * 1. If the resulting expression evaluates to True, then the original + * expression also evaluates to True. + * + * 2. The resulting expression does not contain any of the free + * parameters. + * + */ +// Utility for generating a known true expression from an expression +// with free parameters, and the range of those parameters. +class ExpressionNarrower : public tir::ExprMutator { + public: + static PrimExpr Apply(PrimExpr expr, Map free_parameters) { + ICHECK(expr.dtype().is_bool()) << "Expected boolean expression, but received " << expr; + ExpressionNarrower mutator(free_parameters); + return mutator(expr); + } + + private: + explicit ExpressionNarrower(Map free_parameters) + : free_parameters_(free_parameters) {} + + using Parent = tir::ExprMutator; + using Parent::VisitExpr_; + + enum class Context { + Maximize, + Minimize, + }; + + template + PrimExpr VisitInequality(T t, Context a_ctx, Context b_ctx) { + PrimExpr a = [&]() { + WithContext context(this, a_ctx); + return this->VisitExpr(t->a); + }(); + + PrimExpr b = [&]() { + WithContext context(this, b_ctx); + return this->VisitExpr(t->b); + }(); + + if (contains_unknown_expr_ && t.dtype().is_bool()) { + contains_unknown_expr_ = false; + return Bool(CurrentContext() == Context::Minimize); + } else if (a.same_as(t->a) && b.same_as(t->b)) { + return std::move(t); + } else { + return T(a, b); + } + } + + PrimExpr VisitExpr_(const FloorModNode* op) override { + // FloorMod is non-monotonic, so inserting min/max won't remove + // the free parameters. + contains_unknown_expr_ = true; + return Parent::VisitExpr_(op); + } + + PrimExpr VisitExpr_(const FloorDivNode* op) override { + auto res_a = this->VisitExpr(op->a); + auto res_b = this->VisitExpr(op->b); + if (is_zero(res_b)) { + contains_unknown_expr_ = true; + return IntImm(op->dtype, 0); + } else { + return floordiv(res_a, res_b); + } + } + + PrimExpr VisitExpr_(const GTNode* op) override { + auto current = CurrentContext(); + return VisitInequality(GetRef(op), OppositeContext(current), current); + } + + PrimExpr VisitExpr_(const GENode* op) override { + auto current = CurrentContext(); + return VisitInequality(GetRef(op), OppositeContext(current), current); + } + + PrimExpr VisitExpr_(const LTNode* op) override { + auto current = CurrentContext(); + return VisitInequality(GetRef(op), current, OppositeContext(current)); + } + + PrimExpr VisitExpr_(const LENode* op) override { + auto current = CurrentContext(); + return VisitInequality(GetRef(op), current, OppositeContext(current)); + } + + PrimExpr VisitExpr_(const EQNode* op) override { + auto res_a = this->VisitExpr(op->a <= op->b); + auto res_b = this->VisitExpr(op->b <= op->a); + return res_a && res_b; + } + + PrimExpr VisitExpr_(const NENode* op) override { + auto res_a = this->VisitExpr(op->a < op->b); + auto res_b = this->VisitExpr(op->b < op->a); + return res_a || res_b; + } + + PrimExpr VisitExpr_(const SubNode* op) override { + auto current = CurrentContext(); + return VisitInequality(GetRef(op), current, OppositeContext(current)); + } + + PrimExpr VisitExpr_(const NotNode* op) override { + auto current = CurrentContext(); + WithContext context(this, OppositeContext(current)); + return !VisitExpr(op->a); + } + + PrimExpr VisitExpr_(const BufferLoadNode* op) override { + contains_unknown_expr_ = true; + return GetRef(op); + } + + PrimExpr VisitExpr_(const VarNode* op) override { + auto it = free_parameters_.find(GetRef(op)); + if (it == free_parameters_.end()) { + return Parent::VisitExpr_(op); + } + + Range range = (*it).second; + + switch (CurrentContext()) { + case Context::Minimize: + return range->min; + + case Context::Maximize: + return range->min + range->extent - 1; + } + + return Parent::VisitExpr_(op); + } + + Context CurrentContext() const { + if (context_stack_.size()) { + return context_stack_.back(); + } else { + return Context::Maximize; + } + } + + Context OppositeContext(Context context) const { + switch (context) { + case Context::Minimize: + return Context::Maximize; + + case Context::Maximize: + return Context::Minimize; + + default: + LOG(FATAL) << "Unhandled Context, all legal values should be handled"; + return Context::Maximize; + } + } + + struct WithContext { + WithContext(ExpressionNarrower* self, Context context) : self(self) { + self->context_stack_.push_back(context); + } + ~WithContext() { self->context_stack_.pop_back(); } + ExpressionNarrower* self; + }; + + std::vector context_stack_; + Map free_parameters_; + bool contains_unknown_expr_{false}; +}; + +PrimExpr NarrowPredicateExpression(PrimExpr expr, Map free_parameters) { + return ExpressionNarrower::Apply(std::move(expr), std::move(free_parameters)); +} + +TVM_REGISTER_GLOBAL("arith.NarrowPredicateExpression").set_body_typed(NarrowPredicateExpression); + +} // namespace arith +} // namespace tvm diff --git a/src/arith/narrow_predicate_expression.h b/src/arith/narrow_predicate_expression.h new file mode 100644 index 000000000000..1e452e3ad493 --- /dev/null +++ b/src/arith/narrow_predicate_expression.h @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file narrow_predicate_expression.h + * \brief Utility for extracting and interacting with buffer touch points + */ + +#include +#include + +#ifndef TVM_ARITH_NARROW_PREDICATE_EXPRESSION_H_ +#define TVM_ARITH_NARROW_PREDICATE_EXPRESSION_H_ + +namespace tvm { +namespace arith { + +/* \brief Narrow a true expression to remove free parameters + * + * This function provides two guarantees: + * + * 1. If the resulting expression evaluates to True, then the original + * expression also evaluates to True. + * + * 2. The resulting expression does not contain any of the free + * parameters. + * + * 3. The resulting expression does not contain any BufferLoad + * + * \param expr The expression to be examined. + * + * \param ranges The variables to be removed from the expression + * + * \returns An expression that, if true, implies that the original + * expression is also true. + */ +PrimExpr NarrowPredicateExpression(PrimExpr expr, Map free_parameters); + +} // namespace arith +} // namespace tvm +#endif // TVM_ARITH_NARROW_PREDICATE_EXPRESSION_H_ diff --git a/tests/python/unittest/test_arith_narrow_predicate_expression.py b/tests/python/unittest/test_arith_narrow_predicate_expression.py new file mode 100644 index 000000000000..d38fe70f6b5c --- /dev/null +++ b/tests/python/unittest/test_arith_narrow_predicate_expression.py @@ -0,0 +1,87 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +import tvm.testing + +from tvm import tir +from tvm.runtime import convert + + +i = tir.Var("i", "int32") +j = tir.Var("j", "int32") +n = tir.Var("n", "int32") +m = tir.Var("m", "int32") +b = tir.Var("b", "bool") +buf = tir.decl_buffer(16, "int32", "buf") + +tir_false = tir.IntImm("bool", False) +tir_true = tir.IntImm("bool", True) + +before, expected = tvm.testing.parameters( + # General arithmatic + [tir_true, tir_true], + [tir_false, tir_false], + [b, b], + [i > 5, i > 5], + [i > n, i > 7], + [i < n, i < 0], + [i <= n, i <= 0], + [i >= n, i >= 7], + [n > i, convert(0) > i], + [n < i, convert(7) < i], + [n <= i, convert(7) <= i], + [n >= i, convert(0) >= i], + [i == n, tir.all(i <= 0, convert(7) <= i)], + [n == i, tir.all(convert(7) <= i, i <= 0)], + [i != n, tir.any(i < 0, convert(7) < i)], + [n != i, tir.any(convert(7) < i, i < 0)], + [i // 4 > n, i // 4 > 7], + [n < i // 4, convert(7) < i // 4], + [(i + n) // 4 > 0, tir.Add(i, 0) // 4 > 0], + [(i + n) // 4 == 0, tir.all(tir.Add(i, 7) // 4 <= 0, convert(0) <= tir.Add(i, 0) // 4)], + [i + n < 10, i + 7 < 10], + [i - n < 10, tir.Sub(i, 0) < 10], + [tir.Not(i < n), tir.Not(i < 7)], + # Use of FloorMod should make the narrowing strategy bail out, as + # it is non-monotonic. + [i % 8 == n, tir_false], + # Ensure that dividing by a free parameter doesn't generate a + # divide-by-zero to be triggered later. + [i // n == 0, tir_false], + ### Buffer handling + [buf.vload(0) > 0, tir_false], + [buf.vload(0) > i, tir_false], + [buf.vload(i) > 0, tir_false], + [tir.And(buf.vload(i) > 0, i <= 0), tir.And(tir_false, i <= 0)], + [tir.Or(buf.vload(i) > 0, i <= n), tir.Or(tir_false, i <= 0)], + [tir.Or(tir.Not(buf.vload(i) > 0), i <= n), tir.Or(tir_false, i <= 0)], +) + + +def test_narrow_expression(before, expected): + ranges = {n: tvm.ir.Range(0, 8)} + after = tvm.arith._ffi_api.NarrowPredicateExpression(before, ranges) + + if expected is None: + assert after is None + else: + tvm.ir.assert_structural_equal(after, expected) + + +if __name__ == "__main__": + tvm.testing.main()