Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Jan 27, 2025
1 parent 2984055 commit 6755ee6
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 1 deletion.
1 change: 1 addition & 0 deletions docs/source/reference/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ using the following components:
LazyMemmapStorage
LazyTensorStorage
ListStorage
LazyStackStorage
ListStorageCheckpointer
NestedStorageCheckpointer
PrioritizedSampler
Expand Down
1 change: 1 addition & 0 deletions torchrl/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
H5StorageCheckpointer,
ImmutableDatasetWriter,
LazyMemmapStorage,
LazyStackStorage,
LazyTensorStorage,
ListStorage,
ListStorageCheckpointer,
Expand Down
1 change: 1 addition & 0 deletions torchrl/data/replay_buffers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
)
from .storages import (
LazyMemmapStorage,
LazyStackStorage,
LazyTensorStorage,
ListStorage,
Storage,
Expand Down
6 changes: 5 additions & 1 deletion torchrl/data/replay_buffers/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1049,6 +1049,10 @@ def __init__(
self._cache["stop-and-length"] = vals

else:
if traj_key is not None:
self._fetch_traj = True
elif end_key is not None:
self._fetch_traj = False
if end_key is None:
end_key = ("next", "done")
if traj_key is None:
Expand Down Expand Up @@ -1331,7 +1335,7 @@ def sample(self, storage: Storage, batch_size: int) -> Tuple[torch.Tensor, dict]
if start_idx.shape[1] != storage.ndim:
raise RuntimeError(
f"Expected the end-of-trajectory signal to be "
f"{storage.ndim}-dimensional. Got a {start_idx.shape[1]} tensor "
f"{storage.ndim}-dimensional. Got a tensor with shape[1]={start_idx.shape[1]} "
"instead."
)
seq_length, num_slices = self._adjusted_batch_size(batch_size)
Expand Down
37 changes: 37 additions & 0 deletions torchrl/data/replay_buffers/storages.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,15 @@ def set(
def get(self, index: Union[int, Sequence[int], slice]) -> Any:
if isinstance(index, (INT_CLASSES, slice)):
return self._storage[index]
elif isinstance(index, tuple):
if len(index) > 1:
raise RuntimeError(
f"{type(self).__name__} can only be indexed with one-length tuples."
)
return self.get(index[0])
else:
if isinstance(index, torch.Tensor) and index.device.type != "cpu":
index = index.cpu().tolist()
return [self._storage[i] for i in index]

def __len__(self):
Expand Down Expand Up @@ -353,6 +361,35 @@ def contains(self, item):
raise NotImplementedError(f"type {type(item)} is not supported yet.")


class LazyStackStorage(ListStorage):
"""A ListStorage that returns LazyStackTensorDict instances."""

def __init__(
self,
max_size,
compilable: bool = False,
stack_dim: int = -1,
densify: bool = False,
dense_layout: torch.layout = torch.jagged,
):
super().__init__(max_size=max_size, compilable=compilable)
self.stack_dim = stack_dim
self.densify = densify
self.dense_layout = dense_layout

def get(self, index: Union[int, Sequence[int], slice]) -> Any:
out = super().get(index=index)
if isinstance(out, list):
stack_dim = self.stack_dim
if stack_dim < 0:
stack_dim = out[0].ndim + 1 + stack_dim
out = LazyStackedTensorDict(*out, stack_dim=stack_dim)
if self.densify:
return out.densify(layout=self.dense_layout)
return out
return out


class TensorStorage(Storage):
"""A storage for tensors and tensordicts.
Expand Down

0 comments on commit 6755ee6

Please sign in to comment.