Skip to content

Commit

Permalink
handle empty (keys) sparse features case (#1883)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1883

handle empty keys on meta device path

Reviewed By: MarcioPorto

Differential Revision: D56175026

fbshipit-source-id: 7707cdecc4a68069ee809d630f20aeebbc1db1d7
  • Loading branch information
edqwerty10 authored and facebook-github-bot committed Apr 18, 2024
1 parent a80219a commit 737e283
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 1 deletion.
2 changes: 1 addition & 1 deletion torchrec/sparse/jagged_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -786,7 +786,7 @@ def _maybe_compute_length_per_key(
values: Optional[torch.Tensor],
) -> List[int]:
if length_per_key is None:
if values is not None and values.is_meta:
if len(keys) and values is not None and values.is_meta:
# create dummy lengths per key when on meta device
total_length = values.numel()
_length = [total_length // len(keys)] * len(keys)
Expand Down
7 changes: 7 additions & 0 deletions torchrec/sparse/tests/test_jagged_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1836,6 +1836,13 @@ def test_meta_device_compatibility(self) -> None:
keys=keys, values=values, weights=weights, offsets=offsets
)

# test empty keys case
kjt = KeyedJaggedTensor.from_lengths_sync(
keys=[],
values=torch.tensor([], device=torch.device("meta")),
lengths=torch.tensor([], device=torch.device("meta")),
)


class TestKeyedJaggedTensorScripting(unittest.TestCase):
def test_scriptable_forward(self) -> None:
Expand Down

0 comments on commit 737e283

Please sign in to comment.