Skip to content

Commit

Permalink
Merge pull request #309 from SciML/saveat
Browse files Browse the repository at this point in the history
Some automated conversions in saveat
  • Loading branch information
ChrisRackauckas authored Oct 23, 2023
2 parents 8bf5bb7 + e83d47e commit 304330d
Show file tree
Hide file tree
Showing 8 changed files with 191 additions and 352 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ DocStringExtensions = "0.8, 0.9"
ForwardDiff = "0.10"
KernelAbstractions = "0.9"
LinearSolve = "1.15, 2"
Metal = "0.4"
Metal = "0.5"
MuladdMacro = "0.2"
Parameters = "0.12"
RecursiveArrayTools = "2"
Expand Down
1 change: 1 addition & 0 deletions src/DiffEqGPU.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ include("ensemblegpukernel/integrators/stiff/interpolants.jl")
include("ensemblegpukernel/integrators/nonstiff/interpolants.jl")
include("ensemblegpukernel/nlsolve/type.jl")
include("ensemblegpukernel/nlsolve/utils.jl")
include("ensemblegpukernel/kernels.jl")

include("ensemblegpukernel/perform_step/gpu_tsit5_perform_step.jl")
include("ensemblegpukernel/perform_step/gpu_vern7_perform_step.jl")
Expand Down
6 changes: 5 additions & 1 deletion src/ensemblegpuarray/problem_generation.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
function generate_problem(prob::SciMLBase.AbstractODEProblem, u0, p, jac_prototype, colorvec)
function generate_problem(prob::SciMLBase.AbstractODEProblem,
u0,
p,
jac_prototype,
colorvec)
_f = let f = prob.f.f, kernel = DiffEqBase.isinplace(prob) ? gpu_kernel : gpu_kernel_oop
function (du, u, p, t)
version = get_backend(u)
Expand Down
114 changes: 114 additions & 0 deletions src/ensemblegpukernel/kernels.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@

# saveat is just a bool here:
# true: ts is a vector of timestamps to read from
# false: each ODE has its own timestamps, so ts is a vector to write to
@kernel function ode_solve_kernel(@Const(probs), alg, _us, _ts, dt, callback,
tstops, nsteps,
saveat, ::Val{save_everystep}) where {save_everystep}
i = @index(Global, Linear)

# get the actual problem for this thread
prob = @inbounds probs[i]

# get the input/output arrays for this thread
ts = @inbounds view(_ts, :, i)
us = @inbounds view(_us, :, i)

_saveat = get(prob.kwargs, :saveat, nothing)

saveat = _saveat === nothing ? saveat : _saveat

integ = init(alg, prob.f, false, prob.u0, prob.tspan[1], dt, prob.p, tstops,
callback, save_everystep, saveat)

u0 = prob.u0
tspan = prob.tspan

integ.cur_t = 0
if saveat !== nothing
integ.cur_t = 1
if prob.tspan[1] == saveat[1]
integ.cur_t += 1
@inbounds us[1] = u0
end
else
@inbounds ts[integ.step_idx] = prob.tspan[1]
@inbounds us[integ.step_idx] = prob.u0
end

integ.step_idx += 1
# FSAL
while integ.t < tspan[2] && integ.retcode != DiffEqBase.ReturnCode.Terminated
saved_in_cb = step!(integ, ts, us)
!saved_in_cb && savevalues!(integ, ts, us)
end
if integ.t > tspan[2] && saveat === nothing
## Intepolate to tf
@inbounds us[end] = integ(tspan[2])
@inbounds ts[end] = tspan[2]
end

if saveat === nothing && !save_everystep
@inbounds us[2] = integ.u
@inbounds ts[2] = integ.t
end
end

@kernel function ode_asolve_kernel(@Const(probs), alg, _us, _ts, dt, callback, tstops,
abstol, reltol,
saveat,
::Val{save_everystep}) where {save_everystep}
i = @index(Global, Linear)

# get the actual problem for this thread
prob = @inbounds probs[i]
# get the input/output arrays for this thread
ts = @inbounds view(_ts, :, i)
us = @inbounds view(_us, :, i)
# TODO: optimize contiguous view to return a CuDeviceArray

_saveat = get(prob.kwargs, :saveat, nothing)

