Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Incompatible shapes error after stack error #815

Closed
akotlar opened this issue Nov 18, 2020 · 5 comments · Fixed by #818
Closed

Incompatible shapes error after stack error #815

akotlar opened this issue Nov 18, 2020 · 5 comments · Fixed by #818

Comments

@akotlar
Copy link

akotlar commented Nov 18, 2020

Related to #790

At least under the circumstances listed in #790, when the AssertionError is raised, subsequent inference run on the model will fail in an ipython environment. Memoized (compiled?) tensor shapes appears to be corrupted and maintained until the next invocation.

Stack trace from 790:

---------------------------------------------------------------------------
FilteredStackTrace                        Traceback (most recent call last)
<ipython-input-4-6dacd0827004> in <module>
      2 mcmcPP = MCMC(NUTS(model), 100, 1000)
----> 3 mcmcPP.run(init_rng_key, altCounts)
      4 mcmcPP.print_summary()

~/miniconda3/lib/python3.7/site-packages/numpyro/infer/mcmc.py in run(self, rng_key, extra_fields, init_params, *args, **kwargs)
    448         if self.num_chains == 1:
--> 449             states_flat, last_state = partial_map_fn(map_args)
    450             states = tree_map(lambda x: x[jnp.newaxis, ...], states_flat)

~/miniconda3/lib/python3.7/site-packages/numpyro/infer/mcmc.py in _single_chain_mcmc(self, init, args, kwargs, collect_fields)
    312             init_state = self.sampler.init(rng_key, self.num_warmup, init_params,
--> 313                                            model_args=args, model_kwargs=kwargs)
    314         if self.postprocess_fn is None:

~/miniconda3/lib/python3.7/site-packages/numpyro/infer/hmc.py in init(self, rng_key, num_warmup, init_params, model_args, model_kwargs)
    449             rng_key, rng_key_init_model = jnp.swapaxes(vmap(random.split)(rng_key), 0, 1)
--> 450         init_params = self._init_state(rng_key_init_model, model_args, model_kwargs, init_params)
    451         if self._potential_fn and init_params is None:

~/miniconda3/lib/python3.7/site-packages/numpyro/infer/hmc.py in _init_state(self, rng_key, model_args, model_kwargs, init_params)
    412                 model_args=model_args,
--> 413                 model_kwargs=model_kwargs)
    414             if self._init_fn is None:

~/miniconda3/lib/python3.7/site-packages/numpyro/infer/util.py in initialize_model(rng_key, model, init_strategy, dynamic_args, model_args, model_kwargs)
    429                                                                   model_kwargs=model_kwargs,
--> 430                                                                   prototype_params=prototype_params)
    431 

~/miniconda3/lib/python3.7/site-packages/numpyro/infer/util.py in find_valid_initial_params(rng_key, model, init_strategy, enum, model_args, model_kwargs, prototype_params)
    256     if rng_key.ndim == 1:
--> 257         (init_params, pe, z_grad), is_valid = _find_valid_params(rng_key, exit_early=True)
    258     else:

~/miniconda3/lib/python3.7/site-packages/numpyro/infer/util.py in _find_valid_params(rng_key, exit_early)
    244             # where we can avoid compiling body_fn in while_loop.
--> 245             _, _, (init_params, pe, z_grad), is_valid = init_state = body_fn(init_state)
    246             if not_jax_tracer(is_valid):

~/miniconda3/lib/python3.7/site-packages/numpyro/infer/util.py in body_fn(state)
    234         potential_fn = partial(potential_energy, model, model_args, model_kwargs, enum=enum)
--> 235         pe, z_grad = value_and_grad(potential_fn)(params)
    236         z_grad_flat = ravel_pytree(z_grad)[0]

~/miniconda3/lib/python3.7/site-packages/numpyro/infer/util.py in potential_energy(model, model_args, model_kwargs, params, enum)
    156     # no param is needed for log_density computation because we already substitute
--> 157     log_joint, model_trace = log_density_(substituted_model, model_args, model_kwargs, {})
    158     return - log_joint

~/miniconda3/lib/python3.7/site-packages/numpyro/contrib/funsor/infer_util.py in log_density(model, model_args, model_kwargs, params)
    123     with plate_to_enum_plate():
--> 124         model_trace = packed_trace(model).get_trace(*model_args, **model_kwargs)
    125     log_factors = []

