diff --git a/Project.toml b/Project.toml index a3aa4f4e..63b3ba98 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ApproximateGPs" uuid = "298c2ebc-0411-48ad-af38-99e88101b606" authors = ["JuliaGaussianProcesses Team"] -version = "0.2.3" +version = "0.2.4" [deps] AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918" @@ -13,6 +13,7 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" GPLikelihoods = "6031954c-0455-49d7-b3b9-3e1c99afaf40" KLDivergences = "3c9cd921-3d3f-41e2-830c-e020174918cc" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" @@ -27,6 +28,7 @@ FillArrays = "0.12" ForwardDiff = "0.10" GPLikelihoods = "0.1, 0.2" KLDivergences = "0.2.1" +PDMats = "0.11" Reexport = "1" SpecialFunctions = "1, 2" StatsBase = "0.33" diff --git a/docs/Manifest.toml b/docs/Manifest.toml index fc629d80..d8732bc3 100644 --- a/docs/Manifest.toml +++ b/docs/Manifest.toml @@ -12,10 +12,10 @@ uuid = "99985d1d-32ba-4be9-9821-2ec096f28918" version = "0.5.3" [[ApproximateGPs]] -deps = ["AbstractGPs", "ChainRulesCore", "Distributions", "FastGaussQuadrature", "FillArrays", "ForwardDiff", "GPLikelihoods", "KLDivergences", "LinearAlgebra", "Reexport", "SpecialFunctions", "Statistics", "StatsBase"] +deps = ["AbstractGPs", "ChainRulesCore", "Distributions", "FastGaussQuadrature", "FillArrays", "ForwardDiff", "GPLikelihoods", "KLDivergences", "LinearAlgebra", "PDMats", "Reexport", "SpecialFunctions", "Statistics", "StatsBase"] path = ".." uuid = "298c2ebc-0411-48ad-af38-99e88101b606" -version = "0.2.1" +version = "0.2.3" [[ArgTools]] uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" diff --git a/examples/a-regression/Manifest.toml b/examples/a-regression/Manifest.toml index cd42657c..8db5fb5c 100644 --- a/examples/a-regression/Manifest.toml +++ b/examples/a-regression/Manifest.toml @@ -24,10 +24,10 @@ uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" version = "3.3.1" [[ApproximateGPs]] -deps = ["AbstractGPs", "ChainRulesCore", "Distributions", "FastGaussQuadrature", "FillArrays", "ForwardDiff", "GPLikelihoods", "KLDivergences", "LinearAlgebra", "Reexport", "SpecialFunctions", "Statistics", "StatsBase"] +deps = ["AbstractGPs", "ChainRulesCore", "Distributions", "FastGaussQuadrature", "FillArrays", "ForwardDiff", "GPLikelihoods", "KLDivergences", "LinearAlgebra", "PDMats", "Reexport", "SpecialFunctions", "Statistics", "StatsBase"] path = "../.." uuid = "298c2ebc-0411-48ad-af38-99e88101b606" -version = "0.2.0" +version = "0.2.3" [[ArgTools]] uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" @@ -69,9 +69,9 @@ version = "3.5.0" [[Cairo_jll]] deps = ["Artifacts", "Bzip2_jll", "Fontconfig_jll", "FreeType2_jll", "Glib_jll", "JLLWrappers", "LZO_jll", "Libdl", "Pixman_jll", "Pkg", "Xorg_libXext_jll", "Xorg_libXrender_jll", "Zlib_jll", "libpng_jll"] -git-tree-sha1 = "f2202b55d816427cd385a9a4f3ffb226bee80f99" +git-tree-sha1 = "4b859a208b2397a7a623a03449e4636bdb17bcf2" uuid = "83423d85-b0ee-5818-9007-b63ccbeb887a" -version = "1.16.1+0" +version = "1.16.1+1" [[ChainRules]] deps = ["ChainRulesCore", "Compat", "LinearAlgebra", "Random", "RealDot", "Statistics"] @@ -336,9 +336,9 @@ version = "0.21.0+0" [[Glib_jll]] deps = ["Artifacts", "Gettext_jll", "JLLWrappers", "Libdl", "Libffi_jll", "Libiconv_jll", "Libmount_jll", "PCRE_jll", "Pkg", "Zlib_jll"] -git-tree-sha1 = "7bf67e9a481712b3dbe9cb3dac852dc4b1162e02" +git-tree-sha1 = "a32d672ac2c967f3deb8a81d828afc739c838a06" uuid = "7746bdde-850d-59dc-9ae8-88ece973131d" -version = "2.68.3+0" +version = "2.68.3+2" [[Graphite2_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] @@ -359,9 +359,9 @@ version = "0.9.16" [[HarfBuzz_jll]] deps = ["Artifacts", "Cairo_jll", "Fontconfig_jll", "FreeType2_jll", "Glib_jll", "Graphite2_jll", "JLLWrappers", "Libdl", "Libffi_jll", "Pkg"] -git-tree-sha1 = "8a954fed8ac097d5be04921d595f741115c1b2ad" +git-tree-sha1 = "129acf094d168394e80ee1dc4bc06ec835e510a3" uuid = "2e76f6c2-a576-52d4-95c1-20adfe4de566" -version = "2.8.1+0" +version = "2.8.1+1" [[IOCapture]] deps = ["Logging", "Random"] @@ -507,9 +507,9 @@ uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" [[Libffi_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "761a393aeccd6aa92ec3515e428c26bf99575b3b" +git-tree-sha1 = "0b4a5d71f3e5200a7dff793393e09dfc2d874290" uuid = "e9f186c6-92d2-5b65-8a66-fee21dc1b490" -version = "3.2.2+0" +version = "3.2.2+1" [[Libgcrypt_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Libgpg_error_jll", "Pkg"] diff --git a/examples/b-classification/Manifest.toml b/examples/b-classification/Manifest.toml index a8b06699..67972d0f 100644 --- a/examples/b-classification/Manifest.toml +++ b/examples/b-classification/Manifest.toml @@ -19,10 +19,10 @@ uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" version = "3.3.1" [[ApproximateGPs]] -deps = ["AbstractGPs", "ChainRulesCore", "Distributions", "FastGaussQuadrature", "FillArrays", "ForwardDiff", "GPLikelihoods", "KLDivergences", "LinearAlgebra", "Reexport", "SpecialFunctions", "Statistics", "StatsBase"] +deps = ["AbstractGPs", "ChainRulesCore", "Distributions", "FastGaussQuadrature", "FillArrays", "ForwardDiff", "GPLikelihoods", "KLDivergences", "LinearAlgebra", "PDMats", "Reexport", "SpecialFunctions", "Statistics", "StatsBase"] path = "../.." uuid = "298c2ebc-0411-48ad-af38-99e88101b606" -version = "0.2.0" +version = "0.2.3" [[ArgTools]] uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" @@ -47,9 +47,9 @@ version = "1.0.8+0" [[Cairo_jll]] deps = ["Artifacts", "Bzip2_jll", "Fontconfig_jll", "FreeType2_jll", "Glib_jll", "JLLWrappers", "LZO_jll", "Libdl", "Pixman_jll", "Pkg", "Xorg_libXext_jll", "Xorg_libXrender_jll", "Zlib_jll", "libpng_jll"] -git-tree-sha1 = "f2202b55d816427cd385a9a4f3ffb226bee80f99" +git-tree-sha1 = "4b859a208b2397a7a623a03449e4636bdb17bcf2" uuid = "83423d85-b0ee-5818-9007-b63ccbeb887a" -version = "1.16.1+0" +version = "1.16.1+1" [[ChainRules]] deps = ["ChainRulesCore", "Compat", "LinearAlgebra", "Random", "RealDot", "Statistics"] @@ -291,9 +291,9 @@ version = "0.21.0+0" [[Glib_jll]] deps = ["Artifacts", "Gettext_jll", "JLLWrappers", "Libdl", "Libffi_jll", "Libiconv_jll", "Libmount_jll", "PCRE_jll", "Pkg", "Zlib_jll"] -git-tree-sha1 = "7bf67e9a481712b3dbe9cb3dac852dc4b1162e02" +git-tree-sha1 = "a32d672ac2c967f3deb8a81d828afc739c838a06" uuid = "7746bdde-850d-59dc-9ae8-88ece973131d" -version = "2.68.3+0" +version = "2.68.3+2" [[Graphite2_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] @@ -314,9 +314,9 @@ version = "0.9.16" [[HarfBuzz_jll]] deps = ["Artifacts", "Cairo_jll", "Fontconfig_jll", "FreeType2_jll", "Glib_jll", "Graphite2_jll", "JLLWrappers", "Libdl", "Libffi_jll", "Pkg"] -git-tree-sha1 = "8a954fed8ac097d5be04921d595f741115c1b2ad" +git-tree-sha1 = "129acf094d168394e80ee1dc4bc06ec835e510a3" uuid = "2e76f6c2-a576-52d4-95c1-20adfe4de566" -version = "2.8.1+0" +version = "2.8.1+1" [[IOCapture]] deps = ["Logging", "Random"] @@ -440,9 +440,9 @@ uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" [[Libffi_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "761a393aeccd6aa92ec3515e428c26bf99575b3b" +git-tree-sha1 = "0b4a5d71f3e5200a7dff793393e09dfc2d874290" uuid = "e9f186c6-92d2-5b65-8a66-fee21dc1b490" -version = "3.2.2+0" +version = "3.2.2+1" [[Libgcrypt_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Libgpg_error_jll", "Pkg"] diff --git a/examples/c-comparisons/Manifest.toml b/examples/c-comparisons/Manifest.toml index 66937125..de5d6de3 100644 --- a/examples/c-comparisons/Manifest.toml +++ b/examples/c-comparisons/Manifest.toml @@ -19,10 +19,10 @@ uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" version = "3.3.1" [[ApproximateGPs]] -deps = ["AbstractGPs", "ChainRulesCore", "Distributions", "FastGaussQuadrature", "FillArrays", "ForwardDiff", "GPLikelihoods", "KLDivergences", "LinearAlgebra", "Reexport", "SpecialFunctions", "Statistics", "StatsBase"] +deps = ["AbstractGPs", "ChainRulesCore", "Distributions", "FastGaussQuadrature", "FillArrays", "ForwardDiff", "GPLikelihoods", "KLDivergences", "LinearAlgebra", "PDMats", "Reexport", "SpecialFunctions", "Statistics", "StatsBase"] path = "../.." uuid = "298c2ebc-0411-48ad-af38-99e88101b606" -version = "0.2.0" +version = "0.2.3" [[ArgTools]] uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" @@ -47,9 +47,9 @@ version = "1.0.8+0" [[Cairo_jll]] deps = ["Artifacts", "Bzip2_jll", "Fontconfig_jll", "FreeType2_jll", "Glib_jll", "JLLWrappers", "LZO_jll", "Libdl", "Pixman_jll", "Pkg", "Xorg_libXext_jll", "Xorg_libXrender_jll", "Zlib_jll", "libpng_jll"] -git-tree-sha1 = "f2202b55d816427cd385a9a4f3ffb226bee80f99" +git-tree-sha1 = "4b859a208b2397a7a623a03449e4636bdb17bcf2" uuid = "83423d85-b0ee-5818-9007-b63ccbeb887a" -version = "1.16.1+0" +version = "1.16.1+1" [[ChainRules]] deps = ["ChainRulesCore", "Compat", "LinearAlgebra", "Random", "RealDot", "Statistics"] @@ -291,9 +291,9 @@ version = "0.21.0+0" [[Glib_jll]] deps = ["Artifacts", "Gettext_jll", "JLLWrappers", "Libdl", "Libffi_jll", "Libiconv_jll", "Libmount_jll", "PCRE_jll", "Pkg", "Zlib_jll"] -git-tree-sha1 = "7bf67e9a481712b3dbe9cb3dac852dc4b1162e02" +git-tree-sha1 = "a32d672ac2c967f3deb8a81d828afc739c838a06" uuid = "7746bdde-850d-59dc-9ae8-88ece973131d" -version = "2.68.3+0" +version = "2.68.3+2" [[Graphite2_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] @@ -314,9 +314,9 @@ version = "0.9.16" [[HarfBuzz_jll]] deps = ["Artifacts", "Cairo_jll", "Fontconfig_jll", "FreeType2_jll", "Glib_jll", "Graphite2_jll", "JLLWrappers", "Libdl", "Libffi_jll", "Pkg"] -git-tree-sha1 = "8a954fed8ac097d5be04921d595f741115c1b2ad" +git-tree-sha1 = "129acf094d168394e80ee1dc4bc06ec835e510a3" uuid = "2e76f6c2-a576-52d4-95c1-20adfe4de566" -version = "2.8.1+0" +version = "2.8.1+1" [[IOCapture]] deps = ["Logging", "Random"] @@ -440,9 +440,9 @@ uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" [[Libffi_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "761a393aeccd6aa92ec3515e428c26bf99575b3b" +git-tree-sha1 = "0b4a5d71f3e5200a7dff793393e09dfc2d874290" uuid = "e9f186c6-92d2-5b65-8a66-fee21dc1b490" -version = "3.2.2+0" +version = "3.2.2+1" [[Libgcrypt_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Libgpg_error_jll", "Pkg"] diff --git a/src/ApproximateGPs.jl b/src/ApproximateGPs.jl index 40847914..1f30ef53 100644 --- a/src/ApproximateGPs.jl +++ b/src/ApproximateGPs.jl @@ -12,23 +12,16 @@ using SpecialFunctions using ChainRulesCore using FillArrays using KLDivergences +using PDMats: chol_lower -using AbstractGPs: - AbstractGP, - FiniteGP, - LatentFiniteGP, - ApproxPosteriorGP, - At_A, - diag_At_A, - Xt_A_X, - Xt_A_Y, - diag_Xt_A_X - -export SparseVariationalApproximation, Centered, NonCentered -export DefaultQuadrature, Analytic, GaussHermite, MonteCarlo +using AbstractGPs: AbstractGP, FiniteGP, LatentFiniteGP, ApproxPosteriorGP, At_A, diag_At_A include("utils.jl") + +export DefaultQuadrature, Analytic, GaussHermite, MonteCarlo include("expected_loglik.jl") + +export SparseVariationalApproximation, Centered, NonCentered include("sparse_variational.jl") using ForwardDiff diff --git a/src/sparse_variational.jl b/src/sparse_variational.jl index b27db88f..91364bed 100644 --- a/src/sparse_variational.jl +++ b/src/sparse_variational.jl @@ -56,7 +56,7 @@ end SparseVariationalApproximation(fz::FiniteGP, q::AbstractMvNormal) Packages the prior over the pseudo-points `fz`, and the approximate posterior at the -pseudo-points, which is `mean(fz) + cholesky(cov(fz)).U' * ε`, `ε ∼ q`. +pseudo-points, which is `mean(fz) + cholesky(cov(fz)).L * ε`, `ε ∼ q`. Shorthand for ```julia @@ -86,82 +86,28 @@ variational Gaussian process classification." Artificial Intelligence and Statistics. PMLR, 2015. """ function AbstractGPs.posterior(sva::SparseVariationalApproximation{Centered}) + # m* = K*u Kuu⁻¹ (mean(q) - mean(fz)) + # = K*u α + # Centered: α = Kuu⁻¹ (m - mean(fz)) + # [NonCentered: α = Lk⁻ᵀ m] + # V** = K** - K*u (Kuu⁻¹ - Kuu⁻¹ cov(q) Kuu⁻¹) Ku* + # = K** - K*u (Kuu⁻¹ - Kuu⁻¹ cov(q) Kuu⁻¹) Ku* + # = K** - (K*u Lk⁻ᵀ) (Lk⁻¹ Ku*) + (K*u Lk⁻ᵀ) Lk⁻¹ cov(q) Lk⁻ᵀ (Lk⁻¹ Ku*) + # = K** - A'A + A' Lk⁻¹ cov(q) Lk⁻ᵀ A + # = K** - A'A + A' Lk⁻¹ Lq Lqᵀ Lk⁻ᵀ A + # = K** - A'A + A' B B' A + # A = Lk⁻¹ Ku* + # Centered: B = Lk⁻¹ Lq + # [NonCentered: B = Lq] q, fz = sva.q, sva.fz m, S = mean(q), _chol_cov(q) Kuu = _chol_cov(fz) - B = Kuu.L \ S.L + B = chol_lower(Kuu) \ chol_lower(S) α = Kuu \ (m - mean(fz)) - data = (S=S, m=m, Kuu=Kuu, B=B, α=α) + data = (Kuu=Kuu, B=B, α=α) return ApproxPosteriorGP(sva, fz.f, data) end -function AbstractGPs.posterior( - sva::SparseVariationalApproximation, fx::FiniteGP, ::AbstractVector{<:Real} -) - @assert sva.fz.f === fx.f - return posterior(sva) -end - -# -# Various methods implementing the Internal AbstractGPs API. -# See AbstractGPs.jl API docs for more info. -# - -function Statistics.mean( - f::ApproxPosteriorGP{<:SparseVariationalApproximation{Centered}}, x::AbstractVector -) - return mean(f.prior, x) + cov(f.prior, x, inducing_points(f)) * f.data.α -end - -function Statistics.cov( - f::ApproxPosteriorGP{<:SparseVariationalApproximation{Centered}}, x::AbstractVector -) - Cux = cov(f.prior, inducing_points(f), x) - D = f.data.Kuu.L \ Cux - return cov(f.prior, x) - At_A(D) + At_A(f.data.B' * D) -end - -function Statistics.var( - f::ApproxPosteriorGP{<:SparseVariationalApproximation{Centered}}, x::AbstractVector -) - Cux = cov(f.prior, inducing_points(f), x) - D = f.data.Kuu.L \ Cux - return var(f.prior, x) - diag_At_A(D) + diag_At_A(f.data.B' * D) -end - -function Statistics.cov( - f::ApproxPosteriorGP{<:SparseVariationalApproximation{Centered}}, - x::AbstractVector, - y::AbstractVector, -) - B = f.data.B - Cxu = cov(f.prior, x, inducing_points(f)) - Cuy = cov(f.prior, inducing_points(f), y) - D = f.data.Kuu.L \ Cuy - E = Cxu / f.data.Kuu.L' - return cov(f.prior, x, y) - (E * D) + (E * B * B' * D) -end - -function StatsBase.mean_and_cov( - f::ApproxPosteriorGP{<:SparseVariationalApproximation{Centered}}, x::AbstractVector -) - Cux = cov(f.prior, inducing_points(f), x) - D = f.data.Kuu.L \ Cux - μ = Cux' * f.data.α - Σ = cov(f.prior, x) - At_A(D) + At_A(f.data.B' * D) - return μ, Σ -end - -function StatsBase.mean_and_var( - f::ApproxPosteriorGP{<:SparseVariationalApproximation{Centered}}, x::AbstractVector -) - Cux = cov(f.prior, inducing_points(f), x) - D = f.data.Kuu.L \ Cux - μ = Cux' * f.data.α - Σ_diag = var(f.prior, x) - diag_At_A(D) + diag_At_A(f.data.B' * D) - return μ, Σ_diag -end - # # NonCentered Parametrization. # @@ -172,7 +118,7 @@ end Compute the approximate posterior [1] over the process `f = sva.fz.f`, given inducing inputs `z = sva.fz.x` and a variational distribution over inducing points `sva.q` (which represents ``q(ε)`` -where `ε = cholesky(cov(fz)).U' \ (f(z) - mean(f(z)))`). The approximate posterior at test +where `ε = cholesky(cov(fz)).L \ (f(z) - mean(f(z)))`). The approximate posterior at test points ``x^*`` where ``f^* = f(x^*)`` is then given by: ```math @@ -184,10 +130,40 @@ which can be found in closed form. variational Gaussian process classification." Artificial Intelligence and Statistics. PMLR, 2015. """ -function AbstractGPs.posterior(approx::SparseVariationalApproximation{NonCentered}) - fz = approx.fz - data = (Cuu=_chol_cov(fz), C_ε=_chol_cov(approx.q)) - return ApproxPosteriorGP(approx, fz.f, data) +function AbstractGPs.posterior(sva::SparseVariationalApproximation{NonCentered}) + # u = Lk v + mean(fz), v ~ q + # m* = K*u Kuu⁻¹ Lk (mean(u) - mean(fz)) + # = K*u (Lk Lkᵀ)⁻¹ Lk mean(q) + # = K*u Lk⁻ᵀ Lk⁻¹ Lk mean(q) + # = K*u Lk⁻ᵀ mean(q) + # = K*u α + # NonCentered: α = Lk⁻ᵀ m + # [Centered: α = Kuu⁻¹ (m - mean(fz))] + # V** = K** - K*u (Kuu⁻¹ - Kuu⁻¹ Lk cov(q) Lkᵀ Kuu⁻¹) Ku* + # = K** - K*u (Kuu⁻¹ - (Lk Lkᵀ)⁻¹ Lk cov(q) Lkᵀ (Lk Lkᵀ)⁻¹) Ku* + # = K** - K*u (Kuu⁻¹ - Lk⁻ᵀ Lk⁻¹ Lk cov(q) Lkᵀ Lk⁻ᵀ Lk⁻¹) Ku* + # = K** - K*u (Kuu⁻¹ - Lk⁻ᵀ cov(q) Lk⁻¹) Ku* + # = K** - (K*u Lk⁻ᵀ) (Lk⁻¹ Ku*) - (K*u Lk⁻ᵀ) Lq Lqᵀ (Lk⁻¹ Ku*) + # = K** - A'A - (K*u Lk⁻ᵀ) Lq Lqᵀ (Lk⁻¹ Ku*) + # = K** - A'A - A' B B' A + # A = Lk⁻¹ Ku* + # NonCentered: B = Lq + # [Centered: B = Lk⁻¹ Lq] + q, fz = sva.q, sva.fz + m = mean(q) + Kuu = _chol_cov(fz) + α = chol_lower(Kuu)' \ m + Sv = _chol_cov(q) + B = chol_lower(Sv) + data = (Kuu=Kuu, B=B, α=α) + return ApproxPosteriorGP(sva, fz.f, data) +end + +function AbstractGPs.posterior( + sva::SparseVariationalApproximation, fx::FiniteGP, ::AbstractVector{<:Real} +) + @assert sva.fz.f === fx.f + return posterior(sva) end # @@ -195,57 +171,62 @@ end # See AbstractGPs.jl API docs for more info. # -# Produces a matrix that is consistently referred to as A in this file. A more descriptive -# name is, unfortunately, not obvious. It's just an intermediate quantity that happens to -# get used a lot. -_A(f, x) = f.data.Cuu.U' \ cov(f.prior, inducing_points(f), x) - function Statistics.mean( - f::ApproxPosteriorGP{<:SparseVariationalApproximation{NonCentered}}, x::AbstractVector + f::ApproxPosteriorGP{<:SparseVariationalApproximation}, x::AbstractVector ) - return mean(f.prior, x) + _A(f, x)' * mean(f.approx.q) + return mean(f.prior, x) + cov(f.prior, x, inducing_points(f)) * f.data.α +end + +# A = Lk⁻¹ Ku* is the projection matrix used in computing the predictive variance of the SparseVariationalApproximation posterior. +function _A_and_Kuf(f, x) + Kuf = cov(f.prior, inducing_points(f), x) + A = chol_lower(f.data.Kuu) \ Kuf + return A, Kuf end +_A(f, x) = first(_A_and_Kuf(f, x)) + function Statistics.cov( - f::ApproxPosteriorGP{<:SparseVariationalApproximation{NonCentered}}, x::AbstractVector + f::ApproxPosteriorGP{<:SparseVariationalApproximation}, x::AbstractVector ) A = _A(f, x) - return cov(f.prior, x) - At_A(A) + Xt_A_X(f.data.C_ε, A) + return cov(f.prior, x) - At_A(A) + At_A(f.data.B' * A) end function Statistics.var( - f::ApproxPosteriorGP{<:SparseVariationalApproximation{NonCentered}}, x::AbstractVector + f::ApproxPosteriorGP{<:SparseVariationalApproximation}, x::AbstractVector ) A = _A(f, x) - return var(f.prior, x) - diag_At_A(A) + diag_Xt_A_X(f.data.C_ε, A) -end - -function Statistics.cov( - f::ApproxPosteriorGP{<:SparseVariationalApproximation{NonCentered}}, - x::AbstractVector, - y::AbstractVector, -) - Ax = _A(f, x) - Ay = _A(f, y) - return cov(f.prior, x, y) - Ax'Ay + Xt_A_Y(Ax, f.data.C_ε, Ay) + return var(f.prior, x) - diag_At_A(A) + diag_At_A(f.data.B' * A) end function StatsBase.mean_and_cov( - f::ApproxPosteriorGP{<:SparseVariationalApproximation{NonCentered}}, x::AbstractVector + f::ApproxPosteriorGP{<:SparseVariationalApproximation}, x::AbstractVector ) - A = _A(f, x) - μ = mean(f.prior, x) + A' * mean(f.approx.q) - Σ = cov(f.prior, x) - At_A(A) + Xt_A_X(f.data.C_ε, A) + A, Kuf = _A_and_Kuf(f, x) + μ = mean(f.prior, x) + Kuf' * f.data.α + Σ = cov(f.prior, x) - At_A(A) + At_A(f.data.B' * A) return μ, Σ end function StatsBase.mean_and_var( - f::ApproxPosteriorGP{<:SparseVariationalApproximation{NonCentered}}, x::AbstractVector + f::ApproxPosteriorGP{<:SparseVariationalApproximation}, x::AbstractVector ) - A = _A(f, x) - μ = mean(f.prior, x) + A' * mean(f.approx.q) - Σ = var(f.prior, x) - diag_At_A(A) + diag_Xt_A_X(f.data.C_ε, A) - return μ, Σ + A, Kuf = _A_and_Kuf(f, x) + μ = mean(f.prior, x) + Kuf' * f.data.α + Σ_diag = var(f.prior, x) - diag_At_A(A) + diag_At_A(f.data.B' * A) + return μ, Σ_diag +end + +function Statistics.cov( + f::ApproxPosteriorGP{<:SparseVariationalApproximation}, + x::AbstractVector, + y::AbstractVector, +) + B = f.data.B + Ax = _A(f, x) + Ay = _A(f, y) + return cov(f.prior, x, y) - Ax'Ay + Ax' * B * B' * Ay end # @@ -338,18 +319,25 @@ function _elbo( num_data::Integer, ) @assert sva.fz.f === fx.f - post = posterior(sva) - q_f = marginals(post(fx.x)) + + f_post = posterior(sva) + q_f = marginals(f_post(fx.x)) variational_exp = expected_loglik(quadrature, y, q_f, lik) n_batch = length(y) scale = num_data / n_batch - return sum(variational_exp) * scale - kl_term(sva, post) + return sum(variational_exp) * scale - _prior_kl(sva) end -kl_term(sva::SparseVariationalApproximation{Centered}, post) = KL(sva.q, sva.fz) +_prior_kl(sva::SparseVariationalApproximation{Centered}) = KL(sva.q, sva.fz) -function kl_term(sva::SparseVariationalApproximation{NonCentered}, post) +function _prior_kl(sva::SparseVariationalApproximation{NonCentered}) m_ε = mean(sva.q) - return (tr(cov(sva.q)) + m_ε'm_ε - length(m_ε) - logdet(post.data.C_ε)) / 2 + C_ε = _cov(sva.q) + + # trace_term = tr(C_ε) # does not work due to PDMat / Zygote issues + L = chol_lower(_chol_cov(sva.q)) + trace_term = sum(L .^ 2) # TODO remove AD workaround + + return (trace_term + m_ε'm_ε - length(m_ε) - logdet(C_ε)) / 2 end diff --git a/src/utils.jl b/src/utils.jl index 4db93f31..6b145f2b 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -11,3 +11,5 @@ end _chol_cov(q::AbstractMvNormal) = cholesky(Symmetric(cov(q))) _chol_cov(q::MvNormal) = cholesky(q.Σ) + +_cov(q::MvNormal) = q.Σ diff --git a/test/sparse_variational.jl b/test/sparse_variational.jl index 339e0aa3..43cc5898 100644 --- a/test/sparse_variational.jl +++ b/test/sparse_variational.jl @@ -26,7 +26,10 @@ # Check that approximate posterior is self-consistent. a = collect(range(-1.0, 1.0; length=N_a)) b = randn(rng, N_b) - TestUtils.test_internal_abstractgps_interface(rng, f_approx_post_Centered, a, b) + + @testset "AbstractGPs interface - Centered" begin + TestUtils.test_internal_abstractgps_interface(rng, f_approx_post_Centered, a, b) + end @testset "NonCentered" begin @@ -37,28 +40,32 @@ Cuu.L \ (mean(q) - mean(fz)), Symmetric((Cuu.L \ cov(q)) / Cuu.U) ) - # Check that q_ε has been properly constructed. - @test mean(q) ≈ mean(fz) + Cuu.L * mean(q_ε) - @test cov(q) ≈ Cuu.L * cov(q_ε) * Cuu.U + @testset "Check that q_ε has been properly constructed" begin + @test mean(q) ≈ mean(fz) + Cuu.L * mean(q_ε) + @test cov(q) ≈ Cuu.L * cov(q_ε) * Cuu.U + end # Construct equivalent approximate posteriors. approx_non_Centered = SparseVariationalApproximation(NonCentered(), fz, q_ε) f_approx_post_non_Centered = posterior(approx_non_Centered) - TestUtils.test_internal_abstractgps_interface( - rng, f_approx_post_non_Centered, a, b - ) - # Unit-test kl_term. - @test isapprox( - ApproximateGPs.kl_term(approx_non_Centered, f_approx_post_non_Centered), - ApproximateGPs.kl_term(approx_Centered, f_approx_post_Centered); - rtol=1e-5, - ) + @testset "AbstractGPs interface - NonCentered" begin + TestUtils.test_internal_abstractgps_interface( + rng, f_approx_post_non_Centered, a, b + ) + end - # Verify that the non-centered approximate posterior agrees with centered. - @test mean(f_approx_post_non_Centered, a) ≈ mean(f_approx_post_Centered, a) - @test cov(f_approx_post_non_Centered, a, b) ≈ cov(f_approx_post_Centered, a, b) - @test elbo(approx_non_Centered, fx, y) ≈ elbo(approx_Centered, fx, y) + @testset "Verify that the non-centered approximate posterior agrees with centered" begin + @test isapprox( + ApproximateGPs._prior_kl(approx_non_Centered), + ApproximateGPs._prior_kl(approx_Centered); + rtol=1e-5, + ) + @test mean(f_approx_post_non_Centered, a) ≈ mean(f_approx_post_Centered, a) + @test cov(f_approx_post_non_Centered, a, b) ≈ + cov(f_approx_post_Centered, a, b) + @test elbo(approx_non_Centered, fx, y) ≈ elbo(approx_Centered, fx, y) + end end end @@ -87,7 +94,7 @@ @test elbo(sva, lfx, y) ≈ elbo(sva, fx, y) atol = 1e-10 end - @testset "equivalences" begin + @testset "GPR/VFE equivalences" begin rng, N = MersenneTwister(654321), 20 x = rand(rng, N) * 10 y = sin.(x) + 0.9 * cos.(x * 1.6) + 0.4 * rand(rng, N)