Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bump jaxns to >=2.0.1 #1546

Merged
merged 16 commits into from
Mar 15, 2023
Merged

Conversation

Joshuaalbert
Copy link
Contributor

@Joshuaalbert Joshuaalbert commented Mar 4, 2023

  • reqs are now jaxns>==2.0.1
  • adjusted the wrapped to use new structure.
  • Note, that jaxns likes the user to not jit-compile the top-level nested sampling run. This is because it breaks up the algorithm into static parts and non-static parts.

EDIT: I changed JAXNS to support 3.8, so the bumped version is to >=2.0.1 now.

* adjusted the wrapped to use new structure.
@Joshuaalbert
Copy link
Contributor Author

@fehiepsi here you go^. I haven't tested it. I'm not sure if you have a test for the wrapper. Otherwise, would you mind running the Gaussian Shells example in numpyro.

@fehiepsi
Copy link
Member

fehiepsi commented Mar 6, 2023

Woohoo, thanks @Joshuaalbert! We have some tests here - I can take a look later today if things work.

Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @Joshuaalbert! The changes look great to me. After fixing a typo in PriorModel definition (missing Prior(...)), the tests passed and gaussian shell example also passed in my system.

You will need to run make format (with flake8, black, isort installed) to fix formating issues.

numpyro/contrib/nested_sampling.py Outdated Show resolved Hide resolved
numpyro/contrib/nested_sampling.py Show resolved Hide resolved
README.md Outdated Show resolved Hide resolved
docs/source/mcmc.rst Outdated Show resolved Hide resolved
numpyro/contrib/nested_sampling.py Outdated Show resolved Hide resolved
numpyro/contrib/nested_sampling.py Outdated Show resolved Hide resolved
numpyro/contrib/nested_sampling.py Outdated Show resolved Hide resolved
numpyro/contrib/nested_sampling.py Show resolved Hide resolved
numpyro/contrib/nested_sampling.py Outdated Show resolved Hide resolved
numpyro/contrib/nested_sampling.py Outdated Show resolved Hide resolved
* make format run
# Conflicts:
#	docs/source/mcmc.rst
@Joshuaalbert
Copy link
Contributor Author

I notice that the test on 3.8 if failing. That's because jaxns requires 3.9+.

ERROR: Ignored the following versions that require a different python version: 2.0.0 Requires-Python >=3.9
ERROR: Could not find a version that satisfies the requirement jaxns>=2.0.0 (from versions: 0.0.1, 0.0.2, 0.0.3, 0.0.4, 0.0.5, 0.0.6, 0.0.7, 1.0.0, 1.0.1, 1.1.0, 1.1.1, 1.1.2)

fehiepsi
fehiepsi previously approved these changes Mar 8, 2023
Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, just have some nits left 💯

numpyro/contrib/nested_sampling.py Show resolved Hide resolved
numpyro/contrib/nested_sampling.py Outdated Show resolved Hide resolved
numpyro/contrib/nested_sampling.py Outdated Show resolved Hide resolved
numpyro/contrib/nested_sampling.py Outdated Show resolved Hide resolved
numpyro/contrib/nested_sampling.py Outdated Show resolved Hide resolved
* make format run
@Joshuaalbert
Copy link
Contributor Author

Nits done

@fehiepsi
Copy link
Member

fehiepsi commented Mar 8, 2023

jaxns requires 3.9+

I guess we can just specify jaxns and skip the test on CI by replacing this line by

try:
    from numpyro.contrib.nested_sampling import NestedSampler, UniformReparam
except ImportError:
    pytestmark = pytest.mark.skip(reason="jaxns is not installed")

For the docs, please feel free to bump the version to 3.9 at this line.

@Joshuaalbert
Copy link
Contributor Author

Joshuaalbert commented Mar 8, 2023

Looks like there will still be a problem though because the setup.py will try to get jaxns>=2.0.0 this will fail before getting to the tests. Edit: The only solution would be to put a little logic in setup.py so that it only tries to install jaxns when the python version >= 3.9. This plus skipping the pytest would work.

@Joshuaalbert
Copy link
Contributor Author

@fehiepsi as expected it still fails. The only option would be for me to add a conditional check to include jaxns if python >= 3.9.

@Joshuaalbert Joshuaalbert changed the title bump jaxns to 2.x bump jaxns to >=2.0.1 Mar 8, 2023
@Joshuaalbert Joshuaalbert requested a review from fehiepsi March 9, 2023 01:34
@Joshuaalbert
Copy link
Contributor Author

@fehiepsi fixed the E402. You can run again.

@fehiepsi
Copy link
Member

the current issue is unrelated. let me fit it upstream

fehiepsi
fehiepsi previously approved these changes Mar 14, 2023
Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, @Joshuaalbert! It turns out that there are missing libraries in jaxns requirements like matplotlib and pylab-sdk. Could you merge the dev branch? Hopefully it will work now.

numpyro/contrib/nested_sampling.py Outdated Show resolved Hide resolved
numpyro/contrib/nested_sampling.py Outdated Show resolved Hide resolved
setup.py Outdated Show resolved Hide resolved
* make format
@fehiepsi fehiepsi merged commit bb9e1ba into pyro-ppl:master Mar 15, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants