Skip to content

Commit

Permalink
simplify bounds of saturating_cast + update is_monotonic
Browse files Browse the repository at this point in the history
  • Loading branch information
rootjalex committed Aug 4, 2022
1 parent 5011b1e commit 91c1352
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 97 deletions.
104 changes: 8 additions & 96 deletions src/Bounds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1216,102 +1216,14 @@ class Bounds : public IRVisitor {
Expr a = op->args[0];
a.accept(this);
Interval a_interval = interval;

bounds_of_type(t);

// For float to float, guarantee infinities are always pinned to range.
if (t.is_float() && a.type().is_float()) {
if (t.bits() < a.type().bits()) {
// Casting to a smaller float, so clamp and then cast.
if (a_interval.has_lower_bound()) {
// If representable in the type, then use, otherwise return bounds_of_type(t).min
if (can_prove(a_interval.min >= t.min())) {
interval.min = cast(t, a_interval.min);
}
}
if (a_interval.has_upper_bound()) {
if (can_prove(a_interval.max <= t.max())) {
interval.max = cast(t, a_interval.max);
}
}
return;
} else {
// Casting to a wider float, so cast then clamp.
if (a_interval.has_lower_bound()) {
Expr casted_min = cast(t, a_interval.min);

if (can_prove(casted_min >= t.min())) {
interval.min = casted_min;
}
}
if (a_interval.has_upper_bound()) {
Expr casted_max = cast(t, a_interval.max);
if (can_prove(casted_max <= t.max())) {
interval.max = casted_max;
}
}
return;
}
} else if (a.type() != t) {
// Limits for Int(2^n) or UInt(2^n) are not exactly representable in Float(2^n)
if (a.type().is_float() && !t.is_float() && t.bits() >= a.type().bits()) {
if (a_interval.has_lower_bound()) {
// min values turn out to be always representable
if (can_prove(a_interval.min >= t.min())) {
interval.min = cast(t, a_interval.min);
}
}
if (a_interval.has_upper_bound()) {
// This line depends on t.max() rounding upward, which should always
// be the case as it is one less than a representable value, thus
// the one larger is always the closest.
if (can_prove(a_interval.max <= t.max())) {
interval.max = cast(t, a_interval.max);
}
}
return;
} else {
// We can safely cast a_interval iff we can prove the values
// are within the range of t.
Expr min_bound = lossless_cast(a.type(), t.min());
Expr max_bound = lossless_cast(a.type(), t.max());
// If the inner type is not a uint and we can represent t.min() in a.type(),
// then we need to check that value is >= t.min();
const bool check_lower_bound = !a.type().is_uint() && min_bound.defined();
// We should always check upper bounds if a.type() can represent t.max().
const bool check_upper_bound = max_bound.defined();

// Define a helper function for performing saturation.
auto check_safe_cast = [&](const Expr &value, const Expr &base) {
if (check_upper_bound && check_lower_bound) {
if (can_prove((value >= min_bound) && (value <= max_bound))) {
return cast(t, value);
}
} else if (check_upper_bound) {
if (can_prove(value <= max_bound)) {
return cast(t, value);
}
} else if (check_lower_bound) {
if (can_prove(a_interval.min >= min_bound)) {
return cast(t, value);
}
}
return base;
};

if (a_interval.has_lower_bound()) {
interval.min = check_safe_cast(a_interval.min, interval.min);
}
if (a_interval.has_upper_bound()) {
interval.max = check_safe_cast(a_interval.max, interval.max);
}
return;
}
} else {
// a.type() == t
interval = a_interval;
return;
if (a_interval.has_lower_bound()) {
interval.min = saturating_cast(op->type, a_interval.min);
}
if (a_interval.has_upper_bound()) {
interval.max = saturating_cast(op->type, a_interval.max);
}
return;
} else if (op->is_intrinsic(Call::unsafe_promise_clamped) ||
op->is_intrinsic(Call::promise_clamped)) {
// Unlike an explicit clamp, we are also permitted to
Expand Down Expand Up @@ -3677,7 +3589,7 @@ void bounds_test() {
check(scope, saturating_cast<uint8_t>(clamp(x, 5, 10)), cast<uint8_t>(5), cast<uint8_t>(10));
{
scope.push("x", Interval(UInt(32).min(), UInt(32).max()));
check(scope, saturating_cast<int32_t>(max(cast<uint32_t>(x), cast<uint32_t>(5))), cast<int32_t>(5), Interval::pos_inf());
check(scope, saturating_cast<int32_t>(max(cast<uint32_t>(x), cast<uint32_t>(5))), cast<int32_t>(5), Int(32).max());
scope.pop("x");
}
{
Expand All @@ -3692,7 +3604,7 @@ void bounds_test() {
{
Expr z = Variable::make(UInt(32), "z");
scope.push("z", Interval(UInt(32).max(), UInt(32).max()));
check(scope, saturating_cast<int32_t>(z), Interval::neg_inf(), Interval::pos_inf());
check(scope, saturating_cast<int32_t>(z), Int(32).max(), Int(32).max());

This comment has been minimized.

Copy link
@abadams

abadams Aug 4, 2022

Member

Should the first one be Int(32).min()?

This comment has been minimized.

Copy link
@rootjalex

rootjalex Aug 4, 2022

Author Member

The bounds of z are just [UInt(32).max(), UInt(32).max()], so I think the output should also be a single point

This comment has been minimized.

Copy link
@abadams

abadams Aug 4, 2022

Member

Oh, I misread and assumed nothing was known about z

scope.pop("z");
}

Expand Down
3 changes: 2 additions & 1 deletion src/Monotonic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,8 @@ class DerivativeBounds : public IRVisitor {
}

if (op->is_intrinsic(Call::unsafe_promise_clamped) ||
op->is_intrinsic(Call::promise_clamped)) {
op->is_intrinsic(Call::promise_clamped) ||
op->is_intrinsic(Call::saturating_cast)) {
op->args[0].accept(this);
return;
}
Expand Down
16 changes: 16 additions & 0 deletions src/Simplify_Call.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "Simplify_Internal.h"

#include "FindIntrinsics.h"
#include "Simplify.h"

#ifdef _MSC_VER
Expand Down Expand Up @@ -351,6 +352,21 @@ Expr Simplify::visit(const Call *op, ExprInfo *bounds) {
} else {
return absd(a, b);
}
} else if (op->is_intrinsic(Call::saturating_cast)) {
internal_assert(op->args.size() == 1);
ExprInfo a_bounds;
Expr a = mutate(op->args[0], &a_bounds);

// TODO(rootjalex): We could be intelligent about using a_bounds to remove saturating_casts;

if (is_const(a)) {
a = lower_saturating_cast(op->type, a);
return mutate(a, bounds);
} else if (!a.same_as(op->args[0])) {
return saturating_cast(op->type, a);
} else {
return op;
}
} else if (op->is_intrinsic(Call::stringify)) {
// Eagerly concat constant arguments to a stringify.
bool changed = false;
Expand Down

0 comments on commit 91c1352

Please sign in to comment.