From 3129b563822bd569a8d47fe57c2f618374b14225 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 5 Feb 2025 11:46:49 +0000 Subject: [PATCH] [BE] No warning if user sets the log_prob_key explicitly and only one variable is sampled from the ProbTDMod ghstack-source-id: ccc966d5e698a4fb394081a92bafc31649951ab7 Pull Request resolved: https://github.com/pytorch/tensordict/pull/1209 --- tensordict/nn/distributions/composite.py | 4 +- tensordict/nn/probabilistic.py | 15 ++++-- test/test_nn.py | 66 ++++++++++++++++++++++++ 3 files changed, 80 insertions(+), 5 deletions(-) diff --git a/tensordict/nn/distributions/composite.py b/tensordict/nn/distributions/composite.py index 1eeb7e990..0de644ad5 100644 --- a/tensordict/nn/distributions/composite.py +++ b/tensordict/nn/distributions/composite.py @@ -130,7 +130,9 @@ def __init__( dist_params = params.get(name) kwargs = extra_kwargs.get(name, {}) if dist_params is None: - raise KeyError(f"no param {name} found in params with keys {params.keys(True, True)}") + raise KeyError( + f"no param {name} found in params with keys {params.keys(True, True)}" + ) dist = dist_class(**dist_params, **kwargs) dists[write_name] = dist self.dists = dists diff --git a/tensordict/nn/probabilistic.py b/tensordict/nn/probabilistic.py index 9ae1133f3..ee7b72076 100644 --- a/tensordict/nn/probabilistic.py +++ b/tensordict/nn/probabilistic.py @@ -328,6 +328,9 @@ class ProbabilisticTensorDictModule(TensorDictModuleBase): """ + # To be removed in v0.9 + _trigger_warning_lpk: bool = False + def __init__( self, in_keys: NestedKey | List[NestedKey] | Dict[str, NestedKey], @@ -396,9 +399,11 @@ def __init__( "composite_lp_aggregate is set to True but log_prob_keys were passed. " "When composite_lp_aggregate() returns ``True``, log_prob_key must be used instead." ) + self._trigger_warning_lpk = len(self._out_keys) > 1 if log_prob_key is None: if composite_lp_aggregate(nowarn=True): log_prob_key = "sample_log_prob" + self._trigger_warning_lpk = True elif len(out_keys) == 1: log_prob_key = _add_suffix(out_keys[0], "_log_prob") elif len(out_keys) > 1 and not composite_lp_aggregate(nowarn=True): @@ -451,13 +456,15 @@ def log_prob_key(self): f"unless there is one and only one element in log_prob_keys (got log_prob_keys={self.log_prob_keys}). " f"When composite_lp_aggregate() returns ``False``, try to use {type(self).__name__}.log_prob_keys instead." ) - if _composite_lp_aggregate.get_mode() is None: + if _composite_lp_aggregate.get_mode() is None and self._trigger_warning_lpk: warnings.warn( f"You are querying the log-probability key of a {type(self).__name__} where the " - f"composite_lp_aggregate has not been set. " + f"composite_lp_aggregate has not been set and the log-prob key has not been chosen. " f"Currently, it is assumed that composite_lp_aggregate() will return True: the log-probs will be aggregated " - f"in a {self._log_prob_key} entry. From v0.9, this behaviour will be changed and individual log-probs will " - f"be written in `('path', 'to', 'leaf', '_log_prob')`. To prepare for this change, " + f"in a {self._log_prob_key} entry. " + f"From v0.9, this behaviour will be changed and individual log-probs will " + f"be written in `('path', 'to', 'leaf', '_log_prob')`. " + f"To prepare for this change, " f"call `set_composite_lp_aggregate(mode: bool).set()` at the beginning of your script (or set the " f"COMPOSITE_LP_AGGREGATE env variable). Use mode=True " f"to keep the current behaviour, and mode=False to use per-leaf log-probs.", diff --git a/test/test_nn.py b/test/test_nn.py index a2c3597b7..dd91ab60b 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -2270,6 +2270,72 @@ def test_index_prob_seq(self): assert isinstance(seq[:2], ProbabilisticTensorDictSequential) assert isinstance(seq[-2:], ProbabilisticTensorDictSequential) + def test_no_warning_single_key(self): + # Check that there is no warning if the number of out keys is 1 and sample log prob is set + torch.manual_seed(0) + with set_composite_lp_aggregate(None): + mod = ProbabilisticTensorDictModule( + in_keys=["loc", "scale"], + distribution_class=torch.distributions.Normal, + out_keys=[("an", "action")], + log_prob_key="sample_log_prob", + return_log_prob=True, + ) + td = TensorDict(loc=torch.randn(()), scale=torch.rand(())) + mod(td.copy()) + mod.log_prob(mod(td.copy())) + mod.log_prob_key + + # Don't set the key and trigger the warning + mod = ProbabilisticTensorDictModule( + in_keys=["loc", "scale"], + distribution_class=torch.distributions.Normal, + out_keys=[("an", "action")], + return_log_prob=True, + ) + with pytest.warns( + DeprecationWarning, match="You are querying the log-probability key" + ): + mod(td.copy()) + mod.log_prob(mod(td.copy())) + mod.log_prob_key + + # add another variable, and trigger the warning + mod = ProbabilisticTensorDictModule( + in_keys=["params"], + distribution_class=CompositeDistribution, + distribution_kwargs={ + "distribution_map": { + "dirich": torch.distributions.Dirichlet, + "categ": torch.distributions.Categorical, + } + }, + out_keys=[("dirich", "categ")], + return_log_prob=True, + ) + with pytest.warns( + DeprecationWarning, match="You are querying the log-probability key" + ), pytest.warns( + DeprecationWarning, + match="Composite log-prob aggregation wasn't defined explicitly", + ): + td = TensorDict( + params=TensorDict( + dirich=TensorDict( + concentration=torch.rand( + ( + 10, + 11, + ) + ) + ), + categ=TensorDict(logits=torch.rand((5,))), + ) + ) + mod(td.copy()) + mod.log_prob(mod(td.copy())) + mod.log_prob_key + class TestEnsembleModule: def test_init(self):