diff --git a/.github/workflows/build_docs.yml b/.github/workflows/build_docs.yml index 539989456..433d4e345 100644 --- a/.github/workflows/build_docs.yml +++ b/.github/workflows/build_docs.yml @@ -47,7 +47,7 @@ jobs: - name: Install and configure Poetry uses: snok/install-poetry@v1 with: - version: 1.2.2 + version: 1.5.1 virtualenvs-create: false virtualenvs-in-project: false installer-parallel: true @@ -59,7 +59,6 @@ jobs: - name: Build the documentation with MKDocs run: | - cp docs/examples/gpjax.mplstyle . poetry install --all-extras --with docs conda install pandoc poetry run mkdocs build diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml index 4a71d8969..5aa90b148 100644 --- a/.github/workflows/integration.yml +++ b/.github/workflows/integration.yml @@ -29,7 +29,7 @@ jobs: - name: Install Poetry uses: snok/install-poetry@v1.3.3 with: - version: 1.4.0 + version: 1.5.1 # Configure Poetry to use the virtual environment in the project - name: Setup Poetry diff --git a/.github/workflows/test_docs.yml b/.github/workflows/test_docs.yml index bbf7f5b4c..9dc9c55d5 100644 --- a/.github/workflows/test_docs.yml +++ b/.github/workflows/test_docs.yml @@ -51,14 +51,13 @@ jobs: - name: Install and configure Poetry uses: snok/install-poetry@v1 with: - version: 1.2.2 + version: 1.5.1 virtualenvs-create: false virtualenvs-in-project: false installer-parallel: true - name: Build the documentation with MKDocs run: | - cp docs/examples/gpjax.mplstyle . poetry install --all-extras --with docs conda install pandoc - poetry run mkdocs build + poetry run python docs/scripts/gen_examples.py && poetry run mkdocs build diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index b3d747e90..967896bf5 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -26,10 +26,13 @@ jobs: python-version: ${{ matrix.python-version }} # Install Poetry - - name: Install Poetry - uses: snok/install-poetry@v1.3.3 + - name: Install and configure Poetry + uses: snok/install-poetry@v1 with: - version: 1.4.0 + version: 1.5.1 + virtualenvs-create: false + virtualenvs-in-project: false + installer-parallel: true # Configure Poetry to use the virtual environment in the project - name: Setup Poetry @@ -39,7 +42,7 @@ jobs: # Install the dependencies - name: Install Package run: | - poetry install --with tests + poetry install --with dev - name: Check docstrings run: | diff --git a/.gitignore b/.gitignore index ae49fa031..e51d8509c 100644 --- a/.gitignore +++ b/.gitignore @@ -152,4 +152,4 @@ package-lock.json node_modules/ docs/api -docs/examples/*.md +docs/_examples diff --git a/docs/index.md b/docs/index.md index 2ee142c9f..7879bc8a3 100644 --- a/docs/index.md +++ b/docs/index.md @@ -4,7 +4,7 @@ GPJax is a didactic Gaussian process (GP) library in JAX, supporting GPU acceleration and just-in-time compilation. We seek to provide a flexible API to enable researchers to rapidly prototype and develop new ideas. -![Gaussian process posterior.](./_static/GP.svg) +![Gaussian process posterior.](static/GP.svg) ## "Hello, GP!" @@ -40,7 +40,7 @@ would write on paper, as shown below. !!! Install - GPJax can be installed via pip. See our [installation guide](https://docs.jaxgaussianprocesses.com/installation/) for further details. + GPJax can be installed via pip. See our [installation guide](installation.md) for further details. ```bash pip install gpjax @@ -48,7 +48,7 @@ would write on paper, as shown below. !!! New - New to GPs? Then why not check out our [introductory notebook](https://docs.jaxgaussianprocesses.com/examples/intro_to_gps/) that starts from Bayes' theorem and univariate Gaussian distributions. + New to GPs? Then why not check out our [introductory notebook](_examples/intro_to_gps.md) that starts from Bayes' theorem and univariate Gaussian distributions. !!! Begin diff --git a/docs/scripts/gen_examples.py b/docs/scripts/gen_examples.py index 51dd0e8cd..8fbe71b46 100644 --- a/docs/scripts/gen_examples.py +++ b/docs/scripts/gen_examples.py @@ -1,20 +1,96 @@ +""" Convert python files in "examples" directory to markdown files using jupytext and nbconvert. + +There's only a minor inconvenience with how supporting files are handled by nbconvert, +see https://github.com/jupyter/nbconvert/issues/1164. But these will be under a private +directory `_examples` in the docs folder, so it's not a big deal. + +""" +from argparse import ArgumentParser from pathlib import Path import subprocess +from concurrent.futures import ThreadPoolExecutor, as_completed +import shutil + +EXCLUDE = ["utils.py"] -EXECUTE = False -EXCLUDE = ["docs/examples/utils.py"] -ALLOW_ERRORS = False +def process_file(file: Path, out_file: Path | None = None, execute: bool = False): + """Converts a python file to markdown using jupytext and nbconvert.""" -for file in Path("docs/").glob("examples/*.py"): - if file.as_posix() in EXCLUDE: - continue + out_dir = out_file.parent + command = f"cd {out_dir.as_posix()} && " - out_file = file.with_suffix(".md") + out_file = out_file.relative_to(out_dir).as_posix() - command = "jupytext --to markdown " - command += f"{'--execute ' if EXECUTE else ''}" - command += f"{'--allow-errors ' if ALLOW_ERRORS else ''}" - command += f"{file} --output {out_file}" + if execute: + command += f"jupytext --to ipynb {file} --output - " + command += ( + f"| jupyter nbconvert --to markdown --execute --stdin --output {out_file}" + ) + else: + command = f"jupytext --to markdown {file} --output {out_file}" subprocess.run(command, shell=True, check=False) + + +def is_modified(file: Path, out_file: Path): + """Check if the output file is older than the input file.""" + return out_file.exists() and out_file.stat().st_mtime < file.stat().st_mtime + + +def main(args): + # project root directory + wdir = Path(__file__).parents[2] + + # output directory + out_dir: Path = args.outdir + out_dir.mkdir(exist_ok=True, parents=True) + + # copy directories in "examples" to output directory + for dir in wdir.glob("examples/*"): + if dir.is_dir(): + (out_dir / dir.name).mkdir(exist_ok=True, parents=True) + for file in dir.glob("*"): + # copy, not move! + shutil.copy(file, out_dir / dir.name / file.name) + + # list of files to be processed + files = [f for f in wdir.glob("examples/*.py") if f.name not in EXCLUDE] + + # process only modified files + if args.only_modified: + files = [f for f in files if is_modified(f, out_dir / f"{f.stem}.md")] + + print(files) + + # process files in parallel + with ThreadPoolExecutor(max_workers=args.max_workers) as executor: + futures = [] + for file in files: + out_file = out_dir / f"{file.stem}.md" + futures.append( + executor.submit( + process_file, file, out_file=out_file, execute=args.execute + ) + ) + + for future in as_completed(futures): + try: + future.result() + except Exception as e: + print(f"Error processing file: {e}") + + +if __name__ == "__main__": + project_root = Path(__file__).parents[2] + + parser = ArgumentParser() + parser.add_argument("--max_workers", type=int, default=4) + parser.add_argument("--execute", action="store_true") + parser.add_argument("--only_modified", action="store_true") + parser.add_argument( + "--outdir", type=Path, default=project_root / "docs" / "_examples" + ) + args = parser.parse_args() + + main(args) diff --git a/docs/scripts/sharp_bits_figure.py b/docs/scripts/sharp_bits_figure.py index 49c2a2c7b..ef0622d52 100644 --- a/docs/scripts/sharp_bits_figure.py +++ b/docs/scripts/sharp_bits_figure.py @@ -20,7 +20,9 @@ import matplotlib as mpl from matplotlib import patches -plt.style.use("../examples/gpjax.mplstyle") +plt.style.use( + "https://raw.githubusercontent.com/JaxGaussianProcesses/GPJax/main/docs/examples/gpjax.mplstyle" +) cols = mpl.rcParams["axes.prop_cycle"].by_key()["color"] # %% diff --git a/docs/sharp_bits.md b/docs/sharp_bits.md index be88beb0c..72aeb726b 100644 --- a/docs/sharp_bits.md +++ b/docs/sharp_bits.md @@ -60,7 +60,7 @@ learning rate is greater is than 0.03, we would end up with a negative variance We visualise this issue below where the red cross denotes the invalid lengthscale value that would be obtained, were we to optimise in the unconstrained parameter space. -![](_static/step_size_figure.svg) +![](static/step_size_figure.svg) A simple but impractical solution would be to use a tiny learning rate which would reduce the possibility of stepping outside of the parameter's support. However, this @@ -70,7 +70,7 @@ subspace of the real-line onto the entire real-line. Here, gradient updates are applied in the unconstrained parameter space before transforming the value back to the original support of the parameters. Such a transformation is known as a bijection. -![](_static/bijector_figure.svg) +![](static/bijector_figure.svg) To help understand this, we show the effect of using a log-exp bijector in the above figure. We have six points on the positive real line that range from 0.1 to 3 depicted @@ -81,8 +81,7 @@ value, we apply the inverse of the bijector, which is the exponential function i case. This gives us back the blue cross. In GPJax, we supply bijective functions using [Tensorflow Probability](https://www.tensorflow.org/probability/api_docs/python/tfp/substrates/jax/bijectors). -In our [PyTrees doc](examples/pytrees.md) document, we detail how the user can define -their own bijectors and attach them to the parameter(s) of their model. + ## Positive-definiteness @@ -91,8 +90,7 @@ their own bijectors and attach them to the parameter(s) of their model. ### Why is positive-definiteness important? The Gram matrix of a kernel, a concept that we explore more in our -[kernels notebook](examples/constructing_new_kernels.py) and our [PyTree notebook](examples/pytrees.md), is a -symmetric positive definite matrix. As such, we +[kernels notebook](_examples/constructing_new_kernels.md). As such, we have a range of tools at our disposal to make subsequent operations on the covariance matrix faster. One of these tools is the Cholesky factorisation that uniquely decomposes any symmetric positive-definite matrix $\mathbf{\Sigma}$ by @@ -158,7 +156,7 @@ for some problems, this amount may need to be increased. ## Slow-to-evaluate Famously, a regular Gaussian process model (as detailed in -[our regression notebook](examples/regression.py)) will scale cubically in the number of data points. +[our regression notebook](_examples/regression.md)) will scale cubically in the number of data points. Consequently, if you try to fit your Gaussian process model to a data set containing more than several thousand data points, then you will likely incur a significant computational overhead. In such cases, we recommend using Sparse Gaussian processes to @@ -168,7 +166,7 @@ When the data contains less than around 50000 data points, we recommend using the collapsed evidence lower bound objective [@titsias2009] to optimise the parameters of your sparse Gaussian process model. Such a model will scale linearly in the number of data points and quadratically in the number of inducing points. We demonstrate its use -in [our sparse regression notebook](examples/collapsed_vi.py). +in [our sparse regression notebook](_examples/collapsed_vi.md). For data sets exceeding 50000 data points, even the sparse Gaussian process outlined above will become computationally infeasible. In such cases, we recommend using the @@ -176,4 +174,4 @@ uncollapsed evidence lower bound objective [@hensman2013gaussian] that allows st mini-batch optimisation of the parameters of your sparse Gaussian process model. Such a model will scale linearly in the batch size and quadratically in the number of inducing points. We demonstrate its use in -[our sparse stochastic variational inference notebook](examples/uncollapsed_vi.py). +[our sparse stochastic variational inference notebook](_examples/uncollapsed_vi.md). diff --git a/docs/_static/GP.pdf b/docs/static/GP.pdf similarity index 100% rename from docs/_static/GP.pdf rename to docs/static/GP.pdf diff --git a/docs/_static/GP.svg b/docs/static/GP.svg similarity index 100% rename from docs/_static/GP.svg rename to docs/static/GP.svg diff --git a/docs/_static/bijector_figure.svg b/docs/static/bijector_figure.svg similarity index 100% rename from docs/_static/bijector_figure.svg rename to docs/static/bijector_figure.svg diff --git a/docs/_static/css/gpjax_theme.css b/docs/static/css/gpjax_theme.css similarity index 95% rename from docs/_static/css/gpjax_theme.css rename to docs/static/css/gpjax_theme.css index d564f3b95..47e6a841a 100644 --- a/docs/_static/css/gpjax_theme.css +++ b/docs/static/css/gpjax_theme.css @@ -1,3 +1,3 @@ nav .bd-links a:hover{ color: #B5121B -} +} \ No newline at end of file diff --git a/docs/_static/favicon.ico b/docs/static/favicon.ico similarity index 100% rename from docs/_static/favicon.ico rename to docs/static/favicon.ico diff --git a/docs/_static/gpjax.mplstyle b/docs/static/gpjax.mplstyle similarity index 100% rename from docs/_static/gpjax.mplstyle rename to docs/static/gpjax.mplstyle diff --git a/docs/_static/gpjax_logo.pdf b/docs/static/gpjax_logo.pdf similarity index 100% rename from docs/_static/gpjax_logo.pdf rename to docs/static/gpjax_logo.pdf diff --git a/docs/_static/gpjax_logo.svg b/docs/static/gpjax_logo.svg similarity index 100% rename from docs/_static/gpjax_logo.svg rename to docs/static/gpjax_logo.svg diff --git a/docs/_static/jaxkern/lato.ttf b/docs/static/jaxkern/lato.ttf similarity index 100% rename from docs/_static/jaxkern/lato.ttf rename to docs/static/jaxkern/lato.ttf diff --git a/docs/_static/jaxkern/logo.png b/docs/static/jaxkern/logo.png similarity index 100% rename from docs/_static/jaxkern/logo.png rename to docs/static/jaxkern/logo.png diff --git a/docs/_static/jaxkern/logo.svg b/docs/static/jaxkern/logo.svg similarity index 100% rename from docs/_static/jaxkern/logo.svg rename to docs/static/jaxkern/logo.svg diff --git a/docs/_static/jaxkern/main.py b/docs/static/jaxkern/main.py similarity index 100% rename from docs/_static/jaxkern/main.py rename to docs/static/jaxkern/main.py diff --git a/docs/_static/step_size_figure.png b/docs/static/step_size_figure.png similarity index 100% rename from docs/_static/step_size_figure.png rename to docs/static/step_size_figure.png diff --git a/docs/_static/step_size_figure.svg b/docs/static/step_size_figure.svg similarity index 100% rename from docs/_static/step_size_figure.svg rename to docs/static/step_size_figure.svg diff --git a/docs/stylesheets/extra.css b/docs/stylesheets/extra.css index 14ac7b596..3e459d4b2 100644 --- a/docs/stylesheets/extra.css +++ b/docs/stylesheets/extra.css @@ -88,7 +88,16 @@ div.doc-contents:not(.first) { user-select: none; } +/* Centers all PNG images in markdown files */ +img[src$=".png"] { + display: block; + margin-left: auto; + margin-right: auto; +} + /* Maximum space for text block */ /* .md-grid { max-width: 65%; /* or 100%, if you want to stretch to full-width */ /* } + + diff --git a/docs/examples/barycentres.py b/examples/barycentres.py similarity index 99% rename from docs/examples/barycentres.py rename to examples/barycentres.py index 49836743b..83e664de0 100644 --- a/docs/examples/barycentres.py +++ b/examples/barycentres.py @@ -50,9 +50,12 @@ key = jr.key(123) + +# set the default style for plotting plt.style.use( "https://raw.githubusercontent.com/JaxGaussianProcesses/GPJax/main/docs/examples/gpjax.mplstyle" ) + cols = plt.rcParams["axes.prop_cycle"].by_key()["color"] # %% [markdown] diff --git a/docs/examples/barycentres/barycentre_gp.gif b/examples/barycentres/barycentre_gp.gif similarity index 100% rename from docs/examples/barycentres/barycentre_gp.gif rename to examples/barycentres/barycentre_gp.gif diff --git a/docs/examples/bayesian_optimisation.py b/examples/bayesian_optimisation.py similarity index 100% rename from docs/examples/bayesian_optimisation.py rename to examples/bayesian_optimisation.py diff --git a/docs/examples/classification.py b/examples/classification.py similarity index 100% rename from docs/examples/classification.py rename to examples/classification.py diff --git a/docs/examples/collapsed_vi.py b/examples/collapsed_vi.py similarity index 100% rename from docs/examples/collapsed_vi.py rename to examples/collapsed_vi.py diff --git a/docs/examples/constructing_new_kernels.py b/examples/constructing_new_kernels.py similarity index 100% rename from docs/examples/constructing_new_kernels.py rename to examples/constructing_new_kernels.py diff --git a/docs/examples/data/max_tempeature_switzerland.csv b/examples/data/max_tempeature_switzerland.csv similarity index 100% rename from docs/examples/data/max_tempeature_switzerland.csv rename to examples/data/max_tempeature_switzerland.csv diff --git a/docs/examples/data/yacht_hydrodynamics.data b/examples/data/yacht_hydrodynamics.data similarity index 100% rename from docs/examples/data/yacht_hydrodynamics.data rename to examples/data/yacht_hydrodynamics.data diff --git a/docs/examples/decision_making.py b/examples/decision_making.py similarity index 100% rename from docs/examples/decision_making.py rename to examples/decision_making.py diff --git a/docs/examples/deep_kernels.py b/examples/deep_kernels.py similarity index 100% rename from docs/examples/deep_kernels.py rename to examples/deep_kernels.py diff --git a/docs/examples/gpjax.mplstyle b/examples/gpjax.mplstyle similarity index 93% rename from docs/examples/gpjax.mplstyle rename to examples/gpjax.mplstyle index 38fe42c10..e62ef6ced 100644 --- a/docs/examples/gpjax.mplstyle +++ b/examples/gpjax.mplstyle @@ -14,10 +14,7 @@ axes.axisbelow: true ### Fonts mathtext.fontset: cm -font.family: serif -font.serif: Computer Modern Roman font.size: 10 -text.usetex: True # Axes ticks ytick.left: True diff --git a/docs/examples/graph_kernels.py b/examples/graph_kernels.py similarity index 100% rename from docs/examples/graph_kernels.py rename to examples/graph_kernels.py diff --git a/docs/examples/intro_to_gps.py b/examples/intro_to_gps.py similarity index 100% rename from docs/examples/intro_to_gps.py rename to examples/intro_to_gps.py diff --git a/docs/examples/intro_to_gps/decomposed_mll.png b/examples/intro_to_gps/decomposed_mll.png similarity index 100% rename from docs/examples/intro_to_gps/decomposed_mll.png rename to examples/intro_to_gps/decomposed_mll.png diff --git a/docs/examples/intro_to_gps/generating_process.png b/examples/intro_to_gps/generating_process.png similarity index 100% rename from docs/examples/intro_to_gps/generating_process.png rename to examples/intro_to_gps/generating_process.png diff --git a/docs/examples/intro_to_kernels.py b/examples/intro_to_kernels.py similarity index 100% rename from docs/examples/intro_to_kernels.py rename to examples/intro_to_kernels.py diff --git a/docs/examples/likelihoods_guide.py b/examples/likelihoods_guide.py similarity index 97% rename from docs/examples/likelihoods_guide.py rename to examples/likelihoods_guide.py index 8b3675ffd..2bff2fdfe 100644 --- a/docs/examples/likelihoods_guide.py +++ b/examples/likelihoods_guide.py @@ -25,12 +25,12 @@ # In this section we'll provide a short introduction to likelihoods and why they are # important. For users who are already familiar with likelihoods, feel free to skip to # the next section, and for users who would like more information than is provided -# here, please see our [introduction to Gaussian processes notebook](intro_to_gps.py). +# here, please see our [introduction to Gaussian processes notebook](intro_to_gps.md). # # ### What is a likelihood? # # We adopt the notation of our -# [introduction to Gaussian processes notebook](intro_to_gps.py) where we have a +# [introduction to Gaussian processes notebook](intro_to_gps.md) where we have a # Gaussian process (GP) $f(\cdot)\sim\mathcal{GP}(m(\cdot), k(\cdot, \cdot))$ and a # dataset $\mathbf{y} = \{y_n\}_{n=1}^N$ observed at corresponding inputs # $\mathbf{x} = \{x_n\}_{n=1}^N$. The evaluation of $f$ at $\mathbf{x}$ is denoted by @@ -128,9 +128,7 @@ gpx.likelihoods.Gaussian(num_datapoints=D.n, obs_stddev=0.5) # %% [markdown] -# To control other properties of the observation noise such as trainability and value -# constraints, see our [PyTree guide](pytrees.md). -# + # ### Prediction # # The `predict` method of a likelihood object transforms the latent distribution of @@ -224,7 +222,7 @@ # # The final method that is associated with a likelihood function in GPJax is the # expected log-likelihood. This term is evaluated in the -# [stochastic variational Gaussian process](uncollapsed_vi.py) in the ELBO term. For a +# [stochastic variational Gaussian process](uncollapsed_vi.md) in the ELBO term. For a # variational approximation $q(f)= \mathcal{N}(f\mid m, S)$, the ELBO can be written as # $$ # \begin{align} diff --git a/docs/examples/oceanmodelling.py b/examples/oceanmodelling.py similarity index 100% rename from docs/examples/oceanmodelling.py rename to examples/oceanmodelling.py diff --git a/docs/examples/poisson.py b/examples/poisson.py similarity index 100% rename from docs/examples/poisson.py rename to examples/poisson.py diff --git a/docs/examples/regression.py b/examples/regression.py similarity index 99% rename from docs/examples/regression.py rename to examples/regression.py index c060677c1..bf777b1e4 100644 --- a/docs/examples/regression.py +++ b/examples/regression.py @@ -30,15 +30,18 @@ from jaxtyping import install_import_hook import matplotlib as mpl import matplotlib.pyplot as plt -from docs.examples.utils import clean_legend +from examples.utils import clean_legend with install_import_hook("gpjax", "beartype.beartype"): import gpjax as gpx key = jr.key(123) + +# set the default style for plotting plt.style.use( "https://raw.githubusercontent.com/JaxGaussianProcesses/GPJax/main/docs/examples/gpjax.mplstyle" ) + cols = mpl.rcParams["axes.prop_cycle"].by_key()["color"] # %% [markdown] @@ -97,6 +100,7 @@ # smoothness of the outputs that our GP can generate. # # For simplicity, we consider a radial basis function (RBF) kernel: +# # $$k(x, x') = \sigma^2 \exp\left(-\frac{\lVert x - x' \rVert_2^2}{2 \ell^2}\right).$$ # # On paper a GP is written as $f(\cdot) \sim \mathcal{GP}(\textbf{0}, k(\cdot, \cdot'))$, @@ -146,7 +150,9 @@ # notion of a likelihood function $p(\mathcal{D} | f(\cdot))$. While the choice of # likelihood is a critical in Bayesian modelling, for simplicity we consider a # Gaussian with noise parameter $\alpha$ +# # $$p(\mathcal{D} | f(\cdot)) = \mathcal{N}(\boldsymbol{y}; f(\boldsymbol{x}), \textbf{I} \alpha^2).$$ +# # This is defined in GPJax through calling a `Gaussian` instance. # %% diff --git a/docs/examples/uncollapsed_vi.py b/examples/uncollapsed_vi.py similarity index 100% rename from docs/examples/uncollapsed_vi.py rename to examples/uncollapsed_vi.py diff --git a/docs/examples/utils.py b/examples/utils.py similarity index 97% rename from docs/examples/utils.py rename to examples/utils.py index 520e86a5a..b9ccaa65f 100644 --- a/docs/examples/utils.py +++ b/examples/utils.py @@ -1,3 +1,6 @@ +from pathlib import Path + +import matplotlib.pyplot as plt from matplotlib import transforms from matplotlib.patches import Ellipse import numpy as np diff --git a/docs/examples/yacht.py b/examples/yacht.py similarity index 100% rename from docs/examples/yacht.py rename to examples/yacht.py diff --git a/gpjax/flax_base/bijectors.py b/gpjax/flax_base/bijectors.py deleted file mode 100644 index 3b4d96589..000000000 --- a/gpjax/flax_base/bijectors.py +++ /dev/null @@ -1,8 +0,0 @@ -import tensorflow_probability.substrates.jax.bijectors as tfb - -from gpjax.flax_base.types import BijectorLookupType - -Bijectors: BijectorLookupType = { - "real": tfb.Identity(), - "positive": tfb.Softplus(), -} diff --git a/gpjax/flax_base/param.py b/gpjax/flax_base/param.py deleted file mode 100644 index ff5204a02..000000000 --- a/gpjax/flax_base/param.py +++ /dev/null @@ -1,19 +0,0 @@ -from flax import nnx -import jax.numpy as jnp - -from gpjax.flax_base.types import ( - A, - DomainType, -) - - -class AbstractParameter(nnx.Variable[A]): - domain: DomainType = "real" - static: bool = False - - def __init__(self, value: A, *args, **kwargs): - super().__init__(jnp.asarray(value), *args, **kwargs) - - -class PositiveParameter(AbstractParameter[A]): - domain: DomainType = "positive" diff --git a/gpjax/flax_base/types.py b/gpjax/flax_base/types.py deleted file mode 100644 index b51b8ec7f..000000000 --- a/gpjax/flax_base/types.py +++ /dev/null @@ -1,18 +0,0 @@ -import typing as tp - -import tensorflow_probability.substrates.jax.bijectors as tfb - -DomainType = tp.Literal["real", "positive"] -A = tp.TypeVar("A") - - -# class BijectorLookup(tp.TypedDict): -# domain: DomainType -# bijector: tfb.Bijector - - -class BijectorLookupType(tp.Dict[DomainType, tfb.Bijector]): - pass - - -__all__ = ["DomainType", "A", "BijectorLookupType"] diff --git a/mkdocs.yml b/mkdocs.yml index f9c1a95f9..87b51b122 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -4,7 +4,7 @@ site_url: https://docs.jaxgaussianprocesses.com/ repo_url: https://github.com/JaxGaussianProcesses/GPJax repo_name: JaxGaussianProcesses/GPJax -edit_uri: "" +# edit_uri: "" nav: - 🏡 Home: index.md @@ -13,28 +13,27 @@ nav: - 🎨 Design principles: design.md - 🤝 Contributing: contributing.md - 🔪 Sharp bits: sharp_bits.md - - 🌳 GPJax PyTrees: examples/pytrees.md - 📎 JAX 101 [External]: https://jax.readthedocs.io/en/latest/jax-101/index.html - 💡 Background: - - Intro to GPs: examples/intro_to_gps.md - - Intro to Kernels: examples/intro_to_kernels.md + - Intro to GPs: _examples/intro_to_gps.md + - Intro to Kernels: _examples/intro_to_kernels.md - 🎓 Tutorials: - - Regression: examples/regression.md - - Classification: examples/classification.md - - Poisson regression: examples/poisson.md - - Barycentres: examples/barycentres.md - - Deep kernel learning: examples/deep_kernels.md - - Graph kernels: examples/graph_kernels.md - - Sparse GPs: examples/uncollapsed_vi.md - - Stochastic sparse GPs: examples/collapsed_vi.md - - Bayesian Optimisation: examples/bayesian_optimisation.md - - Decision Making: examples/decision_making.md - - Multi-output GPs for Ocean Modelling: examples/oceanmodelling.md + - Regression: _examples/regression.md + - Classification: _examples/classification.md + - Poisson regression: _examples/poisson.md + - Barycentres: _examples/barycentres.md + - Deep kernel learning: _examples/deep_kernels.md + - Graph kernels: _examples/graph_kernels.md + - Sparse GPs: _examples/uncollapsed_vi.md + - Stochastic sparse GPs: _examples/collapsed_vi.md + - Bayesian Optimisation: _examples/bayesian_optimisation.md + - Decision Making: _examples/decision_making.md + - Multi-output GPs for Ocean Modelling: _examples/oceanmodelling.md - 📖 Guides for customisation: - - Kernels: examples/constructing_new_kernels.md - - Likelihoods: examples/likelihoods_guide.md - - UCI regression: examples/yacht.md - - 💻 Raw tutorial code: give_me_the_code.md + - Kernels: _examples/constructing_new_kernels.md + - Likelihoods: _examples/likelihoods_guide.md + - UCI regression: _examples/yacht.md + # - 💻 Raw tutorial code: give_me_the_code.md - Community: - 👥 Code of conduct: CODE_OF_CONDUCT.md - 📜 Governance: GOVERNANCE.md @@ -56,16 +55,23 @@ theme: - content.code.annotate # Allow individual lines of code to be annotated icon: repo: fontawesome/brands/github - logo: _static/favicon.ico - favicon: _static/favicon.ico + logo: static/favicon.ico + favicon: static/favicon.ico markdown_extensions: + - admonition + - markdown_katex: + no_inline_svg: True + insert_fonts_css: True - pymdownx.highlight: anchor_linenums: true line_spans: __span pygments_lang_class: true + - pymdownx.tabbed: + alternate_style: true - pymdownx.inlinehilite - - pymdownx.snippets + - pymdownx.snippets: + check_paths: true - pymdownx.superfences - pymdownx.arithmatex: generic: true @@ -81,12 +87,9 @@ plugins: - gen-files: scripts: - docs/scripts/gen_pages.py - - docs/scripts/gen_examples.py - # - docs/scripts/notebook_converter.py # or any other name or path - literate-nav: nav_file: SUMMARY.md - mkdocstrings: - default_handler: python handlers: python: paths: ["gpjax"] diff --git a/poetry.lock b/poetry.lock index eae61bea3..1141e1264 100644 --- a/poetry.lock +++ b/poetry.lock @@ -11,6 +11,17 @@ files = [ {file = "absl_py-2.1.0-py3-none-any.whl", hash = "sha256:526a04eadab8b4ee719ce68f204172ead1027549089702d99b9059f129ff1308"}, ] +[[package]] +name = "absolufy-imports" +version = "0.3.1" +description = "A tool to automatically replace relative imports with absolute ones." +optional = false +python-versions = ">=3.6.1" +files = [ + {file = "absolufy_imports-0.3.1-py2.py3-none-any.whl", hash = "sha256:49bf7c753a9282006d553ba99217f48f947e3eef09e18a700f8a82f75dc7fc5c"}, + {file = "absolufy_imports-0.3.1.tar.gz", hash = "sha256:c90638a6c0b66826d1fb4880ddc20ef7701af34192c94faf40b95d32b59f9793"}, +] + [[package]] name = "appnope" version = "0.1.4" @@ -1797,18 +1808,17 @@ testing = ["coverage", "pytest", "pytest-cov", "pytest-regressions"] [[package]] name = "markdown-katex" -version = "202112.1034" +version = "202406.1035" description = "katex extension for Python Markdown" optional = false python-versions = ">=2.7" files = [ - {file = "markdown-katex-202112.1034.tar.gz", hash = "sha256:27892f4cdd6763816f00e4187d0475500697c090aba16630ec4803a6564bf810"}, - {file = "markdown_katex-202112.1034-py2.py3-none-any.whl", hash = "sha256:9ccc5b4b37db7592cc3ea113d763fafe9ffd1b1587e2c217d6145e44a10b4f6d"}, + {file = "markdown-katex-202406.1035.tar.gz", hash = "sha256:e82f7bf9a8536451da8f01768d847516fa1827feb17140b8eaa0bea9826bdab0"}, + {file = "markdown_katex-202406.1035-py2.py3-none-any.whl", hash = "sha256:c1713e85854ddecb641ad96243a8b6cd67367bf1bf8d39b43b3680d7f2b1884d"}, ] [package.dependencies] Markdown = {version = ">=3.0", markers = "python_version >= \"3.6\""} -pathlib2 = "*" setuptools = "*" [[package]] @@ -2174,13 +2184,13 @@ files = [ [[package]] name = "mkdocstrings" -version = "0.24.3" +version = "0.25.2" description = "Automatic documentation from sources, for MkDocs." optional = false python-versions = ">=3.8" files = [ - {file = "mkdocstrings-0.24.3-py3-none-any.whl", hash = "sha256:5c9cf2a32958cd161d5428699b79c8b0988856b0d4a8c5baf8395fc1bf4087c3"}, - {file = "mkdocstrings-0.24.3.tar.gz", hash = "sha256:f327b234eb8d2551a306735436e157d0a22d45f79963c60a8b585d5f7a94c1d2"}, + {file = "mkdocstrings-0.25.2-py3-none-any.whl", hash = "sha256:9e2cda5e2e12db8bb98d21e3410f3f27f8faab685a24b03b06ba7daa5b92abfc"}, + {file = "mkdocstrings-0.25.2.tar.gz", hash = "sha256:5cf57ad7f61e8be3111a2458b4e49c2029c9cb35525393b179f9c916ca8042dc"}, ] [package.dependencies] @@ -2201,18 +2211,18 @@ python-legacy = ["mkdocstrings-python-legacy (>=0.2.1)"] [[package]] name = "mkdocstrings-python" -version = "1.10.0" +version = "1.10.8" description = "A Python handler for mkdocstrings." optional = false python-versions = ">=3.8" files = [ - {file = "mkdocstrings_python-1.10.0-py3-none-any.whl", hash = "sha256:ba833fbd9d178a4b9d5cb2553a4df06e51dc1f51e41559a4d2398c16a6f69ecc"}, - {file = "mkdocstrings_python-1.10.0.tar.gz", hash = "sha256:71678fac657d4d2bb301eed4e4d2d91499c095fd1f8a90fa76422a87a5693828"}, + {file = "mkdocstrings_python-1.10.8-py3-none-any.whl", hash = "sha256:bb12e76c8b071686617f824029cb1dfe0e9afe89f27fb3ad9a27f95f054dcd89"}, + {file = "mkdocstrings_python-1.10.8.tar.gz", hash = "sha256:5856a59cbebbb8deb133224a540de1ff60bded25e54d8beacc375bb133d39016"}, ] [package.dependencies] -griffe = ">=0.44" -mkdocstrings = ">=0.24.2" +griffe = ">=0.49" +mkdocstrings = ">=0.25" [[package]] name = "mktestdocs" @@ -2745,20 +2755,6 @@ files = [ qa = ["flake8 (==5.0.4)", "mypy (==0.971)", "types-setuptools (==67.2.0.1)"] testing = ["docopt", "pytest"] -[[package]] -name = "pathlib2" -version = "2.3.7.post1" -description = "Object-oriented filesystem paths" -optional = false -python-versions = "*" -files = [ - {file = "pathlib2-2.3.7.post1-py2.py3-none-any.whl", hash = "sha256:5266a0fd000452f1b3467d782f079a4343c63aaa119221fbdc4e39577489ca5b"}, - {file = "pathlib2-2.3.7.post1.tar.gz", hash = "sha256:9fe0edad898b83c0c3e199c842b27ed216645d2e177757b2dd67384d4113c641"}, -] - -[package.dependencies] -six = "*" - [[package]] name = "pathspec" version = "0.12.1" @@ -3702,28 +3698,29 @@ files = [ [[package]] name = "ruff" -version = "0.3.7" +version = "0.6.0" description = "An extremely fast Python linter and code formatter, written in Rust." optional = false python-versions = ">=3.7" files = [ - {file = "ruff-0.3.7-py3-none-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:0e8377cccb2f07abd25e84fc5b2cbe48eeb0fea9f1719cad7caedb061d70e5ce"}, - {file = "ruff-0.3.7-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:15a4d1cc1e64e556fa0d67bfd388fed416b7f3b26d5d1c3e7d192c897e39ba4b"}, - {file = "ruff-0.3.7-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d28bdf3d7dc71dd46929fafeec98ba89b7c3550c3f0978e36389b5631b793663"}, - {file = "ruff-0.3.7-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:379b67d4f49774ba679593b232dcd90d9e10f04d96e3c8ce4a28037ae473f7bb"}, - {file = "ruff-0.3.7-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c060aea8ad5ef21cdfbbe05475ab5104ce7827b639a78dd55383a6e9895b7c51"}, - {file = "ruff-0.3.7-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:ebf8f615dde968272d70502c083ebf963b6781aacd3079081e03b32adfe4d58a"}, - {file = "ruff-0.3.7-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d48098bd8f5c38897b03604f5428901b65e3c97d40b3952e38637b5404b739a2"}, - {file = "ruff-0.3.7-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:da8a4fda219bf9024692b1bc68c9cff4b80507879ada8769dc7e985755d662ea"}, - {file = "ruff-0.3.7-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c44e0149f1d8b48c4d5c33d88c677a4aa22fd09b1683d6a7ff55b816b5d074f"}, - {file = "ruff-0.3.7-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:3050ec0af72b709a62ecc2aca941b9cd479a7bf2b36cc4562f0033d688e44fa1"}, - {file = "ruff-0.3.7-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:a29cc38e4c1ab00da18a3f6777f8b50099d73326981bb7d182e54a9a21bb4ff7"}, - {file = "ruff-0.3.7-py3-none-musllinux_1_2_i686.whl", hash = "sha256:5b15cc59c19edca917f51b1956637db47e200b0fc5e6e1878233d3a938384b0b"}, - {file = "ruff-0.3.7-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:e491045781b1e38b72c91247cf4634f040f8d0cb3e6d3d64d38dcf43616650b4"}, - {file = "ruff-0.3.7-py3-none-win32.whl", hash = "sha256:bc931de87593d64fad3a22e201e55ad76271f1d5bfc44e1a1887edd0903c7d9f"}, - {file = "ruff-0.3.7-py3-none-win_amd64.whl", hash = "sha256:5ef0e501e1e39f35e03c2acb1d1238c595b8bb36cf7a170e7c1df1b73da00e74"}, - {file = "ruff-0.3.7-py3-none-win_arm64.whl", hash = "sha256:789e144f6dc7019d1f92a812891c645274ed08af6037d11fc65fcbc183b7d59f"}, - {file = "ruff-0.3.7.tar.gz", hash = "sha256:d5c1aebee5162c2226784800ae031f660c350e7a3402c4d1f8ea4e97e232e3ba"}, + {file = "ruff-0.6.0-py3-none-linux_armv6l.whl", hash = "sha256:92dcce923e5df265781e5fc76f9a1edad52201a7aafe56e586b90988d5239013"}, + {file = "ruff-0.6.0-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:31b90ff9dc79ed476c04e957ba7e2b95c3fceb76148f2079d0d68a908d2cfae7"}, + {file = "ruff-0.6.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:6d834a9ec9f8287dd6c3297058b3a265ed6b59233db22593379ee38ebc4b9768"}, + {file = "ruff-0.6.0-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f2089267692696aba342179471831a085043f218706e642564812145df8b8d0d"}, + {file = "ruff-0.6.0-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:aa62b423ee4bbd8765f2c1dbe8f6aac203e0583993a91453dc0a449d465c84da"}, + {file = "ruff-0.6.0-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7344e1a964b16b1137ea361d6516ce4ee61a0403fa94252a1913ecc1311adcae"}, + {file = "ruff-0.6.0-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:487f3a35c3f33bf82be212ce15dc6278ea854e35573a3f809442f73bec8b2760"}, + {file = "ruff-0.6.0-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:75db409984077a793cf344d499165298a6f65449e905747ac65983b12e3e64b1"}, + {file = "ruff-0.6.0-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:84908bd603533ecf1db456d8fc2665d1f4335d722e84bc871d3bbd2d1116c272"}, + {file = "ruff-0.6.0-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0f1749a0aef3ec41ed91a0e2127a6ae97d2e2853af16dbd4f3c00d7a3af726c5"}, + {file = "ruff-0.6.0-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:016fea751e2bcfbbd2f8cb19b97b37b3fd33148e4df45b526e87096f4e17354f"}, + {file = "ruff-0.6.0-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:6ae80f141b53b2e36e230017e64f5ea2def18fac14334ffceaae1b780d70c4f7"}, + {file = "ruff-0.6.0-py3-none-musllinux_1_2_i686.whl", hash = "sha256:eaaaf33ea4b3f63fd264d6a6f4a73fa224bbfda4b438ffea59a5340f4afa2bb5"}, + {file = "ruff-0.6.0-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:7667ddd1fc688150a7ca4137140867584c63309695a30016880caf20831503a0"}, + {file = "ruff-0.6.0-py3-none-win32.whl", hash = "sha256:ae48365aae60d40865a412356f8c6f2c0be1c928591168111eaf07eaefa6bea3"}, + {file = "ruff-0.6.0-py3-none-win_amd64.whl", hash = "sha256:774032b507c96f0c803c8237ce7d2ef3934df208a09c40fa809c2931f957fe5e"}, + {file = "ruff-0.6.0-py3-none-win_arm64.whl", hash = "sha256:a5366e8c3ae6b2dc32821749b532606c42e609a99b0ae1472cf601da931a048c"}, + {file = "ruff-0.6.0.tar.gz", hash = "sha256:272a81830f68f9bd19d49eaf7fa01a5545c5a2e86f32a9935bb0e4bb9a1db5b8"}, ] [[package]] @@ -4216,4 +4213,4 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.12" -content-hash = "11180d2b39dd5cdc9d98aa088303d2a59b00b64f958c33539a97e07188d5af72" +content-hash = "4b2ae7bad45029e7becc027913e0adbf40d21a2632971f4eb50c3a4096f20766" diff --git a/pyproject.toml b/pyproject.toml index 9919cab84..82a757eda 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,7 +30,7 @@ flax = "^0.8.4" numpy = "<2.0.0" [tool.poetry.group.dev.dependencies] -ruff = "^0.3.0" +ruff = "~0" pre-commit = "^3.2.2" interrogate = "^1.5.0" codespell = "^2.2.4" @@ -39,6 +39,7 @@ pytest-cov = "^4.0.0" pytest-pretty = "^1.1.1" pytest-xdist = "^3.2.1" coverage = "^7.2.2" +absolufy-imports = "^0.3.1" xdoctest = "^1.1.1" mktestdocs = "^0.2.1" asv = "^0.6.0" @@ -47,12 +48,11 @@ asv = "^0.6.0" [tool.poetry.group.docs.dependencies] mkdocs = "^1.5.3" mkdocs-material = "^9.5.12" -mkdocstrings = { version = "^0.24.1", extras = ["python"] } +mkdocstrings = { version = "^0.25.1", extras = ["python"] } mkdocs-jupyter = "^0.24.3" mkdocs-gen-files = "^0.5.0" mkdocs-literate-nav = "^0.6.0" mkdocs-git-authors-plugin = "^0.7.0" -markdown-katex = "^202112.1034" matplotlib = "^3.7.1" seaborn = "^0.12.2" networkx = "^3.0" @@ -65,6 +65,7 @@ ipywidgets = "^8.0.5" pandas = "^1.5.3" pymdown-extensions = "^10.7.1" nbconvert = "^7.16.2" +markdown-katex = "^202406.1035" [build-system] requires = ["poetry-core"] @@ -83,11 +84,13 @@ xfail_strict = true [tool.ruff] # https://github.com/charliermarsh/ruff fix = true cache-dir = "~/.cache/ruff" +exclude = ["docs/", "examples/"] line-length = 88 src = ["gpjax", "tests"] target-version = "py38" [tool.ruff.lint] +dummy-variable-rgx = "^_$" select = [ "F", # pycodestyle @@ -131,9 +134,6 @@ ignore = [ "PLR0913", ] unfixable = ["ERA001", "F401", "F841", "T201", "T203"] -exclude = ["docs/"] -ignore-init-module-imports = true -dummy-variable-rgx = "^_$" [tool.ruff.lint.pydocstyle] convention = "numpy" diff --git a/tests/integration_tests.py b/tests/integration_tests.py index 2ec803a06..bc99facd0 100644 --- a/tests/integration_tests.py +++ b/tests/integration_tests.py @@ -75,7 +75,7 @@ def test(self): regression = Result( - path="docs/examples/regression.py", + path="examples/regression.py", comparisons={ "history": (55.07405622, get_last), "predictive_mean": (36.24383416, jnp.sum), @@ -85,7 +85,7 @@ def test(self): regression.test() sparse = Result( - path="docs/examples/collapsed_vi.py", + path="examples/collapsed_vi.py", comparisons={ "history": (1924.7634809, get_last), "predictive_mean": (-8.39869652, jnp.sum), @@ -95,7 +95,7 @@ def test(self): sparse.test() stochastic = Result( - path="docs/examples/uncollapsed_vi.py", + path="examples/uncollapsed_vi.py", comparisons={ "history": (-2678.41302494, get_last), "meanf": (-54.14787028, jnp.sum),