~/miniconda3/lib/python3.7/site-packages/numpyro/handlers.py in get_trace(self, *args, **kwargs)
    157         """
--> 158         self(*args, **kwargs)
    159         return self.trace

~/miniconda3/lib/python3.7/site-packages/numpyro/primitives.py in __call__(self, *args, **kwargs)
     67         with self:
---> 68             return self.fn(*args, **kwargs)
     69 

~/miniconda3/lib/python3.7/site-packages/numpyro/primitives.py in __exit__(self, *args, **kwargs)
     56     def __exit__(self, *args, **kwargs):
---> 57         assert _PYRO_STACK[-1] is self
     58         _PYRO_STACK.pop()

FilteredStackTrace: AssertionError

The stack trace above excludes JAX-internal frames.
The following is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exceptio (where 6,4 is the dimension of the pooled 'conc' sample statement, and 6 is K):

AssertionError                            Traceback (most recent call last)
<ipython-input-4-6dacd0827004> in <module>
      1 init_rng_key = random.PRNGKey(12273)
      2 mcmcPP = MCMC(NUTS(model), 100, 1000)
----> 3 mcmcPP.run(init_rng_key, altCounts)
      4 mcmcPP.print_summary()

~/miniconda3/lib/python3.7/site-packages/numpyro/infer/mcmc.py in run(self, rng_key, extra_fields, init_params, *args, **kwargs)
    447         map_args = (rng_key, init_state, init_params)
    448         if self.num_chains == 1:
--> 449             states_flat, last_state = partial_map_fn(map_args)
    450             states = tree_map(lambda x: x[jnp.newaxis, ...], states_flat)
    451         else:

~/miniconda3/lib/python3.7/site-packages/numpyro/infer/mcmc.py in _single_chain_mcmc(self, init, args, kwargs, collect_fields)
    311         if init_state is None:
    312             init_state = self.sampler.init(rng_key, self.num_warmup, init_params,
--> 313                                            model_args=args, model_kwargs=kwargs)
    314         if self.postprocess_fn is None:
    315             postprocess_fn = self.sampler.postprocess_fn(args, kwargs)

~/miniconda3/lib/python3.7/site-packages/numpyro/infer/hmc.py in init(self, rng_key, num_warmup, init_params, model_args, model_kwargs)
    448         else:
    449             rng_key, rng_key_init_model = jnp.swapaxes(vmap(random.split)(rng_key), 0, 1)
--> 450         init_params = self._init_state(rng_key_init_model, model_args, model_kwargs, init_params)
    451         if self._potential_fn and init_params is None:
    452             raise ValueError('Valid value of `init_params` must be provided with'

~/miniconda3/lib/python3.7/site-packages/numpyro/infer/hmc.py in _init_state(self, rng_key, model_args, model_kwargs, init_params)
    411                 init_strategy=self._init_strategy,
    412                 model_args=model_args,
--> 413                 model_kwargs=model_kwargs)
    414             if self._init_fn is None:
    415                 self._init_fn, self._sample_fn = hmc(potential_fn_gen=potential_fn,

~/miniconda3/lib/python3.7/site-packages/numpyro/infer/util.py in initialize_model(rng_key, model, init_strategy, dynamic_args, model_args, model_kwargs)
    428                                                                   model_args=model_args,
    429                                                                   model_kwargs=model_kwargs,
--> 430                                                                   prototype_params=prototype_params)
    431 
    432     if not_jax_tracer(is_valid):

~/miniconda3/lib/python3.7/site-packages/numpyro/infer/util.py in find_valid_initial_params(rng_key, model, init_strategy, enum, model_args, model_kwargs, prototype_params)
    255     # Handle possible vectorization
    256     if rng_key.ndim == 1:
--> 257         (init_params, pe, z_grad), is_valid = _find_valid_params(rng_key, exit_early=True)
    258     else:
    259         (init_params, pe, z_grad), is_valid = lax.map(_find_valid_params, rng_key)

~/miniconda3/lib/python3.7/site-packages/numpyro/infer/util.py in _find_valid_params(rng_key, exit_early)
    243             # Early return if valid params found. This is only helpful for single chain,
    244             # where we can avoid compiling body_fn in while_loop.
--> 245             _, _, (init_params, pe, z_grad), is_valid = init_state = body_fn(init_state)
    246             if not_jax_tracer(is_valid):
    247                 if device_get(is_valid):

~/miniconda3/lib/python3.7/site-packages/numpyro/infer/util.py in body_fn(state)
    233 
    234         potential_fn = partial(potential_energy, model, model_args, model_kwargs, enum=enum)
--> 235         pe, z_grad = value_and_grad(potential_fn)(params)
    236         z_grad_flat = ravel_pytree(z_grad)[0]
    237         is_valid = jnp.isfinite(pe) & jnp.all(jnp.isfinite(z_grad_flat))

~/miniconda3/lib/python3.7/site-packages/jax/traceback_util.py in reraise_with_filtered_traceback(*args, **kwargs)
    137   def reraise_with_filtered_traceback(*args, **kwargs):
    138     try:
--> 139       return fun(*args, **kwargs)
    140     except Exception as e:
    141       if not is_under_reraiser(e):

~/miniconda3/lib/python3.7/site-packages/jax/api.py in value_and_grad_f(*args, **kwargs)
    808     tree_map(partial(_check_input_dtype_grad, holomorphic, allow_int), dyn_args)
    809     if not has_aux:
--> 810       ans, vjp_py = _vjp(f_partial, *dyn_args)
    811     else:
    812       ans, vjp_py, aux = _vjp(f_partial, *dyn_args, has_aux=True)

~/miniconda3/lib/python3.7/site-packages/jax/api.py in _vjp(fun, has_aux, *primals)
   1839   if not has_aux:
   1840     flat_fun, out_tree = flatten_fun_nokwargs(fun, in_tree)
-> 1841     out_primal, out_vjp = ad.vjp(flat_fun, primals_flat)
   1842     out_tree = out_tree()
   1843   else:

~/miniconda3/lib/python3.7/site-packages/jax/interpreters/ad.py in vjp(traceable, primals, has_aux)
    111 def vjp(traceable, primals, has_aux=False):
    112   if not has_aux:
--> 113     out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
    114   else:
    115     out_primals, pvals, jaxpr, consts, aux = linearize(traceable, *primals, has_aux=True)

~/miniconda3/lib/python3.7/site-packages/jax/interpreters/ad.py in linearize(traceable, *primals, **kwargs)
     98   _, in_tree = tree_flatten(((primals, primals), {}))
     99   jvpfun_flat, out_tree = flatten_fun(jvpfun, in_tree)
--> 100   jaxpr, out_pvals, consts = pe.trace_to_jaxpr(jvpfun_flat, in_pvals)
    101   out_primals_pvals, out_tangents_pvals = tree_unflatten(out_tree(), out_pvals)
    102   assert all(out_primal_pval.is_known() for out_primal_pval in out_primals_pvals)

~/miniconda3/lib/python3.7/site-packages/jax/interpreters/partial_eval.py in trace_to_jaxpr(fun, pvals, instantiate)
    402   with core.new_main(JaxprTrace) as main:
    403     fun = trace_to_subjaxpr(fun, main, instantiate)
--> 404     jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
    405     assert not env
    406     del main

~/miniconda3/lib/python3.7/site-packages/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
    154 
    155     try:
--> 156       ans = self.f(*args, **dict(self.params, **kwargs))
    157     except:
    158       # Some transformations yield from inside context managers, so we have to

~/miniconda3/lib/python3.7/site-packages/numpyro/infer/util.py in potential_energy(model, model_args, model_kwargs, params, enum)
    155     substituted_model = substitute(model, substitute_fn=partial(_unconstrain_reparam, params))
    156     # no param is needed for log_density computation because we already substitute
--> 157     log_joint, model_trace = log_density_(substituted_model, model_args, model_kwargs, {})
    158     return - log_joint
    159 

~/miniconda3/lib/python3.7/site-packages/numpyro/contrib/funsor/infer_util.py in log_density(model, model_args, model_kwargs, params)
    122     model = substitute(model, data=params)
    123     with plate_to_enum_plate():
--> 124         model_trace = packed_trace(model).get_trace(*model_args, **model_kwargs)
    125     log_factors = []
    126     time_to_factors = defaultdict(list)  # log prob factors

~/miniconda3/lib/python3.7/site-packages/numpyro/handlers.py in get_trace(self, *args, **kwargs)
    156         :return: `OrderedDict` containing the execution trace.
    157         """
