Skip to content

Commit

Permalink
Refactor docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
thomaspinder committed Nov 6, 2023
1 parent 0ebaaf9 commit 8d95268
Show file tree
Hide file tree
Showing 8 changed files with 31 additions and 27 deletions.
6 changes: 3 additions & 3 deletions benchmarks/objectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def setup(self, n_datapoints: int, n_dims: int):
self.data = gpx.Dataset(X=self.X, y=self.y)
kernel = gpx.kernels.RBF(active_dims=list(range(n_dims)))
meanf = gpx.mean_functions.Constant()
self.prior = gpx.Prior(kernel=kernel, mean_function=meanf)
self.prior = gpx.gps.Prior(kernel=kernel, mean_function=meanf)
self.likelihood = gpx.likelihoods.Gaussian(num_datapoints=self.data.n)
self.objective = gpx.ConjugateMLL()
self.posterior = self.prior * self.likelihood
Expand All @@ -48,7 +48,7 @@ def setup(self, n_datapoints: int, n_dims: int):
self.data = gpx.Dataset(X=self.X, y=self.y)
kernel = gpx.kernels.RBF(active_dims=list(range(n_dims)))
meanf = gpx.mean_functions.Constant()
self.prior = gpx.Prior(kernel=kernel, mean_function=meanf)
self.prior = gpx.gps.Prior(kernel=kernel, mean_function=meanf)
self.likelihood = gpx.likelihoods.Bernoulli(num_datapoints=self.data.n)
self.objective = gpx.LogPosteriorDensity()
self.posterior = self.prior * self.likelihood
Expand All @@ -75,7 +75,7 @@ def setup(self, n_datapoints: int, n_dims: int):
self.data = gpx.Dataset(X=self.X, y=self.y)
kernel = gpx.kernels.RBF(active_dims=list(range(n_dims)))
meanf = gpx.mean_functions.Constant()
self.prior = gpx.Prior(kernel=kernel, mean_function=meanf)
self.prior = gpx.gps.Prior(kernel=kernel, mean_function=meanf)
self.likelihood = gpx.likelihoods.Poisson(num_datapoints=self.data.n)
self.objective = gpx.LogPosteriorDensity()
self.posterior = self.prior * self.likelihood
Expand Down
6 changes: 3 additions & 3 deletions benchmarks/predictions.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def setup(self, n_test: int, n_dims: int):
self.data = gpx.Dataset(X=self.X, y=self.y)
kernel = gpx.kernels.RBF(active_dims=list(range(n_dims)))
meanf = gpx.mean_functions.Constant()
self.prior = gpx.Prior(kernel=kernel, mean_function=meanf)
self.prior = gpx.gps.Prior(kernel=kernel, mean_function=meanf)
self.likelihood = gpx.likelihoods.Gaussian(num_datapoints=self.data.n)
self.posterior = self.prior * self.likelihood
key, subkey = jr.split(key)
Expand All @@ -46,7 +46,7 @@ def setup(self, n_test: int, n_dims: int):
self.data = gpx.Dataset(X=self.X, y=self.y)
kernel = gpx.kernels.RBF(active_dims=list(range(n_dims)))
meanf = gpx.mean_functions.Constant()
self.prior = gpx.Prior(kernel=kernel, mean_function=meanf)
self.prior = gpx.gps.Prior(kernel=kernel, mean_function=meanf)
self.likelihood = gpx.likelihoods.Bernoulli(num_datapoints=self.data.n)
self.posterior = self.prior * self.likelihood
key, subkey = jr.split(key)
Expand All @@ -71,7 +71,7 @@ def setup(self, n_test: int, n_dims: int):
self.data = gpx.Dataset(X=self.X, y=self.y)
kernel = gpx.kernels.RBF(active_dims=list(range(n_dims)))
meanf = gpx.mean_functions.Constant()
self.prior = gpx.Prior(kernel=kernel, mean_function=meanf)
self.prior = gpx.gps.Prior(kernel=kernel, mean_function=meanf)
self.likelihood = gpx.likelihoods.Bernoulli(num_datapoints=self.data.n)
self.posterior = self.prior * self.likelihood
key, subkey = jr.split(key)
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def setup(self, n_datapoints: int, n_inducing: int):
self.data = gpx.Dataset(X=self.X, y=self.y)
kernel = gpx.kernels.RBF(active_dims=list(range(1)))
meanf = gpx.mean_functions.Constant()
self.prior = gpx.Prior(kernel=kernel, mean_function=meanf)
self.prior = gpx.gps.Prior(kernel=kernel, mean_function=meanf)
self.likelihood = gpx.likelihoods.Gaussian(num_datapoints=self.data.n)
self.posterior = self.prior * self.likelihood

