Skip to content

Commit

Permalink
Split code into submodules for the different approximations (#96)
Browse files Browse the repository at this point in the history
* stab at submodules

* move approx_lml to API

* Apply suggestions from code review

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* clean up imports

* remove unused imports

* Apply suggestions from code review

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* docfix

* missing import

* test: remove submodules from api docs

* subpages in docs

* hide private API docstrings

* mention types in api/index.md

* qualify

* remove extra line

* minor bump#

* change docs/ compat

* update test compat

* explicitly `@reexport` to avoid exporting the submodules!

* Apply suggestions from code review

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* add likelihood imports

* add import

* Apply suggestions from code review

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* add import

* Update src/SparseVariationalApproximationModule.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
st-- and github-actions[bot] authored Mar 11, 2022
1 parent 783edda commit e57ac47
Show file tree
Hide file tree
Showing 18 changed files with 146 additions and 52 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ApproximateGPs"
uuid = "298c2ebc-0411-48ad-af38-99e88101b606"
authors = ["JuliaGaussianProcesses Team"]
version = "0.2.8"
version = "0.3.0"

[deps]
AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918"
Expand Down
2 changes: 1 addition & 1 deletion docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@ ApproximateGPs = "298c2ebc-0411-48ad-af38-99e88101b606"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"

[compat]
ApproximateGPs = "0.2"
ApproximateGPs = "0.3"
Documenter = "0.27"
2 changes: 1 addition & 1 deletion docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ makedocs(;
pages=[
"Home" => "index.md",
"userguide.md",
"API" => "api.md",
"API" => ["api/index.md", "api/sparsevariational.md", "api/laplace.md"],
"Examples" =>
map(filter!(filename -> endswith(filename, ".md"), readdir(EXAMPLES_OUT))) do x
return joinpath("examples", x)
Expand Down
5 changes: 0 additions & 5 deletions docs/src/api.md

This file was deleted.

6 changes: 6 additions & 0 deletions docs/src/api/index.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# ApproximateGPs API

```@autodocs
Modules = [ApproximateGPs, ApproximateGPs.API]
Private = false
```
6 changes: 6 additions & 0 deletions docs/src/api/laplace.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Laplace Approximation

```@autodocs
Modules = [ApproximateGPs.LaplaceApproximationModule]
Private = false
```
6 changes: 6 additions & 0 deletions docs/src/api/sparsevariational.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Sparse Variational Approximation

```@autodocs
Modules = [ApproximateGPs.SparseVariationalApproximationModule]
Private = false
```
15 changes: 15 additions & 0 deletions src/API.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
module API

export approx_lml # TODO move to AbstractGPs, see https://github.com/JuliaGaussianProcesses/AbstractGPs.jl/issues/221

"""
approx_lml(approx::<Approximation>, lfx::LatentFiniteGP, ys)
Compute an approximation to the log of the marginal likelihood (also known as
"evidence") under the given `approx` to the posterior. This approximation can be used to optimise the hyperparameters of `lfx`.
This should become part of the AbstractGPs API (see JuliaGaussianProcesses/AbstractGPs.jl#221).
"""
function approx_lml end

end
36 changes: 13 additions & 23 deletions src/ApproximateGPs.jl
Original file line number Diff line number Diff line change
@@ -1,35 +1,25 @@
module ApproximateGPs

using Reexport

@reexport using AbstractGPs
@reexport using GPLikelihoods
using Distributions
using LinearAlgebra
using Statistics
using StatsBase
using FastGaussQuadrature
using SpecialFunctions
using ChainRulesCore
using FillArrays
using PDMats: chol_lower
using IrrationalConstants: sqrt2, invsqrtπ

using AbstractGPs: AbstractGP, FiniteGP, LatentFiniteGP, ApproxPosteriorGP, At_A, diag_At_A

include("utils.jl")

export DefaultQuadrature, Analytic, GaussHermite, MonteCarlo
include("expected_loglik.jl")
include("API.jl")
@reexport using .API: approx_lml

export SparseVariationalApproximation, Centered, NonCentered
include("sparse_variational.jl")
include("utils.jl")

using ForwardDiff
include("SparseVariationalApproximationModule.jl")
@reexport using .SparseVariationalApproximationModule:
SparseVariationalApproximation, Centered, NonCentered
@reexport using .SparseVariationalApproximationModule:
DefaultQuadrature, Analytic, GaussHermite, MonteCarlo

export LaplaceApproximation
export build_laplace_objective, build_laplace_objective!
export approx_lml # TODO move to AbstractGPs, see https://github.com/JuliaGaussianProcesses/AbstractGPs.jl/issues/221
include("laplace.jl")
include("LaplaceApproximationModule.jl")
@reexport using .LaplaceApproximationModule: LaplaceApproximation
@reexport using .LaplaceApproximationModule:
build_laplace_objective, build_laplace_objective!

include("deprecations.jl")

Expand Down
29 changes: 26 additions & 3 deletions src/laplace.jl → src/LaplaceApproximationModule.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,22 @@
module LaplaceApproximationModule

using ..API

export LaplaceApproximation
export build_laplace_objective, build_laplace_objective!

using ForwardDiff: ForwardDiff
using Distributions
using LinearAlgebra
using Statistics
using StatsBase

using ChainRulesCore: ignore_derivatives, NoTangent, @thunk
using ChainRulesCore: ChainRulesCore

using AbstractGPs: AbstractGPs
using AbstractGPs: LatentFiniteGP, ApproxPosteriorGP

# Implementation follows Rasmussen & Williams, Gaussian Processes for Machine
# Learning, the MIT Press, 2006. In the following referred to as 'RW'.
# Online text:
Expand Down Expand Up @@ -36,7 +55,7 @@ Compute an approximation to the log of the marginal likelihood (also known as
This should become part of the AbstractGPs API (see JuliaGaussianProcesses/AbstractGPs.jl#221).
"""
function approx_lml(la::LaplaceApproximation, lfx::LatentFiniteGP, ys)
function API.approx_lml(la::LaplaceApproximation, lfx::LatentFiniteGP, ys)
return laplace_lml(lfx, ys; la.newton_kwargs...)
end

Expand Down Expand Up @@ -309,11 +328,13 @@ function ChainRulesCore.rrule(::typeof(newton_inner_loop), dist_y_given_f, ys, K
function newton_pullback(Δf_opt)
∂self = NoTangent()

∂dist_y_given_f = @not_implemented(
∂dist_y_given_f = ChainRulesCore.@not_implemented(
"gradient of Newton's method w.r.t. likelihood parameters"
)

∂ys = @not_implemented("gradient of Newton's method w.r.t. observations")
∂ys = ChainRulesCore.@not_implemented(
"gradient of Newton's method w.r.t. observations"
)

# ∂K = df/dK Δf
∂K = @thunk(cache.Wsqrt * (cache.B_ch \ (cache.Wsqrt \ Δf_opt)) * cache.d_loglik')
Expand Down Expand Up @@ -417,3 +438,5 @@ function Statistics.cov(f::LaplacePosteriorGP, x::AbstractVector, y::AbstractVec
vy = L \ (f.data.Wsqrt * cov(f.prior.f, f.prior.x, y))
return cov(f.prior.f, x, y) - vx' * vy
end

end
Original file line number Diff line number Diff line change
@@ -1,3 +1,33 @@
module SparseVariationalApproximationModule

using ..API

export SparseVariationalApproximation, Centered, NonCentered

using ..ApproximateGPs: _chol_cov, _cov
using Distributions
using LinearAlgebra
using Statistics
using StatsBase
using FillArrays: Fill
using PDMats: chol_lower
using IrrationalConstants: sqrt2, invsqrtπ

using AbstractGPs: AbstractGPs
using AbstractGPs:
AbstractGP,
FiniteGP,
LatentFiniteGP,
ApproxPosteriorGP,
posterior,
marginals,
At_A,
diag_At_A
using GPLikelihoods: GaussianLikelihood

export DefaultQuadrature, Analytic, GaussHermite, MonteCarlo
include("expected_loglik.jl")

@doc raw"""
Centered()
Expand Down Expand Up @@ -341,3 +371,5 @@ function _prior_kl(sva::SparseVariationalApproximation{NonCentered})

return (trace_term + m_ε'm_ε - length(m_ε) - logdet(C_ε)) / 2
end

end
5 changes: 5 additions & 0 deletions src/expected_loglik.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
using GPLikelihoods
using FastGaussQuadrature: gausshermite
using SpecialFunctions: loggamma
using ChainRulesCore: ChainRulesCore

abstract type QuadratureMethod end
struct DefaultQuadrature <: QuadratureMethod end
struct Analytic <: QuadratureMethod end
Expand Down
3 changes: 3 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
using Distributions
using LinearAlgebra

# These methods to create a Cholesky directly from the factorisation will be in Julia 1.7
# https://github.com/JuliaLang/julia/pull/39352
if VERSION < v"1.7"
Expand Down
2 changes: 1 addition & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
AbstractGPs = "0.4, 0.5"
ApproximateGPs = "0.2"
ApproximateGPs = "0.3"
ChainRulesCore = "1"
ChainRulesTestUtils = "1.2.3"
Distributions = "0.25"
Expand Down
26 changes: 16 additions & 10 deletions test/expected_loglik.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@
# Test that we're not missing any analytic implementation in `likelihoods_to_test`!
implementation_types = [
(; quadrature=m.sig.types[2], lik=m.sig.types[5]) for
m in methods(ApproximateGPs.expected_loglik)
m in methods(SparseVariationalApproximationModule.expected_loglik)
]
analytic_likelihoods = [
m.lik for m in implementation_types if
m.quadrature == ApproximateGPs.Analytic && m.lik != Any
m.quadrature == SparseVariationalApproximationModule.Analytic && m.lik != Any
]
for lik_type in analytic_likelihoods
@test any(lik isa lik_type for lik in likelihoods_to_test)
Expand All @@ -28,23 +28,27 @@

@testset "$(nameof(typeof(lik)))" for lik in likelihoods_to_test
methods = [GaussHermite(100), MonteCarlo(1e7)]
def = ApproximateGPs._default_quadrature(lik)
def = SparseVariationalApproximationModule._default_quadrature(lik)
if def isa Analytic
push!(methods, def)
end
y = rand.(rng, lik.(zeros(10)))

results = map(m -> ApproximateGPs.expected_loglik(m, y, q_f, lik), methods)
results = map(
m -> SparseVariationalApproximationModule.expected_loglik(m, y, q_f, lik),
methods,
)
@test all(x -> isapprox(x, results[end]; atol=1e-6, rtol=1e-3), results)
end

@test ApproximateGPs.expected_loglik(
@test SparseVariationalApproximationModule.expected_loglik(
MonteCarlo(), zeros(10), q_f, GaussianLikelihood()
) isa Real
@test ApproximateGPs.expected_loglik(
@test SparseVariationalApproximationModule.expected_loglik(
GaussHermite(), zeros(10), q_f, GaussianLikelihood()
) isa Real
@test ApproximateGPs._default_quadrature-> Normal(0, θ)) isa GaussHermite
@test SparseVariationalApproximationModule._default_quadrature-> Normal(0, θ)) isa
GaussHermite

@testset "testing Zygote compatibility with GaussHermite" begin # see issue #82
N = 10
Expand All @@ -55,7 +59,9 @@
for lik in likelihoods_to_test
y = rand.(rng, lik.(rand.(Normal.(μs, σs))))
gμ, glogσ = Zygote.gradient(μs, log.(σs)) do μ, logσ
ApproximateGPs.expected_loglik(gh, y, Normal.(μ, exp.(logσ)), lik)
SparseVariationalApproximationModule.expected_loglik(
gh, y, Normal.(μ, exp.(logσ)), lik
)
end
@test all(isfinite, gμ)
@test all(isfinite, glogσ)
Expand All @@ -66,7 +72,7 @@
y = randn(rng, N)
glogσ = only(
Zygote.gradient(log(σ)) do x
ApproximateGPs.expected_loglik(
SparseVariationalApproximationModule.expected_loglik(
gh, y, Normal.(μs, σs), GaussianLikelihood(exp(x))
)
end,
Expand All @@ -77,7 +83,7 @@
y = rand.(rng, Gamma.(α, rand(N)))
glogα = only(
Zygote.gradient(log(α)) do x
ApproximateGPs.expected_loglik(
SparseVariationalApproximationModule.expected_loglik(
gh, y, Normal.(μs, σs), GammaLikelihood(exp(x))
)
end,
Expand Down
16 changes: 11 additions & 5 deletions test/laplace.jl
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@
function eval_newton_inner_loop(theta)
k = with_lengthscale(Matern52Kernel(), exp(theta))
K = kernelmatrix(k, xs)
f, cache = ApproximateGPs._newton_inner_loop(
f, cache = LaplaceApproximationModule._newton_inner_loop(
dist_y_given_f, ys, K; f_init=zero(xs), maxiter=100
)
return f
Expand All @@ -133,7 +133,9 @@

function newton_inner_loop_from_L(dist_y_given_f, ys, L; kwargs...)
K = L'L
return ApproximateGPs.newton_inner_loop(dist_y_given_f, ys, K; kwargs...)
return LaplaceApproximationModule.newton_inner_loop(
dist_y_given_f, ys, K; kwargs...
)
end

function ChainRulesCore.frule(
Expand All @@ -149,7 +151,7 @@
ΔK = ΔL'L + L'ΔL
return frule(
(Δself, Δdist_y_given_f, Δys, ΔK),
ApproximateGPs.newton_inner_loop,
LaplaceApproximationModule.newton_inner_loop,
dist_y_given_f,
ys,
K;
Expand All @@ -162,7 +164,11 @@
)
K = L'L
f_opt, newton_from_K_pullback = rrule(
ApproximateGPs.newton_inner_loop, dist_y_given_f, ys, K; kwargs...
LaplaceApproximationModule.newton_inner_loop,
dist_y_given_f,
ys,
K;
kwargs...,
)

function newton_from_L_pullback(Δf_opt)
Expand Down Expand Up @@ -254,7 +260,7 @@
lf = build_latent_gp(theta0)
lfx = lf(X)

res_array = ApproximateGPs.laplace_steps(lfx, Y)
res_array = LaplaceApproximationModule.laplace_steps(lfx, Y)
res = res_array[end]
@test res.q isa MvNormal
end
Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ using Zygote
using ChainRulesCore
using ChainRulesTestUtils
using FiniteDifferences
using ApproximateGPs: SparseVariationalApproximationModule, LaplaceApproximationModule

# Writing tests:
# 1. The file structure of the test should match precisely the file structure of src.
Expand Down
4 changes: 2 additions & 2 deletions test/sparse_variational.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@

@testset "Verify that the non-centered approximate posterior agrees with centered" begin
@test isapprox(
ApproximateGPs._prior_kl(approx_non_Centered),
ApproximateGPs._prior_kl(approx_Centered);
SparseVariationalApproximationModule._prior_kl(approx_non_Centered),
SparseVariationalApproximationModule._prior_kl(approx_Centered);
rtol=1e-5,
)
@test mean(f_approx_post_non_Centered, a) mean(f_approx_post_Centered, a)
Expand Down

2 comments on commit e57ac47

@st--
Copy link
Member Author

@st-- st-- commented on e57ac47 Mar 11, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/56416

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.3.0 -m "<description of version>" e57ac47bb0540f7251f1d9e560417ae65dcbf90f
git push origin v0.3.0

Please sign in to comment.