Skip to content

Commit

Permalink
[BugFix] Fix serialization of stacks of Tensorclasses
Browse files Browse the repository at this point in the history
ghstack-source-id: 8e47f46e83982d554237604f6ef7c845eeed1b50
Pull Request resolved: #1236
  • Loading branch information
vmoens committed Feb 26, 2025
1 parent 06215b6 commit 635c9c0
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 3 deletions.
2 changes: 1 addition & 1 deletion tensordict/_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ def from_dict(
stack_dim_name=None,
stack_dim=0,
):
return LazyStackedTensorDict(
return cls._new_lazy_unsafe(
*(
TensorDict.from_dict(
input_dict[str(i)],
Expand Down
8 changes: 6 additions & 2 deletions tensordict/_reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,12 @@ def from_metadata(metadata=metadata, prefix=None):
_ = metadata.pop("size", None)

d = {
key: NonTensorData(data, batch_size=batch_size)
for (key, (data, batch_size)) in non_tensor.items()
key: NonTensorData(
data,
batch_size=batch_size,
device=torch.device(device) if device is not None else None,
)
for (key, (data, batch_size, device)) in non_tensor.items()
}
for key, (dtype, local_shape, start, stop, pad) in leaves.items():
dtype = _STRDTYPE2DTYPE[dtype]
Expand Down
1 change: 1 addition & 0 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5056,6 +5056,7 @@ def assign(
metadata_dict["non_tensors"][key] = (
value.data,
list(value.batch_size),
str(value.device) if value.device is not None else None,
)
return
elif _is_tensor_collection(cls):
Expand Down
21 changes: 21 additions & 0 deletions test/test_tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,12 @@ class MyDataClass:
MyTensorClass_autocast = MyTensorClass_nocast = MyTensorClass = None


@tensorclass
class TCStrings:
a: str
b: str


class TestTensorClass:
def test_get_default(self):
@tensorclass
Expand Down Expand Up @@ -1252,6 +1258,21 @@ def test_pickle(self):
assert isinstance(data2, MyData)
assert data2.z == data.z

@pytest.mark.parametrize("consolidate", [False, True])
def test_pickle_consolidate(self, consolidate):
with set_capture_non_tensor_stack(False):

tc = TCStrings(a="a", b="b")

tcstack = TensorDict(tc=torch.stack([tc, tc.clone()]))
if consolidate:
tcstack = tcstack.consolidate()
assert isinstance(tcstack["tc"], TCStrings)
loaded = pickle.loads(pickle.dumps(tcstack))
assert isinstance(loaded["tc"], TCStrings)
assert loaded["tc"].a == tcstack["tc"].a
assert loaded["tc"].b == tcstack["tc"].b

def test_post_init(self):
@tensorclass
class MyDataPostInit:
Expand Down

0 comments on commit 635c9c0

Please sign in to comment.