diff --git a/tensordict/nn/common.py b/tensordict/nn/common.py index 02b98cf30..040b7f50a 100644 --- a/tensordict/nn/common.py +++ b/tensordict/nn/common.py @@ -1055,7 +1055,21 @@ def _write_to_tensordict( tensordict_out = TensorDict() else: tensordict_out = tensordict - for _out_key, _tensor in _zip_strict(out_keys, tensors): + if len(tensors) > len(out_keys): + incipit = "There are more tensors than out_keys. " + elif len(out_keys) > len(tensors): + incipit = "There are more out_keys than tensors. " + else: + incipit = None + if incipit is not None: + warnings.warn( + incipit + "This is currently " + "allowed but it will be deprecated in v0.9. To silence this warning, " + "make sure the number of out_keys matches the number of outputs of the " + "network.", + category=DeprecationWarning, + ) + for _out_key, _tensor in zip(out_keys, tensors): if _out_key != "_": tensordict_out.set(_out_key, TensorDict.from_any(_tensor)) return tensordict_out