From 585bdc7c24b4f532b2a6db77e9546e2494ed5319 Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Thu, 11 Jun 2020 15:36:35 -0700 Subject: [PATCH] Handle combination of atomic() and vectorize() in lowering --- src/Lower.cpp | 2 +- src/ScheduleFunctions.cpp | 19 +- src/VectorizeLoops.cpp | 528 ++++++++++++++++++++++++++++++++++++-- src/VectorizeLoops.h | 2 +- 4 files changed, 522 insertions(+), 29 deletions(-) diff --git a/src/Lower.cpp b/src/Lower.cpp index 52f970f4fac7..62c3aa1e00cb 100644 --- a/src/Lower.cpp +++ b/src/Lower.cpp @@ -335,7 +335,7 @@ Module lower(const vector &output_funcs, << s << "\n\n"; debug(1) << "Vectorizing...\n"; - s = vectorize_loops(s, t); + s = vectorize_loops(s, env, t); s = simplify(s); debug(2) << "Lowering after vectorizing:\n" << s << "\n\n"; diff --git a/src/ScheduleFunctions.cpp b/src/ScheduleFunctions.cpp index 59f4269b3427..2b4ab53d5050 100644 --- a/src/ScheduleFunctions.cpp +++ b/src/ScheduleFunctions.cpp @@ -371,9 +371,22 @@ Stmt build_provide_loop_nest(const map &env, // Make the (multi-dimensional multi-valued) store node. Stmt body = Provide::make(func.name(), values, site); if (def.schedule().atomic()) { // Add atomic node. - // If required, we will allocate a mutex buffer called func.name() + ".mutex" - // The buffer is added in the AddAtomicMutex pass. - body = Atomic::make(func.name(), func.name() + ".mutex", body); + bool any_unordered_parallel = false; + for (auto d : def.schedule().dims()) { + any_unordered_parallel |= is_unordered_parallel(d.for_type); + } + if (any_unordered_parallel) { + // If required, we will allocate a mutex buffer called func.name() + ".mutex" + // The buffer is added in the AddAtomicMutex pass. + body = Atomic::make(func.name(), func.name() + ".mutex", body); + } else { + // No mutex is required if there is no parallelism, and it + // wouldn't work if all parallelism is synchronous + // (e.g. vectorization). Vectorization and the like will + // need to handle atomic nodes specially, by either + // emitting VectorReduce ops or scalarizing. + body = Atomic::make(func.name(), std::string{}, body); + } } // Default schedule/values if there is no specialization diff --git a/src/VectorizeLoops.cpp b/src/VectorizeLoops.cpp index 9bb04cfb9bc8..0231e2592b0e 100644 --- a/src/VectorizeLoops.cpp +++ b/src/VectorizeLoops.cpp @@ -18,6 +18,7 @@ namespace Halide { namespace Internal { +using std::map; using std::pair; using std::string; using std::vector; @@ -170,21 +171,95 @@ Interval bounds_of_lanes(const Expr &e) { } // Take the explicit min and max over the lanes - Expr min_lane = extract_lane(e, 0); - Expr max_lane = min_lane; - for (int i = 1; i < e.type().lanes(); i++) { - Expr next_lane = extract_lane(e, i); - if (e.type().is_bool()) { - min_lane = And::make(min_lane, next_lane); - max_lane = Or::make(max_lane, next_lane); - } else { - min_lane = Min::make(min_lane, next_lane); - max_lane = Max::make(max_lane, next_lane); - } + if (e.type().is_bool()) { + Expr min_lane = VectorReduce::make(VectorReduce::And, e, 1); + Expr max_lane = VectorReduce::make(VectorReduce::Or, e, 1); + return {min_lane, max_lane}; + } else { + Expr min_lane = VectorReduce::make(VectorReduce::Min, e, 1); + Expr max_lane = VectorReduce::make(VectorReduce::Max, e, 1); + return {min_lane, max_lane}; } - return {min_lane, max_lane}; }; +// A ramp with the lanes repeated (e.g. <0 0 2 2 4 4 6 6>) +struct InterleavedRamp { + Expr base, stride; + int lanes, repetitions; +}; + +bool is_interleaved_ramp(const Expr &e, const Scope &scope, InterleavedRamp *result) { + if (const Ramp *r = e.as()) { + result->base = r->base; + result->stride = r->stride; + result->lanes = r->lanes; + result->repetitions = 1; + return true; + } else if (const Broadcast *b = e.as()) { + result->base = b->value; + result->stride = 0; + result->lanes = b->lanes; + result->repetitions = 0; + return true; + } else if (const Add *add = e.as()) { + InterleavedRamp ra; + if (is_interleaved_ramp(add->a, scope, &ra) && + is_interleaved_ramp(add->b, scope, result) && + (ra.repetitions == 0 || + result->repetitions == 0 || + ra.repetitions == result->repetitions)) { + result->base = simplify(result->base + ra.base); + result->stride = simplify(result->stride + ra.stride); + if (!result->repetitions) { + result->repetitions = ra.repetitions; + } + return true; + } + } else if (const Sub *sub = e.as()) { + InterleavedRamp ra; + if (is_interleaved_ramp(sub->a, scope, &ra) && + is_interleaved_ramp(sub->b, scope, result) && + (ra.repetitions == 0 || + result->repetitions == 0 || + ra.repetitions == result->repetitions)) { + result->base = simplify(ra.base - result->base); + result->stride = simplify(ra.stride - result->stride); + if (!result->repetitions) { + result->repetitions = ra.repetitions; + } + return true; + } + } else if (const Mul *mul = e.as()) { + const int64_t *b = nullptr; + if (is_interleaved_ramp(mul->a, scope, result) && + (b = as_const_int(mul->b))) { + result->base = simplify(result->base * (int)(*b)); + result->stride = simplify(result->stride * (int)(*b)); + return true; + } + } else if (const Div *div = e.as
()) { + const int64_t *b = nullptr; + if (is_interleaved_ramp(div->a, scope, result) && + (b = as_const_int(div->b)) && + is_one(result->stride) && + (result->repetitions == 1 || + result->repetitions == 0) && + can_prove((result->base % (int)(*b)) == 0)) { + // TODO: Generalize this. Currently only matches + // ramp(base*b, 1, lanes) / b + // broadcast(base * b, lanes) / b + result->base = simplify(result->base / (int)(*b)); + result->repetitions *= (int)(*b); + return true; + } + } else if (const Variable *var = e.as()) { + if (scope.contains(var->name)) { + return is_interleaved_ramp(scope.get(var->name), scope, result); + } + } + return false; +} + // Allocations inside vectorized loops grow an additional inner // dimension to represent the separate copy of the allocation per // vector lane. This means loads and stores to them need to be @@ -384,6 +459,9 @@ class VectorSubs : public IRMutator { // vectors. Scope scope; + // The same set of Exprs, indexed by the vectorized var name + Scope vector_scope; + // A stack of all containing lets. We need to reinject the scalar // version of them if we scalarize inner code. vector> containing_lets; @@ -613,19 +691,24 @@ class VectorSubs : public IRMutator { // If the value was vectorized by this mutator, add a new name to // the scope for the vectorized value expression. - std::string vectorized_name; + string vectorized_name; if (was_vectorized) { vectorized_name = op->name + widening_suffix; scope.push(op->name, mutated_value); + vector_scope.push(vectorized_name, mutated_value); } Expr mutated_body = mutate(op->body); - if (mutated_value.same_as(op->value) && - mutated_body.same_as(op->body)) { + InterleavedRamp ir; + if (is_interleaved_ramp(mutated_value, vector_scope, &ir)) { + return substitute(vectorized_name, mutated_value, mutated_body); + } else if (mutated_value.same_as(op->value) && + mutated_body.same_as(op->body)) { return op; } else if (was_vectorized) { scope.pop(op->name); + vector_scope.pop(vectorized_name); return Let::make(vectorized_name, mutated_value, mutated_body); } else { return Let::make(op->name, mutated_value, mutated_body); @@ -634,7 +717,7 @@ class VectorSubs : public IRMutator { Stmt visit(const LetStmt *op) override { Expr mutated_value = mutate(op->value); - std::string mutated_name = op->name; + string mutated_name = op->name; // Check if the value was vectorized by this mutator. bool was_vectorized = (!op->value.type().is_vector() && @@ -643,6 +726,7 @@ class VectorSubs : public IRMutator { if (was_vectorized) { mutated_name += widening_suffix; scope.push(op->name, mutated_value); + vector_scope.push(mutated_name, mutated_value); // Also keep track of the original let, in case inner code scalarizes. containing_lets.emplace_back(op->name, op->value); } @@ -652,6 +736,7 @@ class VectorSubs : public IRMutator { if (was_vectorized) { containing_lets.pop_back(); scope.pop(op->name); + vector_scope.pop(mutated_name); // Inner code might have extracted my lanes using // extract_lane, which introduces a shuffle_vector. If @@ -688,8 +773,11 @@ class VectorSubs : public IRMutator { } } - if (mutated_value.same_as(op->value) && - mutated_body.same_as(op->body)) { + InterleavedRamp ir; + if (is_interleaved_ramp(mutated_value, vector_scope, &ir)) { + return substitute(mutated_name, mutated_value, mutated_body); + } else if (mutated_value.same_as(op->value) && + mutated_body.same_as(op->body)) { return op; } else { return LetStmt::make(mutated_name, mutated_value, mutated_body); @@ -893,7 +981,7 @@ class VectorSubs : public IRMutator { } Stmt visit(const Allocate *op) override { - std::vector new_extents; + vector new_extents; Expr new_expr; int lanes = replacement.type().lanes(); @@ -941,6 +1029,171 @@ class VectorSubs : public IRMutator { return Allocate::make(op->name, op->type, op->memory_type, new_extents, op->condition, body, new_expr, op->free_function); } + Stmt visit(const Atomic *op) override { + // Recognize a few special cases that we can handle as within-vector reduction trees. + do { + if (!op->mutex_name.empty()) { + // We can't vectorize over a mutex + break; + } + + // f[x] = f[x] y + const Store *store = op->body.as(); + if (!store) break; + + VectorReduce::Operator reduce_op = VectorReduce::Add; + Expr a, b; + if (const Add *add = store->value.as()) { + a = add->a; + b = add->b; + reduce_op = VectorReduce::Add; + } else if (const Mul *mul = store->value.as()) { + a = mul->a; + b = mul->b; + reduce_op = VectorReduce::Mul; + } else if (const Min *min = store->value.as()) { + a = min->a; + b = min->b; + reduce_op = VectorReduce::Min; + } else if (const Max *max = store->value.as()) { + a = max->a; + b = max->b; + reduce_op = VectorReduce::Max; + } else if (const Cast *cast_op = store->value.as()) { + if (cast_op->type.element_of() == UInt(8) && + cast_op->value.type().is_bool()) { + if (const And *and_op = cast_op->value.as()) { + a = and_op->a; + b = and_op->b; + reduce_op = VectorReduce::And; + } else if (const Or *or_op = cast_op->value.as()) { + a = or_op->a; + b = or_op->b; + reduce_op = VectorReduce::Or; + } + } + } + + if (!a.defined() || !b.defined()) { + break; + } + + // Bools get cast to uint8 for storage. Strip off that + // cast around any load. + if (b.type().is_bool()) { + const Cast *cast_op = b.as(); + if (cast_op) { + b = cast_op->value; + } + } + if (a.type().is_bool()) { + const Cast *cast_op = b.as(); + if (cast_op) { + a = cast_op->value; + } + } + + if (a.as() && !b.as()) { + std::swap(a, b); + } + + // We require b to be a var, because it should have been lifted. + const Variable *var_b = b.as(); + const Load *load_a = a.as(); + + if (!var_b || + !scope.contains(var_b->name) || + !load_a || + load_a->name != store->name || + !is_one(load_a->predicate) || + !is_one(store->predicate)) { + break; + } + + b = scope.get(var_b->name); + Expr store_index = mutate(store->index); + Expr load_index = mutate(load_a->index); + + // The load and store indices must be the same interleaved + // ramp (or the same scalar, in the total reduction case). + InterleavedRamp store_ir, load_ir; + Expr test; + if (store_index.type().is_scalar()) { + test = simplify(load_index == store_index); + } else if (is_interleaved_ramp(store_index, vector_scope, &store_ir) && + is_interleaved_ramp(load_index, vector_scope, &load_ir) && + store_ir.repetitions == load_ir.repetitions && + store_ir.lanes == load_ir.lanes) { + test = simplify(store_ir.base == load_ir.base && + store_ir.stride == load_ir.stride); + } + + if (!test.defined()) { + break; + } + + if (is_zero(test)) { + break; + } else if (!is_one(test)) { + // TODO: try harder by substituting in more things in scope + break; + } + + int output_lanes = 1; + if (store_index.type().is_scalar()) { + // The index doesn't depend on the value being + // vectorized, so it's a total reduction. + + b = VectorReduce::make(reduce_op, b, 1); + } else { + + output_lanes = store_index.type().lanes() / store_ir.repetitions; + + store_index = Ramp::make(store_ir.base, store_ir.stride, output_lanes); + b = VectorReduce::make(reduce_op, b, output_lanes); + } + + Expr new_load = Load::make(load_a->type.with_lanes(output_lanes), + load_a->name, store_index, load_a->image, + load_a->param, const_true(output_lanes), + ModulusRemainder{}); + + switch (reduce_op) { + case VectorReduce::Add: + b = new_load + b; + break; + case VectorReduce::Mul: + b = new_load * b; + break; + case VectorReduce::Min: + b = min(new_load, b); + break; + case VectorReduce::Max: + b = max(new_load, b); + break; + case VectorReduce::And: + b = cast(new_load.type(), cast(b.type(), new_load) && b); + break; + case VectorReduce::Or: + b = cast(new_load.type(), cast(b.type(), new_load) || b); + break; + } + + Stmt s = Store::make(store->name, b, store_index, store->param, + const_true(b.type().lanes()), store->alignment); + + // We may still need the atomic node, if there was more + // parallelism than just the vectorization. + s = Atomic::make(op->producer_name, op->mutex_name, s); + + return s; + } while (0); + + // In the general case, if a whole stmt has to be done + // atomically, we need to serialize. + return scalarize(op); + } + Stmt scalarize(Stmt s) { // Wrap a serial loop around it. Maybe LLVM will have // better luck vectorizing it. @@ -984,8 +1237,6 @@ class VectorSubs : public IRMutator { } } - debug(0) << e << " -> " << result << "\n"; - return result; } @@ -994,6 +1245,172 @@ class VectorSubs : public IRMutator { : var(std::move(v)), replacement(std::move(r)), target(t), in_hexagon(in_hexagon) { widening_suffix = ".x" + std::to_string(replacement.type().lanes()); } +}; // namespace + +class FindVectorizableExprsInAtomicNode : public IRMutator { + // An Atomic node protects all accesses to a given buffer. We + // consider a name "poisoned" if it depends on an access to this + // buffer. We can't lift or vectorize anything that has been + // poisoned. + Scope<> poisoned_names; + bool poison = false; + + using IRMutator::visit; + + template + const T *visit_let(const T *op) { + mutate(op->value); + ScopedBinding<> bind_if(poison, poisoned_names, op->name); + mutate(op->body); + return op; + } + + Stmt visit(const LetStmt *op) override { + return visit_let(op); + } + + Expr visit(const Let *op) override { + return visit_let(op); + } + + Expr visit(const Load *op) override { + // Even if the load is bad, maybe we can lift the index + IRMutator::visit(op); + + poison |= poisoned_names.contains(op->name); + return op; + } + + Expr visit(const Variable *op) override { + poison = poisoned_names.contains(op->name); + return op; + } + + Stmt visit(const Store *op) override { + // A store poisons all subsequent loads, but loads before the + // first store can be lifted. + mutate(op->index); + mutate(op->value); + poisoned_names.push(op->name); + return op; + } + + Expr visit(const Call *op) override { + IRMutator::visit(op); + poison |= !op->is_pure(); + return op; + } + +public: + using IRMutator::mutate; + + Expr mutate(const Expr &e) override { + bool old_poison = poison; + poison = false; + IRMutator::mutate(e); + if (!poison) { + liftable.insert(e); + } + poison |= old_poison; + // We're not actually mutating anything. This class is only a + // mutator so that we can override a generic mutate() method. + return e; + } + + FindVectorizableExprsInAtomicNode(const string &buf, const map &env) { + poisoned_names.push(buf); + auto it = env.find(buf); + if (it != env.end()) { + // Handle tuples + size_t n = it->second.values().size(); + if (n > 1) { + for (size_t i = 0; i < n; i++) { + poisoned_names.push(buf + "." + std::to_string(i)); + } + } + } + } + + std::set liftable; +}; + +class LiftVectorizableExprsOutOfSingleAtomicNode : public IRMutator { + const std::set &liftable; + + using IRMutator::visit; + + template + StmtOrExpr visit_let(const LetStmtOrLet *op) { + if (liftable.count(op->value)) { + // Lift it under its current name to avoid having to + // rewrite the variables in other lifted exprs. + // TODO: duplicate non-overlapping liftable let stmts due to unrolling. + lifted.emplace_back(op->name, op->value); + return mutate(op->body); + } else { + return IRMutator::visit(op); + } + } + + Stmt visit(const LetStmt *op) override { + return visit_let(op); + } + + Expr visit(const Let *op) override { + return visit_let(op); + } + +public: + map already_lifted; + vector> lifted; + + using IRMutator::mutate; + + Expr mutate(const Expr &e) override { + if (liftable.count(e) && !is_const(e) && !e.as()) { + auto it = already_lifted.find(e); + string name; + if (it != already_lifted.end()) { + name = it->second; + } else { + name = unique_name('t'); + lifted.emplace_back(name, e); + already_lifted.emplace(e, name); + } + return Variable::make(e.type(), name); + } else { + return IRMutator::mutate(e); + } + } + + LiftVectorizableExprsOutOfSingleAtomicNode(const std::set &liftable) + : liftable(liftable) { + } +}; + +class LiftVectorizableExprsOutOfAllAtomicNodes : public IRMutator { + using IRMutator::visit; + + Stmt visit(const Atomic *op) override { + FindVectorizableExprsInAtomicNode finder(op->producer_name, env); + finder.mutate(op->body); + LiftVectorizableExprsOutOfSingleAtomicNode lifter(finder.liftable); + Stmt new_body = lifter.mutate(op->body); + new_body = Atomic::make(op->producer_name, op->mutex_name, new_body); + while (!lifter.lifted.empty()) { + auto p = lifter.lifted.back(); + new_body = LetStmt::make(p.first, p.second, new_body); + lifter.lifted.pop_back(); + } + return new_body; + } + + const map &env; + +public: + LiftVectorizableExprsOutOfAllAtomicNodes(const map &env) + : env(env) { + } }; // Vectorize all loops marked as such in a Stmt @@ -1040,10 +1457,73 @@ class VectorizeLoops : public IRMutator { } }; -} // Anonymous namespace +/** Check if all stores in a Stmt are to names in a given scope. Used + by RemoveUnnecessaryAtomics below. */ +class AllStoresInScope : public IRVisitor { + using IRVisitor::visit; + void visit(const Store *op) override { + result = result && s.contains(op->name); + } + +public: + bool result = true; + const Scope<> &s; + AllStoresInScope(const Scope<> &s) + : s(s) { + } +}; +bool all_stores_in_scope(const Stmt &stmt, const Scope<> &scope) { + AllStoresInScope checker(scope); + stmt.accept(&checker); + return checker.result; +} + +/** Drop any atomic nodes protecting buffers that are only accessed + * from a single thread. */ +class RemoveUnnecessaryAtomics : public IRMutator { + using IRMutator::visit; + + // Allocations made from within this same thread + bool in_thread = false; + Scope<> local_allocs; + + Stmt visit(const Allocate *op) override { + ScopedBinding<> bind(local_allocs, op->name); + return IRMutator::visit(op); + } + + Stmt visit(const Atomic *op) override { + if (!in_thread || all_stores_in_scope(op->body, local_allocs)) { + return mutate(op->body); + } else { + return op; + } + } + + Stmt visit(const For *op) override { + if (is_parallel(op->for_type)) { + ScopedValue old_in_thread(in_thread, true); + Scope<> old_local_allocs; + old_local_allocs.swap(local_allocs); + Stmt s = IRMutator::visit(op); + old_local_allocs.swap(local_allocs); + return s; + } else { + return IRMutator::visit(op); + } + } +}; + +} // namespace -Stmt vectorize_loops(const Stmt &s, const Target &t) { - return VectorizeLoops(t).mutate(s); +Stmt vectorize_loops(const Stmt &stmt, const map &env, const Target &t) { + // Limit the scope of atomic nodes to just the necessary stuff. + // TODO: Should this be an earlier pass? It's probably a good idea + // for non-vectorizing stuff too. + Stmt s = LiftVectorizableExprsOutOfAllAtomicNodes(env).mutate(stmt); + s = VectorizeLoops(t).mutate(s); + s = RemoveUnnecessaryAtomics().mutate(s); + return s; } } // namespace Internal diff --git a/src/VectorizeLoops.h b/src/VectorizeLoops.h index 52e7dcb73309..cbde217e608d 100644 --- a/src/VectorizeLoops.h +++ b/src/VectorizeLoops.h @@ -15,7 +15,7 @@ namespace Internal { * them into single statements that operate on vectors. The loops in * question must have constant extent. */ -Stmt vectorize_loops(const Stmt &s, const Target &t); +Stmt vectorize_loops(const Stmt &s, const std::map &env, const Target &t); } // namespace Internal } // namespace Halide