Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make saturating_cast an intrinsic #6900

Merged
merged 21 commits into from
Aug 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions src/Bounds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1210,6 +1210,20 @@ class Bounds : public IRVisitor {
bounds_of_type(t);
}
}
} else if (op->is_intrinsic(Call::saturating_cast)) {
internal_assert(op->args.size() == 1);

Expr a = op->args[0];
a.accept(this);
Interval a_interval = interval;
bounds_of_type(t);
if (a_interval.has_lower_bound()) {
interval.min = saturating_cast(t, a_interval.min);
}
if (a_interval.has_upper_bound()) {
interval.max = saturating_cast(t, a_interval.max);
}
return;
} else if (op->is_intrinsic(Call::unsafe_promise_clamped) ||
op->is_intrinsic(Call::promise_clamped)) {
// Unlike an explicit clamp, we are also permitted to
Expand Down Expand Up @@ -3572,6 +3586,28 @@ void bounds_test() {
check(scope, cast<uint16_t>(u8_1) + cast<uint16_t>(u8_2),
u16(0), u16(255 * 2));

check(scope, saturating_cast<uint8_t>(clamp(x, 5, 10)), cast<uint8_t>(5), cast<uint8_t>(10));
{
scope.push("x", Interval(UInt(32).min(), UInt(32).max()));
check(scope, saturating_cast<int32_t>(max(cast<uint32_t>(x), cast<uint32_t>(5))), cast<int32_t>(5), Int(32).max());
scope.pop("x");
}
{
Expr z = Variable::make(Float(32), "z");
scope.push("z", Interval(cast<float>(-1), cast<float>(1)));
check(scope, saturating_cast<int32_t>(z), cast<int32_t>(-1), cast<int32_t>(1));
check(scope, saturating_cast<double>(z), cast<double>(-1), cast<double>(1));
check(scope, saturating_cast<float16_t>(z), cast<float16_t>(-1), cast<float16_t>(1));
check(scope, saturating_cast<uint8_t>(z), cast<uint8_t>(0), cast<uint8_t>(1));
scope.pop("z");
}
{
Expr z = Variable::make(UInt(32), "z");
scope.push("z", Interval(UInt(32).max(), UInt(32).max()));
check(scope, saturating_cast<int32_t>(z), Int(32).max(), Int(32).max());
scope.pop("z");
}

{
Scope<Interval> scope;
Expr x = Variable::make(UInt(16), "x");
Expand Down
147 changes: 79 additions & 68 deletions src/CodeGen_ARM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,43 +142,43 @@ CodeGen_ARM::CodeGen_ARM(const Target &target)
// TODO: We need to match rounding shift right, and negate the RHS.

// SQRSHRN, SQRSHRUN, UQRSHRN - Saturating rounding narrowing shift right narrow (by immediate in [1, output bits])
casts.emplace_back("saturating_rounding_shift_right_narrow", i8_sat(rounding_shift_right(wild_i16x_, wild_u16_)));
casts.emplace_back("saturating_rounding_shift_right_narrow", u8_sat(rounding_shift_right(wild_u16x_, wild_u16_)));
casts.emplace_back("saturating_rounding_shift_right_narrow", u8_sat(rounding_shift_right(wild_i16x_, wild_u16_)));
casts.emplace_back("saturating_rounding_shift_right_narrow", i16_sat(rounding_shift_right(wild_i32x_, wild_u32_)));
casts.emplace_back("saturating_rounding_shift_right_narrow", u16_sat(rounding_shift_right(wild_u32x_, wild_u32_)));
casts.emplace_back("saturating_rounding_shift_right_narrow", u16_sat(rounding_shift_right(wild_i32x_, wild_u32_)));
casts.emplace_back("saturating_rounding_shift_right_narrow", i32_sat(rounding_shift_right(wild_i64x_, wild_u64_)));
casts.emplace_back("saturating_rounding_shift_right_narrow", u32_sat(rounding_shift_right(wild_u64x_, wild_u64_)));
casts.emplace_back("saturating_rounding_shift_right_narrow", u32_sat(rounding_shift_right(wild_i64x_, wild_u64_)));
calls.emplace_back("saturating_rounding_shift_right_narrow", i8_sat(rounding_shift_right(wild_i16x_, wild_u16_)));
calls.emplace_back("saturating_rounding_shift_right_narrow", u8_sat(rounding_shift_right(wild_u16x_, wild_u16_)));
calls.emplace_back("saturating_rounding_shift_right_narrow", u8_sat(rounding_shift_right(wild_i16x_, wild_u16_)));
calls.emplace_back("saturating_rounding_shift_right_narrow", i16_sat(rounding_shift_right(wild_i32x_, wild_u32_)));
calls.emplace_back("saturating_rounding_shift_right_narrow", u16_sat(rounding_shift_right(wild_u32x_, wild_u32_)));
calls.emplace_back("saturating_rounding_shift_right_narrow", u16_sat(rounding_shift_right(wild_i32x_, wild_u32_)));
calls.emplace_back("saturating_rounding_shift_right_narrow", i32_sat(rounding_shift_right(wild_i64x_, wild_u64_)));
calls.emplace_back("saturating_rounding_shift_right_narrow", u32_sat(rounding_shift_right(wild_u64x_, wild_u64_)));
calls.emplace_back("saturating_rounding_shift_right_narrow", u32_sat(rounding_shift_right(wild_i64x_, wild_u64_)));

// SQSHL, UQSHL, SQSHLU - Saturating shift left by signed register.
for (const Expr &rhs : {wild_i8x_, wild_u8x_}) {
casts.emplace_back("saturating_shift_left", i8_sat(widening_shift_left(wild_i8x_, rhs)));
casts.emplace_back("saturating_shift_left", u8_sat(widening_shift_left(wild_u8x_, rhs)));
casts.emplace_back("saturating_shift_left", u8_sat(widening_shift_left(wild_i8x_, rhs)));
calls.emplace_back("saturating_shift_left", i8_sat(widening_shift_left(wild_i8x_, rhs)));
calls.emplace_back("saturating_shift_left", u8_sat(widening_shift_left(wild_u8x_, rhs)));
calls.emplace_back("saturating_shift_left", u8_sat(widening_shift_left(wild_i8x_, rhs)));
}
for (const Expr &rhs : {wild_i16x_, wild_u16x_}) {
casts.emplace_back("saturating_shift_left", i16_sat(widening_shift_left(wild_i16x_, rhs)));
casts.emplace_back("saturating_shift_left", u16_sat(widening_shift_left(wild_u16x_, rhs)));
casts.emplace_back("saturating_shift_left", u16_sat(widening_shift_left(wild_i16x_, rhs)));
calls.emplace_back("saturating_shift_left", i16_sat(widening_shift_left(wild_i16x_, rhs)));
calls.emplace_back("saturating_shift_left", u16_sat(widening_shift_left(wild_u16x_, rhs)));
calls.emplace_back("saturating_shift_left", u16_sat(widening_shift_left(wild_i16x_, rhs)));
}
for (const Expr &rhs : {wild_i32x_, wild_u32x_}) {
casts.emplace_back("saturating_shift_left", i32_sat(widening_shift_left(wild_i32x_, rhs)));
casts.emplace_back("saturating_shift_left", u32_sat(widening_shift_left(wild_u32x_, rhs)));
casts.emplace_back("saturating_shift_left", u32_sat(widening_shift_left(wild_i32x_, rhs)));
calls.emplace_back("saturating_shift_left", i32_sat(widening_shift_left(wild_i32x_, rhs)));
calls.emplace_back("saturating_shift_left", u32_sat(widening_shift_left(wild_u32x_, rhs)));
calls.emplace_back("saturating_shift_left", u32_sat(widening_shift_left(wild_i32x_, rhs)));
}

// SQSHRN, UQSHRN, SQRSHRUN Saturating narrowing shift right by an (by immediate in [1, output bits])
casts.emplace_back("saturating_shift_right_narrow", i8_sat(wild_i16x_ >> wild_u16_));
casts.emplace_back("saturating_shift_right_narrow", u8_sat(wild_u16x_ >> wild_u16_));
casts.emplace_back("saturating_shift_right_narrow", u8_sat(wild_i16x_ >> wild_u16_));
casts.emplace_back("saturating_shift_right_narrow", i16_sat(wild_i32x_ >> wild_u32_));
casts.emplace_back("saturating_shift_right_narrow", u16_sat(wild_u32x_ >> wild_u32_));
casts.emplace_back("saturating_shift_right_narrow", u16_sat(wild_i32x_ >> wild_u32_));
casts.emplace_back("saturating_shift_right_narrow", i32_sat(wild_i64x_ >> wild_u64_));
casts.emplace_back("saturating_shift_right_narrow", u32_sat(wild_u64x_ >> wild_u64_));
casts.emplace_back("saturating_shift_right_narrow", u32_sat(wild_i64x_ >> wild_u64_));
calls.emplace_back("saturating_shift_right_narrow", i8_sat(wild_i16x_ >> wild_u16_));
calls.emplace_back("saturating_shift_right_narrow", u8_sat(wild_u16x_ >> wild_u16_));
calls.emplace_back("saturating_shift_right_narrow", u8_sat(wild_i16x_ >> wild_u16_));
calls.emplace_back("saturating_shift_right_narrow", i16_sat(wild_i32x_ >> wild_u32_));
calls.emplace_back("saturating_shift_right_narrow", u16_sat(wild_u32x_ >> wild_u32_));
calls.emplace_back("saturating_shift_right_narrow", u16_sat(wild_i32x_ >> wild_u32_));
calls.emplace_back("saturating_shift_right_narrow", i32_sat(wild_i64x_ >> wild_u64_));
calls.emplace_back("saturating_shift_right_narrow", u32_sat(wild_u64x_ >> wild_u64_));
calls.emplace_back("saturating_shift_right_narrow", u32_sat(wild_i64x_ >> wild_u64_));

// SRSHL, URSHL - Rounding shift left (by signed vector)
// These are already written as rounding_shift_left
Expand All @@ -190,15 +190,15 @@ CodeGen_ARM::CodeGen_ARM(const Target &target)
// These patterns are almost identity, we just need to strip off the broadcast.

// SQXTN, UQXTN, SQXTUN - Saturating narrow.
casts.emplace_back("saturating_narrow", i8_sat(wild_i16x_));
casts.emplace_back("saturating_narrow", u8_sat(wild_u16x_));
casts.emplace_back("saturating_narrow", u8_sat(wild_i16x_));
casts.emplace_back("saturating_narrow", i16_sat(wild_i32x_));
casts.emplace_back("saturating_narrow", u16_sat(wild_u32x_));
casts.emplace_back("saturating_narrow", u16_sat(wild_i32x_));
casts.emplace_back("saturating_narrow", i32_sat(wild_i64x_));
casts.emplace_back("saturating_narrow", u32_sat(wild_u64x_));
casts.emplace_back("saturating_narrow", u32_sat(wild_i64x_));
calls.emplace_back("saturating_narrow", i8_sat(wild_i16x_));
calls.emplace_back("saturating_narrow", u8_sat(wild_u16x_));
calls.emplace_back("saturating_narrow", u8_sat(wild_i16x_));
calls.emplace_back("saturating_narrow", i16_sat(wild_i32x_));
calls.emplace_back("saturating_narrow", u16_sat(wild_u32x_));
calls.emplace_back("saturating_narrow", u16_sat(wild_i32x_));
calls.emplace_back("saturating_narrow", i32_sat(wild_i64x_));
calls.emplace_back("saturating_narrow", u32_sat(wild_u64x_));
calls.emplace_back("saturating_narrow", u32_sat(wild_i64x_));

// SQNEG - Saturating negate
negations.emplace_back("saturating_negate", -max(wild_i8x_, -127));
Expand Down Expand Up @@ -798,38 +798,6 @@ void CodeGen_ARM::visit(const Cast *op) {
return;
}
}

// If we didn't find a pattern, try rewriting the cast.
static const vector<pair<Expr, Expr>> cast_rewrites = {
// Double or triple narrowing saturating casts are better expressed as
// regular narrowing casts.
{u8_sat(wild_u32x_), u8_sat(u16_sat(wild_u32x_))},
{u8_sat(wild_i32x_), u8_sat(i16_sat(wild_i32x_))},
{u8_sat(wild_f32x_), u8_sat(i16_sat(wild_f32x_))},
{i8_sat(wild_u32x_), i8_sat(u16_sat(wild_u32x_))},
{i8_sat(wild_i32x_), i8_sat(i16_sat(wild_i32x_))},
{i8_sat(wild_f32x_), i8_sat(i16_sat(wild_f32x_))},
{u16_sat(wild_u64x_), u16_sat(u32_sat(wild_u64x_))},
{u16_sat(wild_i64x_), u16_sat(i32_sat(wild_i64x_))},
{u16_sat(wild_f64x_), u16_sat(i32_sat(wild_f64x_))},
{i16_sat(wild_u64x_), i16_sat(u32_sat(wild_u64x_))},
{i16_sat(wild_i64x_), i16_sat(i32_sat(wild_i64x_))},
{i16_sat(wild_f64x_), i16_sat(i32_sat(wild_f64x_))},
{u8_sat(wild_u64x_), u8_sat(u16_sat(u32_sat(wild_u64x_)))},
{u8_sat(wild_i64x_), u8_sat(i16_sat(i32_sat(wild_i64x_)))},
{u8_sat(wild_f64x_), u8_sat(i16_sat(i32_sat(wild_f64x_)))},
{i8_sat(wild_u64x_), i8_sat(u16_sat(u32_sat(wild_u64x_)))},
{i8_sat(wild_i64x_), i8_sat(i16_sat(i32_sat(wild_i64x_)))},
{i8_sat(wild_f64x_), i8_sat(i16_sat(i32_sat(wild_f64x_)))},
};
for (const auto &i : cast_rewrites) {
if (expr_match(i.first, op, matches)) {
Expr replacement = substitute("*", matches[0], with_lanes(i.second, op->type.lanes()));
debug(3) << "rewriting cast to: " << replacement << " from " << Expr(op) << "\n";
value = codegen(replacement);
return;
}
}
}

// LLVM fptoui generates fcvtzs if src is fp16 scalar else fcvtzu.
Expand Down Expand Up @@ -1177,12 +1145,55 @@ void CodeGen_ARM::visit(const Call *op) {
vector<Expr> matches;
for (const Pattern &pattern : calls) {
if (expr_match(pattern.pattern, op, matches)) {
if (pattern.intrin.find("shift_right_narrow") != string::npos) {
// The shift_right_narrow patterns need the shift to be constant in [1, output_bits].
const uint64_t *const_b = as_const_uint(matches[1]);
if (!const_b || *const_b == 0 || (int)*const_b > op->type.bits()) {
continue;
}
}
if (target.bits == 32 && pattern.intrin.find("shift_right") != string::npos) {
// The 32-bit ARM backend wants right shifts as negative values.
matches[1] = simplify(-cast(matches[1].type().with_code(halide_type_int), matches[1]));
}
value = call_overloaded_intrin(op->type, pattern.intrin, matches);
if (value) {
return;
}
}
}

// If we didn't find a pattern, try rewriting any saturating casts.
static const vector<pair<Expr, Expr>> cast_rewrites = {
// Double or triple narrowing saturating casts are better expressed as
rootjalex marked this conversation as resolved.
Show resolved Hide resolved
// combinations of single narrowing saturating casts.
{u8_sat(wild_u32x_), u8_sat(u16_sat(wild_u32x_))},
{u8_sat(wild_i32x_), u8_sat(i16_sat(wild_i32x_))},
{u8_sat(wild_f32x_), u8_sat(i16_sat(wild_f32x_))},
{i8_sat(wild_u32x_), i8_sat(u16_sat(wild_u32x_))},
{i8_sat(wild_i32x_), i8_sat(i16_sat(wild_i32x_))},
{i8_sat(wild_f32x_), i8_sat(i16_sat(wild_f32x_))},
{u16_sat(wild_u64x_), u16_sat(u32_sat(wild_u64x_))},
{u16_sat(wild_i64x_), u16_sat(i32_sat(wild_i64x_))},
{u16_sat(wild_f64x_), u16_sat(i32_sat(wild_f64x_))},
{i16_sat(wild_u64x_), i16_sat(u32_sat(wild_u64x_))},
{i16_sat(wild_i64x_), i16_sat(i32_sat(wild_i64x_))},
{i16_sat(wild_f64x_), i16_sat(i32_sat(wild_f64x_))},
{u8_sat(wild_u64x_), u8_sat(u16_sat(u32_sat(wild_u64x_)))},
{u8_sat(wild_i64x_), u8_sat(i16_sat(i32_sat(wild_i64x_)))},
{u8_sat(wild_f64x_), u8_sat(i16_sat(i32_sat(wild_f64x_)))},
{i8_sat(wild_u64x_), i8_sat(u16_sat(u32_sat(wild_u64x_)))},
{i8_sat(wild_i64x_), i8_sat(i16_sat(i32_sat(wild_i64x_)))},
{i8_sat(wild_f64x_), i8_sat(i16_sat(i32_sat(wild_f64x_)))},
};
for (const auto &i : cast_rewrites) {
if (expr_match(i.first, op, matches)) {
Expr replacement = substitute("*", matches[0], with_lanes(i.second, op->type.lanes()));
debug(3) << "rewriting cast to: " << replacement << " from " << Expr(op) << "\n";
value = codegen(replacement);
return;
}
}
}

if (target.has_feature(Target::ARMFp16)) {
Expand Down
41 changes: 36 additions & 5 deletions src/CodeGen_WebAssembly.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class CodeGen_WebAssembly : public CodeGen_Posix {
bool use_pic() const override;

void visit(const Cast *) override;
void visit(const Call *) override;
void codegen_vector_reduce(const VectorReduce *, const Expr &) override;
};

Expand Down Expand Up @@ -147,11 +148,6 @@ void CodeGen_WebAssembly::visit(const Cast *op) {

// clang-format off
static const Pattern patterns[] = {
{"q15mulr_sat_s", i16_sat(rounding_shift_right(widening_mul(wild_i16x_, wild_i16x_), u16(15))), Target::WasmSimd128},
{"saturating_narrow", i8_sat(wild_i16x_), Target::WasmSimd128},
{"saturating_narrow", u8_sat(wild_i16x_), Target::WasmSimd128},
{"saturating_narrow", i16_sat(wild_i32x_), Target::WasmSimd128},
{"saturating_narrow", u16_sat(wild_i32x_), Target::WasmSimd128},
{"int_to_double", f64(wild_i32x_), Target::WasmSimd128},
{"int_to_double", f64(wild_u32x_), Target::WasmSimd128},
#if LLVM_VERSION == 130
Expand Down Expand Up @@ -184,6 +180,41 @@ void CodeGen_WebAssembly::visit(const Cast *op) {
CodeGen_Posix::visit(op);
}

void CodeGen_WebAssembly::visit(const Call *op) {
struct Pattern {
std::string intrin; ///< Name of the intrinsic
Expr pattern; ///< The pattern to match against
Target::Feature required_feature;
};

// clang-format off
static const Pattern patterns[] = {
{"q15mulr_sat_s", i16_sat(rounding_shift_right(widening_mul(wild_i16x_, wild_i16x_), u16(15))), Target::WasmSimd128},
{"saturating_narrow", i8_sat(wild_i16x_), Target::WasmSimd128},
{"saturating_narrow", u8_sat(wild_i16x_), Target::WasmSimd128},
{"saturating_narrow", i16_sat(wild_i32x_), Target::WasmSimd128},
{"saturating_narrow", u16_sat(wild_i32x_), Target::WasmSimd128},
};
// clang-format on

if (op->type.is_vector()) {
std::vector<Expr> matches;
for (const Pattern &p : patterns) {
if (!target.has_feature(p.required_feature)) {
continue;
}
if (expr_match(p.pattern, op, matches)) {
value = call_overloaded_intrin(op->type, p.intrin, matches);
if (value) {
return;
}
}
}
}

CodeGen_Posix::visit(op);
}

void CodeGen_WebAssembly::codegen_vector_reduce(const VectorReduce *op, const Expr &init) {
struct Pattern {
VectorReduce::Operator reduce_op;
Expand Down
9 changes: 4 additions & 5 deletions src/CodeGen_X86.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -473,11 +473,6 @@ void CodeGen_X86::visit(const Cast *op) {
// saturate the result.
{"pmulhrs", i16(rounding_shift_right(widening_mul(wild_i16x_, wild_i16x_), 15))},

{"saturating_narrow", i16_sat(wild_i32x_)},
{"saturating_narrow", u16_sat(wild_i32x_)},
{"saturating_narrow", i8_sat(wild_i16x_)},
{"saturating_narrow", u8_sat(wild_i16x_)},

{"f32_to_bf16", bf16(wild_f32x_)},
};
// clang-format on
Expand Down Expand Up @@ -575,6 +570,10 @@ void CodeGen_X86::visit(const Call *op) {
{"pmulh", mul_shift_right(wild_i16x_, wild_i16x_, 16)},
{"pmulh", mul_shift_right(wild_u16x_, wild_u16x_, 16)},
{"saturating_pmulhrs", rounding_mul_shift_right(wild_i16x_, wild_i16x_, 15)},
{"saturating_narrow", i16_sat(wild_i32x_)},
{"saturating_narrow", u16_sat(wild_i32x_)},
{"saturating_narrow", i8_sat(wild_i16x_)},
{"saturating_narrow", u8_sat(wild_i16x_)},
};
// clang-format on

Expand Down
Loading