Expand Down
2 changes: 1 addition & 1 deletion benchmarks/stochastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def setup(self, n_datapoints: int, n_inducing: int, batch_size: int):
self.data = gpx.Dataset(X=self.X, y=self.y)
kernel = gpx.kernels.RBF(active_dims=list(range(1)))
meanf = gpx.mean_functions.Constant()
self.prior = gpx.Prior(kernel=kernel, mean_function=meanf)
self.prior = gpx.gps.Prior(kernel=kernel, mean_function=meanf)
self.likelihood = gpx.likelihoods.Gaussian(num_datapoints=self.data.n)
self.posterior = self.prior * self.likelihood

Expand Down
8 changes: 4 additions & 4 deletions gpjax/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,17 +70,17 @@ def fit( # noqa: PLR0913
>>> D = gpx.Dataset(X, y)
>>>
>>> # (2) Define your model:
>>> class LinearModel(gpx.Module):
weight: float = gpx.param_field()
bias: float = gpx.param_field()
>>> class LinearModel(gpx.base.Module):
weight: float = gpx.base.param_field()
bias: float = gpx.base.param_field()
def __call__(self, x):
return self.weight * x + self.bias
>>> model = LinearModel(weight=1.0, bias=1.0)
>>>
>>> # (3) Define your loss function:
>>> class MeanSquareError(gpx.AbstractObjective):
>>> class MeanSquareError(gpx.objectives.AbstractObjective):
def evaluate(self, model: LinearModel, train_data: gpx.Dataset) -> float:
return jnp.mean((train_data.y - model(train_data.X)) ** 2)
>>>
Expand Down
14 changes: 7 additions & 7 deletions gpjax/gps.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ class Prior(AbstractPrior):
>>> kernel = gpx.kernels.RBF()
>>> meanf = gpx.mean_functions.Zero()
>>> prior = gpx.Prior(mean_function=meanf, kernel = kernel)
>>> prior = gpx.gps.Prior(mean_function=meanf, kernel = kernel)
```
"""

Expand Down Expand Up @@ -183,7 +183,7 @@ def __mul__(self, other):
>>>
>>> meanf = gpx.mean_functions.Zero()
>>> kernel = gpx.kernels.RBF()
>>> prior = gpx.Prior(mean_function=meanf, kernel = kernel)
>>> prior = gpx.gps.Prior(mean_function=meanf, kernel = kernel)
>>> likelihood = gpx.likelihoods.Gaussian(num_datapoints=100)
>>>
>>> prior * likelihood
Expand Down Expand Up @@ -244,7 +244,7 @@ def predict(self, test_inputs: Num[Array, "N D"]) -> ReshapedGaussianDistributio
>>>
>>> kernel = gpx.kernels.RBF()
>>> meanf = gpx.mean_functions.Zero()
>>> prior = gpx.Prior(mean_function=meanf, kernel = kernel)
>>> prior = gpx.gps.Prior(mean_function=meanf, kernel = kernel)
>>>
>>> prior.predict(jnp.linspace(0, 1, 100))
```
Expand Down Expand Up @@ -310,7 +310,7 @@ def sample_approx(
>>>
>>> meanf = gpx.mean_functions.Zero()
>>> kernel = gpx.kernels.RBF()
>>> prior = gpx.Prior(mean_function=meanf, kernel = kernel)
>>> prior = gpx.gps.Prior(mean_function=meanf, kernel = kernel)
>>>
>>> sample_fn = prior.sample_approx(10, key)
>>> sample_fn(jnp.linspace(0, 1, 100).reshape(-1, 1))
Expand Down Expand Up @@ -434,7 +434,7 @@ class ConjugatePosterior(AbstractPosterior):
>>> import gpjax as gpx
>>> import jax.numpy as jnp
>>> prior = gpx.Prior(
>>> prior = gpx.gps.Prior(
mean_function = gpx.mean_functions.Zero(),
kernel = gpx.kernels.RBF()
)
Expand Down Expand Up @@ -482,8 +482,8 @@ def predict(
>>> D = gpx.Dataset(X=xtrain, y=ytrain)
>>> xtest = jnp.linspace(0, 1).reshape(-1, 1)
>>>
>>> prior = gpx.Prior(mean_function = gpx.Zero(), kernel = gpx.RBF())
>>> posterior = prior * gpx.Gaussian(num_datapoints = D.n)
>>> prior = gpx.gps.Prior(mean_function = gpx.mean_functions.Zero(), kernel = gpx.kernels.RBF())
>>> posterior = prior * gpx.likelihoods.Gaussian(num_datapoints = D.n)
>>> predictive_dist = posterior(xtest, D)
```
Expand Down
8 changes: 4 additions & 4 deletions gpjax/objectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,10 @@ def step(
>>> meanf = gpx.mean_functions.Constant()
>>> kernel = gpx.kernels.RBF()
>>> likelihood = gpx.likelihoods.Gaussian(num_datapoints=D.n)
>>> prior = gpx.Prior(mean_function = meanf, kernel=kernel)
>>> prior = gpx.gps.Prior(mean_function = meanf, kernel=kernel)
>>> posterior = prior * likelihood
>>>
>>> mll = gpx.ConjugateMLL(negative=True)
>>> mll = gpx.objectives.ConjugateMLL(negative=True)
>>> mll(posterior, train_data = D)
```
Expand All @@ -112,13 +112,13 @@ def step(
marginal log-likelihood. This can be realised through
```python
mll = gpx.ConjugateMLL(negative=True)
mll = gpx.objectives.ConjugateMLL(negative=True)
```
For optimal performance, the marginal log-likelihood should be ``jax.jit``
compiled.
```python
mll = jit(gpx.ConjugateMLL(negative=True))
mll = jit(gpx.objectives.ConjugateMLL(negative=True))
```
Args:
Expand Down
12 changes: 8 additions & 4 deletions tests/test_objectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ def test_conjugate_mll(

# Build model
p = Prior(
kernel=gpx.RBF(active_dims=list(range(num_dims))), mean_function=gpx.Constant()
kernel=gpx.kernels.RBF(active_dims=list(range(num_dims))),
mean_function=gpx.mean_functions.Constant(),
)
likelihood = Gaussian(num_datapoints=num_datapoints)
post = p * likelihood
Expand Down Expand Up @@ -93,7 +94,8 @@ def test_non_conjugate_mll(

# Build model
p = Prior(
kernel=gpx.RBF(active_dims=list(range(num_dims))), mean_function=gpx.Constant()
kernel=gpx.kernels.RBF(active_dims=list(range(num_dims))),
mean_function=gpx.mean_functions.Constant(),
)
likelihood = Bernoulli(num_datapoints=num_datapoints)
post = p * likelihood
Expand Down Expand Up @@ -129,7 +131,8 @@ def test_collapsed_elbo(
)

p = Prior(
kernel=gpx.RBF(active_dims=list(range(num_dims))), mean_function=gpx.Constant()
kernel=gpx.kernels.RBF(active_dims=list(range(num_dims))),
mean_function=gpx.mean_functions.Constant(),
)
likelihood = Gaussian(num_datapoints=num_datapoints)
q = gpx.CollapsedVariationalGaussian(posterior=p * likelihood, inducing_inputs=z)
Expand Down Expand Up @@ -169,7 +172,8 @@ def test_elbo(
)

p = Prior(
kernel=gpx.RBF(active_dims=list(range(num_dims))), mean_function=gpx.Constant()
kernel=gpx.kernels.RBF(active_dims=list(range(num_dims))),
mean_function=gpx.mean_functions.Constant(),
)
if binary:
likelihood = Bernoulli(num_datapoints=num_datapoints)
Expand Down

0 comments on commit 8d95268

Please sign in to comment.