-
Notifications
You must be signed in to change notification settings - Fork 246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support multidimensional HSGP (numpyro.contrib.hsgp) #1803
Conversation
wip(hsgp_nd): test list input for m, make dim required arg wip(hsgp_nd): support batched eigenfunctions, swap dims for m and d
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
doc(hsgp_nd): polish contrib documentation rm TODO fix eigenindices docstring cleanup nb
fda9312
to
8b064eb
Compare
^ fix pushed for python 3.9 typing issues causing CI failure |
^once more, fix for linting |
This looks super cool @brendancooley ! I this one ready for review? |
looks like matplotlib snuck in a deprecation here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@brendancooley This is amazing! I could have never done this in a cleaner way 🙌 ! Thank you!
I left two suggestions regarding documentation and some additional simple unit-tests to make this bullet-proof and to help other devs get what are the auxiliary functions expected to output. Besides that is a ✅ from my side.
docs/source/contrib.rst
Outdated
\times | ||
\overbrace{\color{green}{\beta_{j}}}^{\sim \: \text{Normal}(0,1)} | ||
|
||
where :math:`\lambda_j` are the eigenvalues of the Laplacian operator, :math:`\phi_{j}(x)` are the eigenfunctions of the | ||
where :math:`\boldsymbol{x}` is a :math:`D` vector of inputs, :math:`\boldsymbol{\lambda}_j^\star` are the eigenvalues of the Laplacian operator, :math:`\phi_{j}(\boldsymbol{x})` are the eigenfunctions of the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you define or explain what \boldsymbol{\lambda}_j^\
is :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good catch, thanks :)
numpyro/contrib/hsgp/laplacian.py
Outdated
m1 = jnp.tile(w0 * x[:, None], m) | ||
m2 = jnp.diag(jnp.arange(m, dtype=jnp.float32)) | ||
mw0x = m1 @ m2 | ||
cosines = jnp.cos(mw0x) | ||
sines = jnp.sin(mw0x) | ||
return cosines, sines | ||
|
||
|
||
def _convert_ell(ell: float | list[float] | ArrayImpl, dim: int) -> ArrayImpl: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we add a unit-test for this function? It would be great for documenting how is supposed to work.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added!
import jax.numpy as jnp | ||
|
||
|
||
# TODO: Adapt to dim >= 1. | ||
def sqrt_eigenvalues(ell: float, m: int) -> ArrayImpl: | ||
def eigenindices(m: list[int] | int, dim: int) -> ArrayImpl: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we add a unit-test for this function? It would be great for documenting how is supposed to work.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added!
Regarding the TODOS:
|
@fehiepsi any hints on the examples CI failing because other tests not relevant to this PR 😄 ? |
Thanks the feedback @juanitorduz! Will push a commit including responses to your suggestions shortly. |
It looks like the |
issue for vector-valued |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks great! very nice to review :D
Thanks again! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Beautiful work, @brendancooley!!
Thanks for reviewing, @juanitorduz!
eigenindices
function, following Eq 10 in Riutort-Mayol et aleigenfunctions
to handle multidimensionalx
. Vector-valued inputs are treated as unidimensional problems. Otherwise the trailing dimension ofx
is inferred as the dimension of the input space. All preceding dims are treated as batch dims.laplacian
andapproximation
functions to accept list-valued inputs form
andell
. If an int (m
) or float (ell
) is passed when the problem is multidimensional, the same value is used for all dimensions.The results in @juanitorduz notebook replicate with one change to function arguments.
TODO (on this PR or future)
numpyro.contrib.hsgp
numpyro.contrib.hsgp