diff --git a/src/Fields/function_field.jl b/src/Fields/function_field.jl index 41b2c8b5fc..4f2ac384e2 100644 --- a/src/Fields/function_field.jl +++ b/src/Fields/function_field.jl @@ -58,9 +58,9 @@ fieldify_function(L, a::Function, grid) = FunctionField(L, a, grid) Adapt.adapt_structure(to, f::FunctionField{LX, LY, LZ}) where {LX, LY, LZ} = FunctionField{LX, LY, LZ}(Adapt.adapt(to, f.func), - Adapt.adapt(to, f.grid), - clock = Adapt.adapt(to, f.clock), - parameters = Adapt.adapt(to, f.parameters)) + Adapt.adapt(to, f.grid), + clock = Adapt.adapt(to, f.clock), + parameters = Adapt.adapt(to, f.parameters)) on_architecture(to, f::FunctionField{LX, LY, LZ}) where {LX, LY, LZ} = diff --git a/src/Fields/set!.jl b/src/Fields/set!.jl index 4dbb6ee96b..ea253a7b45 100644 --- a/src/Fields/set!.jl +++ b/src/Fields/set!.jl @@ -50,7 +50,7 @@ end ##### Setting to specific things ##### -function set_to_function!(u, f) +function set_to_function!(u::Field, f) # Supports serial and distributed arch = architecture(u) child_arch = child_architecture(u) @@ -60,7 +60,6 @@ function set_to_function!(u, f) cpu_arch = cpu_architecture(arch) cpu_grid = on_architecture(cpu_arch, u.grid) cpu_u = Field(location(u), cpu_grid; indices = indices(u)) - elseif child_arch isa CPU cpu_grid = u.grid cpu_u = u @@ -96,7 +95,7 @@ function set_to_function!(u, f) return u end -function set_to_array!(u, f) +function set_to_array!(u::Field, f) f = on_architecture(architecture(u), f) try @@ -118,7 +117,7 @@ function set_to_array!(u, f) return u end -function set_to_field!(u, v) +function set_to_field!(u::Field, v) # We implement some niceities in here that attempt to copy halo data, # and revert to copying just interior points if that fails. diff --git a/src/OutputReaders/field_time_series.jl b/src/OutputReaders/field_time_series.jl index 9d77c161c3..af2263aa91 100644 --- a/src/OutputReaders/field_time_series.jl +++ b/src/OutputReaders/field_time_series.jl @@ -25,7 +25,7 @@ using Oceananigans.Utils: launch! import Oceananigans.Architectures: architecture, on_architecture import Oceananigans.BoundaryConditions: fill_halo_regions!, BoundaryCondition, getbc -import Oceananigans.Fields: Field, set!, interior, indices, interpolate! +import Oceananigans.Fields: Field, interior, indices, interpolate! ##### ##### Data backends for FieldTimeSeries diff --git a/src/OutputReaders/field_time_series_indexing.jl b/src/OutputReaders/field_time_series_indexing.jl index d849192f67..e29ab8a404 100644 --- a/src/OutputReaders/field_time_series_indexing.jl +++ b/src/OutputReaders/field_time_series_indexing.jl @@ -283,3 +283,4 @@ function getindex(fts::InMemoryFTS, n::Int) return Field(location(fts), fts.grid; data, fts.boundary_conditions, fts.indices) end + diff --git a/src/OutputReaders/set_field_time_series.jl b/src/OutputReaders/set_field_time_series.jl index bde0968c8e..ce6a82311b 100644 --- a/src/OutputReaders/set_field_time_series.jl +++ b/src/OutputReaders/set_field_time_series.jl @@ -1,6 +1,70 @@ using Printf using Oceananigans.Architectures: cpu_architecture +import Oceananigans.Fields: set! + +function set!(u::InMemoryFTS, v::InMemoryFTS) + if child_architecture(u) === child_architecture(v) + # Note: we could try to copy first halo point even when halo + # regions are a different size. That's a bit more complicated than + # the below so we leave it for the future. + + try # to copy halo regions along with interior data + parent(u) .= parent(v) + catch # this could fail if the halo regions are different sizes? + # copy just the interior data + interior(u) .= interior(v) + end + else + v_data = on_architecture(child_architecture(u), v.data) + + # As above, we permit ourselves a little ambition and try to copy halo data: + try + parent(u) .= parent(v_data) + catch + interior(u) .= interior(v_data, location(v), v.grid, v.indices) + end + end + + return u +end + +function set!(u::InMemoryFTS, v::Function) + # Supports serial and distributed + arch = architecture(u) + child_arch = child_architecture(u) + LX, LY, LZ = location(u) + + # Determine cpu_grid and cpu_u + if child_arch isa GPU + cpu_arch = cpu_architecture(arch) + cpu_grid = on_architecture(cpu_arch, u.grid) + cpu_times = on_architecture(cpu_arch, u.times) + cpu_u = FieldTimeSeries{LX, LY, LZ}(cpu_grid, cpu_times; indices=indices(u)) + elseif child_arch isa CPU + cpu_arch = child_arch + cpu_grid = u.grid + cpu_times = u.times + cpu_u = u + end + + launch!(cpu_arch, cpu_grid, size(cpu_u), + _set_fts_to_function!, cpu_u, (LX(), LY(), LZ()), cpu_grid, cpu_times, v) + + # Transfer data to GPU if u is on the GPU + child_arch isa GPU && set!(u, cpu_u) + + return u +end + +@kernel function _set_fts_to_function!(fts, loc, grid, times, func) + i, j, k, n = @index(Global, NTuple) + X = node(i, j, k, grid, loc...) + @inbounds begin + fts[i, j, k, n] = func(X..., times[n]) + end +end + ##### ##### set! ##### @@ -46,10 +110,10 @@ function set!(fts::InMemoryFTS, path::String=fts.path, name::String=fts.name; wa end end - return nothing + return fts end -set!(fts::InMemoryFTS, value, n::Int) = set!(fts[n], value) +set!(fts::InMemoryFTS, v, n::Int) = set!(fts[n], value) function set!(fts::InMemoryFTS, fields_vector::AbstractVector{<:AbstractField}) raw_data = parent(fts) @@ -63,7 +127,7 @@ function set!(fts::InMemoryFTS, fields_vector::AbstractVector{<:AbstractField}) close(file) - return nothing + return fts end # Write property only if it does not already exist @@ -92,6 +156,8 @@ function set!(fts::OnDiskFTS, field::Field, n::Int, time=fts.times[n]) maybe_write_property!(file, "timeseries/t/$n", time) maybe_write_property!(file, "timeseries/$name/$n", Array(parent(field))) end + + return fts end function initialize_file!(file, name, fts) @@ -102,4 +168,5 @@ function initialize_file!(file, name, fts) return nothing end -set!(fts::OnDiskFTS, path::String, name::String) = nothing +set!(fts::OnDiskFTS, path::String, name::String) = fts + diff --git a/src/OutputWriters/checkpointer.jl b/src/OutputWriters/checkpointer.jl index 083c9790a1..c97dbaf385 100644 --- a/src/OutputWriters/checkpointer.jl +++ b/src/OutputWriters/checkpointer.jl @@ -211,6 +211,8 @@ end ##### set! for checkpointer filepaths ##### +set!(model::AbstractModel, ::Nothing) = nothing + """ set!(model, filepath::AbstractString) @@ -218,7 +220,6 @@ Set data in `model.velocities`, `model.tracers`, `model.timestepper.Gⁿ`, and `model.timestepper.G⁻` to checkpointed data stored at `filepath`. """ function set!(model::AbstractModel, filepath::AbstractString) - addr = checkpointer_address(model) jldopen(filepath, "r") do file