Skip to content

Commit

Permalink
[BugFix] Fix tensorclass update
Browse files Browse the repository at this point in the history
ghstack-source-id: 3f50604b340b7a7cdc710dfaceedf563295b2911
Pull Request resolved: #1255
  • Loading branch information
vmoens committed Mar 6, 2025
1 parent 55fab2a commit cb81b5e
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 5 deletions.
4 changes: 2 additions & 2 deletions tensordict/nn/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1089,9 +1089,9 @@ def __init__(
)

self.module = module
if inplace not in (True, False, "empty"):
if inplace not in (None, True, False, "empty"):
raise ValueError(
f"The only accepted valued for inplace is `True`, `False`, or `'empty'`. Got inplace={inplace} "
f"The only accepted valued for inplace is `None`, `True`, `False`, or `'empty'`. Got inplace={inplace} "
"instead."
)
self.inplace = inplace
Expand Down
13 changes: 10 additions & 3 deletions tensordict/nn/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,9 @@ def __init__(
**{key: val for key, val in _zip_strict(modules[0], modules_vals)}
)
super().__init__(
module=nn.ModuleDict(modules), in_keys=in_keys, out_keys=out_keys
module=nn.ModuleDict(modules),
in_keys=in_keys,
out_keys=out_keys,
)
elif len(modules) == 1 and isinstance(
modules[0], collections.abc.MutableSequence
Expand All @@ -227,20 +229,25 @@ def __init__(
in_keys, out_keys = self._compute_in_and_out_keys(modules)
self._complete_out_keys = list(out_keys)
super().__init__(
module=nn.ModuleList(modules), in_keys=in_keys, out_keys=out_keys
module=nn.ModuleList(modules),
in_keys=in_keys,
out_keys=out_keys,
)
elif len(modules) == 1 and isinstance(modules[0], dict):
return self.__init__(
collections.OrderedDict(modules[0]),
partial_tolerant=partial_tolerant,
selected_out_keys=selected_out_keys,
inplace=inplace,
)
else:
modules = self._convert_modules(modules)
in_keys, out_keys = self._compute_in_and_out_keys(modules)
self._complete_out_keys = list(out_keys)
super().__init__(
module=nn.ModuleList(list(modules)), in_keys=in_keys, out_keys=out_keys
module=nn.ModuleList(list(modules)),
in_keys=in_keys,
out_keys=out_keys,
)

self.inplace = inplace
Expand Down
6 changes: 6 additions & 0 deletions tensordict/tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -1694,6 +1694,12 @@ def _update(
update_batch_size=update_batch_size,
ignore_lock=ignore_lock,
)
# We also need to remove things from non_tensordict
if self._non_tensordict:
keys = set(self._tensordict.keys())
ntd = {k: val for k, val in self._non_tensordict.items() if k not in keys}
self._non_tensordict.clear()
self._non_tensordict.update(ntd)
return self


Expand Down

0 comments on commit cb81b5e

Please sign in to comment.