SVGP in stheno #24
Replies: 1 comment
-
Hey @patel-zeel! I think an example of SVGD with Stheno would be a super good addition. Perhaps we can slightly generalise that example to more general non-Gaussian likelihoods. See also #22. I'm currently a little overloaded, but I should be able to help out with this in a week or two, if that would be fine. :)
It should be possible to leverage the current machinery of Stheno. Here's roughly the structure that I have in mind: import jax.numpy as jnp
import lab.jax as B
from stheno.jax import GP, EQ, Normal
from varz.jax import Vars, minimise_adam
z = B.randn(jnp.float64, 50, 1)
m = B.shape(z, 0)
x = B.randn(jnp.float64, 100, 1)
f = GP(EQ())
y = f(x, jnp.float64(0.1)).sample()
def objective(vs, state):
p = vs.struct
f = GP(p.variance.positive() * EQ().stretch(p.scale.positive()))
noise = p.noise.positive()
p_z = f(z)
q_z = Normal(
p.q.mean.unbounded(p_z.mean),
p.q.cov.positive_definite(B.dense(p_z.var)),
)
z_sample = q_z.sample()
f_sample = (f | (f(z), z_sample))(x).diagonalise().sample()
lik = Normal(f_sample, B.fill_diag(noise, B.shape(f_sample, 0)))
return -(lik.logpdf(y) - q_z.kl(p_z)), state
vs = Vars(jnp.float64)
state = B.create_random_state(jnp.float64, seed=0)
state = minimise_adam(objective, (vs, state), rate=5e-3, jit=True, trace=True, iters=10_000)
vs.print() Currently, this attempt still has many issues, but hopefully it should illustrate the approach. I'd be happy to spend a little time on this in a week or two! |
Beta Was this translation helpful? Give feedback.
-
Hi @wesselb,
Would you like to see the implementation of SVGP [Hensman et al. 2013] in Stheno? I think it is very similar to PseudoObs (VFE), but as far as I understand from these GPSS slides, KL divergence between the variational distribution and prior needs to be computed in closed form. So, could you please suggest some directions on how to implement SVGP with the existing machinery of Stheno, or if that is not possible, what would be the best way to go about it?
Beta Was this translation helpful? Give feedback.
All reactions