--> 158         self(*args, **kwargs)
    159         return self.trace
    160 

~/miniconda3/lib/python3.7/site-packages/numpyro/primitives.py in __call__(self, *args, **kwargs)
     66     def __call__(self, *args, **kwargs):
     67         with self:
---> 68             return self.fn(*args, **kwargs)
     69 
     70 

~/miniconda3/lib/python3.7/site-packages/numpyro/primitives.py in __exit__(self, *args, **kwargs)
     55 
     56     def __exit__(self, *args, **kwargs):
---> 57         assert _PYRO_STACK[-1] is self
     58         _PYRO_STACK.pop()
     59 

AssertionError: 

After fixing the cause of the error (wrapping the pooled sample statement in a plate, or removing the offending plate), a shape error is generated with the following stack trace:

~/miniconda3/lib/python3.7/site-packages/numpyro/infer/mcmc.py in run(self, rng_key, extra_fields, init_params, *args, **kwargs)
    447         map_args = (rng_key, init_state, init_params)
    448         if self.num_chains == 1:
--> 449             states_flat, last_state = partial_map_fn(map_args)
    450             states = tree_map(lambda x: x[jnp.newaxis, ...], states_flat)
    451         else:

~/miniconda3/lib/python3.7/site-packages/numpyro/infer/mcmc.py in _single_chain_mcmc(self, init, args, kwargs, collect_fields)
    311         if init_state is None:
    312             init_state = self.sampler.init(rng_key, self.num_warmup, init_params,
--> 313                                            model_args=args, model_kwargs=kwargs)
    314         if self.postprocess_fn is None:
    315             postprocess_fn = self.sampler.postprocess_fn(args, kwargs)

~/miniconda3/lib/python3.7/site-packages/numpyro/infer/hmc.py in init(self, rng_key, num_warmup, init_params, model_args, model_kwargs)
    448         else:
    449             rng_key, rng_key_init_model = jnp.swapaxes(vmap(random.split)(rng_key), 0, 1)
--> 450         init_params = self._init_state(rng_key_init_model, model_args, model_kwargs, init_params)
    451         if self._potential_fn and init_params is None:
    452             raise ValueError('Valid value of `init_params` must be provided with'

~/miniconda3/lib/python3.7/site-packages/numpyro/infer/hmc.py in _init_state(self, rng_key, model_args, model_kwargs, init_params)
    411                 init_strategy=self._init_strategy,
    412                 model_args=model_args,
--> 413                 model_kwargs=model_kwargs)
    414             if self._init_fn is None:
    415                 self._init_fn, self._sample_fn = hmc(potential_fn_gen=potential_fn,

~/miniconda3/lib/python3.7/site-packages/numpyro/infer/util.py in initialize_model(rng_key, model, init_strategy, dynamic_args, model_args, model_kwargs)
    393                                    substitute_fn=init_strategy)
    394     inv_transforms, replay_model, has_enumerate_support, model_trace = _get_model_transforms(
--> 395         substituted_model, model_args, model_kwargs)
    396     # substitute param sites from model_trace to model so
    397     # we don't need to generate again parameters of `numpyro.module`

~/miniconda3/lib/python3.7/site-packages/numpyro/infer/util.py in _get_model_transforms(model, model_args, model_kwargs)
    263 def _get_model_transforms(model, model_args=(), model_kwargs=None):
    264     model_kwargs = {} if model_kwargs is None else model_kwargs
--> 265     model_trace = trace(model).get_trace(*model_args, **model_kwargs)
    266     inv_transforms = {}
    267     # model code may need to be replayed in the presence of deterministic sites

