Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

restore and save optsum for GLMM #791

Merged
merged 15 commits into from
Nov 8, 2024
6 changes: 6 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
MixedModels v4.27.0 Release Notes
==============================
- `saveoptsum` and `restoreoptsum!` now support `GeneralizedLinearMixedModel`s [#791]
- `unfit!(::GeneralizedLinearMixedModel)` (called internally by `refit!`) now does a better job of fully resetting the model state [#791]

MixedModels v4.26.1 Release Notes
==============================
- lower and upper edges of profile confidence intervals for REML-fitted models are no longer flipped [#785]
Expand Down Expand Up @@ -569,3 +574,4 @@ Package dependencies
[#778]: https://github.com/JuliaStats/MixedModels.jl/issues/778
[#783]: https://github.com/JuliaStats/MixedModels.jl/issues/783
[#785]: https://github.com/JuliaStats/MixedModels.jl/issues/785
[#791]: https://github.com/JuliaStats/MixedModels.jl/issues/791
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MixedModels"
uuid = "ff71e718-51f3-5ec2-a782-8ffcbfa3c316"
author = ["Phillip Alday <[email protected]>", "Douglas Bates <[email protected]>", "Jose Bayoan Santiago Calderon <[email protected]>"]
version = "4.26.1"
version = "4.27.0"

[deps]
Arrow = "69666777-d1a9-59fb-9406-91d4454c9d45"
Expand Down
6 changes: 4 additions & 2 deletions src/generalizedlinearmixedmodel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -719,7 +719,7 @@
io::IO, ::MIME"text/plain", m::GeneralizedLinearMixedModel{T,D}
) where {T,D}
if m.optsum.feval < 0
@warn("Model has not been fit")

Check warning on line 722 in src/generalizedlinearmixedmodel.jl

View workflow job for this annotation

GitHub Actions / Documentation

Model has not been fit
return nothing
end
nAGQ = m.LMM.optsum.nAGQ
Expand Down Expand Up @@ -767,19 +767,21 @@
end

function unfit!(model::GeneralizedLinearMixedModel{T}) where {T}
deviance!(model, 1)
reevaluateAend!(model.LMM)

reterms = model.LMM.reterms
optsum = model.LMM.optsum
# we need to reset optsum so that it
# plays nice with the modifications fit!() does
optsum.lowerbd = mapfoldl(lowerbd, vcat, reterms)
optsum.initial = mapfoldl(getθ, vcat, reterms)
# for variances (bounded at zero), we have ones, while
# for everything else (bounded at -Inf), we have zeros
optsum.initial = map(T ∘ iszero, optsum.lowerbd)
optsum.final = copy(optsum.initial)
optsum.xtol_abs = fill!(copy(optsum.initial), 1.0e-10)
optsum.initial_step = T[]
optsum.feval = -1
deviance!(model, 1)

return model
end
Expand Down
6 changes: 6 additions & 0 deletions src/optsummary.jl
Original file line number Diff line number Diff line change
Expand Up @@ -162,3 +162,9 @@ function _check_nlopt_return(ret, failure_modes=_NLOPT_FAILURE_MODES)
@warn("NLopt optimization failure: $ret")
end
end

function Base.:(==)(o1::OptSummary{T}, o2::OptSummary{T}) where {T}
return all(fieldnames(OptSummary)) do fn
return getfield(o1, fn) == getfield(o2, fn)
end
end
93 changes: 72 additions & 21 deletions src/serialization.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,78 @@
"""
restoreoptsum!(m::LinearMixedModel, io::IO; atol::Real=0, rtol::Real=atol>0 ? 0 : √eps)
restoreoptsum!(m::LinearMixedModel, filename; atol::Real=0, rtol::Real=atol>0 ? 0 : √eps)
restoreoptsum!(m::MixedModel, io::IO; atol::Real=0, rtol::Real=atol>0 ? 0 : √eps)
restoreoptsum!(m::MixedModel, filename; atol::Real=0, rtol::Real=atol>0 ? 0 : √eps)

Read, check, and restore the `optsum` field from a JSON stream or filename.
"""
function restoreoptsum!(m::MixedModel, filename; kwargs...)
return open(filename, "r") do io
return restoreoptsum!(m, io; kwargs...)
end
end

function restoreoptsum!(
m::LinearMixedModel{T}, io::IO; atol::Real=zero(T),
rtol::Real=atol > 0 ? zero(T) : √eps(T),
) where {T}
dict = JSON3.read(io)
ops = restoreoptsum!(m.optsum, dict)
for (par, obj_at_par) in (:initial => :finitial, :final => :fmin)
if !isapprox(
objective(updateL!(setθ!(m, getfield(ops, par)))), getfield(ops, obj_at_par);
rtol, atol,
)
throw(
ArgumentError(
"model at $par does not match stored $obj_at_par within atol=$atol, rtol=$rtol"
),
)
end
end
return m
end

function restoreoptsum!(
m::GeneralizedLinearMixedModel{T}, io::IO; atol::Real=zero(T),
rtol::Real=atol > 0 ? zero(T) : √eps(T),
) where {T}
dict = JSON3.read(io)
ops = m.optsum

# need to accommodate fast and slow fits
resize!(ops.initial, length(dict.initial))
resize!(ops.final, length(dict.final))

theta_beta_len = length(m.θ) + length(m.β)
if length(dict.initial) == theta_beta_len # fast=false
if length(ops.lowerbd) == length(m.θ)
prepend!(ops.lowerbd, fill(-Inf, length(m.β)))
end
setpar! = setβθ!
varyβ = false
else # fast=true
setpar! = setθ!
varyβ = true
if length(ops.lowerbd) != length(m.θ)
deleteat!(ops.lowerbd, 1:length(m.β))
Comment on lines +55 to +56
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose it can never be the case that length(ops.lowerbd) < length(m.β) or that ops.lowerbd has indices that don't start at 1?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lowerbd is in a struct we control and it's always Vector, so the indices are fine.

There are two cases here:

  • the "slow" optimization, where the parameter vector contains both beta and theta, so length(lowerbd) needs to be length(beta) + length(theta).
  • the "fast" optimization, where the parameter vector contains only theta, so length(lowerbd) needs to be length(theta). (In this case, there is an inner optimization step that finds the best beta for a given theta using PIRLS -- this will generally give the same answer or a very close one to doing free optimization over both, but there are a few pathological cases where the two methods can diverge, in which case the extra freedom of the slow optimization will yield more accurate results.)

The initial model creation initializes the lower bound to the length of theta because that is shared with the initialization for the linear mixed model (where only theta is included in optimization and beta is solved for directly).

end
end
restoreoptsum!(ops, dict)
for (par, obj_at_par) in (:initial => :finitial, :final => :fmin)
if !isapprox(
deviance(pirls!(setpar!(m, getfield(ops, par)), varyβ), dict.nAGQ),
getfield(ops, obj_at_par); rtol, atol,
)
throw(

Check warning on line 65 in src/serialization.jl

View check run for this annotation

Codecov / codecov/patch

src/serialization.jl#L65

Added line #L65 was not covered by tests
ArgumentError(
"model at $par does not match stored $obj_at_par within atol=$atol, rtol=$rtol"
),
)
end
end
return m
end

function restoreoptsum!(ops::OptSummary{T}, dict::AbstractDict) where {T}
allowed_missing = (
:lowerbd, # never saved, -Inf not allowed in JSON
:xtol_zero_abs, # added in v4.25.0
Expand All @@ -27,7 +90,9 @@
if length(setdiff(allowed_missing, keys(dict))) > 1 # 1 because :lowerbd
@warn "optsum was saved with an older version of MixedModels.jl: consider resaving."
end

if any(ops.lowerbd .> dict.initial) || any(ops.lowerbd .> dict.final)
@debug "" ops.lowerbd dict.initial dict.final
throw(ArgumentError("initial or final parameters in io do not satisfy lowerbd"))
end
for fld in (:feval, :finitial, :fmin, :ftol_rel, :ftol_abs, :maxfeval, :nAGQ, :REML)
Expand All @@ -37,13 +102,6 @@
ops.xtol_rel = copy(dict.xtol_rel)
copyto!(ops.initial, dict.initial)
copyto!(ops.final, dict.final)
for (v, f) in (:initial => :finitial, :final => :fmin)
if !isapprox(
objective(updateL!(setθ!(m, getfield(ops, v)))), getfield(ops, f); rtol, atol
)
throw(ArgumentError("model m at $v does not give stored $f"))
end
end
ops.optimizer = Symbol(dict.optimizer)
ops.returnvalue = Symbol(dict.returnvalue)
# compatibility with fits saved before the introduction of various extensions
Expand All @@ -59,30 +117,23 @@
else
[(convert(Vector{T}, first(entry)), T(last(entry))) for entry in fitlog]
end
return m
end

function restoreoptsum!(m::LinearMixedModel{T}, filename; kwargs...) where {T}
open(filename, "r") do io
restoreoptsum!(m, io; kwargs...)
end
return ops
end

"""
saveoptsum(io::IO, m::LinearMixedModel)
saveoptsum(filename, m::LinearMixedModel)
saveoptsum(io::IO, m::MixedModel)
saveoptsum(filename, m::MixedModel)

Save `m.optsum` (w/o the `lowerbd` field) in JSON format to an IO stream or a file

The reason for omitting the `lowerbd` field is because it often contains `-Inf`
values that are not allowed in JSON.
"""
saveoptsum(io::IO, m::LinearMixedModel) = JSON3.write(io, m.optsum)
function saveoptsum(filename, m::LinearMixedModel)
saveoptsum(io::IO, m::MixedModel) = JSON3.write(io, m.optsum)
function saveoptsum(filename, m::MixedModel)
open(filename, "w") do io
saveoptsum(io, m)
end
end

# TODO: write methods for GLMM
# TODO, maybe: something nice for the MixedModelBootstrap
26 changes: 26 additions & 0 deletions test/pirls.jl
Original file line number Diff line number Diff line change
Expand Up @@ -239,3 +239,29 @@ end
@test isapprox(first(gm5.β), -0.13860166843315044, atol=1.e-3)
@test isapprox(last(gm5.β), -0.034414458364713504, atol=1.e-3)
end

@testset "GLMM saveoptsum" begin
cbpp = dataset(:cbpp)
gm_original = GeneralizedLinearMixedModel(first(gfms[:cbpp]), cbpp, Binomial(); wts=cbpp.hsz)
gm_restored = GeneralizedLinearMixedModel(first(gfms[:cbpp]), cbpp, Binomial(); wts=cbpp.hsz)
fit!(gm_original; progress=false, nAGQ=1)

io = IOBuffer()

saveoptsum(seekstart(io), gm_original)
restoreoptsum!(gm_restored, seekstart(io))
@test gm_original.optsum == gm_restored.optsum
@test deviance(gm_original) ≈ deviance(gm_restored)

refit!(gm_original; progress=false, nAGQ=3)
saveoptsum(seekstart(io), gm_original)
restoreoptsum!(gm_restored, seekstart(io))
@test gm_original.optsum == gm_restored.optsum
@test deviance(gm_original) ≈ deviance(gm_restored)

refit!(gm_original; progress=false, fast=true)
saveoptsum(seekstart(io), gm_original)
restoreoptsum!(gm_restored, seekstart(io))
@test gm_original.optsum == gm_restored.optsum
@test deviance(gm_original) ≈ deviance(gm_restored)
end
4 changes: 2 additions & 2 deletions test/pls.jl
Original file line number Diff line number Diff line change
Expand Up @@ -530,8 +530,8 @@ end
fm_mod = deepcopy(fm)
fm_mod.optsum.fmin += 1
saveoptsum(seekstart(io), fm_mod)
@test_throws(ArgumentError("model m at final does not give stored fmin"),
restoreoptsum!(m, seekstart(io)))
@test_throws(ArgumentError("model at final does not match stored fmin within atol=0.0, rtol=1.0e-8"),
restoreoptsum!(m, seekstart(io); atol=0.0, rtol=1e-8))
restoreoptsum!(m, seekstart(io); atol=1)
@test m.optsum.fmin - fm.optsum.fmin ≈ 1

Expand Down
Loading