Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix build_laplace_objective behaviour #115

Merged
merged 16 commits into from
Mar 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ApproximateGPs"
uuid = "298c2ebc-0411-48ad-af38-99e88101b606"
authors = ["JuliaGaussianProcesses Team"]
version = "0.3.3"
version = "0.3.4"

[deps]
AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918"
Expand Down
6 changes: 3 additions & 3 deletions examples/c-comparisons/script.jl
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,10 @@ lf2.f.kernel
# Finally, we need to construct again the (approximate) posterior given the
# observations for the latent GP with optimised hyperparameters:

f_post2 = posterior(LaplaceApproximation(; f_init=objective.f), lf2(X), Y)
f_post2 = posterior(LaplaceApproximation(; f_init=objective.cache.f), lf2(X), Y)

# By passing `f_init=objective.f` we let the Laplace approximation "warm-start"
# at the last point of the inner-loop Newton optimisation; `objective.f` is a
# By passing `f_init=objective.cache.f` we let the Laplace approximation "warm-start"
# at the last point of the inner-loop Newton optimisation; `objective.cache` is a
# field on the `objective` closure.

# Let's plot samples from the approximate posterior for the optimised hyperparameters:
Expand Down
30 changes: 22 additions & 8 deletions src/LaplaceApproximationModule.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,13 +75,25 @@ closure passes its arguments to `build_latent_gp`, which must return the
- `newton_maxiter=100`: maximum number of Newton steps.
"""
function build_laplace_objective(build_latent_gp, xs, ys; kwargs...)
# TODO assumes type of `xs` will be same as `mean(lfx.fx)`
f = similar(xs, length(xs)) # will be mutated in-place to "warm-start" the Newton steps
return build_laplace_objective!(f, build_latent_gp, xs, ys; kwargs...)
cache = LaplaceObjectiveCache(nothing)
# cache.f will be mutated in-place to "warm-start" the Newton steps
# f should be similar(mean(lfx.fx)), but to construct lfx we would need the arguments
# so we set it to `nothing` initially, and set it to mean(lfx.fx) within the objective
return build_laplace_objective!(cache, build_latent_gp, xs, ys; kwargs...)
end

function build_laplace_objective!(f_init::Vector, build_latent_gp, xs, ys; kwargs...)
return build_laplace_objective!(
LaplaceObjectiveCache(f_init), build_latent_gp, xs, ys; kwargs...
)
end

mutable struct LaplaceObjectiveCache
f::Union{Nothing,Vector}
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would there be any significant gains from changing it to two fields, f::Vector and f_initialized::Bool or something like that?

end

function build_laplace_objective!(
f,
cache::LaplaceObjectiveCache,
build_latent_gp,
xs,
ys;
Expand All @@ -98,16 +110,18 @@ function build_laplace_objective!(
# Zygote does not like the try/catch within @info etc.
@debug "Objective arguments: $args"
# Zygote does not like in-place assignments either
if initialize_f
f .= mean(lfx.fx)
if cache.f === nothing
cache.f = mean(lfx.fx)
elseif initialize_f
cache.f .= mean(lfx.fx)
end
end
f_opt, lml = laplace_f_and_lml(
lfx, ys; f_init=f, maxiter=newton_maxiter, callback=newton_callback
lfx, ys; f_init=cache.f, maxiter=newton_maxiter, callback=newton_callback
)
ignore_derivatives() do
if newton_warmstart
f .= f_opt
cache.f .= f_opt
initialize_f = false
end
end
Expand Down
12 changes: 11 additions & 1 deletion test/LaplaceApproximationModule.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
)

lf = build_latent_gp(training_results.minimizer)
f_post = posterior(LaplaceApproximation(; f_init=objective.f), lf(xs), ys)
f_post = posterior(LaplaceApproximation(; f_init=objective.cache.f), lf(xs), ys)
return f_post, training_results
end

Expand Down Expand Up @@ -208,4 +208,14 @@
res = res_array[end]
@test res.q isa MvNormal
end

@testset "GitHub issue #109" begin
build_latent_gp() = LatentGP(GP(SEKernel()), BernoulliLikelihood(), 1e-8)

x = ColVecs(randn(2, 5))
_, y = rand(build_latent_gp()(x))

objective = build_laplace_objective(build_latent_gp, x, y)
_ = objective() # check that it works
end
end