-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[TIR] Allreduce broadcast result to each thread in multi-warp case (#…
…15373) PR #15327 introduces the warp-level primitive support in multi-warp allreduce. However, due to the specialty of the two-stage shuffle-down reduction implementation of the allreduce in multi-warp scenarios, PR #15327 did not broadcast the allreduce result to each reduction thread. This behavior does not align with the semantics of allreduce and is not ideal for many use cases. Therefore, this PR completes the implementation by inserting a stage of writing the reduction results to shared memory, so that each reduction thread across all the reduction warps can access the reduction results. This shared memory write-back stage will only be inserted in multi-warp allreduce cases. In single-warp allreduce, a `shfl_sync` is used to broadcast the reduction results across reduction threads. Since in multi-warp settings we cannot leverage warp-level primitives to broadcast the value, we can only make use of shared memory. The numerical correctness are verified locally.
- Loading branch information
1 parent
7ebc802
commit 5029477
Showing
2 changed files
with
70 additions
and
76 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.