Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix correctness in cuda_mapreduce #2106

Merged
merged 1 commit into from
Dec 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ main

### ![][badge-🐛bugfix] Bug fixes

- Fixed writing/reading purely vertical spaces
- Fixed writing/reading purely vertical spaces. PR [2102](https://github.com/CliMA/ClimaCore.jl/pull/2102)
- Fixed correctness bug in reductions on GPUs. PR [2106](https://github.com/CliMA/ClimaCore.jl/pull/2106)

v0.14.20
--------
Expand Down
37 changes: 36 additions & 1 deletion ext/cuda/data_layouts_mapreduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,35 @@ function mapreduce_cuda(
weighted_jacobian = OnesArray(parent(data)),
opargs...,
)
# This function implements the following parallel reduction algorithm:
#
# Each thread in each blocks processes multiple data points at the same time
# (n_ops_on_load) each and we perform a block-wise reduction, with each
# block writing to an array of (block-)shared memory. This array has the
# same size as the block, ie, it is as long as many threads are available.
# Processing multiple points means that we apply the reduction to the point
# with index reduction[thread_index] = f(thread_index, thread_index +
# OFFSET), with various OFFSETS that depend on `n_ops_on_load` and block
# size.
#
# For the purpose of indexing, this is equivalent to having larger blocks
# with size effective_blksize = blksize * (n_ops_on_load + 1).
#
#
# After this operation, we have reduced all the data by a factor of
# 1/n_ops_on_load and have results in various arrays `reduction` (one per
# block)
#
# Once we have all the blocks reduced, we perform a tree reduction within
# the block and "move" the reduced value to the first element of the array.
# In this, one of the things to watch out for is that the last block might
# not necessarily have all threads doing work, so we have to be careful to
# not include data in `reduction` that did not have corresponding work.
# Threads of index 1 will write that array into an output array.
#
# The output array has size nblocks, so we do another round of reduction,
# but this time we put each Field in a different block.

S = eltype(data)
pdata = parent(data)
T = eltype(pdata)
Expand Down Expand Up @@ -112,7 +141,13 @@ function mapreduce_cuda_kernel!(
end
end
sync_threads()
_cuda_intrablock_reduce!(op, reduction, tidx, blksize)

# The last block might not have enough threads to fill `reduction`, so some
# of its elements might still have the value at initialization.
blksize_for_reduction =
min(blksize, nitems - effective_blksize * (bidx - 1))

_cuda_intrablock_reduce!(op, reduction, tidx, blksize_for_reduction)

tidx == 1 && (reduce_cuda[bidx, fidx] = reduction[1])
return nothing
Expand Down
32 changes: 32 additions & 0 deletions test/DataLayouts/unit_mapreduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -162,3 +162,35 @@ end
# data = DataLayouts.IJKFVH{S}(ArrayType{FT}, zeros; Nij,Nk,Nv,Nh); test_mapreduce_2!(context, data_view(data)) # TODO: test
# data = DataLayouts.IH1JH2{S}(ArrayType{FT}, zeros; Nij); test_mapreduce_2!(context, data_view(data)) # TODO: test
end

Sbozzolo marked this conversation as resolved.
Show resolved Hide resolved
@testset "mapreduce with space with some non-round blocks" begin
# https://github.com/CliMA/ClimaCore.jl/issues/2097
space = ClimaCore.CommonSpaces.RectangleXYSpace(;
x_min = 0,
x_max = 1,
y_min = 0,
y_max = 1,
periodic_x = false,
periodic_y = false,
n_quad_points = 4,
x_elem = 129,
y_elem = 129,
)
@test minimum(ones(space)) == 1

if ClimaComms.context isa ClimaComms.SingletonCommsContext
# Less than 256 threads
space = ClimaCore.CommonSpaces.RectangleXYSpace(;
x_min = 0,
x_max = 1,
y_min = 0,
y_max = 1,
periodic_x = false,
periodic_y = false,
n_quad_points = 2,
x_elem = 2,
y_elem = 2,
)
@test minimum(ones(space)) == 1
end
end
Loading