diff --git a/Project.toml b/Project.toml index e580847c..4d78bbd4 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ApproximateGPs" uuid = "298c2ebc-0411-48ad-af38-99e88101b606" authors = ["JuliaGaussianProcesses Team"] -version = "0.3.3" +version = "0.3.4" [deps] AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918" diff --git a/examples/c-comparisons/script.jl b/examples/c-comparisons/script.jl index 011c17fd..287ef47a 100644 --- a/examples/c-comparisons/script.jl +++ b/examples/c-comparisons/script.jl @@ -127,10 +127,10 @@ lf2.f.kernel # Finally, we need to construct again the (approximate) posterior given the # observations for the latent GP with optimised hyperparameters: -f_post2 = posterior(LaplaceApproximation(; f_init=objective.f), lf2(X), Y) +f_post2 = posterior(LaplaceApproximation(; f_init=objective.cache.f), lf2(X), Y) -# By passing `f_init=objective.f` we let the Laplace approximation "warm-start" -# at the last point of the inner-loop Newton optimisation; `objective.f` is a +# By passing `f_init=objective.cache.f` we let the Laplace approximation "warm-start" +# at the last point of the inner-loop Newton optimisation; `objective.cache` is a # field on the `objective` closure. # Let's plot samples from the approximate posterior for the optimised hyperparameters: diff --git a/src/LaplaceApproximationModule.jl b/src/LaplaceApproximationModule.jl index d43615f7..e9e968b6 100644 --- a/src/LaplaceApproximationModule.jl +++ b/src/LaplaceApproximationModule.jl @@ -75,13 +75,25 @@ closure passes its arguments to `build_latent_gp`, which must return the - `newton_maxiter=100`: maximum number of Newton steps. """ function build_laplace_objective(build_latent_gp, xs, ys; kwargs...) - # TODO assumes type of `xs` will be same as `mean(lfx.fx)` - f = similar(xs, length(xs)) # will be mutated in-place to "warm-start" the Newton steps - return build_laplace_objective!(f, build_latent_gp, xs, ys; kwargs...) + cache = LaplaceObjectiveCache(nothing) + # cache.f will be mutated in-place to "warm-start" the Newton steps + # f should be similar(mean(lfx.fx)), but to construct lfx we would need the arguments + # so we set it to `nothing` initially, and set it to mean(lfx.fx) within the objective + return build_laplace_objective!(cache, build_latent_gp, xs, ys; kwargs...) +end + +function build_laplace_objective!(f_init::Vector, build_latent_gp, xs, ys; kwargs...) + return build_laplace_objective!( + LaplaceObjectiveCache(f_init), build_latent_gp, xs, ys; kwargs... + ) +end + +mutable struct LaplaceObjectiveCache + f::Union{Nothing,Vector} end function build_laplace_objective!( - f, + cache::LaplaceObjectiveCache, build_latent_gp, xs, ys; @@ -98,16 +110,18 @@ function build_laplace_objective!( # Zygote does not like the try/catch within @info etc. @debug "Objective arguments: $args" # Zygote does not like in-place assignments either - if initialize_f - f .= mean(lfx.fx) + if cache.f === nothing + cache.f = mean(lfx.fx) + elseif initialize_f + cache.f .= mean(lfx.fx) end end f_opt, lml = laplace_f_and_lml( - lfx, ys; f_init=f, maxiter=newton_maxiter, callback=newton_callback + lfx, ys; f_init=cache.f, maxiter=newton_maxiter, callback=newton_callback ) ignore_derivatives() do if newton_warmstart - f .= f_opt + cache.f .= f_opt initialize_f = false end end diff --git a/test/LaplaceApproximationModule.jl b/test/LaplaceApproximationModule.jl index a3f705c2..547b60f3 100644 --- a/test/LaplaceApproximationModule.jl +++ b/test/LaplaceApproximationModule.jl @@ -23,7 +23,7 @@ ) lf = build_latent_gp(training_results.minimizer) - f_post = posterior(LaplaceApproximation(; f_init=objective.f), lf(xs), ys) + f_post = posterior(LaplaceApproximation(; f_init=objective.cache.f), lf(xs), ys) return f_post, training_results end @@ -208,4 +208,14 @@ res = res_array[end] @test res.q isa MvNormal end + + @testset "GitHub issue #109" begin + build_latent_gp() = LatentGP(GP(SEKernel()), BernoulliLikelihood(), 1e-8) + + x = ColVecs(randn(2, 5)) + _, y = rand(build_latent_gp()(x)) + + objective = build_laplace_objective(build_latent_gp, x, y) + _ = objective() # check that it works + end end