diff --git a/src/Bounds.cpp b/src/Bounds.cpp index 10833d2b45c0..beae9bff317d 100644 --- a/src/Bounds.cpp +++ b/src/Bounds.cpp @@ -1216,102 +1216,14 @@ class Bounds : public IRVisitor { Expr a = op->args[0]; a.accept(this); Interval a_interval = interval; - bounds_of_type(t); - - // For float to float, guarantee infinities are always pinned to range. - if (t.is_float() && a.type().is_float()) { - if (t.bits() < a.type().bits()) { - // Casting to a smaller float, so clamp and then cast. - if (a_interval.has_lower_bound()) { - // If representable in the type, then use, otherwise return bounds_of_type(t).min - if (can_prove(a_interval.min >= t.min())) { - interval.min = cast(t, a_interval.min); - } - } - if (a_interval.has_upper_bound()) { - if (can_prove(a_interval.max <= t.max())) { - interval.max = cast(t, a_interval.max); - } - } - return; - } else { - // Casting to a wider float, so cast then clamp. - if (a_interval.has_lower_bound()) { - Expr casted_min = cast(t, a_interval.min); - - if (can_prove(casted_min >= t.min())) { - interval.min = casted_min; - } - } - if (a_interval.has_upper_bound()) { - Expr casted_max = cast(t, a_interval.max); - if (can_prove(casted_max <= t.max())) { - interval.max = casted_max; - } - } - return; - } - } else if (a.type() != t) { - // Limits for Int(2^n) or UInt(2^n) are not exactly representable in Float(2^n) - if (a.type().is_float() && !t.is_float() && t.bits() >= a.type().bits()) { - if (a_interval.has_lower_bound()) { - // min values turn out to be always representable - if (can_prove(a_interval.min >= t.min())) { - interval.min = cast(t, a_interval.min); - } - } - if (a_interval.has_upper_bound()) { - // This line depends on t.max() rounding upward, which should always - // be the case as it is one less than a representable value, thus - // the one larger is always the closest. - if (can_prove(a_interval.max <= t.max())) { - interval.max = cast(t, a_interval.max); - } - } - return; - } else { - // We can safely cast a_interval iff we can prove the values - // are within the range of t. - Expr min_bound = lossless_cast(a.type(), t.min()); - Expr max_bound = lossless_cast(a.type(), t.max()); - // If the inner type is not a uint and we can represent t.min() in a.type(), - // then we need to check that value is >= t.min(); - const bool check_lower_bound = !a.type().is_uint() && min_bound.defined(); - // We should always check upper bounds if a.type() can represent t.max(). - const bool check_upper_bound = max_bound.defined(); - - // Define a helper function for performing saturation. - auto check_safe_cast = [&](const Expr &value, const Expr &base) { - if (check_upper_bound && check_lower_bound) { - if (can_prove((value >= min_bound) && (value <= max_bound))) { - return cast(t, value); - } - } else if (check_upper_bound) { - if (can_prove(value <= max_bound)) { - return cast(t, value); - } - } else if (check_lower_bound) { - if (can_prove(a_interval.min >= min_bound)) { - return cast(t, value); - } - } - return base; - }; - - if (a_interval.has_lower_bound()) { - interval.min = check_safe_cast(a_interval.min, interval.min); - } - if (a_interval.has_upper_bound()) { - interval.max = check_safe_cast(a_interval.max, interval.max); - } - return; - } - } else { - // a.type() == t - interval = a_interval; - return; + if (a_interval.has_lower_bound()) { + interval.min = saturating_cast(op->type, a_interval.min); } + if (a_interval.has_upper_bound()) { + interval.max = saturating_cast(op->type, a_interval.max); + } + return; } else if (op->is_intrinsic(Call::unsafe_promise_clamped) || op->is_intrinsic(Call::promise_clamped)) { // Unlike an explicit clamp, we are also permitted to @@ -3677,7 +3589,7 @@ void bounds_test() { check(scope, saturating_cast(clamp(x, 5, 10)), cast(5), cast(10)); { scope.push("x", Interval(UInt(32).min(), UInt(32).max())); - check(scope, saturating_cast(max(cast(x), cast(5))), cast(5), Interval::pos_inf()); + check(scope, saturating_cast(max(cast(x), cast(5))), cast(5), Int(32).max()); scope.pop("x"); } { @@ -3692,7 +3604,7 @@ void bounds_test() { { Expr z = Variable::make(UInt(32), "z"); scope.push("z", Interval(UInt(32).max(), UInt(32).max())); - check(scope, saturating_cast(z), Interval::neg_inf(), Interval::pos_inf()); + check(scope, saturating_cast(z), Int(32).max(), Int(32).max()); scope.pop("z"); } diff --git a/src/Monotonic.cpp b/src/Monotonic.cpp index ae8978b2cb57..cec309571aa8 100644 --- a/src/Monotonic.cpp +++ b/src/Monotonic.cpp @@ -480,7 +480,8 @@ class DerivativeBounds : public IRVisitor { } if (op->is_intrinsic(Call::unsafe_promise_clamped) || - op->is_intrinsic(Call::promise_clamped)) { + op->is_intrinsic(Call::promise_clamped) || + op->is_intrinsic(Call::saturating_cast)) { op->args[0].accept(this); return; } diff --git a/src/Simplify_Call.cpp b/src/Simplify_Call.cpp index a1ff4c5130fe..6c92fd086405 100644 --- a/src/Simplify_Call.cpp +++ b/src/Simplify_Call.cpp @@ -1,5 +1,6 @@ #include "Simplify_Internal.h" +#include "FindIntrinsics.h" #include "Simplify.h" #ifdef _MSC_VER @@ -351,6 +352,21 @@ Expr Simplify::visit(const Call *op, ExprInfo *bounds) { } else { return absd(a, b); } + } else if (op->is_intrinsic(Call::saturating_cast)) { + internal_assert(op->args.size() == 1); + ExprInfo a_bounds; + Expr a = mutate(op->args[0], &a_bounds); + + // TODO(rootjalex): We could be intelligent about using a_bounds to remove saturating_casts; + + if (is_const(a)) { + a = lower_saturating_cast(op->type, a); + return mutate(a, bounds); + } else if (!a.same_as(op->args[0])) { + return saturating_cast(op->type, a); + } else { + return op; + } } else if (op->is_intrinsic(Call::stringify)) { // Eagerly concat constant arguments to a stringify. bool changed = false;