You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I have confirmed that it usually works as follows.
d = dist.TransformedDistribution(dist.Dirichlet(np.ones((3,))),dist.transforms.SimplexToOrderedTransform(0))
d.sample(random.PRNGKey(0))
Array([-1.584826 , 0.6460422], dtype=float32)
However, when executing MCMC, it does not seem to work correctly as shown in the following code.
This is a bug in SimplexToOrderedTransform where the methods forward_shape and inverse_shape are not implemented correctly (it is currently using the default ones, that maps shape to shape rather than shape to shape[:-1] + (shape[-1] - 1,).
import os
import jax.numpy as jnp
from jax import random
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import numpyro
from numpyro.diagnostics import hpdi
import numpyro.distributions as dist
from jax.experimental.ode import odeint
import numpyro.distributions.constraints as constraints
from numpyro.infer import MCMC, NUTS
from numpyro.infer import Predictive
from numpyro.infer.util import initialize_model
import arviz as az
az.style.use("arviz-darkgrid")
assert numpyro.__version__.startswith("0.11.0")
numpyro.set_platform("cpu")
numpyro.set_host_device_count(1)
from numpyro.infer.reparam import TransformReparam
# data generation
simkeys = random.split(random.PRNGKey(1), 2)
nsim = 50
nclasses = 3
Y = dist.Categorical(logits=np.zeros(nclasses)).sample(simkeys[0], sample_shape=(nsim,))
X =dist. Normal().sample(simkeys[1], sample_shape=(nsim,))
X += Y
df = pd.DataFrame({"X": X, "Y": Y})
def model_ng(X, Y, nclasses, concentration, anchor_point=0.0):
b_X_eta = numpyro.sample("b_X_eta", dist.Normal(0, 5))
#with numpyro.handlers.reparam(config={"c_y": TransformReparam()}):
c_y = numpyro.sample(
"c_y",
dist.TransformedDistribution(
dist.Dirichlet(concentration),
dist.transforms.SimplexToOrderedTransform(anchor_point),
)
)
print(c_y.shape, c_y)
with numpyro.plate("obs", X.shape[0]):
eta = X * b_X_eta
numpyro.sample("Y", dist.OrderedLogistic(eta, c_y), obs=Y)
concentration = np.ones((nclasses,)) * 10.0
rng_key= random.PRNGKey(0)
kernel = NUTS(model_ng)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=2000, num_chains=1, thinning=1)
mcmc.run(
rng_key=rng_key,
X=df["X"].values,
Y=df["Y"].values,
nclasses=nclasses,
concentration=concentration,
)
# with exclude_deterministic=False, we will also show the ordinal probabilities sampled from Dirichlet (vis. `c_y_base`)
mcmc.print_summary(exclude_deterministic=False)
(3,) [1.2993407 2.6367652 9.292245 ]
(3,) Traced<ConcreteArray([1.2993407 2.6367652 9.292245 ], dtype=float32)>with<JVPTrace(level=2/0)> with
primal = Array([1.2993407, 2.6367652, 9.292245 ], dtype=float32)
tangent = Traced<ShapedArray(float32[3])>with<JaxprTrace(level=1/0)> with
pval = (ShapedArray(float32[3]), None)
recipe = JaxprEqnRecipe(eqn_id=<object object at 0x3019bfba0>, in_tracers=(Traced<ShapedArray(float32[3]):JaxprTrace(level=1/0)>,), out_tracer_refs=[<weakref at 0x3028781d0; to 'JaxprTracer' at 0x302878400>], out_avals=[ShapedArray(float32[3])], primitive=pjit, params={'jaxpr': { lambda ; a:f32[3]. let b:f32[3] = cumsum[axis=0 reverse=False] a in (b,) }, 'in_shardings': (UnspecifiedValue,), 'out_shardings': (UnspecifiedValue,), 'resource_env': None, 'donated_invars': (False,), 'name': '_cumulative_reduction', 'in_positional_semantics': (<_PositionalSemantics.GLOBAL: 1>,), 'out_positional_semantics': <_PositionalSemantics.GLOBAL: 1>, 'keep_unused': False, 'inline': False}, effects=set(), source_info=SourceInfo(traceback=<jaxlib.xla_extension.Traceback object at 0x174ac8330>, name_stack=NameStack(stack=(Transform(name='jvp'),))))
Output exceeds the size limit. Open the full output data in a text editor
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[58], line 62
60 kernel = NUTS(model_ng)
61 mcmc = MCMC(kernel, num_warmup=1000, num_samples=2000, num_chains=1, thinning=1)
---> 62 mcmc.run(
63 rng_key=rng_key,
64 X=df["X"].values,
65 Y=df["Y"].values,
66 nclasses=nclasses,
67 concentration=concentration,
68 )
69 # with exclude_deterministic=False, we will also show the ordinal probabilities sampled from Dirichlet (vis. `c_y_base`)
70 mcmc.print_summary(exclude_deterministic=False)
File ~/Desktop/programming/numpyro_intro/.venv/lib/python3.9/site-packages/numpyro/infer/mcmc.py:628, in MCMC.run(self, rng_key, extra_fields, init_params, *args, **kwargs)
626 map_args = (rng_key, init_state, init_params)
627 if self.num_chains == 1:
--> 628 states_flat, last_state = partial_map_fn(map_args)
629 states = tree_map(lambda x: x[jnp.newaxis, ...], states_flat)
630 else:
File ~/Desktop/programming/numpyro_intro/.venv/lib/python3.9/site-packages/numpyro/infer/mcmc.py:410, in MCMC._single_chain_mcmc(self, init, args, kwargs, collect_fields)
408 # Check if _sample_fn is None, then we need to initialize the sampler.
409 if init_state is None or (getattr(self.sampler, "_sample_fn", None) is None):
...
-> 1617 raise TypeError(f'{name} got incompatible shapes for broadcasting: '
1618 f'{", ".join(map(str, map(tuple, shapes)))}.')
1620 return tuple(result_shape)
TypeError: mul got incompatible shapes for broadcasting: (4,), (3,).
The text was updated successfully, but these errors were encountered:
I run the tutorial Ordinal Regression (https://num.pyro.ai/en/latest/tutorials/ordinal_regression.html).
I am not sure how the transforms.SimplexToOrderedTransform behaves.
I have confirmed that it usually works as follows.
However, when executing MCMC, it does not seem to work correctly as shown in the following code.
This is a bug in SimplexToOrderedTransform where the methods forward_shape and inverse_shape are not implemented correctly (it is currently using the default ones, that maps shape to shape rather than shape to shape[:-1] + (shape[-1] - 1,).
The text was updated successfully, but these errors were encountered: