Skip to content

Commit

Permalink
Better guesses for why logp has RVs
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Jan 28, 2025
1 parent fa43eba commit 4a0cbd1
Showing 1 changed file with 10 additions and 9 deletions.
19 changes: 10 additions & 9 deletions pymc/logprob/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,13 +400,6 @@ def icdf(rv: TensorVariable, value: TensorLike, warn_rvs=None, **kwargs) -> Tens
return expr


RVS_IN_JOINT_LOGP_GRAPH_MSG = (
"Random variables detected in the logp graph: %s.\n"
"This can happen when DensityDist logp or Interval transform functions reference nonlocal variables,\n"
"or when not all rvs have a corresponding value variable."
)


def conditional_logp(
rv_values: dict[TensorVariable, TensorVariable],
warn_rvs=None,
Expand Down Expand Up @@ -563,7 +556,11 @@ def conditional_logp(
if warn_rvs:
rvs_in_logp_expressions = _find_unallowed_rvs_in_graph(logprobs)
if rvs_in_logp_expressions:
warnings.warn(RVS_IN_JOINT_LOGP_GRAPH_MSG % rvs_in_logp_expressions, UserWarning)
warnings.warn(
f"Random variables detected in the logp graph: {rvs_in_logp_expressions}.\n"
"This can happen when not all random variables have a corresponding value variable.",
UserWarning,
)

return values_to_logprobs

Expand Down Expand Up @@ -611,7 +608,11 @@ def transformed_conditional_logp(

rvs_in_logp_expressions = _find_unallowed_rvs_in_graph(logp_terms_list)
if rvs_in_logp_expressions:
raise ValueError(RVS_IN_JOINT_LOGP_GRAPH_MSG % rvs_in_logp_expressions)
raise ValueError(
f"Random variables detected in the logp graph: {rvs_in_logp_expressions}.\n"
"This can happen when mixing variables from different models, "
"or when CustomDist logp or Interval transform functions reference nonlocal variables."
)

return logp_terms_list

Expand Down

0 comments on commit 4a0cbd1

Please sign in to comment.