~/miniconda3/lib/python3.7/site-packages/numpyro/handlers.py in get_trace(self, *args, **kwargs)
    156         :return: `OrderedDict` containing the execution trace.
    157         """
--> 158         self(*args, **kwargs)
    159         return self.trace
    160 

~/miniconda3/lib/python3.7/site-packages/numpyro/primitives.py in __call__(self, *args, **kwargs)
     66     def __call__(self, *args, **kwargs):
     67         with self:
---> 68             return self.fn(*args, **kwargs)
     69 
     70 

~/miniconda3/lib/python3.7/site-packages/numpyro/primitives.py in __call__(self, *args, **kwargs)
     66     def __call__(self, *args, **kwargs):
     67         with self:
---> 68             return self.fn(*args, **kwargs)
     69 
     70 

~/miniconda3/lib/python3.7/site-packages/numpyro/primitives.py in __call__(self, *args, **kwargs)
     66     def __call__(self, *args, **kwargs):
     67         with self:
---> 68             return self.fn(*args, **kwargs)
     69 
     70 

<ipython-input-5-2f909eaf93fa> in modelPartialPooled(data)
    177 def modelPartialPooled(data):
    178     # This also works, for a single set of shraed parameters
--> 179     conc = numpyro.sample("conc", Gamma(pdsAllShaped, 1))
    180     # print("conc.shape", conc)
    181     with numpyro.plate("beta_plate", nHypotheses):

~/miniconda3/lib/python3.7/site-packages/numpyro/primitives.py in sample(name, fn, obs, rng_key, sample_shape, infer)
    112 
    113     # ...and use apply_stack to send it to the Messengers
--> 114     msg = apply_stack(initial_msg)
    115     return msg['value']
    116 

~/miniconda3/lib/python3.7/site-packages/numpyro/primitives.py in apply_stack(msg)
     22     pointer = 0
     23     for pointer, handler in enumerate(reversed(_PYRO_STACK)):
---> 24         handler.process_message(msg)
     25         # When a Messenger sets the "stop" field of a message,
     26         # it prevents any Messengers above it on the stack from being applied.

~/miniconda3/lib/python3.7/site-packages/numpyro/contrib/funsor/enum_messenger.py in process_message(self, msg)
    470         if msg["type"] in ["to_funsor", "to_data"]:
    471             return super().process_message(msg)
--> 472         return OrigPlateMessenger.process_message(self, msg)
    473 
    474     def postprocess_message(self, msg):

~/miniconda3/lib/python3.7/site-packages/numpyro/primitives.py in process_message(self, msg)
    314             overlap_idx = max(len(expected_shape) - len(dist_batch_shape), 0)
    315             trailing_shape = expected_shape[overlap_idx:]
--> 316             broadcast_shape = lax.broadcast_shapes(trailing_shape, tuple(dist_batch_shape))
    317             batch_shape = expected_shape[:overlap_idx] + broadcast_shape
    318             msg['fn'] = msg['fn'].expand(batch_shape)

~/miniconda3/lib/python3.7/site-packages/jax/lax/lax.py in broadcast_shapes(*shapes)
     80   if result_shape is None:
     81     raise ValueError("Incompatible shapes for broadcasting: {}"
---> 82                      .format(tuple(map(tuple, shapes))))
     83   return result_shape
     84 

ValueError: Incompatible shapes for broadcasting: ((1, 6), (6, 4))

Reproduction case:

def modelPartialPooled(data):
    # to fix get rid of beta_plate and z
    conc = numpyro.sample("conc", Gamma(1, 1))
    with numpyro.plate("beta_plate", K):
        beta = numpyro.sample("beta", Beta(1, alpha, validate_args=False))

    with numpyro.plate("data", N):
        z = numpyro.sample("z", Categorical(mix_weights(beta)))
        # z specifically not used, to reduce degrees of freedom of the test
        probs = numpyro.sample("probs", Dirichlet(conc[0], validate_args=False))
        return numpyro.sample("obs", Multinomial(probs=probs, validate_args=False), obs=data)
@fehiepsi
Copy link
Member

It seems that if a discrete variable is not used elsewhere, then the error will happen. We should catch that issue and raise a better error message. Thanks, @akotlar!

cc @eb8680

@fritzo
Copy link
Member

fritzo commented Nov 18, 2020

[after] the AssertionError is raised, subsequent inference ...

I believe the _PYRO_STACK is not being correctly cleaned on error, due to logic in the __exit__() methods of handlers.

@fehiepsi
Copy link
Member

Thanks, @fritzo! I understand the situtation now. Will make the fix soon.

@fehiepsi
Copy link
Member

fehiepsi commented Nov 19, 2020

@akotlar The stack should be cleaned after #818. Please reopen if this happens again. Thanks for your feedback!

@akotlar
Copy link
Author

akotlar commented Dec 6, 2020

Thanks @fehiepsi @fritzo!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants