Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

a bug in SimplexToOrderedTransform #1580

Closed
yoshida-chem opened this issue Apr 21, 2023 · 0 comments · Fixed by #1583
Closed

a bug in SimplexToOrderedTransform #1580

yoshida-chem opened this issue Apr 21, 2023 · 0 comments · Fixed by #1583
Labels
bug Something isn't working

Comments

@yoshida-chem
Copy link

I run the tutorial Ordinal Regression (https://num.pyro.ai/en/latest/tutorials/ordinal_regression.html).
I am not sure how the transforms.SimplexToOrderedTransform behaves.

I have confirmed that it usually works as follows.

d = dist.TransformedDistribution(dist.Dirichlet(np.ones((3,))),dist.transforms.SimplexToOrderedTransform(0))
d.sample(random.PRNGKey(0))
Array([-1.584826 ,  0.6460422], dtype=float32)

However, when executing MCMC, it does not seem to work correctly as shown in the following code.
This is a bug in SimplexToOrderedTransform where the methods forward_shape and inverse_shape are not implemented correctly (it is currently using the default ones, that maps shape to shape rather than shape to shape[:-1] + (shape[-1] - 1,).

import os

import jax.numpy as jnp
from jax import random
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

import numpyro
from numpyro.diagnostics import hpdi
import numpyro.distributions as dist
from jax.experimental.ode import odeint
import numpyro.distributions.constraints as constraints
from numpyro.infer import MCMC, NUTS
from numpyro.infer import Predictive
from numpyro.infer.util import initialize_model

import arviz as az

az.style.use("arviz-darkgrid")

assert numpyro.__version__.startswith("0.11.0")

numpyro.set_platform("cpu")
numpyro.set_host_device_count(1)
from numpyro.infer.reparam import TransformReparam


# data generation
simkeys = random.split(random.PRNGKey(1), 2)
nsim = 50
nclasses = 3
Y = dist.Categorical(logits=np.zeros(nclasses)).sample(simkeys[0], sample_shape=(nsim,))
X =dist. Normal().sample(simkeys[1], sample_shape=(nsim,))
X += Y
df = pd.DataFrame({"X": X, "Y": Y})


def model_ng(X, Y, nclasses, concentration, anchor_point=0.0):
    b_X_eta = numpyro.sample("b_X_eta", dist.Normal(0, 5))

    #with numpyro.handlers.reparam(config={"c_y": TransformReparam()}):
    c_y = numpyro.sample(
        "c_y",
        dist.TransformedDistribution(
            dist.Dirichlet(concentration),
            dist.transforms.SimplexToOrderedTransform(anchor_point),
        )
    )
    print(c_y.shape, c_y)
    with numpyro.plate("obs", X.shape[0]):
        eta = X * b_X_eta
        numpyro.sample("Y", dist.OrderedLogistic(eta, c_y), obs=Y)


concentration = np.ones((nclasses,)) * 10.0

rng_key= random.PRNGKey(0)
kernel = NUTS(model_ng)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=2000, num_chains=1, thinning=1)
mcmc.run(
    rng_key=rng_key,
    X=df["X"].values,
    Y=df["Y"].values,
    nclasses=nclasses,
    concentration=concentration,
)
# with exclude_deterministic=False, we will also show the ordinal probabilities sampled from Dirichlet (vis. `c_y_base`)
mcmc.print_summary(exclude_deterministic=False)
(3,) [1.2993407 2.6367652 9.292245 ]
(3,) Traced<ConcreteArray([1.2993407 2.6367652 9.292245 ], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([1.2993407, 2.6367652, 9.292245 ], dtype=float32)
  tangent = Traced<ShapedArray(float32[3])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[3]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x3019bfba0>, in_tracers=(Traced<ShapedArray(float32[3]):JaxprTrace(level=1/0)>,), out_tracer_refs=[<weakref at 0x3028781d0; to 'JaxprTracer' at 0x302878400>], out_avals=[ShapedArray(float32[3])], primitive=pjit, params={'jaxpr': { lambda ; a:f32[3]. let b:f32[3] = cumsum[axis=0 reverse=False] a in (b,) }, 'in_shardings': (UnspecifiedValue,), 'out_shardings': (UnspecifiedValue,), 'resource_env': None, 'donated_invars': (False,), 'name': '_cumulative_reduction', 'in_positional_semantics': (<_PositionalSemantics.GLOBAL: 1>,), 'out_positional_semantics': <_PositionalSemantics.GLOBAL: 1>, 'keep_unused': False, 'inline': False}, effects=set(), source_info=SourceInfo(traceback=<jaxlib.xla_extension.Traceback object at 0x174ac8330>, name_stack=NameStack(stack=(Transform(name='jvp'),))))
Output exceeds the size limit. Open the full output data in a text editor
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[58], line 62
     60 kernel = NUTS(model_ng)
     61 mcmc = MCMC(kernel, num_warmup=1000, num_samples=2000, num_chains=1, thinning=1)
---> 62 mcmc.run(
     63     rng_key=rng_key,
     64     X=df["X"].values,
     65     Y=df["Y"].values,
     66     nclasses=nclasses,
     67     concentration=concentration,
     68 )
     69 # with exclude_deterministic=False, we will also show the ordinal probabilities sampled from Dirichlet (vis. `c_y_base`)
     70 mcmc.print_summary(exclude_deterministic=False)

File ~/Desktop/programming/numpyro_intro/.venv/lib/python3.9/site-packages/numpyro/infer/mcmc.py:628, in MCMC.run(self, rng_key, extra_fields, init_params, *args, **kwargs)
    626 map_args = (rng_key, init_state, init_params)
    627 if self.num_chains == 1:
--> 628     states_flat, last_state = partial_map_fn(map_args)
    629     states = tree_map(lambda x: x[jnp.newaxis, ...], states_flat)
    630 else:

File ~/Desktop/programming/numpyro_intro/.venv/lib/python3.9/site-packages/numpyro/infer/mcmc.py:410, in MCMC._single_chain_mcmc(self, init, args, kwargs, collect_fields)
    408 # Check if _sample_fn is None, then we need to initialize the sampler.
    409 if init_state is None or (getattr(self.sampler, "_sample_fn", None) is None):
...
-> 1617       raise TypeError(f'{name} got incompatible shapes for broadcasting: '
   1618                       f'{", ".join(map(str, map(tuple, shapes)))}.')
   1620 return tuple(result_shape)

TypeError: mul got incompatible shapes for broadcasting: (4,), (3,).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants