Skip to content

Commit

Permalink
Implement whitened parametrisation (#71)
Browse files Browse the repository at this point in the history
* Implement whitened parametrisation

* Bump patch

* Improve docs

* Apply suggestions from code review

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Update docs

* SVGP -> SparseVariationalApproximation

* Fix docstring typo

Co-authored-by: Ross Viljoen <[email protected]>

* Update docs

* Refactor to use type parameter

* Test kl_term

* Run all tests

* Clarify tests

* Apply suggestions from code review

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Stabilise numerics in sparse_variational tests

* Stabilise tests

* Improve docs

* Add Gorinova reference

* Add whitening transformation ref

* Fix tests

* Use American English :(

* Update test/sparse_variational.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Typos

* Apply Theo's suggestions

Co-authored-by: Théo Galy-Fajou <[email protected]>

* Update src/sparse_variational.jl

Co-authored-by: Théo Galy-Fajou <[email protected]>

* Fix for docs

* Fix rest of the docs

* Apply Ti's formatting suggestions

Co-authored-by: st-- <[email protected]>

* Add Paciorek reference

* Bump patch

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: Ross Viljoen <[email protected]>
Co-authored-by: Théo Galy-Fajou <[email protected]>
Co-authored-by: st-- <[email protected]>
  • Loading branch information
5 people authored Nov 16, 2021
1 parent 8b4d798 commit a3db592
Show file tree
Hide file tree
Showing 8 changed files with 326 additions and 59 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ApproximateGPs"
uuid = "298c2ebc-0411-48ad-af38-99e88101b606"
authors = ["JuliaGaussianProcesses Team"]
version = "0.2.1"
version = "0.2.2"

[deps]
AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918"
Expand Down
48 changes: 30 additions & 18 deletions docs/Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ version = "0.5.3"
deps = ["AbstractGPs", "ChainRulesCore", "Distributions", "FastGaussQuadrature", "FillArrays", "ForwardDiff", "GPLikelihoods", "KLDivergences", "LinearAlgebra", "Reexport", "SpecialFunctions", "Statistics", "StatsBase"]
path = ".."
uuid = "298c2ebc-0411-48ad-af38-99e88101b606"
version = "0.2.0"
version = "0.2.1"

[[ArgTools]]
uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f"
Expand All @@ -32,6 +32,12 @@ git-tree-sha1 = "f885e7e7c124f8c92650d61b9477b9ac2ee607dd"
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
version = "1.11.1"

[[ChangesOfVariables]]
deps = ["LinearAlgebra", "Test"]
git-tree-sha1 = "9a1d594397670492219635b35a3d830b04730d62"
uuid = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0"
version = "0.1.1"

[[CommonSubexpressions]]
deps = ["MacroTools", "Test"]
git-tree-sha1 = "7b8a93dba8af7e3b42fecabf646260105ac373f7"
Expand Down Expand Up @@ -72,6 +78,12 @@ uuid = "ade2ca70-3891-5945-98fb-dc099432e06a"
deps = ["Mmap"]
uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab"

[[DensityInterface]]
deps = ["InverseFunctions", "Test"]
git-tree-sha1 = "794daf62dce7df839b8ed446fc59c68db4b5182f"
uuid = "b429d917-457f-4dbc-8f4c-0cc954292b1d"
version = "0.3.3"

[[DiffResults]]
deps = ["StaticArrays"]
git-tree-sha1 = "c18e98cba888c6c25d1c3b048e4b3380ca956805"
Expand All @@ -95,10 +107,10 @@ deps = ["Random", "Serialization", "Sockets"]
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"

[[Distributions]]
deps = ["ChainRulesCore", "FillArrays", "LinearAlgebra", "PDMats", "Printf", "QuadGK", "Random", "SparseArrays", "SpecialFunctions", "Statistics", "StatsBase", "StatsFuns"]
git-tree-sha1 = "72dcda9e19f88d09bf21b5f9507a0bb430bce2aa"
deps = ["ChainRulesCore", "DensityInterface", "FillArrays", "LinearAlgebra", "PDMats", "Printf", "QuadGK", "Random", "SparseArrays", "SpecialFunctions", "Statistics", "StatsBase", "StatsFuns", "Test"]
git-tree-sha1 = "cce8159f0fee1281335a04bbf876572e46c921ba"
uuid = "31c24e10-a181-5473-b8eb-7969acd0382f"
version = "0.25.24"
version = "0.25.29"

[[DocStringExtensions]]
deps = ["LibGit2"]
Expand Down Expand Up @@ -129,21 +141,21 @@ uuid = "1a297f60-69ca-5386-bcde-b61e274b549b"
version = "0.12.7"

[[ForwardDiff]]
deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "LinearAlgebra", "NaNMath", "Preferences", "Printf", "Random", "SpecialFunctions", "StaticArrays"]
git-tree-sha1 = "ef3fec65f9db26fa2cf8f4133c697c5b7ce63c1d"
deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "LinearAlgebra", "LogExpFunctions", "NaNMath", "Preferences", "Printf", "Random", "SpecialFunctions", "StaticArrays"]
git-tree-sha1 = "6406b5112809c08b1baa5703ad274e1dded0652f"
uuid = "f6369f11-7733-5829-9624-2563aa707210"
version = "0.10.22"
version = "0.10.23"

[[Functors]]
git-tree-sha1 = "e4768c3b7f597d5a352afa09874d16e3c3f6ead2"
uuid = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
version = "0.2.7"

[[GPLikelihoods]]
deps = ["Distributions", "Functors", "LinearAlgebra", "Random", "StatsFuns"]
git-tree-sha1 = "bdfe8a65b3ca3aa92812d74138264570f33aa66e"
deps = ["Distributions", "Functors", "InverseFunctions", "LinearAlgebra", "Random", "StatsFuns"]
git-tree-sha1 = "561e03fc0dc1d38560dc1403ad95b308418f0ed6"
uuid = "6031954c-0455-49d7-b3b9-3e1c99afaf40"
version = "0.2.4"
version = "0.2.5"

[[IOCapture]]
deps = ["Logging", "Random"]
Expand All @@ -157,9 +169,9 @@ uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"

[[InverseFunctions]]
deps = ["Test"]
git-tree-sha1 = "f0c6489b12d28fb4c2103073ec7452f3423bd308"
git-tree-sha1 = "a7254c0acd8e62f1ac75ad24d5db43f5f19f3c65"
uuid = "3587e190-3f89-42d0-90ee-14403ec27112"
version = "0.1.1"
version = "0.1.2"

[[IrrationalConstants]]
git-tree-sha1 = "7fd44fd4ff43fc60815f8e764c0f352b83c49151"
Expand Down Expand Up @@ -214,10 +226,10 @@ deps = ["Libdl"]
uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

[[LogExpFunctions]]
deps = ["ChainRulesCore", "DocStringExtensions", "InverseFunctions", "IrrationalConstants", "LinearAlgebra"]
git-tree-sha1 = "6193c3815f13ba1b78a51ce391db8be016ae9214"
deps = ["ChainRulesCore", "ChangesOfVariables", "DocStringExtensions", "InverseFunctions", "IrrationalConstants", "LinearAlgebra"]
git-tree-sha1 = "be9eef9f9d78cecb6f262f3c10da151a6c5ab827"
uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
version = "0.3.4"
version = "0.3.5"

[[Logging]]
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
Expand Down Expand Up @@ -390,10 +402,10 @@ uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
version = "0.33.12"

[[StatsFuns]]
deps = ["ChainRulesCore", "IrrationalConstants", "LogExpFunctions", "Reexport", "Rmath", "SpecialFunctions"]
git-tree-sha1 = "95072ef1a22b057b1e80f73c2a89ad238ae4cfff"
deps = ["ChainRulesCore", "InverseFunctions", "IrrationalConstants", "LogExpFunctions", "Reexport", "Rmath", "SpecialFunctions"]
git-tree-sha1 = "385ab64e64e79f0cd7cfcf897169b91ebbb2d6c8"
uuid = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
version = "0.9.12"
version = "0.9.13"

[[SuiteSparse]]
deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"]
Expand Down
35 changes: 31 additions & 4 deletions docs/src/userguide.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,22 +28,49 @@ To construct a sparse approximation to the exact posterior, we first need to sel
M = 15 # The number of inducing points
z = x[1:M]
```
The inducing inputs `z` imply some latent function values `u = f(z)`, sometimes called pseudo-points. The stochastic variational Gaussian process (SVGP) approximation is defined by a variational distribution `q(u)` over the pseudo-points. In the case of GP regression, the optimal form for `q(u)` is a multivariate Gaussian, which is the only form of `q` currently supported by this package.
The inducing inputs `z` imply some latent function values `u = f(z)`, sometimes called pseudo-points. The [`SparseVariationalApproximation`](@ref) specifies a distribution `q(u)` over the pseudo-points. In the case of GP regression, the optimal form for `q(u)` is a multivariate Gaussian, which is the only form of `q` currently supported by this package.
```julia
using Distributions, LinearAlgebra
q = MvNormal(zeros(length(z)), I)
```
Finally, we pass our `q` along with the inputs `f(z)` to obtain an approximate posterior GP:
```julia
fz = f(z, 1e-6) # 'observe' the process at z with some jitter for numerical stability
approx = SVGP(fz, q) # Instantiate everything needed for the svgp approximation
approx = SparseVariationalApproximation(fz, q) # Instantiate everything needed for the approximation

svgp_posterior = posterior(approx) # Create the approximate posterior
sva_posterior = posterior(approx) # Create the approximate posterior
```

## The Evidence Lower Bound (ELBO)
The approximate posterior constructed above will be a very poor approximation, since `q` was simply chosen to have zero mean and covariance `I`. A measure of the quality of the approximation is given by the ELBO. Optimising this term with respect to the parameters of `q` and the inducing input locations `z` will improve the approximation.
```julia
elbo(SVGP(fz, q), fx, y)
elbo(SparseVariationalApproximation(fz, q), fx, y)
```
A detailed example of how to carry out such optimisation is given in [Regression: Sparse Variational Gaussian Process for Stochastic Optimisation with Flux.jl](@ref). For an example of non-conjugate inference, see [Classification: Sparse Variational Approximation for Non-Conjugate Likelihoods with Optim's L-BFGS](@ref).

# Available Parametrizations

Two parametrizations of `q(u)` are presently available: [`Centered`](@ref) and [`NonCentered`](@ref).
The `Centered` parametrization expresses `q(u)` directly in terms of its mean and covariance.
The `NonCentered` parametrization instead parametrizes the mean and covariance of
`ε := cholesky(cov(u)).U' \ (u - mean(u))`.
These parametrizations are also known respectively as "Unwhitened" and "Whitened".

The choice of parametrization can have a substantial impact on the time it takes for ELBO
optimization to converge, and which parametrization is better in a particular situation is
not generally obvious.
That being said, the `NonCentered` parametrization often converges in fewer iterations, so it is the default --
it is what is used in all of the examples above.

If you require a particular parametrization, simply use the 3-argument version of the
approximation constructor:
```julia
SparseVariationalApproximation(Centered(), fz, q)
SparseVariationalApproximation(NonCentered(), fz, q)
```

For a general discussion around these two parametrizations, see e.g. [^Gorinova].
For a GP-specific discussion, see e.g. section 3.4 of [^Paciorek].

[^Gorinova]: Gorinova, Maria and Moore, Dave and Hoffman, Matthew [Automatic Reparameterisation of Probabilistic Programs](http://proceedings.mlr.press/v119/gorinova20a)
[^Paciorek]: [Paciorek, Christopher Joseph. Nonstationary Gaussian processes for regression and spatial modelling. Diss. Carnegie Mellon University, 2003.](https://www.stat.berkeley.edu/~paciorek/diss/paciorek-thesis.pdf)
15 changes: 12 additions & 3 deletions src/ApproximateGPs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,18 @@ using ChainRulesCore
using FillArrays
using KLDivergences

using AbstractGPs: AbstractGP, FiniteGP, LatentFiniteGP, ApproxPosteriorGP, At_A, diag_At_A

export SparseVariationalApproximation
using AbstractGPs:
AbstractGP,
FiniteGP,
LatentFiniteGP,
ApproxPosteriorGP,
At_A,
diag_At_A,
Xt_A_X,
Xt_A_Y,
diag_Xt_A_X

export SparseVariationalApproximation, Centered, NonCentered
export DefaultQuadrature, Analytic, GaussHermite, MonteCarlo

include("utils.jl")
Expand Down
Loading

2 comments on commit a3db592

@willtebbutt
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register()

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/48891

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.2.2 -m "<description of version>" a3db592a69078c2df90cb0b891df0669d08a0534
git push origin v0.2.2

Please sign in to comment.