Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] Fix tensorclass indexing #1217

Merged
merged 1 commit into from
Feb 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 53 additions & 17 deletions tensordict/_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
_get_shape_from_args,
_getitem_batch_size,
_is_number,
_maybe_correct_neg_dim,
_parse_to,
_renamed_inplace_method,
_shape,
Expand Down Expand Up @@ -292,6 +293,49 @@ def __init__(
if stack_dim_name is not None:
self._td_dim_name = stack_dim_name

@classmethod
def _new_lazy_unsafe(
cls,
*tensordicts: T,
stack_dim: int = 0,
hook_out: callable | None = None,
hook_in: callable | None = None,
batch_size: Sequence[int] | None = None,
device: torch.device | None = None,
names: Sequence[str] | None = None,
stack_dim_name: str | None = None,
strict_shape: bool = False,
) -> None:
self = cls.__new__(cls)
self._is_locked = None

# sanity check
num_tds = len(tensordicts)
batch_size = torch.Size(batch_size) if batch_size is not None else None
if not num_tds:
# create an empty tensor
td0 = TensorDict(batch_size=batch_size, device=device, names=names)
self._device = torch.device(device) if device is not None else None
else:
td0 = tensordicts[0]
# device = td0.device
_batch_size = td0.batch_size

for td in tensordicts[1:]:
_bs = td.batch_size
if _bs != _batch_size:
_batch_size = torch.Size(
[s if _bs[i] == s else -1 for i, s in enumerate(_batch_size)]
)
self.tensordicts: list[TensorDictBase] = list(tensordicts)
self.stack_dim = stack_dim
self._batch_size = self._compute_batch_size(_batch_size, stack_dim, num_tds)
self.hook_out = hook_out
self.hook_in = hook_in
if stack_dim_name is not None:
self._td_dim_name = stack_dim_name
return self

# These attributes should never be set
@property
def _is_shared(self):
Expand Down Expand Up @@ -633,7 +677,9 @@ def _split_index(self, index):
encountered_tensor = False
for i, idx in enumerate(index): # noqa: B007
cursor_incr = 1
if idx is None:
# if idx is None:
# idx = True
if idx is None or idx is True:
out.append(None)
num_none += cursor <= self.stack_dim
continue
Expand Down Expand Up @@ -1675,6 +1721,8 @@ def _iterate_over_keys(self) -> None:

@cache # noqa: B019
def _key_list(self):
if not self.tensordicts:
return []
keys = set(self.tensordicts[0].keys())
for td in self.tensordicts[1:]:
keys = keys.intersection(td.keys())
Expand Down Expand Up @@ -2099,15 +2147,6 @@ def assign(converted_idx, value=value):
value_unbind,
):
if mask.any():
assert (
self.tensordicts[i][_idx].shape
== torch.zeros(self.tensordicts[i].shape)[_idx].shape
), (
self.tensordicts[i].shape,
_idx,
self.tensordicts[i][_idx],
torch.zeros(self.tensordicts[i].shape)[_idx].shape,
)
self.tensordicts[i][_idx] = _value
else:
for (i, _idx), _value in _zip_strict(
Expand Down Expand Up @@ -2160,7 +2199,7 @@ def __getitem__(self, index: IndexType) -> Any:
batch_size = _getitem_batch_size(self.batch_size, index)
else:
batch_size = None
return LazyStackedTensorDict(
return self._new_lazy_unsafe(
*result,
stack_dim=cat_dim,
device=self.device,
Expand Down Expand Up @@ -3203,10 +3242,7 @@ def _unsqueeze(self, dim):
return result

def split(self, split_size: int | list[int], dim: int = 0) -> list[TensorDictBase]:
if dim < 0:
dim = self.ndim + dim
if dim < 0 or dim > self.ndim - 1:
raise ValueError(f"Out-of-bounds dim value: {dim}.")
dim = _maybe_correct_neg_dim(dim, shape=self.shape)
if dim == self.stack_dim:
if isinstance(split_size, int):
split_size = [split_size] * -(len(self.tensordicts) // -split_size)
Expand All @@ -3217,15 +3253,15 @@ def iter_across_tds():
for s in split_size:
if s == 0:
batch_size = list(self._batch_size)
batch_size.pop(self.stack_dim)
batch_size[self.stack_dim] = 0
yield LazyStackedTensorDict(
batch_size=batch_size,
device=self.device,
stack_dim=self.stack_dim,
)
continue
stop = start + s
yield LazyStackedTensorDict(
yield self._new_lazy_unsafe(
*self.tensordicts[slice(start, stop)], stack_dim=self.stack_dim
)
start = stop
Expand Down
16 changes: 16 additions & 0 deletions tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -2518,7 +2518,23 @@ def _set_at_str(self, key, value, idx, *, validated, non_blocking: bool):
tensor_in = self._get_str(key, NO_DEFAULT)

if is_non_tensor(value) and not (self._is_shared or self._is_memmap):
if isinstance(idx, tuple) and len(idx) == 1:
idx = idx[0]
dest = tensor_in
if (
isinstance(idx, torch.Tensor)
and idx.shape == ()
and self.shape == ()
and idx.dtype == torch.bool
and idx
):
self._set_str(
key,
dest.squeeze(0),
validated=True,
inplace=False,
ignore_lock=True,
)
is_diff = dest[idx].tolist() != value.tolist()
if is_diff:
dest_val = dest.maybe_to_stack()
Expand Down
2 changes: 1 addition & 1 deletion tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6593,7 +6593,7 @@ def update(
value = tree_map(torch.clone, value)
# the key must be a string by now. Let's check if it is present
if target is not None:
if not is_leaf(type(target)):
if not is_leaf(type(target)) and not is_leaf(type(value)):
if subkey:
sub_keys_to_update = _prune_selected_keys(
keys_to_update, firstkey
Expand Down
25 changes: 20 additions & 5 deletions tensordict/tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -1571,7 +1571,6 @@ def wrapped_func(self, *args, **kwargs):
td = super(type(self), self).__getattribute__("_tensordict")
else:
td = self._tensordict

result = getattr(td, funcname)(*args, **kwargs)
if no_wrap:
return result
Expand Down Expand Up @@ -1776,10 +1775,26 @@ def _setitem(self, item: NestedKey, value: Any) -> None: # noqa: D417
value (any): value to set for the item

"""
if isinstance(item, str) or (
isinstance(item, tuple) and all(isinstance(_item, str) for _item in item)
):
raise ValueError(f"Invalid indexing arguments: {item}.")
istuple = isinstance(item, tuple)
if istuple or isinstance(item, str):
# _unravel_key_to_tuple will return an empty tuple if the index isn't a NestedKey
idx_unravel = _unravel_key_to_tuple(item)
if idx_unravel:
raise ValueError(f"Invalid indexing arguments: {item}.")

if istuple and len(item) == 1:
return _setitem(self, item[0], value)
if (
(
isinstance(item, torch.Tensor)
and item.dtype == torch.bool
and not item.shape
and item
)
or (item is True)
or (item is None)
) and self.batch_size == ():
return self.update(value.squeeze(0))

if not is_tensorclass(value) and not isinstance(
value, (TensorDictBase, numbers.Number, Tensor)
Expand Down
14 changes: 11 additions & 3 deletions tensordict/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1243,7 +1243,11 @@ def func_as_decorator(_self, *args, **kwargs):
if out is not None:
if _attr_post is not _attr_pre:
ref = weakref.ref(_self)
out._last_op = (
if is_tensorclass(out):
out_lo = out._tensordict
else:
out_lo = out
out_lo._last_op = (
func.__name__,
(
args,
Expand All @@ -1262,7 +1266,11 @@ def func_as_decorator(_self, *args, **kwargs):
out = func(_self, *args, **kwargs)
if out is not None:
ref = weakref.ref(_self)
out._last_op = (func.__name__, (args, kwargs, ref))
if is_tensorclass(out):
out_lo = out._tensordict
else:
out_lo = out
out_lo._last_op = (func.__name__, (args, kwargs, ref))
return out

return func_as_decorator
Expand Down Expand Up @@ -2023,7 +2031,7 @@ def _getitem_batch_size(batch_size, index):
out = []
count = -1
for i, idx in enumerate(index):
if idx is None:
if idx is True or idx is None:
out.append(1)
continue
count += 1 if not bools[i] else idx.ndim
Expand Down
12 changes: 12 additions & 0 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -11299,6 +11299,18 @@ def test_set(self, non_tensor_data):
== "another string!"
)

def test_setitem_edge_case(self):
s = NonTensorStack("a string")
t = NonTensorStack("another string")
s[0][True] = t
assert s[0].data == "another string"
for i in (None, True):
s = NonTensorStack("0", "1")
t = NonTensorStack(NonTensorStack("2", "3"), stack_dim=1)
assert t.batch_size == (2, 1)
s[:, i] = t
assert s.tolist() == ["2", "3"]

def test_stack(self, non_tensor_data):
assert (
LazyStackedTensorDict.lazy_stack([non_tensor_data, non_tensor_data], 0).get(
Expand Down
Loading