Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[TIR] ThreadAllreduce warp-level primitive support with multi-warp
This PR enhances the implementation of the LowerThreadAllreduce pass. Prior to this PR, for CUDA backend we will leverage warp-level primitives only when * the reducing threads are a sub-warp (i.e., size 16, 8, 4, 2), or * the number of reducing threads is less then 32, and equals the reduction extent. Under the requirement above, for reductions that have large number of reducing threads (e.g., reducing over 128, 256 or larger number or threads), the generated code is inefficient. This PR improves the LowerThreadAllreduce pass, so that we now generate more efficient CUDA code in such cases, when the number of reducing threads is a multiple of warp size, with the help of warp-level primitives. Specifically, in such cases, we first reducing 32 elements within each warp, getting the results of each warp stored in shared memory. We then trigger a second round of warp-level primitive reduction within the first warp, and get the final reduction results. In addition to using warp-level primitives, by doing this we also reduce the size of the shared memory. For example, even when reducing over 1024 threads, we now only require shared memory of size 32, compared with 1024 prior to this PR. Tests are added to ensure correctness.
- Loading branch information