diff --git a/torchrec/sparse/tensor_dict.py b/torchrec/sparse/tensor_dict.py new file mode 100644 index 000000000..5eadebd1b --- /dev/null +++ b/torchrec/sparse/tensor_dict.py @@ -0,0 +1,43 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import List, Optional + +import torch +from tensordict import TensorDict + +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor + + +def maybe_td_to_kjt( + features: KeyedJaggedTensor, keys: Optional[List[str]] = None +) -> KeyedJaggedTensor: + if torch.jit.is_scripting(): + assert isinstance(features, KeyedJaggedTensor) + return features + if isinstance(features, TensorDict): + if keys is None: + keys = list(features.keys()) + values = torch.cat([features[key]._values for key in keys], dim=0) + lengths = torch.cat( + [ + ( + (features[key]._lengths) + if features[key]._lengths is not None + else torch.diff(features[key]._offsets) + ) + for key in keys + ], + dim=0, + ) + return KeyedJaggedTensor( + keys=keys, + values=values, + lengths=lengths, + ) + else: + return features diff --git a/torchrec/sparse/tests/test_tensor_dict.py b/torchrec/sparse/tests/test_tensor_dict.py new file mode 100644 index 000000000..2a7c02439 --- /dev/null +++ b/torchrec/sparse/tests/test_tensor_dict.py @@ -0,0 +1,58 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + + +import unittest + +import torch +from tensordict import TensorDict +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor +from torchrec.sparse.tensor_dict import maybe_td_to_kjt +from torchrec.sparse.tests.utils import repeat_test + + +class TestTensorDIct(unittest.TestCase): + @repeat_test(device=["cpu", "cuda", "meta"]) + def test_kjt_input(self, device: str) -> None: + device = torch.device(device) + values = torch.tensor([0, 1, 2, 3, 2, 3, 4], device=device) + kjt = KeyedJaggedTensor.from_offsets_sync( + keys=["f1", "f2", "f3"], + values=values, + offsets=torch.tensor([0, 2, 2, 3, 4, 5, 7], device=device), + ) + features = maybe_td_to_kjt(kjt) + self.assertEqual(features, kjt) + + @repeat_test(device=["cpu", "cuda", "meta"]) + def test_td_kjt(self, device: str) -> None: + device = torch.device(device) + values = torch.tensor([0, 1, 2, 3, 2, 3, 4], device=device) + lengths = torch.tensor([2, 0, 1, 1, 1, 2], device=device) + td = TensorDict( + { + "f2": torch.nested.nested_tensor_from_jagged( + torch.tensor([2, 3], device=device), + lengths=torch.tensor([1, 1], device=device), + ), + "f1": torch.nested.nested_tensor_from_jagged( + torch.arange(2, device=device), + offsets=torch.tensor([0, 2, 2], device=device), + ), + "f3": torch.nested.nested_tensor_from_jagged( + torch.tensor([2, 3, 4], device=device), + lengths=torch.tensor([1, 2], device=device), + ), + }, + device=device, + batch_size=[2], + ) + features = maybe_td_to_kjt(td, ["f1", "f2", "f3"]) + torch.testing.assert_close(features.values(), values) + torch.testing.assert_close(features.lengths(), lengths)