Skip to content

Commit

Permalink
[BugFix] Consolidate lazy stacks of non-tensors
Browse files Browse the repository at this point in the history
ghstack-source-id: d3d822dba235b74128f99e6cbff08989d13c1af4
Pull Request resolved: #1224
  • Loading branch information
vmoens committed Feb 20, 2025
1 parent 0b901a7 commit f67a15c
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 2 deletions.
4 changes: 3 additions & 1 deletion tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@
lazy_legacy,
lock_blocked,
prod,
set_capture_non_tensor_stack,
set_lazy_legacy,
strtobool,
TensorDictFuture,
Expand Down Expand Up @@ -9088,7 +9089,8 @@ def newfn(item_and_out):
from tensordict._lazy import LazyStackedTensorDict

# We want to be able to return whichever data structure
out = LazyStackedTensorDict.maybe_dense_stack(imaplist, dim)
with set_capture_non_tensor_stack(False):
out = LazyStackedTensorDict.maybe_dense_stack(imaplist, dim)
else:
out = torch.cat(imaplist, dim)
return out
Expand Down
29 changes: 28 additions & 1 deletion test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
from tensordict import (
capture_non_tensor_stack,
get_defaults_to_none,
lazy_legacy,
lazy_stack,
LazyStackedTensorDict,
make_tensordict,
PersistentTensorDict,
Expand Down Expand Up @@ -67,7 +69,6 @@
convert_ellipsis_to_idx,
is_non_tensor,
is_tensorclass,
lazy_legacy,
logger as tdlogger,
set_lazy_legacy,
)
Expand Down Expand Up @@ -11220,6 +11221,11 @@ def test_map_iter_interrupt_early(self, chunksize, num_chunks, shuffle):


class TestNonTensorData:
@tensorclass
class SomeTensorClass:
a: str
b: torch.Tensor

@pytest.fixture
def non_tensor_data(self):
return TensorDict(
Expand All @@ -11234,6 +11240,27 @@ def non_tensor_data(self):
batch_size=[],
)

@set_capture_non_tensor_stack(False)
def test_consolidate_nested(self):
import pickle

td = TensorDict(
a=TensorDict(b=self.SomeTensorClass(a="a string!", b=torch.randn(10))),
c=TensorDict(d=NonTensorData("another string!")),
)
td = lazy_stack([td.clone(), td.clone()])
td = lazy_stack([td.clone(), td.clone()], -1)

tdc = td.consolidate()

assert (tdc == td).all()

tdr = pickle.loads(pickle.dumps(td))
assert (tdr == td).all()

tdcr = pickle.loads(pickle.dumps(tdc))
assert (tdcr == td).all()

def test_comparison(self, non_tensor_data):
non_tensor_data = non_tensor_data.exclude(("nested", "str"))
assert (non_tensor_data | non_tensor_data).get_non_tensor(("nested", "bool"))
Expand Down

0 comments on commit f67a15c

Please sign in to comment.