diff --git a/pymc_experimental/gp/pytensor_gp.py b/pymc_experimental/gp/pytensor_gp.py new file mode 100644 index 00000000..edcb1b14 --- /dev/null +++ b/pymc_experimental/gp/pytensor_gp.py @@ -0,0 +1,160 @@ +from collections.abc import Sequence + +import pymc as pm +import pytensor.tensor as pt + +from pymc.distributions.distribution import Continuous +from pymc.model.fgraph import fgraph_from_model, model_free_rv, model_from_fgraph +from pytensor import Variable +from pytensor.compile.builders import OpFromGraph + + +class GPCovariance(OpFromGraph): + """OFG representing a GP covariance""" + + @staticmethod + def square_dist(X, Xs, ls): + assert X.ndim == 2, "Complain to Bill about it" + assert Xs.ndim == 2, "Complain to Bill about it" + + X = X / ls + Xs = Xs / ls + + X2 = pt.sum(pt.square(X), axis=-1) + Xs2 = pt.sum(pt.square(Xs), axis=-1) + + sqd = -2.0 * X @ Xs.mT + (X2[..., :, None] + Xs2[..., None, :]) + return pt.clip(sqd, 0, pt.inf) + + +class ExpQuadCov(GPCovariance): + """ + ExpQuad covariance function + """ + + @classmethod + def exp_quad_full(cls, X, Xs, ls): + return pt.exp(-0.5 * cls.square_dist(X, Xs, ls)) + + @classmethod + def build_covariance(cls, X, Xs=None, *, ls): + X = pt.as_tensor(X) + if Xs is None: + Xs = X + else: + Xs = pt.as_tensor(Xs) + ls = pt.as_tensor(ls) + + out = cls.exp_quad_full(X, Xs, ls) + if Xs is X: + return cls(inputs=[X, ls], outputs=[out])(X, ls) + else: + return cls(inputs=[X, Xs, ls], outputs=[out])(X, Xs, ls) + + +def ExpQuad(X, X_new=None, *, ls=1.0): + return ExpQuadCov.build_covariance(X, X_new, ls=ls) + + +class GP_RV(pm.MvNormal.rv_type): + name = "gaussian_process" + signature = "(n),(n,n)->(n)" + dtype = "floatX" + _print_name = ("GP", "\\operatorname{GP}") + + +class GP(Continuous): + rv_type = GP_RV + rv_op = GP_RV() + + @classmethod + def dist(cls, cov, **kwargs): + cov = pt.as_tensor(cov) + mu = pt.zeros(cov.shape[-1]) + return super().dist([mu, cov], **kwargs) + + +def conditional_gp( + model, + gp: Variable | str, + Xnew, + *, + jitter=1e-6, + dims: Sequence[str] = (), + inline: bool = False, +): + """ + Condition a GP on new data. + + Parameters + ---------- + model: Model + gp: Variable | str + The GP to condition on. + Xnew: Tensor-like + New data to condition the GP on. + jitter: float, default=1e-6 + Jitter to add to the new GP covariance matrix. + dims: Sequence[str], default=() + Dimensions of the new GP. + inline: bool, default=False + Whether to inline the new GP in place of the old one. This is not always a safe operation. + If True, any variables that depend on the GP will be updated to depend on the new GP. + + Returns + ------- + Conditional model: Model + A new model with a GP free RV named f"{gp.name}_star" conditioned on the new data. + + """ + + def _build_conditional(Xnew, f, cov, jitter): + if not isinstance(cov.owner.op, GPCovariance): + # TODO: Look for xx kernels in the ancestors of f + raise NotImplementedError(f"Cannot build conditional of {cov.owner.op} operation") + + X, ls = cov.owner.inputs + + Kxx = cov + # Kxs = toposort_replace(cov, tuple(zip(xx_kernels, xs_kernels)), rebuild=True) + Kxs = cov.owner.op.build_covariance(X, Xnew, ls=ls) + # Kss = toposort_replace(cov, tuple(zip(xx_kernels, ss_kernels)), rebuild=True) + Kss = cov.owner.op.build_covariance(Xnew, ls=ls) + + L = pt.linalg.cholesky(Kxx + pt.eye(X.shape[0]) * jitter) + # TODO: Use cho_solve + A = pt.linalg.solve_triangular(L, Kxs, lower=True) + v = pt.linalg.solve_triangular(L, f, lower=True) + + mu = (A.mT @ v).T # Vector? + cov = Kss - (A.mT @ A) + + return mu, cov + + if isinstance(gp, Variable): + assert model[gp.name] is gp + else: + gp = model[gp.name] + + fgraph, memo = fgraph_from_model(model) + gp_model_var = memo[gp] + gp_rv = gp_model_var.owner.inputs[0] + + if isinstance(gp_rv.owner.op, pm.MvNormal.rv_type): + _, cov = gp_rv.owner.op.dist_params(gp.owner) + else: + raise NotImplementedError("Can only condition on pure GPs") + + mu_star, cov_star = _build_conditional(Xnew, gp_model_var, cov, jitter) + gp_rv_star = pm.MvNormal.dist(mu_star, cov_star, name=f"{gp.name}_star") + + value = gp_rv_star.clone() + transform = None + gp_model_var_star = model_free_rv(gp_rv_star, value, transform, *dims) + + if inline: + fgraph.replace(gp_model_var, gp_model_var_star, import_missing=True) + else: + fgraph.add_output(gp_model_var_star, import_missing=True) + + return model_from_fgraph(fgraph, mutate_fgraph=True) diff --git a/tests/test_gp.py b/tests/test_gp.py new file mode 100644 index 00000000..e461ea90 --- /dev/null +++ b/tests/test_gp.py @@ -0,0 +1,163 @@ +import arviz as az +import numpy as np +import pymc as pm +import pytensor.tensor as pt +import pytest + +from pymc_experimental.gp.pytensor_gp import GP, ExpQuad, conditional_gp + + +def build_latent_model(): + with pm.Model() as m: + X = pm.Data("X", np.arange(3)[:, None]) + y = np.full(3, np.pi) + ls = 1.0 + cov = ExpQuad(X, ls=ls) + gp = GP("gp", cov=cov) + + sigma = 1.0 + obs = pm.Normal("obs", mu=gp, sigma=sigma, observed=y) + + return m + + +def build_latent_model_old_API(): + with pm.Model() as m: + X = pm.Data("X", np.arange(3)[:, None]) + y = np.full(3, np.pi) + ls = 1.0 + cov = pm.gp.cov.ExpQuad(1, ls) + gp_class = pm.gp.Latent(cov_func=cov) + gp = gp_class.prior("gp", X, reparameterize=False) + + sigma = 1.0 + obs = pm.Normal("obs", mu=gp, sigma=sigma, observed=y) + + return m, gp_class + + +def test_exp_quad(): + x = pt.arange(3)[:, None] + ls = pt.ones(()) + cov = ExpQuad(x, ls=ls).eval() + expected_distance = np.array([[0.0, 1.0, 4.0], [1.0, 0.0, 1.0], [4.0, 1.0, 0.0]]) + + np.testing.assert_allclose(cov, np.exp(-0.5 * expected_distance)) + + +def test_latent_model_prior(): + m = build_latent_model() + ref_m, _ = build_latent_model_old_API() + + prior = pm.draw(m["gp"], draws=1000) + prior_ref = pm.draw(ref_m["gp"], draws=1000) + + np.testing.assert_allclose( + prior.mean(), + prior_ref.mean(), + atol=0.1, + ) + + np.testing.assert_allclose( + prior.std(), + prior_ref.std(), + rtol=0.1, + ) + + +def test_latent_model_logp(): + m = build_latent_model() + ip = m.initial_point() + + ref_m, _ = build_latent_model_old_API() + + np.testing.assert_allclose( + m.compile_logp()(ip), + ref_m.compile_logp()(ip), + rtol=1e-6, + ) + + +@pytest.mark.parametrize("inline", (False, True)) +def test_latent_model_conditional(inline): + rng = np.random.default_rng(0) + posterior = az.from_dict( + posterior={"gp": rng.normal(np.pi, 1e-3, size=(4, 1000, 3))}, + constant_data={"X": np.arange(3)[:, None]}, + ) + + new_x = np.array([3, 4])[:, None] + + m = build_latent_model() + with m: + pm.Deterministic("gp_exp", m["gp"].exp()) + + with conditional_gp(m, m["gp"], new_x, inline=inline) as cgp: + pred = pm.sample_posterior_predictive( + posterior, + var_names=["gp_star", "gp_exp"], + progressbar=False, + ).posterior_predictive + + ref_m, ref_gp_class = build_latent_model_old_API() + with ref_m: + gp_star = ref_gp_class.conditional("gp_star", Xnew=new_x) + pred_ref = pm.sample_posterior_predictive( + posterior, + var_names=["gp_star"], + progressbar=False, + ).posterior_predictive + + np.testing.assert_allclose( + pred["gp_star"].mean(), + pred_ref["gp_star"].mean(), + atol=0.1, + ) + + np.testing.assert_allclose( + pred["gp_star"].std(), + pred_ref["gp_star"].std(), + rtol=0.1, + ) + + if inline: + assert np.testing.assert_allclose( + pred["gp_exp"], + np.exp(pred["gp_star"]), + ) + else: + np.testing.assert_allclose( + pred["gp_exp"], + np.exp(posterior.posterior["gp"]), + ) + + +# +# def test_marginal_sigma_rewrites_to_white_noise_cov(marginal_model, ): +# obs = marginal_model["obs"] +# +# # TODO: Bring these checks back after we implement marginalization of the GP RV +# # +# # assert sum(isinstance(var.owner.op, pm.Normal.rv_type) +# # for var in ancestors([obs]) +# # if var.owner is not None) == 1 +# # +# f = pm.compile_pymc([], obs) +# # +# # assert not any(isinstance(node.op, pm.Normal.rv_type) for node in f.maker.fgraph.apply_nodes) +# +# draws = np.stack([f() for _ in range(10_000)]) +# empirical_cov = np.cov(draws.T) +# +# expected_distance = np.array([[0.0, 1.0, 4.0], [1.0, 0.0, 1.0], [4.0, 1.0, 0.0]]) +# +# np.testing.assert_allclose( +# empirical_cov, np.exp(-0.5 * expected_distance) + np.eye(3), atol=0.1, rtol=0.1 +# ) +# +# +# def test_marginal_gp_logp(marginal_model): +# expected_logps = {"obs": -8.8778} +# point_logps = marginal_model.point_logps(round_vals=4) +# for v1, v2 in zip(point_logps.values(), expected_logps.values()): +# np.testing.assert_allclose(v1, v2, atol=1e-6)