Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(0.3.0) Cleaner syntax for ECCOFieldTimeSeries and ECCORestoring #284

Merged
merged 19 commits into from
Dec 14, 2024
Merged
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
182 changes: 82 additions & 100 deletions src/DataWrangling/ECCO/ECCO_restoring.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using Oceananigans: location
using Oceananigans.Grids: node, on_architecture
using Oceananigans.Grids: AbstractGrid, node, on_architecture
using Oceananigans.Fields: interpolate!, interpolate, location, instantiated_location
using Oceananigans.OutputReaders: Cyclical, TotallyInMemory, AbstractInMemoryBackend, FlavorOfFTS, time_indices
using Oceananigans.Utils: Time
Expand Down Expand Up @@ -35,7 +35,7 @@ end
Adapt.adapt_structure(to, b::ECCONetCDFBackend{N, C}) where {N, C} = ECCONetCDFBackend{N, C}(b.start, b.length, nothing, nothing)

"""
ECCONetCDFBackend(length; on_native_grid = false, inpainting = NearestNeighborInpainting(Inf))
ECCONetCDFBackend(length; on_native_grid=false, inpainting=NearestNeighborInpainting(Inf))

Represent an ECCO FieldTimeSeries backed by ECCO native netCDF files.
Each time instance is stored in an individual file.
Expand All @@ -51,15 +51,15 @@ end
Base.length(backend::ECCONetCDFBackend) = backend.length
Base.summary(backend::ECCONetCDFBackend) = string("ECCONetCDFBackend(", backend.start, ", ", backend.length, ")")

const ECCONetCDFFTS{N} = FlavorOfFTS{<:Any, <:Any, <:Any, <:Any, <:ECCONetCDFBackend{N}} where N
const ECCOFieldTimeSeries{N} = FlavorOfFTS{<:Any, <:Any, <:Any, <:Any, <:ECCONetCDFBackend{N}} where N

new_backend(b::ECCONetCDFBackend{native, cache_data}, start, length) where {native, cache_data} =
ECCONetCDFBackend{native, cache_data}(start, length, b.inpainting, b.metadata)

on_native_grid(::ECCONetCDFBackend{native}) where native = native
cache_inpainted_data(::ECCONetCDFBackend{native, cache_data}) where {native, cache_data} = cache_data

function set!(fts::ECCONetCDFFTS)
function set!(fts::ECCOFieldTimeSeries)
backend = fts.backend
start = backend.start
inpainting = backend.inpainting
Expand Down Expand Up @@ -105,12 +105,10 @@ function ECCO_times(metadata; start_time = first(metadata).dates)
end

"""
ECCO_field_time_series(metadata::ECCOMetadata;
grid = nothing,
architecture = isnothing(grid) ? CPU() : architecture(grid),
time_indices_in_memory = 2,
time_indexing = Cyclical(),
inpainting_iterations = prod(size(metadata)),
ECCOFieldTimeSeries(metadata::ECCOMetadata [, arch_or_grid=CPU() ];
time_indices_in_memory = 2,
time_indexing = Cyclical(),
inpainting = nothing)

Create a field time series object for ECCO data.

Expand All @@ -119,13 +117,12 @@ Arguments

- `metadata`: `ECCOMetadata` containing information about the ECCO dataset.

- `arch_or_grid`: Either a grid to interpolate ECCO data to, or an `arch`itecture
to use for the native ECCO grid. Default: CPU().

Keyword Arguments
=================

- `grid`: where ECCO data is interpolated. If `nothing`, the native `ECCO` grid is used.

- `architecture`: where data is stored. Should only be set if `isnothing(grid)`.

- `time_indices_in_memory`: The number of time indices to keep in memory. Default: 2.

- `time_indexing`: The time indexing scheme to use. Default: `Cyclical()`.
Expand All @@ -138,49 +135,47 @@ Keyword Arguments
Default: `true`.

"""
function ECCO_field_time_series(metadata::ECCOMetadata;
architecture = CPU(),
time_indices_in_memory = 2,
time_indexing = Cyclical(),
inpainting = NearestNeighborInpainting(prod(size(metadata))),
cache_inpainted_data = true,
grid = nothing)
function ECCOFieldTimeSeries(metadata::ECCOMetadata, arch::AbstractArchitecture=CPU(); kw...)
download_dataset(metadata)
ftmp = empty_ECCO_field(first(metadata); architecture)
grid = ftmp.grid
return ECCOFieldTimeSeries(metadata, grid; kw...)
end

function ECCOFieldTimeSeries(metadata::ECCOMetadata, grid::AbstractGrid;
time_indices_in_memory = 2,
time_indexing = Cyclical(),
inpainting = nothing,
cache_inpainted_data = true)

# Make sure all the required individual files are downloaded
download_dataset(metadata)

inpainting isa Int && (inpainting = NearestNeighborInpainting(inpainting))

ftmp = empty_ECCO_field(first(metadata); architecture)
on_native_grid = isnothing(grid)
on_native_grid && (grid = ftmp.grid)
backend = ECCONetCDFBackend(time_indices_in_memory, metadata; on_native_grid, inpainting, cache_inpainted_data)

times = ECCO_times(metadata)
loc = LX, LY, LZ = location(metadata)
boundary_conditions = FieldBoundaryConditions(grid, loc)

backend = ECCONetCDFBackend(time_indices_in_memory, metadata; on_native_grid, inpainting, cache_inpainted_data)

fts = FieldTimeSeries{LX, LY, LZ}(grid, times; backend, time_indexing, boundary_conditions)
set!(fts)

return fts
end

ECCO_field_time_series(variable_name::Symbol, version=ECCO4Monthly(); kw...) =
ECCO_field_time_series(ECCOMetadata(variable_name, all_ECCO_dates(version), version); kw...)
ECCOFieldTimeSeries(variable_name::Symbol, version=ECCO4Monthly(); kw...) =
ECCOFieldTimeSeries(ECCOMetadata(variable_name, all_ECCO_dates(version), version); kw...)

# Variable names for restoreable data
struct Temperature end
struct Salinity end
struct UVelocity end
struct VVelocity end

oceananigans_fieldname = Dict(
:temperature => Temperature(),
:salinity => Salinity(),
:u_velocity => UVelocity(),
:v_velocity => VVelocity())
const oceananigans_fieldnames = Dict(:temperature => Temperature(),
:salinity => Salinity(),
:u_velocity => UVelocity(),
:v_velocity => VVelocity())

@inline Base.getindex(fields, i, j, k, ::Temperature) = @inbounds fields.T[i, j, k]
@inline Base.getindex(fields, i, j, k, ::Salinity) = @inbounds fields.S[i, j, k]
Expand All @@ -194,14 +189,14 @@ Base.summary(::VVelocity) = "v_velocity"

struct ECCORestoring{FTS, G, M, V, N}
field_time_series :: FTS
grid :: G
native_grid :: G
mask :: M
variable_name :: V
rate :: N
end

Adapt.adapt_structure(to, p::ECCORestoring) = ECCORestoring(Adapt.adapt(to, p.field_time_series),
Adapt.adapt(to, p.grid),
Adapt.adapt(to, p.native_grid),
Adapt.adapt(to, p.mask),
Adapt.adapt(to, p.variable_name),
Adapt.adapt(to, p.rate))
Expand All @@ -215,8 +210,7 @@ Adapt.adapt_structure(to, p::ECCORestoring) = ECCORestoring(Adapt.adapt(to, p.fi
# Possibly interpolate ECCO data from the ECCO grid to simulation grid.
# Otherwise, simply extract the pre-interpolated data from p.field_time_series.
backend = p.field_time_series.backend
interpolating = on_native_grid(backend)
ψ_ecco = maybe_interpolate(Val(interpolating), p.field_time_series, i, j, k, p.grid, grid, time)
ψ_ecco = maybe_interpolate(p.field_time_series, i, j, k, p.native_grid, grid, time)

ψ = @inbounds fields[i, j, k, p.variable_name]
μ = stateindex(p.mask, i, j, k, grid, clock.time, loc)
Expand All @@ -225,9 +219,9 @@ Adapt.adapt_structure(to, p::ECCORestoring) = ECCORestoring(Adapt.adapt(to, p.fi
return ω * μ * (ψ_ecco - ψ)
end

@inline maybe_interpolate(::Val{false}, fts, i, j, k, native_grid, grid, time) = @inbounds fts[i, j, k, time]
@inline maybe_interpolate(fts, i, j, k, ::Nothing, grid, time) = @inbounds fts[i, j, k, time]

@inline function maybe_interpolate(::Val{true}, fts, i, j, k, native_grid, grid, time)
@inline function maybe_interpolate(fts, i, j, k, native_grid, grid, time)
times = fts.times
data = fts.data
time_indexing = fts.time_indexing
Expand All @@ -240,40 +234,47 @@ end
end

"""
ECCORestoring([arch=CPU(),]
variable_name::Symbol;
version=ECCO4Monthly(),
dates=all_ECCO_dates(version),
dates = all_ECCO_dates(version),
ECCORestoring(variable_name::Symbol, [ arch_or_grid = CPU(), ];
version = ECCO4Monthly(),
dates = all_ECCO_dates(version),
time_indices_in_memory = 2,
time_indexing = Cyclical(),
mask = 1,
rate = 1,
grid = nothing,
inpainting = NearestNeighborInpainting(prod(size(metadata))),
inpainting = NearestNeighborInpainting(Inf),
cache_inpainted_data = true)

Create a forcing term that restores to values stored in an ECCO field time series.
Build a forcing term that restores to values stored in an ECCO field time series.
The restoring is applied as a forcing on the right hand side of the evolution equations calculated as

```julia
F = mask ⋅ rate ⋅ (ECCO_variable - simulation_variable[i, j, k])
```math
= μ r (ψ_ECCO - ψ)
```
where `ECCO_variable` is linearly interpolated in space and time from the ECCO dataset of choice to the
simulation grid and time.

where ``μ`` is the mask, ``r`` is the restoring rate, ``ψ`` is the simulation variable,
and the ECCO variable ``ψ_ECCO`` is linearly interpolated in space and time from the
ECCO dataset of choice to the simulation grid and time.

Arguments
=========

- `arch`: The architecture. Typically `CPU()` or `GPU()`. Default: `CPU()`.

- `variable_name`: The name of the variable to restore. Choices include:
* `:temperature`,
* `:salinity`,
* `:u_velocity`,
* `:v_velocity`,
* `:sea_ice_thickness`,
* `:sea_ice_area_fraction`.
* `:temperature`,
* `:salinity`,
* `:u_velocity`,
* `:v_velocity`,
* `:sea_ice_thickness`,
* `:sea_ice_area_fraction`.

Note that `ECCOMetadata` may be provided as the first argument instead
of `variable_name`. In this case the `version` and `dates` kwargs (described below)
cannot be provided.

- `arch_or_grid`: Either the architecture of the simulation, or a grid on which the ECCO data
is pre-interpolated when loaded. If an `arch`itecture is provided, such as
`arch_or_grid = CPU()` or `arch_or_grid = GPU()`, ECCO data
will be interpolated on-the-fly when the forcing tendency is computed.
Default: CPU().

Keyword Arguments
=================
Expand All @@ -282,77 +283,58 @@ Keyword Arguments

- `dates`: The dates to use for the ECCO dataset. Default: `all_ECCO_dates(version)`.

- `time_indices_in_memory`: The number of time indices to keep in memory; trade-off between performance
and memory footprint.
- `time_indices_in_memory`: The number of time indices to keep in memory. The number is chosen based on
a trade-off between increased performance (more indices in memory) and reduced
memory footprint (fewer indices in memory). Default: 2.

- `time_indexing`: The time indexing scheme for the field time series
- `time_indexing`: The time indexing scheme for the field time series.

- `mask`: The mask value. Can be a function of `(x, y, z, time)`, an array, or a number.

- `rate`: The restoring rate, i.e., the inverse of the restoring timescale (in s⁻¹).

- `time_indices_in_memory:` how many time instances are loaded in memory; the remaining are loaded lazily.

- `inpainting`: inpainting algorithm, see [`inpaint_mask!`](@ref). Default: `NearestNeighborInpainting(Inf)`.

- `grid`: If `isnothing(grid)`, ECCO data is interpolated on-the-fly to the simulation grid.
If `!isnothing(grid)`, ECCO data is pre-interpolated to `grid`.
Default: nothing.

- `cache_inpainted_data`: If `true`, the data is cached to disk after inpainting for later retrieving.
Default: `true`.

It is possible to also pass an `ECCOMetadata` type as the first argument without the need for the
`variable_name` argument and the `version` and `dates` keyword arguments.
"""
function ECCORestoring(arch::AbstractArchitecture,
variable_name::Symbol;
version=ECCO4Monthly(),
dates=all_ECCO_dates(version),
function ECCORestoring(variable_name::Symbol,
arch_or_grid = CPU();
version = ECCO4Monthly(),
dates = all_ECCO_dates(version),
kw...)

metadata = ECCOMetadata(variable_name, dates, version)
return ECCORestoring(arch, metadata; kw...)
return ECCORestoring(metadata, arch_or_grid; kw...)
end

function ECCORestoring(arch::AbstractArchitecture,
metadata::ECCOMetadata;
function ECCORestoring(metadata::ECCOMetadata,
arch_or_grid = CPU();
rate,
mask = 1,
grid = nothing,
time_indices_in_memory = 2, # Not more than this if we want to use GPU!
time_indexing = Cyclical(),
inpainting = NearestNeighborInpainting(Inf),
cache_inpainted_data = true)

# Validate architecture
if !isnothing(grid) && architecture(grid) != arch
throw(ArgumentError("The architecture of ECCORestoring must match the architecture of the grid."))
end

fts = ECCO_field_time_series(metadata;
grid,
architecture = arch,
time_indices_in_memory,
time_indexing,
inpainting,
cache_inpainted_data)
fts = ECCOFieldTimeSeries(metadata, arch_or_grid;
time_indices_in_memory,
time_indexing,
inpainting,
cache_inpainted_data)

# Grab the correct Oceananigans field to restore
variable_name = metadata.name
field_name = oceananigans_fieldname[variable_name]
field_name = oceananigans_fieldnames[variable_name]

# If we pass the grid we do not need to interpolate
# so we can save parameter space by setting the native grid to nothing
native_grid = isnothing(grid) ? fts.grid : nothing
on_native_grid = arch_or_grid isa AbstractArchitecture
maybe_native_grid = on_native_grid ? fts.grid : nothing

return ECCORestoring(fts, native_grid, mask, field_name, rate)
return ECCORestoring(fts, maybe_native_grid, mask, field_name, rate)
end

# Make sure we can call ECCORestoring with architecture as the first positional argument
ECCORestoring(variable_name::Symbol; kw...) = ECCORestoring(CPU(), variable_name; kw...)
ECCORestoring(metadata::ECCOMetadata; kw...) = ECCORestoring(CPU(), metadata; kw...)

function Base.show(io::IO, p::ECCORestoring)
print(io, "ECCORestoring:", '\n',
"├── restored variable: ", summary(p.variable_name), '\n',
Expand Down
Loading