Skip to content

Commit

Permalink
Bump to version 0.5.0 (#888)
Browse files Browse the repository at this point in the history
* Support .to_event() in TransformReparam

* Strengthen test

* add forward shape for transforms

* add forward shapes and tests

* simplify nested independent Constraint

* shape check for affine transform

* add missing commit

* Bump to version 0.5.0

* fix tests, addresses comments

* port some fixes from hmcgibbs PR

* pin to specific version

* run isort

Co-authored-by: Fritz Obermeyer <[email protected]>
  • Loading branch information
fehiepsi and fritzo authored Jan 24, 2021
1 parent f052ea6 commit 6a1f522
Show file tree
Hide file tree
Showing 26 changed files with 27 additions and 27 deletions.
2 changes: 1 addition & 1 deletion examples/annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ def main(args):


if __name__ == "__main__":
assert numpyro.__version__.startswith("0.4.1")
assert numpyro.__version__.startswith("0.5.0")
parser = argparse.ArgumentParser(description="Bayesian Models of Annotation")
parser.add_argument("-n", "--num-samples", nargs="?", default=1000, type=int)
parser.add_argument("--num-warmup", nargs="?", default=1000, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/baseball.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def main(args):


if __name__ == "__main__":
assert numpyro.__version__.startswith('0.4.1')
assert numpyro.__version__.startswith('0.5.0')
parser = argparse.ArgumentParser(description="Baseball batting average using MCMC")
parser.add_argument("-n", "--num-samples", nargs="?", default=3000, type=int)
parser.add_argument("--num-warmup", nargs='?', default=1500, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/bnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def main(args):


if __name__ == "__main__":
assert numpyro.__version__.startswith('0.4.1')
assert numpyro.__version__.startswith('0.5.0')
parser = argparse.ArgumentParser(description="Bayesian neural network example")
parser.add_argument("-n", "--num-samples", nargs="?", default=2000, type=int)
parser.add_argument("--num-warmup", nargs='?', default=1000, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/covtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def main(args):


if __name__ == '__main__':
assert numpyro.__version__.startswith('0.4.1')
assert numpyro.__version__.startswith('0.5.0')
parser = argparse.ArgumentParser(description="parse args")
parser.add_argument('-n', '--num-samples', default=100, type=int, help='number of samples')
parser.add_argument('--num-steps', default=10, type=int, help='number of steps (for "HMC")')
Expand Down
2 changes: 1 addition & 1 deletion examples/funnel.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def main(args):


if __name__ == "__main__":
assert numpyro.__version__.startswith('0.4.1')
assert numpyro.__version__.startswith('0.5.0')
parser = argparse.ArgumentParser(description="Non-centered reparameterization example")
parser.add_argument("-n", "--num-samples", nargs="?", default=1000, type=int)
parser.add_argument("--num-warmup", nargs='?', default=1000, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def main(args):


if __name__ == "__main__":
assert numpyro.__version__.startswith('0.4.1')
assert numpyro.__version__.startswith('0.5.0')
parser = argparse.ArgumentParser(description="Gaussian Process example")
parser.add_argument("-n", "--num-samples", nargs="?", default=1000, type=int)
parser.add_argument("--num-warmup", nargs='?', default=1000, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/hmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def main(args):


if __name__ == '__main__':
assert numpyro.__version__.startswith('0.4.1')
assert numpyro.__version__.startswith('0.5.0')
parser = argparse.ArgumentParser(description='Semi-supervised Hidden Markov Model')
parser.add_argument('--num-categories', default=3, type=int)
parser.add_argument('--num-words', default=10, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/minipyro.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def body_fn(i, val):


if __name__ == "__main__":
assert numpyro.__version__.startswith('0.4.1')
assert numpyro.__version__.startswith('0.5.0')
parser = argparse.ArgumentParser(description="Mini Pyro demo")
parser.add_argument("-f", "--full-pyro", action="store_true", default=False)
parser.add_argument("-n", "--num-steps", default=1001, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/neutra.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def main(args):


if __name__ == "__main__":
assert numpyro.__version__.startswith('0.4.1')
assert numpyro.__version__.startswith('0.5.0')
parser = argparse.ArgumentParser(description="NeuTra HMC")
parser.add_argument('-n', '--num-samples', nargs='?', default=4000, type=int)
parser.add_argument('--num-warmup', nargs='?', default=1000, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def main(args):


if __name__ == '__main__':
assert numpyro.__version__.startswith('0.4.1')
assert numpyro.__version__.startswith('0.5.0')
parser = argparse.ArgumentParser(description='Predator-Prey Model')
parser.add_argument('-n', '--num-samples', nargs='?', default=1000, type=int)
parser.add_argument('--num-warmup', nargs='?', default=1000, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/proportion_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def main(args):


if __name__ == '__main__':
assert numpyro.__version__.startswith('0.4.1')
assert numpyro.__version__.startswith('0.5.0')
parser = argparse.ArgumentParser(description='Testing whether ')
parser.add_argument('-n', '--num-samples', nargs='?', default=500, type=int)
parser.add_argument('--num-warmup', nargs='?', default=1500, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/sparse_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ def main(args):


if __name__ == "__main__":
assert numpyro.__version__.startswith('0.4.1')
assert numpyro.__version__.startswith('0.5.0')
parser = argparse.ArgumentParser(description="Gaussian Process example")
parser.add_argument("-n", "--num-samples", nargs="?", default=1000, type=int)
parser.add_argument("--num-warmup", nargs='?', default=500, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/stochastic_volatility.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def main(args):


if __name__ == "__main__":
assert numpyro.__version__.startswith('0.4.1')
assert numpyro.__version__.startswith('0.5.0')
parser = argparse.ArgumentParser(description="Stochastic Volatility Model")
parser.add_argument('-n', '--num-samples', nargs='?', default=600, type=int)
parser.add_argument('--num-warmup', nargs='?', default=600, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/ucbadmit.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def main(args):


if __name__ == '__main__':
assert numpyro.__version__.startswith('0.4.1')
assert numpyro.__version__.startswith('0.5.0')
parser = argparse.ArgumentParser(description='UCBadmit gender discrimination using HMC')
parser.add_argument('-n', '--num-samples', nargs='?', default=2000, type=int)
parser.add_argument('--num-warmup', nargs='?', default=500, type=int)
Expand Down
2 changes: 1 addition & 1 deletion examples/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def reconstruct_img(epoch, rng_key):


if __name__ == '__main__':
assert numpyro.__version__.startswith('0.4.1')
assert numpyro.__version__.startswith('0.5.0')
parser = argparse.ArgumentParser(description="parse args")
parser.add_argument('-n', '--num-epochs', default=15, type=int, help='number of training epochs')
parser.add_argument('-lr', '--learning-rate', default=1.0e-3, type=float, help='learning rate')
Expand Down
2 changes: 1 addition & 1 deletion notebooks/source/bayesian_imputation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
"if \"NUMPYRO_SPHINXBUILD\" in os.environ:\n",
" set_matplotlib_formats(\"svg\")\n",
"\n",
"assert numpyro.__version__.startswith(\"0.4.1\")"
"assert numpyro.__version__.startswith(\"0.5.0\")"
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion notebooks/source/bayesian_regression.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@
"if \"NUMPYRO_SPHINXBUILD\" in os.environ:\n",
" set_matplotlib_formats('svg')\n",
"\n",
"assert numpyro.__version__.startswith('0.4.1')"
"assert numpyro.__version__.startswith('0.5.0')"
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion notebooks/source/logistic_regression.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
"import numpyro.distributions as dist\n",
"from numpyro.examples.datasets import COVTYPE, load_dataset\n",
"from numpyro.infer import HMC, MCMC, NUTS\n",
"assert numpyro.__version__.startswith('0.4.1')\n",
"assert numpyro.__version__.startswith('0.5.0')\n",
"\n",
"# NB: replace gpu by cpu to run this notebook in cpu\n",
"numpyro.set_platform(\"gpu\")"
Expand Down
2 changes: 1 addition & 1 deletion notebooks/source/ordinal_regression.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
"from numpyro.infer import MCMC, NUTS\n",
"import pandas as pd\n",
"import seaborn as sns\n",
"assert numpyro.__version__.startswith('0.4.1')"
"assert numpyro.__version__.startswith('0.5.0')"
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion notebooks/source/time_series_forecasting.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
" set_matplotlib_formats(\"svg\")\n",
"\n",
"numpyro.set_host_device_count(4)\n",
"assert numpyro.__version__.startswith(\"0.4.1\")"
"assert numpyro.__version__.startswith(\"0.5.0\")"
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion numpyro/contrib/einstein/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
import jax.scipy.linalg
import jax.scipy.stats

import numpyro.distributions as dist
from numpyro.contrib.einstein.utils import posdef, safe_norm, sqrth
import numpyro.distributions as dist


class PrecondMatrix(ABC):
Expand Down
2 changes: 1 addition & 1 deletion numpyro/infer/hmc_gibbs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import copy
from functools import partial

from jax import device_put, jacfwd, grad, ops, random, value_and_grad
from jax import device_put, grad, jacfwd, ops, random, value_and_grad
import jax.numpy as jnp
from jax.scipy.special import expit

Expand Down
2 changes: 1 addition & 1 deletion numpyro/infer/hmc_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from collections import namedtuple

from jax import jacfwd, grad, random, value_and_grad, vmap
from jax import grad, jacfwd, random, value_and_grad, vmap
from jax.flatten_util import ravel_pytree
import jax.numpy as jnp
from jax.ops import index_update
Expand Down
2 changes: 1 addition & 1 deletion numpyro/infer/reparam.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

from abc import ABC, abstractmethod

import jax.numpy as jnp
from jax import lax
import jax.numpy as jnp

import numpyro
import numpyro.distributions as dist
Expand Down
2 changes: 1 addition & 1 deletion numpyro/version.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

__version__ = '0.4.1'
__version__ = '0.5.0'
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@
author='Uber AI Labs',
install_requires=[
# TODO: pin to a specific version for the release (until JAX's API becomes stable)
'jax>=0.2.7',
'jax==0.2.8',
# check min version here: https://github.com/google/jax/blob/master/jax/lib/__init__.py#L26
'jaxlib>=0.1.56',
'jaxlib==0.1.59',
'tqdm',
],
extras_require={
Expand Down

0 comments on commit 6a1f522

Please sign in to comment.