diff --git a/src/CodeGen_ARM.cpp b/src/CodeGen_ARM.cpp index c445d3477b81..fbcff72d7db1 100644 --- a/src/CodeGen_ARM.cpp +++ b/src/CodeGen_ARM.cpp @@ -1145,6 +1145,17 @@ void CodeGen_ARM::visit(const Call *op) { vector matches; for (const Pattern &pattern : calls) { if (expr_match(pattern.pattern, op, matches)) { + if (pattern.intrin.find("shift_right_narrow") != string::npos) { + // The shift_right_narrow patterns need the shift to be constant in [1, output_bits]. + const uint64_t *const_b = as_const_uint(matches[1]); + if (!const_b || *const_b == 0 || (int)*const_b > op->type.bits()) { + continue; + } + } + if (target.bits == 32 && pattern.intrin.find("shift_right") != string::npos) { + // The 32-bit ARM backend wants right shifts as negative values. + matches[1] = simplify(-cast(matches[1].type().with_code(halide_type_int), matches[1])); + } value = call_overloaded_intrin(op->type, pattern.intrin, matches); if (value) { return;