Skip to content

Commit

Permalink
move expected loglik to GPLikelihoods (#123)
Browse files Browse the repository at this point in the history
* move to GPLikelihoods 0.4; remove expected_loglik from here
* remove unnecessary internal function
  • Loading branch information
st-- authored Mar 29, 2022
1 parent 490ece8 commit 8ceb309
Show file tree
Hide file tree
Showing 8 changed files with 24 additions and 305 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ApproximateGPs"
uuid = "298c2ebc-0411-48ad-af38-99e88101b606"
authors = ["JuliaGaussianProcesses Team"]
version = "0.3.4"
version = "0.4.0"

[deps]
AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918"
Expand Down Expand Up @@ -29,7 +29,7 @@ Distributions = "0.25"
FastGaussQuadrature = "0.4"
FillArrays = "0.12, 0.13"
ForwardDiff = "0.10"
GPLikelihoods = "0.3"
GPLikelihoods = "0.4"
IrrationalConstants = "0.1"
LogExpFunctions = "0.3"
PDMats = "0.11"
Expand Down
2 changes: 1 addition & 1 deletion docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@ ApproximateGPs = "298c2ebc-0411-48ad-af38-99e88101b606"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"

[compat]
ApproximateGPs = "0.3"
ApproximateGPs = "0.3,0.4"
Documenter = "0.27"
2 changes: 0 additions & 2 deletions src/ApproximateGPs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@ include("utils.jl")
include("SparseVariationalApproximationModule.jl")
@reexport using .SparseVariationalApproximationModule:
SparseVariationalApproximation, Centered, NonCentered
@reexport using .SparseVariationalApproximationModule:
DefaultQuadrature, Analytic, GaussHermite, MonteCarlo

include("LaplaceApproximationModule.jl")
@reexport using .LaplaceApproximationModule: LaplaceApproximation
Expand Down
54 changes: 20 additions & 34 deletions src/SparseVariationalApproximationModule.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,12 @@ using ..API

export SparseVariationalApproximation, Centered, NonCentered

using ..ApproximateGPs: _chol_cov, _cov
using Distributions
using LinearAlgebra
using Statistics
using StatsBase
using FillArrays: Fill
using PDMats: chol_lower
using IrrationalConstants: sqrt2, invsqrtπ

using AbstractGPs: AbstractGPs
using AbstractGPs:
Expand All @@ -24,10 +22,8 @@ using AbstractGPs:
marginals,
At_A,
diag_At_A
using GPLikelihoods: GaussianLikelihood

export DefaultQuadrature, Analytic, GaussHermite, MonteCarlo
include("expected_loglik.jl")
using GPLikelihoods: GaussianLikelihood, DefaultExpectationMethod, expected_loglikelihood
using ..ApproximateGPs: _chol_cov, _cov

@doc raw"""
Centered()
Expand Down Expand Up @@ -289,22 +285,20 @@ end
fx::FiniteGP,
y::AbstractVector{<:Real};
num_data=length(y),
quadrature=DefaultQuadrature(),
quadrature=GPLikelihoods.DefaultExpectationMethod(),
)
Compute the Evidence Lower BOund from [1] for the process `f = fx.f ==
svgp.fz.f` where `y` are observations of `fx`, pseudo-inputs are given by `z =
svgp.fz.x` and `q(u)` is a variational distribution over inducing points `u =
f(z)`.
`quadrature` selects which method is used to calculate the expected loglikelihood in
the ELBO. The options are: `DefaultQuadrature()`, `Analytic()`, `GaussHermite()` and
`MonteCarlo()`. For likelihoods with an analytic solution, `DefaultQuadrature()` uses this
exact solution. If there is no such solution, `DefaultQuadrature()` either uses
`GaussHermite()` or `MonteCarlo()`, depending on the likelihood.
`quadrature` is passed on to `GPLikelihoods.expected_loglikelihood` and selects
which method is used to calculate the expected loglikelihood in the ELBO. See
`GPLikelihoods.expected_loglikelihood` for more details.
N.B. the likelihood is assumed to be Gaussian with observation noise `fx.Σy`.
Further, `fx.Σy` must be isotropic - i.e. `fx.Σy = α * I`.
Further, `fx.Σy` must be isotropic - i.e. `fx.Σy = σ² * I`.
[1] - Hensman, James, Alexander Matthews, and Zoubin Ghahramani. "Scalable
variational Gaussian process classification." Artificial Intelligence and
Expand All @@ -315,10 +309,11 @@ function AbstractGPs.elbo(
fx::FiniteGP{<:AbstractGP,<:AbstractVector,<:Diagonal{<:Real,<:Fill}},
y::AbstractVector{<:Real};
num_data=length(y),
quadrature=DefaultQuadrature(),
quadrature=DefaultExpectationMethod(),
)
@assert sva.fz.f === fx.f
return _elbo(quadrature, sva, fx, y, GaussianLikelihood(fx.Σy[1]), num_data)
σ² = fx.Σy[1]
lik = GaussianLikelihood(σ²)
return elbo(sva, LatentFiniteGP(fx, lik), y; num_data, quadrature)
end

function AbstractGPs.elbo(
Expand All @@ -337,7 +332,7 @@ end
lfx::LatentFiniteGP,
y::AbstractVector;
num_data=length(y),
quadrature=DefaultQuadrature(),
quadrature=GPLikelihoods.DefaultExpectationMethod(),
)
Compute the ELBO for a LatentGP with a possibly non-conjugate likelihood.
Expand All @@ -347,26 +342,17 @@ function AbstractGPs.elbo(
lfx::LatentFiniteGP,
y::AbstractVector;
num_data=length(y),
quadrature=DefaultQuadrature(),
)
@assert sva.fz.f === lfx.fx.f
return _elbo(quadrature, sva, lfx.fx, y, lfx.lik, num_data)
end

# Compute the common elements of the ELBO
function _elbo(
quadrature::QuadratureMethod,
sva::SparseVariationalApproximation,
fx::FiniteGP,
y::AbstractVector,
lik,
num_data::Integer,
quadrature=DefaultExpectationMethod(),
)
@assert sva.fz.f === fx.f
sva.fz.f === lfx.fx.f || throw(
ArgumentError(
"(Latent)FiniteGP prior is not consistent with SparseVariationalApproximation's",
),
)

f_post = posterior(sva)
q_f = marginals(f_post(fx.x))
variational_exp = expected_loglik(quadrature, y, q_f, lik)
q_f = marginals(f_post(lfx.fx.x))
variational_exp = expected_loglikelihood(quadrature, lfx.lik, q_f, y)

n_batch = length(y)
scale = num_data / n_batch
Expand Down
168 changes: 0 additions & 168 deletions src/expected_loglik.jl

This file was deleted.

2 changes: 1 addition & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
AbstractGPs = "0.4, 0.5"
ApproximateGPs = "0.3"
ApproximateGPs = "0.4"
ChainRulesCore = "1"
ChainRulesTestUtils = "1.2.3"
Distributions = "0.25"
Expand Down
Loading

2 comments on commit 8ceb309

@st--
Copy link
Member Author

@st-- st-- commented on 8ceb309 Mar 29, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/57585

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.4.0 -m "<description of version>" 8ceb30966ab2aca2bfa2fa7a51fc403022e7ed2f
git push origin v0.4.0

Please sign in to comment.