-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
simplify bounds of saturating_cast + update is_monotonic
- Loading branch information
Showing
3 changed files
with
26 additions
and
97 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1216,102 +1216,14 @@ class Bounds : public IRVisitor { | |
Expr a = op->args[0]; | ||
a.accept(this); | ||
Interval a_interval = interval; | ||
|
||
bounds_of_type(t); | ||
|
||
// For float to float, guarantee infinities are always pinned to range. | ||
if (t.is_float() && a.type().is_float()) { | ||
if (t.bits() < a.type().bits()) { | ||
// Casting to a smaller float, so clamp and then cast. | ||
if (a_interval.has_lower_bound()) { | ||
// If representable in the type, then use, otherwise return bounds_of_type(t).min | ||
if (can_prove(a_interval.min >= t.min())) { | ||
interval.min = cast(t, a_interval.min); | ||
} | ||
} | ||
if (a_interval.has_upper_bound()) { | ||
if (can_prove(a_interval.max <= t.max())) { | ||
interval.max = cast(t, a_interval.max); | ||
} | ||
} | ||
return; | ||
} else { | ||
// Casting to a wider float, so cast then clamp. | ||
if (a_interval.has_lower_bound()) { | ||
Expr casted_min = cast(t, a_interval.min); | ||
|
||
if (can_prove(casted_min >= t.min())) { | ||
interval.min = casted_min; | ||
} | ||
} | ||
if (a_interval.has_upper_bound()) { | ||
Expr casted_max = cast(t, a_interval.max); | ||
if (can_prove(casted_max <= t.max())) { | ||
interval.max = casted_max; | ||
} | ||
} | ||
return; | ||
} | ||
} else if (a.type() != t) { | ||
// Limits for Int(2^n) or UInt(2^n) are not exactly representable in Float(2^n) | ||
if (a.type().is_float() && !t.is_float() && t.bits() >= a.type().bits()) { | ||
if (a_interval.has_lower_bound()) { | ||
// min values turn out to be always representable | ||
if (can_prove(a_interval.min >= t.min())) { | ||
interval.min = cast(t, a_interval.min); | ||
} | ||
} | ||
if (a_interval.has_upper_bound()) { | ||
// This line depends on t.max() rounding upward, which should always | ||
// be the case as it is one less than a representable value, thus | ||
// the one larger is always the closest. | ||
if (can_prove(a_interval.max <= t.max())) { | ||
interval.max = cast(t, a_interval.max); | ||
} | ||
} | ||
return; | ||
} else { | ||
// We can safely cast a_interval iff we can prove the values | ||
// are within the range of t. | ||
Expr min_bound = lossless_cast(a.type(), t.min()); | ||
Expr max_bound = lossless_cast(a.type(), t.max()); | ||
// If the inner type is not a uint and we can represent t.min() in a.type(), | ||
// then we need to check that value is >= t.min(); | ||
const bool check_lower_bound = !a.type().is_uint() && min_bound.defined(); | ||
// We should always check upper bounds if a.type() can represent t.max(). | ||
const bool check_upper_bound = max_bound.defined(); | ||
|
||
// Define a helper function for performing saturation. | ||
auto check_safe_cast = [&](const Expr &value, const Expr &base) { | ||
if (check_upper_bound && check_lower_bound) { | ||
if (can_prove((value >= min_bound) && (value <= max_bound))) { | ||
return cast(t, value); | ||
} | ||
} else if (check_upper_bound) { | ||
if (can_prove(value <= max_bound)) { | ||
return cast(t, value); | ||
} | ||
} else if (check_lower_bound) { | ||
if (can_prove(a_interval.min >= min_bound)) { | ||
return cast(t, value); | ||
} | ||
} | ||
return base; | ||
}; | ||
|
||
if (a_interval.has_lower_bound()) { | ||
interval.min = check_safe_cast(a_interval.min, interval.min); | ||
} | ||
if (a_interval.has_upper_bound()) { | ||
interval.max = check_safe_cast(a_interval.max, interval.max); | ||
} | ||
return; | ||
} | ||
} else { | ||
// a.type() == t | ||
interval = a_interval; | ||
return; | ||
if (a_interval.has_lower_bound()) { | ||
interval.min = saturating_cast(op->type, a_interval.min); | ||
} | ||
if (a_interval.has_upper_bound()) { | ||
interval.max = saturating_cast(op->type, a_interval.max); | ||
} | ||
return; | ||
} else if (op->is_intrinsic(Call::unsafe_promise_clamped) || | ||
op->is_intrinsic(Call::promise_clamped)) { | ||
// Unlike an explicit clamp, we are also permitted to | ||
|
@@ -3677,7 +3589,7 @@ void bounds_test() { | |
check(scope, saturating_cast<uint8_t>(clamp(x, 5, 10)), cast<uint8_t>(5), cast<uint8_t>(10)); | ||
{ | ||
scope.push("x", Interval(UInt(32).min(), UInt(32).max())); | ||
check(scope, saturating_cast<int32_t>(max(cast<uint32_t>(x), cast<uint32_t>(5))), cast<int32_t>(5), Interval::pos_inf()); | ||
check(scope, saturating_cast<int32_t>(max(cast<uint32_t>(x), cast<uint32_t>(5))), cast<int32_t>(5), Int(32).max()); | ||
scope.pop("x"); | ||
} | ||
{ | ||
|
@@ -3692,7 +3604,7 @@ void bounds_test() { | |
{ | ||
Expr z = Variable::make(UInt(32), "z"); | ||
scope.push("z", Interval(UInt(32).max(), UInt(32).max())); | ||
check(scope, saturating_cast<int32_t>(z), Interval::neg_inf(), Interval::pos_inf()); | ||
check(scope, saturating_cast<int32_t>(z), Int(32).max(), Int(32).max()); | ||
This comment has been minimized.
Sorry, something went wrong.
This comment has been minimized.
Sorry, something went wrong.
rootjalex
Author
Member
|
||
scope.pop("z"); | ||
} | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Should the first one be Int(32).min()?