Skip to content

Commit

Permalink
refactor: use cleaner structs and stable Lux APIs
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Sep 10, 2024
1 parent e795d81 commit 96a2d34
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 43 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
Expand Down Expand Up @@ -61,6 +62,7 @@ ReTestItems = "1.25.1"
Reexport = "0.2, 1"
SciMLBase = "2"
SciMLSensitivity = "7"
Setfield = "1.1.1"
Statistics = "1.10"
StochasticDiffEq = "6.68.0"
Test = "1.10"
Expand Down
1 change: 1 addition & 0 deletions src/DiffEqFlux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ using SciMLSensitivity: SciMLSensitivity, AdjointLSS, BacksolveAdjoint, EnzymeVJ
NILSS, QuadratureAdjoint, ReverseDiffAdjoint, ReverseDiffVJP,
SteadyStateAdjoint, TrackerAdjoint, TrackerVJP, ZygoteAdjoint,
ZygoteVJP
using Setfield: @set!
using Zygote: Zygote

const CRC = ChainRulesCore
Expand Down
29 changes: 9 additions & 20 deletions src/ffjord.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,9 @@ Information Processing Systems, pp. 6572-6583. 2018.
"Ffjord: Free-form continuous dynamics for scalable reversible generative models." arXiv
preprint arXiv:1810.01367 (2018).
"""
@concrete struct FFJORD{M <: AbstractExplicitLayer, D <: Union{Nothing, Distribution}} <:
CNFLayer
model::M
basedist::D
@concrete struct FFJORD <: CNFLayer
model <: AbstractExplicitLayer
basedist <: Union{Nothing, Distribution}
ad
input_dims
tspan
Expand All @@ -76,11 +75,11 @@ end

@inline __norm_batched(x) = sqrt.(sum(abs2, x; dims = 1:(ndims(x) - 1)))

function __ffjord(_model::StatefulLuxLayer{FST}, u::AbstractArray{T, N}, p, ad = nothing,
regularize::Bool = false, monte_carlo::Bool = true) where {T, N, FST}
function __ffjord(model::StatefulLuxLayer, u::AbstractArray{T, N}, p, ad = nothing,
regularize::Bool = false, monte_carlo::Bool = true) where {T, N}
L = size(u, N - 1)
z = selectdim(u, N - 1, 1:(L - ifelse(regularize, 3, 1)))
model = StatefulLuxLayer{FST}(_model.model, p, ifelse(FST, _model.st, _model.st_any))
@set! model.ps = p
mz = model(z, p)
@assert size(mz) == size(z)
if monte_carlo
Expand Down Expand Up @@ -209,24 +208,14 @@ Arguments:
- `regularize`: Whether we use regularization (default: `false`).
- `monte_carlo`: Whether we use monte carlo (default: `true`).
"""
@concrete struct FFJORDDistribution{F <: FFJORD} <: ContinuousMultivariateDistribution
model::F
@concrete struct FFJORDDistribution <: ContinuousMultivariateDistribution
model <: FFJORD
ps
st
end

Base.length(d::FFJORDDistribution) = prod(d.model.input_dims)
Base.eltype(d::FFJORDDistribution) = __eltype(d.ps)

__eltype(x::AbstractArray) = eltype(x)
function __eltype(x)
T = Ref(Bool)
fmap(x) do x_
T[] = promote_type(T[], __eltype(x_))
return x_
end
return T[]
end
Base.eltype(d::FFJORDDistribution) = Lux.recursive_eltype(d.ps)