saveat = _saveat === nothing ? saveat : _saveat

u0 = prob.u0
tspan = prob.tspan
f = prob.f
p = prob.p

t = tspan[1]
tf = prob.tspan[2]

integ = init(alg, prob.f, false, prob.u0, prob.tspan[1], prob.tspan[2], dt,
prob.p,
abstol, reltol, DiffEqBase.ODE_DEFAULT_NORM, tstops, callback,
saveat)

integ.cur_t = 0
if saveat !== nothing
integ.cur_t = 1
if tspan[1] == saveat[1]
integ.cur_t += 1
@inbounds us[1] = u0
end
else
@inbounds ts[1] = tspan[1]
@inbounds us[1] = u0
end

while integ.t < tspan[2] && integ.retcode != DiffEqBase.ReturnCode.Terminated
saved_in_cb = step!(integ, ts, us)
!saved_in_cb && savevalues!(integ, ts, us)
end

if integ.t > tspan[2] && saveat === nothing
## Intepolate to tf
@inbounds us[end] = integ(tspan[2])
@inbounds ts[end] = tspan[2]
end

if saveat === nothing && !save_everystep
@inbounds us[2] = integ.u
@inbounds ts[2] = integ.t
end
end
159 changes: 41 additions & 118 deletions src/ensemblegpukernel/lowerlevel_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ function vectorized_solve(probs, prob::ODEProblem, alg;
nsteps = length(timeseries)

prob = convert(ImmutableODEProblem, prob)

dt = convert(eltype(prob.tspan), dt)

if saveat === nothing
Expand All @@ -52,7 +51,29 @@ function vectorized_solve(probs, prob::ODEProblem, alg;
fill!(ts, prob.tspan[1])
us = allocate(backend, typeof(prob.u0), (len, length(probs)))
else
saveat = adapt(backend, saveat)
saveat = if saveat isa AbstractRange
_saveat = range(convert(eltype(prob.tspan), first(saveat)),
convert(eltype(prob.tspan), last(saveat)),
length = length(saveat))
convert(StepRangeLen{
eltype(_saveat),
eltype(_saveat),
eltype(_saveat),
eltype(_saveat) === Float32 ? Int32 : Int64,
},
_saveat)
elseif saveat isa AbstractVector
adapt(backend, convert.(eltype(prob.tspan), saveat))
else
_saveat = prob.tspan[1]:convert(eltype(prob.tspan), saveat):prob.tspan[end]
convert(StepRangeLen{
eltype(_saveat),
eltype(_saveat),
eltype(_saveat),
eltype(_saveat) === Float32 ? Int32 : Int64,
},
_saveat)
end
ts = allocate(backend, typeof(dt), (length(saveat), length(probs)))
fill!(ts, prob.tspan[1])
us = allocate(backend, typeof(prob.u0), (length(saveat), length(probs)))
Expand Down Expand Up @@ -99,7 +120,15 @@ function vectorized_solve(probs, prob::SDEProblem, alg;
fill!(ts, prob.tspan[1])
us = allocate(backend, typeof(prob.u0), (len, length(probs)))
else
saveat = adapt(backend, saveat)
saveat = if saveat isa AbstractRange
range(convert(eltype(prob.tspan), first(saveat)),
convert(eltype(prob.tspan), last(saveat)),
length = length(saveat))
elseif saveat isa AbstractVector
convert.(eltype(prob.tspan), adapt(backend, saveat))
else
prob.tspan[1]:convert(eltype(prob.tspan), saveat):prob.tspan[end]
end
ts = allocate(backend, typeof(dt), (length(saveat), length(probs)))
fill!(ts, prob.tspan[1])
us = allocate(backend, typeof(prob.u0), (length(saveat), length(probs)))
Expand Down Expand Up @@ -176,7 +205,15 @@ function vectorized_asolve(probs, prob::ODEProblem, alg;
fill!(ts, prob.tspan[1])
us = allocate(backend, typeof(prob.u0), (len, length(probs)))
else
saveat = adapt(backend, saveat)
saveat = if saveat isa AbstractRange
range(convert(eltype(prob.tspan), first(saveat)),
convert(eltype(prob.tspan), last(saveat)),
length = length(saveat))
elseif saveat isa AbstractVector
adapt(backend, convert.(eltype(prob.tspan), saveat))
else
prob.tspan[1]:convert(eltype(prob.tspan), saveat):prob.tspan[end]
end
ts = allocate(backend, typeof(dt), (length(saveat), length(probs)))
fill!(ts, prob.tspan[1])
us = allocate(backend, typeof(prob.u0), (length(saveat), length(probs)))
Expand Down Expand Up @@ -211,117 +248,3 @@ function vectorized_asolve(probs, prob::SDEProblem, alg;
kwargs...)
error("Adaptive time-stepping is not supported yet with GPUEM.")
end

# saveat is just a bool here:
# true: ts is a vector of timestamps to read from
# false: each ODE has its own timestamps, so ts is a vector to write to
@kernel function ode_solve_kernel(@Const(probs), alg, _us, _ts, dt, callback,
tstops, nsteps,
saveat, ::Val{save_everystep}) where {save_everystep}
i = @index(Global, Linear)

# get the actual problem for this thread
prob = @inbounds probs[i]

# get the input/output arrays for this thread
ts = @inbounds view(_ts, :, i)
us = @inbounds view(_us, :, i)

_saveat = get(prob.kwargs, :saveat, nothing)

saveat = _saveat === nothing ? saveat : _saveat

integ = init(alg, prob.f, false, prob.u0, prob.tspan[1], dt, prob.p, tstops,
callback, save_everystep, saveat)

u0 = prob.u0
tspan = prob.tspan

integ.cur_t = 0
if saveat !== nothing
integ.cur_t = 1
if prob.tspan[1] == saveat[1]
integ.cur_t += 1
@inbounds us[1] = u0
end
else
@inbounds ts[integ.step_idx] = prob.tspan[1]
@inbounds us[integ.step_idx] = prob.u0
end

integ.step_idx += 1
# FSAL
while integ.t < tspan[2] && integ.retcode != DiffEqBase.ReturnCode.Terminated
saved_in_cb = step!(integ, ts, us)
!saved_in_cb && savevalues!(integ, ts, us)
end
if integ.t > tspan[2] && saveat === nothing
## Intepolate to tf
@inbounds us[end] = integ(tspan[2])
@inbounds ts[end] = tspan[2]
end

if saveat === nothing && !save_everystep
@inbounds us[2] = integ.u
@inbounds ts[2] = integ.t
end
end

@kernel function ode_asolve_kernel(probs, alg, _us, _ts, dt, callback, tstops,
abstol, reltol,
saveat,
::Val{save_everystep}) where {save_everystep}
i = @index(Global, Linear)

# get the actual problem for this thread
prob = @inbounds probs[i]
# get the input/output arrays for this thread
ts = @inbounds view(_ts, :, i)
us = @inbounds view(_us, :, i)
# TODO: optimize contiguous view to return a CuDeviceArray

_saveat = get(prob.kwargs, :saveat, nothing)

saveat = _saveat === nothing ? saveat : _saveat

u0 = prob.u0
tspan = prob.tspan
f = prob.f
p = prob.p

t = tspan[1]
tf = prob.tspan[2]

integ = init(alg, prob.f, false, prob.u0, prob.tspan[1], prob.tspan[2], dt,
prob.p,
abstol, reltol, DiffEqBase.ODE_DEFAULT_NORM, tstops, callback,
saveat)

integ.cur_t = 0
if saveat !== nothing
integ.cur_t = 1
if tspan[1] == saveat[1]
integ.cur_t += 1
@inbounds us[1] = u0
end
else
@inbounds ts[1] = tspan[1]
@inbounds us[1] = u0
end

while integ.t < tspan[2] && integ.retcode != DiffEqBase.ReturnCode.Terminated
saved_in_cb = step!(integ, ts, us)
!saved_in_cb && savevalues!(integ, ts, us)
end

if integ.t > tspan[2] && saveat === nothing
## Intepolate to tf
@inbounds us[end] = integ(tspan[2])
@inbounds ts[end] = tspan[2]
end

if saveat === nothing && !save_everystep
@inbounds us[2] = integ.u
@inbounds ts[2] = integ.t
end
end
Loading

0 comments on commit 304330d

Please sign in to comment.