Skip to content

Commit

Permalink
Move in-kernel specific adapt functions to cuda ext
Browse files Browse the repository at this point in the history
  • Loading branch information
charleskawczynski committed Jan 8, 2025
1 parent b099c3e commit f1dbd7d
Show file tree
Hide file tree
Showing 7 changed files with 45 additions and 30 deletions.
1 change: 1 addition & 0 deletions ext/ClimaCoreCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import ClimaCore.RecursiveApply:
import ClimaCore.DataLayouts: get_N, get_Nv, get_Nij, get_Nij, get_Nh
import ClimaCore.DataLayouts: UniversalSize

include(joinpath("cuda", "adapt.jl"))
include(joinpath("cuda", "cuda_utils.jl"))
include(joinpath("cuda", "data_layouts.jl"))
include(joinpath("cuda", "fields.jl"))
Expand Down
44 changes: 44 additions & 0 deletions ext/cuda/adapt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import CUDA, Adapt
import ClimaCore
import ClimaCore: Grids, Spaces, Topologies, Devices

Adapt.adapt_structure(
to::CUDA.KernelAdaptor,
grid::Grids.ExtrudedFiniteDifferenceGrid,
) = Grids.DeviceExtrudedFiniteDifferenceGrid(
Adapt.adapt(to, vertical_topology(grid)),
Adapt.adapt(to, grid.horizontal_grid.quadrature_style),
Adapt.adapt(to, grid.global_geometry),
Adapt.adapt(to, grid.center_local_geometry),
Adapt.adapt(to, grid.face_local_geometry),
)

Adapt.adapt_structure(
to::CUDA.KernelAdaptor,
grid::Grids.FiniteDifferenceGrid,
) = Grids.DeviceFiniteDifferenceGrid(
Adapt.adapt(to, grid.topology),
Adapt.adapt(to, grid.global_geometry),
Adapt.adapt(to, grid.center_local_geometry),
Adapt.adapt(to, grid.face_local_geometry),
)

Adapt.adapt_structure(
to::CUDA.KernelAdaptor,
grid::Grids.SpectralElementGrid2D,
) = Grids.DeviceSpectralElementGrid2D(
Adapt.adapt(to, grid.quadrature_style),
Adapt.adapt(to, grid.global_geometry),
Adapt.adapt(to, grid.local_geometry),
)

Adapt.adapt_structure(to::CUDA.KernelAdaptor, space::Spaces.PointSpace) =
Spaces.PointSpace(
ClimaCore.DeviceSideContext(),
Adapt.adapt(to, Spaces.local_geometry_data(space)),
)

Adapt.adapt_structure(
to::CUDA.KernelAdaptor,
topology::Topologies.IntervalTopology,
) = Topologies.DeviceIntervalTopology(topology.boundaries)
9 changes: 0 additions & 9 deletions src/Grids/extruded.jl
Original file line number Diff line number Diff line change
Expand Up @@ -155,15 +155,6 @@ local_geometry_type(
::Type{DeviceExtrudedFiniteDifferenceGrid{VT, Q, GG, CLG, FLG}},
) where {VT, Q, GG, CLG, FLG} = eltype(CLG) # calls eltype from DataLayouts

Adapt.adapt_structure(to, grid::ExtrudedFiniteDifferenceGrid) =
DeviceExtrudedFiniteDifferenceGrid(
Adapt.adapt(to, vertical_topology(grid)),
Adapt.adapt(to, grid.horizontal_grid.quadrature_style),
Adapt.adapt(to, grid.global_geometry),
Adapt.adapt(to, grid.center_local_geometry),
Adapt.adapt(to, grid.face_local_geometry),
)

quadrature_style(grid::DeviceExtrudedFiniteDifferenceGrid) =
grid.quadrature_style
vertical_topology(grid::DeviceExtrudedFiniteDifferenceGrid) =
Expand Down
8 changes: 0 additions & 8 deletions src/Grids/finitedifference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -190,14 +190,6 @@ local_geometry_type(
::Type{DeviceFiniteDifferenceGrid{T, GG, CLG, FLG}},
) where {T, GG, CLG, FLG} = eltype(CLG) # calls eltype from DataLayouts

Adapt.adapt_structure(to, grid::FiniteDifferenceGrid) =
DeviceFiniteDifferenceGrid(
Adapt.adapt(to, grid.topology),
Adapt.adapt(to, grid.global_geometry),
Adapt.adapt(to, grid.center_local_geometry),
Adapt.adapt(to, grid.face_local_geometry),
)

topology(grid::DeviceFiniteDifferenceGrid) = grid.topology
vertical_topology(grid::DeviceFiniteDifferenceGrid) = grid.topology

Expand Down
7 changes: 0 additions & 7 deletions src/Grids/spectralelement.jl
Original file line number Diff line number Diff line change
Expand Up @@ -597,13 +597,6 @@ end
ClimaComms.context(grid::DeviceSpectralElementGrid2D) = DeviceSideContext()
ClimaComms.device(grid::DeviceSpectralElementGrid2D) = DeviceSideDevice()

Adapt.adapt_structure(to, grid::SpectralElementGrid2D) =
DeviceSpectralElementGrid2D(
Adapt.adapt(to, grid.quadrature_style),
Adapt.adapt(to, grid.global_geometry),
Adapt.adapt(to, grid.local_geometry),
)

## aliases
const RectilinearSpectralElementGrid2D =
SpectralElementGrid2D{<:Topologies.RectilinearTopology2D}
Expand Down
4 changes: 0 additions & 4 deletions src/Spaces/pointspace.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,6 @@ function PointSpace(
return PointSpace(context, Adapt.adapt(ArrayType, local_geometry_data))
end


Adapt.adapt_structure(to, space::PointSpace) =
PointSpace(DeviceSideContext(), Adapt.adapt(to, local_geometry_data(space)))

function PointSpace(
context::ClimaComms.AbstractCommsContext,
coord::Geometry.Abstract1DPoint{FT},
Expand Down
2 changes: 0 additions & 2 deletions src/Topologies/interval.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@ end
struct DeviceIntervalTopology{B} <: AbstractIntervalTopology
boundaries::B
end
Adapt.adapt_structure(to, topology::IntervalTopology) =
DeviceIntervalTopology(topology.boundaries)

ClimaComms.context(topology::DeviceIntervalTopology) = DeviceSideContext()
ClimaComms.device(topology::DeviceIntervalTopology) = DeviceSideDevice()
Expand Down

0 comments on commit f1dbd7d

Please sign in to comment.