Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
henrymoss committed Nov 7, 2023
1 parent cde11d8 commit dbe8287
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 4 deletions.
3 changes: 0 additions & 3 deletions docs/examples/intro_to_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

config.update("jax_enable_x64", True)

from jax import jit
import jax.numpy as jnp
import jax.random as jr
from jaxtyping import install_import_hook, Float
Expand Down Expand Up @@ -235,7 +234,6 @@ def forrester(x: Float[Array, "N"]) -> Float[Array, "N"]:
# %%
negative_mll = gpx.objectives.ConjugateMLL(negative=True)
negative_mll(no_opt_posterior, train_data=D)
negative_mll = jit(negative_mll)

opt_posterior, history = gpx.fit_scipy(
model=no_opt_posterior,
Expand Down Expand Up @@ -536,7 +534,6 @@ def forrester(x: Float[Array, "N"]) -> Float[Array, "N"]:
# %%
negative_mll = gpx.objectives.ConjugateMLL(negative=True)
negative_mll(posterior, train_data=D)
negative_mll = jit(negative_mll)

opt_posterior, history = gpx.fit(
model=posterior,
Expand Down
4 changes: 3 additions & 1 deletion gpjax/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,9 @@ def loss(model: Module, data: Dataset) -> ScalarFloat:
if result.success:
print("Optimization was successful")
else:
raise FailedScipyFitError("Optimization failed, try increasing max_iters or using a different optimiser.")
raise FailedScipyFitError(
"Optimization failed, try increasing max_iters or using a different optimiser."
)
print(f"Final loss is {result.fun_val} after {result.num_fun_eval} iterations")

# Constrained space.
Expand Down

0 comments on commit dbe8287

Please sign in to comment.