Skip to content

Commit

Permalink
[Feature] strict kwarg in TDModule
Browse files Browse the repository at this point in the history
ghstack-source-id: ced22130bf45945e2671fa9c2e776d482fcd8b15
Pull Request resolved: #1234
  • Loading branch information
vmoens committed Feb 24, 2025
1 parent e23ce5c commit 06215b6
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 3 deletions.
12 changes: 9 additions & 3 deletions tensordict/nn/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from cloudpickle import dumps as cloudpickle_dumps, loads as cloudpickle_loads
from tensordict._td import TensorDict

from tensordict.base import is_tensor_collection, TensorDictBase
from tensordict.base import is_tensor_collection, NO_DEFAULT, TensorDictBase
from tensordict.functional import make_tensordict
from tensordict.nn.utils import _dispatch_td_nn_modules, _set_skip_existing_None
from tensordict.utils import (
Expand Down Expand Up @@ -848,6 +848,8 @@ class TensorDictModule(TensorDictModuleBase):
method (str, optional): the method to be called in the module, if any. Defaults to `__call__`.
method_kwargs (Dict[str, Any], optional): additional keyword arguments to be passed to the module's method being called.
strict (bool, optional): if ``True``, the module will raise an exception if any of the inputs is missing from
the input tensordict. Otherwise, a `None` value will be used as placeholder. Defaults to ``False``.
Embedding a neural network in a TensorDictModule only requires to specify the input
and output keys. TensorDictModule support functional and regular :obj:`nn.Module`
Expand Down Expand Up @@ -1014,6 +1016,7 @@ def __init__(
inplace: bool | str = True,
method: str | None = None,
method_kwargs: dict | None = None,
strict: bool = False,
) -> None:
super().__init__()

Expand Down Expand Up @@ -1074,6 +1077,8 @@ def __init__(
self.out_keys = out_keys
self.in_keys = in_keys

self.strict = strict

if "_" in self.in_keys:
warnings.warn(
'key "_" is for ignoring output, it should not be used in input keys',
Expand Down Expand Up @@ -1158,18 +1163,19 @@ def forward(
raise ValueError(
"Got a non-empty list of extra agruments, when none was expected."
)
default = None if not self.strict else NO_DEFAULT
if self._kwargs is not None:
kwargs.update(
{
kwarg: tensordict.get(in_key)
kwarg: tensordict.get(in_key, default=default)
for kwarg, in_key in _zip_strict(self._kwargs, self.in_keys)
}
)
tensors = ()
else:
tensors = tuple(
tensordict._get_tuple_maybe_non_tensor(
_unravel_key_to_tuple(in_key), None
_unravel_key_to_tuple(in_key), default
)
for in_key in self.in_keys
)
Expand Down
17 changes: 17 additions & 0 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,6 +659,23 @@ def test_stateful_probabilistic_kwargs(self, lazy, it, out_keys, max_dist):
assert td.shape == torch.Size([3])
assert td.get("out").shape == torch.Size([3, 4])

@pytest.mark.parametrize("strict", [True, False])
def test_strict(self, strict):
def check(a, b):
assert b is None
return a

tdm = TensorDictModule(
check,
in_keys=["present", "missing"],
out_keys=["new_present"],
strict=strict,
)
td = TensorDict(present=0)
with pytest.raises(KeyError) if strict else contextlib.nullcontext():
tdout = tdm(td)
assert tdout["new_present"] is td["present"]

def test_nontensor(self):
tdm = TensorDictModule(
lambda: NonTensorStack(NonTensorData(1), NonTensorData(2)),
Expand Down

0 comments on commit 06215b6

Please sign in to comment.