Skip to content

Commit

Permalink
add maybe_td_to_kjt function to convert td to kjt
Browse files Browse the repository at this point in the history
Summary:
# context
* add `tensordict` into torchrec's dependency tree
* the `maybe_td_to_kjt` function convert `TensorDict` into `KeyedJaggedTensor` with the correct key order.

Differential Revision: D64671782
  • Loading branch information
TroyGarden authored and facebook-github-bot committed Jan 14, 2025
1 parent 542b0b2 commit e1657f6
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 0 deletions.
43 changes: 43 additions & 0 deletions torchrec/sparse/tensor_dict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import List, Optional

import torch
from tensordict import TensorDict

from torchrec.sparse.jagged_tensor import KeyedJaggedTensor


def maybe_td_to_kjt(
features: KeyedJaggedTensor, keys: Optional[List[str]] = None
) -> KeyedJaggedTensor:
if torch.jit.is_scripting():
assert isinstance(features, KeyedJaggedTensor)
return features
if isinstance(features, TensorDict):
if keys is None:
keys = list(features.keys())
values = torch.cat([features[key]._values for key in keys], dim=0)
lengths = torch.cat(
[
(
(features[key]._lengths)
if features[key]._lengths is not None
else torch.diff(features[key]._offsets)
)
for key in keys
],
dim=0,
)
return KeyedJaggedTensor(
keys=keys,
values=values,
lengths=lengths,
)
else:
return features
58 changes: 58 additions & 0 deletions torchrec/sparse/tests/test_tensor_dict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict


import unittest

import torch
from tensordict import TensorDict
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
from torchrec.sparse.tensor_dict import maybe_td_to_kjt
from torchrec.sparse.tests.utils import repeat_test


class TestTensorDIct(unittest.TestCase):
@repeat_test(device=["cpu", "cuda", "meta"])
def test_kjt_input(self, device: str) -> None:
device = torch.device(device)
values = torch.tensor([0, 1, 2, 3, 2, 3, 4], device=device)
kjt = KeyedJaggedTensor.from_offsets_sync(
keys=["f1", "f2", "f3"],
values=values,
offsets=torch.tensor([0, 2, 2, 3, 4, 5, 7], device=device),
)
features = maybe_td_to_kjt(kjt)
self.assertEqual(features, kjt)

@repeat_test(device=["cpu", "cuda", "meta"])
def test_td_kjt(self, device: str) -> None:
device = torch.device(device)
values = torch.tensor([0, 1, 2, 3, 2, 3, 4], device=device)
lengths = torch.tensor([2, 0, 1, 1, 1, 2], device=device)
td = TensorDict(
{
"f2": torch.nested.nested_tensor_from_jagged(
torch.tensor([2, 3], device=device),
lengths=torch.tensor([1, 1], device=device),
),
"f1": torch.nested.nested_tensor_from_jagged(
torch.arange(2, device=device),
offsets=torch.tensor([0, 2, 2], device=device),
),
"f3": torch.nested.nested_tensor_from_jagged(
torch.tensor([2, 3, 4], device=device),
lengths=torch.tensor([1, 2], device=device),
),
},
device=device,
batch_size=[2],
)
features = maybe_td_to_kjt(td, ["f1", "f2", "f3"])
torch.testing.assert_close(features.values(), values)
torch.testing.assert_close(features.lengths(), lengths)

0 comments on commit e1657f6

Please sign in to comment.