From 4b60da54a621893e0d88078b154519dc86d9c5ce Mon Sep 17 00:00:00 2001 From: vmoens Date: Wed, 24 Jan 2024 17:55:33 +0000 Subject: [PATCH] init --- tensordict/_td.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tensordict/_td.py b/tensordict/_td.py index 1ad770f13..c94f578cf 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -1971,7 +1971,8 @@ def _exclude(self, *keys: str, inplace: bool = False, set_shared: bool = True) - if keys_to_exclude is None: # delay creation of defaultdict keys_to_exclude = defaultdict(list) - keys_to_exclude[key[0]].append(key[1:]) + if key[0] in self._tensordict: + keys_to_exclude[key[0]].append(key[1:]) if keys_to_exclude is not None: for key, cur_keys in keys_to_exclude.items(): val = _tensordict.get(key, None) @@ -1979,8 +1980,8 @@ def _exclude(self, *keys: str, inplace: bool = False, set_shared: bool = True) - val = val._exclude( *cur_keys, inplace=inplace, set_shared=set_shared ) - if not inplace: - _tensordict[key] = val + if not inplace: + _tensordict[key] = val if inplace: return self result = TensorDict(