Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] LazyStackStorage #2723

Merged
merged 2 commits into from
Jan 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/reference/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ using the following components:
LazyMemmapStorage
LazyTensorStorage
ListStorage
LazyStackStorage
ListStorageCheckpointer
NestedStorageCheckpointer
PrioritizedSampler
Expand Down
26 changes: 26 additions & 0 deletions test/test_rb.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@

from torchrl.data.replay_buffers.storages import (
LazyMemmapStorage,
LazyStackStorage,
LazyTensorStorage,
ListStorage,
StorageEnsemble,
Expand Down Expand Up @@ -1116,6 +1117,31 @@ def test_storage_inplace_writing_ndim(self, storage_type):
assert (rb[:, 10:20] == 0).all()
assert len(rb) == 100

@pytest.mark.parametrize("max_size", [1000, None])
@pytest.mark.parametrize("stack_dim", [-1, 0])
def test_lazy_stack_storage(self, max_size, stack_dim):
# Create an instance of LazyStackStorage with given parameters
storage = LazyStackStorage(max_size=max_size, stack_dim=stack_dim)
# Create a ReplayBuffer using the created storage
rb = ReplayBuffer(storage=storage)
# Generate some random data to add to the buffer
torch.manual_seed(0)
data0 = TensorDict(a=torch.randn((10,)), b=torch.rand(4), c="a string!")
data1 = TensorDict(a=torch.randn((11,)), b=torch.rand(4), c="another string!")
# Add the data to the buffer
rb.add(data0)
rb.add(data1)
# Sample from the buffer
sample = rb.sample(10)
# Check that the sampled data has the correct shape and type
assert isinstance(sample, LazyStackedTensorDict)
assert sample["b"].shape[0] == 10
assert all(isinstance(item, str) for item in sample["c"])
# If densify is True, check that the sampled data is dense
sample = sample.densify(layout=torch.jagged)
assert isinstance(sample["a"], torch.Tensor)
assert sample["a"].shape[0] == 10


@pytest.mark.parametrize("max_size", [1000])
@pytest.mark.parametrize("shape", [[3, 4]])
Expand Down
1 change: 1 addition & 0 deletions torchrl/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
H5StorageCheckpointer,
ImmutableDatasetWriter,
LazyMemmapStorage,
LazyStackStorage,
LazyTensorStorage,
ListStorage,
ListStorageCheckpointer,
Expand Down
1 change: 1 addition & 0 deletions torchrl/data/replay_buffers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
)
from .storages import (
LazyMemmapStorage,
LazyStackStorage,
LazyTensorStorage,
ListStorage,
Storage,
Expand Down
6 changes: 5 additions & 1 deletion torchrl/data/replay_buffers/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1049,6 +1049,10 @@ def __init__(
self._cache["stop-and-length"] = vals

else:
if traj_key is not None:
self._fetch_traj = True
elif end_key is not None:
self._fetch_traj = False
if end_key is None:
end_key = ("next", "done")
if traj_key is None:
Expand Down Expand Up @@ -1331,7 +1335,7 @@ def sample(self, storage: Storage, batch_size: int) -> Tuple[torch.Tensor, dict]
if start_idx.shape[1] != storage.ndim:
raise RuntimeError(
f"Expected the end-of-trajectory signal to be "
f"{storage.ndim}-dimensional. Got a {start_idx.shape[1]} tensor "
f"{storage.ndim}-dimensional. Got a tensor with shape[1]={start_idx.shape[1]} "
"instead."
)
seq_length, num_slices = self._adjusted_batch_size(batch_size)
Expand Down
79 changes: 79 additions & 0 deletions torchrl/data/replay_buffers/storages.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,15 @@ def set(
def get(self, index: Union[int, Sequence[int], slice]) -> Any:
if isinstance(index, (INT_CLASSES, slice)):
return self._storage[index]
elif isinstance(index, tuple):
if len(index) > 1:
raise RuntimeError(
f"{type(self).__name__} can only be indexed with one-length tuples."
)
return self.get(index[0])
else:
if isinstance(index, torch.Tensor) and index.device.type != "cpu":
index = index.cpu().tolist()
return [self._storage[i] for i in index]

def __len__(self):
Expand Down Expand Up @@ -353,6 +361,77 @@ def contains(self, item):
raise NotImplementedError(f"type {type(item)} is not supported yet.")


class LazyStackStorage(ListStorage):
"""A ListStorage that returns LazyStackTensorDict instances.

This storage allows for heterougeneous structures to be indexed as a single `TensorDict` representation.
It uses :class:`~tensordict.LazyStackedTensorDict` which operates on non-contiguous lists of tensordicts,
lazily stacking items when queried.
This means that this storage is going to be fast to sample but data access may be slow (as it requires a stack).
Tensors of heterogeneous shapes can also be stored within the storage and stacked together.
Because the storage is represented as a list, the number of tensors to store in memory will grow linearly with
the size of the buffer.

If possible, nested tensors can also be created via :meth:`~tensordict.LazyStackedTensorDict.densify`
(see :mod:`~torch.nested`).

Args:
max_size (int, optional): the maximum number of elements stored in the storage.
If not provided, an unlimited storage is created.

Keyword Args:
compilable (bool, optional): if ``True``, the storage will be made compatible with :func:`~torch.compile` at
the cost of being executable in multiprocessed settings.
stack_dim (int, optional): the stack dimension in terms of TensorDict batch sizes. Defaults to `-1`.

Examples:
>>> import torch
>>> from torchrl.data import ReplayBuffer, LazyStackStorage
>>> from tensordict import TensorDict
>>> _ = torch.manual_seed(0)
>>> rb = ReplayBuffer(storage=LazyStackStorage(max_size=1000, stack_dim=-1))
>>> data0 = TensorDict(a=torch.randn((10,)), b=torch.rand(4), c="a string!")
>>> data1 = TensorDict(a=torch.randn((11,)), b=torch.rand(4), c="another string!")
>>> _ = rb.add(data0)
>>> _ = rb.add(data1)
>>> rb.sample(10)
LazyStackedTensorDict(
fields={
a: Tensor(shape=torch.Size([10, -1]), device=cpu, dtype=torch.float32, is_shared=False),
b: Tensor(shape=torch.Size([10, 4]), device=cpu, dtype=torch.float32, is_shared=False),
c: NonTensorStack(
['another string!', 'another string!', 'another st...,
batch_size=torch.Size([10]),
device=None)},
exclusive_fields={
},
batch_size=torch.Size([10]),
device=None,
is_shared=False,
stack_dim=0)
"""

def __init__(
self,
max_size: int | None = None,
*,
compilable: bool = False,
stack_dim: int = -1,
):
super().__init__(max_size=max_size, compilable=compilable)
self.stack_dim = stack_dim

def get(self, index: Union[int, Sequence[int], slice]) -> Any:
out = super().get(index=index)
if isinstance(out, list):
stack_dim = self.stack_dim
if stack_dim < 0:
stack_dim = out[0].ndim + 1 + stack_dim
out = LazyStackedTensorDict(*out, stack_dim=stack_dim)
return out
return out


class TensorStorage(Storage):
"""A storage for tensors and tensordicts.

Expand Down
Loading