From 635c9c0a3c2e5dbe032797f3ad809e81d8ef61b3 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 26 Feb 2025 09:46:57 +0000 Subject: [PATCH] [BugFix] Fix serialization of stacks of Tensorclasses ghstack-source-id: 8e47f46e83982d554237604f6ef7c845eeed1b50 Pull Request resolved: https://github.com/pytorch/tensordict/pull/1236 --- tensordict/_lazy.py | 2 +- tensordict/_reductions.py | 8 ++++++-- tensordict/base.py | 1 + test/test_tensorclass.py | 21 +++++++++++++++++++++ 4 files changed, 29 insertions(+), 3 deletions(-) diff --git a/tensordict/_lazy.py b/tensordict/_lazy.py index f13be041f..a918719fb 100644 --- a/tensordict/_lazy.py +++ b/tensordict/_lazy.py @@ -383,7 +383,7 @@ def from_dict( stack_dim_name=None, stack_dim=0, ): - return LazyStackedTensorDict( + return cls._new_lazy_unsafe( *( TensorDict.from_dict( input_dict[str(i)], diff --git a/tensordict/_reductions.py b/tensordict/_reductions.py index 1816143ed..15013b59d 100644 --- a/tensordict/_reductions.py +++ b/tensordict/_reductions.py @@ -91,8 +91,12 @@ def from_metadata(metadata=metadata, prefix=None): _ = metadata.pop("size", None) d = { - key: NonTensorData(data, batch_size=batch_size) - for (key, (data, batch_size)) in non_tensor.items() + key: NonTensorData( + data, + batch_size=batch_size, + device=torch.device(device) if device is not None else None, + ) + for (key, (data, batch_size, device)) in non_tensor.items() } for key, (dtype, local_shape, start, stop, pad) in leaves.items(): dtype = _STRDTYPE2DTYPE[dtype] diff --git a/tensordict/base.py b/tensordict/base.py index 6f3b1926c..428ab3545 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -5056,6 +5056,7 @@ def assign( metadata_dict["non_tensors"][key] = ( value.data, list(value.batch_size), + str(value.device) if value.device is not None else None, ) return elif _is_tensor_collection(cls): diff --git a/test/test_tensorclass.py b/test/test_tensorclass.py index bede17948..b1b15a430 100644 --- a/test/test_tensorclass.py +++ b/test/test_tensorclass.py @@ -234,6 +234,12 @@ class MyDataClass: MyTensorClass_autocast = MyTensorClass_nocast = MyTensorClass = None +@tensorclass +class TCStrings: + a: str + b: str + + class TestTensorClass: def test_get_default(self): @tensorclass @@ -1252,6 +1258,21 @@ def test_pickle(self): assert isinstance(data2, MyData) assert data2.z == data.z + @pytest.mark.parametrize("consolidate", [False, True]) + def test_pickle_consolidate(self, consolidate): + with set_capture_non_tensor_stack(False): + + tc = TCStrings(a="a", b="b") + + tcstack = TensorDict(tc=torch.stack([tc, tc.clone()])) + if consolidate: + tcstack = tcstack.consolidate() + assert isinstance(tcstack["tc"], TCStrings) + loaded = pickle.loads(pickle.dumps(tcstack)) + assert isinstance(loaded["tc"], TCStrings) + assert loaded["tc"].a == tcstack["tc"].a + assert loaded["tc"].b == tcstack["tc"].b + def test_post_init(self): @tensorclass class MyDataPostInit: