Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support multidimensional HSGP (numpyro.contrib.hsgp) #1803

Merged
merged 8 commits into from
May 28, 2024

Conversation

brendancooley
Copy link
Contributor

  • Implements eigenindices function, following Eq 10 in Riutort-Mayol et al
  • Update upstream functionality to accept a $D \times m^\star$ array of eigenvalues
  • Update eigenfunctions to handle multidimensional x. Vector-valued inputs are treated as unidimensional problems. Otherwise the trailing dimension of x is inferred as the dimension of the input space. All preceding dims are treated as batch dims.
  • Update laplacian and approximation functions to accept list-valued inputs for m and ell. If an int (m) or float (ell) is passed when the problem is multidimensional, the same value is used for all dimensions.
  • Miscellaneous cleanup and extensions of docstrings and methodology overview in "Contributed Code"

The results in @juanitorduz notebook replicate with one change to function arguments.

TODO (on this PR or future)

wip(hsgp_nd): test list input for m, make dim required arg

wip(hsgp_nd): support batched eigenfunctions, swap dims for m and d
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

doc(hsgp_nd): polish contrib documentation

rm TODO

fix eigenindices docstring

cleanup nb
@brendancooley
Copy link
Contributor Author

^ fix pushed for python 3.9 typing issues causing CI failure

@brendancooley
Copy link
Contributor Author

^once more, fix for linting

@juanitorduz
Copy link
Contributor

This looks super cool @brendancooley ! I this one ready for review?

looks like matplotlib snuck in a deprecation here
Copy link
Contributor

@juanitorduz juanitorduz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@brendancooley This is amazing! I could have never done this in a cleaner way 🙌 ! Thank you!

I left two suggestions regarding documentation and some additional simple unit-tests to make this bullet-proof and to help other devs get what are the auxiliary functions expected to output. Besides that is a ✅ from my side.

\times
\overbrace{\color{green}{\beta_{j}}}^{\sim \: \text{Normal}(0,1)}

where :math:`\lambda_j` are the eigenvalues of the Laplacian operator, :math:`\phi_{j}(x)` are the eigenfunctions of the
where :math:`\boldsymbol{x}` is a :math:`D` vector of inputs, :math:`\boldsymbol{\lambda}_j^\star` are the eigenvalues of the Laplacian operator, :math:`\phi_{j}(\boldsymbol{x})` are the eigenfunctions of the
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you define or explain what \boldsymbol{\lambda}_j^\ is :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch, thanks :)

m1 = jnp.tile(w0 * x[:, None], m)
m2 = jnp.diag(jnp.arange(m, dtype=jnp.float32))
mw0x = m1 @ m2
cosines = jnp.cos(mw0x)
sines = jnp.sin(mw0x)
return cosines, sines


def _convert_ell(ell: float | list[float] | ArrayImpl, dim: int) -> ArrayImpl:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we add a unit-test for this function? It would be great for documenting how is supposed to work.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added!

import jax.numpy as jnp


# TODO: Adapt to dim >= 1.
def sqrt_eigenvalues(ell: float, m: int) -> ArrayImpl:
def eigenindices(m: list[int] | int, dim: int) -> ArrayImpl:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we add a unit-test for this function? It would be great for documenting how is supposed to work.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added!

@juanitorduz
Copy link
Contributor

Regarding the TODOS:

  • The first one we can leave for other PRs so that we work on iterations (would you mind creating an issue?).
  • The second would be good to change after we merge this one.
  • I do not know if the third one is necessary as maybe this could serve as a "raw" example. No strong opinion. Still, we can open an issue but definitively not part of this PR.

@juanitorduz
Copy link
Contributor

@fehiepsi any hints on the examples CI failing because other tests not relevant to this PR 😄 ?

@brendancooley
Copy link
Contributor Author

brendancooley commented May 22, 2024

@brendancooley This is amazing! I could have never done this in a cleaner way 🙌 ! Thank you!

I left two suggestions regarding documentation and some additional simple unit-tests to make this bullet-proof and to help other devs get what are the auxiliary functions expected to output. Besides that is a ✅ from my side.

Thanks the feedback @juanitorduz! Will push a commit including responses to your suggestions shortly.

@brendancooley
Copy link
Contributor Author

@fehiepsi any hints on the examples CI failing because other tests not relevant to this PR 😄 ?

It looks like the stein_bnn example is sporadically failing to pull the boston housing data. I've been able to replicate locally but not consistently.

@brendancooley
Copy link
Contributor Author

Regarding the TODOS:

  • The first one we can leave for other PRs so that we work on iterations (would you mind creating an issue?).
  • The second would be good to change after we merge this one.
  • I do not know if the third one is necessary as maybe this could serve as a "raw" example. No strong opinion. Still, we can open an issue but definitively not part of this PR.

issue for vector-valued lengthscale up! #1805

Copy link
Contributor

@juanitorduz juanitorduz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks great! very nice to review :D

@brendancooley
Copy link
Contributor Author

This looks great! very nice to review :D

Thanks again!

Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Beautiful work, @brendancooley!!

Thanks for reviewing, @juanitorduz!

@fehiepsi fehiepsi merged commit e8216d7 into pyro-ppl:master May 28, 2024
4 checks passed
@brendancooley brendancooley deleted the feat/hsgp-nd branch May 31, 2024 12:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants