Skip to content

Commit

Permalink
Add some doctests to transforms (#1300)
Browse files Browse the repository at this point in the history
* add some doctest to transforms

* make format
  • Loading branch information
wataruhashimoto52 authored Jan 22, 2022
1 parent 99c3132 commit 0a308a4
Showing 1 changed file with 34 additions and 0 deletions.
34 changes: 34 additions & 0 deletions numpyro/distributions/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,6 +654,19 @@ class LowerCholeskyAffine(Transform):
:param loc: a real vector.
:param scale_tril: a lower triangular matrix with positive diagonal.
**Example**
.. doctest::
>>> import jax.numpy as jnp
>>> from numpyro.distributions.transforms import LowerCholeskyAffine
>>> base = jnp.ones(2)
>>> loc = jnp.zeros(2)
>>> scale_tril = jnp.array([[0.3, 0.0], [1.0, 0.5]])
>>> affine = LowerCholeskyAffine(loc=loc, scale_tril=scale_tril)
>>> affine(base)
DeviceArray([0.3, 1.5], dtype=float32)
"""
domain = constraints.real_vector
codomain = constraints.real_vector
Expand Down Expand Up @@ -773,6 +786,17 @@ class OrderedTransform(Transform):
1. *Stan Reference Manual v2.20, section 10.6*,
Stan Development Team
**Example**
.. doctest::
>>> import jax.numpy as jnp
>>> from numpyro.distributions.transforms import OrderedTransform
>>> base = jnp.ones(3)
>>> transform = OrderedTransform()
>>> assert jnp.allclose(transform(base), jnp.array([1., 3.7182817, 6.4365635]), rtol=1e-3, atol=1e-3)
"""

domain = constraints.real_vector
Expand Down Expand Up @@ -863,6 +887,16 @@ class SimplexToOrderedTransform(Transform):
1. *Ordinal Regression Case Study, section 2.2*,
M. Betancourt, https://betanalpha.github.io/assets/case_studies/ordinal_regression.html
**Example**
.. doctest::
>>> import jax.numpy as jnp
>>> from numpyro.distributions.transforms import SimplexToOrderedTransform
>>> base = jnp.array([0.3, 0.1, 0.4, 0.2])
>>> transform = SimplexToOrderedTransform()
>>> assert jnp.allclose(transform(base), jnp.array([-0.8472978, -0.40546507, 1.3862944]), rtol=1e-3, atol=1e-3)
"""

domain = constraints.simplex
Expand Down

0 comments on commit 0a308a4

Please sign in to comment.