Skip to content

Commit

Permalink
[TIR] ThreadAllreduce warp-level primitive support with multi-warp
Browse files Browse the repository at this point in the history
This PR enhances the implementation of the LowerThreadAllreduce pass.

Prior to this PR, for CUDA backend we will leverage warp-level
primitives only when
* the reducing threads are a sub-warp (i.e., size 16, 8, 4, 2), or
* the number of reducing threads is less then 32, and equals the
reduction extent.

Under the requirement above, for reductions that have large number
of reducing threads (e.g., reducing over 128, 256 or larger number
or threads), the generated code is inefficient.

This PR improves the LowerThreadAllreduce pass, so that we now generate
more efficient CUDA code in such cases, when the number of reducing
threads is a multiple of warp size, with the help of warp-level
primitives.

Specifically, in such cases, we first reducing 32 elements within
each warp, getting the results of each warp stored in shared memory.
We then trigger a second round of warp-level primitive reduction
within the first warp, and get the final reduction results.

In addition to using warp-level primitives, by doing this we also
reduce the size of the shared memory. For example, even when reducing
over 1024 threads, we now only require shared memory of size 32,
compared with 1024 prior to this PR.

Tests are added to ensure correctness.
  • Loading branch information
MasterJH5574 committed Jul 16, 2023
1 parent 9af8efc commit 5efb770
Show file tree
Hide file tree
Showing 4 changed files with 613 additions and 137 deletions.
2 changes: 1 addition & 1 deletion python/tvm/tir/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,7 +616,7 @@ def tvm_storage_sync(storage_scope):
call : PrimExpr
The call expression.
"""
return call_intrin("handle", "tir.tvm_storage_sync", storage_scope)
return call_intrin("int32", "tir.tvm_storage_sync", storage_scope)


def tvm_warp_shuffle(mask, value, warp_id, width, warp_size):
Expand Down
13 changes: 7 additions & 6 deletions src/te/operation/cross_thread_reduction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -181,22 +181,23 @@ Stmt MakeCrossThreadReduction(const ComputeOpNode* self, const Stage& stage,
freduce_args.push_back(dummy_load);
}

// Checks for the thread.
std::vector<PrimExpr> output_preds;
if (stage->store_predicate.defined()) {
output_preds.emplace_back(stage->store_predicate);
}

for (IterVar iv : stage->leaf_iter_vars) {
if (iv->iter_type == kCommReduce) {
auto it = stage->iter_var_attrs.find(iv);
if (it != stage->iter_var_attrs.end() && (*it).second->bind_thread.defined()) {
IterVar tv = (*it).second->bind_thread;
freduce_args.push_back(tv->var);
output_preds.push_back(tv->var == make_const(tv->var->dtype, 0));
}
}
}

// Checks for the thread.
std::vector<PrimExpr> output_preds;
if (stage->store_predicate.defined()) {
output_preds.emplace_back(stage->store_predicate);
}

// Apply the existing input predicate if any.
output_preds.push_back(input_pred);

Expand Down
Loading

0 comments on commit 5efb770

Please sign in to comment.