Skip to content

Commit

Permalink
Support empty tensor list for KeyedTensor.regroup (pytorch#2053)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#2053

When we configure empty sparse feature list, the embedding lookup results could be an empty list.

This is for unblocking the experiment in D57450566 (leaving only dense features in APS models)

Reviewed By: dstaay-fb

Differential Revision: D57500720

fbshipit-source-id: 056a1171ca73dba9895958f540b793d86bdb946f
  • Loading branch information
Shihao Xu authored and facebook-github-bot committed May 29, 2024
1 parent b4f7649 commit a71f049
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 5 deletions.
24 changes: 19 additions & 5 deletions torchrec/sparse/jagged_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,21 @@ def _all_keys_used_once(
return len(key_set) == len(group_set) == len(flat_keys) == len(flat_groups)


@torch.fx.wrap
def _regroup(
keyed_tensors: List["KeyedTensor"],
groups: List[List[str]],
) -> List[torch.Tensor]:
if len(keyed_tensors) == 0:
return []

# Fast path, one-to-one correspondence between keyed_tensors and groups
if _all_keys_used_once(keyed_tensors, groups) is True:
return _fbgemm_permute_pooled_embs(keyed_tensors, groups)
else: # Fallback to slow path otherwise
return _regroup_keyed_tensors(keyed_tensors, groups)


@torch.fx.wrap
def _fbgemm_permute_pooled_embs(
keyed_tensors: List["KeyedTensor"], groups: List[List["str"]]
Expand Down Expand Up @@ -2577,11 +2592,10 @@ def to_dict(self) -> Dict[str, torch.Tensor]:
def regroup(
keyed_tensors: List["KeyedTensor"], groups: List[List[str]]
) -> List[torch.Tensor]:
# Fast path, one-to-one correspondence between keyed_tensors and groups
if _all_keys_used_once(keyed_tensors, groups) is True:
return _fbgemm_permute_pooled_embs(keyed_tensors, groups)
else: # Fallback to slow path otherwise
return _regroup_keyed_tensors(keyed_tensors, groups)
return _regroup(
keyed_tensors=keyed_tensors,
groups=groups,
)

@staticmethod
def regroup_as_dict(
Expand Down
11 changes: 11 additions & 0 deletions torchrec/sparse/tests/test_jagged_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2150,6 +2150,17 @@ def test_to_dict_dim_1(self) -> None:
for key in keys:
self.assertTrue(torch.equal(kt[key], d[key]))

def test_regroup_empty_list(self) -> None:
keyed_tensors = []
groups = []

grouped_tensors = KeyedTensor.regroup(
keyed_tensors=keyed_tensors,
groups=groups,
)

self.assertEqual([], grouped_tensors)

def test_regroup_single_kt(self) -> None:
tensor_list = [torch.randn(2, 3) for i in range(5)]
key_dim = 1
Expand Down

0 comments on commit a71f049

Please sign in to comment.