diff --git a/python_bindings/src/PyEnums.cpp b/python_bindings/src/PyEnums.cpp index ed453ca08d6f..26b8866035c6 100644 --- a/python_bindings/src/PyEnums.cpp +++ b/python_bindings/src/PyEnums.cpp @@ -142,6 +142,7 @@ void define_enums(py::module &m) { .value("WasmSignExt", Target::Feature::WasmSignExt) .value("SVE", Target::Feature::SVE) .value("SVE2", Target::Feature::SVE2) + .value("ARMDotProd", Target::Feature::ARMDotProd) .value("FeatureEnd", Target::Feature::FeatureEnd); py::enum_(m, "TypeCode") diff --git a/src/CodeGen_ARM.cpp b/src/CodeGen_ARM.cpp index 5e9809d49210..9e9334f1bc30 100644 --- a/src/CodeGen_ARM.cpp +++ b/src/CodeGen_ARM.cpp @@ -1,6 +1,7 @@ #include #include +#include "CSE.h" #include "CodeGen_ARM.h" #include "ConciseCasts.h" #include "Debug.h" @@ -483,10 +484,6 @@ void CodeGen_ARM::visit(const Div *op) { CodeGen_Posix::visit(op); } -void CodeGen_ARM::visit(const Add *op) { - CodeGen_Posix::visit(op); -} - void CodeGen_ARM::visit(const Sub *op) { if (neon_intrinsics_disabled()) { CodeGen_Posix::visit(op); @@ -1063,6 +1060,184 @@ void CodeGen_ARM::visit(const LE *op) { CodeGen_Posix::visit(op); } +void CodeGen_ARM::codegen_vector_reduce(const VectorReduce *op, const Expr &init) { + if (neon_intrinsics_disabled() || + op->op == VectorReduce::Or || + op->op == VectorReduce::And || + op->op == VectorReduce::Mul) { + CodeGen_Posix::codegen_vector_reduce(op, init); + return; + } + + // ARM has a variety of pairwise reduction ops for +, min, + // max. The versions that do not widen take two 64-bit args and + // return one 64-bit vector of the same type. The versions that + // widen take one arg and return something with half the vector + // lanes and double the bit-width. + + int factor = op->value.type().lanes() / op->type.lanes(); + + // These are the types for which we have reduce intrinsics in the + // runtime. + bool have_reduce_intrinsic = (op->type.is_int() || + op->type.is_uint() || + op->type.is_float()); + + // We don't have 16-bit float or bfloat horizontal ops + if (op->type.is_bfloat() || (op->type.is_float() && op->type.bits() < 32)) { + have_reduce_intrinsic = false; + } + + // Only aarch64 has float64 horizontal ops + if (target.bits == 32 && op->type.element_of() == Float(64)) { + have_reduce_intrinsic = false; + } + + // For 64-bit integers, we only have addition, not min/max + if (op->type.bits() == 64 && + !op->type.is_float() && + op->op != VectorReduce::Add) { + have_reduce_intrinsic = false; + } + + // We only have intrinsics that reduce by a factor of two + if (factor != 2) { + have_reduce_intrinsic = false; + } + + if (have_reduce_intrinsic) { + Expr arg = op->value; + if (op->op == VectorReduce::Add && + op->type.bits() >= 16 && + !op->type.is_float()) { + Type narrower_type = arg.type().with_bits(arg.type().bits() / 2); + Expr narrower = lossless_cast(narrower_type, arg); + if (!narrower.defined() && arg.type().is_int()) { + // We can also safely accumulate from a uint into a + // wider int, because the addition uses at most one + // extra bit. + narrower = lossless_cast(narrower_type.with_code(Type::UInt), arg); + } + if (narrower.defined()) { + arg = narrower; + } + } + int output_bits; + if (target.bits == 32 && arg.type().bits() == op->type.bits()) { + // For the non-widening version, the output must be 64-bit + output_bits = 64; + } else if (op->type.bits() * op->type.lanes() <= 64) { + // No point using the 128-bit version of the instruction if the output is narrow. + output_bits = 64; + } else { + output_bits = 128; + } + + const int output_lanes = output_bits / op->type.bits(); + Type intrin_type = op->type.with_lanes(output_lanes); + Type arg_type = arg.type().with_lanes(output_lanes * 2); + if (op->op == VectorReduce::Add && + arg.type().bits() == op->type.bits() && + arg_type.is_uint()) { + // For non-widening additions, there is only a signed + // version (because it's equivalent). + arg_type = arg_type.with_code(Type::Int); + intrin_type = intrin_type.with_code(Type::Int); + } else if (arg.type().is_uint() && intrin_type.is_int()) { + // Use the uint version + intrin_type = intrin_type.with_code(Type::UInt); + } + + std::stringstream ss; + vector args; + ss << "pairwise_" << op->op << "_" << intrin_type << "_" << arg_type; + Expr accumulator = init; + if (op->op == VectorReduce::Add && + accumulator.defined() && + arg_type.bits() < intrin_type.bits()) { + // We can use the accumulating variant + ss << "_accumulate"; + args.push_back(init); + accumulator = Expr(); + } + args.push_back(arg); + value = call_intrin(op->type, output_lanes, ss.str(), args); + + if (accumulator.defined()) { + // We still have an initial value to take care of + string n = unique_name('t'); + sym_push(n, value); + Expr v = Variable::make(accumulator.type(), n); + switch (op->op) { + case VectorReduce::Add: + accumulator += v; + break; + case VectorReduce::Min: + accumulator = min(accumulator, v); + break; + case VectorReduce::Max: + accumulator = max(accumulator, v); + break; + default: + internal_error << "unreachable"; + } + codegen(accumulator); + sym_pop(n); + } + + return; + } + + // Pattern-match 8-bit dot product instructions available on newer + // ARM cores. + if (target.has_feature(Target::ARMDotProd) && + factor % 4 == 0 && + op->op == VectorReduce::Add && + target.bits == 64 && + (op->type.element_of() == Int(32) || + op->type.element_of() == UInt(32))) { + const Mul *mul = op->value.as(); + if (mul) { + const int input_lanes = mul->type.lanes(); + Expr a = lossless_cast(UInt(8, input_lanes), mul->a); + Expr b = lossless_cast(UInt(8, input_lanes), mul->b); + if (!a.defined()) { + a = lossless_cast(Int(8, input_lanes), mul->a); + b = lossless_cast(Int(8, input_lanes), mul->b); + } + if (a.defined() && b.defined()) { + if (factor != 4) { + Expr equiv = VectorReduce::make(op->op, op->value, input_lanes / 4); + equiv = VectorReduce::make(op->op, equiv, op->type.lanes()); + codegen_vector_reduce(equiv.as(), init); + return; + } + Expr i = init; + if (!i.defined()) { + i = make_zero(op->type); + } + vector args{i, a, b}; + if (op->type.lanes() <= 2) { + if (op->type.is_uint()) { + value = call_intrin(op->type, 2, "llvm.aarch64.neon.udot.v2i32.v8i8", args); + } else { + value = call_intrin(op->type, 2, "llvm.aarch64.neon.sdot.v2i32.v8i8", args); + } + } else { + if (op->type.is_uint()) { + value = call_intrin(op->type, 4, "llvm.aarch64.neon.udot.v4i32.v16i8", args); + } else { + value = call_intrin(op->type, 4, "llvm.aarch64.neon.sdot.v4i32.v16i8", args); + } + } + return; + } + } + } + + CodeGen_Posix::codegen_vector_reduce(op, init); +} + string CodeGen_ARM::mcpu() const { if (target.bits == 32) { if (target.has_feature(Target::ARMv7s)) { @@ -1098,6 +1273,10 @@ string CodeGen_ARM::mattrs() const { arch_flags = "+sve"; } + if (target.has_feature(Target::ARMDotProd)) { + arch_flags += "+dotprod"; + } + if (target.os == Target::IOS || target.os == Target::OSX) { return arch_flags + "+reserve-x18"; } else { diff --git a/src/CodeGen_ARM.h b/src/CodeGen_ARM.h index fc1c3ed848f4..c8d868169296 100644 --- a/src/CodeGen_ARM.h +++ b/src/CodeGen_ARM.h @@ -24,7 +24,6 @@ class CodeGen_ARM : public CodeGen_Posix { /** Nodes for which we want to emit specific neon intrinsics */ // @{ void visit(const Cast *) override; - void visit(const Add *) override; void visit(const Sub *) override; void visit(const Div *) override; void visit(const Mul *) override; @@ -35,6 +34,7 @@ class CodeGen_ARM : public CodeGen_Posix { void visit(const Call *) override; void visit(const LT *) override; void visit(const LE *) override; + void codegen_vector_reduce(const VectorReduce *, const Expr &) override; // @} /** Various patterns to peephole match against */ diff --git a/src/CodeGen_LLVM.cpp b/src/CodeGen_LLVM.cpp index cc26207e3fb9..5fe39359e47c 100644 --- a/src/CodeGen_LLVM.cpp +++ b/src/CodeGen_LLVM.cpp @@ -1394,6 +1394,16 @@ Value *CodeGen_LLVM::codegen(const Expr &e) { value = nullptr; e.accept(this); internal_assert(value) << "Codegen of an expr did not produce an llvm value\n"; + + // Halide's type system doesn't distinguish between scalars and + // vectors of size 1, so if a codegen method returned a vector of + // size one, just extract it out as a scalar. + if (e.type().is_scalar() && + value->getType()->isVectorTy()) { + internal_assert(get_vector_num_elements(value->getType()) == 1); + value = builder->CreateExtractElement(value, ConstantInt::get(i32_t, 0)); + } + // TODO: skip this correctness check for bool vectors, // as eliminate_bool_vectors() will cause a discrepancy for some backends // (eg OpenCL, HVX); for now we're just ignoring the assert, but @@ -1534,6 +1544,27 @@ void CodeGen_LLVM::visit(const Variable *op) { value = sym_get(op->name); } +template +bool CodeGen_LLVM::try_to_fold_vector_reduce(const Op *op) { + const VectorReduce *red = op->a.template as(); + Expr b = op->b; + if (!red) { + red = op->b.template as(); + b = op->a; + } + if (red && + ((std::is_same::value && red->op == VectorReduce::Add) || + (std::is_same::value && red->op == VectorReduce::Min) || + (std::is_same::value && red->op == VectorReduce::Max) || + (std::is_same::value && red->op == VectorReduce::Mul) || + (std::is_same::value && red->op == VectorReduce::And) || + (std::is_same::value && red->op == VectorReduce::Or))) { + codegen_vector_reduce(red, b); + return true; + } + return false; +} + void CodeGen_LLVM::visit(const Add *op) { Type t = upgrade_type_for_arithmetic(op->type); if (t != op->type) { @@ -1541,6 +1572,11 @@ void CodeGen_LLVM::visit(const Add *op) { return; } + // Some backends can fold the add into a vector reduce + if (try_to_fold_vector_reduce(op)) { + return; + } + Value *a = codegen(op->a); Value *b = codegen(op->b); if (op->type.is_float()) { @@ -1581,6 +1617,10 @@ void CodeGen_LLVM::visit(const Mul *op) { return; } + if (try_to_fold_vector_reduce(op)) { + return; + } + Value *a = codegen(op->a); Value *b = codegen(op->b); if (op->type.is_float()) { @@ -1637,6 +1677,10 @@ void CodeGen_LLVM::visit(const Min *op) { return; } + if (try_to_fold_vector_reduce(op)) { + return; + } + string a_name = unique_name('a'); string b_name = unique_name('b'); Expr a = Variable::make(op->a.type(), a_name); @@ -1653,6 +1697,10 @@ void CodeGen_LLVM::visit(const Max *op) { return; } + if (try_to_fold_vector_reduce(op)) { + return; + } + string a_name = unique_name('a'); string b_name = unique_name('b'); Expr a = Variable::make(op->a.type(), a_name); @@ -1768,12 +1816,20 @@ void CodeGen_LLVM::visit(const GE *op) { } void CodeGen_LLVM::visit(const And *op) { + if (try_to_fold_vector_reduce(op)) { + return; + } + Value *a = codegen(op->a); Value *b = codegen(op->b); value = builder->CreateAnd(a, b); } void CodeGen_LLVM::visit(const Or *op) { + if (try_to_fold_vector_reduce(op)) { + return; + } + Value *a = codegen(op->a); Value *b = codegen(op->b); value = builder->CreateOr(a, b); @@ -2352,7 +2408,7 @@ Value *CodeGen_LLVM::codegen_dense_vector_load(const Load *load, Value *vpred) { // For dense vector loads wider than the native vector // width, bust them up into native vectors int load_lanes = load->type.lanes(); - int native_lanes = native_bits / load->type.bits(); + int native_lanes = std::max(1, native_bits / load->type.bits()); vector slices; for (int i = 0; i < load_lanes; i += native_lanes) { int slice_lanes = std::min(native_lanes, load_lanes - i); @@ -4223,11 +4279,251 @@ void CodeGen_LLVM::visit(const Shuffle *op) { } } - if (op->type.is_scalar()) { + if (op->type.is_scalar() && value->getType()->isVectorTy()) { value = builder->CreateExtractElement(value, ConstantInt::get(i32_t, 0)); } } +void CodeGen_LLVM::visit(const VectorReduce *op) { + codegen_vector_reduce(op, Expr()); +} + +void CodeGen_LLVM::codegen_vector_reduce(const VectorReduce *op, const Expr &init) { + Expr val = op->value; + const int output_lanes = op->type.lanes(); + const int native_lanes = native_vector_bits() / op->type.bits(); + const int factor = val.type().lanes() / output_lanes; + + Expr (*binop)(Expr, Expr) = nullptr; + switch (op->op) { + case VectorReduce::Add: + binop = Add::make; + break; + case VectorReduce::Mul: + binop = Mul::make; + break; + case VectorReduce::Min: + binop = Min::make; + break; + case VectorReduce::Max: + binop = Max::make; + break; + case VectorReduce::And: + binop = And::make; + break; + case VectorReduce::Or: + binop = Or::make; + break; + } + + if (op->type.is_bool() && op->op == VectorReduce::Or) { + // Cast to u8, use max, cast back to bool. + Expr equiv = cast(op->value.type().with_bits(8), op->value); + equiv = VectorReduce::make(VectorReduce::Max, equiv, op->type.lanes()); + if (init.defined()) { + equiv = max(equiv, init); + } + equiv = cast(op->type, equiv); + equiv.accept(this); + return; + } + + if (op->type.is_bool() && op->op == VectorReduce::And) { + // Cast to u8, use min, cast back to bool. + Expr equiv = cast(op->value.type().with_bits(8), op->value); + equiv = VectorReduce::make(VectorReduce::Min, equiv, op->type.lanes()); + equiv = cast(op->type, equiv); + if (init.defined()) { + equiv = min(equiv, init); + } + equiv.accept(this); + return; + } + + if (op->type.element_of() == Float(16)) { + Expr equiv = cast(op->value.type().with_bits(32), op->value); + equiv = VectorReduce::make(op->op, equiv, op->type.lanes()); + if (init.defined()) { + equiv = binop(equiv, init); + } + equiv = cast(op->type, equiv); + equiv.accept(this); + return; + } + +#if LLVM_VERSION >= 90 + if (output_lanes == 1) { + const int input_lanes = val.type().lanes(); + const int input_bytes = input_lanes * val.type().bytes(); + const bool llvm_has_intrinsic = + // Must be one of these ops + ((op->op == VectorReduce::Add || + op->op == VectorReduce::Mul || + op->op == VectorReduce::Min || + op->op == VectorReduce::Max) && + // Must be a power of two lanes + (input_lanes >= 2) && + ((input_lanes & (input_lanes - 1)) == 0) && + // int versions exist up to 1024 bits + ((!op->type.is_float() && input_bytes <= 1024) || + // float versions exist up to 16 lanes + input_lanes <= 16) && + // As of the release of llvm 10, the 64-bit experimental total + // reductions don't seem to be done yet on arm. + (val.type().bits() != 64 || + target.arch != Target::ARM)); + + if (llvm_has_intrinsic) { + std::stringstream name; + name << "llvm.experimental.vector.reduce."; + const int bits = op->type.bits(); + bool takes_initial_value = false; + Expr initial_value = init; + if (op->type.is_float()) { + switch (op->op) { + case VectorReduce::Add: + name << "v2.fadd.f" << bits; + takes_initial_value = true; + if (!initial_value.defined()) { + initial_value = make_zero(op->type); + } + break; + case VectorReduce::Mul: + name << "v2.fmul.f" << bits; + takes_initial_value = true; + if (!initial_value.defined()) { + initial_value = make_one(op->type); + } + break; + case VectorReduce::Min: + name << "fmin"; + break; + case VectorReduce::Max: + name << "fmax"; + break; + default: + break; + } + } else if (op->type.is_int() || op->type.is_uint()) { + switch (op->op) { + case VectorReduce::Add: + name << "add"; + break; + case VectorReduce::Mul: + name << "mul"; + break; + case VectorReduce::Min: + name << (op->type.is_int() ? 's' : 'u') << "min"; + break; + case VectorReduce::Max: + name << (op->type.is_int() ? 's' : 'u') << "max"; + break; + default: + break; + } + } + name << ".v" << val.type().lanes() << (op->type.is_float() ? 'f' : 'i') << bits; + + string intrin_name = name.str(); + + vector args; + if (takes_initial_value) { + args.push_back(initial_value); + initial_value = Expr(); + } + args.push_back(op->value); + + // Make sure the declaration exists, or the codegen for + // call will assume that the args should scalarize. + if (!module->getFunction(intrin_name)) { + vector arg_types; + for (const Expr &e : args) { + arg_types.push_back(llvm_type_of(e.type())); + } + FunctionType *func_t = FunctionType::get(llvm_type_of(op->type), arg_types, false); + llvm::Function::Create(func_t, llvm::Function::ExternalLinkage, intrin_name, module.get()); + } + + Expr equiv = Call::make(op->type, intrin_name, args, Call::PureExtern); + if (initial_value.defined()) { + equiv = binop(initial_value, equiv); + } + equiv.accept(this); + return; + } + } +#endif + + if (output_lanes == 1 && + factor > native_lanes && + factor % native_lanes == 0) { + // It's a total reduction of multiple native + // vectors. Start by adding the vectors together. + Expr equiv; + for (int i = 0; i < factor / native_lanes; i++) { + Expr next = Shuffle::make_slice(val, i * native_lanes, 1, native_lanes); + if (equiv.defined()) { + equiv = binop(equiv, next); + } else { + equiv = next; + } + } + equiv = VectorReduce::make(op->op, equiv, 1); + if (init.defined()) { + equiv = binop(equiv, init); + } + equiv = common_subexpression_elimination(equiv); + equiv.accept(this); + return; + } + + if (factor > 2 && ((factor & 1) == 0)) { + // Factor the reduce into multiple stages. If we're going to + // be widening the type by 4x or more we should also factor the + // widening into multiple stages. + Type intermediate_type = op->value.type().with_lanes(op->value.type().lanes() / 2); + Expr equiv = VectorReduce::make(op->op, op->value, intermediate_type.lanes()); + if (op->op == VectorReduce::Add && + (op->type.is_int() || op->type.is_uint()) && + op->type.bits() >= 32) { + Type narrower_type = op->value.type().with_bits(op->type.bits() / 4); + Expr narrower = lossless_cast(narrower_type, op->value); + if (!narrower.defined() && narrower_type.is_int()) { + // Maybe we can narrow to an unsigned int instead. + narrower_type = narrower_type.with_code(Type::UInt); + narrower = lossless_cast(narrower_type, op->value); + } + if (narrower.defined()) { + // Widen it by 2x before the horizontal add + narrower = cast(narrower.type().with_bits(narrower.type().bits() * 2), narrower); + equiv = VectorReduce::make(op->op, narrower, intermediate_type.lanes()); + // Then widen it by 2x again afterwards + equiv = cast(intermediate_type, equiv); + } + } + equiv = VectorReduce::make(op->op, equiv, op->type.lanes()); + if (init.defined()) { + equiv = binop(equiv, init); + } + equiv = common_subexpression_elimination(equiv); + codegen(equiv); + return; + } + + // Extract each slice and combine + Expr equiv = init; + for (int i = 0; i < factor; i++) { + Expr next = Shuffle::make_slice(val, i, factor, val.type().lanes() / factor); + if (equiv.defined()) { + equiv = binop(equiv, next); + } else { + equiv = next; + } + } + equiv = common_subexpression_elimination(equiv); + codegen(equiv); +} // namespace Internal + void CodeGen_LLVM::visit(const Atomic *op) { if (op->mutex_name != "") { internal_assert(!inside_atomic_mutex_node) @@ -4286,16 +4582,19 @@ Value *CodeGen_LLVM::call_intrin(const Type &result_type, int intrin_lanes, arg_values[i] = codegen(args[i]); } - return call_intrin(llvm_type_of(result_type), + llvm::Type *t = llvm_type_of(result_type); + + return call_intrin(t, intrin_lanes, name, arg_values); } Value *CodeGen_LLVM::call_intrin(llvm::Type *result_type, int intrin_lanes, const string &name, vector arg_values) { - internal_assert(result_type->isVectorTy()) << "call_intrin is for vector intrinsics only\n"; - - int arg_lanes = get_vector_num_elements(result_type); + int arg_lanes = 1; + if (result_type->isVectorTy()) { + arg_lanes = get_vector_num_elements(result_type); + } if (intrin_lanes != arg_lanes) { // Cut up each arg into appropriately-sized pieces, call the @@ -4304,17 +4603,24 @@ Value *CodeGen_LLVM::call_intrin(llvm::Type *result_type, int intrin_lanes, for (int start = 0; start < arg_lanes; start += intrin_lanes) { vector args; for (size_t i = 0; i < arg_values.size(); i++) { + int arg_i_lanes = 1; if (arg_values[i]->getType()->isVectorTy()) { - int arg_i_lanes = get_vector_num_elements(arg_values[i]->getType()); - internal_assert(arg_i_lanes >= arg_lanes); + arg_i_lanes = get_vector_num_elements(arg_values[i]->getType()); + } + if (arg_i_lanes >= arg_lanes) { // Horizontally reducing intrinsics may have // arguments that have more lanes than the // result. Assume that the horizontally reduce // neighboring elements... int reduce = arg_i_lanes / arg_lanes; args.push_back(slice_vector(arg_values[i], start * reduce, intrin_lanes * reduce)); - } else { + } else if (arg_i_lanes == 1) { + // It's a scalar arg to an intrinsic that returns + // a vector. Replicate it over the slices. args.push_back(arg_values[i]); + } else { + internal_error << "Argument in call_intrin has " << arg_i_lanes + << " with result type having " << arg_lanes << "\n"; } } @@ -4335,7 +4641,10 @@ Value *CodeGen_LLVM::call_intrin(llvm::Type *result_type, int intrin_lanes, llvm::Function *fn = module->getFunction(name); if (!fn) { - llvm::Type *intrinsic_result_type = VectorType::get(result_type->getScalarType(), intrin_lanes); + llvm::Type *intrinsic_result_type = result_type->getScalarType(); + if (intrin_lanes > 1) { + intrinsic_result_type = VectorType::get(result_type->getScalarType(), intrin_lanes); + } FunctionType *func_t = FunctionType::get(intrinsic_result_type, arg_types, false); fn = llvm::Function::Create(func_t, llvm::Function::ExternalLinkage, name, module.get()); fn->setCallingConv(CallingConv::C); @@ -4350,12 +4659,21 @@ Value *CodeGen_LLVM::call_intrin(llvm::Type *result_type, int intrin_lanes, } Value *CodeGen_LLVM::slice_vector(Value *vec, int start, int size) { + // Force the arg to be an actual vector + if (!vec->getType()->isVectorTy()) { + vec = create_broadcast(vec, 1); + } + int vec_lanes = get_vector_num_elements(vec->getType()); if (start == 0 && size == vec_lanes) { return vec; } + if (size == 1) { + return builder->CreateExtractElement(vec, (uint64_t)start); + } + vector indices(size); for (int i = 0; i < size; i++) { int idx = start + i; diff --git a/src/CodeGen_LLVM.h b/src/CodeGen_LLVM.h index 487f3ba0effe..3983041c231c 100644 --- a/src/CodeGen_LLVM.h +++ b/src/CodeGen_LLVM.h @@ -401,6 +401,7 @@ class CodeGen_LLVM : public IRVisitor { void visit(const IfThenElse *) override; void visit(const Evaluate *) override; void visit(const Shuffle *) override; + void visit(const VectorReduce *) override; void visit(const Prefetch *) override; void visit(const Atomic *) override; // @} @@ -512,6 +513,13 @@ class CodeGen_LLVM : public IRVisitor { virtual bool supports_atomic_add(const Type &t) const; + /** Compile a horizontal reduction that starts with an explicit + * initial value. There are lots of complex ways to peephole + * optimize this pattern, especially with the proliferation of + * dot-product instructions, and they can usefully share logic + * across backends. */ + virtual void codegen_vector_reduce(const VectorReduce *op, const Expr &init); + /** Are we inside an atomic node that uses mutex locks? This is used for detecting deadlocks from nested atomics & illegal vectorization. */ bool inside_atomic_mutex_node; @@ -560,6 +568,10 @@ class CodeGen_LLVM : public IRVisitor { void init_codegen(const std::string &name, bool any_strict_float = false); std::unique_ptr finish_codegen(); + + /** A helper routine for generating folded vector reductions. */ + template + bool try_to_fold_vector_reduce(const Op *op); }; } // namespace Internal diff --git a/src/CodeGen_PTX_Dev.cpp b/src/CodeGen_PTX_Dev.cpp index 20ff7cf97860..008f9a678060 100644 --- a/src/CodeGen_PTX_Dev.cpp +++ b/src/CodeGen_PTX_Dev.cpp @@ -331,6 +331,163 @@ void CodeGen_PTX_Dev::visit(const Atomic *op) { CodeGen_LLVM::visit(op); } +void CodeGen_PTX_Dev::codegen_vector_reduce(const VectorReduce *op, const Expr &init) { + // Pattern match 8/16-bit dot products + + const int input_lanes = op->value.type().lanes(); + const int factor = input_lanes / op->type.lanes(); + const Mul *mul = op->value.as(); + if (op->op == VectorReduce::Add && + mul && + (factor % 4 == 0) && + (op->type.element_of() == Int(32) || + op->type.element_of() == UInt(32))) { + Expr i = init; + if (!i.defined()) { + i = cast(mul->type, 0); + } + // Try to narrow the multiply args to 8-bit + Expr a = mul->a, b = mul->b; + if (op->type.is_uint()) { + a = lossless_cast(UInt(8, input_lanes), a); + b = lossless_cast(UInt(8, input_lanes), b); + } else { + a = lossless_cast(Int(8, input_lanes), a); + b = lossless_cast(Int(8, input_lanes), b); + if (!a.defined()) { + // try uint + a = lossless_cast(UInt(8, input_lanes), mul->a); + } + if (!b.defined()) { + b = lossless_cast(UInt(8, input_lanes), mul->b); + } + } + // If we only managed to narrow one of them, try to narrow the + // other to 16-bit. Swap the args so that it's always 'a'. + Expr a_orig = mul->a; + if (a.defined() && !b.defined()) { + std::swap(a, b); + a_orig = mul->b; + } + if (b.defined() && !a.defined()) { + // Try 16-bit instead + a = lossless_cast(UInt(16, input_lanes), a_orig); + if (!a.defined() && !op->type.is_uint()) { + a = lossless_cast(Int(16, input_lanes), a_orig); + } + } + + if (a.defined() && b.defined()) { + std::ostringstream ss; + if (a.type().bits() == 8) { + ss << "dp4a"; + } else { + ss << "dp2a"; + } + if (a.type().is_int()) { + ss << "_s32"; + } else { + ss << "_u32"; + } + if (b.type().is_int()) { + ss << "_s32"; + } else { + ss << "_u32"; + } + const int a_32_bit_words_per_sum = (factor * a.type().bits()) / 32; + const int b_32_bit_words_per_sum = (factor * b.type().bits()) / 32; + // Reinterpret a and b as 32-bit values with fewer + // lanes. If they're aligned dense loads we should just do a + // different load. + for (Expr *e : {&a, &b}) { + int sub_lanes = 32 / e->type().bits(); + const Load *load = e->as(); + const Ramp *idx = load ? load->index.as() : nullptr; + if (idx && + is_one(idx->stride) && + load->alignment.modulus % sub_lanes == 0 && + load->alignment.remainder % sub_lanes == 0) { + Expr new_idx = simplify(idx->base / sub_lanes); + int load_lanes = input_lanes / sub_lanes; + if (input_lanes > sub_lanes) { + new_idx = Ramp::make(new_idx, 1, load_lanes); + } + *e = Load::make(Int(32, load_lanes), + load->name, + new_idx, + load->image, + load->param, + const_true(load_lanes), + load->alignment / sub_lanes); + } else { + *e = reinterpret(Int(32, input_lanes / sub_lanes), *e); + } + } + string name = ss.str(); + vector result; + for (int l = 0; l < op->type.lanes(); l++) { + // To compute a single lane of the output, we'll + // extract the appropriate slice of the args, which + // have been reinterpreted as 32-bit vectors, then + // call either dp4a or dp2a the appropriate number of + // times, and finally sum the result. + Expr i_slice, a_slice, b_slice; + if (i.type().is_scalar()) { + i_slice = i; + } else { + i_slice = Shuffle::make_extract_element(i, l); + } + if (a.type().is_scalar()) { + a_slice = a; + } else { + a_slice = Shuffle::make_slice(a, l * a_32_bit_words_per_sum, 1, a_32_bit_words_per_sum); + } + if (b.type().is_scalar()) { + b_slice = b; + } else { + b_slice = Shuffle::make_slice(b, l * b_32_bit_words_per_sum, 1, b_32_bit_words_per_sum); + } + for (int i = 0; i < b_32_bit_words_per_sum; i++) { + if (a_slice.type().lanes() == b_slice.type().lanes()) { + Expr a_lane, b_lane; + if (b_slice.type().is_scalar()) { + a_lane = a_slice; + b_lane = b_slice; + } else { + a_lane = Shuffle::make_extract_element(a_slice, i); + b_lane = Shuffle::make_extract_element(b_slice, i); + } + i_slice = Call::make(i_slice.type(), name, + {a_lane, b_lane, i_slice}, + Call::PureExtern); + } else { + internal_assert(a_slice.type().lanes() == 2 * b_slice.type().lanes()); + Expr a_lane_lo, a_lane_hi, b_lane; + if (b_slice.type().is_scalar()) { + b_lane = b_slice; + } else { + b_lane = Shuffle::make_extract_element(b_slice, i); + } + a_lane_lo = Shuffle::make_extract_element(a_slice, 2 * i); + a_lane_hi = Shuffle::make_extract_element(a_slice, 2 * i + 1); + i_slice = Call::make(i_slice.type(), name, + {a_lane_lo, a_lane_hi, b_lane, i_slice}, + Call::PureExtern); + } + } + i_slice = simplify(i_slice); + i_slice = common_subexpression_elimination(i_slice); + result.push_back(i_slice); + } + // Concatenate the per-lane results to get the full vector result + Expr equiv = Shuffle::make_concat(result); + equiv.accept(this); + return; + } + } + CodeGen_LLVM::codegen_vector_reduce(op, init); +} + string CodeGen_PTX_Dev::march() const { return "nvptx64"; } diff --git a/src/CodeGen_PTX_Dev.h b/src/CodeGen_PTX_Dev.h index 94925d0a4e0a..7f7c80669f47 100644 --- a/src/CodeGen_PTX_Dev.h +++ b/src/CodeGen_PTX_Dev.h @@ -65,6 +65,7 @@ class CodeGen_PTX_Dev : public CodeGen_LLVM, public CodeGen_GPU_Dev { void visit(const Load *) override; void visit(const Store *) override; void visit(const Atomic *) override; + void codegen_vector_reduce(const VectorReduce *op, const Expr &init) override; // @} std::string march() const; diff --git a/src/CodeGen_X86.cpp b/src/CodeGen_X86.cpp index 884301c342a8..e8821d01a824 100644 --- a/src/CodeGen_X86.cpp +++ b/src/CodeGen_X86.cpp @@ -107,6 +107,33 @@ void CodeGen_X86::visit(const Sub *op) { } } +void CodeGen_X86::visit(const Mul *op) { + +#if LLVM_VERSION < 110 + // Widening integer multiply of non-power-of-two vector sizes is + // broken in older llvms for older x86: + // https://bugs.llvm.org/show_bug.cgi?id=44976 + const int lanes = op->type.lanes(); + if (!target.has_feature(Target::SSE41) && + (lanes & (lanes - 1)) && + (op->type.bits() >= 32) && + !op->type.is_float()) { + // Any fancy shuffles to pad or slice into smaller vectors + // just gets undone by LLVM and retriggers the bug. Just + // scalarize. + vector result; + for (int i = 0; i < lanes; i++) { + result.emplace_back(Shuffle::make_extract_element(op->a, i) * + Shuffle::make_extract_element(op->b, i)); + } + codegen(Shuffle::make_concat(result)); + return; + } +#endif + + return CodeGen_Posix::visit(op); +} + void CodeGen_X86::visit(const GT *op) { Type t = op->a.type(); @@ -390,6 +417,35 @@ void CodeGen_X86::visit(const Call *op) { CodeGen_Posix::visit(op); } +void CodeGen_X86::visit(const VectorReduce *op) { + const int factor = op->value.type().lanes() / op->type.lanes(); + + // Match pmaddwd. X86 doesn't have many horizontal reduction ops, + // and the ones that exist are hit by llvm automatically using the + // base class lowering of VectorReduce (see + // test/correctness/simd_op_check.cpp). + if (const Mul *mul = op->value.as()) { + Type narrower = Int(16, mul->type.lanes()); + Expr a = lossless_cast(narrower, mul->a); + Expr b = lossless_cast(narrower, mul->b); + if (op->type.is_int() && + op->type.bits() == 32 && + a.defined() && + b.defined() && + factor == 2 && + op->op == VectorReduce::Add) { + if (target.has_feature(Target::AVX2) && op->type.lanes() > 4) { + value = call_intrin(op->type, 8, "llvm.x86.avx2.pmadd.wd", {a, b}); + } else { + value = call_intrin(op->type, 4, "llvm.x86.sse2.pmadd.wd", {a, b}); + } + return; + } + } + + CodeGen_Posix::visit(op); +} + string CodeGen_X86::mcpu() const { if (target.has_feature(Target::AVX512_Cannonlake)) return "cannonlake"; if (target.has_feature(Target::AVX512_Skylake)) return "skylake-avx512"; diff --git a/src/CodeGen_X86.h b/src/CodeGen_X86.h index b76ffdeea359..e6e579ec071b 100644 --- a/src/CodeGen_X86.h +++ b/src/CodeGen_X86.h @@ -47,6 +47,8 @@ class CodeGen_X86 : public CodeGen_Posix { void visit(const EQ *) override; void visit(const NE *) override; void visit(const Select *) override; + void visit(const VectorReduce *) override; + void visit(const Mul *) override; // @} }; diff --git a/src/Target.cpp b/src/Target.cpp index 32d436526255..6591f646931e 100644 --- a/src/Target.cpp +++ b/src/Target.cpp @@ -360,6 +360,7 @@ const std::map feature_name_map = { {"wasm_signext", Target::WasmSignExt}, {"sve", Target::SVE}, {"sve2", Target::SVE2}, + {"arm_dot_prod", Target::ARMDotProd}, // NOTE: When adding features to this map, be sure to update PyEnums.cpp as well. }; diff --git a/src/Target.h b/src/Target.h index 182a27fa9418..4786c01de336 100644 --- a/src/Target.h +++ b/src/Target.h @@ -119,6 +119,7 @@ struct Target { WasmSignExt = halide_target_feature_wasm_signext, SVE = halide_target_feature_sve, SVE2 = halide_target_feature_sve2, + ARMDotProd = halide_target_feature_arm_dot_prod, FeatureEnd = halide_target_feature_end }; Target() diff --git a/src/runtime/HalideRuntime.h b/src/runtime/HalideRuntime.h index 9161c8da9678..1fcdf5f476a0 100644 --- a/src/runtime/HalideRuntime.h +++ b/src/runtime/HalideRuntime.h @@ -1313,7 +1313,8 @@ typedef enum halide_target_feature_t { halide_target_feature_sve2, ///< Enable ARM Scalable Vector Extensions v2 halide_target_feature_egl, ///< Force use of EGL support. - halide_target_feature_end ///< A sentinel. Every target is considered to have this feature, and setting this feature does nothing. + halide_target_feature_arm_dot_prod, ///< Enable ARMv8.2-a dotprod extension (i.e. udot and sdot instructions) + halide_target_feature_end ///< A sentinel. Every target is considered to have this feature, and setting this feature does nothing. } halide_target_feature_t; /** This function is called internally by Halide in some situations to determine diff --git a/src/runtime/aarch64.ll b/src/runtime/aarch64.ll index 8adb25eca59d..1472ddfc700d 100644 --- a/src/runtime/aarch64.ll +++ b/src/runtime/aarch64.ll @@ -231,3 +231,513 @@ define weak_odr <4 x float> @fast_inverse_sqrt_f32x4(<4 x float> %x) nounwind al %result = fmul <4 x float> %approx, %correction ret <4 x float> %result } + +; The way llvm represents intrinsics for horizontal addition are +; somewhat ad-hoc, and can be incompatible with the way we slice up +; intrinsics to meet the native vector width. We define wrappers for +; everything here instead. + +declare <2 x double> @llvm.aarch64.neon.faddp.v2f64(<2 x double>, <2 x double>) nounwind readnone +declare <2 x float> @llvm.aarch64.neon.faddp.v2f32(<2 x float>, <2 x float>) nounwind readnone +declare <2 x i32> @llvm.aarch64.neon.addp.v2i32(<2 x i32>, <2 x i32>) nounwind readnone +declare <2 x i64> @llvm.aarch64.neon.addp.v2i64(<2 x i64>, <2 x i64>) nounwind readnone +declare <4 x float> @llvm.aarch64.neon.faddp.v4f32(<4 x float>, <4 x float>) nounwind readnone +declare <4 x i16> @llvm.aarch64.neon.addp.v4i16(<4 x i16>, <4 x i16>) nounwind readnone +declare <4 x i32> @llvm.aarch64.neon.addp.v4i32(<4 x i32>, <4 x i32>) nounwind readnone +declare <8 x i16> @llvm.aarch64.neon.addp.v8i16(<8 x i16>, <8 x i16>) nounwind readnone +declare <8 x i8> @llvm.aarch64.neon.addp.v8i8(<8 x i8>, <8 x i8>) nounwind readnone +declare <16 x i8> @llvm.aarch64.neon.addp.v16i8(<16 x i8>, <16 x i8>) nounwind readnone + +define weak_odr <8 x i8> @pairwise_Add_int8x8_int8x16(<16 x i8> %x) nounwind alwaysinline { + %a = shufflevector <16 x i8> %x, <16 x i8> undef, <8 x i32> + %b = shufflevector <16 x i8> %x, <16 x i8> undef, <8 x i32> + %result = tail call <8 x i8> @llvm.aarch64.neon.addp.v8i8(<8 x i8> %a, <8 x i8> %b) + ret <8 x i8> %result +} + +define weak_odr <16 x i8> @pairwise_Add_int8x16_int8x32(<32 x i8> %x) nounwind alwaysinline { + %a = shufflevector <32 x i8> %x, <32 x i8> undef, <16 x i32> + %b = shufflevector <32 x i8> %x, <32 x i8> undef, <16 x i32> + %result = tail call <16 x i8> @llvm.aarch64.neon.addp.v16i8(<16 x i8> %a, <16 x i8> %b) + ret <16 x i8> %result +} + +define weak_odr <4 x i16> @pairwise_Add_int16x4_int16x8(<8 x i16> %x) nounwind alwaysinline { + %a = shufflevector <8 x i16> %x, <8 x i16> undef, <4 x i32> + %b = shufflevector <8 x i16> %x, <8 x i16> undef, <4 x i32> + %result = tail call <4 x i16> @llvm.aarch64.neon.addp.v4i16(<4 x i16> %a, <4 x i16> %b) + ret <4 x i16> %result +} + +define weak_odr <8 x i16> @pairwise_Add_int16x8_int16x16(<16 x i16> %x) nounwind alwaysinline { + %a = shufflevector <16 x i16> %x, <16 x i16> undef, <8 x i32> + %b = shufflevector <16 x i16> %x, <16 x i16> undef, <8 x i32> + %result = tail call <8 x i16> @llvm.aarch64.neon.addp.v8i16(<8 x i16> %a, <8 x i16> %b) + ret <8 x i16> %result +} + +define weak_odr <4 x i32> @pairwise_Add_int32x4_int32x8(<8 x i32> %x) nounwind alwaysinline { + %a = shufflevector <8 x i32> %x, <8 x i32> undef, <4 x i32> + %b = shufflevector <8 x i32> %x, <8 x i32> undef, <4 x i32> + %result = tail call <4 x i32> @llvm.aarch64.neon.addp.v4i32(<4 x i32> %a, <4 x i32> %b) + ret <4 x i32> %result +} + +define weak_odr <2 x i32> @pairwise_Add_int32x2_int32x4(<4 x i32> %x) nounwind alwaysinline { + %a = shufflevector <4 x i32> %x, <4 x i32> undef, <2 x i32> + %b = shufflevector <4 x i32> %x, <4 x i32> undef, <2 x i32> + %result = tail call <2 x i32> @llvm.aarch64.neon.addp.v2i32(<2 x i32> %a, <2 x i32> %b) + ret <2 x i32> %result +} + +define weak_odr i64 @pairwise_Add_int64_int64x2(<2 x i64> %x) nounwind alwaysinline { + %result = tail call <2 x i64> @llvm.aarch64.neon.addp.v2i64(<2 x i64> %x, <2 x i64> undef) + %scalar = extractelement <2 x i64> %result, i32 0 + ret i64 %scalar +} + +define weak_odr <2 x i64> @pairwise_Add_int64x2_int64x4(<4 x i64> %x) nounwind alwaysinline { + %a = shufflevector <4 x i64> %x, <4 x i64> undef, <2 x i32> + %b = shufflevector <4 x i64> %x, <4 x i64> undef, <2 x i32> + %result = tail call <2 x i64> @llvm.aarch64.neon.addp.v2i64(<2 x i64> %a, <2 x i64> %b) + ret <2 x i64> %result +} + +define weak_odr <4 x float> @pairwise_Add_float32x4_float32x8(<8 x float> %x) nounwind alwaysinline { + %a = shufflevector <8 x float> %x, <8 x float> undef, <4 x i32> + %b = shufflevector <8 x float> %x, <8 x float> undef, <4 x i32> + %result = tail call <4 x float> @llvm.aarch64.neon.faddp.v4f32(<4 x float> %a, <4 x float> %b) + ret <4 x float> %result +} + +define weak_odr <2 x float> @pairwise_Add_float32x2_float32x4(<4 x float> %x) nounwind alwaysinline { + %a = shufflevector <4 x float> %x, <4 x float> undef, <2 x i32> + %b = shufflevector <4 x float> %x, <4 x float> undef, <2 x i32> + %result = tail call <2 x float> @llvm.aarch64.neon.faddp.v2f32(<2 x float> %a, <2 x float> %b) + ret <2 x float> %result +} + +define weak_odr double @pairwise_Add_float64_float64x2(<2 x double> %x) nounwind alwaysinline { + %result = tail call <2 x double> @llvm.aarch64.neon.faddp.v2f64(<2 x double> %x, <2 x double> undef) + %scalar = extractelement <2 x double> %result, i32 0 + ret double %scalar +} + +define weak_odr <2 x double> @pairwise_Add_float64x2_float64x4(<4 x double> %x) nounwind alwaysinline { + %a = shufflevector <4 x double> %x, <4 x double> undef, <2 x i32> + %b = shufflevector <4 x double> %x, <4 x double> undef, <2 x i32> + %result = tail call <2 x double> @llvm.aarch64.neon.faddp.v2f64(<2 x double> %a, <2 x double> %b) + ret <2 x double> %result +} + + +declare <1 x i64> @llvm.aarch64.neon.saddlp.v1i64.v2i32(<2 x i32>) nounwind readnone +declare <1 x i64> @llvm.aarch64.neon.uaddlp.v1i64.v2i32(<2 x i32>) nounwind readnone +declare <2 x i32> @llvm.aarch64.neon.saddlp.v2i32.v4i16(<4 x i16>) nounwind readnone +declare <2 x i32> @llvm.aarch64.neon.uaddlp.v2i32.v4i16(<4 x i16>) nounwind readnone +declare <2 x i64> @llvm.aarch64.neon.saddlp.v2i64.v4i32(<4 x i32>) nounwind readnone +declare <2 x i64> @llvm.aarch64.neon.uaddlp.v2i64.v4i32(<4 x i32>) nounwind readnone +declare <4 x i16> @llvm.aarch64.neon.saddlp.v4i16.v8i8(<8 x i8>) nounwind readnone +declare <4 x i16> @llvm.aarch64.neon.uaddlp.v4i16.v8i8(<8 x i8>) nounwind readnone +declare <4 x i32> @llvm.aarch64.neon.saddlp.v4i32.v8i16(<8 x i16>) nounwind readnone +declare <4 x i32> @llvm.aarch64.neon.uaddlp.v4i32.v8i16(<8 x i16>) nounwind readnone +declare <8 x i16> @llvm.aarch64.neon.saddlp.v8i16.v16i8(<16 x i8>) nounwind readnone +declare <8 x i16> @llvm.aarch64.neon.uaddlp.v8i16.v16i8(<16 x i8>) nounwind readnone + + +define weak_odr <8 x i16> @pairwise_Add_int16x8_int8x16(<16 x i8> %x) nounwind alwaysinline { + %result = tail call <8 x i16> @llvm.aarch64.neon.saddlp.v8i16.v16i8(<16 x i8> %x) + ret <8 x i16> %result +} + +define weak_odr <4 x i16> @pairwise_Add_int16x4_int8x8(<8 x i8> %x) nounwind alwaysinline { + %result = tail call <4 x i16> @llvm.aarch64.neon.saddlp.v4i16.v8i8(<8 x i8> %x) + ret <4 x i16> %result +} + +define weak_odr <4 x i32> @pairwise_Add_int32x4_int16x8(<8 x i16> %x) nounwind alwaysinline { + %result = tail call <4 x i32> @llvm.aarch64.neon.saddlp.v4i32.v8i16(<8 x i16> %x) + ret <4 x i32> %result +} + +define weak_odr <2 x i32> @pairwise_Add_int32x2_int16x4(<4 x i16> %x) nounwind alwaysinline { + %result = tail call <2 x i32> @llvm.aarch64.neon.saddlp.v2i32.v4i16(<4 x i16> %x) + ret <2 x i32> %result +} + +define weak_odr <2 x i64> @pairwise_Add_int64x2_int32x4(<4 x i32> %x) nounwind alwaysinline { + %result = tail call <2 x i64> @llvm.aarch64.neon.saddlp.v2i64.v4i32(<4 x i32> %x) + ret <2 x i64> %result +} + +define weak_odr i64 @pairwise_Add_int64_int32x2(<2 x i32> %x) nounwind alwaysinline { + %result = tail call <1 x i64> @llvm.aarch64.neon.saddlp.v1i64.v2i32(<2 x i32> %x) + %scalar = extractelement <1 x i64> %result, i32 0 + ret i64 %scalar +} + +define weak_odr <8 x i16> @pairwise_Add_uint16x8_uint8x16(<16 x i8> %x) nounwind alwaysinline { + %result = tail call <8 x i16> @llvm.aarch64.neon.uaddlp.v8i16.v16i8(<16 x i8> %x) + ret <8 x i16> %result +} + +define weak_odr <4 x i16> @pairwise_Add_uint16x4_uint8x8(<8 x i8> %x) nounwind alwaysinline { + %result = tail call <4 x i16> @llvm.aarch64.neon.uaddlp.v4i16.v8i8(<8 x i8> %x) + ret <4 x i16> %result +} + +define weak_odr <4 x i32> @pairwise_Add_uint32x4_uint16x8(<8 x i16> %x) nounwind alwaysinline { + %result = tail call <4 x i32> @llvm.aarch64.neon.uaddlp.v4i32.v8i16(<8 x i16> %x) + ret <4 x i32> %result +} + +define weak_odr <2 x i32> @pairwise_Add_uint32x2_uint16x4(<4 x i16> %x) nounwind alwaysinline { + %result = tail call <2 x i32> @llvm.aarch64.neon.uaddlp.v2i32.v4i16(<4 x i16> %x) + ret <2 x i32> %result +} + +define weak_odr <2 x i64> @pairwise_Add_uint64x2_uint32x4(<4 x i32> %x) nounwind alwaysinline { + %result = tail call <2 x i64> @llvm.aarch64.neon.uaddlp.v2i64.v4i32(<4 x i32> %x) + ret <2 x i64> %result +} + +define weak_odr i64 @pairwise_Add_uint64_uint32x2(<2 x i32> %x) nounwind alwaysinline { + %result = tail call <1 x i64> @llvm.aarch64.neon.uaddlp.v1i64.v2i32(<2 x i32> %x) + %scalar = extractelement <1 x i64> %result, i32 0 + ret i64 %scalar +} + +define weak_odr <8 x i16> @pairwise_Add_int16x8_int8x16_accumulate(<8 x i16> %a, <16 x i8> %x) nounwind alwaysinline { + %y = tail call <8 x i16> @llvm.aarch64.neon.saddlp.v8i16.v16i8(<16 x i8> %x) + %result = add <8 x i16> %a, %y + ret <8 x i16> %result +} + +define weak_odr <4 x i16> @pairwise_Add_int16x4_int8x8_accumulate(<4 x i16> %a, <8 x i8> %x) nounwind alwaysinline { + %y = tail call <4 x i16> @llvm.aarch64.neon.saddlp.v4i16.v8i8(<8 x i8> %x) + %result = add <4 x i16> %a, %y + ret <4 x i16> %result +} + +define weak_odr <4 x i32> @pairwise_Add_int32x4_int16x8_accumulate(<4 x i32> %a, <8 x i16> %x) nounwind alwaysinline { + %y = tail call <4 x i32> @llvm.aarch64.neon.saddlp.v4i32.v8i16(<8 x i16> %x) + %result = add <4 x i32> %a, %y + ret <4 x i32> %result +} + +define weak_odr <2 x i32> @pairwise_Add_int32x2_int16x4_accumulate(<2 x i32> %a, <4 x i16> %x) nounwind alwaysinline { + %y = tail call <2 x i32> @llvm.aarch64.neon.saddlp.v2i32.v4i16(<4 x i16> %x) + %result = add <2 x i32> %a, %y + ret <2 x i32> %result +} + +define weak_odr <2 x i64> @pairwise_Add_int64x2_int32x4_accumulate(<2 x i64> %a, <4 x i32> %x) nounwind alwaysinline { + %y = tail call <2 x i64> @llvm.aarch64.neon.saddlp.v2i64.v4i32(<4 x i32> %x) + %result = add <2 x i64> %a, %y + ret <2 x i64> %result +} + +define weak_odr i64 @pairwise_Add_int64_int32x2_accumulate(i64 %a, <2 x i32> %x) nounwind alwaysinline { + %y = tail call <1 x i64> @llvm.aarch64.neon.saddlp.v1i64.v2i32(<2 x i32> %x) + %scalar = extractelement <1 x i64> %y, i32 0 + %result = add i64 %a, %scalar + ret i64 %result +} + +define weak_odr <8 x i16> @pairwise_Add_uint16x8_uint8x16_accumulate(<8 x i16> %a, <16 x i8> %x) nounwind alwaysinline { + %y = tail call <8 x i16> @llvm.aarch64.neon.uaddlp.v8i16.v16i8(<16 x i8> %x) + %result = add <8 x i16> %a, %y + ret <8 x i16> %result +} + +define weak_odr <4 x i16> @pairwise_Add_uint16x4_uint8x8_accumulate(<4 x i16> %a, <8 x i8> %x) nounwind alwaysinline { + %y = tail call <4 x i16> @llvm.aarch64.neon.uaddlp.v4i16.v8i8(<8 x i8> %x) + %result = add <4 x i16> %a, %y + ret <4 x i16> %result +} + +define weak_odr <4 x i32> @pairwise_Add_uint32x4_uint16x8_accumulate(<4 x i32> %a, <8 x i16> %x) nounwind alwaysinline { + %y = tail call <4 x i32> @llvm.aarch64.neon.uaddlp.v4i32.v8i16(<8 x i16> %x) + %result = add <4 x i32> %a, %y + ret <4 x i32> %result +} + +define weak_odr <2 x i32> @pairwise_Add_uint32x2_uint16x4_accumulate(<2 x i32> %a, <4 x i16> %x) nounwind alwaysinline { + %y = tail call <2 x i32> @llvm.aarch64.neon.uaddlp.v2i32.v4i16(<4 x i16> %x) + %result = add <2 x i32> %a, %y + ret <2 x i32> %result +} + +define weak_odr <2 x i64> @pairwise_Add_uint64x2_uint32x4_accumulate(<2 x i64> %a, <4 x i32> %x) nounwind alwaysinline { + %y = tail call <2 x i64> @llvm.aarch64.neon.uaddlp.v2i64.v4i32(<4 x i32> %x) + %result = add <2 x i64> %a, %y + ret <2 x i64> %result +} + +define weak_odr i64 @pairwise_Add_uint64_uint32x2_accumulate(i64 %a, <2 x i32> %x) nounwind alwaysinline { + %y = tail call <1 x i64> @llvm.aarch64.neon.uaddlp.v1i64.v2i32(<2 x i32> %x) + %scalar = extractelement <1 x i64> %y, i32 0 + %result = add i64 %a, %scalar + ret i64 %result +} + + + +declare <16 x i8> @llvm.aarch64.neon.smaxp.v16i8(<16 x i8>, <16 x i8>) nounwind readnone +declare <16 x i8> @llvm.aarch64.neon.umaxp.v16i8(<16 x i8>, <16 x i8>) nounwind readnone +declare <2 x double> @llvm.aarch64.neon.fmaxp.v2f64(<2 x double>, <2 x double>) nounwind readnone +declare <2 x float> @llvm.aarch64.neon.fmaxp.v2f32(<2 x float>, <2 x float>) nounwind readnone +declare <2 x i32> @llvm.aarch64.neon.smaxp.v2i32(<2 x i32>, <2 x i32>) nounwind readnone +declare <2 x i32> @llvm.aarch64.neon.umaxp.v2i32(<2 x i32>, <2 x i32>) nounwind readnone +declare <4 x float> @llvm.aarch64.neon.fmaxp.v4f32(<4 x float>, <4 x float>) nounwind readnone +declare <4 x i16> @llvm.aarch64.neon.smaxp.v4i16(<4 x i16>, <4 x i16>) nounwind readnone +declare <4 x i16> @llvm.aarch64.neon.umaxp.v4i16(<4 x i16>, <4 x i16>) nounwind readnone +declare <4 x i32> @llvm.aarch64.neon.smaxp.v4i32(<4 x i32>, <4 x i32>) nounwind readnone +declare <4 x i32> @llvm.aarch64.neon.umaxp.v4i32(<4 x i32>, <4 x i32>) nounwind readnone +declare <8 x i16> @llvm.aarch64.neon.smaxp.v8i16(<8 x i16>, <8 x i16>) nounwind readnone +declare <8 x i16> @llvm.aarch64.neon.umaxp.v8i16(<8 x i16>, <8 x i16>) nounwind readnone +declare <8 x i8> @llvm.aarch64.neon.smaxp.v8i8(<8 x i8>, <8 x i8>) nounwind readnone +declare <8 x i8> @llvm.aarch64.neon.umaxp.v8i8(<8 x i8>, <8 x i8>) nounwind readnone + +define weak_odr <8 x i8> @pairwise_Max_int8x8_int8x16(<16 x i8> %x) nounwind alwaysinline { + %a = shufflevector <16 x i8> %x, <16 x i8> undef, <8 x i32> + %b = shufflevector <16 x i8> %x, <16 x i8> undef, <8 x i32> + %result = tail call <8 x i8> @llvm.aarch64.neon.smaxp.v8i8(<8 x i8> %a, <8 x i8> %b) + ret <8 x i8> %result +} + +define weak_odr <16 x i8> @pairwise_Max_int8x16_int8x32(<32 x i8> %x) nounwind alwaysinline { + %a = shufflevector <32 x i8> %x, <32 x i8> undef, <16 x i32> + %b = shufflevector <32 x i8> %x, <32 x i8> undef, <16 x i32> + %result = tail call <16 x i8> @llvm.aarch64.neon.smaxp.v16i8(<16 x i8> %a, <16 x i8> %b) + ret <16 x i8> %result +} + +define weak_odr <4 x i16> @pairwise_Max_int16x4_int16x8(<8 x i16> %x) nounwind alwaysinline { + %a = shufflevector <8 x i16> %x, <8 x i16> undef, <4 x i32> + %b = shufflevector <8 x i16> %x, <8 x i16> undef, <4 x i32> + %result = tail call <4 x i16> @llvm.aarch64.neon.smaxp.v4i16(<4 x i16> %a, <4 x i16> %b) + ret <4 x i16> %result +} + +define weak_odr <8 x i16> @pairwise_Max_int16x8_int16x16(<16 x i16> %x) nounwind alwaysinline { + %a = shufflevector <16 x i16> %x, <16 x i16> undef, <8 x i32> + %b = shufflevector <16 x i16> %x, <16 x i16> undef, <8 x i32> + %result = tail call <8 x i16> @llvm.aarch64.neon.smaxp.v8i16(<8 x i16> %a, <8 x i16> %b) + ret <8 x i16> %result +} + +define weak_odr <4 x i32> @pairwise_Max_int32x4_int32x8(<8 x i32> %x) nounwind alwaysinline { + %a = shufflevector <8 x i32> %x, <8 x i32> undef, <4 x i32> + %b = shufflevector <8 x i32> %x, <8 x i32> undef, <4 x i32> + %result = tail call <4 x i32> @llvm.aarch64.neon.smaxp.v4i32(<4 x i32> %a, <4 x i32> %b) + ret <4 x i32> %result +} + +define weak_odr <2 x i32> @pairwise_Max_int32x2_int32x4(<4 x i32> %x) nounwind alwaysinline { + %a = shufflevector <4 x i32> %x, <4 x i32> undef, <2 x i32> + %b = shufflevector <4 x i32> %x, <4 x i32> undef, <2 x i32> + %result = tail call <2 x i32> @llvm.aarch64.neon.smaxp.v2i32(<2 x i32> %a, <2 x i32> %b) + ret <2 x i32> %result +} + +define weak_odr <4 x float> @pairwise_Max_float32x4_float32x8(<8 x float> %x) nounwind alwaysinline { + %a = shufflevector <8 x float> %x, <8 x float> undef, <4 x i32> + %b = shufflevector <8 x float> %x, <8 x float> undef, <4 x i32> + %result = tail call <4 x float> @llvm.aarch64.neon.fmaxp.v4f32(<4 x float> %a, <4 x float> %b) + ret <4 x float> %result +} + +define weak_odr <2 x float> @pairwise_Max_float32x2_float32x4(<4 x float> %x) nounwind alwaysinline { + %a = shufflevector <4 x float> %x, <4 x float> undef, <2 x i32> + %b = shufflevector <4 x float> %x, <4 x float> undef, <2 x i32> + %result = tail call <2 x float> @llvm.aarch64.neon.fmaxp.v2f32(<2 x float> %a, <2 x float> %b) + ret <2 x float> %result +} + + +define weak_odr <2 x double> @pairwise_Max_float64x2_float64x4(<4 x double> %x) nounwind alwaysinline { + %a = shufflevector <4 x double> %x, <4 x double> undef, <2 x i32> + %b = shufflevector <4 x double> %x, <4 x double> undef, <2 x i32> + %result = tail call <2 x double> @llvm.aarch64.neon.fmaxp.v2f64(<2 x double> %a, <2 x double> %b) + ret <2 x double> %result +} + +define weak_odr double @pairwise_Max_float64_float64x2(<2 x double> %x) nounwind alwaysinline { + %result = tail call <2 x double> @llvm.aarch64.neon.fmaxp.v2f64(<2 x double> %x, <2 x double> undef) + %scalar = extractelement <2 x double> %result, i32 0 + ret double %scalar +} + + +define weak_odr <8 x i8> @pairwise_Max_uint8x8_uint8x16(<16 x i8> %x) nounwind alwaysinline { + %a = shufflevector <16 x i8> %x, <16 x i8> undef, <8 x i32> + %b = shufflevector <16 x i8> %x, <16 x i8> undef, <8 x i32> + %result = tail call <8 x i8> @llvm.aarch64.neon.umaxp.v8i8(<8 x i8> %a, <8 x i8> %b) + ret <8 x i8> %result +} + +define weak_odr <16 x i8> @pairwise_Max_uint8x16_uint8x32(<32 x i8> %x) nounwind alwaysinline { + %a = shufflevector <32 x i8> %x, <32 x i8> undef, <16 x i32> + %b = shufflevector <32 x i8> %x, <32 x i8> undef, <16 x i32> + %result = tail call <16 x i8> @llvm.aarch64.neon.umaxp.v16i8(<16 x i8> %a, <16 x i8> %b) + ret <16 x i8> %result +} + +define weak_odr <4 x i16> @pairwise_Max_uint16x4_uint16x8(<8 x i16> %x) nounwind alwaysinline { + %a = shufflevector <8 x i16> %x, <8 x i16> undef, <4 x i32> + %b = shufflevector <8 x i16> %x, <8 x i16> undef, <4 x i32> + %result = tail call <4 x i16> @llvm.aarch64.neon.umaxp.v4i16(<4 x i16> %a, <4 x i16> %b) + ret <4 x i16> %result +} + +define weak_odr <8 x i16> @pairwise_Max_uint16x8_uint16x16(<16 x i16> %x) nounwind alwaysinline { + %a = shufflevector <16 x i16> %x, <16 x i16> undef, <8 x i32> + %b = shufflevector <16 x i16> %x, <16 x i16> undef, <8 x i32> + %result = tail call <8 x i16> @llvm.aarch64.neon.umaxp.v8i16(<8 x i16> %a, <8 x i16> %b) + ret <8 x i16> %result +} + +define weak_odr <4 x i32> @pairwise_Max_uint32x4_uint32x8(<8 x i32> %x) nounwind alwaysinline { + %a = shufflevector <8 x i32> %x, <8 x i32> undef, <4 x i32> + %b = shufflevector <8 x i32> %x, <8 x i32> undef, <4 x i32> + %result = tail call <4 x i32> @llvm.aarch64.neon.umaxp.v4i32(<4 x i32> %a, <4 x i32> %b) + ret <4 x i32> %result +} + +define weak_odr <2 x i32> @pairwise_Max_uint32x2_uint32x4(<4 x i32> %x) nounwind alwaysinline { + %a = shufflevector <4 x i32> %x, <4 x i32> undef, <2 x i32> + %b = shufflevector <4 x i32> %x, <4 x i32> undef, <2 x i32> + %result = tail call <2 x i32> @llvm.aarch64.neon.umaxp.v2i32(<2 x i32> %a, <2 x i32> %b) + ret <2 x i32> %result +} + + +declare <16 x i8> @llvm.aarch64.neon.sminp.v16i8(<16 x i8>, <16 x i8>) nounwind readnone +declare <16 x i8> @llvm.aarch64.neon.uminp.v16i8(<16 x i8>, <16 x i8>) nounwind readnone +declare <2 x double> @llvm.aarch64.neon.fminp.v2f64(<2 x double>, <2 x double>) nounwind readnone +declare <2 x float> @llvm.aarch64.neon.fminp.v2f32(<2 x float>, <2 x float>) nounwind readnone +declare <2 x i32> @llvm.aarch64.neon.sminp.v2i32(<2 x i32>, <2 x i32>) nounwind readnone +declare <2 x i32> @llvm.aarch64.neon.uminp.v2i32(<2 x i32>, <2 x i32>) nounwind readnone +declare <4 x float> @llvm.aarch64.neon.fminp.v4f32(<4 x float>, <4 x float>) nounwind readnone +declare <4 x i16> @llvm.aarch64.neon.sminp.v4i16(<4 x i16>, <4 x i16>) nounwind readnone +declare <4 x i16> @llvm.aarch64.neon.uminp.v4i16(<4 x i16>, <4 x i16>) nounwind readnone +declare <4 x i32> @llvm.aarch64.neon.sminp.v4i32(<4 x i32>, <4 x i32>) nounwind readnone +declare <4 x i32> @llvm.aarch64.neon.uminp.v4i32(<4 x i32>, <4 x i32>) nounwind readnone +declare <8 x i16> @llvm.aarch64.neon.sminp.v8i16(<8 x i16>, <8 x i16>) nounwind readnone +declare <8 x i16> @llvm.aarch64.neon.uminp.v8i16(<8 x i16>, <8 x i16>) nounwind readnone +declare <8 x i8> @llvm.aarch64.neon.sminp.v8i8(<8 x i8>, <8 x i8>) nounwind readnone +declare <8 x i8> @llvm.aarch64.neon.uminp.v8i8(<8 x i8>, <8 x i8>) nounwind readnone + +define weak_odr <8 x i8> @pairwise_Min_int8x8_int8x16(<16 x i8> %x) nounwind alwaysinline { + %a = shufflevector <16 x i8> %x, <16 x i8> undef, <8 x i32> + %b = shufflevector <16 x i8> %x, <16 x i8> undef, <8 x i32> + %result = tail call <8 x i8> @llvm.aarch64.neon.sminp.v8i8(<8 x i8> %a, <8 x i8> %b) + ret <8 x i8> %result +} + +define weak_odr <16 x i8> @pairwise_Min_int8x16_int8x32(<32 x i8> %x) nounwind alwaysinline { + %a = shufflevector <32 x i8> %x, <32 x i8> undef, <16 x i32> + %b = shufflevector <32 x i8> %x, <32 x i8> undef, <16 x i32> + %result = tail call <16 x i8> @llvm.aarch64.neon.sminp.v16i8(<16 x i8> %a, <16 x i8> %b) + ret <16 x i8> %result +} + +define weak_odr <4 x i16> @pairwise_Min_int16x4_int16x8(<8 x i16> %x) nounwind alwaysinline { + %a = shufflevector <8 x i16> %x, <8 x i16> undef, <4 x i32> + %b = shufflevector <8 x i16> %x, <8 x i16> undef, <4 x i32> + %result = tail call <4 x i16> @llvm.aarch64.neon.sminp.v4i16(<4 x i16> %a, <4 x i16> %b) + ret <4 x i16> %result +} + +define weak_odr <8 x i16> @pairwise_Min_int16x8_int16x16(<16 x i16> %x) nounwind alwaysinline { + %a = shufflevector <16 x i16> %x, <16 x i16> undef, <8 x i32> + %b = shufflevector <16 x i16> %x, <16 x i16> undef, <8 x i32> + %result = tail call <8 x i16> @llvm.aarch64.neon.sminp.v8i16(<8 x i16> %a, <8 x i16> %b) + ret <8 x i16> %result +} + +define weak_odr <4 x i32> @pairwise_Min_int32x4_int32x8(<8 x i32> %x) nounwind alwaysinline { + %a = shufflevector <8 x i32> %x, <8 x i32> undef, <4 x i32> + %b = shufflevector <8 x i32> %x, <8 x i32> undef, <4 x i32> + %result = tail call <4 x i32> @llvm.aarch64.neon.sminp.v4i32(<4 x i32> %a, <4 x i32> %b) + ret <4 x i32> %result +} + +define weak_odr <2 x i32> @pairwise_Min_int32x2_int32x4(<4 x i32> %x) nounwind alwaysinline { + %a = shufflevector <4 x i32> %x, <4 x i32> undef, <2 x i32> + %b = shufflevector <4 x i32> %x, <4 x i32> undef, <2 x i32> + %result = tail call <2 x i32> @llvm.aarch64.neon.sminp.v2i32(<2 x i32> %a, <2 x i32> %b) + ret <2 x i32> %result +} + + +define weak_odr <4 x float> @pairwise_Min_float32x4_float32x8(<8 x float> %x) nounwind alwaysinline { + %a = shufflevector <8 x float> %x, <8 x float> undef, <4 x i32> + %b = shufflevector <8 x float> %x, <8 x float> undef, <4 x i32> + %result = tail call <4 x float> @llvm.aarch64.neon.fminp.v4f32(<4 x float> %a, <4 x float> %b) + ret <4 x float> %result +} + +define weak_odr <2 x float> @pairwise_Min_float32x2_float32x4(<4 x float> %x) nounwind alwaysinline { + %a = shufflevector <4 x float> %x, <4 x float> undef, <2 x i32> + %b = shufflevector <4 x float> %x, <4 x float> undef, <2 x i32> + %result = tail call <2 x float> @llvm.aarch64.neon.fminp.v2f32(<2 x float> %a, <2 x float> %b) + ret <2 x float> %result +} + + +define weak_odr <2 x double> @pairwise_Min_float64x2_float64x4(<4 x double> %x) nounwind alwaysinline { + %a = shufflevector <4 x double> %x, <4 x double> undef, <2 x i32> + %b = shufflevector <4 x double> %x, <4 x double> undef, <2 x i32> + %result = tail call <2 x double> @llvm.aarch64.neon.fminp.v2f64(<2 x double> %a, <2 x double> %b) + ret <2 x double> %result +} + +define weak_odr double @pairwise_Min_float64_float64x2(<2 x double> %x) nounwind alwaysinline { + %result = tail call <2 x double> @llvm.aarch64.neon.fminp.v2f64(<2 x double> %x, <2 x double> undef) + %scalar = extractelement <2 x double> %result, i32 0 + ret double %scalar +} + +define weak_odr <8 x i8> @pairwise_Min_uint8x8_uint8x16(<16 x i8> %x) nounwind alwaysinline { + %a = shufflevector <16 x i8> %x, <16 x i8> undef, <8 x i32> + %b = shufflevector <16 x i8> %x, <16 x i8> undef, <8 x i32> + %result = tail call <8 x i8> @llvm.aarch64.neon.uminp.v8i8(<8 x i8> %a, <8 x i8> %b) + ret <8 x i8> %result +} + +define weak_odr <16 x i8> @pairwise_Min_uint8x16_uint8x32(<32 x i8> %x) nounwind alwaysinline { + %a = shufflevector <32 x i8> %x, <32 x i8> undef, <16 x i32> + %b = shufflevector <32 x i8> %x, <32 x i8> undef, <16 x i32> + %result = tail call <16 x i8> @llvm.aarch64.neon.uminp.v16i8(<16 x i8> %a, <16 x i8> %b) + ret <16 x i8> %result +} + +define weak_odr <4 x i16> @pairwise_Min_uint16x4_uint16x8(<8 x i16> %x) nounwind alwaysinline { + %a = shufflevector <8 x i16> %x, <8 x i16> undef, <4 x i32> + %b = shufflevector <8 x i16> %x, <8 x i16> undef, <4 x i32> + %result = tail call <4 x i16> @llvm.aarch64.neon.uminp.v4i16(<4 x i16> %a, <4 x i16> %b) + ret <4 x i16> %result +} + +define weak_odr <8 x i16> @pairwise_Min_uint16x8_uint16x16(<16 x i16> %x) nounwind alwaysinline { + %a = shufflevector <16 x i16> %x, <16 x i16> undef, <8 x i32> + %b = shufflevector <16 x i16> %x, <16 x i16> undef, <8 x i32> + %result = tail call <8 x i16> @llvm.aarch64.neon.uminp.v8i16(<8 x i16> %a, <8 x i16> %b) + ret <8 x i16> %result +} + +define weak_odr <4 x i32> @pairwise_Min_uint32x4_uint32x8(<8 x i32> %x) nounwind alwaysinline { + %a = shufflevector <8 x i32> %x, <8 x i32> undef, <4 x i32> + %b = shufflevector <8 x i32> %x, <8 x i32> undef, <4 x i32> + %result = tail call <4 x i32> @llvm.aarch64.neon.uminp.v4i32(<4 x i32> %a, <4 x i32> %b) + ret <4 x i32> %result +} + +define weak_odr <2 x i32> @pairwise_Min_uint32x2_uint32x4(<4 x i32> %x) nounwind alwaysinline { + %a = shufflevector <4 x i32> %x, <4 x i32> undef, <2 x i32> + %b = shufflevector <4 x i32> %x, <4 x i32> undef, <2 x i32> + %result = tail call <2 x i32> @llvm.aarch64.neon.uminp.v2i32(<2 x i32> %a, <2 x i32> %b) + ret <2 x i32> %result +} diff --git a/src/runtime/arm.ll b/src/runtime/arm.ll index ee3f69a86518..42b590e514ca 100644 --- a/src/runtime/arm.ll +++ b/src/runtime/arm.ll @@ -398,3 +398,316 @@ define weak_odr void @strided_store_f32x4(float * %ptr, i32 %stride, <4 x float> ret void } +; The way llvm represents intrinsics for horizontal addition are +; somewhat ad-hoc, and can be incompatible with the way we slice up +; intrinsics to meet the native vector width. We define wrappers for +; everything here instead. + +declare <2 x float> @llvm.arm.neon.vpadd.v2f32(<2 x float>, <2 x float>) nounwind readnone +declare <2 x i32> @llvm.arm.neon.vpadd.v2i32(<2 x i32>, <2 x i32>) nounwind readnone +declare <4 x i16> @llvm.arm.neon.vpadd.v4i16(<4 x i16>, <4 x i16>) nounwind readnone +declare <8 x i8> @llvm.arm.neon.vpadd.v8i8(<8 x i8>, <8 x i8>) nounwind readnone + +define weak_odr <8 x i8> @pairwise_Add_int8x8_int8x16(<16 x i8> %x) nounwind alwaysinline { + %a = shufflevector <16 x i8> %x, <16 x i8> undef, <8 x i32> + %b = shufflevector <16 x i8> %x, <16 x i8> undef, <8 x i32> + %result = tail call <8 x i8> @llvm.arm.neon.vpadd.v8i8(<8 x i8> %a, <8 x i8> %b) + ret <8 x i8> %result +} + +define weak_odr <4 x i16> @pairwise_Add_int16x4_int16x8(<8 x i16> %x) nounwind alwaysinline { + %a = shufflevector <8 x i16> %x, <8 x i16> undef, <4 x i32> + %b = shufflevector <8 x i16> %x, <8 x i16> undef, <4 x i32> + %result = tail call <4 x i16> @llvm.arm.neon.vpadd.v4i16(<4 x i16> %a, <4 x i16> %b) + ret <4 x i16> %result +} + +define weak_odr <2 x i32> @pairwise_Add_int32x2_int32x4(<4 x i32> %x) nounwind alwaysinline { + %a = shufflevector <4 x i32> %x, <4 x i32> undef, <2 x i32> + %b = shufflevector <4 x i32> %x, <4 x i32> undef, <2 x i32> + %result = tail call <2 x i32> @llvm.arm.neon.vpadd.v2i32(<2 x i32> %a, <2 x i32> %b) + ret <2 x i32> %result +} + +define weak_odr <2 x float> @pairwise_Add_float32x2_float32x4(<4 x float> %x) nounwind alwaysinline { + %a = shufflevector <4 x float> %x, <4 x float> undef, <2 x i32> + %b = shufflevector <4 x float> %x, <4 x float> undef, <2 x i32> + %result = tail call <2 x float> @llvm.arm.neon.vpadd.v2f32(<2 x float> %a, <2 x float> %b) + ret <2 x float> %result +} + +declare <1 x i64> @llvm.arm.neon.vpaddls.v1i64.v2i32(<2 x i32>) nounwind readnone +declare <1 x i64> @llvm.arm.neon.vpaddlu.v1i64.v2i32(<2 x i32>) nounwind readnone +declare <2 x i32> @llvm.arm.neon.vpaddls.v2i32.v4i16(<4 x i16>) nounwind readnone +declare <2 x i32> @llvm.arm.neon.vpaddlu.v2i32.v4i16(<4 x i16>) nounwind readnone +declare <2 x i64> @llvm.arm.neon.vpaddls.v2i64.v4i32(<4 x i32>) nounwind readnone +declare <2 x i64> @llvm.arm.neon.vpaddlu.v2i64.v4i32(<4 x i32>) nounwind readnone +declare <4 x i16> @llvm.arm.neon.vpaddls.v4i16.v8i8(<8 x i8>) nounwind readnone +declare <4 x i16> @llvm.arm.neon.vpaddlu.v4i16.v8i8(<8 x i8>) nounwind readnone +declare <4 x i32> @llvm.arm.neon.vpaddls.v4i32.v8i16(<8 x i16>) nounwind readnone +declare <4 x i32> @llvm.arm.neon.vpaddlu.v4i32.v8i16(<8 x i16>) nounwind readnone +declare <8 x i16> @llvm.arm.neon.vpaddls.v8i16.v16i8(<16 x i8>) nounwind readnone +declare <8 x i16> @llvm.arm.neon.vpaddlu.v8i16.v16i8(<16 x i8>) nounwind readnone + +define weak_odr <8 x i16> @pairwise_Add_int16x8_int8x16(<16 x i8> %x) nounwind alwaysinline { + %result = tail call <8 x i16> @llvm.arm.neon.vpaddls.v8i16.v16i8(<16 x i8> %x) + ret <8 x i16> %result +} + +define weak_odr <4 x i16> @pairwise_Add_int16x4_int8x8(<8 x i8> %x) nounwind alwaysinline { + %result = tail call <4 x i16> @llvm.arm.neon.vpaddls.v4i16.v8i8(<8 x i8> %x) + ret <4 x i16> %result +} + +define weak_odr <4 x i32> @pairwise_Add_int32x4_int16x8(<8 x i16> %x) nounwind alwaysinline { + %result = tail call <4 x i32> @llvm.arm.neon.vpaddls.v4i32.v8i16(<8 x i16> %x) + ret <4 x i32> %result +} + +define weak_odr <2 x i32> @pairwise_Add_int32x2_int16x4(<4 x i16> %x) nounwind alwaysinline { + %result = tail call <2 x i32> @llvm.arm.neon.vpaddls.v2i32.v4i16(<4 x i16> %x) + ret <2 x i32> %result +} + +define weak_odr <2 x i64> @pairwise_Add_int64x2_int32x4(<4 x i32> %x) nounwind alwaysinline { + %result = tail call <2 x i64> @llvm.arm.neon.vpaddls.v2i64.v4i32(<4 x i32> %x) + ret <2 x i64> %result +} + +define weak_odr i64 @pairwise_Add_int64_int32x2(<2 x i32> %x) nounwind alwaysinline { + %result = tail call <1 x i64> @llvm.arm.neon.vpaddls.v1i64.v2i32(<2 x i32> %x) + %scalar = extractelement <1 x i64> %result, i32 0 + ret i64 %scalar +} + +define weak_odr i64 @pairwise_Add_int64_int64x2(<2 x i64> %x) nounwind alwaysinline { + ; There's no intrinsic for this on arm32, but we include an implementation for completeness. + %a = extractelement <2 x i64> %x, i32 0 + %b = extractelement <2 x i64> %x, i32 1 + %result = add i64 %a, %b + ret i64 %result +} + +define weak_odr <8 x i16> @pairwise_Add_uint16x8_uint8x16(<16 x i8> %x) nounwind alwaysinline { + %result = tail call <8 x i16> @llvm.arm.neon.vpaddlu.v8i16.v16i8(<16 x i8> %x) + ret <8 x i16> %result +} + +define weak_odr <4 x i16> @pairwise_Add_uint16x4_uint8x8(<8 x i8> %x) nounwind alwaysinline { + %result = tail call <4 x i16> @llvm.arm.neon.vpaddlu.v4i16.v8i8(<8 x i8> %x) + ret <4 x i16> %result +} + +define weak_odr <4 x i32> @pairwise_Add_uint32x4_uint16x8(<8 x i16> %x) nounwind alwaysinline { + %result = tail call <4 x i32> @llvm.arm.neon.vpaddlu.v4i32.v8i16(<8 x i16> %x) + ret <4 x i32> %result +} + +define weak_odr <2 x i32> @pairwise_Add_uint32x2_uint16x4(<4 x i16> %x) nounwind alwaysinline { + %result = tail call <2 x i32> @llvm.arm.neon.vpaddlu.v2i32.v4i16(<4 x i16> %x) + ret <2 x i32> %result +} + +define weak_odr <2 x i64> @pairwise_Add_uint64x2_uint32x4(<4 x i32> %x) nounwind alwaysinline { + %result = tail call <2 x i64> @llvm.arm.neon.vpaddlu.v2i64.v4i32(<4 x i32> %x) + ret <2 x i64> %result +} + +define weak_odr i64 @pairwise_Add_uint64_uint32x2(<2 x i32> %x) nounwind alwaysinline { + %result = tail call <1 x i64> @llvm.arm.neon.vpaddlu.v1i64.v2i32(<2 x i32> %x) + %scalar = extractelement <1 x i64> %result, i32 0 + ret i64 %scalar +} + +declare <4 x i16> @llvm.arm.neon.vpadals.v4i16.v8i8(<4 x i16>, <8 x i8>) nounwind readnone +declare <2 x i32> @llvm.arm.neon.vpadals.v2i32.v4i16(<2 x i32>, <4 x i16>) nounwind readnone +declare <1 x i64> @llvm.arm.neon.vpadals.v1i64.v2i32(<1 x i64>, <2 x i32>) nounwind readnone +declare <4 x i16> @llvm.arm.neon.vpadalu.v4i16.v8i8(<4 x i16>, <8 x i8>) nounwind readnone +declare <2 x i32> @llvm.arm.neon.vpadalu.v2i32.v4i16(<2 x i32>, <4 x i16>) nounwind readnone +declare <1 x i64> @llvm.arm.neon.vpadalu.v1i64.v2i32(<1 x i64>, <2 x i32>) nounwind readnone +declare <8 x i16> @llvm.arm.neon.vpadals.v8i16.v16i8(<8 x i16>, <16 x i8>) nounwind readnone +declare <4 x i32> @llvm.arm.neon.vpadals.v4i32.v8i16(<4 x i32>, <8 x i16>) nounwind readnone +declare <2 x i64> @llvm.arm.neon.vpadals.v2i64.v4i32(<2 x i64>, <4 x i32>) nounwind readnone +declare <8 x i16> @llvm.arm.neon.vpadalu.v8i16.v16i8(<8 x i16>, <16 x i8>) nounwind readnone +declare <4 x i32> @llvm.arm.neon.vpadalu.v4i32.v8i16(<4 x i32>, <8 x i16>) nounwind readnone +declare <2 x i64> @llvm.arm.neon.vpadalu.v2i64.v4i32(<2 x i64>, <4 x i32>) nounwind readnone + + +define weak_odr <8 x i16> @pairwise_Add_int16x8_int8x16_accumulate(<8 x i16> %a, <16 x i8> %x) nounwind alwaysinline { + %result = tail call <8 x i16> @llvm.arm.neon.vpadals.v8i16.v16i8(<8 x i16> %a, <16 x i8> %x) + ret <8 x i16> %result +} + +define weak_odr <4 x i16> @pairwise_Add_int16x4_int8x8_accumulate(<4 x i16> %a, <8 x i8> %x) nounwind alwaysinline { + %result = tail call <4 x i16> @llvm.arm.neon.vpadals.v4i16.v8i8(<4 x i16> %a, <8 x i8> %x) + ret <4 x i16> %result +} + +define weak_odr <4 x i32> @pairwise_Add_int32x4_int16x8_accumulate(<4 x i32> %a, <8 x i16> %x) nounwind alwaysinline { + %result = tail call <4 x i32> @llvm.arm.neon.vpadals.v4i32.v8i16(<4 x i32> %a, <8 x i16> %x) + ret <4 x i32> %result +} + +define weak_odr <2 x i32> @pairwise_Add_int32x2_int16x4_accumulate(<2 x i32> %a, <4 x i16> %x) nounwind alwaysinline { + %result = tail call <2 x i32> @llvm.arm.neon.vpadals.v2i32.v4i16(<2 x i32> %a, <4 x i16> %x) + ret <2 x i32> %result +} + +define weak_odr <2 x i64> @pairwise_Add_int64x2_int32x4_accumulate(<2 x i64> %a, <4 x i32> %x) nounwind alwaysinline { + %result = tail call <2 x i64> @llvm.arm.neon.vpadals.v2i64.v4i32(<2 x i64> %a, <4 x i32> %x) + ret <2 x i64> %result +} + +define weak_odr i64 @pairwise_Add_int64_int32x2_accumulate(i64 %a, <2 x i32> %x) nounwind alwaysinline { + %vec = insertelement <1 x i64> undef, i64 %a, i32 0 + %result = tail call <1 x i64> @llvm.arm.neon.vpadals.v1i64.v2i32(<1 x i64> %vec, <2 x i32> %x) + %scalar = extractelement <1 x i64> %result, i32 0 + ret i64 %scalar +} + +define weak_odr <8 x i16> @pairwise_Add_uint16x8_uint8x16_accumulate(<8 x i16> %a, <16 x i8> %x) nounwind alwaysinline { + %result = tail call <8 x i16> @llvm.arm.neon.vpadalu.v8i16.v16i8(<8 x i16> %a, <16 x i8> %x) + ret <8 x i16> %result +} + +define weak_odr <4 x i16> @pairwise_Add_uint16x4_uint8x8_accumulate(<4 x i16> %a, <8 x i8> %x) nounwind alwaysinline { + %result = tail call <4 x i16> @llvm.arm.neon.vpadalu.v4i16.v8i8(<4 x i16> %a, <8 x i8> %x) + ret <4 x i16> %result +} + +define weak_odr <4 x i32> @pairwise_Add_uint32x4_uint16x8_accumulate(<4 x i32> %a, <8 x i16> %x) nounwind alwaysinline { + %result = tail call <4 x i32> @llvm.arm.neon.vpadalu.v4i32.v8i16(<4 x i32> %a, <8 x i16> %x) + ret <4 x i32> %result +} + +define weak_odr <2 x i32> @pairwise_Add_uint32x2_uint16x4_accumulate(<2 x i32> %a, <4 x i16> %x) nounwind alwaysinline { + %result = tail call <2 x i32> @llvm.arm.neon.vpadalu.v2i32.v4i16(<2 x i32> %a, <4 x i16> %x) + ret <2 x i32> %result +} + +define weak_odr <2 x i64> @pairwise_Add_uint64x2_uint32x4_accumulate(<2 x i64> %a, <4 x i32> %x) nounwind alwaysinline { + %result = tail call <2 x i64> @llvm.arm.neon.vpadalu.v2i64.v4i32(<2 x i64> %a, <4 x i32> %x) + ret <2 x i64> %result +} + +define weak_odr i64 @pairwise_Add_uint64_uint32x2_accumulate(i64 %a, <2 x i32> %x) nounwind alwaysinline { + %vec = insertelement <1 x i64> undef, i64 %a, i32 0 + %result = tail call <1 x i64> @llvm.arm.neon.vpadalu.v1i64.v2i32(<1 x i64> %vec, <2 x i32> %x) + %scalar = extractelement <1 x i64> %result, i32 0 + ret i64 %scalar +} + +declare <2 x float> @llvm.arm.neon.vpmaxs.v2f32(<2 x float>, <2 x float>) nounwind readnone +declare <2 x i32> @llvm.arm.neon.vpmaxs.v2i32(<2 x i32>, <2 x i32>) nounwind readnone +declare <2 x i32> @llvm.arm.neon.vpmaxu.v2i32(<2 x i32>, <2 x i32>) nounwind readnone +declare <4 x i16> @llvm.arm.neon.vpmaxs.v4i16(<4 x i16>, <4 x i16>) nounwind readnone +declare <4 x i16> @llvm.arm.neon.vpmaxu.v4i16(<4 x i16>, <4 x i16>) nounwind readnone +declare <8 x i8> @llvm.arm.neon.vpmaxs.v8i8(<8 x i8>, <8 x i8>) nounwind readnone +declare <8 x i8> @llvm.arm.neon.vpmaxu.v8i8(<8 x i8>, <8 x i8>) nounwind readnone + +define weak_odr <8 x i8> @pairwise_Max_int8x8_int8x16(<16 x i8> %x) nounwind alwaysinline { + %a = shufflevector <16 x i8> %x, <16 x i8> undef, <8 x i32> + %b = shufflevector <16 x i8> %x, <16 x i8> undef, <8 x i32> + %result = tail call <8 x i8> @llvm.arm.neon.vpmaxs.v8i8(<8 x i8> %a, <8 x i8> %b) + ret <8 x i8> %result +} + +define weak_odr <4 x i16> @pairwise_Max_int16x4_int16x8(<8 x i16> %x) nounwind alwaysinline { + %a = shufflevector <8 x i16> %x, <8 x i16> undef, <4 x i32> + %b = shufflevector <8 x i16> %x, <8 x i16> undef, <4 x i32> + %result = tail call <4 x i16> @llvm.arm.neon.vpmaxs.v4i16(<4 x i16> %a, <4 x i16> %b) + ret <4 x i16> %result +} + +define weak_odr <2 x i32> @pairwise_Max_int32x2_int32x4(<4 x i32> %x) nounwind alwaysinline { + %a = shufflevector <4 x i32> %x, <4 x i32> undef, <2 x i32> + %b = shufflevector <4 x i32> %x, <4 x i32> undef, <2 x i32> + %result = tail call <2 x i32> @llvm.arm.neon.vpmaxs.v2i32(<2 x i32> %a, <2 x i32> %b) + ret <2 x i32> %result +} + +define weak_odr <2 x float> @pairwise_Max_float32x2_float32x4(<4 x float> %x) nounwind alwaysinline { + %a = shufflevector <4 x float> %x, <4 x float> undef, <2 x i32> + %b = shufflevector <4 x float> %x, <4 x float> undef, <2 x i32> + %result = tail call <2 x float> @llvm.arm.neon.vpmaxs.v2f32(<2 x float> %a, <2 x float> %b) + ret <2 x float> %result +} + +define weak_odr <8 x i8> @pairwise_Max_uint8x8_uint8x16(<16 x i8> %x) nounwind alwaysinline { + %a = shufflevector <16 x i8> %x, <16 x i8> undef, <8 x i32> + %b = shufflevector <16 x i8> %x, <16 x i8> undef, <8 x i32> + %result = tail call <8 x i8> @llvm.arm.neon.vpmaxu.v8i8(<8 x i8> %a, <8 x i8> %b) + ret <8 x i8> %result +} + +define weak_odr <4 x i16> @pairwise_Max_uint16x4_uint16x8(<8 x i16> %x) nounwind alwaysinline { + %a = shufflevector <8 x i16> %x, <8 x i16> undef, <4 x i32> + %b = shufflevector <8 x i16> %x, <8 x i16> undef, <4 x i32> + %result = tail call <4 x i16> @llvm.arm.neon.vpmaxu.v4i16(<4 x i16> %a, <4 x i16> %b) + ret <4 x i16> %result +} + +define weak_odr <2 x i32> @pairwise_Max_uint32x2_uint32x4(<4 x i32> %x) nounwind alwaysinline { + %a = shufflevector <4 x i32> %x, <4 x i32> undef, <2 x i32> + %b = shufflevector <4 x i32> %x, <4 x i32> undef, <2 x i32> + %result = tail call <2 x i32> @llvm.arm.neon.vpmaxu.v2i32(<2 x i32> %a, <2 x i32> %b) + ret <2 x i32> %result +} + + +declare <2 x float> @llvm.arm.neon.vpmins.v2f32(<2 x float>, <2 x float>) nounwind readnone +declare <2 x i32> @llvm.arm.neon.vpmins.v2i32(<2 x i32>, <2 x i32>) nounwind readnone +declare <2 x i32> @llvm.arm.neon.vpminu.v2i32(<2 x i32>, <2 x i32>) nounwind readnone +declare <4 x i16> @llvm.arm.neon.vpmins.v4i16(<4 x i16>, <4 x i16>) nounwind readnone +declare <4 x i16> @llvm.arm.neon.vpminu.v4i16(<4 x i16>, <4 x i16>) nounwind readnone +declare <8 x i8> @llvm.arm.neon.vpmins.v8i8(<8 x i8>, <8 x i8>) nounwind readnone +declare <8 x i8> @llvm.arm.neon.vpminu.v8i8(<8 x i8>, <8 x i8>) nounwind readnone + +define weak_odr <8 x i8> @pairwise_Min_int8x8_int8x16(<16 x i8> %x) nounwind alwaysinline { + %a = shufflevector <16 x i8> %x, <16 x i8> undef, <8 x i32> + %b = shufflevector <16 x i8> %x, <16 x i8> undef, <8 x i32> + %result = tail call <8 x i8> @llvm.arm.neon.vpmins.v8i8(<8 x i8> %a, <8 x i8> %b) + ret <8 x i8> %result +} + +define weak_odr <4 x i16> @pairwise_Min_int16x4_int16x8(<8 x i16> %x) nounwind alwaysinline { + %a = shufflevector <8 x i16> %x, <8 x i16> undef, <4 x i32> + %b = shufflevector <8 x i16> %x, <8 x i16> undef, <4 x i32> + %result = tail call <4 x i16> @llvm.arm.neon.vpmins.v4i16(<4 x i16> %a, <4 x i16> %b) + ret <4 x i16> %result +} + +define weak_odr <2 x i32> @pairwise_Min_int32x2_int32x4(<4 x i32> %x) nounwind alwaysinline { + %a = shufflevector <4 x i32> %x, <4 x i32> undef, <2 x i32> + %b = shufflevector <4 x i32> %x, <4 x i32> undef, <2 x i32> + %result = tail call <2 x i32> @llvm.arm.neon.vpmins.v2i32(<2 x i32> %a, <2 x i32> %b) + ret <2 x i32> %result +} + +define weak_odr <2 x float> @pairwise_Min_float32x2_float32x4(<4 x float> %x) nounwind alwaysinline { + %a = shufflevector <4 x float> %x, <4 x float> undef, <2 x i32> + %b = shufflevector <4 x float> %x, <4 x float> undef, <2 x i32> + %result = tail call <2 x float> @llvm.arm.neon.vpmins.v2f32(<2 x float> %a, <2 x float> %b) + ret <2 x float> %result +} + +define weak_odr <8 x i8> @pairwise_Min_uint8x8_uint8x16(<16 x i8> %x) nounwind alwaysinline { + %a = shufflevector <16 x i8> %x, <16 x i8> undef, <8 x i32> + %b = shufflevector <16 x i8> %x, <16 x i8> undef, <8 x i32> + %result = tail call <8 x i8> @llvm.arm.neon.vpminu.v8i8(<8 x i8> %a, <8 x i8> %b) + ret <8 x i8> %result +} + +define weak_odr <4 x i16> @pairwise_Min_uint16x4_uint16x8(<8 x i16> %x) nounwind alwaysinline { + %a = shufflevector <8 x i16> %x, <8 x i16> undef, <4 x i32> + %b = shufflevector <8 x i16> %x, <8 x i16> undef, <4 x i32> + %result = tail call <4 x i16> @llvm.arm.neon.vpminu.v4i16(<4 x i16> %a, <4 x i16> %b) + ret <4 x i16> %result +} + +define weak_odr <2 x i32> @pairwise_Min_uint32x2_uint32x4(<4 x i32> %x) nounwind alwaysinline { + %a = shufflevector <4 x i32> %x, <4 x i32> undef, <2 x i32> + %b = shufflevector <4 x i32> %x, <4 x i32> undef, <2 x i32> + %result = tail call <2 x i32> @llvm.arm.neon.vpminu.v2i32(<2 x i32> %a, <2 x i32> %b) + ret <2 x i32> %result +} \ No newline at end of file diff --git a/src/runtime/arm_cpu_features.cpp b/src/runtime/arm_cpu_features.cpp index 647bbd024c03..7293f333cbf0 100644 --- a/src/runtime/arm_cpu_features.cpp +++ b/src/runtime/arm_cpu_features.cpp @@ -20,6 +20,8 @@ WEAK CpuFeatures halide_get_cpu_features() { // features.set_available(halide_target_feature_armv7s); // } + // TODO: add runtime detection for ARMDotProd extension + // https://github.com/halide/Halide/issues/4727 return features; } diff --git a/src/runtime/ptx_dev.ll b/src/runtime/ptx_dev.ll index 4125e8bd3938..e93d3ebc1253 100644 --- a/src/runtime/ptx_dev.ll +++ b/src/runtime/ptx_dev.ll @@ -345,3 +345,45 @@ define weak_odr i32 @halide_ptx_trap() nounwind uwtable alwaysinline { ret i32 0 } +; llvm doesn't expose dot product instructions as intrinsics +define weak_odr i32 @dp4a_s32_s32(i32 %a, i32 %b, i32 %i) nounwind readnone alwaysinline { + %d = tail call i32 asm "dp4a.s32.s32 $0, $1, $2, $3;", "=r,r,r,r"(i32 %a, i32 %b, i32 %i) nounwind readnone + ret i32 %d +} + +define weak_odr i32 @dp4a_s32_u32(i32 %a, i32 %b, i32 %i) nounwind readnone alwaysinline { + %d = tail call i32 asm "dp4a.s32.u32 $0, $1, $2, $3;", "=r,r,r,r"(i32 %a, i32 %b, i32 %i) nounwind readnone + ret i32 %d +} + +define weak_odr i32 @dp4a_u32_s32(i32 %a, i32 %b, i32 %i) nounwind readnone alwaysinline { + %d = tail call i32 asm "dp4a.u32.s32 $0, $1, $2, $3;", "=r,r,r,r"(i32 %a, i32 %b, i32 %i) nounwind readnone + ret i32 %d +} + +define weak_odr i32 @dp4a_u32_u32(i32 %a, i32 %b, i32 %i) nounwind readnone alwaysinline { + %d = tail call i32 asm "dp4a.u32.u32 $0, $1, $2, $3;", "=r,r,r,r"(i32 %a, i32 %b, i32 %i) nounwind readnone + ret i32 %d +} + + +define weak_odr i32 @dp2a_s32_s32(i32 %a_lo, i32 %a_hi, i32 %b, i32 %i) nounwind readnone alwaysinline { + %d = tail call i32 asm "dp2a.lo.s32.s32 $0, $1, $3, $4; dp2a.hi.s32.s32 $0, $2, $3, $0;", "=r,r,r,r,r"(i32 %a_lo, i32 %a_hi, i32 %b, i32 %i) nounwind readnone + ret i32 %d +} + +define weak_odr i32 @dp2a_s32_u32(i32 %a_lo, i32 %a_hi, i32 %b, i32 %i) nounwind readnone alwaysinline { + %d = tail call i32 asm "dp2a.lo.s32.u32 $0, $1, $3, $4; dp2a.hi.s32.u32 $0, $2, $3, $0;", "=r,r,r,r,r"(i32 %a_lo, i32 %a_hi, i32 %b, i32 %i) nounwind readnone + ret i32 %d +} + +define weak_odr i32 @dp2a_u32_s32(i32 %a_lo, i32 %a_hi, i32 %b, i32 %i) nounwind readnone alwaysinline { + %d = tail call i32 asm "dp2a.lo.u32.s32 $0, $1, $3, $4; dp2a.hi.u32.s32 $0, $2, $3, $0;", "=r,r,r,r,r"(i32 %a_lo, i32 %a_hi, i32 %b, i32 %i) nounwind readnone + ret i32 %d +} + +define weak_odr i32 @dp2a_u32_u32(i32 %a_lo, i32 %a_hi, i32 %b, i32 %i) nounwind readnone alwaysinline { + %d = tail call i32 asm "dp2a.lo.u32.u32 $0, $1, $3, $4; dp2a.hi.u32.u32 $0, $2, $3, $0;", "=r,r,r,r,r"(i32 %a_lo, i32 %a_hi, i32 %b, i32 %i) nounwind readnone + ret i32 %d +} + diff --git a/test/correctness/CMakeLists.txt b/test/correctness/CMakeLists.txt index 4ea68cc100c0..d81f13fbcf8b 100644 --- a/test/correctness/CMakeLists.txt +++ b/test/correctness/CMakeLists.txt @@ -67,6 +67,7 @@ tests(GROUPS correctness convolution.cpp convolution_multiple_kernels.cpp cross_compilation.cpp + cuda_8_bit_dot_product.cpp custom_allocator.cpp custom_auto_scheduler.cpp custom_error_reporter.cpp @@ -323,6 +324,7 @@ tests(GROUPS correctness vector_extern.cpp vector_math.cpp vector_print_bug.cpp + vector_reductions.cpp vector_tile.cpp vectorize_guard_with_if.cpp vectorize_mixed_widths.cpp diff --git a/test/correctness/atomics.cpp b/test/correctness/atomics.cpp index 71b1c268217d..91e35c49d829 100644 --- a/test/correctness/atomics.cpp +++ b/test/correctness/atomics.cpp @@ -22,7 +22,9 @@ template::value>::type * = nullptr> inline void check(int line_number, T x, T target, T threshold = T(1e-6)) { _halide_user_assert(std::fabs((x) - (target)) < threshold) - << "Line " << line_number << ": Expected " << (target) << " instead of " << (x) << "\n"; + << "Line " << line_number + << ": Expected " << (target) + << " instead of " << (x) << "\n"; } inline void check(int line_number, float16_t x, float16_t target) { @@ -37,7 +39,9 @@ template::value, int>::type * = nullptr> inline void check(int line_number, T x, T target) { _halide_user_assert(x == target) - << "Line " << line_number << ": Expected " << (target) << " instead of " << (x) << "\n"; + << "Line " << line_number + << ": Expected " << (int64_t)(target) + << " instead of " << (int64_t)(x) << "\n"; } template @@ -357,7 +361,7 @@ void test_predicated_hist(const Backend &backend) { case Backend::CUDAVectorize: { RVar ro, ri; RVar rio, rii; - hist.update() + hist.update(update_id) .atomic(true /*override_assciativity_test*/) .split(r, ro, ri, 32) .split(ri, rio, rii, 4) @@ -824,7 +828,7 @@ void test_hist_rfactor(const Backend &backend) { Func intermediate = hist.update() - .rfactor({{r.y, y}}); + .rfactor(r.y, y); intermediate.compute_root(); hist.compute_root(); switch (backend) { @@ -858,7 +862,13 @@ void test_hist_rfactor(const Backend &backend) { case Backend::CUDAVectorize: { RVar ro, ri; RVar rio, rii; - hist.update().atomic(true).split(r, ro, ri, 32).split(ri, rio, rii, 4).gpu_blocks(ro, DeviceAPI::CUDA).gpu_threads(rio, DeviceAPI::CUDA).vectorize(rii); + intermediate.update() + .atomic(true) + .split(r.x, ro, ri, 32) + .split(ri, rio, rii, 4) + .gpu_blocks(ro, DeviceAPI::CUDA) + .gpu_threads(rio, DeviceAPI::CUDA) + .vectorize(rii); } break; default: { _halide_user_assert(false) << "Unsupported backend.\n"; diff --git a/test/correctness/cuda_8_bit_dot_product.cpp b/test/correctness/cuda_8_bit_dot_product.cpp new file mode 100644 index 000000000000..8ec33f23458e --- /dev/null +++ b/test/correctness/cuda_8_bit_dot_product.cpp @@ -0,0 +1,90 @@ +#include "Halide.h" + +#include + +using namespace Halide; + +template +void test(Target t) { + for (int factor : {4, 16}) { + for (int vec : {1, 4}) { + std::cout + << "Testing dot product of " + << type_of() << " * " << type_of() << " -> " << type_of() + << " with vector width " << vec + << " and reduction factor " << factor << "\n"; + Func in_a, in_b; + Var x, y; + + in_a(x, y) = cast(x - y * 17); + in_a.compute_root(); + + in_b(x, y) = cast(x * 3 + y * 7); + in_b.compute_root(); + + Func g; + RDom r(0, factor * 4); + g(x, y) += cast(in_a(r, x)) * in_b(r, y); + + Func h; + h(x, y) = g(x, y); + + Var xi, yi; + g.update().atomic().vectorize(r, factor).unroll(r); + h.gpu_tile(x, y, xi, yi, 32, 8, TailStrategy::RoundUp); + + Buffer out(128, 128); + h.realize(out); + out.copy_to_host(); + + for (int y = 0; y < out.height(); y++) { + for (int x = 0; x < out.width(); x++) { + Out correct = 0; + for (int r = 0; r < factor * 4; r++) { + A in_a_r_x = (A)(r - x * 17); + B in_b_r_y = (B)(r * 3 + y * 7); + correct += ((Out)(in_a_r_x)) * in_b_r_y; + } + if (out(x, y) != correct) { + printf("out(%d, %d) = %d instead of %d\n", x, y, (int)(out(x, y)), (int)(correct)); + exit(-1); + } + } + } + + // Check the instruction was emitted intended by just grepping the + // compiled code (the PTX source is an embedded string). + Buffer buf = h.compile_to_module(std::vector(), "h", t).compile_to_buffer(); + std::basic_regex regex("dp[24]a[.lo]*[us]32[.][us]32"); + if (!std::regex_search((const char *)buf.begin(), (const char *)buf.end(), regex)) { + printf("Did not find use of dp2a or dp4a in compiled code. Rerun test with HL_DEBUG_CODEGEN=1 to debug\n"); + exit(-1); + } + } + } +} + +int main(int argc, char **argv) { + Target t = get_jit_target_from_environment(); + if (!t.has_feature(Target::CUDACapability61)) { + printf("[SKIP] Cuda (with compute capability 6.1) is not enabled in target: %s\n", + t.to_string().c_str()); + return 0; + } + + test(t); + test(t); + test(t); + test(t); + test(t); + test(t); + test(t); + test(t); + test(t); + test(t); + test(t); + test(t); + + printf("Success!\n"); + return 0; +} diff --git a/test/correctness/simd_op_check.cpp b/test/correctness/simd_op_check.cpp index a712f25c0962..782881c18b96 100644 --- a/test/correctness/simd_op_check.cpp +++ b/test/correctness/simd_op_check.cpp @@ -187,7 +187,7 @@ class SimdOpCheck : public SimdOpCheckTest { // SSE 2 - for (int w = 2; w <= 4; w++) { + for (int w : {2, 4}) { check("addpd", w, f64_1 + f64_2); check("subpd", w, f64_1 - f64_2); check("mulpd", w, f64_1 * f64_2); @@ -225,11 +225,8 @@ class SimdOpCheck : public SimdOpCheckTest { check(std::string("packuswb") + check_suffix, 8 * w, u8_sat(i16_1)); } - // SSE 3 + // SSE 3 / SSSE 3 - // We don't do horizontal add/sub ops, so nothing new here - - // SSSE 3 if (use_ssse3) { for (int w = 2; w <= 4; w++) { check("pmulhrsw", 4 * w, i16((((i32(i16_1) * i32(i16_2)) + 16384)) / 32768)); @@ -237,15 +234,68 @@ class SimdOpCheck : public SimdOpCheckTest { check("pabsw", 4 * w, abs(i16_1)); check("pabsd", 2 * w, abs(i32_1)); } + +#if LLVM_VERSION >= 90 + // Horizontal ops. Our support for them uses intrinsics + // from LLVM 9+. + + // Paradoxically, haddps is a bad way to do horizontal + // adds down to a single scalar on most x86. A better + // sequence (according to Peter Cordes on stackoverflow) + // is movshdup, addps, movhlps, addss. haddps is still + // good if you're only partially reducing and your result + // is at least one native vector, if only to save code + // size, but LLVM really really tries to avoid it and + // replace it with shuffles whenever it can, so we won't + // test for it. + // + // See: + // https://stackoverflow.com/questions/6996764/fastest-way-to-do-horizontal-float-vector-sum-on-x86 + + // For reducing down to a scalar we expect to see addps + // and movshdup. We'll sniff for the movshdup. + check("movshdup", 1, sum(in_f32(RDom(0, 2) + 2 * x))); + check("movshdup", 1, sum(in_f32(RDom(0, 4) + 4 * x))); + check("movshdup", 1, sum(in_f32(RDom(0, 16) + 16 * x))); + + // The integer horizontal add operations are pretty + // terrible on all x86 variants, and LLVM does its best to + // avoid generating those too, so we won't test that here + // either. + + // Min reductions should use phminposuw when + // possible. This only exists for u16. X86 is weird. + check("phminposuw", 1, minimum(in_u16(RDom(0, 8) + 8 * x))); + + // Max reductions can use the same instruction by first + // flipping the bits. + check("phminposuw", 1, maximum(in_u16(RDom(0, 8) + 8 * x))); + + // Reductions over signed ints can flip the sign bit + // before and after (equivalent to adding 128). + check("phminposuw", 1, minimum(in_i16(RDom(0, 8) + 8 * x))); + check("phminposuw", 1, maximum(in_i16(RDom(0, 8) + 8 * x))); + + // Reductions over 8-bit ints can widen first + check("phminposuw", 1, minimum(in_u8(RDom(0, 16) + 16 * x))); + check("phminposuw", 1, maximum(in_u8(RDom(0, 16) + 16 * x))); + check("phminposuw", 1, minimum(in_i8(RDom(0, 16) + 16 * x))); + check("phminposuw", 1, maximum(in_i8(RDom(0, 16) + 16 * x))); +#endif } // SSE 4.1 - // skip dot product and argmin - for (int w = 2; w <= 4; w++) { - const char *check_pmaddwd = (use_avx2 && w > 3) ? "vpmaddwd*ymm" : "pmaddwd"; + for (int w = 2; w <= 8; w++) { + // We generated pmaddwd when we do a sum of widening multiplies + const char *check_pmaddwd = + (use_avx2 && w >= 4) ? "vpmaddwd" : "pmaddwd"; check(check_pmaddwd, 2 * w, i32(i16_1) * 3 + i32(i16_2) * 4); check(check_pmaddwd, 2 * w, i32(i16_1) * 3 - i32(i16_2) * 4); + + // And also for dot-products + RDom r(0, 4); + check(check_pmaddwd, 2 * w, sum(i32(in_i16(x * 4 + r)) * in_i16(x * 4 + r + 32))); } // llvm doesn't distinguish between signed and unsigned multiplies @@ -888,12 +938,113 @@ class SimdOpCheck : public SimdOpCheckTest { // VORR X - Bitwise OR // check("vorr", bool1 | bool2); - // VPADAL I - Pairwise Add and Accumulate Long - // VPADD I, F - Pairwise Add - // VPADDL I - Pairwise Add Long - // VPMAX I, F - Pairwise Maximum - // VPMIN I, F - Pairwise Minimum - // We don't do horizontal ops + for (int f : {2, 4}) { + RDom r(0, f); + + // A summation reduction that starts at something + // non-trivial, to avoid llvm simplifying accumulating + // widening summations into just widening summations. + auto sum_ = [&](Expr e) { + Func f; + f(x) = cast(e.type(), 123); + f(x) += e; + return f(x); + }; + + // VPADD I, F - Pairwise Add + check(arm32 ? "vpadd.i8" : "addp", 16, sum_(in_i8(f * x + r))); + check(arm32 ? "vpadd.i8" : "addp", 16, sum_(in_u8(f * x + r))); + check(arm32 ? "vpadd.i16" : "addp", 8, sum_(in_i16(f * x + r))); + check(arm32 ? "vpadd.i16" : "addp", 8, sum_(in_u16(f * x + r))); + check(arm32 ? "vpadd.i32" : "addp", 4, sum_(in_i32(f * x + r))); + check(arm32 ? "vpadd.i32" : "addp", 4, sum_(in_u32(f * x + r))); + check(arm32 ? "vpadd.f32" : "addp", 4, sum_(in_f32(f * x + r))); + // In 32-bit, we don't have a pairwise op for doubles, + // and expect to just get vadd instructions on d + // registers. + check(arm32 ? "vadd.f64" : "addp", 4, sum_(in_f64(f * x + r))); + + if (f == 2) { + // VPADAL I - Pairwise Add and Accumulate Long + + // If we're reducing by a factor of two, we can + // use the forms with an accumulator + check(arm32 ? "vpadal.s8" : "sadalp", 16, sum_(i16(in_i8(f * x + r)))); + check(arm32 ? "vpadal.u8" : "uadalp", 16, sum_(i16(in_u8(f * x + r)))); + check(arm32 ? "vpadal.u8" : "uadalp", 16, sum_(u16(in_u8(f * x + r)))); + + check(arm32 ? "vpadal.s16" : "sadalp", 8, sum_(i32(in_i16(f * x + r)))); + check(arm32 ? "vpadal.u16" : "uadalp", 8, sum_(i32(in_u16(f * x + r)))); + check(arm32 ? "vpadal.u16" : "uadalp", 8, sum_(u32(in_u16(f * x + r)))); + + check(arm32 ? "vpadal.s32" : "sadalp", 4, sum_(i64(in_i32(f * x + r)))); + check(arm32 ? "vpadal.u32" : "uadalp", 4, sum_(i64(in_u32(f * x + r)))); + check(arm32 ? "vpadal.u32" : "uadalp", 4, sum_(u64(in_u32(f * x + r)))); + } else { + // VPADDL I - Pairwise Add Long + + // If we're reducing by more than that, that's not + // possible. + check(arm32 ? "vpaddl.s8" : "saddlp", 16, sum_(i16(in_i8(f * x + r)))); + check(arm32 ? "vpaddl.u8" : "uaddlp", 16, sum_(i16(in_u8(f * x + r)))); + check(arm32 ? "vpaddl.u8" : "uaddlp", 16, sum_(u16(in_u8(f * x + r)))); + + check(arm32 ? "vpaddl.s16" : "saddlp", 8, sum_(i32(in_i16(f * x + r)))); + check(arm32 ? "vpaddl.u16" : "uaddlp", 8, sum_(i32(in_u16(f * x + r)))); + check(arm32 ? "vpaddl.u16" : "uaddlp", 8, sum_(u32(in_u16(f * x + r)))); + + check(arm32 ? "vpaddl.s32" : "saddlp", 4, sum_(i64(in_i32(f * x + r)))); + check(arm32 ? "vpaddl.u32" : "uaddlp", 4, sum_(i64(in_u32(f * x + r)))); + check(arm32 ? "vpaddl.u32" : "uaddlp", 4, sum_(u64(in_u32(f * x + r)))); + + // If we're widening the type by a factor of four + // as well as reducing by a factor of four, we + // expect vpaddl followed by vpadal + check(arm32 ? "vpaddl.s8" : "saddlp", 8, sum_(i32(in_i8(f * x + r)))); + check(arm32 ? "vpaddl.u8" : "uaddlp", 8, sum_(i32(in_u8(f * x + r)))); + check(arm32 ? "vpaddl.u8" : "uaddlp", 8, sum_(u32(in_u8(f * x + r)))); + check(arm32 ? "vpaddl.s16" : "saddlp", 4, sum_(i64(in_i16(f * x + r)))); + check(arm32 ? "vpaddl.u16" : "uaddlp", 4, sum_(i64(in_u16(f * x + r)))); + check(arm32 ? "vpaddl.u16" : "uaddlp", 4, sum_(u64(in_u16(f * x + r)))); + + // Note that when going from u8 to i32 like this, + // the vpaddl is unsigned and the vpadal is a + // signed, because the intermediate type is u16 + check(arm32 ? "vpadal.s16" : "sadalp", 8, sum_(i32(in_i8(f * x + r)))); + check(arm32 ? "vpadal.u16" : "uadalp", 8, sum_(i32(in_u8(f * x + r)))); + check(arm32 ? "vpadal.u16" : "uadalp", 8, sum_(u32(in_u8(f * x + r)))); + check(arm32 ? "vpadal.s32" : "sadalp", 4, sum_(i64(in_i16(f * x + r)))); + check(arm32 ? "vpadal.u32" : "uadalp", 4, sum_(i64(in_u16(f * x + r)))); + check(arm32 ? "vpadal.u32" : "uadalp", 4, sum_(u64(in_u16(f * x + r)))); + } + + // VPMAX I, F - Pairwise Maximum + check(arm32 ? "vpmax.s8" : "smaxp", 16, maximum(in_i8(f * x + r))); + check(arm32 ? "vpmax.u8" : "umaxp", 16, maximum(in_u8(f * x + r))); + check(arm32 ? "vpmax.s16" : "smaxp", 8, maximum(in_i16(f * x + r))); + check(arm32 ? "vpmax.u16" : "umaxp", 8, maximum(in_u16(f * x + r))); + check(arm32 ? "vpmax.s32" : "smaxp", 4, maximum(in_i32(f * x + r))); + check(arm32 ? "vpmax.u32" : "umaxp", 4, maximum(in_u32(f * x + r))); + + // VPMIN I, F - Pairwise Minimum + check(arm32 ? "vpmin.s8" : "sminp", 16, minimum(in_i8(f * x + r))); + check(arm32 ? "vpmin.u8" : "uminp", 16, minimum(in_u8(f * x + r))); + check(arm32 ? "vpmin.s16" : "sminp", 8, minimum(in_i16(f * x + r))); + check(arm32 ? "vpmin.u16" : "uminp", 8, minimum(in_u16(f * x + r))); + check(arm32 ? "vpmin.s32" : "sminp", 4, minimum(in_i32(f * x + r))); + check(arm32 ? "vpmin.u32" : "uminp", 4, minimum(in_u32(f * x + r))); + } + + // UDOT/SDOT + if (target.has_feature(Target::ARMDotProd)) { + for (int f : {4, 8}) { + RDom r(0, f); + for (int v : {2, 4}) { + check("udot", v, sum(u32(in_u8(f * x + r)) * in_u8(f * x + r + 32))); + check("sdot", v, sum(i32(in_i8(f * x + r)) * in_i8(f * x + r + 32))); + } + } + } // VPOP X F, D Pop from Stack // VPUSH X F, D Push to Stack diff --git a/test/correctness/simd_op_check.h b/test/correctness/simd_op_check.h index ceec22221347..0232f95ef52e 100644 --- a/test/correctness/simd_op_check.h +++ b/test/correctness/simd_op_check.h @@ -130,6 +130,25 @@ class SimdOpCheckTest { TestResult check_one(const std::string &op, const std::string &name, int vector_width, Expr e) { std::ostringstream error_msg; + class HasInlineReduction : public Internal::IRVisitor { + using Internal::IRVisitor::visit; + void visit(const Internal::Call *op) override { + if (op->call_type == Internal::Call::Halide) { + Internal::Function f(op->func); + if (f.has_update_definition()) { + inline_reduction = f; + result = true; + } + } + IRVisitor::visit(op); + } + + public: + Internal::Function inline_reduction; + bool result = false; + } has_inline_reduction; + e.accept(&has_inline_reduction); + // Define a vectorized Halide::Func that uses the pattern. Halide::Func f(name); f(x, y) = e; @@ -142,10 +161,28 @@ class SimdOpCheckTest { f_scalar.bound(x, 0, W); f_scalar.compute_root(); + if (has_inline_reduction.result) { + // If there's an inline reduction, we want to vectorize it + // over the RVar. + Var xo, xi; + RVar rxi; + Func g{has_inline_reduction.inline_reduction}; + + // Do the reduction separately in f_scalar + g.clone_in(f_scalar); + + g.compute_at(f, x) + .update() + .split(x, xo, xi, vector_width) + .fuse(g.rvars()[0], xi, rxi) + .atomic() + .vectorize(rxi); + } + // The output to the pipeline is the maximum absolute difference as a double. - RDom r(0, W, 0, H); + RDom r_check(0, W, 0, H); Halide::Func error("error_" + name); - error() = Halide::cast(maximum(absd(f(r.x, r.y), f_scalar(r.x, r.y)))); + error() = Halide::cast(maximum(absd(f(r_check.x, r_check.y), f_scalar(r_check.x, r_check.y)))); setup_images(); { diff --git a/test/correctness/tuple_vector_reduce.cpp b/test/correctness/tuple_vector_reduce.cpp new file mode 100644 index 000000000000..7cf15c0d1351 --- /dev/null +++ b/test/correctness/tuple_vector_reduce.cpp @@ -0,0 +1,59 @@ +#include "Halide.h" + +using namespace Halide; + +int main(int argc, char **argv) { + // Make sure a tuple-valued associative reduction can be + // horizontally vectorized. + + { + // Tuple addition + Func in; + Var x; + in(x) = {x, 2 * x}; + + Func f; + f() = {0, 0}; + + RDom r(1, 100); + f() = {f()[0] + in(r)[0], f()[1] + in(r)[1]}; + + in.compute_root(); + f.update().atomic().vectorize(r, 8); //.parallel(r); + + f.realize(); + } + + return 0; + + { + // Complex multiplication is associative. Let's multiply a bunch + // of complex numbers together. + Func in; + Var x; + in(x) = {x, x}; + + Func f; + f() = {1, 0}; + + RDom r(1, 100); + Expr a_real = f()[0]; + Expr a_imag = f()[1]; + Expr b_real = in(r)[0]; + Expr b_imag = in(r)[1]; + f() = {a_real * b_real - a_imag * b_imag, + a_real * b_imag + b_real * a_imag}; + + in.compute_root(); + f.update().atomic().vectorize(r, 8); + + // Sadly, this won't actually vectorize, because it's not + // expressible as a horizontal reduction op on a single + // vector. You'd need to rfactor. We can at least check we get + // the right value back though. + f.realize(); + } + + printf("Success!\n"); + return 0; +} diff --git a/test/correctness/vector_reductions.cpp b/test/correctness/vector_reductions.cpp new file mode 100644 index 000000000000..18c0bc259def --- /dev/null +++ b/test/correctness/vector_reductions.cpp @@ -0,0 +1,126 @@ +#include "Halide.h" + +using namespace Halide; + +int main(int argc, char **argv) { + for (int dst_lanes : {1, 3}) { + for (int reduce_factor : {2, 3, 4}) { + std::vector types = + {UInt(8), Int(8), UInt(16), Int(16), UInt(32), Int(32), + UInt(64), Int(64), Float(16), Float(32), Float(64)}; + const int src_lanes = dst_lanes * reduce_factor; + for (Type src_type : types) { + for (int widen_factor : {1, 2, 4}) { + Type dst_type = src_type.with_bits(src_type.bits() * widen_factor); + if (std::find(types.begin(), types.end(), dst_type) == types.end()) { + continue; + } + + for (int op = 0; op < 7; op++) { + if (dst_type == Float(16) && reduce_factor > 2) { + // Reductions of float16s is really not very associative + continue; + } + + Var x, xo, xi; + RDom r(0, reduce_factor); + RVar rx; + Func in; + if (src_type.is_float()) { + in(x) = cast(src_type, random_float()); + } else { + in(x) = cast(src_type, random_int()); + } + in.compute_root(); + + Expr rhs = cast(dst_type, in(x * reduce_factor + r)); + Expr rhs2 = cast(dst_type, in(x * reduce_factor + r + 32)); + + if (op == 4 || op == 5) { + rhs = rhs > cast(rhs.type(), 5); + } + + Func f, ref("ref"); + switch (op) { + case 0: + f(x) += rhs; + ref(x) += rhs; + break; + case 1: + f(x) *= rhs; + ref(x) *= rhs; + break; + case 2: + // Widening min/max reductions are not interesting + if (widen_factor != 1) { + continue; + } + f(x) = rhs.type().min(); + ref(x) = rhs.type().min(); + f(x) = max(f(x), rhs); + ref(x) = max(f(x), rhs); + break; + case 3: + if (widen_factor != 1) { + continue; + } + f(x) = rhs.type().max(); + ref(x) = rhs.type().max(); + f(x) = min(f(x), rhs); + ref(x) = min(f(x), rhs); + break; + case 4: + if (widen_factor != 1) { + continue; + } + f(x) = cast(false); + ref(x) = cast(false); + f(x) = f(x) || rhs; + ref(x) = f(x) || rhs; + break; + case 5: + if (widen_factor != 1) { + continue; + } + f(x) = cast(true); + ref(x) = cast(true); + f(x) = f(x) && rhs; + ref(x) = f(x) && rhs; + break; + case 6: + // Dot product + f(x) += rhs * rhs2; + ref(x) += rhs * rhs2; + } + + f.compute_root() + .update() + .split(x, xo, xi, dst_lanes) + .fuse(r, xi, rx) + .atomic() + .vectorize(rx); + ref.compute_root(); + + RDom c(0, 128); + Expr err = cast(maximum(absd(f(c), ref(c)))); + + double e = evaluate(err); + + if (e > 1e-3) { + std::cerr + << "Horizontal reduction produced different output when vectorized!\n" + << "Maximum error = " << e << "\n" + << "Reducing from " << src_type.with_lanes(src_lanes) + << " to " << dst_type.with_lanes(dst_lanes) << "\n" + << "RHS: " << f.update_value() << "\n"; + exit(-1); + } + } + } + } + } + } + + printf("Success!\n"); + return 0; +} diff --git a/test/error/CMakeLists.txt b/test/error/CMakeLists.txt index 304676fca36c..f2e8feb7775d 100644 --- a/test/error/CMakeLists.txt +++ b/test/error/CMakeLists.txt @@ -6,7 +6,6 @@ tests(GROUPS error atomics_gpu_8_bit.cpp atomics_gpu_mutex.cpp atomics_self_reference.cpp - atomics_vectorized_mutex.cpp auto_schedule_no_parallel.cpp auto_schedule_no_reorder.cpp autodiff_unbounded.cpp diff --git a/test/error/atomics_vectorized_mutex.cpp b/test/error/atomics_vectorized_mutex.cpp deleted file mode 100644 index 75a71840e9ea..000000000000 --- a/test/error/atomics_vectorized_mutex.cpp +++ /dev/null @@ -1,29 +0,0 @@ -#include "Halide.h" - -using namespace Halide; - -int main(int argc, char **argv) { - int img_size = 10000; - - Func f; - Var x; - RDom r(0, img_size); - - f(x) = Tuple(1, 0); - f(r) = Tuple(f(r)[1] + 1, f(r)[0] + 1); - - f.compute_root(); - - f.update() - .atomic() - .vectorize(r, 8); - - // f's update will be lowered to mutex locks, - // and we don't allow vectorization on mutex locks since - // it leads to deadlocks. - // This should throw an error - Realization out = f.realize(img_size); - - printf("Success!\n"); - return 0; -}