Skip to content

Commit

Permalink
Update (base update)
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Feb 6, 2025
1 parent 60e7179 commit d1b97e0
Show file tree
Hide file tree
Showing 4 changed files with 179 additions and 65 deletions.
42 changes: 40 additions & 2 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6384,6 +6384,11 @@ def _default_get(self, key: NestedKey, default: Any = NO_DEFAULT) -> CompatibleT
_KEY_ERROR.format(key, type(self).__name__, sorted(self.keys()))
)

@overload
def get(self, key): ...
@overload
def get(self, key, default): ...

def get(self, key: NestedKey, *args, **kwargs) -> CompatibleType:
"""Gets the value stored with the input key.

Expand Down Expand Up @@ -6439,8 +6444,18 @@ def _get_tuple_maybe_non_tensor(self, key, default):
return result.data
return result


@overload
def get_at(self, key, index): ...

@overload
def get_at(self, key, index, default): ...

def get_at(
self, key: NestedKey, index: IndexType, default: CompatibleType = NO_DEFAULT
self,
key: NestedKey,
*args,
**kwargs,
) -> CompatibleType:
"""Get the value of a tensordict from the key `key` at the index `idx`.

Expand All @@ -6463,7 +6478,30 @@ def get_at(
key = _unravel_key_to_tuple(key)
if not key:
raise KeyError(_GENERIC_NESTED_ERR.format(key))
# must be a tuple

try:
if len(args):
index = args[0]
args = args[1:]
else:
index = kwargs.pop("index")
except KeyError:
raise TypeError("index argument missing from get_at")

# Find what the default is
if args:
default = args[0]
if len(args) > 1 or kwargs:
raise TypeError("only one (keyword) argument is allowed.")
elif kwargs:
default = kwargs.pop("default")
if args or kwargs:
raise TypeError("only one (keyword) argument is allowed.")
elif _GET_DEFAULTS_TO_NONE:
default = None
else:
default = NO_DEFAULT

return self._get_at_tuple(key, index, default)

def _get_at_str(self, key, idx, default):
Expand Down
74 changes: 58 additions & 16 deletions tensordict/tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,13 @@
from tensordict._torch_func import TD_HANDLED_FUNCTIONS
from tensordict.base import (
_ACCEPTED_CLASSES,
_GET_DEFAULTS_TO_NONE,
_is_tensor_collection,
_register_tensor_class,
CompatibleType,
)
from tensordict.utils import ( # @manual=//pytorch/tensordict:_C
_GENERIC_NESTED_ERR,
_is_dataclass as is_dataclass,
_is_json_serializable,
_is_tensorclass,
Expand Down Expand Up @@ -2238,7 +2240,7 @@ def _set_at_(
return self._tensordict.set_at_(key, value, idx, non_blocking=non_blocking)


def _get(self, key: NestedKey, default: Any = NO_DEFAULT):
def _get(self, key: NestedKey, *args, **kwargs):
"""Gets the value stored with the input key.
Args:
Expand All @@ -2250,25 +2252,65 @@ def _get(self, key: NestedKey, default: Any = NO_DEFAULT):
value stored with the input key
"""
if isinstance(key, str):
key = (key,)
key = _unravel_key_to_tuple(key)
if not key:
raise KeyError(_GENERIC_NESTED_ERR.format(key))

# Find what the default is
if args:
default = args[0]
if len(args) > 1 or kwargs:
raise TypeError("only one (keyword) argument is allowed.")
elif kwargs:
default = kwargs.pop("default")
if args or kwargs:
raise TypeError("only one (keyword) argument is allowed.")
elif _GET_DEFAULTS_TO_NONE:
default = None
else:
default = NO_DEFAULT

if isinstance(key, tuple):
try:
if len(key) > 1:
return getattr(self, key[0]).get(key[1:])
return getattr(self, key[0])
except AttributeError:
if default is NO_DEFAULT:
raise
return default
raise ValueError(f"Supported type for key are str and tuple, got {type(key)}")
try:
if len(key) > 1:
return getattr(self, key[0]).get(key[1:], default=default)
return getattr(self, key[0])
except (AttributeError, KeyError):
if default is NO_DEFAULT:
raise
return default


def _get_at(self, key: NestedKey, *args, **kwargs):
key = _unravel_key_to_tuple(key)
if not key:
raise KeyError(_GENERIC_NESTED_ERR.format(key))

try:
if len(args):
index = args[0]
args = args[1:]
else:
index = kwargs.pop("index")
except KeyError:
raise TypeError("index argument missing from get_at")

# Find what the default is
if args:
default = args[0]
if len(args) > 1 or kwargs:
raise TypeError("only one (keyword) argument is allowed.")
elif kwargs:
default = kwargs.pop("default")
if args or kwargs:
raise TypeError("only one (keyword) argument is allowed.")
elif _GET_DEFAULTS_TO_NONE:
default = None
else:
default = NO_DEFAULT

def _get_at(self, key: NestedKey, idx, default: Any = NO_DEFAULT):
try:
return self.get(key, NO_DEFAULT)[idx]
except AttributeError:
return self.get(key, NO_DEFAULT)[index]
except (AttributeError, KeyError):
if default is NO_DEFAULT:
raise
return default
Expand Down
10 changes: 9 additions & 1 deletion tensordict/tensorclass.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -595,9 +595,17 @@ class TensorClass:
def set_(
self, key: NestedKey, item: CompatibleType, *, non_blocking: bool = False
) -> T: ...
@overload
def get(self, key): ...
@overload
def get(self, key, default): ...
def get(self, key: NestedKey, *args, **kwargs) -> CompatibleType: ...
@overload
def get_at(self, key, index):...
@overload
def get_at(self, key, index, default): ...
def get_at(
self, key: NestedKey, index: IndexType, default: CompatibleType = ...
self, key: NestedKey, *args, **kwargs,
) -> CompatibleType: ...
def get_item_shape(self, key: NestedKey): ...
def update(
Expand Down
118 changes: 72 additions & 46 deletions test/test_tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
TensorDictBase,
)
from tensordict._lazy import _PermutedTensorDict, _ViewedTensorDict
from tensordict.base import _GENERIC_NESTED_ERR
from torch import Tensor

# Capture all warnings
Expand Down Expand Up @@ -233,6 +234,77 @@ class MyDataClass:


class TestTensorClass:
def test_get_default(self):
@tensorclass
class Data:
td: TensorDict
a: torch.Tensor

data = Data(td=TensorDict(), a=torch.zeros(()))
assert data.get("a") is not None
assert data.get("b") is None
assert data.get("b", "else") == "else"

with pytest.raises(KeyError, match=_GENERIC_NESTED_ERR.format(())):
data.get(("td", str)) # something unexpected!

assert data.get(("td", "missing"), "else") == "else"
assert data.get(("td", "missing")) is None

data = data.expand(10)
assert data.get_at("a", 0) is not None
assert data.get_at("b", 0) is None
assert data.get_at("b", 0, "else") == "else"

assert data.get_at(("td", "missing"), 0, "else") == "else"
assert data.get_at(("td", "missing"), 0) is None

def test_decorator(self):
@tensorclass
class MyClass:
X: torch.Tensor
y: Any

obj = MyClass(X=torch.zeros(2), y="a string!", batch_size=[])
assert not obj.is_locked
with obj.lock_():
assert obj.is_locked
with obj.unlock_():
assert not obj.is_locked
assert obj.is_locked
assert not obj.is_locked

def test_to_dict(self):
@tensorclass
class TestClass:
my_tensor: torch.Tensor
my_str: str

test_class = TestClass(
my_tensor=torch.tensor([1, 2, 3]), my_str="hello", batch_size=[3]
)

assert (
test_class
== TestClass.from_dict(test_class.to_dict(), auto_batch_size=True)
).all()

# Currently we don't test non-tensor in __eq__ because __eq__ can break with arrays and such
# test_class2 = TestClass(
# my_tensor=torch.tensor([1, 2, 3]), my_str="goodbye", batch_size=[3]
# )
#
# assert not (test_class == TestClass.from_dict(test_class2.to_dict())).all()

test_class3 = TestClass(
my_tensor=torch.tensor([1, 2, 0]), my_str="hello", batch_size=[3]
)

assert not (
test_class
== TestClass.from_dict(test_class3.to_dict(), auto_batch_size=True)
).all()

def test_all_any(self):
@tensorclass
class MyClass1:
Expand Down Expand Up @@ -2229,52 +2301,6 @@ def test_to(self):
assert td_double.device == torch.device("cpu")


def test_decorator():
@tensorclass
class MyClass:
X: torch.Tensor
y: Any

obj = MyClass(X=torch.zeros(2), y="a string!", batch_size=[])
assert not obj.is_locked
with obj.lock_():
assert obj.is_locked
with obj.unlock_():
assert not obj.is_locked
assert obj.is_locked
assert not obj.is_locked


def test_to_dict():
@tensorclass
class TestClass:
my_tensor: torch.Tensor
my_str: str

test_class = TestClass(
my_tensor=torch.tensor([1, 2, 3]), my_str="hello", batch_size=[3]
)

assert (
test_class == TestClass.from_dict(test_class.to_dict(), auto_batch_size=True)
).all()

# Currently we don't test non-tensor in __eq__ because __eq__ can break with arrays and such
# test_class2 = TestClass(
# my_tensor=torch.tensor([1, 2, 3]), my_str="goodbye", batch_size=[3]
# )
#
# assert not (test_class == TestClass.from_dict(test_class2.to_dict())).all()

test_class3 = TestClass(
my_tensor=torch.tensor([1, 2, 0]), my_str="hello", batch_size=[3]
)

assert not (
test_class == TestClass.from_dict(test_class3.to_dict(), auto_batch_size=True)
).all()


@tensorclass(autocast=True)
class AutoCast:
tensor: torch.Tensor
Expand Down

0 comments on commit d1b97e0

Please sign in to comment.