Skip to content

Commit

Permalink
Codegen for VectorReduce IR nodes
Browse files Browse the repository at this point in the history
  • Loading branch information
abadams committed Jun 13, 2020
1 parent 585bdc7 commit f63c800
Show file tree
Hide file tree
Showing 25 changed files with 2,157 additions and 67 deletions.
1 change: 1 addition & 0 deletions python_bindings/src/PyEnums.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ void define_enums(py::module &m) {
.value("WasmSignExt", Target::Feature::WasmSignExt)
.value("SVE", Target::Feature::SVE)
.value("SVE2", Target::Feature::SVE2)
.value("ARMDotProd", Target::Feature::ARMDotProd)
.value("FeatureEnd", Target::Feature::FeatureEnd);

py::enum_<halide_type_code_t>(m, "TypeCode")
Expand Down
187 changes: 183 additions & 4 deletions src/CodeGen_ARM.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include <iostream>
#include <sstream>

#include "CSE.h"
#include "CodeGen_ARM.h"
#include "ConciseCasts.h"
#include "Debug.h"
Expand Down Expand Up @@ -483,10 +484,6 @@ void CodeGen_ARM::visit(const Div *op) {
CodeGen_Posix::visit(op);
}

void CodeGen_ARM::visit(const Add *op) {
CodeGen_Posix::visit(op);
}

void CodeGen_ARM::visit(const Sub *op) {
if (neon_intrinsics_disabled()) {
CodeGen_Posix::visit(op);
Expand Down Expand Up @@ -1063,6 +1060,184 @@ void CodeGen_ARM::visit(const LE *op) {
CodeGen_Posix::visit(op);
}

void CodeGen_ARM::codegen_vector_reduce(const VectorReduce *op, const Expr &init) {
if (neon_intrinsics_disabled() ||
op->op == VectorReduce::Or ||
op->op == VectorReduce::And ||
op->op == VectorReduce::Mul) {
CodeGen_Posix::codegen_vector_reduce(op, init);
return;
}

// ARM has a variety of pairwise reduction ops for +, min,
// max. The versions that do not widen take two 64-bit args and
// return one 64-bit vector of the same type. The versions that
// widen take one arg and return something with half the vector
// lanes and double the bit-width.

int factor = op->value.type().lanes() / op->type.lanes();

// These are the types for which we have reduce intrinsics in the
// runtime.
bool have_reduce_intrinsic = (op->type.is_int() ||
op->type.is_uint() ||
op->type.is_float());

// We don't have 16-bit float or bfloat horizontal ops
if (op->type.is_bfloat() || (op->type.is_float() && op->type.bits() < 32)) {
have_reduce_intrinsic = false;
}

// Only aarch64 has float64 horizontal ops
if (target.bits == 32 && op->type.element_of() == Float(64)) {
have_reduce_intrinsic = false;
}

// For 64-bit integers, we only have addition, not min/max
if (op->type.bits() == 64 &&
!op->type.is_float() &&
op->op != VectorReduce::Add) {
have_reduce_intrinsic = false;
}

// We only have intrinsics that reduce by a factor of two
if (factor != 2) {
have_reduce_intrinsic = false;
}

if (have_reduce_intrinsic) {
Expr arg = op->value;
if (op->op == VectorReduce::Add &&
op->type.bits() >= 16 &&
!op->type.is_float()) {
Type narrower_type = arg.type().with_bits(arg.type().bits() / 2);
Expr narrower = lossless_cast(narrower_type, arg);
if (!narrower.defined() && arg.type().is_int()) {
// We can also safely accumulate from a uint into a
// wider int, because the addition uses at most one
// extra bit.
narrower = lossless_cast(narrower_type.with_code(Type::UInt), arg);
}
if (narrower.defined()) {
arg = narrower;
}
}
int output_bits;
if (target.bits == 32 && arg.type().bits() == op->type.bits()) {
// For the non-widening version, the output must be 64-bit
output_bits = 64;
} else if (op->type.bits() * op->type.lanes() <= 64) {
// No point using the 128-bit version of the instruction if the output is narrow.
output_bits = 64;
} else {
output_bits = 128;
}

const int output_lanes = output_bits / op->type.bits();
Type intrin_type = op->type.with_lanes(output_lanes);
Type arg_type = arg.type().with_lanes(output_lanes * 2);
if (op->op == VectorReduce::Add &&
arg.type().bits() == op->type.bits() &&
arg_type.is_uint()) {
// For non-widening additions, there is only a signed
// version (because it's equivalent).
arg_type = arg_type.with_code(Type::Int);
intrin_type = intrin_type.with_code(Type::Int);
} else if (arg.type().is_uint() && intrin_type.is_int()) {
// Use the uint version
intrin_type = intrin_type.with_code(Type::UInt);
}

std::stringstream ss;
vector<Expr> args;
ss << "pairwise_" << op->op << "_" << intrin_type << "_" << arg_type;
Expr accumulator = init;
if (op->op == VectorReduce::Add &&
accumulator.defined() &&
arg_type.bits() < intrin_type.bits()) {
// We can use the accumulating variant
ss << "_accumulate";
args.push_back(init);
accumulator = Expr();
}
args.push_back(arg);
value = call_intrin(op->type, output_lanes, ss.str(), args);

if (accumulator.defined()) {
// We still have an initial value to take care of
string n = unique_name('t');
sym_push(n, value);
Expr v = Variable::make(accumulator.type(), n);
switch (op->op) {
case VectorReduce::Add:
accumulator += v;
break;
case VectorReduce::Min:
accumulator = min(accumulator, v);
break;
case VectorReduce::Max:
accumulator = max(accumulator, v);
break;
default:
internal_error << "unreachable";
}
codegen(accumulator);
sym_pop(n);
}

return;
}

// Pattern-match 8-bit dot product instructions available on newer
// ARM cores.
if (target.has_feature(Target::ARMDotProd) &&
factor % 4 == 0 &&
op->op == VectorReduce::Add &&
target.bits == 64 &&
(op->type.element_of() == Int(32) ||
op->type.element_of() == UInt(32))) {
const Mul *mul = op->value.as<Mul>();
if (mul) {
const int input_lanes = mul->type.lanes();
Expr a = lossless_cast(UInt(8, input_lanes), mul->a);
Expr b = lossless_cast(UInt(8, input_lanes), mul->b);
if (!a.defined()) {
a = lossless_cast(Int(8, input_lanes), mul->a);
b = lossless_cast(Int(8, input_lanes), mul->b);
}
if (a.defined() && b.defined()) {
if (factor != 4) {
Expr equiv = VectorReduce::make(op->op, op->value, input_lanes / 4);
equiv = VectorReduce::make(op->op, equiv, op->type.lanes());
codegen_vector_reduce(equiv.as<VectorReduce>(), init);
return;
}
Expr i = init;
if (!i.defined()) {
i = make_zero(op->type);
}
vector<Expr> args{i, a, b};
if (op->type.lanes() <= 2) {
if (op->type.is_uint()) {
value = call_intrin(op->type, 2, "llvm.aarch64.neon.udot.v2i32.v8i8", args);
} else {
value = call_intrin(op->type, 2, "llvm.aarch64.neon.sdot.v2i32.v8i8", args);
}
} else {
if (op->type.is_uint()) {
value = call_intrin(op->type, 4, "llvm.aarch64.neon.udot.v4i32.v16i8", args);
} else {
value = call_intrin(op->type, 4, "llvm.aarch64.neon.sdot.v4i32.v16i8", args);
}
}
return;
}
}
}

CodeGen_Posix::codegen_vector_reduce(op, init);
}

string CodeGen_ARM::mcpu() const {
if (target.bits == 32) {
if (target.has_feature(Target::ARMv7s)) {
Expand Down Expand Up @@ -1098,6 +1273,10 @@ string CodeGen_ARM::mattrs() const {
arch_flags = "+sve";
}

if (target.has_feature(Target::ARMDotProd)) {
arch_flags += "+dotprod";
}

if (target.os == Target::IOS || target.os == Target::OSX) {
return arch_flags + "+reserve-x18";
} else {
Expand Down
2 changes: 1 addition & 1 deletion src/CodeGen_ARM.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ class CodeGen_ARM : public CodeGen_Posix {
/** Nodes for which we want to emit specific neon intrinsics */
// @{
void visit(const Cast *) override;
void visit(const Add *) override;
void visit(const Sub *) override;
void visit(const Div *) override;
void visit(const Mul *) override;
Expand All @@ -35,6 +34,7 @@ class CodeGen_ARM : public CodeGen_Posix {
void visit(const Call *) override;
void visit(const LT *) override;
void visit(const LE *) override;
void codegen_vector_reduce(const VectorReduce *, const Expr &) override;
// @}

/** Various patterns to peephole match against */
Expand Down
Loading

0 comments on commit f63c800

Please sign in to comment.