Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pure pytensor GP #395

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 160 additions & 0 deletions pymc_experimental/gp/pytensor_gp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
from collections.abc import Sequence

import pymc as pm
import pytensor.tensor as pt

from pymc.distributions.distribution import Continuous
from pymc.model.fgraph import fgraph_from_model, model_free_rv, model_from_fgraph
from pytensor import Variable
from pytensor.compile.builders import OpFromGraph


class GPCovariance(OpFromGraph):
"""OFG representing a GP covariance"""

@staticmethod
def square_dist(X, Xs, ls):
assert X.ndim == 2, "Complain to Bill about it"
assert Xs.ndim == 2, "Complain to Bill about it"

X = X / ls
Xs = Xs / ls

X2 = pt.sum(pt.square(X), axis=-1)
Xs2 = pt.sum(pt.square(Xs), axis=-1)

sqd = -2.0 * X @ Xs.mT + (X2[..., :, None] + Xs2[..., None, :])
return pt.clip(sqd, 0, pt.inf)


class ExpQuadCov(GPCovariance):
"""
ExpQuad covariance function
"""

@classmethod
def exp_quad_full(cls, X, Xs, ls):
return pt.exp(-0.5 * cls.square_dist(X, Xs, ls))

@classmethod
def build_covariance(cls, X, Xs=None, *, ls):
X = pt.as_tensor(X)
if Xs is None:
Xs = X
else:
Xs = pt.as_tensor(Xs)
ls = pt.as_tensor(ls)

out = cls.exp_quad_full(X, Xs, ls)
if Xs is X:
return cls(inputs=[X, ls], outputs=[out])(X, ls)
else:
return cls(inputs=[X, Xs, ls], outputs=[out])(X, Xs, ls)


def ExpQuad(X, X_new=None, *, ls=1.0):
return ExpQuadCov.build_covariance(X, X_new, ls=ls)


class GP_RV(pm.MvNormal.rv_type):
name = "gaussian_process"
signature = "(n),(n,n)->(n)"
dtype = "floatX"
_print_name = ("GP", "\\operatorname{GP}")


class GP(Continuous):
rv_type = GP_RV
rv_op = GP_RV()

@classmethod
def dist(cls, cov, **kwargs):
cov = pt.as_tensor(cov)
mu = pt.zeros(cov.shape[-1])
return super().dist([mu, cov], **kwargs)


def conditional_gp(
model,
gp: Variable | str,
Xnew,
*,
jitter=1e-6,
dims: Sequence[str] = (),
inline: bool = False,
):
"""
Condition a GP on new data.

Parameters
----------
model: Model
gp: Variable | str
The GP to condition on.
Xnew: Tensor-like
New data to condition the GP on.
jitter: float, default=1e-6
Jitter to add to the new GP covariance matrix.
dims: Sequence[str], default=()
Dimensions of the new GP.
inline: bool, default=False
Whether to inline the new GP in place of the old one. This is not always a safe operation.
If True, any variables that depend on the GP will be updated to depend on the new GP.

Returns
-------
Conditional model: Model
A new model with a GP free RV named f"{gp.name}_star" conditioned on the new data.

"""

def _build_conditional(Xnew, f, cov, jitter):
if not isinstance(cov.owner.op, GPCovariance):
# TODO: Look for xx kernels in the ancestors of f
raise NotImplementedError(f"Cannot build conditional of {cov.owner.op} operation")

X, ls = cov.owner.inputs

Kxx = cov
# Kxs = toposort_replace(cov, tuple(zip(xx_kernels, xs_kernels)), rebuild=True)
Kxs = cov.owner.op.build_covariance(X, Xnew, ls=ls)
# Kss = toposort_replace(cov, tuple(zip(xx_kernels, ss_kernels)), rebuild=True)
Kss = cov.owner.op.build_covariance(Xnew, ls=ls)

L = pt.linalg.cholesky(Kxx + pt.eye(X.shape[0]) * jitter)
# TODO: Use cho_solve
A = pt.linalg.solve_triangular(L, Kxs, lower=True)
v = pt.linalg.solve_triangular(L, f, lower=True)

mu = (A.mT @ v).T # Vector?
cov = Kss - (A.mT @ A)

return mu, cov

if isinstance(gp, Variable):
assert model[gp.name] is gp
else:
gp = model[gp.name]

fgraph, memo = fgraph_from_model(model)
gp_model_var = memo[gp]
gp_rv = gp_model_var.owner.inputs[0]

if isinstance(gp_rv.owner.op, pm.MvNormal.rv_type):
_, cov = gp_rv.owner.op.dist_params(gp.owner)
else:
raise NotImplementedError("Can only condition on pure GPs")

mu_star, cov_star = _build_conditional(Xnew, gp_model_var, cov, jitter)
gp_rv_star = pm.MvNormal.dist(mu_star, cov_star, name=f"{gp.name}_star")

value = gp_rv_star.clone()
transform = None
gp_model_var_star = model_free_rv(gp_rv_star, value, transform, *dims)

if inline:
fgraph.replace(gp_model_var, gp_model_var_star, import_missing=True)
else:
fgraph.add_output(gp_model_var_star, import_missing=True)

