From c3734d3a4de2b9722763ffcae7a09c94131ee5fa Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Thu, 11 Jun 2020 15:33:29 -0700 Subject: [PATCH 1/4] Add VectorReduce IR node --- src/Bounds.cpp | 34 ++++++++++ src/Deinterleave.cpp | 12 ++++ src/Derivative.cpp | 3 + src/Expr.h | 1 + src/IR.cpp | 30 +++++++++ src/IR.h | 27 ++++++++ src/IREquality.cpp | 9 +++ src/IRMatch.cpp | 14 ++++ src/IRMatch.h | 110 ++++++++++++++++++++++++++++++- src/IRMutator.cpp | 8 +++ src/IRMutator.h | 1 + src/IROperator.cpp | 25 ++++++++ src/IRPrinter.cpp | 34 ++++++++++ src/IRPrinter.h | 4 ++ src/IRVisitor.cpp | 8 +++ src/IRVisitor.h | 5 ++ src/ModulusRemainder.cpp | 9 ++- src/Monotonic.cpp | 18 ++++++ src/Simplify_Exprs.cpp | 135 +++++++++++++++++++++++++++++++++++++++ src/Simplify_Internal.h | 1 + src/StmtToHtml.cpp | 7 ++ 21 files changed, 493 insertions(+), 2 deletions(-) diff --git a/src/Bounds.cpp b/src/Bounds.cpp index 891de2a019b2..ece8617b3290 100644 --- a/src/Bounds.cpp +++ b/src/Bounds.cpp @@ -1403,6 +1403,40 @@ class Bounds : public IRVisitor { interval = result; } + void visit(const VectorReduce *op) override { + TRACK_BOUNDS_INTERVAL; + op->value.accept(this); + int factor = op->value.type().lanes() / op->type.lanes(); + switch (op->op) { + case VectorReduce::Add: + if (interval.has_upper_bound()) { + interval.max *= factor; + } + if (interval.has_lower_bound()) { + interval.min *= factor; + } + break; + case VectorReduce::Mul: + // Technically there are some things we could say + // here. E.g. if all the lanes are positive then we're + // bounded by the upper bound raised to the factor + // power. However it's extremely unlikely that a mul + // reduce will ever make it into a bounds expression, so + // for now we bail. + interval = Interval::everything(); + break; + case VectorReduce::Min: + case VectorReduce::Max: + // The bounds of a single lane are sufficient + break; + case VectorReduce::And: + case VectorReduce::Or: + // Don't try for now + interval = Interval::everything(); + break; + } + } + void visit(const LetStmt *) override { internal_error << "Bounds of statement\n"; } diff --git a/src/Deinterleave.cpp b/src/Deinterleave.cpp index 68617a7ef13f..bd9e61c33b27 100644 --- a/src/Deinterleave.cpp +++ b/src/Deinterleave.cpp @@ -194,6 +194,18 @@ class Deinterleaver : public IRGraphMutator { using IRMutator::visit; + Expr visit(const VectorReduce *op) override { + std::vector input_lanes; + int factor = op->value.type().lanes() / op->type.lanes(); + for (int i = starting_lane; i < op->type.lanes(); i += lane_stride) { + for (int j = 0; j < factor; j++) { + input_lanes.push_back(i * factor + j); + } + } + Expr in = Shuffle::make({op->value}, input_lanes); + return VectorReduce::make(op->op, in, new_lanes); + } + Expr visit(const Broadcast *op) override { if (new_lanes == 1) { return op->value; diff --git a/src/Derivative.cpp b/src/Derivative.cpp index a916ef1874ee..8c7363323289 100644 --- a/src/Derivative.cpp +++ b/src/Derivative.cpp @@ -88,6 +88,9 @@ class ReverseAccumulationVisitor : public IRVisitor { void visit(const Shuffle *op) override { internal_error << "Encounter unexpected expression \"Shuffle\" when differentiating."; } + void visit(const VectorReduce *op) override { + internal_error << "Encounter unexpected expression \"VectorReduce\" when differentiating."; + } void visit(const LetStmt *op) override { internal_error << "Encounter unexpected statement \"LetStmt\" when differentiating."; } diff --git a/src/Expr.h b/src/Expr.h index b1c31541b14d..06ed94638b66 100644 --- a/src/Expr.h +++ b/src/Expr.h @@ -56,6 +56,7 @@ enum class IRNodeType { Call, Let, Shuffle, + VectorReduce, // Stmts LetStmt, AssertStmt, diff --git a/src/IR.cpp b/src/IR.cpp index aae9f66ed547..c7359be9046a 100644 --- a/src/IR.cpp +++ b/src/IR.cpp @@ -824,6 +824,28 @@ Stmt Atomic::make(const std::string &producer_name, return node; } +Expr VectorReduce::make(VectorReduce::Operator op, + Expr vec, + int lanes) { + if (vec.type().is_bool()) { + internal_assert(op == VectorReduce::And || op == VectorReduce::Or) + << "The only legal operators for VectorReduce on a Bool" + << "vector are VectorReduce::And and VectorReduce::Or\n"; + } + internal_assert(!vec.type().is_handle()) << "VectorReduce of handle type"; + // Check the output lanes is a factor of the input lanes. They can + // also both be zero if we're constructing a wildcard expression. + internal_assert((lanes == 0 && vec.type().lanes() == 0) || + (lanes != 0 && (vec.type().lanes() % lanes == 0))) + << "Vector reduce output lanes must be a divisor of the number of lanes in the argument " + << lanes << " " << vec.type().lanes() << "\n"; + VectorReduce *node = new VectorReduce; + node->type = vec.type().with_lanes(lanes); + node->op = op; + node->value = std::move(vec); + return node; +} + namespace { // Helper function to determine if a sequence of indices is a @@ -978,6 +1000,10 @@ void ExprNode::accept(IRVisitor *v) const { v->visit((const Shuffle *)this); } template<> +void ExprNode::accept(IRVisitor *v) const { + v->visit((const VectorReduce *)this); +} +template<> void ExprNode::accept(IRVisitor *v) const { v->visit((const Let *)this); } @@ -1159,6 +1185,10 @@ Expr ExprNode::mutate_expr(IRMutator *v) const { return v->visit((const Shuffle *)this); } template<> +Expr ExprNode::mutate_expr(IRMutator *v) const { + return v->visit((const VectorReduce *)this); +} +template<> Expr ExprNode::mutate_expr(IRMutator *v) const { return v->visit((const Let *)this); } diff --git a/src/IR.h b/src/IR.h index a93961572756..0316be7f61e3 100644 --- a/src/IR.h +++ b/src/IR.h @@ -819,6 +819,33 @@ struct Atomic : public StmtNode { static const IRNodeType _node_type = IRNodeType::Atomic; }; +/** Horizontally reduce a vector to a scalar or narrower vector using + * the given commutative and associative binary operator. The reduction + * factor is dictated by the number of lanes in the input and output + * types. Groups of adjacent lanes are combined. The number of lanes + * in the input type must be a divisor of the number of lanes of the + * output type. */ +struct VectorReduce : public ExprNode { + // 99.9% of the time people will use this for horizontal addition, + // but these are all of our commutative and associative primitive + // operators. + typedef enum { + Add, + Mul, + Min, + Max, + And, + Or, + } Operator; + + Expr value; + Operator op; + + static Expr make(Operator op, Expr vec, int lanes); + + static const IRNodeType _node_type = IRNodeType::VectorReduce; +}; + } // namespace Internal } // namespace Halide diff --git a/src/IREquality.cpp b/src/IREquality.cpp index 229972a71348..93f1be8ebead 100644 --- a/src/IREquality.cpp +++ b/src/IREquality.cpp @@ -97,6 +97,7 @@ class IRComparer : public IRVisitor { void visit(const Shuffle *) override; void visit(const Prefetch *) override; void visit(const Atomic *) override; + void visit(const VectorReduce *) override; }; template @@ -589,6 +590,14 @@ void IRComparer::visit(const Atomic *op) { compare_stmt(s->body, op->body); } +void IRComparer::visit(const VectorReduce *op) { + const VectorReduce *e = expr.as(); + + compare_scalar(op->op, e->op); + // We've already compared types, so it's enough to compare the value + compare_expr(op->value, e->value); +} + } // namespace // Now the methods exposed in the header. diff --git a/src/IRMatch.cpp b/src/IRMatch.cpp index f7ed05fada73..6dfb5b8a2b68 100644 --- a/src/IRMatch.cpp +++ b/src/IRMatch.cpp @@ -282,6 +282,16 @@ class IRMatch : public IRVisitor { result = false; } } + + void visit(const VectorReduce *op) override { + const VectorReduce *e = expr.as(); + if (result && e && op->op == e->op) { + expr = e->value; + op->value.accept(this); + } else { + result = false; + } + } }; bool expr_match(const Expr &pattern, const Expr &expr, vector &matches) { @@ -413,6 +423,10 @@ bool equal_helper(const BaseExprNode &a, const BaseExprNode &b) noexcept { case IRNodeType::Shuffle: return (equal_helper(((const Shuffle &)a).vectors, ((const Shuffle &)b).vectors) && equal_helper(((const Shuffle &)a).indices, ((const Shuffle &)b).indices)); + case IRNodeType::VectorReduce: + return (((const VectorReduce &)a).op == ((const VectorReduce &)b).op && + equal_helper(((const VectorReduce &)a).value, ((const VectorReduce &)b).value)); + // Explicitly list all the Stmts instead of using a default // clause so that if new Exprs are added without being handled // here we get a compile-time error. diff --git a/src/IRMatch.h b/src/IRMatch.h index a452a1c312ba..958e1293cd28 100644 --- a/src/IRMatch.h +++ b/src/IRMatch.h @@ -1548,7 +1548,12 @@ struct BroadcastOp { Expr make(MatcherState &state, halide_type_t type_hint) const { const int l = known_lanes ? lanes : type_hint.lanes; type_hint.lanes = 1; - return Broadcast::make(a.make(state, type_hint), l); + Expr val = a.make(state, type_hint); + if (l == 1) { + return val; + } else { + return Broadcast::make(std::move(val), l); + } } constexpr static bool foldable = false; @@ -1659,6 +1664,109 @@ HALIDE_ALWAYS_INLINE auto ramp(A a, B b) noexcept -> RampOp +struct VectorReduceOp { + struct pattern_tag {}; + A a; + int lanes; + + constexpr static uint32_t binds = bindings::mask; + + constexpr static IRNodeType min_node_type = IRNodeType::VectorReduce; + constexpr static IRNodeType max_node_type = IRNodeType::VectorReduce; + constexpr static bool canonical = A::canonical; + + template + HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept { + if (e.node_type == VectorReduce::_node_type) { + const VectorReduce &op = (const VectorReduce &)e; + if (op.op == reduce_op && + (!known_lanes || lanes == op.type.lanes()) && + a.template match(*op.value.get(), state)) { + return true; + } + } + return false; + } + + template + HALIDE_ALWAYS_INLINE bool match(const VectorReduceOp &op, MatcherState &state) const noexcept { + return (reduce_op == reduce_op_2 && + a.template match(unwrap(op.a), state) && + (lanes == op.lanes || !known_lanes || !known_lanes_2)); + } + + HALIDE_ALWAYS_INLINE + Expr make(MatcherState &state, halide_type_t type_hint) const { + const int l = known_lanes ? lanes : type_hint.lanes; + return VectorReduce::make(reduce_op, a.make(state, type_hint), l); + } + + constexpr static bool foldable = false; +}; + +template +inline std::ostream &operator<<(std::ostream &s, const VectorReduceOp &op) { + s << "vector_reduce(" << reduce_op << ", " << op.a << ", " << op.lanes << ")"; + return s; +} + +template +inline std::ostream &operator<<(std::ostream &s, const VectorReduceOp &op) { + s << "vector_reduce(" << reduce_op << ", " << op.a << ")"; + return s; +} + +template +HALIDE_ALWAYS_INLINE auto h_add(A a, int lanes) noexcept -> VectorReduceOp { + return {pattern_arg(a), lanes}; +} + +template +HALIDE_ALWAYS_INLINE auto h_add(A a) noexcept -> VectorReduceOp { + return {pattern_arg(a), 0}; +} + +template +HALIDE_ALWAYS_INLINE auto h_min(A a, int lanes) noexcept -> VectorReduceOp { + return {pattern_arg(a), lanes}; +} + +template +HALIDE_ALWAYS_INLINE auto h_min(A a) noexcept -> VectorReduceOp { + return {pattern_arg(a), 0}; +} + +template +HALIDE_ALWAYS_INLINE auto h_max(A a, int lanes) noexcept -> VectorReduceOp { + return {pattern_arg(a), lanes}; +} + +template +HALIDE_ALWAYS_INLINE auto h_max(A a) noexcept -> VectorReduceOp { + return {pattern_arg(a), 0}; +} + +template +HALIDE_ALWAYS_INLINE auto h_and(A a, int lanes) noexcept -> VectorReduceOp { + return {pattern_arg(a), lanes}; +} + +template +HALIDE_ALWAYS_INLINE auto h_and(A a) noexcept -> VectorReduceOp { + return {pattern_arg(a), 0}; +} + +template +HALIDE_ALWAYS_INLINE auto h_or(A a, int lanes) noexcept -> VectorReduceOp { + return {pattern_arg(a), lanes}; +} + +template +HALIDE_ALWAYS_INLINE auto h_or(A a) noexcept -> VectorReduceOp { + return {pattern_arg(a), 0}; +} + template struct NegateOp { struct pattern_tag {}; diff --git a/src/IRMutator.cpp b/src/IRMutator.cpp index 0b459d88450a..c2976ec2d17d 100644 --- a/src/IRMutator.cpp +++ b/src/IRMutator.cpp @@ -369,6 +369,14 @@ Expr IRMutator::visit(const Shuffle *op) { return Shuffle::make(new_vectors, op->indices); } +Expr IRMutator::visit(const VectorReduce *op) { + Expr value = mutate(op->value); + if (value.same_as(op->value)) { + return op; + } + return VectorReduce::make(op->op, std::move(value), op->type.lanes()); +} + Stmt IRMutator::visit(const Fork *op) { Stmt first = mutate(op->first); Stmt rest = mutate(op->rest); diff --git a/src/IRMutator.h b/src/IRMutator.h index 64729204247b..4330714e9605 100644 --- a/src/IRMutator.h +++ b/src/IRMutator.h @@ -69,6 +69,7 @@ class IRMutator { virtual Expr visit(const Call *); virtual Expr visit(const Let *); virtual Expr visit(const Shuffle *); + virtual Expr visit(const VectorReduce *); virtual Stmt visit(const LetStmt *); virtual Stmt visit(const AssertStmt *); diff --git a/src/IROperator.cpp b/src/IROperator.cpp index f32e00b86788..92d94ba96c82 100644 --- a/src/IROperator.cpp +++ b/src/IROperator.cpp @@ -463,6 +463,31 @@ Expr lossless_cast(Type t, Expr e) { return Expr(); } } + + if (const VectorReduce *red = e.as()) { + const int factor = red->value.type().lanes() / red->type.lanes(); + switch (red->op) { + case VectorReduce::Add: + if (t.bits() >= 16 && factor < (1 << (t.bits() / 2))) { + Type narrower = red->value.type().with_bits(t.bits() / 2); + Expr val = lossless_cast(narrower, red->value); + if (val.defined()) { + return VectorReduce::make(red->op, val, red->type.lanes()); + } + } + break; + case VectorReduce::Max: + case VectorReduce::Min: { + Expr val = lossless_cast(t, red->value); + if (val.defined()) { + return VectorReduce::make(red->op, val, red->type.lanes()); + } + break; + } + default: + break; + } + } } return Expr(); diff --git a/src/IRPrinter.cpp b/src/IRPrinter.cpp index 3b0b967e05f6..953ba044883b 100644 --- a/src/IRPrinter.cpp +++ b/src/IRPrinter.cpp @@ -274,6 +274,30 @@ ostream &operator<<(ostream &out, const ForType &type) { return out; } +ostream &operator<<(ostream &out, const VectorReduce::Operator &op) { + switch (op) { + case VectorReduce::Add: + out << "Add"; + break; + case VectorReduce::Mul: + out << "Mul"; + break; + case VectorReduce::Min: + out << "Min"; + break; + case VectorReduce::Max: + out << "Max"; + break; + case VectorReduce::And: + out << "And"; + break; + case VectorReduce::Or: + out << "Or"; + break; + } + return out; +} + ostream &operator<<(ostream &out, const NameMangling &m) { switch (m) { case NameMangling::Default: @@ -973,6 +997,16 @@ void IRPrinter::visit(const Shuffle *op) { } } +void IRPrinter::visit(const VectorReduce *op) { + stream << "(" + << op->type + << ")vector_reduce(" + << op->op + << ", " + << op->value + << ")\n"; +} + void IRPrinter::visit(const Atomic *op) { if (op->mutex_name.empty()) { stream << get_indent() << "atomic {\n"; diff --git a/src/IRPrinter.h b/src/IRPrinter.h index b7c78890f4f3..e5a873baed95 100644 --- a/src/IRPrinter.h +++ b/src/IRPrinter.h @@ -71,6 +71,9 @@ std::ostream &operator<<(std::ostream &stream, const Stmt &); * readable form */ std::ostream &operator<<(std::ostream &stream, const ForType &); +/** Emit a horizontal vector reduction op in human-readable form. */ +std::ostream &operator<<(std::ostream &stream, const VectorReduce::Operator &); + /** Emit a halide name mangling value in a human readable format */ std::ostream &operator<<(std::ostream &stream, const NameMangling &); @@ -186,6 +189,7 @@ class IRPrinter : public IRVisitor { void visit(const IfThenElse *) override; void visit(const Evaluate *) override; void visit(const Shuffle *) override; + void visit(const VectorReduce *) override; void visit(const Prefetch *) override; void visit(const Atomic *) override; }; diff --git a/src/IRVisitor.cpp b/src/IRVisitor.cpp index a46ef403feed..0357cb13f2a9 100644 --- a/src/IRVisitor.cpp +++ b/src/IRVisitor.cpp @@ -259,6 +259,10 @@ void IRVisitor::visit(const Shuffle *op) { } } +void IRVisitor::visit(const VectorReduce *op) { + op->value.accept(this); +} + void IRVisitor::visit(const Atomic *op) { op->body.accept(this); } @@ -509,6 +513,10 @@ void IRGraphVisitor::visit(const Shuffle *op) { } } +void IRGraphVisitor::visit(const VectorReduce *op) { + include(op->value); +} + void IRGraphVisitor::visit(const Atomic *op) { include(op->body); } diff --git a/src/IRVisitor.h b/src/IRVisitor.h index 4ef099fa39ea..cf429313275c 100644 --- a/src/IRVisitor.h +++ b/src/IRVisitor.h @@ -72,6 +72,7 @@ class IRVisitor { virtual void visit(const IfThenElse *); virtual void visit(const Evaluate *); virtual void visit(const Shuffle *); + virtual void visit(const VectorReduce *); virtual void visit(const Prefetch *); virtual void visit(const Fork *); virtual void visit(const Acquire *); @@ -141,6 +142,7 @@ class IRGraphVisitor : public IRVisitor { void visit(const IfThenElse *) override; void visit(const Evaluate *) override; void visit(const Shuffle *) override; + void visit(const VectorReduce *) override; void visit(const Prefetch *) override; void visit(const Acquire *) override; void visit(const Fork *) override; @@ -218,6 +220,8 @@ class VariadicVisitor { return ((T *)this)->visit((const Let *)node, std::forward(args)...); case IRNodeType::Shuffle: return ((T *)this)->visit((const Shuffle *)node, std::forward(args)...); + case IRNodeType::VectorReduce: + return ((T *)this)->visit((const VectorReduce *)node, std::forward(args)...); // Explicitly list the Stmt types rather than using a // default case so that when new IR nodes are added we // don't miss them here. @@ -275,6 +279,7 @@ class VariadicVisitor { case IRNodeType::Call: case IRNodeType::Let: case IRNodeType::Shuffle: + case IRNodeType::VectorReduce: internal_error << "Unreachable"; break; case IRNodeType::LetStmt: diff --git a/src/ModulusRemainder.cpp b/src/ModulusRemainder.cpp index d095b4331869..56b3afee9229 100644 --- a/src/ModulusRemainder.cpp +++ b/src/ModulusRemainder.cpp @@ -73,6 +73,7 @@ class ComputeModulusRemainder : public IRVisitor { void visit(const Free *) override; void visit(const Evaluate *) override; void visit(const Shuffle *) override; + void visit(const VectorReduce *) override; void visit(const Prefetch *) override; void visit(const Atomic *) override; }; @@ -489,11 +490,17 @@ void ComputeModulusRemainder::visit(const Let *op) { } void ComputeModulusRemainder::visit(const Shuffle *op) { - // It's possible that scalar expressions are extracting a lane of a vector - don't fail in this case, but stop + // It's possible that scalar expressions are extracting a lane of + // a vector - don't faiql in this case, but stop internal_assert(op->indices.size() == 1) << "modulus_remainder of vector\n"; result = ModulusRemainder{}; } +void ComputeModulusRemainder::visit(const VectorReduce *op) { + internal_assert(op->type.is_scalar()) << "modulus_remainder of vector\n"; + result = ModulusRemainder{}; +} + void ComputeModulusRemainder::visit(const LetStmt *) { internal_error << "modulus_remainder of statement\n"; } diff --git a/src/Monotonic.cpp b/src/Monotonic.cpp index df2c5deca1ab..e39fbc930fe4 100644 --- a/src/Monotonic.cpp +++ b/src/Monotonic.cpp @@ -371,6 +371,24 @@ class MonotonicVisitor : public IRVisitor { result = Monotonic::Constant; } + void visit(const VectorReduce *op) override { + op->value.accept(this); + switch (op->op) { + case VectorReduce::Add: + case VectorReduce::Min: + case VectorReduce::Max: + // These reductions are monotonic in the arg + break; + case VectorReduce::Mul: + case VectorReduce::And: + case VectorReduce::Or: + // These ones are not + if (result != Monotonic::Constant) { + result = Monotonic::Unknown; + } + } + } + void visit(const LetStmt *op) override { internal_error << "Monotonic of statement\n"; } diff --git a/src/Simplify_Exprs.cpp b/src/Simplify_Exprs.cpp index 78cc9229f933..ebddf0c4a52b 100644 --- a/src/Simplify_Exprs.cpp +++ b/src/Simplify_Exprs.cpp @@ -42,6 +42,141 @@ Expr Simplify::visit(const Broadcast *op, ExprInfo *bounds) { } } +Expr Simplify::visit(const VectorReduce *op, ExprInfo *bounds) { + Expr value = mutate(op->value, bounds); + if (bounds && op->type.is_int()) { + int factor = op->value.type().lanes() / op->type.lanes(); + switch (op->op) { + case VectorReduce::Add: + // Alignment of result is the alignment of the arg. Bounds + // of the result can grow according to the reduction + // factor. + if (bounds->min_defined) { + bounds->min *= factor; + } + if (bounds->max_defined) { + bounds->max *= factor; + } + break; + case VectorReduce::Mul: + // Don't try to infer anything about bounds. Leave the + // alignment unchanged even though we could theoretically + // upgrade it. + bounds->min_defined = bounds->max_defined = false; + break; + case VectorReduce::Min: + case VectorReduce::Max: + // Bounds and alignment of the result are just the bounds and alignment of the arg. + break; + case VectorReduce::And: + case VectorReduce::Or: + // For integer types this is a bitwise operator. Don't try + // to infer anything for now. + bounds->min_defined = bounds->max_defined = false; + bounds->alignment = ModulusRemainder{}; + break; + } + }; + + // We can pull multiplications by a broadcast out of horizontal + // additions and do the horizontal addition earlier. This means we + // do the multiplication on a vector with fewer lanes. This + // approach applies whenever we have a distributive law. We'll + // exploit the following distributive laws here: + // - Multiplication distributes over addition + // - min/max distributes over min/max + // - and/or distributes over and/or + + // Further, we can collapse min/max/and/or of a broadcast down to + // a narrower broadcast. + + // TODO: There are other rules we could apply here if they ever + // come up in practice: + // - a horizontal min/max/add of a ramp is a different ramp + // - horizontal add of a broadcast is a broadcast + multiply + // - horizontal reduce of an shuffle_vectors may be simplifiable to the + // underlying op on different shuffle_vectors calls + + const int lanes = op->type.lanes(); + const int arg_lanes = op->value.type().lanes(); + switch (op->op) { + case VectorReduce::Add: { + auto rewrite = IRMatcher::rewriter(IRMatcher::h_add(value, lanes), op->type); + if (rewrite(h_add(x * broadcast(y)), h_add(x, lanes) * broadcast(y, lanes)) || + rewrite(h_add(broadcast(x) * y), h_add(y, lanes) * broadcast(x, lanes))) { + return mutate(rewrite.result, bounds); + } + break; + } + case VectorReduce::Min: { + auto rewrite = IRMatcher::rewriter(IRMatcher::h_min(value, lanes), op->type); + if (rewrite(h_min(min(x, broadcast(y))), min(h_min(x, lanes), broadcast(y, lanes))) || + rewrite(h_min(min(broadcast(x), y)), min(h_min(y, lanes), broadcast(x, lanes))) || + rewrite(h_min(max(x, broadcast(y))), max(h_min(x, lanes), broadcast(y, lanes))) || + rewrite(h_min(max(broadcast(x), y)), max(h_min(y, lanes), broadcast(x, lanes))) || + rewrite(h_min(broadcast(x)), broadcast(x, lanes)) || + rewrite(h_min(ramp(x, y)), x + min(y * (arg_lanes - 1), 0)) || + false) { + return mutate(rewrite.result, bounds); + } + break; + } + case VectorReduce::Max: { + auto rewrite = IRMatcher::rewriter(IRMatcher::h_max(value, lanes), op->type); + if (rewrite(h_max(min(x, broadcast(y))), min(h_max(x, lanes), broadcast(y, lanes))) || + rewrite(h_max(min(broadcast(x), y)), min(h_max(y, lanes), broadcast(x, lanes))) || + rewrite(h_max(max(x, broadcast(y))), max(h_max(x, lanes), broadcast(y, lanes))) || + rewrite(h_max(max(broadcast(x), y)), max(h_max(y, lanes), broadcast(x, lanes))) || + rewrite(h_max(broadcast(x)), broadcast(x, lanes)) || + rewrite(h_max(ramp(x, y)), x + max(y * (arg_lanes - 1), 0)) || + false) { + return mutate(rewrite.result, bounds); + } + break; + } + case VectorReduce::And: { + auto rewrite = IRMatcher::rewriter(IRMatcher::h_and(value, lanes), op->type); + if (rewrite(h_and(x || broadcast(y)), h_and(x, lanes) || broadcast(y, lanes)) || + rewrite(h_and(broadcast(x) || y), h_and(y, lanes) || broadcast(x, lanes)) || + rewrite(h_and(x && broadcast(y)), h_and(x, lanes) && broadcast(y, lanes)) || + rewrite(h_and(broadcast(x) && y), h_and(y, lanes) && broadcast(x, lanes)) || + rewrite(h_and(broadcast(x)), broadcast(x, lanes)) || + rewrite(h_and(ramp(x, y) < broadcast(z)), x + max(y * (arg_lanes - 1), 0) < z) || + rewrite(h_and(ramp(x, y) <= broadcast(z)), x + max(y * (arg_lanes - 1), 0) <= z) || + rewrite(h_and(broadcast(x) < ramp(y, z)), x < y + min(z * (arg_lanes - 1), 0)) || + rewrite(h_and(broadcast(x) < ramp(y, z)), x <= y + min(z * (arg_lanes - 1), 0)) || + false) { + return mutate(rewrite.result, bounds); + } + break; + } + case VectorReduce::Or: { + auto rewrite = IRMatcher::rewriter(IRMatcher::h_or(value, lanes), op->type); + if (rewrite(h_or(x || broadcast(y)), h_or(x, lanes) || broadcast(y, lanes)) || + rewrite(h_or(broadcast(x) || y), h_or(y, lanes) || broadcast(x, lanes)) || + rewrite(h_or(x && broadcast(y)), h_or(x, lanes) && broadcast(y, lanes)) || + rewrite(h_or(broadcast(x) && y), h_or(y, lanes) && broadcast(x, lanes)) || + rewrite(h_or(broadcast(x)), broadcast(x, lanes)) || + rewrite(h_or(ramp(x, y) < broadcast(z)), x + min(y * (arg_lanes - 1), 0) < z) || + rewrite(h_or(ramp(x, y) <= broadcast(z)), x + min(y * (arg_lanes - 1), 0) <= z) || + rewrite(h_or(broadcast(x) < ramp(y, z)), x < y + max(z * (arg_lanes - 1), 0)) || + rewrite(h_or(broadcast(x) < ramp(y, z)), x <= y + max(z * (arg_lanes - 1), 0)) || + false) { + return mutate(rewrite.result, bounds); + } + break; + } + default: + break; + } + + if (value.same_as(op->value)) { + return op; + } else { + return VectorReduce::make(op->op, value, op->type.lanes()); + } +} + Expr Simplify::visit(const Variable *op, ExprInfo *bounds) { if (bounds_and_alignment_info.contains(op->name)) { const ExprInfo &b = bounds_and_alignment_info.get(op->name); diff --git a/src/Simplify_Internal.h b/src/Simplify_Internal.h index bf8862c30dc2..cbfb511873af 100644 --- a/src/Simplify_Internal.h +++ b/src/Simplify_Internal.h @@ -302,6 +302,7 @@ class Simplify : public VariadicVisitor { Expr visit(const Load *op, ExprInfo *bounds); Expr visit(const Call *op, ExprInfo *bounds); Expr visit(const Shuffle *op, ExprInfo *bounds); + Expr visit(const VectorReduce *op, ExprInfo *bounds); Expr visit(const Let *op, ExprInfo *bounds); Stmt visit(const LetStmt *op); Stmt visit(const AssertStmt *op); diff --git a/src/StmtToHtml.cpp b/src/StmtToHtml.cpp index 52fd4675457f..f4766d815296 100644 --- a/src/StmtToHtml.cpp +++ b/src/StmtToHtml.cpp @@ -694,6 +694,13 @@ class StmtToHtml : public IRVisitor { stream << close_span(); } + void visit(const VectorReduce *op) override { + stream << open_span("VectorReduce"); + stream << open_span("Type") << op->type << close_span(); + print_list(symbol("vector_reduce") + "(", {op->op, op->value}, ")"); + stream << close_span(); + } + void visit(const Atomic *op) override { stream << open_div("Atomic"); int id = unique_id(); From 585bdc7c24b4f532b2a6db77e9546e2494ed5319 Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Thu, 11 Jun 2020 15:36:35 -0700 Subject: [PATCH 2/4] Handle combination of atomic() and vectorize() in lowering --- src/Lower.cpp | 2 +- src/ScheduleFunctions.cpp | 19 +- src/VectorizeLoops.cpp | 528 ++++++++++++++++++++++++++++++++++++-- src/VectorizeLoops.h | 2 +- 4 files changed, 522 insertions(+), 29 deletions(-) diff --git a/src/Lower.cpp b/src/Lower.cpp index 52f970f4fac7..62c3aa1e00cb 100644 --- a/src/Lower.cpp +++ b/src/Lower.cpp @@ -335,7 +335,7 @@ Module lower(const vector &output_funcs, << s << "\n\n"; debug(1) << "Vectorizing...\n"; - s = vectorize_loops(s, t); + s = vectorize_loops(s, env, t); s = simplify(s); debug(2) << "Lowering after vectorizing:\n" << s << "\n\n"; diff --git a/src/ScheduleFunctions.cpp b/src/ScheduleFunctions.cpp index 59f4269b3427..2b4ab53d5050 100644 --- a/src/ScheduleFunctions.cpp +++ b/src/ScheduleFunctions.cpp @@ -371,9 +371,22 @@ Stmt build_provide_loop_nest(const map &env, // Make the (multi-dimensional multi-valued) store node. Stmt body = Provide::make(func.name(), values, site); if (def.schedule().atomic()) { // Add atomic node. - // If required, we will allocate a mutex buffer called func.name() + ".mutex" - // The buffer is added in the AddAtomicMutex pass. - body = Atomic::make(func.name(), func.name() + ".mutex", body); + bool any_unordered_parallel = false; + for (auto d : def.schedule().dims()) { + any_unordered_parallel |= is_unordered_parallel(d.for_type); + } + if (any_unordered_parallel) { + // If required, we will allocate a mutex buffer called func.name() + ".mutex" + // The buffer is added in the AddAtomicMutex pass. + body = Atomic::make(func.name(), func.name() + ".mutex", body); + } else { + // No mutex is required if there is no parallelism, and it + // wouldn't work if all parallelism is synchronous + // (e.g. vectorization). Vectorization and the like will + // need to handle atomic nodes specially, by either + // emitting VectorReduce ops or scalarizing. + body = Atomic::make(func.name(), std::string{}, body); + } } // Default schedule/values if there is no specialization diff --git a/src/VectorizeLoops.cpp b/src/VectorizeLoops.cpp index 9bb04cfb9bc8..0231e2592b0e 100644 --- a/src/VectorizeLoops.cpp +++ b/src/VectorizeLoops.cpp @@ -18,6 +18,7 @@ namespace Halide { namespace Internal { +using std::map; using std::pair; using std::string; using std::vector; @@ -170,21 +171,95 @@ Interval bounds_of_lanes(const Expr &e) { } // Take the explicit min and max over the lanes - Expr min_lane = extract_lane(e, 0); - Expr max_lane = min_lane; - for (int i = 1; i < e.type().lanes(); i++) { - Expr next_lane = extract_lane(e, i); - if (e.type().is_bool()) { - min_lane = And::make(min_lane, next_lane); - max_lane = Or::make(max_lane, next_lane); - } else { - min_lane = Min::make(min_lane, next_lane); - max_lane = Max::make(max_lane, next_lane); - } + if (e.type().is_bool()) { + Expr min_lane = VectorReduce::make(VectorReduce::And, e, 1); + Expr max_lane = VectorReduce::make(VectorReduce::Or, e, 1); + return {min_lane, max_lane}; + } else { + Expr min_lane = VectorReduce::make(VectorReduce::Min, e, 1); + Expr max_lane = VectorReduce::make(VectorReduce::Max, e, 1); + return {min_lane, max_lane}; } - return {min_lane, max_lane}; }; +// A ramp with the lanes repeated (e.g. <0 0 2 2 4 4 6 6>) +struct InterleavedRamp { + Expr base, stride; + int lanes, repetitions; +}; + +bool is_interleaved_ramp(const Expr &e, const Scope &scope, InterleavedRamp *result) { + if (const Ramp *r = e.as()) { + result->base = r->base; + result->stride = r->stride; + result->lanes = r->lanes; + result->repetitions = 1; + return true; + } else if (const Broadcast *b = e.as()) { + result->base = b->value; + result->stride = 0; + result->lanes = b->lanes; + result->repetitions = 0; + return true; + } else if (const Add *add = e.as()) { + InterleavedRamp ra; + if (is_interleaved_ramp(add->a, scope, &ra) && + is_interleaved_ramp(add->b, scope, result) && + (ra.repetitions == 0 || + result->repetitions == 0 || + ra.repetitions == result->repetitions)) { + result->base = simplify(result->base + ra.base); + result->stride = simplify(result->stride + ra.stride); + if (!result->repetitions) { + result->repetitions = ra.repetitions; + } + return true; + } + } else if (const Sub *sub = e.as()) { + InterleavedRamp ra; + if (is_interleaved_ramp(sub->a, scope, &ra) && + is_interleaved_ramp(sub->b, scope, result) && + (ra.repetitions == 0 || + result->repetitions == 0 || + ra.repetitions == result->repetitions)) { + result->base = simplify(ra.base - result->base); + result->stride = simplify(ra.stride - result->stride); + if (!result->repetitions) { + result->repetitions = ra.repetitions; + } + return true; + } + } else if (const Mul *mul = e.as()) { + const int64_t *b = nullptr; + if (is_interleaved_ramp(mul->a, scope, result) && + (b = as_const_int(mul->b))) { + result->base = simplify(result->base * (int)(*b)); + result->stride = simplify(result->stride * (int)(*b)); + return true; + } + } else if (const Div *div = e.as
()) { + const int64_t *b = nullptr; + if (is_interleaved_ramp(div->a, scope, result) && + (b = as_const_int(div->b)) && + is_one(result->stride) && + (result->repetitions == 1 || + result->repetitions == 0) && + can_prove((result->base % (int)(*b)) == 0)) { + // TODO: Generalize this. Currently only matches + // ramp(base*b, 1, lanes) / b + // broadcast(base * b, lanes) / b + result->base = simplify(result->base / (int)(*b)); + result->repetitions *= (int)(*b); + return true; + } + } else if (const Variable *var = e.as()) { + if (scope.contains(var->name)) { + return is_interleaved_ramp(scope.get(var->name), scope, result); + } + } + return false; +} + // Allocations inside vectorized loops grow an additional inner // dimension to represent the separate copy of the allocation per // vector lane. This means loads and stores to them need to be @@ -384,6 +459,9 @@ class VectorSubs : public IRMutator { // vectors. Scope scope; + // The same set of Exprs, indexed by the vectorized var name + Scope vector_scope; + // A stack of all containing lets. We need to reinject the scalar // version of them if we scalarize inner code. vector> containing_lets; @@ -613,19 +691,24 @@ class VectorSubs : public IRMutator { // If the value was vectorized by this mutator, add a new name to // the scope for the vectorized value expression. - std::string vectorized_name; + string vectorized_name; if (was_vectorized) { vectorized_name = op->name + widening_suffix; scope.push(op->name, mutated_value); + vector_scope.push(vectorized_name, mutated_value); } Expr mutated_body = mutate(op->body); - if (mutated_value.same_as(op->value) && - mutated_body.same_as(op->body)) { + InterleavedRamp ir; + if (is_interleaved_ramp(mutated_value, vector_scope, &ir)) { + return substitute(vectorized_name, mutated_value, mutated_body); + } else if (mutated_value.same_as(op->value) && + mutated_body.same_as(op->body)) { return op; } else if (was_vectorized) { scope.pop(op->name); + vector_scope.pop(vectorized_name); return Let::make(vectorized_name, mutated_value, mutated_body); } else { return Let::make(op->name, mutated_value, mutated_body); @@ -634,7 +717,7 @@ class VectorSubs : public IRMutator { Stmt visit(const LetStmt *op) override { Expr mutated_value = mutate(op->value); - std::string mutated_name = op->name; + string mutated_name = op->name; // Check if the value was vectorized by this mutator. bool was_vectorized = (!op->value.type().is_vector() && @@ -643,6 +726,7 @@ class VectorSubs : public IRMutator { if (was_vectorized) { mutated_name += widening_suffix; scope.push(op->name, mutated_value); + vector_scope.push(mutated_name, mutated_value); // Also keep track of the original let, in case inner code scalarizes. containing_lets.emplace_back(op->name, op->value); } @@ -652,6 +736,7 @@ class VectorSubs : public IRMutator { if (was_vectorized) { containing_lets.pop_back(); scope.pop(op->name); + vector_scope.pop(mutated_name); // Inner code might have extracted my lanes using // extract_lane, which introduces a shuffle_vector. If @@ -688,8 +773,11 @@ class VectorSubs : public IRMutator { } } - if (mutated_value.same_as(op->value) && - mutated_body.same_as(op->body)) { + InterleavedRamp ir; + if (is_interleaved_ramp(mutated_value, vector_scope, &ir)) { + return substitute(mutated_name, mutated_value, mutated_body); + } else if (mutated_value.same_as(op->value) && + mutated_body.same_as(op->body)) { return op; } else { return LetStmt::make(mutated_name, mutated_value, mutated_body); @@ -893,7 +981,7 @@ class VectorSubs : public IRMutator { } Stmt visit(const Allocate *op) override { - std::vector new_extents; + vector new_extents; Expr new_expr; int lanes = replacement.type().lanes(); @@ -941,6 +1029,171 @@ class VectorSubs : public IRMutator { return Allocate::make(op->name, op->type, op->memory_type, new_extents, op->condition, body, new_expr, op->free_function); } + Stmt visit(const Atomic *op) override { + // Recognize a few special cases that we can handle as within-vector reduction trees. + do { + if (!op->mutex_name.empty()) { + // We can't vectorize over a mutex + break; + } + + // f[x] = f[x] y + const Store *store = op->body.as(); + if (!store) break; + + VectorReduce::Operator reduce_op = VectorReduce::Add; + Expr a, b; + if (const Add *add = store->value.as()) { + a = add->a; + b = add->b; + reduce_op = VectorReduce::Add; + } else if (const Mul *mul = store->value.as()) { + a = mul->a; + b = mul->b; + reduce_op = VectorReduce::Mul; + } else if (const Min *min = store->value.as()) { + a = min->a; + b = min->b; + reduce_op = VectorReduce::Min; + } else if (const Max *max = store->value.as()) { + a = max->a; + b = max->b; + reduce_op = VectorReduce::Max; + } else if (const Cast *cast_op = store->value.as()) { + if (cast_op->type.element_of() == UInt(8) && + cast_op->value.type().is_bool()) { + if (const And *and_op = cast_op->value.as()) { + a = and_op->a; + b = and_op->b; + reduce_op = VectorReduce::And; + } else if (const Or *or_op = cast_op->value.as()) { + a = or_op->a; + b = or_op->b; + reduce_op = VectorReduce::Or; + } + } + } + + if (!a.defined() || !b.defined()) { + break; + } + + // Bools get cast to uint8 for storage. Strip off that + // cast around any load. + if (b.type().is_bool()) { + const Cast *cast_op = b.as(); + if (cast_op) { + b = cast_op->value; + } + } + if (a.type().is_bool()) { + const Cast *cast_op = b.as(); + if (cast_op) { + a = cast_op->value; + } + } + + if (a.as() && !b.as()) { + std::swap(a, b); + } + + // We require b to be a var, because it should have been lifted. + const Variable *var_b = b.as(); + const Load *load_a = a.as(); + + if (!var_b || + !scope.contains(var_b->name) || + !load_a || + load_a->name != store->name || + !is_one(load_a->predicate) || + !is_one(store->predicate)) { + break; + } + + b = scope.get(var_b->name); + Expr store_index = mutate(store->index); + Expr load_index = mutate(load_a->index); + + // The load and store indices must be the same interleaved + // ramp (or the same scalar, in the total reduction case). + InterleavedRamp store_ir, load_ir; + Expr test; + if (store_index.type().is_scalar()) { + test = simplify(load_index == store_index); + } else if (is_interleaved_ramp(store_index, vector_scope, &store_ir) && + is_interleaved_ramp(load_index, vector_scope, &load_ir) && + store_ir.repetitions == load_ir.repetitions && + store_ir.lanes == load_ir.lanes) { + test = simplify(store_ir.base == load_ir.base && + store_ir.stride == load_ir.stride); + } + + if (!test.defined()) { + break; + } + + if (is_zero(test)) { + break; + } else if (!is_one(test)) { + // TODO: try harder by substituting in more things in scope + break; + } + + int output_lanes = 1; + if (store_index.type().is_scalar()) { + // The index doesn't depend on the value being + // vectorized, so it's a total reduction. + + b = VectorReduce::make(reduce_op, b, 1); + } else { + + output_lanes = store_index.type().lanes() / store_ir.repetitions; + + store_index = Ramp::make(store_ir.base, store_ir.stride, output_lanes); + b = VectorReduce::make(reduce_op, b, output_lanes); + } + + Expr new_load = Load::make(load_a->type.with_lanes(output_lanes), + load_a->name, store_index, load_a->image, + load_a->param, const_true(output_lanes), + ModulusRemainder{}); + + switch (reduce_op) { + case VectorReduce::Add: + b = new_load + b; + break; + case VectorReduce::Mul: + b = new_load * b; + break; + case VectorReduce::Min: + b = min(new_load, b); + break; + case VectorReduce::Max: + b = max(new_load, b); + break; + case VectorReduce::And: + b = cast(new_load.type(), cast(b.type(), new_load) && b); + break; + case VectorReduce::Or: + b = cast(new_load.type(), cast(b.type(), new_load) || b); + break; + } + + Stmt s = Store::make(store->name, b, store_index, store->param, + const_true(b.type().lanes()), store->alignment); + + // We may still need the atomic node, if there was more + // parallelism than just the vectorization. + s = Atomic::make(op->producer_name, op->mutex_name, s); + + return s; + } while (0); + + // In the general case, if a whole stmt has to be done + // atomically, we need to serialize. + return scalarize(op); + } + Stmt scalarize(Stmt s) { // Wrap a serial loop around it. Maybe LLVM will have // better luck vectorizing it. @@ -984,8 +1237,6 @@ class VectorSubs : public IRMutator { } } - debug(0) << e << " -> " << result << "\n"; - return result; } @@ -994,6 +1245,172 @@ class VectorSubs : public IRMutator { : var(std::move(v)), replacement(std::move(r)), target(t), in_hexagon(in_hexagon) { widening_suffix = ".x" + std::to_string(replacement.type().lanes()); } +}; // namespace + +class FindVectorizableExprsInAtomicNode : public IRMutator { + // An Atomic node protects all accesses to a given buffer. We + // consider a name "poisoned" if it depends on an access to this + // buffer. We can't lift or vectorize anything that has been + // poisoned. + Scope<> poisoned_names; + bool poison = false; + + using IRMutator::visit; + + template + const T *visit_let(const T *op) { + mutate(op->value); + ScopedBinding<> bind_if(poison, poisoned_names, op->name); + mutate(op->body); + return op; + } + + Stmt visit(const LetStmt *op) override { + return visit_let(op); + } + + Expr visit(const Let *op) override { + return visit_let(op); + } + + Expr visit(const Load *op) override { + // Even if the load is bad, maybe we can lift the index + IRMutator::visit(op); + + poison |= poisoned_names.contains(op->name); + return op; + } + + Expr visit(const Variable *op) override { + poison = poisoned_names.contains(op->name); + return op; + } + + Stmt visit(const Store *op) override { + // A store poisons all subsequent loads, but loads before the + // first store can be lifted. + mutate(op->index); + mutate(op->value); + poisoned_names.push(op->name); + return op; + } + + Expr visit(const Call *op) override { + IRMutator::visit(op); + poison |= !op->is_pure(); + return op; + } + +public: + using IRMutator::mutate; + + Expr mutate(const Expr &e) override { + bool old_poison = poison; + poison = false; + IRMutator::mutate(e); + if (!poison) { + liftable.insert(e); + } + poison |= old_poison; + // We're not actually mutating anything. This class is only a + // mutator so that we can override a generic mutate() method. + return e; + } + + FindVectorizableExprsInAtomicNode(const string &buf, const map &env) { + poisoned_names.push(buf); + auto it = env.find(buf); + if (it != env.end()) { + // Handle tuples + size_t n = it->second.values().size(); + if (n > 1) { + for (size_t i = 0; i < n; i++) { + poisoned_names.push(buf + "." + std::to_string(i)); + } + } + } + } + + std::set liftable; +}; + +class LiftVectorizableExprsOutOfSingleAtomicNode : public IRMutator { + const std::set &liftable; + + using IRMutator::visit; + + template + StmtOrExpr visit_let(const LetStmtOrLet *op) { + if (liftable.count(op->value)) { + // Lift it under its current name to avoid having to + // rewrite the variables in other lifted exprs. + // TODO: duplicate non-overlapping liftable let stmts due to unrolling. + lifted.emplace_back(op->name, op->value); + return mutate(op->body); + } else { + return IRMutator::visit(op); + } + } + + Stmt visit(const LetStmt *op) override { + return visit_let(op); + } + + Expr visit(const Let *op) override { + return visit_let(op); + } + +public: + map already_lifted; + vector> lifted; + + using IRMutator::mutate; + + Expr mutate(const Expr &e) override { + if (liftable.count(e) && !is_const(e) && !e.as()) { + auto it = already_lifted.find(e); + string name; + if (it != already_lifted.end()) { + name = it->second; + } else { + name = unique_name('t'); + lifted.emplace_back(name, e); + already_lifted.emplace(e, name); + } + return Variable::make(e.type(), name); + } else { + return IRMutator::mutate(e); + } + } + + LiftVectorizableExprsOutOfSingleAtomicNode(const std::set &liftable) + : liftable(liftable) { + } +}; + +class LiftVectorizableExprsOutOfAllAtomicNodes : public IRMutator { + using IRMutator::visit; + + Stmt visit(const Atomic *op) override { + FindVectorizableExprsInAtomicNode finder(op->producer_name, env); + finder.mutate(op->body); + LiftVectorizableExprsOutOfSingleAtomicNode lifter(finder.liftable); + Stmt new_body = lifter.mutate(op->body); + new_body = Atomic::make(op->producer_name, op->mutex_name, new_body); + while (!lifter.lifted.empty()) { + auto p = lifter.lifted.back(); + new_body = LetStmt::make(p.first, p.second, new_body); + lifter.lifted.pop_back(); + } + return new_body; + } + + const map &env; + +public: + LiftVectorizableExprsOutOfAllAtomicNodes(const map &env) + : env(env) { + } }; // Vectorize all loops marked as such in a Stmt @@ -1040,10 +1457,73 @@ class VectorizeLoops : public IRMutator { } }; -} // Anonymous namespace +/** Check if all stores in a Stmt are to names in a given scope. Used + by RemoveUnnecessaryAtomics below. */ +class AllStoresInScope : public IRVisitor { + using IRVisitor::visit; + void visit(const Store *op) override { + result = result && s.contains(op->name); + } + +public: + bool result = true; + const Scope<> &s; + AllStoresInScope(const Scope<> &s) + : s(s) { + } +}; +bool all_stores_in_scope(const Stmt &stmt, const Scope<> &scope) { + AllStoresInScope checker(scope); + stmt.accept(&checker); + return checker.result; +} + +/** Drop any atomic nodes protecting buffers that are only accessed + * from a single thread. */ +class RemoveUnnecessaryAtomics : public IRMutator { + using IRMutator::visit; + + // Allocations made from within this same thread + bool in_thread = false; + Scope<> local_allocs; + + Stmt visit(const Allocate *op) override { + ScopedBinding<> bind(local_allocs, op->name); + return IRMutator::visit(op); + } + + Stmt visit(const Atomic *op) override { + if (!in_thread || all_stores_in_scope(op->body, local_allocs)) { + return mutate(op->body); + } else { + return op; + } + } + + Stmt visit(const For *op) override { + if (is_parallel(op->for_type)) { + ScopedValue old_in_thread(in_thread, true); + Scope<> old_local_allocs; + old_local_allocs.swap(local_allocs); + Stmt s = IRMutator::visit(op); + old_local_allocs.swap(local_allocs); + return s; + } else { + return IRMutator::visit(op); + } + } +}; + +} // namespace -Stmt vectorize_loops(const Stmt &s, const Target &t) { - return VectorizeLoops(t).mutate(s); +Stmt vectorize_loops(const Stmt &stmt, const map &env, const Target &t) { + // Limit the scope of atomic nodes to just the necessary stuff. + // TODO: Should this be an earlier pass? It's probably a good idea + // for non-vectorizing stuff too. + Stmt s = LiftVectorizableExprsOutOfAllAtomicNodes(env).mutate(stmt); + s = VectorizeLoops(t).mutate(s); + s = RemoveUnnecessaryAtomics().mutate(s); + return s; } } // namespace Internal diff --git a/src/VectorizeLoops.h b/src/VectorizeLoops.h index 52e7dcb73309..cbde217e608d 100644 --- a/src/VectorizeLoops.h +++ b/src/VectorizeLoops.h @@ -15,7 +15,7 @@ namespace Internal { * them into single statements that operate on vectors. The loops in * question must have constant extent. */ -Stmt vectorize_loops(const Stmt &s, const Target &t); +Stmt vectorize_loops(const Stmt &s, const std::map &env, const Target &t); } // namespace Internal } // namespace Halide From f63c800473642748a6f1207b96acc218c8094ffb Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Thu, 11 Jun 2020 15:37:34 -0700 Subject: [PATCH 3/4] Codegen for VectorReduce IR nodes --- python_bindings/src/PyEnums.cpp | 1 + src/CodeGen_ARM.cpp | 187 ++++++- src/CodeGen_ARM.h | 2 +- src/CodeGen_LLVM.cpp | 338 ++++++++++++- src/CodeGen_LLVM.h | 12 + src/CodeGen_PTX_Dev.cpp | 157 ++++++ src/CodeGen_PTX_Dev.h | 1 + src/CodeGen_X86.cpp | 56 +++ src/CodeGen_X86.h | 2 + src/Target.cpp | 1 + src/Target.h | 1 + src/runtime/HalideRuntime.h | 3 +- src/runtime/aarch64.ll | 510 ++++++++++++++++++++ src/runtime/arm.ll | 313 ++++++++++++ src/runtime/arm_cpu_features.cpp | 2 + src/runtime/ptx_dev.ll | 42 ++ test/correctness/CMakeLists.txt | 3 + test/correctness/atomics.cpp | 20 +- test/correctness/cuda_8_bit_dot_product.cpp | 90 ++++ test/correctness/simd_op_check.cpp | 179 ++++++- test/correctness/simd_op_check.h | 41 +- test/correctness/tuple_vector_reduce.cpp | 107 ++++ test/correctness/vector_reductions.cpp | 126 +++++ test/error/CMakeLists.txt | 1 - test/error/atomics_vectorized_mutex.cpp | 29 -- 25 files changed, 2157 insertions(+), 67 deletions(-) create mode 100644 test/correctness/cuda_8_bit_dot_product.cpp create mode 100644 test/correctness/tuple_vector_reduce.cpp create mode 100644 test/correctness/vector_reductions.cpp delete mode 100644 test/error/atomics_vectorized_mutex.cpp diff --git a/python_bindings/src/PyEnums.cpp b/python_bindings/src/PyEnums.cpp index ed453ca08d6f..26b8866035c6 100644 --- a/python_bindings/src/PyEnums.cpp +++ b/python_bindings/src/PyEnums.cpp @@ -142,6 +142,7 @@ void define_enums(py::module &m) { .value("WasmSignExt", Target::Feature::WasmSignExt) .value("SVE", Target::Feature::SVE) .value("SVE2", Target::Feature::SVE2) + .value("ARMDotProd", Target::Feature::ARMDotProd) .value("FeatureEnd", Target::Feature::FeatureEnd); py::enum_(m, "TypeCode") diff --git a/src/CodeGen_ARM.cpp b/src/CodeGen_ARM.cpp index 5e9809d49210..9e9334f1bc30 100644 --- a/src/CodeGen_ARM.cpp +++ b/src/CodeGen_ARM.cpp @@ -1,6 +1,7 @@ #include #include +#include "CSE.h" #include "CodeGen_ARM.h" #include "ConciseCasts.h" #include "Debug.h" @@ -483,10 +484,6 @@ void CodeGen_ARM::visit(const Div *op) { CodeGen_Posix::visit(op); } -void CodeGen_ARM::visit(const Add *op) { - CodeGen_Posix::visit(op); -} - void CodeGen_ARM::visit(const Sub *op) { if (neon_intrinsics_disabled()) { CodeGen_Posix::visit(op); @@ -1063,6 +1060,184 @@ void CodeGen_ARM::visit(const LE *op) { CodeGen_Posix::visit(op); } +void CodeGen_ARM::codegen_vector_reduce(const VectorReduce *op, const Expr &init) { + if (neon_intrinsics_disabled() || + op->op == VectorReduce::Or || + op->op == VectorReduce::And || + op->op == VectorReduce::Mul) { + CodeGen_Posix::codegen_vector_reduce(op, init); + return; + } + + // ARM has a variety of pairwise reduction ops for +, min, + // max. The versions that do not widen take two 64-bit args and + // return one 64-bit vector of the same type. The versions that + // widen take one arg and return something with half the vector + // lanes and double the bit-width. + + int factor = op->value.type().lanes() / op->type.lanes(); + + // These are the types for which we have reduce intrinsics in the + // runtime. + bool have_reduce_intrinsic = (op->type.is_int() || + op->type.is_uint() || + op->type.is_float()); + + // We don't have 16-bit float or bfloat horizontal ops + if (op->type.is_bfloat() || (op->type.is_float() && op->type.bits() < 32)) { + have_reduce_intrinsic = false; + } + + // Only aarch64 has float64 horizontal ops + if (target.bits == 32 && op->type.element_of() == Float(64)) { + have_reduce_intrinsic = false; + } + + // For 64-bit integers, we only have addition, not min/max + if (op->type.bits() == 64 && + !op->type.is_float() && + op->op != VectorReduce::Add) { + have_reduce_intrinsic = false; + } + + // We only have intrinsics that reduce by a factor of two + if (factor != 2) { + have_reduce_intrinsic = false; + } + + if (have_reduce_intrinsic) { + Expr arg = op->value; + if (op->op == VectorReduce::Add && + op->type.bits() >= 16 && + !op->type.is_float()) { + Type narrower_type = arg.type().with_bits(arg.type().bits() / 2); + Expr narrower = lossless_cast(narrower_type, arg); + if (!narrower.defined() && arg.type().is_int()) { + // We can also safely accumulate from a uint into a + // wider int, because the addition uses at most one + // extra bit. + narrower = lossless_cast(narrower_type.with_code(Type::UInt), arg); + } + if (narrower.defined()) { + arg = narrower; + } + } + int output_bits; + if (target.bits == 32 && arg.type().bits() == op->type.bits()) { + // For the non-widening version, the output must be 64-bit + output_bits = 64; + } else if (op->type.bits() * op->type.lanes() <= 64) { + // No point using the 128-bit version of the instruction if the output is narrow. + output_bits = 64; + } else { + output_bits = 128; + } + + const int output_lanes = output_bits / op->type.bits(); + Type intrin_type = op->type.with_lanes(output_lanes); + Type arg_type = arg.type().with_lanes(output_lanes * 2); + if (op->op == VectorReduce::Add && + arg.type().bits() == op->type.bits() && + arg_type.is_uint()) { + // For non-widening additions, there is only a signed + // version (because it's equivalent). + arg_type = arg_type.with_code(Type::Int); + intrin_type = intrin_type.with_code(Type::Int); + } else if (arg.type().is_uint() && intrin_type.is_int()) { + // Use the uint version + intrin_type = intrin_type.with_code(Type::UInt); + } + + std::stringstream ss; + vector args; + ss << "pairwise_" << op->op << "_" << intrin_type << "_" << arg_type; + Expr accumulator = init; + if (op->op == VectorReduce::Add && + accumulator.defined() && + arg_type.bits() < intrin_type.bits()) { + // We can use the accumulating variant + ss << "_accumulate"; + args.push_back(init); + accumulator = Expr(); + } + args.push_back(arg); + value = call_intrin(op->type, output_lanes, ss.str(), args); + + if (accumulator.defined()) { + // We still have an initial value to take care of + string n = unique_name('t'); + sym_push(n, value); + Expr v = Variable::make(accumulator.type(), n); + switch (op->op) { + case VectorReduce::Add: + accumulator += v; + break; + case VectorReduce::Min: + accumulator = min(accumulator, v); + break; + case VectorReduce::Max: + accumulator = max(accumulator, v); + break; + default: + internal_error << "unreachable"; + } + codegen(accumulator); + sym_pop(n); + } + + return; + } + + // Pattern-match 8-bit dot product instructions available on newer + // ARM cores. + if (target.has_feature(Target::ARMDotProd) && + factor % 4 == 0 && + op->op == VectorReduce::Add && + target.bits == 64 && + (op->type.element_of() == Int(32) || + op->type.element_of() == UInt(32))) { + const Mul *mul = op->value.as(); + if (mul) { + const int input_lanes = mul->type.lanes(); + Expr a = lossless_cast(UInt(8, input_lanes), mul->a); + Expr b = lossless_cast(UInt(8, input_lanes), mul->b); + if (!a.defined()) { + a = lossless_cast(Int(8, input_lanes), mul->a); + b = lossless_cast(Int(8, input_lanes), mul->b); + } + if (a.defined() && b.defined()) { + if (factor != 4) { + Expr equiv = VectorReduce::make(op->op, op->value, input_lanes / 4); + equiv = VectorReduce::make(op->op, equiv, op->type.lanes()); + codegen_vector_reduce(equiv.as(), init); + return; + } + Expr i = init; + if (!i.defined()) { + i = make_zero(op->type); + } + vector args{i, a, b}; + if (op->type.lanes() <= 2) { + if (op->type.is_uint()) { + value = call_intrin(op->type, 2, "llvm.aarch64.neon.udot.v2i32.v8i8", args); + } else { + value = call_intrin(op->type, 2, "llvm.aarch64.neon.sdot.v2i32.v8i8", args); + } + } else { + if (op->type.is_uint()) { + value = call_intrin(op->type, 4, "llvm.aarch64.neon.udot.v4i32.v16i8", args); + } else { + value = call_intrin(op->type, 4, "llvm.aarch64.neon.sdot.v4i32.v16i8", args); + } + } + return; + } + } + } + + CodeGen_Posix::codegen_vector_reduce(op, init); +} + string CodeGen_ARM::mcpu() const { if (target.bits == 32) { if (target.has_feature(Target::ARMv7s)) { @@ -1098,6 +1273,10 @@ string CodeGen_ARM::mattrs() const { arch_flags = "+sve"; } + if (target.has_feature(Target::ARMDotProd)) { + arch_flags += "+dotprod"; + } + if (target.os == Target::IOS || target.os == Target::OSX) { return arch_flags + "+reserve-x18"; } else { diff --git a/src/CodeGen_ARM.h b/src/CodeGen_ARM.h index fc1c3ed848f4..c8d868169296 100644 --- a/src/CodeGen_ARM.h +++ b/src/CodeGen_ARM.h @@ -24,7 +24,6 @@ class CodeGen_ARM : public CodeGen_Posix { /** Nodes for which we want to emit specific neon intrinsics */ // @{ void visit(const Cast *) override; - void visit(const Add *) override; void visit(const Sub *) override; void visit(const Div *) override; void visit(const Mul *) override; @@ -35,6 +34,7 @@ class CodeGen_ARM : public CodeGen_Posix { void visit(const Call *) override; void visit(const LT *) override; void visit(const LE *) override; + void codegen_vector_reduce(const VectorReduce *, const Expr &) override; // @} /** Various patterns to peephole match against */ diff --git a/src/CodeGen_LLVM.cpp b/src/CodeGen_LLVM.cpp index cc26207e3fb9..5fe39359e47c 100644 --- a/src/CodeGen_LLVM.cpp +++ b/src/CodeGen_LLVM.cpp @@ -1394,6 +1394,16 @@ Value *CodeGen_LLVM::codegen(const Expr &e) { value = nullptr; e.accept(this); internal_assert(value) << "Codegen of an expr did not produce an llvm value\n"; + + // Halide's type system doesn't distinguish between scalars and + // vectors of size 1, so if a codegen method returned a vector of + // size one, just extract it out as a scalar. + if (e.type().is_scalar() && + value->getType()->isVectorTy()) { + internal_assert(get_vector_num_elements(value->getType()) == 1); + value = builder->CreateExtractElement(value, ConstantInt::get(i32_t, 0)); + } + // TODO: skip this correctness check for bool vectors, // as eliminate_bool_vectors() will cause a discrepancy for some backends // (eg OpenCL, HVX); for now we're just ignoring the assert, but @@ -1534,6 +1544,27 @@ void CodeGen_LLVM::visit(const Variable *op) { value = sym_get(op->name); } +template +bool CodeGen_LLVM::try_to_fold_vector_reduce(const Op *op) { + const VectorReduce *red = op->a.template as(); + Expr b = op->b; + if (!red) { + red = op->b.template as(); + b = op->a; + } + if (red && + ((std::is_same::value && red->op == VectorReduce::Add) || + (std::is_same::value && red->op == VectorReduce::Min) || + (std::is_same::value && red->op == VectorReduce::Max) || + (std::is_same::value && red->op == VectorReduce::Mul) || + (std::is_same::value && red->op == VectorReduce::And) || + (std::is_same::value && red->op == VectorReduce::Or))) { + codegen_vector_reduce(red, b); + return true; + } + return false; +} + void CodeGen_LLVM::visit(const Add *op) { Type t = upgrade_type_for_arithmetic(op->type); if (t != op->type) { @@ -1541,6 +1572,11 @@ void CodeGen_LLVM::visit(const Add *op) { return; } + // Some backends can fold the add into a vector reduce + if (try_to_fold_vector_reduce(op)) { + return; + } + Value *a = codegen(op->a); Value *b = codegen(op->b); if (op->type.is_float()) { @@ -1581,6 +1617,10 @@ void CodeGen_LLVM::visit(const Mul *op) { return; } + if (try_to_fold_vector_reduce(op)) { + return; + } + Value *a = codegen(op->a); Value *b = codegen(op->b); if (op->type.is_float()) { @@ -1637,6 +1677,10 @@ void CodeGen_LLVM::visit(const Min *op) { return; } + if (try_to_fold_vector_reduce(op)) { + return; + } + string a_name = unique_name('a'); string b_name = unique_name('b'); Expr a = Variable::make(op->a.type(), a_name); @@ -1653,6 +1697,10 @@ void CodeGen_LLVM::visit(const Max *op) { return; } + if (try_to_fold_vector_reduce(op)) { + return; + } + string a_name = unique_name('a'); string b_name = unique_name('b'); Expr a = Variable::make(op->a.type(), a_name); @@ -1768,12 +1816,20 @@ void CodeGen_LLVM::visit(const GE *op) { } void CodeGen_LLVM::visit(const And *op) { + if (try_to_fold_vector_reduce(op)) { + return; + } + Value *a = codegen(op->a); Value *b = codegen(op->b); value = builder->CreateAnd(a, b); } void CodeGen_LLVM::visit(const Or *op) { + if (try_to_fold_vector_reduce(op)) { + return; + } + Value *a = codegen(op->a); Value *b = codegen(op->b); value = builder->CreateOr(a, b); @@ -2352,7 +2408,7 @@ Value *CodeGen_LLVM::codegen_dense_vector_load(const Load *load, Value *vpred) { // For dense vector loads wider than the native vector // width, bust them up into native vectors int load_lanes = load->type.lanes(); - int native_lanes = native_bits / load->type.bits(); + int native_lanes = std::max(1, native_bits / load->type.bits()); vector slices; for (int i = 0; i < load_lanes; i += native_lanes) { int slice_lanes = std::min(native_lanes, load_lanes - i); @@ -4223,11 +4279,251 @@ void CodeGen_LLVM::visit(const Shuffle *op) { } } - if (op->type.is_scalar()) { + if (op->type.is_scalar() && value->getType()->isVectorTy()) { value = builder->CreateExtractElement(value, ConstantInt::get(i32_t, 0)); } } +void CodeGen_LLVM::visit(const VectorReduce *op) { + codegen_vector_reduce(op, Expr()); +} + +void CodeGen_LLVM::codegen_vector_reduce(const VectorReduce *op, const Expr &init) { + Expr val = op->value; + const int output_lanes = op->type.lanes(); + const int native_lanes = native_vector_bits() / op->type.bits(); + const int factor = val.type().lanes() / output_lanes; + + Expr (*binop)(Expr, Expr) = nullptr; + switch (op->op) { + case VectorReduce::Add: + binop = Add::make; + break; + case VectorReduce::Mul: + binop = Mul::make; + break; + case VectorReduce::Min: + binop = Min::make; + break; + case VectorReduce::Max: + binop = Max::make; + break; + case VectorReduce::And: + binop = And::make; + break; + case VectorReduce::Or: + binop = Or::make; + break; + } + + if (op->type.is_bool() && op->op == VectorReduce::Or) { + // Cast to u8, use max, cast back to bool. + Expr equiv = cast(op->value.type().with_bits(8), op->value); + equiv = VectorReduce::make(VectorReduce::Max, equiv, op->type.lanes()); + if (init.defined()) { + equiv = max(equiv, init); + } + equiv = cast(op->type, equiv); + equiv.accept(this); + return; + } + + if (op->type.is_bool() && op->op == VectorReduce::And) { + // Cast to u8, use min, cast back to bool. + Expr equiv = cast(op->value.type().with_bits(8), op->value); + equiv = VectorReduce::make(VectorReduce::Min, equiv, op->type.lanes()); + equiv = cast(op->type, equiv); + if (init.defined()) { + equiv = min(equiv, init); + } + equiv.accept(this); + return; + } + + if (op->type.element_of() == Float(16)) { + Expr equiv = cast(op->value.type().with_bits(32), op->value); + equiv = VectorReduce::make(op->op, equiv, op->type.lanes()); + if (init.defined()) { + equiv = binop(equiv, init); + } + equiv = cast(op->type, equiv); + equiv.accept(this); + return; + } + +#if LLVM_VERSION >= 90 + if (output_lanes == 1) { + const int input_lanes = val.type().lanes(); + const int input_bytes = input_lanes * val.type().bytes(); + const bool llvm_has_intrinsic = + // Must be one of these ops + ((op->op == VectorReduce::Add || + op->op == VectorReduce::Mul || + op->op == VectorReduce::Min || + op->op == VectorReduce::Max) && + // Must be a power of two lanes + (input_lanes >= 2) && + ((input_lanes & (input_lanes - 1)) == 0) && + // int versions exist up to 1024 bits + ((!op->type.is_float() && input_bytes <= 1024) || + // float versions exist up to 16 lanes + input_lanes <= 16) && + // As of the release of llvm 10, the 64-bit experimental total + // reductions don't seem to be done yet on arm. + (val.type().bits() != 64 || + target.arch != Target::ARM)); + + if (llvm_has_intrinsic) { + std::stringstream name; + name << "llvm.experimental.vector.reduce."; + const int bits = op->type.bits(); + bool takes_initial_value = false; + Expr initial_value = init; + if (op->type.is_float()) { + switch (op->op) { + case VectorReduce::Add: + name << "v2.fadd.f" << bits; + takes_initial_value = true; + if (!initial_value.defined()) { + initial_value = make_zero(op->type); + } + break; + case VectorReduce::Mul: + name << "v2.fmul.f" << bits; + takes_initial_value = true; + if (!initial_value.defined()) { + initial_value = make_one(op->type); + } + break; + case VectorReduce::Min: + name << "fmin"; + break; + case VectorReduce::Max: + name << "fmax"; + break; + default: + break; + } + } else if (op->type.is_int() || op->type.is_uint()) { + switch (op->op) { + case VectorReduce::Add: + name << "add"; + break; + case VectorReduce::Mul: + name << "mul"; + break; + case VectorReduce::Min: + name << (op->type.is_int() ? 's' : 'u') << "min"; + break; + case VectorReduce::Max: + name << (op->type.is_int() ? 's' : 'u') << "max"; + break; + default: + break; + } + } + name << ".v" << val.type().lanes() << (op->type.is_float() ? 'f' : 'i') << bits; + + string intrin_name = name.str(); + + vector args; + if (takes_initial_value) { + args.push_back(initial_value); + initial_value = Expr(); + } + args.push_back(op->value); + + // Make sure the declaration exists, or the codegen for + // call will assume that the args should scalarize. + if (!module->getFunction(intrin_name)) { + vector arg_types; + for (const Expr &e : args) { + arg_types.push_back(llvm_type_of(e.type())); + } + FunctionType *func_t = FunctionType::get(llvm_type_of(op->type), arg_types, false); + llvm::Function::Create(func_t, llvm::Function::ExternalLinkage, intrin_name, module.get()); + } + + Expr equiv = Call::make(op->type, intrin_name, args, Call::PureExtern); + if (initial_value.defined()) { + equiv = binop(initial_value, equiv); + } + equiv.accept(this); + return; + } + } +#endif + + if (output_lanes == 1 && + factor > native_lanes && + factor % native_lanes == 0) { + // It's a total reduction of multiple native + // vectors. Start by adding the vectors together. + Expr equiv; + for (int i = 0; i < factor / native_lanes; i++) { + Expr next = Shuffle::make_slice(val, i * native_lanes, 1, native_lanes); + if (equiv.defined()) { + equiv = binop(equiv, next); + } else { + equiv = next; + } + } + equiv = VectorReduce::make(op->op, equiv, 1); + if (init.defined()) { + equiv = binop(equiv, init); + } + equiv = common_subexpression_elimination(equiv); + equiv.accept(this); + return; + } + + if (factor > 2 && ((factor & 1) == 0)) { + // Factor the reduce into multiple stages. If we're going to + // be widening the type by 4x or more we should also factor the + // widening into multiple stages. + Type intermediate_type = op->value.type().with_lanes(op->value.type().lanes() / 2); + Expr equiv = VectorReduce::make(op->op, op->value, intermediate_type.lanes()); + if (op->op == VectorReduce::Add && + (op->type.is_int() || op->type.is_uint()) && + op->type.bits() >= 32) { + Type narrower_type = op->value.type().with_bits(op->type.bits() / 4); + Expr narrower = lossless_cast(narrower_type, op->value); + if (!narrower.defined() && narrower_type.is_int()) { + // Maybe we can narrow to an unsigned int instead. + narrower_type = narrower_type.with_code(Type::UInt); + narrower = lossless_cast(narrower_type, op->value); + } + if (narrower.defined()) { + // Widen it by 2x before the horizontal add + narrower = cast(narrower.type().with_bits(narrower.type().bits() * 2), narrower); + equiv = VectorReduce::make(op->op, narrower, intermediate_type.lanes()); + // Then widen it by 2x again afterwards + equiv = cast(intermediate_type, equiv); + } + } + equiv = VectorReduce::make(op->op, equiv, op->type.lanes()); + if (init.defined()) { + equiv = binop(equiv, init); + } + equiv = common_subexpression_elimination(equiv); + codegen(equiv); + return; + } + + // Extract each slice and combine + Expr equiv = init; + for (int i = 0; i < factor; i++) { + Expr next = Shuffle::make_slice(val, i, factor, val.type().lanes() / factor); + if (equiv.defined()) { + equiv = binop(equiv, next); + } else { + equiv = next; + } + } + equiv = common_subexpression_elimination(equiv); + codegen(equiv); +} // namespace Internal + void CodeGen_LLVM::visit(const Atomic *op) { if (op->mutex_name != "") { internal_assert(!inside_atomic_mutex_node) @@ -4286,16 +4582,19 @@ Value *CodeGen_LLVM::call_intrin(const Type &result_type, int intrin_lanes, arg_values[i] = codegen(args[i]); } - return call_intrin(llvm_type_of(result_type), + llvm::Type *t = llvm_type_of(result_type); + + return call_intrin(t, intrin_lanes, name, arg_values); } Value *CodeGen_LLVM::call_intrin(llvm::Type *result_type, int intrin_lanes, const string &name, vector arg_values) { - internal_assert(result_type->isVectorTy()) << "call_intrin is for vector intrinsics only\n"; - - int arg_lanes = get_vector_num_elements(result_type); + int arg_lanes = 1; + if (result_type->isVectorTy()) { + arg_lanes = get_vector_num_elements(result_type); + } if (intrin_lanes != arg_lanes) { // Cut up each arg into appropriately-sized pieces, call the @@ -4304,17 +4603,24 @@ Value *CodeGen_LLVM::call_intrin(llvm::Type *result_type, int intrin_lanes, for (int start = 0; start < arg_lanes; start += intrin_lanes) { vector args; for (size_t i = 0; i < arg_values.size(); i++) { + int arg_i_lanes = 1; if (arg_values[i]->getType()->isVectorTy()) { - int arg_i_lanes = get_vector_num_elements(arg_values[i]->getType()); - internal_assert(arg_i_lanes >= arg_lanes); + arg_i_lanes = get_vector_num_elements(arg_values[i]->getType()); + } + if (arg_i_lanes >= arg_lanes) { // Horizontally reducing intrinsics may have // arguments that have more lanes than the // result. Assume that the horizontally reduce // neighboring elements... int reduce = arg_i_lanes / arg_lanes; args.push_back(slice_vector(arg_values[i], start * reduce, intrin_lanes * reduce)); - } else { + } else if (arg_i_lanes == 1) { + // It's a scalar arg to an intrinsic that returns + // a vector. Replicate it over the slices. args.push_back(arg_values[i]); + } else { + internal_error << "Argument in call_intrin has " << arg_i_lanes + << " with result type having " << arg_lanes << "\n"; } } @@ -4335,7 +4641,10 @@ Value *CodeGen_LLVM::call_intrin(llvm::Type *result_type, int intrin_lanes, llvm::Function *fn = module->getFunction(name); if (!fn) { - llvm::Type *intrinsic_result_type = VectorType::get(result_type->getScalarType(), intrin_lanes); + llvm::Type *intrinsic_result_type = result_type->getScalarType(); + if (intrin_lanes > 1) { + intrinsic_result_type = VectorType::get(result_type->getScalarType(), intrin_lanes); + } FunctionType *func_t = FunctionType::get(intrinsic_result_type, arg_types, false); fn = llvm::Function::Create(func_t, llvm::Function::ExternalLinkage, name, module.get()); fn->setCallingConv(CallingConv::C); @@ -4350,12 +4659,21 @@ Value *CodeGen_LLVM::call_intrin(llvm::Type *result_type, int intrin_lanes, } Value *CodeGen_LLVM::slice_vector(Value *vec, int start, int size) { + // Force the arg to be an actual vector + if (!vec->getType()->isVectorTy()) { + vec = create_broadcast(vec, 1); + } + int vec_lanes = get_vector_num_elements(vec->getType()); if (start == 0 && size == vec_lanes) { return vec; } + if (size == 1) { + return builder->CreateExtractElement(vec, (uint64_t)start); + } + vector indices(size); for (int i = 0; i < size; i++) { int idx = start + i; diff --git a/src/CodeGen_LLVM.h b/src/CodeGen_LLVM.h index 487f3ba0effe..3983041c231c 100644 --- a/src/CodeGen_LLVM.h +++ b/src/CodeGen_LLVM.h @@ -401,6 +401,7 @@ class CodeGen_LLVM : public IRVisitor { void visit(const IfThenElse *) override; void visit(const Evaluate *) override; void visit(const Shuffle *) override; + void visit(const VectorReduce *) override; void visit(const Prefetch *) override; void visit(const Atomic *) override; // @} @@ -512,6 +513,13 @@ class CodeGen_LLVM : public IRVisitor { virtual bool supports_atomic_add(const Type &t) const; + /** Compile a horizontal reduction that starts with an explicit + * initial value. There are lots of complex ways to peephole + * optimize this pattern, especially with the proliferation of + * dot-product instructions, and they can usefully share logic + * across backends. */ + virtual void codegen_vector_reduce(const VectorReduce *op, const Expr &init); + /** Are we inside an atomic node that uses mutex locks? This is used for detecting deadlocks from nested atomics & illegal vectorization. */ bool inside_atomic_mutex_node; @@ -560,6 +568,10 @@ class CodeGen_LLVM : public IRVisitor { void init_codegen(const std::string &name, bool any_strict_float = false); std::unique_ptr finish_codegen(); + + /** A helper routine for generating folded vector reductions. */ + template + bool try_to_fold_vector_reduce(const Op *op); }; } // namespace Internal diff --git a/src/CodeGen_PTX_Dev.cpp b/src/CodeGen_PTX_Dev.cpp index 20ff7cf97860..008f9a678060 100644 --- a/src/CodeGen_PTX_Dev.cpp +++ b/src/CodeGen_PTX_Dev.cpp @@ -331,6 +331,163 @@ void CodeGen_PTX_Dev::visit(const Atomic *op) { CodeGen_LLVM::visit(op); } +void CodeGen_PTX_Dev::codegen_vector_reduce(const VectorReduce *op, const Expr &init) { + // Pattern match 8/16-bit dot products + + const int input_lanes = op->value.type().lanes(); + const int factor = input_lanes / op->type.lanes(); + const Mul *mul = op->value.as(); + if (op->op == VectorReduce::Add && + mul && + (factor % 4 == 0) && + (op->type.element_of() == Int(32) || + op->type.element_of() == UInt(32))) { + Expr i = init; + if (!i.defined()) { + i = cast(mul->type, 0); + } + // Try to narrow the multiply args to 8-bit + Expr a = mul->a, b = mul->b; + if (op->type.is_uint()) { + a = lossless_cast(UInt(8, input_lanes), a); + b = lossless_cast(UInt(8, input_lanes), b); + } else { + a = lossless_cast(Int(8, input_lanes), a); + b = lossless_cast(Int(8, input_lanes), b); + if (!a.defined()) { + // try uint + a = lossless_cast(UInt(8, input_lanes), mul->a); + } + if (!b.defined()) { + b = lossless_cast(UInt(8, input_lanes), mul->b); + } + } + // If we only managed to narrow one of them, try to narrow the + // other to 16-bit. Swap the args so that it's always 'a'. + Expr a_orig = mul->a; + if (a.defined() && !b.defined()) { + std::swap(a, b); + a_orig = mul->b; + } + if (b.defined() && !a.defined()) { + // Try 16-bit instead + a = lossless_cast(UInt(16, input_lanes), a_orig); + if (!a.defined() && !op->type.is_uint()) { + a = lossless_cast(Int(16, input_lanes), a_orig); + } + } + + if (a.defined() && b.defined()) { + std::ostringstream ss; + if (a.type().bits() == 8) { + ss << "dp4a"; + } else { + ss << "dp2a"; + } + if (a.type().is_int()) { + ss << "_s32"; + } else { + ss << "_u32"; + } + if (b.type().is_int()) { + ss << "_s32"; + } else { + ss << "_u32"; + } + const int a_32_bit_words_per_sum = (factor * a.type().bits()) / 32; + const int b_32_bit_words_per_sum = (factor * b.type().bits()) / 32; + // Reinterpret a and b as 32-bit values with fewer + // lanes. If they're aligned dense loads we should just do a + // different load. + for (Expr *e : {&a, &b}) { + int sub_lanes = 32 / e->type().bits(); + const Load *load = e->as(); + const Ramp *idx = load ? load->index.as() : nullptr; + if (idx && + is_one(idx->stride) && + load->alignment.modulus % sub_lanes == 0 && + load->alignment.remainder % sub_lanes == 0) { + Expr new_idx = simplify(idx->base / sub_lanes); + int load_lanes = input_lanes / sub_lanes; + if (input_lanes > sub_lanes) { + new_idx = Ramp::make(new_idx, 1, load_lanes); + } + *e = Load::make(Int(32, load_lanes), + load->name, + new_idx, + load->image, + load->param, + const_true(load_lanes), + load->alignment / sub_lanes); + } else { + *e = reinterpret(Int(32, input_lanes / sub_lanes), *e); + } + } + string name = ss.str(); + vector result; + for (int l = 0; l < op->type.lanes(); l++) { + // To compute a single lane of the output, we'll + // extract the appropriate slice of the args, which + // have been reinterpreted as 32-bit vectors, then + // call either dp4a or dp2a the appropriate number of + // times, and finally sum the result. + Expr i_slice, a_slice, b_slice; + if (i.type().is_scalar()) { + i_slice = i; + } else { + i_slice = Shuffle::make_extract_element(i, l); + } + if (a.type().is_scalar()) { + a_slice = a; + } else { + a_slice = Shuffle::make_slice(a, l * a_32_bit_words_per_sum, 1, a_32_bit_words_per_sum); + } + if (b.type().is_scalar()) { + b_slice = b; + } else { + b_slice = Shuffle::make_slice(b, l * b_32_bit_words_per_sum, 1, b_32_bit_words_per_sum); + } + for (int i = 0; i < b_32_bit_words_per_sum; i++) { + if (a_slice.type().lanes() == b_slice.type().lanes()) { + Expr a_lane, b_lane; + if (b_slice.type().is_scalar()) { + a_lane = a_slice; + b_lane = b_slice; + } else { + a_lane = Shuffle::make_extract_element(a_slice, i); + b_lane = Shuffle::make_extract_element(b_slice, i); + } + i_slice = Call::make(i_slice.type(), name, + {a_lane, b_lane, i_slice}, + Call::PureExtern); + } else { + internal_assert(a_slice.type().lanes() == 2 * b_slice.type().lanes()); + Expr a_lane_lo, a_lane_hi, b_lane; + if (b_slice.type().is_scalar()) { + b_lane = b_slice; + } else { + b_lane = Shuffle::make_extract_element(b_slice, i); + } + a_lane_lo = Shuffle::make_extract_element(a_slice, 2 * i); + a_lane_hi = Shuffle::make_extract_element(a_slice, 2 * i + 1); + i_slice = Call::make(i_slice.type(), name, + {a_lane_lo, a_lane_hi, b_lane, i_slice}, + Call::PureExtern); + } + } + i_slice = simplify(i_slice); + i_slice = common_subexpression_elimination(i_slice); + result.push_back(i_slice); + } + // Concatenate the per-lane results to get the full vector result + Expr equiv = Shuffle::make_concat(result); + equiv.accept(this); + return; + } + } + CodeGen_LLVM::codegen_vector_reduce(op, init); +} + string CodeGen_PTX_Dev::march() const { return "nvptx64"; } diff --git a/src/CodeGen_PTX_Dev.h b/src/CodeGen_PTX_Dev.h index 94925d0a4e0a..7f7c80669f47 100644 --- a/src/CodeGen_PTX_Dev.h +++ b/src/CodeGen_PTX_Dev.h @@ -65,6 +65,7 @@ class CodeGen_PTX_Dev : public CodeGen_LLVM, public CodeGen_GPU_Dev { void visit(const Load *) override; void visit(const Store *) override; void visit(const Atomic *) override; + void codegen_vector_reduce(const VectorReduce *op, const Expr &init) override; // @} std::string march() const; diff --git a/src/CodeGen_X86.cpp b/src/CodeGen_X86.cpp index 884301c342a8..e8821d01a824 100644 --- a/src/CodeGen_X86.cpp +++ b/src/CodeGen_X86.cpp @@ -107,6 +107,33 @@ void CodeGen_X86::visit(const Sub *op) { } } +void CodeGen_X86::visit(const Mul *op) { + +#if LLVM_VERSION < 110 + // Widening integer multiply of non-power-of-two vector sizes is + // broken in older llvms for older x86: + // https://bugs.llvm.org/show_bug.cgi?id=44976 + const int lanes = op->type.lanes(); + if (!target.has_feature(Target::SSE41) && + (lanes & (lanes - 1)) && + (op->type.bits() >= 32) && + !op->type.is_float()) { + // Any fancy shuffles to pad or slice into smaller vectors + // just gets undone by LLVM and retriggers the bug. Just + // scalarize. + vector result; + for (int i = 0; i < lanes; i++) { + result.emplace_back(Shuffle::make_extract_element(op->a, i) * + Shuffle::make_extract_element(op->b, i)); + } + codegen(Shuffle::make_concat(result)); + return; + } +#endif + + return CodeGen_Posix::visit(op); +} + void CodeGen_X86::visit(const GT *op) { Type t = op->a.type(); @@ -390,6 +417,35 @@ void CodeGen_X86::visit(const Call *op) { CodeGen_Posix::visit(op); } +void CodeGen_X86::visit(const VectorReduce *op) { + const int factor = op->value.type().lanes() / op->type.lanes(); + + // Match pmaddwd. X86 doesn't have many horizontal reduction ops, + // and the ones that exist are hit by llvm automatically using the + // base class lowering of VectorReduce (see + // test/correctness/simd_op_check.cpp). + if (const Mul *mul = op->value.as()) { + Type narrower = Int(16, mul->type.lanes()); + Expr a = lossless_cast(narrower, mul->a); + Expr b = lossless_cast(narrower, mul->b); + if (op->type.is_int() && + op->type.bits() == 32 && + a.defined() && + b.defined() && + factor == 2 && + op->op == VectorReduce::Add) { + if (target.has_feature(Target::AVX2) && op->type.lanes() > 4) { + value = call_intrin(op->type, 8, "llvm.x86.avx2.pmadd.wd", {a, b}); + } else { + value = call_intrin(op->type, 4, "llvm.x86.sse2.pmadd.wd", {a, b}); + } + return; + } + } + + CodeGen_Posix::visit(op); +} + string CodeGen_X86::mcpu() const { if (target.has_feature(Target::AVX512_Cannonlake)) return "cannonlake"; if (target.has_feature(Target::AVX512_Skylake)) return "skylake-avx512"; diff --git a/src/CodeGen_X86.h b/src/CodeGen_X86.h index b76ffdeea359..e6e579ec071b 100644 --- a/src/CodeGen_X86.h +++ b/src/CodeGen_X86.h @@ -47,6 +47,8 @@ class CodeGen_X86 : public CodeGen_Posix { void visit(const EQ *) override; void visit(const NE *) override; void visit(const Select *) override; + void visit(const VectorReduce *) override; + void visit(const Mul *) override; // @} }; diff --git a/src/Target.cpp b/src/Target.cpp index 32d436526255..6591f646931e 100644 --- a/src/Target.cpp +++ b/src/Target.cpp @@ -360,6 +360,7 @@ const std::map feature_name_map = { {"wasm_signext", Target::WasmSignExt}, {"sve", Target::SVE}, {"sve2", Target::SVE2}, + {"arm_dot_prod", Target::ARMDotProd}, // NOTE: When adding features to this map, be sure to update PyEnums.cpp as well. }; diff --git a/src/Target.h b/src/Target.h index 182a27fa9418..4786c01de336 100644 --- a/src/Target.h +++ b/src/Target.h @@ -119,6 +119,7 @@ struct Target { WasmSignExt = halide_target_feature_wasm_signext, SVE = halide_target_feature_sve, SVE2 = halide_target_feature_sve2, + ARMDotProd = halide_target_feature_arm_dot_prod, FeatureEnd = halide_target_feature_end }; Target() diff --git a/src/runtime/HalideRuntime.h b/src/runtime/HalideRuntime.h index 9161c8da9678..1fcdf5f476a0 100644 --- a/src/runtime/HalideRuntime.h +++ b/src/runtime/HalideRuntime.h @@ -1313,7 +1313,8 @@ typedef enum halide_target_feature_t { halide_target_feature_sve2, ///< Enable ARM Scalable Vector Extensions v2 halide_target_feature_egl, ///< Force use of EGL support. - halide_target_feature_end ///< A sentinel. Every target is considered to have this feature, and setting this feature does nothing. + halide_target_feature_arm_dot_prod, ///< Enable ARMv8.2-a dotprod extension (i.e. udot and sdot instructions) + halide_target_feature_end ///< A sentinel. Every target is considered to have this feature, and setting this feature does nothing. } halide_target_feature_t; /** This function is called internally by Halide in some situations to determine diff --git a/src/runtime/aarch64.ll b/src/runtime/aarch64.ll index 8adb25eca59d..1472ddfc700d 100644 --- a/src/runtime/aarch64.ll +++ b/src/runtime/aarch64.ll @@ -231,3 +231,513 @@ define weak_odr <4 x float> @fast_inverse_sqrt_f32x4(<4 x float> %x) nounwind al %result = fmul <4 x float> %approx, %correction ret <4 x float> %result } + +; The way llvm represents intrinsics for horizontal addition are +; somewhat ad-hoc, and can be incompatible with the way we slice up +; intrinsics to meet the native vector width. We define wrappers for +; everything here instead. + +declare <2 x double> @llvm.aarch64.neon.faddp.v2f64(<2 x double>, <2 x double>) nounwind readnone +declare <2 x float> @llvm.aarch64.neon.faddp.v2f32(<2 x float>, <2 x float>) nounwind readnone +declare <2 x i32> @llvm.aarch64.neon.addp.v2i32(<2 x i32>, <2 x i32>) nounwind readnone +declare <2 x i64> @llvm.aarch64.neon.addp.v2i64(<2 x i64>, <2 x i64>) nounwind readnone +declare <4 x float> @llvm.aarch64.neon.faddp.v4f32(<4 x float>, <4 x float>) nounwind readnone +declare <4 x i16> @llvm.aarch64.neon.addp.v4i16(<4 x i16>, <4 x i16>) nounwind readnone +declare <4 x i32> @llvm.aarch64.neon.addp.v4i32(<4 x i32>, <4 x i32>) nounwind readnone +declare <8 x i16> @llvm.aarch64.neon.addp.v8i16(<8 x i16>, <8 x i16>) nounwind readnone +declare <8 x i8> @llvm.aarch64.neon.addp.v8i8(<8 x i8>, <8 x i8>) nounwind readnone +declare <16 x i8> @llvm.aarch64.neon.addp.v16i8(<16 x i8>, <16 x i8>) nounwind readnone + +define weak_odr <8 x i8> @pairwise_Add_int8x8_int8x16(<16 x i8> %x) nounwind alwaysinline { + %a = shufflevector <16 x i8> %x, <16 x i8> undef, <8 x i32> + %b = shufflevector <16 x i8> %x, <16 x i8> undef, <8 x i32> + %result = tail call <8 x i8> @llvm.aarch64.neon.addp.v8i8(<8 x i8> %a, <8 x i8> %b) + ret <8 x i8> %result +} + +define weak_odr <16 x i8> @pairwise_Add_int8x16_int8x32(<32 x i8> %x) nounwind alwaysinline { + %a = shufflevector <32 x i8> %x, <32 x i8> undef, <16 x i32> + %b = shufflevector <32 x i8> %x, <32 x i8> undef, <16 x i32> + %result = tail call <16 x i8> @llvm.aarch64.neon.addp.v16i8(<16 x i8> %a, <16 x i8> %b) + ret <16 x i8> %result +} + +define weak_odr <4 x i16> @pairwise_Add_int16x4_int16x8(<8 x i16> %x) nounwind alwaysinline { + %a = shufflevector <8 x i16> %x, <8 x i16> undef, <4 x i32> + %b = shufflevector <8 x i16> %x, <8 x i16> undef, <4 x i32> + %result = tail call <4 x i16> @llvm.aarch64.neon.addp.v4i16(<4 x i16> %a, <4 x i16> %b) + ret <4 x i16> %result +} + +define weak_odr <8 x i16> @pairwise_Add_int16x8_int16x16(<16 x i16> %x) nounwind alwaysinline { + %a = shufflevector <16 x i16> %x, <16 x i16> undef, <8 x i32> + %b = shufflevector <16 x i16> %x, <16 x i16> undef, <8 x i32> + %result = tail call <8 x i16> @llvm.aarch64.neon.addp.v8i16(<8 x i16> %a, <8 x i16> %b) + ret <8 x i16> %result +} + +define weak_odr <4 x i32> @pairwise_Add_int32x4_int32x8(<8 x i32> %x) nounwind alwaysinline { + %a = shufflevector <8 x i32> %x, <8 x i32> undef, <4 x i32> + %b = shufflevector <8 x i32> %x, <8 x i32> undef, <4 x i32> + %result = tail call <4 x i32> @llvm.aarch64.neon.addp.v4i32(<4 x i32> %a, <4 x i32> %b) + ret <4 x i32> %result +} + +define weak_odr <2 x i32> @pairwise_Add_int32x2_int32x4(<4 x i32> %x) nounwind alwaysinline { + %a = shufflevector <4 x i32> %x, <4 x i32> undef, <2 x i32> + %b = shufflevector <4 x i32> %x, <4 x i32> undef, <2 x i32> + %result = tail call <2 x i32> @llvm.aarch64.neon.addp.v2i32(<2 x i32> %a, <2 x i32> %b) + ret <2 x i32> %result +} + +define weak_odr i64 @pairwise_Add_int64_int64x2(<2 x i64> %x) nounwind alwaysinline { + %result = tail call <2 x i64> @llvm.aarch64.neon.addp.v2i64(<2 x i64> %x, <2 x i64> undef) + %scalar = extractelement <2 x i64> %result, i32 0 + ret i64 %scalar +} + +define weak_odr <2 x i64> @pairwise_Add_int64x2_int64x4(<4 x i64> %x) nounwind alwaysinline { + %a = shufflevector <4 x i64> %x, <4 x i64> undef, <2 x i32> + %b = shufflevector <4 x i64> %x, <4 x i64> undef, <2 x i32> + %result = tail call <2 x i64> @llvm.aarch64.neon.addp.v2i64(<2 x i64> %a, <2 x i64> %b) + ret <2 x i64> %result +} + +define weak_odr <4 x float> @pairwise_Add_float32x4_float32x8(<8 x float> %x) nounwind alwaysinline { + %a = shufflevector <8 x float> %x, <8 x float> undef, <4 x i32> + %b = shufflevector <8 x float> %x, <8 x float> undef, <4 x i32> + %result = tail call <4 x float> @llvm.aarch64.neon.faddp.v4f32(<4 x float> %a, <4 x float> %b) + ret <4 x float> %result +} + +define weak_odr <2 x float> @pairwise_Add_float32x2_float32x4(<4 x float> %x) nounwind alwaysinline { + %a = shufflevector <4 x float> %x, <4 x float> undef, <2 x i32> + %b = shufflevector <4 x float> %x, <4 x float> undef, <2 x i32> + %result = tail call <2 x float> @llvm.aarch64.neon.faddp.v2f32(<2 x float> %a, <2 x float> %b) + ret <2 x float> %result +} + +define weak_odr double @pairwise_Add_float64_float64x2(<2 x double> %x) nounwind alwaysinline { + %result = tail call <2 x double> @llvm.aarch64.neon.faddp.v2f64(<2 x double> %x, <2 x double> undef) + %scalar = extractelement <2 x double> %result, i32 0 + ret double %scalar +} + +define weak_odr <2 x double> @pairwise_Add_float64x2_float64x4(<4 x double> %x) nounwind alwaysinline { + %a = shufflevector <4 x double> %x, <4 x double> undef, <2 x i32> + %b = shufflevector <4 x double> %x, <4 x double> undef, <2 x i32> + %result = tail call <2 x double> @llvm.aarch64.neon.faddp.v2f64(<2 x double> %a, <2 x double> %b) + ret <2 x double> %result +} + + +declare <1 x i64> @llvm.aarch64.neon.saddlp.v1i64.v2i32(<2 x i32>) nounwind readnone +declare <1 x i64> @llvm.aarch64.neon.uaddlp.v1i64.v2i32(<2 x i32>) nounwind readnone +declare <2 x i32> @llvm.aarch64.neon.saddlp.v2i32.v4i16(<4 x i16>) nounwind readnone +declare <2 x i32> @llvm.aarch64.neon.uaddlp.v2i32.v4i16(<4 x i16>) nounwind readnone +declare <2 x i64> @llvm.aarch64.neon.saddlp.v2i64.v4i32(<4 x i32>) nounwind readnone +declare <2 x i64> @llvm.aarch64.neon.uaddlp.v2i64.v4i32(<4 x i32>) nounwind readnone +declare <4 x i16> @llvm.aarch64.neon.saddlp.v4i16.v8i8(<8 x i8>) nounwind readnone +declare <4 x i16> @llvm.aarch64.neon.uaddlp.v4i16.v8i8(<8 x i8>) nounwind readnone +declare <4 x i32> @llvm.aarch64.neon.saddlp.v4i32.v8i16(<8 x i16>) nounwind readnone +declare <4 x i32> @llvm.aarch64.neon.uaddlp.v4i32.v8i16(<8 x i16>) nounwind readnone +declare <8 x i16> @llvm.aarch64.neon.saddlp.v8i16.v16i8(<16 x i8>) nounwind readnone +declare <8 x i16> @llvm.aarch64.neon.uaddlp.v8i16.v16i8(<16 x i8>) nounwind readnone + + +define weak_odr <8 x i16> @pairwise_Add_int16x8_int8x16(<16 x i8> %x) nounwind alwaysinline { + %result = tail call <8 x i16> @llvm.aarch64.neon.saddlp.v8i16.v16i8(<16 x i8> %x) + ret <8 x i16> %result +} + +define weak_odr <4 x i16> @pairwise_Add_int16x4_int8x8(<8 x i8> %x) nounwind alwaysinline { + %result = tail call <4 x i16> @llvm.aarch64.neon.saddlp.v4i16.v8i8(<8 x i8> %x) + ret <4 x i16> %result +} + +define weak_odr <4 x i32> @pairwise_Add_int32x4_int16x8(<8 x i16> %x) nounwind alwaysinline { + %result = tail call <4 x i32> @llvm.aarch64.neon.saddlp.v4i32.v8i16(<8 x i16> %x) + ret <4 x i32> %result +} + +define weak_odr <2 x i32> @pairwise_Add_int32x2_int16x4(<4 x i16> %x) nounwind alwaysinline { + %result = tail call <2 x i32> @llvm.aarch64.neon.saddlp.v2i32.v4i16(<4 x i16> %x) + ret <2 x i32> %result +} + +define weak_odr <2 x i64> @pairwise_Add_int64x2_int32x4(<4 x i32> %x) nounwind alwaysinline { + %result = tail call <2 x i64> @llvm.aarch64.neon.saddlp.v2i64.v4i32(<4 x i32> %x) + ret <2 x i64> %result +} + +define weak_odr i64 @pairwise_Add_int64_int32x2(<2 x i32> %x) nounwind alwaysinline { + %result = tail call <1 x i64> @llvm.aarch64.neon.saddlp.v1i64.v2i32(<2 x i32> %x) + %scalar = extractelement <1 x i64> %result, i32 0 + ret i64 %scalar +} + +define weak_odr <8 x i16> @pairwise_Add_uint16x8_uint8x16(<16 x i8> %x) nounwind alwaysinline { + %result = tail call <8 x i16> @llvm.aarch64.neon.uaddlp.v8i16.v16i8(<16 x i8> %x) + ret <8 x i16> %result +} + +define weak_odr <4 x i16> @pairwise_Add_uint16x4_uint8x8(<8 x i8> %x) nounwind alwaysinline { + %result = tail call <4 x i16> @llvm.aarch64.neon.uaddlp.v4i16.v8i8(<8 x i8> %x) + ret <4 x i16> %result +} + +define weak_odr <4 x i32> @pairwise_Add_uint32x4_uint16x8(<8 x i16> %x) nounwind alwaysinline { + %result = tail call <4 x i32> @llvm.aarch64.neon.uaddlp.v4i32.v8i16(<8 x i16> %x) + ret <4 x i32> %result +} + +define weak_odr <2 x i32> @pairwise_Add_uint32x2_uint16x4(<4 x i16> %x) nounwind alwaysinline { + %result = tail call <2 x i32> @llvm.aarch64.neon.uaddlp.v2i32.v4i16(<4 x i16> %x) + ret <2 x i32> %result +} + +define weak_odr <2 x i64> @pairwise_Add_uint64x2_uint32x4(<4 x i32> %x) nounwind alwaysinline { + %result = tail call <2 x i64> @llvm.aarch64.neon.uaddlp.v2i64.v4i32(<4 x i32> %x) + ret <2 x i64> %result +} + +define weak_odr i64 @pairwise_Add_uint64_uint32x2(<2 x i32> %x) nounwind alwaysinline { + %result = tail call <1 x i64> @llvm.aarch64.neon.uaddlp.v1i64.v2i32(<2 x i32> %x) + %scalar = extractelement <1 x i64> %result, i32 0 + ret i64 %scalar +} + +define weak_odr <8 x i16> @pairwise_Add_int16x8_int8x16_accumulate(<8 x i16> %a, <16 x i8> %x) nounwind alwaysinline { + %y = tail call <8 x i16> @llvm.aarch64.neon.saddlp.v8i16.v16i8(<16 x i8> %x) + %result = add <8 x i16> %a, %y + ret <8 x i16> %result +} + +define weak_odr <4 x i16> @pairwise_Add_int16x4_int8x8_accumulate(<4 x i16> %a, <8 x i8> %x) nounwind alwaysinline { + %y = tail call <4 x i16> @llvm.aarch64.neon.saddlp.v4i16.v8i8(<8 x i8> %x) + %result = add <4 x i16> %a, %y + ret <4 x i16> %result +} + +define weak_odr <4 x i32> @pairwise_Add_int32x4_int16x8_accumulate(<4 x i32> %a, <8 x i16> %x) nounwind alwaysinline { + %y = tail call <4 x i32> @llvm.aarch64.neon.saddlp.v4i32.v8i16(<8 x i16> %x) + %result = add <4 x i32> %a, %y + ret <4 x i32> %result +} + +define weak_odr <2 x i32> @pairwise_Add_int32x2_int16x4_accumulate(<2 x i32> %a, <4 x i16> %x) nounwind alwaysinline { + %y = tail call <2 x i32> @llvm.aarch64.neon.saddlp.v2i32.v4i16(<4 x i16> %x) + %result = add <2 x i32> %a, %y + ret <2 x i32> %result +} + +define weak_odr <2 x i64> @pairwise_Add_int64x2_int32x4_accumulate(<2 x i64> %a, <4 x i32> %x) nounwind alwaysinline { + %y = tail call <2 x i64> @llvm.aarch64.neon.saddlp.v2i64.v4i32(<4 x i32> %x) + %result = add <2 x i64> %a, %y + ret <2 x i64> %result +} + +define weak_odr i64 @pairwise_Add_int64_int32x2_accumulate(i64 %a, <2 x i32> %x) nounwind alwaysinline { + %y = tail call <1 x i64> @llvm.aarch64.neon.saddlp.v1i64.v2i32(<2 x i32> %x) + %scalar = extractelement <1 x i64> %y, i32 0 + %result = add i64 %a, %scalar + ret i64 %result +} + +define weak_odr <8 x i16> @pairwise_Add_uint16x8_uint8x16_accumulate(<8 x i16> %a, <16 x i8> %x) nounwind alwaysinline { + %y = tail call <8 x i16> @llvm.aarch64.neon.uaddlp.v8i16.v16i8(<16 x i8> %x) + %result = add <8 x i16> %a, %y + ret <8 x i16> %result +} + +define weak_odr <4 x i16> @pairwise_Add_uint16x4_uint8x8_accumulate(<4 x i16> %a, <8 x i8> %x) nounwind alwaysinline { + %y = tail call <4 x i16> @llvm.aarch64.neon.uaddlp.v4i16.v8i8(<8 x i8> %x) + %result = add <4 x i16> %a, %y + ret <4 x i16> %result +} + +define weak_odr <4 x i32> @pairwise_Add_uint32x4_uint16x8_accumulate(<4 x i32> %a, <8 x i16> %x) nounwind alwaysinline { + %y = tail call <4 x i32> @llvm.aarch64.neon.uaddlp.v4i32.v8i16(<8 x i16> %x) + %result = add <4 x i32> %a, %y + ret <4 x i32> %result +} + +define weak_odr <2 x i32> @pairwise_Add_uint32x2_uint16x4_accumulate(<2 x i32> %a, <4 x i16> %x) nounwind alwaysinline { + %y = tail call <2 x i32> @llvm.aarch64.neon.uaddlp.v2i32.v4i16(<4 x i16> %x) + %result = add <2 x i32> %a, %y + ret <2 x i32> %result +} + +define weak_odr <2 x i64> @pairwise_Add_uint64x2_uint32x4_accumulate(<2 x i64> %a, <4 x i32> %x) nounwind alwaysinline { + %y = tail call <2 x i64> @llvm.aarch64.neon.uaddlp.v2i64.v4i32(<4 x i32> %x) + %result = add <2 x i64> %a, %y + ret <2 x i64> %result +} + +define weak_odr i64 @pairwise_Add_uint64_uint32x2_accumulate(i64 %a, <2 x i32> %x) nounwind alwaysinline { + %y = tail call <1 x i64> @llvm.aarch64.neon.uaddlp.v1i64.v2i32(<2 x i32> %x) + %scalar = extractelement <1 x i64> %y, i32 0 + %result = add i64 %a, %scalar + ret i64 %result +} + + + +declare <16 x i8> @llvm.aarch64.neon.smaxp.v16i8(<16 x i8>, <16 x i8>) nounwind readnone +declare <16 x i8> @llvm.aarch64.neon.umaxp.v16i8(<16 x i8>, <16 x i8>) nounwind readnone +declare <2 x double> @llvm.aarch64.neon.fmaxp.v2f64(<2 x double>, <2 x double>) nounwind readnone +declare <2 x float> @llvm.aarch64.neon.fmaxp.v2f32(<2 x float>, <2 x float>) nounwind readnone +declare <2 x i32> @llvm.aarch64.neon.smaxp.v2i32(<2 x i32>, <2 x i32>) nounwind readnone +declare <2 x i32> @llvm.aarch64.neon.umaxp.v2i32(<2 x i32>, <2 x i32>) nounwind readnone +declare <4 x float> @llvm.aarch64.neon.fmaxp.v4f32(<4 x float>, <4 x float>) nounwind readnone +declare <4 x i16> @llvm.aarch64.neon.smaxp.v4i16(<4 x i16>, <4 x i16>) nounwind readnone +declare <4 x i16> @llvm.aarch64.neon.umaxp.v4i16(<4 x i16>, <4 x i16>) nounwind readnone +declare <4 x i32> @llvm.aarch64.neon.smaxp.v4i32(<4 x i32>, <4 x i32>) nounwind readnone +declare <4 x i32> @llvm.aarch64.neon.umaxp.v4i32(<4 x i32>, <4 x i32>) nounwind readnone +declare <8 x i16> @llvm.aarch64.neon.smaxp.v8i16(<8 x i16>, <8 x i16>) nounwind readnone +declare <8 x i16> @llvm.aarch64.neon.umaxp.v8i16(<8 x i16>, <8 x i16>) nounwind readnone +declare <8 x i8> @llvm.aarch64.neon.smaxp.v8i8(<8 x i8>, <8 x i8>) nounwind readnone +declare <8 x i8> @llvm.aarch64.neon.umaxp.v8i8(<8 x i8>, <8 x i8>) nounwind readnone + +define weak_odr <8 x i8> @pairwise_Max_int8x8_int8x16(<16 x i8> %x) nounwind alwaysinline { + %a = shufflevector <16 x i8> %x, <16 x i8> undef, <8 x i32> + %b = shufflevector <16 x i8> %x, <16 x i8> undef, <8 x i32> + %result = tail call <8 x i8> @llvm.aarch64.neon.smaxp.v8i8(<8 x i8> %a, <8 x i8> %b) + ret <8 x i8> %result +} + +define weak_odr <16 x i8> @pairwise_Max_int8x16_int8x32(<32 x i8> %x) nounwind alwaysinline { + %a = shufflevector <32 x i8> %x, <32 x i8> undef, <16 x i32> + %b = shufflevector <32 x i8> %x, <32 x i8> undef, <16 x i32> + %result = tail call <16 x i8> @llvm.aarch64.neon.smaxp.v16i8(<16 x i8> %a, <16 x i8> %b) + ret <16 x i8> %result +} + +define weak_odr <4 x i16> @pairwise_Max_int16x4_int16x8(<8 x i16> %x) nounwind alwaysinline { + %a = shufflevector <8 x i16> %x, <8 x i16> undef, <4 x i32> + %b = shufflevector <8 x i16> %x, <8 x i16> undef, <4 x i32> + %result = tail call <4 x i16> @llvm.aarch64.neon.smaxp.v4i16(<4 x i16> %a, <4 x i16> %b) + ret <4 x i16> %result +} + +define weak_odr <8 x i16> @pairwise_Max_int16x8_int16x16(<16 x i16> %x) nounwind alwaysinline { + %a = shufflevector <16 x i16> %x, <16 x i16> undef, <8 x i32> + %b = shufflevector <16 x i16> %x, <16 x i16> undef, <8 x i32> + %result = tail call <8 x i16> @llvm.aarch64.neon.smaxp.v8i16(<8 x i16> %a, <8 x i16> %b) + ret <8 x i16> %result +} + +define weak_odr <4 x i32> @pairwise_Max_int32x4_int32x8(<8 x i32> %x) nounwind alwaysinline { + %a = shufflevector <8 x i32> %x, <8 x i32> undef, <4 x i32> + %b = shufflevector <8 x i32> %x, <8 x i32> undef, <4 x i32> + %result = tail call <4 x i32> @llvm.aarch64.neon.smaxp.v4i32(<4 x i32> %a, <4 x i32> %b) + ret <4 x i32> %result +} + +define weak_odr <2 x i32> @pairwise_Max_int32x2_int32x4(<4 x i32> %x) nounwind alwaysinline { + %a = shufflevector <4 x i32> %x, <4 x i32> undef, <2 x i32> + %b = shufflevector <4 x i32> %x, <4 x i32> undef, <2 x i32> + %result = tail call <2 x i32> @llvm.aarch64.neon.smaxp.v2i32(<2 x i32> %a, <2 x i32> %b) + ret <2 x i32> %result +} + +define weak_odr <4 x float> @pairwise_Max_float32x4_float32x8(<8 x float> %x) nounwind alwaysinline { + %a = shufflevector <8 x float> %x, <8 x float> undef, <4 x i32> + %b = shufflevector <8 x float> %x, <8 x float> undef, <4 x i32> + %result = tail call <4 x float> @llvm.aarch64.neon.fmaxp.v4f32(<4 x float> %a, <4 x float> %b) + ret <4 x float> %result +} + +define weak_odr <2 x float> @pairwise_Max_float32x2_float32x4(<4 x float> %x) nounwind alwaysinline { + %a = shufflevector <4 x float> %x, <4 x float> undef, <2 x i32> + %b = shufflevector <4 x float> %x, <4 x float> undef, <2 x i32> + %result = tail call <2 x float> @llvm.aarch64.neon.fmaxp.v2f32(<2 x float> %a, <2 x float> %b) + ret <2 x float> %result +} + + +define weak_odr <2 x double> @pairwise_Max_float64x2_float64x4(<4 x double> %x) nounwind alwaysinline { + %a = shufflevector <4 x double> %x, <4 x double> undef, <2 x i32> + %b = shufflevector <4 x double> %x, <4 x double> undef, <2 x i32> + %result = tail call <2 x double> @llvm.aarch64.neon.fmaxp.v2f64(<2 x double> %a, <2 x double> %b) + ret <2 x double> %result +} + +define weak_odr double @pairwise_Max_float64_float64x2(<2 x double> %x) nounwind alwaysinline { + %result = tail call <2 x double> @llvm.aarch64.neon.fmaxp.v2f64(<2 x double> %x, <2 x double> undef) + %scalar = extractelement <2 x double> %result, i32 0 + ret double %scalar +} + + +define weak_odr <8 x i8> @pairwise_Max_uint8x8_uint8x16(<16 x i8> %x) nounwind alwaysinline { + %a = shufflevector <16 x i8> %x, <16 x i8> undef, <8 x i32> + %b = shufflevector <16 x i8> %x, <16 x i8> undef, <8 x i32> + %result = tail call <8 x i8> @llvm.aarch64.neon.umaxp.v8i8(<8 x i8> %a, <8 x i8> %b) + ret <8 x i8> %result +} + +define weak_odr <16 x i8> @pairwise_Max_uint8x16_uint8x32(<32 x i8> %x) nounwind alwaysinline { + %a = shufflevector <32 x i8> %x, <32 x i8> undef, <16 x i32> + %b = shufflevector <32 x i8> %x, <32 x i8> undef, <16 x i32> + %result = tail call <16 x i8> @llvm.aarch64.neon.umaxp.v16i8(<16 x i8> %a, <16 x i8> %b) + ret <16 x i8> %result +} + +define weak_odr <4 x i16> @pairwise_Max_uint16x4_uint16x8(<8 x i16> %x) nounwind alwaysinline { + %a = shufflevector <8 x i16> %x, <8 x i16> undef, <4 x i32> + %b = shufflevector <8 x i16> %x, <8 x i16> undef, <4 x i32> + %result = tail call <4 x i16> @llvm.aarch64.neon.umaxp.v4i16(<4 x i16> %a, <4 x i16> %b) + ret <4 x i16> %result +} + +define weak_odr <8 x i16> @pairwise_Max_uint16x8_uint16x16(<16 x i16> %x) nounwind alwaysinline { + %a = shufflevector <16 x i16> %x, <16 x i16> undef, <8 x i32> + %b = shufflevector <16 x i16> %x, <16 x i16> undef, <8 x i32> + %result = tail call <8 x i16> @llvm.aarch64.neon.umaxp.v8i16(<8 x i16> %a, <8 x i16> %b) + ret <8 x i16> %result +} + +define weak_odr <4 x i32> @pairwise_Max_uint32x4_uint32x8(<8 x i32> %x) nounwind alwaysinline { + %a = shufflevector <8 x i32> %x, <8 x i32> undef, <4 x i32> + %b = shufflevector <8 x i32> %x, <8 x i32> undef, <4 x i32> + %result = tail call <4 x i32> @llvm.aarch64.neon.umaxp.v4i32(<4 x i32> %a, <4 x i32> %b) + ret <4 x i32> %result +} + +define weak_odr <2 x i32> @pairwise_Max_uint32x2_uint32x4(<4 x i32> %x) nounwind alwaysinline { + %a = shufflevector <4 x i32> %x, <4 x i32> undef, <2 x i32> + %b = shufflevector <4 x i32> %x, <4 x i32> undef, <2 x i32> + %result = tail call <2 x i32> @llvm.aarch64.neon.umaxp.v2i32(<2 x i32> %a, <2 x i32> %b) + ret <2 x i32> %result +} + + +declare <16 x i8> @llvm.aarch64.neon.sminp.v16i8(<16 x i8>, <16 x i8>) nounwind readnone +declare <16 x i8> @llvm.aarch64.neon.uminp.v16i8(<16 x i8>, <16 x i8>) nounwind readnone +declare <2 x double> @llvm.aarch64.neon.fminp.v2f64(<2 x double>, <2 x double>) nounwind readnone +declare <2 x float> @llvm.aarch64.neon.fminp.v2f32(<2 x float>, <2 x float>) nounwind readnone +declare <2 x i32> @llvm.aarch64.neon.sminp.v2i32(<2 x i32>, <2 x i32>) nounwind readnone +declare <2 x i32> @llvm.aarch64.neon.uminp.v2i32(<2 x i32>, <2 x i32>) nounwind readnone +declare <4 x float> @llvm.aarch64.neon.fminp.v4f32(<4 x float>, <4 x float>) nounwind readnone +declare <4 x i16> @llvm.aarch64.neon.sminp.v4i16(<4 x i16>, <4 x i16>) nounwind readnone +declare <4 x i16> @llvm.aarch64.neon.uminp.v4i16(<4 x i16>, <4 x i16>) nounwind readnone +declare <4 x i32> @llvm.aarch64.neon.sminp.v4i32(<4 x i32>, <4 x i32>) nounwind readnone +declare <4 x i32> @llvm.aarch64.neon.uminp.v4i32(<4 x i32>, <4 x i32>) nounwind readnone +declare <8 x i16> @llvm.aarch64.neon.sminp.v8i16(<8 x i16>, <8 x i16>) nounwind readnone +declare <8 x i16> @llvm.aarch64.neon.uminp.v8i16(<8 x i16>, <8 x i16>) nounwind readnone +declare <8 x i8> @llvm.aarch64.neon.sminp.v8i8(<8 x i8>, <8 x i8>) nounwind readnone +declare <8 x i8> @llvm.aarch64.neon.uminp.v8i8(<8 x i8>, <8 x i8>) nounwind readnone + +define weak_odr <8 x i8> @pairwise_Min_int8x8_int8x16(<16 x i8> %x) nounwind alwaysinline { + %a = shufflevector <16 x i8> %x, <16 x i8> undef, <8 x i32> + %b = shufflevector <16 x i8> %x, <16 x i8> undef, <8 x i32> + %result = tail call <8 x i8> @llvm.aarch64.neon.sminp.v8i8(<8 x i8> %a, <8 x i8> %b) + ret <8 x i8> %result +} + +define weak_odr <16 x i8> @pairwise_Min_int8x16_int8x32(<32 x i8> %x) nounwind alwaysinline { + %a = shufflevector <32 x i8> %x, <32 x i8> undef, <16 x i32> + %b = shufflevector <32 x i8> %x, <32 x i8> undef, <16 x i32> + %result = tail call <16 x i8> @llvm.aarch64.neon.sminp.v16i8(<16 x i8> %a, <16 x i8> %b) + ret <16 x i8> %result +} + +define weak_odr <4 x i16> @pairwise_Min_int16x4_int16x8(<8 x i16> %x) nounwind alwaysinline { + %a = shufflevector <8 x i16> %x, <8 x i16> undef, <4 x i32> + %b = shufflevector <8 x i16> %x, <8 x i16> undef, <4 x i32> + %result = tail call <4 x i16> @llvm.aarch64.neon.sminp.v4i16(<4 x i16> %a, <4 x i16> %b) + ret <4 x i16> %result +} + +define weak_odr <8 x i16> @pairwise_Min_int16x8_int16x16(<16 x i16> %x) nounwind alwaysinline { + %a = shufflevector <16 x i16> %x, <16 x i16> undef, <8 x i32> + %b = shufflevector <16 x i16> %x, <16 x i16> undef, <8 x i32> + %result = tail call <8 x i16> @llvm.aarch64.neon.sminp.v8i16(<8 x i16> %a, <8 x i16> %b) + ret <8 x i16> %result +} + +define weak_odr <4 x i32> @pairwise_Min_int32x4_int32x8(<8 x i32> %x) nounwind alwaysinline { + %a = shufflevector <8 x i32> %x, <8 x i32> undef, <4 x i32> + %b = shufflevector <8 x i32> %x, <8 x i32> undef, <4 x i32> + %result = tail call <4 x i32> @llvm.aarch64.neon.sminp.v4i32(<4 x i32> %a, <4 x i32> %b) + ret <4 x i32> %result +} + +define weak_odr <2 x i32> @pairwise_Min_int32x2_int32x4(<4 x i32> %x) nounwind alwaysinline { + %a = shufflevector <4 x i32> %x, <4 x i32> undef, <2 x i32> + %b = shufflevector <4 x i32> %x, <4 x i32> undef, <2 x i32> + %result = tail call <2 x i32> @llvm.aarch64.neon.sminp.v2i32(<2 x i32> %a, <2 x i32> %b) + ret <2 x i32> %result +} + + +define weak_odr <4 x float> @pairwise_Min_float32x4_float32x8(<8 x float> %x) nounwind alwaysinline { + %a = shufflevector <8 x float> %x, <8 x float> undef, <4 x i32> + %b = shufflevector <8 x float> %x, <8 x float> undef, <4 x i32> + %result = tail call <4 x float> @llvm.aarch64.neon.fminp.v4f32(<4 x float> %a, <4 x float> %b) + ret <4 x float> %result +} + +define weak_odr <2 x float> @pairwise_Min_float32x2_float32x4(<4 x float> %x) nounwind alwaysinline { + %a = shufflevector <4 x float> %x, <4 x float> undef, <2 x i32> + %b = shufflevector <4 x float> %x, <4 x float> undef, <2 x i32> + %result = tail call <2 x float> @llvm.aarch64.neon.fminp.v2f32(<2 x float> %a, <2 x float> %b) + ret <2 x float> %result +} + + +define weak_odr <2 x double> @pairwise_Min_float64x2_float64x4(<4 x double> %x) nounwind alwaysinline { + %a = shufflevector <4 x double> %x, <4 x double> undef, <2 x i32> + %b = shufflevector <4 x double> %x, <4 x double> undef, <2 x i32> + %result = tail call <2 x double> @llvm.aarch64.neon.fminp.v2f64(<2 x double> %a, <2 x double> %b) + ret <2 x double> %result +} + +define weak_odr double @pairwise_Min_float64_float64x2(<2 x double> %x) nounwind alwaysinline { + %result = tail call <2 x double> @llvm.aarch64.neon.fminp.v2f64(<2 x double> %x, <2 x double> undef) + %scalar = extractelement <2 x double> %result, i32 0 + ret double %scalar +} + +define weak_odr <8 x i8> @pairwise_Min_uint8x8_uint8x16(<16 x i8> %x) nounwind alwaysinline { + %a = shufflevector <16 x i8> %x, <16 x i8> undef, <8 x i32> + %b = shufflevector <16 x i8> %x, <16 x i8> undef, <8 x i32> + %result = tail call <8 x i8> @llvm.aarch64.neon.uminp.v8i8(<8 x i8> %a, <8 x i8> %b) + ret <8 x i8> %result +} + +define weak_odr <16 x i8> @pairwise_Min_uint8x16_uint8x32(<32 x i8> %x) nounwind alwaysinline { + %a = shufflevector <32 x i8> %x, <32 x i8> undef, <16 x i32> + %b = shufflevector <32 x i8> %x, <32 x i8> undef, <16 x i32> + %result = tail call <16 x i8> @llvm.aarch64.neon.uminp.v16i8(<16 x i8> %a, <16 x i8> %b) + ret <16 x i8> %result +} + +define weak_odr <4 x i16> @pairwise_Min_uint16x4_uint16x8(<8 x i16> %x) nounwind alwaysinline { + %a = shufflevector <8 x i16> %x, <8 x i16> undef, <4 x i32> + %b = shufflevector <8 x i16> %x, <8 x i16> undef, <4 x i32> + %result = tail call <4 x i16> @llvm.aarch64.neon.uminp.v4i16(<4 x i16> %a, <4 x i16> %b) + ret <4 x i16> %result +} + +define weak_odr <8 x i16> @pairwise_Min_uint16x8_uint16x16(<16 x i16> %x) nounwind alwaysinline { + %a = shufflevector <16 x i16> %x, <16 x i16> undef, <8 x i32> + %b = shufflevector <16 x i16> %x, <16 x i16> undef, <8 x i32> + %result = tail call <8 x i16> @llvm.aarch64.neon.uminp.v8i16(<8 x i16> %a, <8 x i16> %b) + ret <8 x i16> %result +} + +define weak_odr <4 x i32> @pairwise_Min_uint32x4_uint32x8(<8 x i32> %x) nounwind alwaysinline { + %a = shufflevector <8 x i32> %x, <8 x i32> undef, <4 x i32> + %b = shufflevector <8 x i32> %x, <8 x i32> undef, <4 x i32> + %result = tail call <4 x i32> @llvm.aarch64.neon.uminp.v4i32(<4 x i32> %a, <4 x i32> %b) + ret <4 x i32> %result +} + +define weak_odr <2 x i32> @pairwise_Min_uint32x2_uint32x4(<4 x i32> %x) nounwind alwaysinline { + %a = shufflevector <4 x i32> %x, <4 x i32> undef, <2 x i32> + %b = shufflevector <4 x i32> %x, <4 x i32> undef, <2 x i32> + %result = tail call <2 x i32> @llvm.aarch64.neon.uminp.v2i32(<2 x i32> %a, <2 x i32> %b) + ret <2 x i32> %result +} diff --git a/src/runtime/arm.ll b/src/runtime/arm.ll index ee3f69a86518..42b590e514ca 100644 --- a/src/runtime/arm.ll +++ b/src/runtime/arm.ll @@ -398,3 +398,316 @@ define weak_odr void @strided_store_f32x4(float * %ptr, i32 %stride, <4 x float> ret void } +; The way llvm represents intrinsics for horizontal addition are +; somewhat ad-hoc, and can be incompatible with the way we slice up +; intrinsics to meet the native vector width. We define wrappers for +; everything here instead. + +declare <2 x float> @llvm.arm.neon.vpadd.v2f32(<2 x float>, <2 x float>) nounwind readnone +declare <2 x i32> @llvm.arm.neon.vpadd.v2i32(<2 x i32>, <2 x i32>) nounwind readnone +declare <4 x i16> @llvm.arm.neon.vpadd.v4i16(<4 x i16>, <4 x i16>) nounwind readnone +declare <8 x i8> @llvm.arm.neon.vpadd.v8i8(<8 x i8>, <8 x i8>) nounwind readnone + +define weak_odr <8 x i8> @pairwise_Add_int8x8_int8x16(<16 x i8> %x) nounwind alwaysinline { + %a = shufflevector <16 x i8> %x, <16 x i8> undef, <8 x i32> + %b = shufflevector <16 x i8> %x, <16 x i8> undef, <8 x i32> + %result = tail call <8 x i8> @llvm.arm.neon.vpadd.v8i8(<8 x i8> %a, <8 x i8> %b) + ret <8 x i8> %result +} + +define weak_odr <4 x i16> @pairwise_Add_int16x4_int16x8(<8 x i16> %x) nounwind alwaysinline { + %a = shufflevector <8 x i16> %x, <8 x i16> undef, <4 x i32> + %b = shufflevector <8 x i16> %x, <8 x i16> undef, <4 x i32> + %result = tail call <4 x i16> @llvm.arm.neon.vpadd.v4i16(<4 x i16> %a, <4 x i16> %b) + ret <4 x i16> %result +} + +define weak_odr <2 x i32> @pairwise_Add_int32x2_int32x4(<4 x i32> %x) nounwind alwaysinline { + %a = shufflevector <4 x i32> %x, <4 x i32> undef, <2 x i32> + %b = shufflevector <4 x i32> %x, <4 x i32> undef, <2 x i32> + %result = tail call <2 x i32> @llvm.arm.neon.vpadd.v2i32(<2 x i32> %a, <2 x i32> %b) + ret <2 x i32> %result +} + +define weak_odr <2 x float> @pairwise_Add_float32x2_float32x4(<4 x float> %x) nounwind alwaysinline { + %a = shufflevector <4 x float> %x, <4 x float> undef, <2 x i32> + %b = shufflevector <4 x float> %x, <4 x float> undef, <2 x i32> + %result = tail call <2 x float> @llvm.arm.neon.vpadd.v2f32(<2 x float> %a, <2 x float> %b) + ret <2 x float> %result +} + +declare <1 x i64> @llvm.arm.neon.vpaddls.v1i64.v2i32(<2 x i32>) nounwind readnone +declare <1 x i64> @llvm.arm.neon.vpaddlu.v1i64.v2i32(<2 x i32>) nounwind readnone +declare <2 x i32> @llvm.arm.neon.vpaddls.v2i32.v4i16(<4 x i16>) nounwind readnone +declare <2 x i32> @llvm.arm.neon.vpaddlu.v2i32.v4i16(<4 x i16>) nounwind readnone +declare <2 x i64> @llvm.arm.neon.vpaddls.v2i64.v4i32(<4 x i32>) nounwind readnone +declare <2 x i64> @llvm.arm.neon.vpaddlu.v2i64.v4i32(<4 x i32>) nounwind readnone +declare <4 x i16> @llvm.arm.neon.vpaddls.v4i16.v8i8(<8 x i8>) nounwind readnone +declare <4 x i16> @llvm.arm.neon.vpaddlu.v4i16.v8i8(<8 x i8>) nounwind readnone +declare <4 x i32> @llvm.arm.neon.vpaddls.v4i32.v8i16(<8 x i16>) nounwind readnone +declare <4 x i32> @llvm.arm.neon.vpaddlu.v4i32.v8i16(<8 x i16>) nounwind readnone +declare <8 x i16> @llvm.arm.neon.vpaddls.v8i16.v16i8(<16 x i8>) nounwind readnone +declare <8 x i16> @llvm.arm.neon.vpaddlu.v8i16.v16i8(<16 x i8>) nounwind readnone + +define weak_odr <8 x i16> @pairwise_Add_int16x8_int8x16(<16 x i8> %x) nounwind alwaysinline { + %result = tail call <8 x i16> @llvm.arm.neon.vpaddls.v8i16.v16i8(<16 x i8> %x) + ret <8 x i16> %result +} + +define weak_odr <4 x i16> @pairwise_Add_int16x4_int8x8(<8 x i8> %x) nounwind alwaysinline { + %result = tail call <4 x i16> @llvm.arm.neon.vpaddls.v4i16.v8i8(<8 x i8> %x) + ret <4 x i16> %result +} + +define weak_odr <4 x i32> @pairwise_Add_int32x4_int16x8(<8 x i16> %x) nounwind alwaysinline { + %result = tail call <4 x i32> @llvm.arm.neon.vpaddls.v4i32.v8i16(<8 x i16> %x) + ret <4 x i32> %result +} + +define weak_odr <2 x i32> @pairwise_Add_int32x2_int16x4(<4 x i16> %x) nounwind alwaysinline { + %result = tail call <2 x i32> @llvm.arm.neon.vpaddls.v2i32.v4i16(<4 x i16> %x) + ret <2 x i32> %result +} + +define weak_odr <2 x i64> @pairwise_Add_int64x2_int32x4(<4 x i32> %x) nounwind alwaysinline { + %result = tail call <2 x i64> @llvm.arm.neon.vpaddls.v2i64.v4i32(<4 x i32> %x) + ret <2 x i64> %result +} + +define weak_odr i64 @pairwise_Add_int64_int32x2(<2 x i32> %x) nounwind alwaysinline { + %result = tail call <1 x i64> @llvm.arm.neon.vpaddls.v1i64.v2i32(<2 x i32> %x) + %scalar = extractelement <1 x i64> %result, i32 0 + ret i64 %scalar +} + +define weak_odr i64 @pairwise_Add_int64_int64x2(<2 x i64> %x) nounwind alwaysinline { + ; There's no intrinsic for this on arm32, but we include an implementation for completeness. + %a = extractelement <2 x i64> %x, i32 0 + %b = extractelement <2 x i64> %x, i32 1 + %result = add i64 %a, %b + ret i64 %result +} + +define weak_odr <8 x i16> @pairwise_Add_uint16x8_uint8x16(<16 x i8> %x) nounwind alwaysinline { + %result = tail call <8 x i16> @llvm.arm.neon.vpaddlu.v8i16.v16i8(<16 x i8> %x) + ret <8 x i16> %result +} + +define weak_odr <4 x i16> @pairwise_Add_uint16x4_uint8x8(<8 x i8> %x) nounwind alwaysinline { + %result = tail call <4 x i16> @llvm.arm.neon.vpaddlu.v4i16.v8i8(<8 x i8> %x) + ret <4 x i16> %result +} + +define weak_odr <4 x i32> @pairwise_Add_uint32x4_uint16x8(<8 x i16> %x) nounwind alwaysinline { + %result = tail call <4 x i32> @llvm.arm.neon.vpaddlu.v4i32.v8i16(<8 x i16> %x) + ret <4 x i32> %result +} + +define weak_odr <2 x i32> @pairwise_Add_uint32x2_uint16x4(<4 x i16> %x) nounwind alwaysinline { + %result = tail call <2 x i32> @llvm.arm.neon.vpaddlu.v2i32.v4i16(<4 x i16> %x) + ret <2 x i32> %result +} + +define weak_odr <2 x i64> @pairwise_Add_uint64x2_uint32x4(<4 x i32> %x) nounwind alwaysinline { + %result = tail call <2 x i64> @llvm.arm.neon.vpaddlu.v2i64.v4i32(<4 x i32> %x) + ret <2 x i64> %result +} + +define weak_odr i64 @pairwise_Add_uint64_uint32x2(<2 x i32> %x) nounwind alwaysinline { + %result = tail call <1 x i64> @llvm.arm.neon.vpaddlu.v1i64.v2i32(<2 x i32> %x) + %scalar = extractelement <1 x i64> %result, i32 0 + ret i64 %scalar +} + +declare <4 x i16> @llvm.arm.neon.vpadals.v4i16.v8i8(<4 x i16>, <8 x i8>) nounwind readnone +declare <2 x i32> @llvm.arm.neon.vpadals.v2i32.v4i16(<2 x i32>, <4 x i16>) nounwind readnone +declare <1 x i64> @llvm.arm.neon.vpadals.v1i64.v2i32(<1 x i64>, <2 x i32>) nounwind readnone +declare <4 x i16> @llvm.arm.neon.vpadalu.v4i16.v8i8(<4 x i16>, <8 x i8>) nounwind readnone +declare <2 x i32> @llvm.arm.neon.vpadalu.v2i32.v4i16(<2 x i32>, <4 x i16>) nounwind readnone +declare <1 x i64> @llvm.arm.neon.vpadalu.v1i64.v2i32(<1 x i64>, <2 x i32>) nounwind readnone +declare <8 x i16> @llvm.arm.neon.vpadals.v8i16.v16i8(<8 x i16>, <16 x i8>) nounwind readnone +declare <4 x i32> @llvm.arm.neon.vpadals.v4i32.v8i16(<4 x i32>, <8 x i16>) nounwind readnone +declare <2 x i64> @llvm.arm.neon.vpadals.v2i64.v4i32(<2 x i64>, <4 x i32>) nounwind readnone +declare <8 x i16> @llvm.arm.neon.vpadalu.v8i16.v16i8(<8 x i16>, <16 x i8>) nounwind readnone +declare <4 x i32> @llvm.arm.neon.vpadalu.v4i32.v8i16(<4 x i32>, <8 x i16>) nounwind readnone +declare <2 x i64> @llvm.arm.neon.vpadalu.v2i64.v4i32(<2 x i64>, <4 x i32>) nounwind readnone + + +define weak_odr <8 x i16> @pairwise_Add_int16x8_int8x16_accumulate(<8 x i16> %a, <16 x i8> %x) nounwind alwaysinline { + %result = tail call <8 x i16> @llvm.arm.neon.vpadals.v8i16.v16i8(<8 x i16> %a, <16 x i8> %x) + ret <8 x i16> %result +} + +define weak_odr <4 x i16> @pairwise_Add_int16x4_int8x8_accumulate(<4 x i16> %a, <8 x i8> %x) nounwind alwaysinline { + %result = tail call <4 x i16> @llvm.arm.neon.vpadals.v4i16.v8i8(<4 x i16> %a, <8 x i8> %x) + ret <4 x i16> %result +} + +define weak_odr <4 x i32> @pairwise_Add_int32x4_int16x8_accumulate(<4 x i32> %a, <8 x i16> %x) nounwind alwaysinline { + %result = tail call <4 x i32> @llvm.arm.neon.vpadals.v4i32.v8i16(<4 x i32> %a, <8 x i16> %x) + ret <4 x i32> %result +} + +define weak_odr <2 x i32> @pairwise_Add_int32x2_int16x4_accumulate(<2 x i32> %a, <4 x i16> %x) nounwind alwaysinline { + %result = tail call <2 x i32> @llvm.arm.neon.vpadals.v2i32.v4i16(<2 x i32> %a, <4 x i16> %x) + ret <2 x i32> %result +} + +define weak_odr <2 x i64> @pairwise_Add_int64x2_int32x4_accumulate(<2 x i64> %a, <4 x i32> %x) nounwind alwaysinline { + %result = tail call <2 x i64> @llvm.arm.neon.vpadals.v2i64.v4i32(<2 x i64> %a, <4 x i32> %x) + ret <2 x i64> %result +} + +define weak_odr i64 @pairwise_Add_int64_int32x2_accumulate(i64 %a, <2 x i32> %x) nounwind alwaysinline { + %vec = insertelement <1 x i64> undef, i64 %a, i32 0 + %result = tail call <1 x i64> @llvm.arm.neon.vpadals.v1i64.v2i32(<1 x i64> %vec, <2 x i32> %x) + %scalar = extractelement <1 x i64> %result, i32 0 + ret i64 %scalar +} + +define weak_odr <8 x i16> @pairwise_Add_uint16x8_uint8x16_accumulate(<8 x i16> %a, <16 x i8> %x) nounwind alwaysinline { + %result = tail call <8 x i16> @llvm.arm.neon.vpadalu.v8i16.v16i8(<8 x i16> %a, <16 x i8> %x) + ret <8 x i16> %result +} + +define weak_odr <4 x i16> @pairwise_Add_uint16x4_uint8x8_accumulate(<4 x i16> %a, <8 x i8> %x) nounwind alwaysinline { + %result = tail call <4 x i16> @llvm.arm.neon.vpadalu.v4i16.v8i8(<4 x i16> %a, <8 x i8> %x) + ret <4 x i16> %result +} + +define weak_odr <4 x i32> @pairwise_Add_uint32x4_uint16x8_accumulate(<4 x i32> %a, <8 x i16> %x) nounwind alwaysinline { + %result = tail call <4 x i32> @llvm.arm.neon.vpadalu.v4i32.v8i16(<4 x i32> %a, <8 x i16> %x) + ret <4 x i32> %result +} + +define weak_odr <2 x i32> @pairwise_Add_uint32x2_uint16x4_accumulate(<2 x i32> %a, <4 x i16> %x) nounwind alwaysinline { + %result = tail call <2 x i32> @llvm.arm.neon.vpadalu.v2i32.v4i16(<2 x i32> %a, <4 x i16> %x) + ret <2 x i32> %result +} + +define weak_odr <2 x i64> @pairwise_Add_uint64x2_uint32x4_accumulate(<2 x i64> %a, <4 x i32> %x) nounwind alwaysinline { + %result = tail call <2 x i64> @llvm.arm.neon.vpadalu.v2i64.v4i32(<2 x i64> %a, <4 x i32> %x) + ret <2 x i64> %result +} + +define weak_odr i64 @pairwise_Add_uint64_uint32x2_accumulate(i64 %a, <2 x i32> %x) nounwind alwaysinline { + %vec = insertelement <1 x i64> undef, i64 %a, i32 0 + %result = tail call <1 x i64> @llvm.arm.neon.vpadalu.v1i64.v2i32(<1 x i64> %vec, <2 x i32> %x) + %scalar = extractelement <1 x i64> %result, i32 0 + ret i64 %scalar +} + +declare <2 x float> @llvm.arm.neon.vpmaxs.v2f32(<2 x float>, <2 x float>) nounwind readnone +declare <2 x i32> @llvm.arm.neon.vpmaxs.v2i32(<2 x i32>, <2 x i32>) nounwind readnone +declare <2 x i32> @llvm.arm.neon.vpmaxu.v2i32(<2 x i32>, <2 x i32>) nounwind readnone +declare <4 x i16> @llvm.arm.neon.vpmaxs.v4i16(<4 x i16>, <4 x i16>) nounwind readnone +declare <4 x i16> @llvm.arm.neon.vpmaxu.v4i16(<4 x i16>, <4 x i16>) nounwind readnone +declare <8 x i8> @llvm.arm.neon.vpmaxs.v8i8(<8 x i8>, <8 x i8>) nounwind readnone +declare <8 x i8> @llvm.arm.neon.vpmaxu.v8i8(<8 x i8>, <8 x i8>) nounwind readnone + +define weak_odr <8 x i8> @pairwise_Max_int8x8_int8x16(<16 x i8> %x) nounwind alwaysinline { + %a = shufflevector <16 x i8> %x, <16 x i8> undef, <8 x i32> + %b = shufflevector <16 x i8> %x, <16 x i8> undef, <8 x i32> + %result = tail call <8 x i8> @llvm.arm.neon.vpmaxs.v8i8(<8 x i8> %a, <8 x i8> %b) + ret <8 x i8> %result +} + +define weak_odr <4 x i16> @pairwise_Max_int16x4_int16x8(<8 x i16> %x) nounwind alwaysinline { + %a = shufflevector <8 x i16> %x, <8 x i16> undef, <4 x i32> + %b = shufflevector <8 x i16> %x, <8 x i16> undef, <4 x i32> + %result = tail call <4 x i16> @llvm.arm.neon.vpmaxs.v4i16(<4 x i16> %a, <4 x i16> %b) + ret <4 x i16> %result +} + +define weak_odr <2 x i32> @pairwise_Max_int32x2_int32x4(<4 x i32> %x) nounwind alwaysinline { + %a = shufflevector <4 x i32> %x, <4 x i32> undef, <2 x i32> + %b = shufflevector <4 x i32> %x, <4 x i32> undef, <2 x i32> + %result = tail call <2 x i32> @llvm.arm.neon.vpmaxs.v2i32(<2 x i32> %a, <2 x i32> %b) + ret <2 x i32> %result +} + +define weak_odr <2 x float> @pairwise_Max_float32x2_float32x4(<4 x float> %x) nounwind alwaysinline { + %a = shufflevector <4 x float> %x, <4 x float> undef, <2 x i32> + %b = shufflevector <4 x float> %x, <4 x float> undef, <2 x i32> + %result = tail call <2 x float> @llvm.arm.neon.vpmaxs.v2f32(<2 x float> %a, <2 x float> %b) + ret <2 x float> %result +} + +define weak_odr <8 x i8> @pairwise_Max_uint8x8_uint8x16(<16 x i8> %x) nounwind alwaysinline { + %a = shufflevector <16 x i8> %x, <16 x i8> undef, <8 x i32> + %b = shufflevector <16 x i8> %x, <16 x i8> undef, <8 x i32> + %result = tail call <8 x i8> @llvm.arm.neon.vpmaxu.v8i8(<8 x i8> %a, <8 x i8> %b) + ret <8 x i8> %result +} + +define weak_odr <4 x i16> @pairwise_Max_uint16x4_uint16x8(<8 x i16> %x) nounwind alwaysinline { + %a = shufflevector <8 x i16> %x, <8 x i16> undef, <4 x i32> + %b = shufflevector <8 x i16> %x, <8 x i16> undef, <4 x i32> + %result = tail call <4 x i16> @llvm.arm.neon.vpmaxu.v4i16(<4 x i16> %a, <4 x i16> %b) + ret <4 x i16> %result +} + +define weak_odr <2 x i32> @pairwise_Max_uint32x2_uint32x4(<4 x i32> %x) nounwind alwaysinline { + %a = shufflevector <4 x i32> %x, <4 x i32> undef, <2 x i32> + %b = shufflevector <4 x i32> %x, <4 x i32> undef, <2 x i32> + %result = tail call <2 x i32> @llvm.arm.neon.vpmaxu.v2i32(<2 x i32> %a, <2 x i32> %b) + ret <2 x i32> %result +} + + +declare <2 x float> @llvm.arm.neon.vpmins.v2f32(<2 x float>, <2 x float>) nounwind readnone +declare <2 x i32> @llvm.arm.neon.vpmins.v2i32(<2 x i32>, <2 x i32>) nounwind readnone +declare <2 x i32> @llvm.arm.neon.vpminu.v2i32(<2 x i32>, <2 x i32>) nounwind readnone +declare <4 x i16> @llvm.arm.neon.vpmins.v4i16(<4 x i16>, <4 x i16>) nounwind readnone +declare <4 x i16> @llvm.arm.neon.vpminu.v4i16(<4 x i16>, <4 x i16>) nounwind readnone +declare <8 x i8> @llvm.arm.neon.vpmins.v8i8(<8 x i8>, <8 x i8>) nounwind readnone +declare <8 x i8> @llvm.arm.neon.vpminu.v8i8(<8 x i8>, <8 x i8>) nounwind readnone + +define weak_odr <8 x i8> @pairwise_Min_int8x8_int8x16(<16 x i8> %x) nounwind alwaysinline { + %a = shufflevector <16 x i8> %x, <16 x i8> undef, <8 x i32> + %b = shufflevector <16 x i8> %x, <16 x i8> undef, <8 x i32> + %result = tail call <8 x i8> @llvm.arm.neon.vpmins.v8i8(<8 x i8> %a, <8 x i8> %b) + ret <8 x i8> %result +} + +define weak_odr <4 x i16> @pairwise_Min_int16x4_int16x8(<8 x i16> %x) nounwind alwaysinline { + %a = shufflevector <8 x i16> %x, <8 x i16> undef, <4 x i32> + %b = shufflevector <8 x i16> %x, <8 x i16> undef, <4 x i32> + %result = tail call <4 x i16> @llvm.arm.neon.vpmins.v4i16(<4 x i16> %a, <4 x i16> %b) + ret <4 x i16> %result +} + +define weak_odr <2 x i32> @pairwise_Min_int32x2_int32x4(<4 x i32> %x) nounwind alwaysinline { + %a = shufflevector <4 x i32> %x, <4 x i32> undef, <2 x i32> + %b = shufflevector <4 x i32> %x, <4 x i32> undef, <2 x i32> + %result = tail call <2 x i32> @llvm.arm.neon.vpmins.v2i32(<2 x i32> %a, <2 x i32> %b) + ret <2 x i32> %result +} + +define weak_odr <2 x float> @pairwise_Min_float32x2_float32x4(<4 x float> %x) nounwind alwaysinline { + %a = shufflevector <4 x float> %x, <4 x float> undef, <2 x i32> + %b = shufflevector <4 x float> %x, <4 x float> undef, <2 x i32> + %result = tail call <2 x float> @llvm.arm.neon.vpmins.v2f32(<2 x float> %a, <2 x float> %b) + ret <2 x float> %result +} + +define weak_odr <8 x i8> @pairwise_Min_uint8x8_uint8x16(<16 x i8> %x) nounwind alwaysinline { + %a = shufflevector <16 x i8> %x, <16 x i8> undef, <8 x i32> + %b = shufflevector <16 x i8> %x, <16 x i8> undef, <8 x i32> + %result = tail call <8 x i8> @llvm.arm.neon.vpminu.v8i8(<8 x i8> %a, <8 x i8> %b) + ret <8 x i8> %result +} + +define weak_odr <4 x i16> @pairwise_Min_uint16x4_uint16x8(<8 x i16> %x) nounwind alwaysinline { + %a = shufflevector <8 x i16> %x, <8 x i16> undef, <4 x i32> + %b = shufflevector <8 x i16> %x, <8 x i16> undef, <4 x i32> + %result = tail call <4 x i16> @llvm.arm.neon.vpminu.v4i16(<4 x i16> %a, <4 x i16> %b) + ret <4 x i16> %result +} + +define weak_odr <2 x i32> @pairwise_Min_uint32x2_uint32x4(<4 x i32> %x) nounwind alwaysinline { + %a = shufflevector <4 x i32> %x, <4 x i32> undef, <2 x i32> + %b = shufflevector <4 x i32> %x, <4 x i32> undef, <2 x i32> + %result = tail call <2 x i32> @llvm.arm.neon.vpminu.v2i32(<2 x i32> %a, <2 x i32> %b) + ret <2 x i32> %result +} \ No newline at end of file diff --git a/src/runtime/arm_cpu_features.cpp b/src/runtime/arm_cpu_features.cpp index 647bbd024c03..7293f333cbf0 100644 --- a/src/runtime/arm_cpu_features.cpp +++ b/src/runtime/arm_cpu_features.cpp @@ -20,6 +20,8 @@ WEAK CpuFeatures halide_get_cpu_features() { // features.set_available(halide_target_feature_armv7s); // } + // TODO: add runtime detection for ARMDotProd extension + // https://github.com/halide/Halide/issues/4727 return features; } diff --git a/src/runtime/ptx_dev.ll b/src/runtime/ptx_dev.ll index 4125e8bd3938..e93d3ebc1253 100644 --- a/src/runtime/ptx_dev.ll +++ b/src/runtime/ptx_dev.ll @@ -345,3 +345,45 @@ define weak_odr i32 @halide_ptx_trap() nounwind uwtable alwaysinline { ret i32 0 } +; llvm doesn't expose dot product instructions as intrinsics +define weak_odr i32 @dp4a_s32_s32(i32 %a, i32 %b, i32 %i) nounwind readnone alwaysinline { + %d = tail call i32 asm "dp4a.s32.s32 $0, $1, $2, $3;", "=r,r,r,r"(i32 %a, i32 %b, i32 %i) nounwind readnone + ret i32 %d +} + +define weak_odr i32 @dp4a_s32_u32(i32 %a, i32 %b, i32 %i) nounwind readnone alwaysinline { + %d = tail call i32 asm "dp4a.s32.u32 $0, $1, $2, $3;", "=r,r,r,r"(i32 %a, i32 %b, i32 %i) nounwind readnone + ret i32 %d +} + +define weak_odr i32 @dp4a_u32_s32(i32 %a, i32 %b, i32 %i) nounwind readnone alwaysinline { + %d = tail call i32 asm "dp4a.u32.s32 $0, $1, $2, $3;", "=r,r,r,r"(i32 %a, i32 %b, i32 %i) nounwind readnone + ret i32 %d +} + +define weak_odr i32 @dp4a_u32_u32(i32 %a, i32 %b, i32 %i) nounwind readnone alwaysinline { + %d = tail call i32 asm "dp4a.u32.u32 $0, $1, $2, $3;", "=r,r,r,r"(i32 %a, i32 %b, i32 %i) nounwind readnone + ret i32 %d +} + + +define weak_odr i32 @dp2a_s32_s32(i32 %a_lo, i32 %a_hi, i32 %b, i32 %i) nounwind readnone alwaysinline { + %d = tail call i32 asm "dp2a.lo.s32.s32 $0, $1, $3, $4; dp2a.hi.s32.s32 $0, $2, $3, $0;", "=r,r,r,r,r"(i32 %a_lo, i32 %a_hi, i32 %b, i32 %i) nounwind readnone + ret i32 %d +} + +define weak_odr i32 @dp2a_s32_u32(i32 %a_lo, i32 %a_hi, i32 %b, i32 %i) nounwind readnone alwaysinline { + %d = tail call i32 asm "dp2a.lo.s32.u32 $0, $1, $3, $4; dp2a.hi.s32.u32 $0, $2, $3, $0;", "=r,r,r,r,r"(i32 %a_lo, i32 %a_hi, i32 %b, i32 %i) nounwind readnone + ret i32 %d +} + +define weak_odr i32 @dp2a_u32_s32(i32 %a_lo, i32 %a_hi, i32 %b, i32 %i) nounwind readnone alwaysinline { + %d = tail call i32 asm "dp2a.lo.u32.s32 $0, $1, $3, $4; dp2a.hi.u32.s32 $0, $2, $3, $0;", "=r,r,r,r,r"(i32 %a_lo, i32 %a_hi, i32 %b, i32 %i) nounwind readnone + ret i32 %d +} + +define weak_odr i32 @dp2a_u32_u32(i32 %a_lo, i32 %a_hi, i32 %b, i32 %i) nounwind readnone alwaysinline { + %d = tail call i32 asm "dp2a.lo.u32.u32 $0, $1, $3, $4; dp2a.hi.u32.u32 $0, $2, $3, $0;", "=r,r,r,r,r"(i32 %a_lo, i32 %a_hi, i32 %b, i32 %i) nounwind readnone + ret i32 %d +} + diff --git a/test/correctness/CMakeLists.txt b/test/correctness/CMakeLists.txt index 4ea68cc100c0..f1773fcf2ade 100644 --- a/test/correctness/CMakeLists.txt +++ b/test/correctness/CMakeLists.txt @@ -67,6 +67,7 @@ tests(GROUPS correctness convolution.cpp convolution_multiple_kernels.cpp cross_compilation.cpp + cuda_8_bit_dot_product.cpp custom_allocator.cpp custom_auto_scheduler.cpp custom_error_reporter.cpp @@ -308,6 +309,7 @@ tests(GROUPS correctness tuple_select.cpp tuple_undef.cpp tuple_update_ops.cpp + tuple_vector_reduce.cpp two_vector_args.cpp undef.cpp uninitialized_read.cpp @@ -323,6 +325,7 @@ tests(GROUPS correctness vector_extern.cpp vector_math.cpp vector_print_bug.cpp + vector_reductions.cpp vector_tile.cpp vectorize_guard_with_if.cpp vectorize_mixed_widths.cpp diff --git a/test/correctness/atomics.cpp b/test/correctness/atomics.cpp index 71b1c268217d..91e35c49d829 100644 --- a/test/correctness/atomics.cpp +++ b/test/correctness/atomics.cpp @@ -22,7 +22,9 @@ template::value>::type * = nullptr> inline void check(int line_number, T x, T target, T threshold = T(1e-6)) { _halide_user_assert(std::fabs((x) - (target)) < threshold) - << "Line " << line_number << ": Expected " << (target) << " instead of " << (x) << "\n"; + << "Line " << line_number + << ": Expected " << (target) + << " instead of " << (x) << "\n"; } inline void check(int line_number, float16_t x, float16_t target) { @@ -37,7 +39,9 @@ template::value, int>::type * = nullptr> inline void check(int line_number, T x, T target) { _halide_user_assert(x == target) - << "Line " << line_number << ": Expected " << (target) << " instead of " << (x) << "\n"; + << "Line " << line_number + << ": Expected " << (int64_t)(target) + << " instead of " << (int64_t)(x) << "\n"; } template @@ -357,7 +361,7 @@ void test_predicated_hist(const Backend &backend) { case Backend::CUDAVectorize: { RVar ro, ri; RVar rio, rii; - hist.update() + hist.update(update_id) .atomic(true /*override_assciativity_test*/) .split(r, ro, ri, 32) .split(ri, rio, rii, 4) @@ -824,7 +828,7 @@ void test_hist_rfactor(const Backend &backend) { Func intermediate = hist.update() - .rfactor({{r.y, y}}); + .rfactor(r.y, y); intermediate.compute_root(); hist.compute_root(); switch (backend) { @@ -858,7 +862,13 @@ void test_hist_rfactor(const Backend &backend) { case Backend::CUDAVectorize: { RVar ro, ri; RVar rio, rii; - hist.update().atomic(true).split(r, ro, ri, 32).split(ri, rio, rii, 4).gpu_blocks(ro, DeviceAPI::CUDA).gpu_threads(rio, DeviceAPI::CUDA).vectorize(rii); + intermediate.update() + .atomic(true) + .split(r.x, ro, ri, 32) + .split(ri, rio, rii, 4) + .gpu_blocks(ro, DeviceAPI::CUDA) + .gpu_threads(rio, DeviceAPI::CUDA) + .vectorize(rii); } break; default: { _halide_user_assert(false) << "Unsupported backend.\n"; diff --git a/test/correctness/cuda_8_bit_dot_product.cpp b/test/correctness/cuda_8_bit_dot_product.cpp new file mode 100644 index 000000000000..8ec33f23458e --- /dev/null +++ b/test/correctness/cuda_8_bit_dot_product.cpp @@ -0,0 +1,90 @@ +#include "Halide.h" + +#include + +using namespace Halide; + +template +void test(Target t) { + for (int factor : {4, 16}) { + for (int vec : {1, 4}) { + std::cout + << "Testing dot product of " + << type_of() << " * " << type_of() << " -> " << type_of() + << " with vector width " << vec + << " and reduction factor " << factor << "\n"; + Func in_a, in_b; + Var x, y; + + in_a(x, y) = cast(x - y * 17); + in_a.compute_root(); + + in_b(x, y) = cast(x * 3 + y * 7); + in_b.compute_root(); + + Func g; + RDom r(0, factor * 4); + g(x, y) += cast(in_a(r, x)) * in_b(r, y); + + Func h; + h(x, y) = g(x, y); + + Var xi, yi; + g.update().atomic().vectorize(r, factor).unroll(r); + h.gpu_tile(x, y, xi, yi, 32, 8, TailStrategy::RoundUp); + + Buffer out(128, 128); + h.realize(out); + out.copy_to_host(); + + for (int y = 0; y < out.height(); y++) { + for (int x = 0; x < out.width(); x++) { + Out correct = 0; + for (int r = 0; r < factor * 4; r++) { + A in_a_r_x = (A)(r - x * 17); + B in_b_r_y = (B)(r * 3 + y * 7); + correct += ((Out)(in_a_r_x)) * in_b_r_y; + } + if (out(x, y) != correct) { + printf("out(%d, %d) = %d instead of %d\n", x, y, (int)(out(x, y)), (int)(correct)); + exit(-1); + } + } + } + + // Check the instruction was emitted intended by just grepping the + // compiled code (the PTX source is an embedded string). + Buffer buf = h.compile_to_module(std::vector(), "h", t).compile_to_buffer(); + std::basic_regex regex("dp[24]a[.lo]*[us]32[.][us]32"); + if (!std::regex_search((const char *)buf.begin(), (const char *)buf.end(), regex)) { + printf("Did not find use of dp2a or dp4a in compiled code. Rerun test with HL_DEBUG_CODEGEN=1 to debug\n"); + exit(-1); + } + } + } +} + +int main(int argc, char **argv) { + Target t = get_jit_target_from_environment(); + if (!t.has_feature(Target::CUDACapability61)) { + printf("[SKIP] Cuda (with compute capability 6.1) is not enabled in target: %s\n", + t.to_string().c_str()); + return 0; + } + + test(t); + test(t); + test(t); + test(t); + test(t); + test(t); + test(t); + test(t); + test(t); + test(t); + test(t); + test(t); + + printf("Success!\n"); + return 0; +} diff --git a/test/correctness/simd_op_check.cpp b/test/correctness/simd_op_check.cpp index a712f25c0962..782881c18b96 100644 --- a/test/correctness/simd_op_check.cpp +++ b/test/correctness/simd_op_check.cpp @@ -187,7 +187,7 @@ class SimdOpCheck : public SimdOpCheckTest { // SSE 2 - for (int w = 2; w <= 4; w++) { + for (int w : {2, 4}) { check("addpd", w, f64_1 + f64_2); check("subpd", w, f64_1 - f64_2); check("mulpd", w, f64_1 * f64_2); @@ -225,11 +225,8 @@ class SimdOpCheck : public SimdOpCheckTest { check(std::string("packuswb") + check_suffix, 8 * w, u8_sat(i16_1)); } - // SSE 3 + // SSE 3 / SSSE 3 - // We don't do horizontal add/sub ops, so nothing new here - - // SSSE 3 if (use_ssse3) { for (int w = 2; w <= 4; w++) { check("pmulhrsw", 4 * w, i16((((i32(i16_1) * i32(i16_2)) + 16384)) / 32768)); @@ -237,15 +234,68 @@ class SimdOpCheck : public SimdOpCheckTest { check("pabsw", 4 * w, abs(i16_1)); check("pabsd", 2 * w, abs(i32_1)); } + +#if LLVM_VERSION >= 90 + // Horizontal ops. Our support for them uses intrinsics + // from LLVM 9+. + + // Paradoxically, haddps is a bad way to do horizontal + // adds down to a single scalar on most x86. A better + // sequence (according to Peter Cordes on stackoverflow) + // is movshdup, addps, movhlps, addss. haddps is still + // good if you're only partially reducing and your result + // is at least one native vector, if only to save code + // size, but LLVM really really tries to avoid it and + // replace it with shuffles whenever it can, so we won't + // test for it. + // + // See: + // https://stackoverflow.com/questions/6996764/fastest-way-to-do-horizontal-float-vector-sum-on-x86 + + // For reducing down to a scalar we expect to see addps + // and movshdup. We'll sniff for the movshdup. + check("movshdup", 1, sum(in_f32(RDom(0, 2) + 2 * x))); + check("movshdup", 1, sum(in_f32(RDom(0, 4) + 4 * x))); + check("movshdup", 1, sum(in_f32(RDom(0, 16) + 16 * x))); + + // The integer horizontal add operations are pretty + // terrible on all x86 variants, and LLVM does its best to + // avoid generating those too, so we won't test that here + // either. + + // Min reductions should use phminposuw when + // possible. This only exists for u16. X86 is weird. + check("phminposuw", 1, minimum(in_u16(RDom(0, 8) + 8 * x))); + + // Max reductions can use the same instruction by first + // flipping the bits. + check("phminposuw", 1, maximum(in_u16(RDom(0, 8) + 8 * x))); + + // Reductions over signed ints can flip the sign bit + // before and after (equivalent to adding 128). + check("phminposuw", 1, minimum(in_i16(RDom(0, 8) + 8 * x))); + check("phminposuw", 1, maximum(in_i16(RDom(0, 8) + 8 * x))); + + // Reductions over 8-bit ints can widen first + check("phminposuw", 1, minimum(in_u8(RDom(0, 16) + 16 * x))); + check("phminposuw", 1, maximum(in_u8(RDom(0, 16) + 16 * x))); + check("phminposuw", 1, minimum(in_i8(RDom(0, 16) + 16 * x))); + check("phminposuw", 1, maximum(in_i8(RDom(0, 16) + 16 * x))); +#endif } // SSE 4.1 - // skip dot product and argmin - for (int w = 2; w <= 4; w++) { - const char *check_pmaddwd = (use_avx2 && w > 3) ? "vpmaddwd*ymm" : "pmaddwd"; + for (int w = 2; w <= 8; w++) { + // We generated pmaddwd when we do a sum of widening multiplies + const char *check_pmaddwd = + (use_avx2 && w >= 4) ? "vpmaddwd" : "pmaddwd"; check(check_pmaddwd, 2 * w, i32(i16_1) * 3 + i32(i16_2) * 4); check(check_pmaddwd, 2 * w, i32(i16_1) * 3 - i32(i16_2) * 4); + + // And also for dot-products + RDom r(0, 4); + check(check_pmaddwd, 2 * w, sum(i32(in_i16(x * 4 + r)) * in_i16(x * 4 + r + 32))); } // llvm doesn't distinguish between signed and unsigned multiplies @@ -888,12 +938,113 @@ class SimdOpCheck : public SimdOpCheckTest { // VORR X - Bitwise OR // check("vorr", bool1 | bool2); - // VPADAL I - Pairwise Add and Accumulate Long - // VPADD I, F - Pairwise Add - // VPADDL I - Pairwise Add Long - // VPMAX I, F - Pairwise Maximum - // VPMIN I, F - Pairwise Minimum - // We don't do horizontal ops + for (int f : {2, 4}) { + RDom r(0, f); + + // A summation reduction that starts at something + // non-trivial, to avoid llvm simplifying accumulating + // widening summations into just widening summations. + auto sum_ = [&](Expr e) { + Func f; + f(x) = cast(e.type(), 123); + f(x) += e; + return f(x); + }; + + // VPADD I, F - Pairwise Add + check(arm32 ? "vpadd.i8" : "addp", 16, sum_(in_i8(f * x + r))); + check(arm32 ? "vpadd.i8" : "addp", 16, sum_(in_u8(f * x + r))); + check(arm32 ? "vpadd.i16" : "addp", 8, sum_(in_i16(f * x + r))); + check(arm32 ? "vpadd.i16" : "addp", 8, sum_(in_u16(f * x + r))); + check(arm32 ? "vpadd.i32" : "addp", 4, sum_(in_i32(f * x + r))); + check(arm32 ? "vpadd.i32" : "addp", 4, sum_(in_u32(f * x + r))); + check(arm32 ? "vpadd.f32" : "addp", 4, sum_(in_f32(f * x + r))); + // In 32-bit, we don't have a pairwise op for doubles, + // and expect to just get vadd instructions on d + // registers. + check(arm32 ? "vadd.f64" : "addp", 4, sum_(in_f64(f * x + r))); + + if (f == 2) { + // VPADAL I - Pairwise Add and Accumulate Long + + // If we're reducing by a factor of two, we can + // use the forms with an accumulator + check(arm32 ? "vpadal.s8" : "sadalp", 16, sum_(i16(in_i8(f * x + r)))); + check(arm32 ? "vpadal.u8" : "uadalp", 16, sum_(i16(in_u8(f * x + r)))); + check(arm32 ? "vpadal.u8" : "uadalp", 16, sum_(u16(in_u8(f * x + r)))); + + check(arm32 ? "vpadal.s16" : "sadalp", 8, sum_(i32(in_i16(f * x + r)))); + check(arm32 ? "vpadal.u16" : "uadalp", 8, sum_(i32(in_u16(f * x + r)))); + check(arm32 ? "vpadal.u16" : "uadalp", 8, sum_(u32(in_u16(f * x + r)))); + + check(arm32 ? "vpadal.s32" : "sadalp", 4, sum_(i64(in_i32(f * x + r)))); + check(arm32 ? "vpadal.u32" : "uadalp", 4, sum_(i64(in_u32(f * x + r)))); + check(arm32 ? "vpadal.u32" : "uadalp", 4, sum_(u64(in_u32(f * x + r)))); + } else { + // VPADDL I - Pairwise Add Long + + // If we're reducing by more than that, that's not + // possible. + check(arm32 ? "vpaddl.s8" : "saddlp", 16, sum_(i16(in_i8(f * x + r)))); + check(arm32 ? "vpaddl.u8" : "uaddlp", 16, sum_(i16(in_u8(f * x + r)))); + check(arm32 ? "vpaddl.u8" : "uaddlp", 16, sum_(u16(in_u8(f * x + r)))); + + check(arm32 ? "vpaddl.s16" : "saddlp", 8, sum_(i32(in_i16(f * x + r)))); + check(arm32 ? "vpaddl.u16" : "uaddlp", 8, sum_(i32(in_u16(f * x + r)))); + check(arm32 ? "vpaddl.u16" : "uaddlp", 8, sum_(u32(in_u16(f * x + r)))); + + check(arm32 ? "vpaddl.s32" : "saddlp", 4, sum_(i64(in_i32(f * x + r)))); + check(arm32 ? "vpaddl.u32" : "uaddlp", 4, sum_(i64(in_u32(f * x + r)))); + check(arm32 ? "vpaddl.u32" : "uaddlp", 4, sum_(u64(in_u32(f * x + r)))); + + // If we're widening the type by a factor of four + // as well as reducing by a factor of four, we + // expect vpaddl followed by vpadal + check(arm32 ? "vpaddl.s8" : "saddlp", 8, sum_(i32(in_i8(f * x + r)))); + check(arm32 ? "vpaddl.u8" : "uaddlp", 8, sum_(i32(in_u8(f * x + r)))); + check(arm32 ? "vpaddl.u8" : "uaddlp", 8, sum_(u32(in_u8(f * x + r)))); + check(arm32 ? "vpaddl.s16" : "saddlp", 4, sum_(i64(in_i16(f * x + r)))); + check(arm32 ? "vpaddl.u16" : "uaddlp", 4, sum_(i64(in_u16(f * x + r)))); + check(arm32 ? "vpaddl.u16" : "uaddlp", 4, sum_(u64(in_u16(f * x + r)))); + + // Note that when going from u8 to i32 like this, + // the vpaddl is unsigned and the vpadal is a + // signed, because the intermediate type is u16 + check(arm32 ? "vpadal.s16" : "sadalp", 8, sum_(i32(in_i8(f * x + r)))); + check(arm32 ? "vpadal.u16" : "uadalp", 8, sum_(i32(in_u8(f * x + r)))); + check(arm32 ? "vpadal.u16" : "uadalp", 8, sum_(u32(in_u8(f * x + r)))); + check(arm32 ? "vpadal.s32" : "sadalp", 4, sum_(i64(in_i16(f * x + r)))); + check(arm32 ? "vpadal.u32" : "uadalp", 4, sum_(i64(in_u16(f * x + r)))); + check(arm32 ? "vpadal.u32" : "uadalp", 4, sum_(u64(in_u16(f * x + r)))); + } + + // VPMAX I, F - Pairwise Maximum + check(arm32 ? "vpmax.s8" : "smaxp", 16, maximum(in_i8(f * x + r))); + check(arm32 ? "vpmax.u8" : "umaxp", 16, maximum(in_u8(f * x + r))); + check(arm32 ? "vpmax.s16" : "smaxp", 8, maximum(in_i16(f * x + r))); + check(arm32 ? "vpmax.u16" : "umaxp", 8, maximum(in_u16(f * x + r))); + check(arm32 ? "vpmax.s32" : "smaxp", 4, maximum(in_i32(f * x + r))); + check(arm32 ? "vpmax.u32" : "umaxp", 4, maximum(in_u32(f * x + r))); + + // VPMIN I, F - Pairwise Minimum + check(arm32 ? "vpmin.s8" : "sminp", 16, minimum(in_i8(f * x + r))); + check(arm32 ? "vpmin.u8" : "uminp", 16, minimum(in_u8(f * x + r))); + check(arm32 ? "vpmin.s16" : "sminp", 8, minimum(in_i16(f * x + r))); + check(arm32 ? "vpmin.u16" : "uminp", 8, minimum(in_u16(f * x + r))); + check(arm32 ? "vpmin.s32" : "sminp", 4, minimum(in_i32(f * x + r))); + check(arm32 ? "vpmin.u32" : "uminp", 4, minimum(in_u32(f * x + r))); + } + + // UDOT/SDOT + if (target.has_feature(Target::ARMDotProd)) { + for (int f : {4, 8}) { + RDom r(0, f); + for (int v : {2, 4}) { + check("udot", v, sum(u32(in_u8(f * x + r)) * in_u8(f * x + r + 32))); + check("sdot", v, sum(i32(in_i8(f * x + r)) * in_i8(f * x + r + 32))); + } + } + } // VPOP X F, D Pop from Stack // VPUSH X F, D Push to Stack diff --git a/test/correctness/simd_op_check.h b/test/correctness/simd_op_check.h index ceec22221347..0232f95ef52e 100644 --- a/test/correctness/simd_op_check.h +++ b/test/correctness/simd_op_check.h @@ -130,6 +130,25 @@ class SimdOpCheckTest { TestResult check_one(const std::string &op, const std::string &name, int vector_width, Expr e) { std::ostringstream error_msg; + class HasInlineReduction : public Internal::IRVisitor { + using Internal::IRVisitor::visit; + void visit(const Internal::Call *op) override { + if (op->call_type == Internal::Call::Halide) { + Internal::Function f(op->func); + if (f.has_update_definition()) { + inline_reduction = f; + result = true; + } + } + IRVisitor::visit(op); + } + + public: + Internal::Function inline_reduction; + bool result = false; + } has_inline_reduction; + e.accept(&has_inline_reduction); + // Define a vectorized Halide::Func that uses the pattern. Halide::Func f(name); f(x, y) = e; @@ -142,10 +161,28 @@ class SimdOpCheckTest { f_scalar.bound(x, 0, W); f_scalar.compute_root(); + if (has_inline_reduction.result) { + // If there's an inline reduction, we want to vectorize it + // over the RVar. + Var xo, xi; + RVar rxi; + Func g{has_inline_reduction.inline_reduction}; + + // Do the reduction separately in f_scalar + g.clone_in(f_scalar); + + g.compute_at(f, x) + .update() + .split(x, xo, xi, vector_width) + .fuse(g.rvars()[0], xi, rxi) + .atomic() + .vectorize(rxi); + } + // The output to the pipeline is the maximum absolute difference as a double. - RDom r(0, W, 0, H); + RDom r_check(0, W, 0, H); Halide::Func error("error_" + name); - error() = Halide::cast(maximum(absd(f(r.x, r.y), f_scalar(r.x, r.y)))); + error() = Halide::cast(maximum(absd(f(r_check.x, r_check.y), f_scalar(r_check.x, r_check.y)))); setup_images(); { diff --git a/test/correctness/tuple_vector_reduce.cpp b/test/correctness/tuple_vector_reduce.cpp new file mode 100644 index 000000000000..5650f2fc5b68 --- /dev/null +++ b/test/correctness/tuple_vector_reduce.cpp @@ -0,0 +1,107 @@ +#include "Halide.h" + +using namespace Halide; +using namespace Halide::Internal; + +int main(int argc, char **argv) { + // Make sure a tuple-valued associative reduction can be + // horizontally vectorized. + + { + // Tuple addition + Func in; + Var x; + in(x) = {x, 2 * x}; + + Func f; + f() = {0, 0}; + + const int N = 100; + + RDom r(1, N); + f() = {f()[0] + in(r)[0], f()[1] + in(r)[1]}; + + in.compute_root(); + f.update().atomic().vectorize(r, 8).parallel(r); + + class CheckIR : public IRMutator { + using IRMutator::visit; + Expr visit(const VectorReduce *op) override { + vector_reduces++; + return IRMutator::visit(op); + } + Stmt visit(const Atomic *op) override { + atomics++; + mutexes += (!op->mutex_name.empty()); + return IRMutator::visit(op); + } + + public: + int atomics = 0, mutexes = 0, vector_reduces = 0; + } checker; + + f.add_custom_lowering_pass(&checker, []() {}); + + Realization result = f.realize(); + int a = Buffer(result[0])(); + int b = Buffer(result[1])(); + if (a != (N * (N + 1)) / 2 || b != N * (N + 1)) { + printf("Incorrect output: %d %d\n", a, b); + return -1; + } + + if (!checker.atomics) { + printf("Expected VectorReduce nodes\n"); + return -1; + } + + if (!checker.atomics) { + printf("Expected atomic nodes\n"); + return -1; + } + + if (checker.mutexes) { + printf("Did not expect mutexes\n"); + return -1; + } + } + + { + // Complex multiplication is associative. Let's multiply a bunch + // of complex numbers together. + Func in; + Var x; + in(x) = {cos(cast(x)), sin(cast(x))}; + + Func f; + f() = {1.0f, 0.0f}; + + RDom r(1, 50); + Expr a_real = f()[0]; + Expr a_imag = f()[1]; + Expr b_real = in(r)[0]; + Expr b_imag = in(r)[1]; + f() = {a_real * b_real - a_imag * b_imag, + a_real * b_imag + b_real * a_imag}; + + in.compute_root(); + f.update().atomic().vectorize(r, 8); + + // Sadly, this won't actually vectorize, because it's not + // expressible as a horizontal reduction op on a single + // vector. You'd need to rfactor. We can at least check we get + // the right value back though. + Realization result = f.realize(); + float a = Buffer(result[0])(); + float b = Buffer(result[1])(); + // We multiplied a large number of complex numbers of magnitude 1. + float mag = a * a + b * b; + if (mag <= 0.9 || mag >= 1.1) { + printf("Should have been magnitude one: %f + %f i\n", a, b); + return -1; + } + } + + printf("Success!\n"); + return 0; +} diff --git a/test/correctness/vector_reductions.cpp b/test/correctness/vector_reductions.cpp new file mode 100644 index 000000000000..18c0bc259def --- /dev/null +++ b/test/correctness/vector_reductions.cpp @@ -0,0 +1,126 @@ +#include "Halide.h" + +using namespace Halide; + +int main(int argc, char **argv) { + for (int dst_lanes : {1, 3}) { + for (int reduce_factor : {2, 3, 4}) { + std::vector types = + {UInt(8), Int(8), UInt(16), Int(16), UInt(32), Int(32), + UInt(64), Int(64), Float(16), Float(32), Float(64)}; + const int src_lanes = dst_lanes * reduce_factor; + for (Type src_type : types) { + for (int widen_factor : {1, 2, 4}) { + Type dst_type = src_type.with_bits(src_type.bits() * widen_factor); + if (std::find(types.begin(), types.end(), dst_type) == types.end()) { + continue; + } + + for (int op = 0; op < 7; op++) { + if (dst_type == Float(16) && reduce_factor > 2) { + // Reductions of float16s is really not very associative + continue; + } + + Var x, xo, xi; + RDom r(0, reduce_factor); + RVar rx; + Func in; + if (src_type.is_float()) { + in(x) = cast(src_type, random_float()); + } else { + in(x) = cast(src_type, random_int()); + } + in.compute_root(); + + Expr rhs = cast(dst_type, in(x * reduce_factor + r)); + Expr rhs2 = cast(dst_type, in(x * reduce_factor + r + 32)); + + if (op == 4 || op == 5) { + rhs = rhs > cast(rhs.type(), 5); + } + + Func f, ref("ref"); + switch (op) { + case 0: + f(x) += rhs; + ref(x) += rhs; + break; + case 1: + f(x) *= rhs; + ref(x) *= rhs; + break; + case 2: + // Widening min/max reductions are not interesting + if (widen_factor != 1) { + continue; + } + f(x) = rhs.type().min(); + ref(x) = rhs.type().min(); + f(x) = max(f(x), rhs); + ref(x) = max(f(x), rhs); + break; + case 3: + if (widen_factor != 1) { + continue; + } + f(x) = rhs.type().max(); + ref(x) = rhs.type().max(); + f(x) = min(f(x), rhs); + ref(x) = min(f(x), rhs); + break; + case 4: + if (widen_factor != 1) { + continue; + } + f(x) = cast(false); + ref(x) = cast(false); + f(x) = f(x) || rhs; + ref(x) = f(x) || rhs; + break; + case 5: + if (widen_factor != 1) { + continue; + } + f(x) = cast(true); + ref(x) = cast(true); + f(x) = f(x) && rhs; + ref(x) = f(x) && rhs; + break; + case 6: + // Dot product + f(x) += rhs * rhs2; + ref(x) += rhs * rhs2; + } + + f.compute_root() + .update() + .split(x, xo, xi, dst_lanes) + .fuse(r, xi, rx) + .atomic() + .vectorize(rx); + ref.compute_root(); + + RDom c(0, 128); + Expr err = cast(maximum(absd(f(c), ref(c)))); + + double e = evaluate(err); + + if (e > 1e-3) { + std::cerr + << "Horizontal reduction produced different output when vectorized!\n" + << "Maximum error = " << e << "\n" + << "Reducing from " << src_type.with_lanes(src_lanes) + << " to " << dst_type.with_lanes(dst_lanes) << "\n" + << "RHS: " << f.update_value() << "\n"; + exit(-1); + } + } + } + } + } + } + + printf("Success!\n"); + return 0; +} diff --git a/test/error/CMakeLists.txt b/test/error/CMakeLists.txt index 304676fca36c..f2e8feb7775d 100644 --- a/test/error/CMakeLists.txt +++ b/test/error/CMakeLists.txt @@ -6,7 +6,6 @@ tests(GROUPS error atomics_gpu_8_bit.cpp atomics_gpu_mutex.cpp atomics_self_reference.cpp - atomics_vectorized_mutex.cpp auto_schedule_no_parallel.cpp auto_schedule_no_reorder.cpp autodiff_unbounded.cpp diff --git a/test/error/atomics_vectorized_mutex.cpp b/test/error/atomics_vectorized_mutex.cpp deleted file mode 100644 index 75a71840e9ea..000000000000 --- a/test/error/atomics_vectorized_mutex.cpp +++ /dev/null @@ -1,29 +0,0 @@ -#include "Halide.h" - -using namespace Halide; - -int main(int argc, char **argv) { - int img_size = 10000; - - Func f; - Var x; - RDom r(0, img_size); - - f(x) = Tuple(1, 0); - f(r) = Tuple(f(r)[1] + 1, f(r)[0] + 1); - - f.compute_root(); - - f.update() - .atomic() - .vectorize(r, 8); - - // f's update will be lowered to mutex locks, - // and we don't allow vectorization on mutex locks since - // it leads to deadlocks. - // This should throw an error - Realization out = f.realize(img_size); - - printf("Success!\n"); - return 0; -} From 51e4c6d953ec89c741b824a2d4d72cd74a0b8891 Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Tue, 23 Jun 2020 13:32:55 -0700 Subject: [PATCH 4/4] Address review comments --- src/IRMatch.cpp | 14 ++++++++++--- src/IROperator.cpp | 25 +++++++++++++++--------- src/ModulusRemainder.cpp | 2 +- src/VectorizeLoops.cpp | 3 +++ test/correctness/tuple_vector_reduce.cpp | 2 +- test/correctness/vector_reductions.cpp | 2 ++ 6 files changed, 34 insertions(+), 14 deletions(-) diff --git a/src/IRMatch.cpp b/src/IRMatch.cpp index 6dfb5b8a2b68..8923a9546a81 100644 --- a/src/IRMatch.cpp +++ b/src/IRMatch.cpp @@ -285,7 +285,7 @@ class IRMatch : public IRVisitor { void visit(const VectorReduce *op) override { const VectorReduce *e = expr.as(); - if (result && e && op->op == e->op) { + if (result && e && op->op == e->op && types_match(op->type, e->type)) { expr = e->value; op->value.accept(this); } else { @@ -364,7 +364,10 @@ bool equal_helper(const BaseExprNode &a, const BaseExprNode &b) noexcept { case IRNodeType::StringImm: return ((const StringImm &)a).value == ((const StringImm &)b).value; case IRNodeType::Cast: - return equal_helper(((const Cast &)a).value, ((const Cast &)b).value); + // While we know a and b have matching type, we don't know + // that the types of the values match, so use equal rather + // than equal_helper. + return equal(((const Cast &)a).value, ((const Cast &)b).value); case IRNodeType::Variable: return ((const Variable &)a).name == ((const Variable &)b).name; case IRNodeType::Add: @@ -424,8 +427,13 @@ bool equal_helper(const BaseExprNode &a, const BaseExprNode &b) noexcept { return (equal_helper(((const Shuffle &)a).vectors, ((const Shuffle &)b).vectors) && equal_helper(((const Shuffle &)a).indices, ((const Shuffle &)b).indices)); case IRNodeType::VectorReduce: + // As with Cast above, we use equal instead of equal_helper + // here, because while we know a.type == b.type, we don't know + // if the types of the value fields also match. We could be + // comparing a reduction of an 8-vector down to a 4 vector to + // a reduction of a 16-vector down to a 4-vector. return (((const VectorReduce &)a).op == ((const VectorReduce &)b).op && - equal_helper(((const VectorReduce &)a).value, ((const VectorReduce &)b).value)); + equal(((const VectorReduce &)a).value, ((const VectorReduce &)b).value)); // Explicitly list all the Stmts instead of using a default // clause so that if new Exprs are added without being handled diff --git a/src/IROperator.cpp b/src/IROperator.cpp index 92d94ba96c82..5a3ef542a382 100644 --- a/src/IROperator.cpp +++ b/src/IROperator.cpp @@ -464,23 +464,30 @@ Expr lossless_cast(Type t, Expr e) { } } - if (const VectorReduce *red = e.as()) { - const int factor = red->value.type().lanes() / red->type.lanes(); - switch (red->op) { + if (const VectorReduce *reduce = e.as()) { + const int factor = reduce->value.type().lanes() / reduce->type.lanes(); + switch (reduce->op) { case VectorReduce::Add: - if (t.bits() >= 16 && factor < (1 << (t.bits() / 2))) { - Type narrower = red->value.type().with_bits(t.bits() / 2); - Expr val = lossless_cast(narrower, red->value); + // A horizontal add requires one extra bit per factor + // of two in the reduction factor. E.g. a reduction of + // 8 vector lanes down to 2 requires 2 extra bits in + // the output. We only deal with power-of-two types + // though, so just make sure the reduction factor + // isn't so large that it will more than double the + // number of bits required. + if (factor < (1 << (t.bits() / 2))) { + Type narrower = reduce->value.type().with_bits(t.bits() / 2); + Expr val = lossless_cast(narrower, reduce->value); if (val.defined()) { - return VectorReduce::make(red->op, val, red->type.lanes()); + return VectorReduce::make(reduce->op, val, reduce->type.lanes()); } } break; case VectorReduce::Max: case VectorReduce::Min: { - Expr val = lossless_cast(t, red->value); + Expr val = lossless_cast(t, reduce->value); if (val.defined()) { - return VectorReduce::make(red->op, val, red->type.lanes()); + return VectorReduce::make(reduce->op, val, reduce->type.lanes()); } break; } diff --git a/src/ModulusRemainder.cpp b/src/ModulusRemainder.cpp index 56b3afee9229..671af6e96dd3 100644 --- a/src/ModulusRemainder.cpp +++ b/src/ModulusRemainder.cpp @@ -491,7 +491,7 @@ void ComputeModulusRemainder::visit(const Let *op) { void ComputeModulusRemainder::visit(const Shuffle *op) { // It's possible that scalar expressions are extracting a lane of - // a vector - don't faiql in this case, but stop + // a vector - don't fail in this case, but stop internal_assert(op->indices.size() == 1) << "modulus_remainder of vector\n"; result = ModulusRemainder{}; } diff --git a/src/VectorizeLoops.cpp b/src/VectorizeLoops.cpp index 0231e2592b0e..86521396167e 100644 --- a/src/VectorizeLoops.cpp +++ b/src/VectorizeLoops.cpp @@ -183,6 +183,9 @@ Interval bounds_of_lanes(const Expr &e) { }; // A ramp with the lanes repeated (e.g. <0 0 2 2 4 4 6 6>) +// TODO(vksnk): With nested vectorization, this will be representable +// as a ramp(broadcast(a, repetitions), broadcast(b, repetitions, +// lanes) struct InterleavedRamp { Expr base, stride; int lanes, repetitions; diff --git a/test/correctness/tuple_vector_reduce.cpp b/test/correctness/tuple_vector_reduce.cpp index 5650f2fc5b68..84be15f80de7 100644 --- a/test/correctness/tuple_vector_reduce.cpp +++ b/test/correctness/tuple_vector_reduce.cpp @@ -50,7 +50,7 @@ int main(int argc, char **argv) { return -1; } - if (!checker.atomics) { + if (!checker.vector_reduces) { printf("Expected VectorReduce nodes\n"); return -1; } diff --git a/test/correctness/vector_reductions.cpp b/test/correctness/vector_reductions.cpp index 18c0bc259def..d4d2acc43984 100644 --- a/test/correctness/vector_reductions.cpp +++ b/test/correctness/vector_reductions.cpp @@ -37,6 +37,8 @@ int main(int argc, char **argv) { Expr rhs2 = cast(dst_type, in(x * reduce_factor + r + 32)); if (op == 4 || op == 5) { + // Test cases 4 and 5 in the switch + // statement below require a Bool rhs. rhs = rhs > cast(rhs.type(), 5); }