From 8ceb30966ab2aca2bfa2fa7a51fc403022e7ed2f Mon Sep 17 00:00:00 2001 From: st-- Date: Tue, 29 Mar 2022 21:55:33 +0200 Subject: [PATCH] move expected loglik to GPLikelihoods (#123) * move to GPLikelihoods 0.4; remove expected_loglik from here * remove unnecessary internal function --- Project.toml | 4 +- docs/Project.toml | 2 +- src/ApproximateGPs.jl | 2 - src/SparseVariationalApproximationModule.jl | 54 +++---- src/expected_loglik.jl | 168 -------------------- test/Project.toml | 2 +- test/expected_loglik.jl | 93 ----------- test/runtests.jl | 4 - 8 files changed, 24 insertions(+), 305 deletions(-) delete mode 100644 src/expected_loglik.jl delete mode 100644 test/expected_loglik.jl diff --git a/Project.toml b/Project.toml index 4d78bbd4..5df985ef 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ApproximateGPs" uuid = "298c2ebc-0411-48ad-af38-99e88101b606" authors = ["JuliaGaussianProcesses Team"] -version = "0.3.4" +version = "0.4.0" [deps] AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918" @@ -29,7 +29,7 @@ Distributions = "0.25" FastGaussQuadrature = "0.4" FillArrays = "0.12, 0.13" ForwardDiff = "0.10" -GPLikelihoods = "0.3" +GPLikelihoods = "0.4" IrrationalConstants = "0.1" LogExpFunctions = "0.3" PDMats = "0.11" diff --git a/docs/Project.toml b/docs/Project.toml index e3a5d366..ad869f1d 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -3,5 +3,5 @@ ApproximateGPs = "298c2ebc-0411-48ad-af38-99e88101b606" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" [compat] -ApproximateGPs = "0.3" +ApproximateGPs = "0.3,0.4" Documenter = "0.27" diff --git a/src/ApproximateGPs.jl b/src/ApproximateGPs.jl index 6d5e4d95..029be317 100644 --- a/src/ApproximateGPs.jl +++ b/src/ApproximateGPs.jl @@ -13,8 +13,6 @@ include("utils.jl") include("SparseVariationalApproximationModule.jl") @reexport using .SparseVariationalApproximationModule: SparseVariationalApproximation, Centered, NonCentered -@reexport using .SparseVariationalApproximationModule: - DefaultQuadrature, Analytic, GaussHermite, MonteCarlo include("LaplaceApproximationModule.jl") @reexport using .LaplaceApproximationModule: LaplaceApproximation diff --git a/src/SparseVariationalApproximationModule.jl b/src/SparseVariationalApproximationModule.jl index 0b6fdadb..132de1ab 100644 --- a/src/SparseVariationalApproximationModule.jl +++ b/src/SparseVariationalApproximationModule.jl @@ -4,14 +4,12 @@ using ..API export SparseVariationalApproximation, Centered, NonCentered -using ..ApproximateGPs: _chol_cov, _cov using Distributions using LinearAlgebra using Statistics using StatsBase using FillArrays: Fill using PDMats: chol_lower -using IrrationalConstants: sqrt2, invsqrtπ using AbstractGPs: AbstractGPs using AbstractGPs: @@ -24,10 +22,8 @@ using AbstractGPs: marginals, At_A, diag_At_A -using GPLikelihoods: GaussianLikelihood - -export DefaultQuadrature, Analytic, GaussHermite, MonteCarlo -include("expected_loglik.jl") +using GPLikelihoods: GaussianLikelihood, DefaultExpectationMethod, expected_loglikelihood +using ..ApproximateGPs: _chol_cov, _cov @doc raw""" Centered() @@ -289,7 +285,7 @@ end fx::FiniteGP, y::AbstractVector{<:Real}; num_data=length(y), - quadrature=DefaultQuadrature(), + quadrature=GPLikelihoods.DefaultExpectationMethod(), ) Compute the Evidence Lower BOund from [1] for the process `f = fx.f == @@ -297,14 +293,12 @@ svgp.fz.f` where `y` are observations of `fx`, pseudo-inputs are given by `z = svgp.fz.x` and `q(u)` is a variational distribution over inducing points `u = f(z)`. -`quadrature` selects which method is used to calculate the expected loglikelihood in -the ELBO. The options are: `DefaultQuadrature()`, `Analytic()`, `GaussHermite()` and -`MonteCarlo()`. For likelihoods with an analytic solution, `DefaultQuadrature()` uses this -exact solution. If there is no such solution, `DefaultQuadrature()` either uses -`GaussHermite()` or `MonteCarlo()`, depending on the likelihood. +`quadrature` is passed on to `GPLikelihoods.expected_loglikelihood` and selects +which method is used to calculate the expected loglikelihood in the ELBO. See +`GPLikelihoods.expected_loglikelihood` for more details. N.B. the likelihood is assumed to be Gaussian with observation noise `fx.Σy`. -Further, `fx.Σy` must be isotropic - i.e. `fx.Σy = α * I`. +Further, `fx.Σy` must be isotropic - i.e. `fx.Σy = σ² * I`. [1] - Hensman, James, Alexander Matthews, and Zoubin Ghahramani. "Scalable variational Gaussian process classification." Artificial Intelligence and @@ -315,10 +309,11 @@ function AbstractGPs.elbo( fx::FiniteGP{<:AbstractGP,<:AbstractVector,<:Diagonal{<:Real,<:Fill}}, y::AbstractVector{<:Real}; num_data=length(y), - quadrature=DefaultQuadrature(), + quadrature=DefaultExpectationMethod(), ) - @assert sva.fz.f === fx.f - return _elbo(quadrature, sva, fx, y, GaussianLikelihood(fx.Σy[1]), num_data) + σ² = fx.Σy[1] + lik = GaussianLikelihood(σ²) + return elbo(sva, LatentFiniteGP(fx, lik), y; num_data, quadrature) end function AbstractGPs.elbo( @@ -337,7 +332,7 @@ end lfx::LatentFiniteGP, y::AbstractVector; num_data=length(y), - quadrature=DefaultQuadrature(), + quadrature=GPLikelihoods.DefaultExpectationMethod(), ) Compute the ELBO for a LatentGP with a possibly non-conjugate likelihood. @@ -347,26 +342,17 @@ function AbstractGPs.elbo( lfx::LatentFiniteGP, y::AbstractVector; num_data=length(y), - quadrature=DefaultQuadrature(), -) - @assert sva.fz.f === lfx.fx.f - return _elbo(quadrature, sva, lfx.fx, y, lfx.lik, num_data) -end - -# Compute the common elements of the ELBO -function _elbo( - quadrature::QuadratureMethod, - sva::SparseVariationalApproximation, - fx::FiniteGP, - y::AbstractVector, - lik, - num_data::Integer, + quadrature=DefaultExpectationMethod(), ) - @assert sva.fz.f === fx.f + sva.fz.f === lfx.fx.f || throw( + ArgumentError( + "(Latent)FiniteGP prior is not consistent with SparseVariationalApproximation's", + ), + ) f_post = posterior(sva) - q_f = marginals(f_post(fx.x)) - variational_exp = expected_loglik(quadrature, y, q_f, lik) + q_f = marginals(f_post(lfx.fx.x)) + variational_exp = expected_loglikelihood(quadrature, lfx.lik, q_f, y) n_batch = length(y) scale = num_data / n_batch diff --git a/src/expected_loglik.jl b/src/expected_loglik.jl deleted file mode 100644 index 67958551..00000000 --- a/src/expected_loglik.jl +++ /dev/null @@ -1,168 +0,0 @@ -using GPLikelihoods -using FastGaussQuadrature: gausshermite -using SpecialFunctions: loggamma -using ChainRulesCore: ChainRulesCore - -abstract type QuadratureMethod end -struct DefaultQuadrature <: QuadratureMethod end -struct Analytic <: QuadratureMethod end - -struct GaussHermite <: QuadratureMethod - n_points::Int -end -GaussHermite() = GaussHermite(20) - -struct MonteCarlo <: QuadratureMethod - n_samples::Int -end -MonteCarlo() = MonteCarlo(20) - -_default_quadrature(_) = GaussHermite() - -""" - expected_loglik(quadrature::QuadratureMethod, y::AbstractVector, q_f::AbstractVector{<:Normal}, lik) - -This function computes the expected log likelihood: - -```math - ∫ q(f) log p(y | f) df -``` -where `p(y | f)` is the process likelihood. This is described by `lik`, which should be a callable that takes `f` as input and returns a Distribution over `y` that supports `loglikelihood(lik(f), y)`. - -`q(f)` is an approximation to the latent function values `f` given by: -```math - q(f) = ∫ p(f | u) q(u) du -``` -where `q(u)` is the variational distribution over inducing points (see -[`elbo`](@ref)). The marginal distributions of `q(f)` are given by `q_f`. - -`quadrature` determines which method is used to calculate the expected log -likelihood - see [`elbo`](@ref) for more details. - -# Extended help - -`q(f)` is assumed to be an `MvNormal` distribution and `p(y | f)` is assumed to -have independent marginals such that only the marginals of `q(f)` are required. -""" -expected_loglik(quadrature, y, q_f, lik) - -""" - expected_loglik(::DefaultQuadrature, y::AbstractVector, q_f::AbstractVector{<:Normal}, lik) - -The expected log likelihood. -Defaults to a closed form solution if it exists, otherwise defaults to -Gauss-Hermite quadrature. -""" -function expected_loglik( - ::DefaultQuadrature, y::AbstractVector, q_f::AbstractVector{<:Normal}, lik -) - quadrature = _default_quadrature(lik) - return expected_loglik(quadrature, y, q_f, lik) -end - -function expected_loglik( - mc::MonteCarlo, y::AbstractVector, q_f::AbstractVector{<:Normal}, lik -) - # take `n_samples` reparameterised samples - f_μ = mean.(q_f) - fs = f_μ .+ std.(q_f) .* randn(eltype(f_μ), length(q_f), mc.n_samples) - lls = loglikelihood.(lik.(fs), y) - return sum(lls) / mc.n_samples -end - -# Compute the expected_loglik over a collection of observations and marginal distributions -function expected_loglik( - gh::GaussHermite, y::AbstractVector, q_f::AbstractVector{<:Normal}, lik -) - # Compute the expectation via Gauss-Hermite quadrature - # using a reparameterisation by change of variable - # (see e.g. en.wikipedia.org/wiki/Gauss%E2%80%93Hermite_quadrature) - xs, ws = gausshermite(gh.n_points) - return sum(Broadcast.instantiate( - Broadcast.broadcasted(y, q_f) do yᵢ, q_fᵢ # Loop over every pair - # of marginal distribution q(fᵢ) and observation yᵢ - expected_loglik(gh, yᵢ, q_fᵢ, lik, (xs, ws)) - end, - )) -end - -# Compute the expected_loglik for one observation and a marginal distributions -function expected_loglik( - gh::GaussHermite, y, q_f::Normal, lik, (xs, ws)=gausshermite(gh.n_points) -) - μ = mean(q_f) - σ̃ = sqrt2 * std(q_f) - return invsqrtπ * sum(Broadcast.instantiate( - Broadcast.broadcasted(xs, ws) do x, w # Loop over every - # pair of Gauss-Hermite point x with weight w - f = σ̃ * x + μ - loglikelihood(lik(f), y) * w - end, - )) -end - -ChainRulesCore.@non_differentiable gausshermite(n) - -function expected_loglik(::Analytic, y::AbstractVector, q_f::AbstractVector{<:Normal}, lik) - return error( - "No analytic solution exists for ", - typeof(lik), - ". Use `DefaultQuadrature()`, `GaussHermite()` or `MonteCarlo()` instead.", - ) -end - -# The closed form solution for independent Gaussian noise -function expected_loglik( - ::Analytic, - y::AbstractVector{<:Real}, - q_f::AbstractVector{<:Normal}, - lik::GaussianLikelihood, -) - return sum( - -0.5 * (log(2π) .+ log.(lik.σ²) .+ ((y .- mean.(q_f)) .^ 2 .+ var.(q_f)) / lik.σ²) - ) -end - -_default_quadrature(::GaussianLikelihood) = Analytic() - -# The closed form solution for a Poisson likelihood with an exponential inverse link function -function expected_loglik( - ::Analytic, - y::AbstractVector{<:Real}, - q_f::AbstractVector{<:Normal}, - ::PoissonLikelihood{ExpLink}, -) - f_μ = mean.(q_f) - return sum((y .* f_μ) - exp.(f_μ .+ (var.(q_f) / 2)) - loggamma.(y .+ 1)) -end - -_default_quadrature(::PoissonLikelihood{ExpLink}) = Analytic() - -# The closed form solution for an Exponential likelihood with an exponential inverse link function -function expected_loglik( - ::Analytic, - y::AbstractVector{<:Real}, - q_f::AbstractVector{<:Normal}, - ::ExponentialLikelihood{ExpLink}, -) - f_μ = mean.(q_f) - return sum(-f_μ - y .* exp.((var.(q_f) / 2) .- f_μ)) -end - -_default_quadrature(::ExponentialLikelihood{ExpLink}) = Analytic() - -# The closed form solution for a Gamma likelihood with an exponential inverse link function -function expected_loglik( - ::Analytic, - y::AbstractVector{<:Real}, - q_f::AbstractVector{<:Normal}, - lik::GammaLikelihood{ExpLink}, -) - f_μ = mean.(q_f) - return sum( - (lik.α - 1) * log.(y) .- y .* exp.((var.(q_f) / 2) .- f_μ) .- lik.α * f_μ .- - loggamma(lik.α), - ) -end - -_default_quadrature(::GammaLikelihood{ExpLink}) = Analytic() diff --git a/test/Project.toml b/test/Project.toml index 39e43130..d53777f7 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -17,7 +17,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] AbstractGPs = "0.4, 0.5" -ApproximateGPs = "0.3" +ApproximateGPs = "0.4" ChainRulesCore = "1" ChainRulesTestUtils = "1.2.3" Distributions = "0.25" diff --git a/test/expected_loglik.jl b/test/expected_loglik.jl deleted file mode 100644 index a12f45ea..00000000 --- a/test/expected_loglik.jl +++ /dev/null @@ -1,93 +0,0 @@ -@testset "expected_loglik" begin - # Test that the various methods of computing expectations return the same - # result. - rng = MersenneTwister(123456) - q_f = Normal.(zeros(10), ones(10)) - - likelihoods_to_test = [ - ExponentialLikelihood(), - GammaLikelihood(), - PoissonLikelihood(), - GaussianLikelihood(), - ] - - @testset "testing all analytic implementations" begin - # Test that we're not missing any analytic implementation in `likelihoods_to_test`! - implementation_types = [ - (; quadrature=m.sig.types[2], lik=m.sig.types[5]) for - m in methods(SparseVariationalApproximationModule.expected_loglik) - ] - analytic_likelihoods = [ - m.lik for m in implementation_types if - m.quadrature == SparseVariationalApproximationModule.Analytic && m.lik != Any - ] - for lik_type in analytic_likelihoods - @test any(lik isa lik_type for lik in likelihoods_to_test) - end - end - - @testset "$(nameof(typeof(lik)))" for lik in likelihoods_to_test - methods = [GaussHermite(100), MonteCarlo(1e7)] - def = SparseVariationalApproximationModule._default_quadrature(lik) - if def isa Analytic - push!(methods, def) - end - y = rand.(rng, lik.(zeros(10))) - - results = map( - m -> SparseVariationalApproximationModule.expected_loglik(m, y, q_f, lik), - methods, - ) - @test all(x -> isapprox(x, results[end]; atol=1e-6, rtol=1e-3), results) - end - - @test SparseVariationalApproximationModule.expected_loglik( - MonteCarlo(), zeros(10), q_f, GaussianLikelihood() - ) isa Real - @test SparseVariationalApproximationModule.expected_loglik( - GaussHermite(), zeros(10), q_f, GaussianLikelihood() - ) isa Real - @test SparseVariationalApproximationModule._default_quadrature(θ -> Normal(0, θ)) isa - GaussHermite - - @testset "testing Zygote compatibility with GaussHermite" begin # see issue #82 - N = 10 - gh = GaussHermite(12) - μs = randn(rng, N) - σs = rand(rng, N) - # Test differentiation with variational parameters - for lik in likelihoods_to_test - y = rand.(rng, lik.(rand.(Normal.(μs, σs)))) - gμ, glogσ = Zygote.gradient(μs, log.(σs)) do μ, logσ - SparseVariationalApproximationModule.expected_loglik( - gh, y, Normal.(μ, exp.(logσ)), lik - ) - end - @test all(isfinite, gμ) - @test all(isfinite, glogσ) - end - # Test differentiation with likelihood parameters - # Test GaussianLikelihood parameter - σ = 1.0 - y = randn(rng, N) - glogσ = only( - Zygote.gradient(log(σ)) do x - SparseVariationalApproximationModule.expected_loglik( - gh, y, Normal.(μs, σs), GaussianLikelihood(exp(x)) - ) - end, - ) - @test isfinite(glogσ) - # Test GammaLikelihood parameter - α = 2.0 - y = rand.(rng, Gamma.(α, rand(N))) - glogα = only( - Zygote.gradient(log(α)) do x - SparseVariationalApproximationModule.expected_loglik( - gh, y, Normal.(μs, σs), GammaLikelihood(exp(x)) - ) - end, - ) - @test isfinite(glogα) - end -end diff --git a/test/runtests.jl b/test/runtests.jl index a1ec05f4..adaa3631 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -51,10 +51,6 @@ using ApproximateGPs: SparseVariationalApproximationModule, LaplaceApproximation include("test_utils.jl") @testset "ApproximateGPs" begin - include("expected_loglik.jl") - println(" ") - @info "Ran expected_loglik tests" - include("SparseVariationalApproximationModule.jl") println(" ") @info "Ran sva tests"