return model_from_fgraph(fgraph, mutate_fgraph=True)
163 changes: 163 additions & 0 deletions tests/test_gp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
import arviz as az
import numpy as np
import pymc as pm
import pytensor.tensor as pt
import pytest

from pymc_experimental.gp.pytensor_gp import GP, ExpQuad, conditional_gp


def build_latent_model():
with pm.Model() as m:
X = pm.Data("X", np.arange(3)[:, None])
y = np.full(3, np.pi)
ls = 1.0
cov = ExpQuad(X, ls=ls)
gp = GP("gp", cov=cov)

sigma = 1.0
obs = pm.Normal("obs", mu=gp, sigma=sigma, observed=y)

return m


def build_latent_model_old_API():
with pm.Model() as m:
X = pm.Data("X", np.arange(3)[:, None])
y = np.full(3, np.pi)
ls = 1.0
cov = pm.gp.cov.ExpQuad(1, ls)
gp_class = pm.gp.Latent(cov_func=cov)
gp = gp_class.prior("gp", X, reparameterize=False)

sigma = 1.0
obs = pm.Normal("obs", mu=gp, sigma=sigma, observed=y)

return m, gp_class


def test_exp_quad():
x = pt.arange(3)[:, None]
ls = pt.ones(())
cov = ExpQuad(x, ls=ls).eval()
expected_distance = np.array([[0.0, 1.0, 4.0], [1.0, 0.0, 1.0], [4.0, 1.0, 0.0]])

np.testing.assert_allclose(cov, np.exp(-0.5 * expected_distance))


def test_latent_model_prior():
m = build_latent_model()
ref_m, _ = build_latent_model_old_API()

prior = pm.draw(m["gp"], draws=1000)
prior_ref = pm.draw(ref_m["gp"], draws=1000)

np.testing.assert_allclose(
prior.mean(),
prior_ref.mean(),
atol=0.1,
)

np.testing.assert_allclose(
prior.std(),
prior_ref.std(),
rtol=0.1,
)


def test_latent_model_logp():
m = build_latent_model()
ip = m.initial_point()

ref_m, _ = build_latent_model_old_API()

np.testing.assert_allclose(
m.compile_logp()(ip),
ref_m.compile_logp()(ip),
rtol=1e-6,
)


@pytest.mark.parametrize("inline", (False, True))
def test_latent_model_conditional(inline):
rng = np.random.default_rng(0)
posterior = az.from_dict(
posterior={"gp": rng.normal(np.pi, 1e-3, size=(4, 1000, 3))},
constant_data={"X": np.arange(3)[:, None]},
)

new_x = np.array([3, 4])[:, None]

m = build_latent_model()
with m:
pm.Deterministic("gp_exp", m["gp"].exp())

with conditional_gp(m, m["gp"], new_x, inline=inline) as cgp:
pred = pm.sample_posterior_predictive(
posterior,
var_names=["gp_star", "gp_exp"],
progressbar=False,
).posterior_predictive

ref_m, ref_gp_class = build_latent_model_old_API()
with ref_m:
gp_star = ref_gp_class.conditional("gp_star", Xnew=new_x)
pred_ref = pm.sample_posterior_predictive(
posterior,
var_names=["gp_star"],
progressbar=False,
).posterior_predictive

np.testing.assert_allclose(
pred["gp_star"].mean(),
pred_ref["gp_star"].mean(),
atol=0.1,
)

np.testing.assert_allclose(
pred["gp_star"].std(),
pred_ref["gp_star"].std(),
rtol=0.1,
)

if inline:
assert np.testing.assert_allclose(
pred["gp_exp"],
np.exp(pred["gp_star"]),
)
else:
np.testing.assert_allclose(
pred["gp_exp"],
np.exp(posterior.posterior["gp"]),
)


#
# def test_marginal_sigma_rewrites_to_white_noise_cov(marginal_model, ):
# obs = marginal_model["obs"]
#
# # TODO: Bring these checks back after we implement marginalization of the GP RV
# #
# # assert sum(isinstance(var.owner.op, pm.Normal.rv_type)
# # for var in ancestors([obs])
# # if var.owner is not None) == 1
# #
# f = pm.compile_pymc([], obs)
# #
# # assert not any(isinstance(node.op, pm.Normal.rv_type) for node in f.maker.fgraph.apply_nodes)
#
# draws = np.stack([f() for _ in range(10_000)])
# empirical_cov = np.cov(draws.T)
#
# expected_distance = np.array([[0.0, 1.0, 4.0], [1.0, 0.0, 1.0], [4.0, 1.0, 0.0]])
#
# np.testing.assert_allclose(
# empirical_cov, np.exp(-0.5 * expected_distance) + np.eye(3), atol=0.1, rtol=0.1
# )
#
#
# def test_marginal_gp_logp(marginal_model):
# expected_logps = {"obs": -8.8778}
# point_logps = marginal_model.point_logps(round_vals=4)
# for v1, v2 in zip(point_logps.values(), expected_logps.values()):
# np.testing.assert_allclose(v1, v2, atol=1e-6)
Loading