function Distributions._logpdf(d::FFJORDDistribution, x::AbstractVector)
return first(first(__forward_ffjord(d.model, reshape(x, :, 1), d.ps, d.st)))
Expand Down
13 changes: 6 additions & 7 deletions src/hnn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,8 @@ References:
[1] Greydanus, Samuel, Misko Dzamba, and Jason Yosinski. "Hamiltonian Neural Networks."
Advances in Neural Information Processing Systems 32 (2019): 15379-15389.
"""
@concrete struct HamiltonianNN{M <: AbstractExplicitLayer} <:
AbstractExplicitContainerLayer{(:model,)}
model::M
@concrete struct HamiltonianNN <: AbstractExplicitContainerLayer{(:model,)}
model <: AbstractExplicitLayer
ad
end

Expand All @@ -41,11 +40,11 @@ function HamiltonianNN(model; ad = AutoZygote())
return HamiltonianNN(model, ad)
end

function __hamiltonian_forward(ad::AutoForwardDiff, model, x)
function __hamiltonian_forward(::AutoForwardDiff, model, x)
return ForwardDiff.gradient(sum model, x)
end

function __hamiltonian_forward(ad::AutoZygote, model::StatefulLuxLayer, x)
function __hamiltonian_forward(::AutoZygote, model::StatefulLuxLayer, x)
return only(Zygote.gradient(sum model, x))
end

Expand All @@ -71,8 +70,8 @@ Arguments:
[Common Solver Arguments](https://docs.sciml.ai/DiffEqDocs/stable/basics/common_solver_opts/)
documentation for more details.
"""
@concrete struct NeuralHamiltonianDE{M <: HamiltonianNN} <: NeuralDELayer
model::M
@concrete struct NeuralHamiltonianDE <: NeuralDELayer
model <: HamiltonianNN
tspan
args
kwargs
Expand Down
30 changes: 14 additions & 16 deletions src/neural_de.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ References:
[1] Pontryagin, Lev Semenovich. Mathematical theory of optimal processes. CRC press, 1987.
"""
@concrete struct NeuralODE{M <: AbstractExplicitLayer} <: NeuralDELayer
model::M
@concrete struct NeuralODE <: NeuralDELayer
model <: AbstractExplicitLayer
tspan
args
kwargs
Expand Down Expand Up @@ -77,10 +77,9 @@ Arguments:
[Common Solver Arguments](https://docs.sciml.ai/DiffEqDocs/stable/basics/common_solver_opts/)
documentation for more details.
"""
@concrete struct NeuralDSDE{M1 <: AbstractExplicitLayer, M2 <: AbstractExplicitLayer} <:
NeuralSDELayer
drift::M1
diffusion::M2
@concrete struct NeuralDSDE <: NeuralSDELayer
drift <: AbstractExplicitLayer
diffusion <: AbstractExplicitLayer
tspan
args
kwargs
Expand Down Expand Up @@ -126,10 +125,9 @@ Arguments:
[Common Solver Arguments](https://docs.sciml.ai/DiffEqDocs/stable/basics/common_solver_opts/)
documentation for more details.
"""
@concrete struct NeuralSDE{M1 <: AbstractExplicitLayer, M2 <: AbstractExplicitLayer} <:
NeuralSDELayer
drift::M1
diffusion::M2
@concrete struct NeuralSDE <: NeuralSDELayer
drift <: AbstractExplicitLayer
diffusion <: AbstractExplicitLayer
tspan
nbrown::Int
args
Expand Down Expand Up @@ -181,8 +179,8 @@ Arguments:
[Common Solver Arguments](https://docs.sciml.ai/DiffEqDocs/stable/basics/common_solver_opts/)
documentation for more details.
"""
@concrete struct NeuralCDDE{M <: AbstractExplicitLayer} <: NeuralDELayer
model::M
@concrete struct NeuralCDDE <: NeuralDELayer
model <: AbstractExplicitLayer
tspan
hist
lags
Expand Down Expand Up @@ -232,8 +230,8 @@ Arguments:
[Common Solver Arguments](https://docs.sciml.ai/DiffEqDocs/stable/basics/common_solver_opts/)
documentation for more details.
"""
@concrete struct NeuralDAE{M <: AbstractExplicitLayer} <: NeuralDELayer
model::M
@concrete struct NeuralDAE <: NeuralDELayer
model <: AbstractExplicitLayer
constraints_model
tspan
args
Expand Down Expand Up @@ -307,8 +305,8 @@ Arguments:
[Common Solver Arguments](https://docs.sciml.ai/DiffEqDocs/stable/basics/common_solver_opts/)
documentation for more details.
"""
@concrete struct NeuralODEMM{M <: AbstractExplicitLayer} <: NeuralDELayer
model::M
@concrete struct NeuralODEMM <: NeuralDELayer
model <: AbstractExplicitLayer
constraints_model
tspan
mass_matrix
Expand Down

0 comments on commit 96a2d34

Please sign in to comment.