From b66dcb8dd2302c4f2e2d657abe4787f4716433f2 Mon Sep 17 00:00:00 2001 From: Shihao Xu Date: Tue, 28 May 2024 17:04:20 -0700 Subject: [PATCH] Support empty tensor list for KeyedTensor.regroup (#2053) Summary: When we configure empty sparse feature list, the embedding lookup results could be an empty list. This is for unblocking the experiment in D57450566 (leaving only dense features in APS models) Reviewed By: dstaay-fb Differential Revision: D57500720 --- torchrec/sparse/jagged_tensor.py | 24 ++++++++++++++++----- torchrec/sparse/tests/test_jagged_tensor.py | 11 ++++++++++ 2 files changed, 30 insertions(+), 5 deletions(-) diff --git a/torchrec/sparse/jagged_tensor.py b/torchrec/sparse/jagged_tensor.py index 4d34fc664..cba8bb1fc 100644 --- a/torchrec/sparse/jagged_tensor.py +++ b/torchrec/sparse/jagged_tensor.py @@ -162,6 +162,21 @@ def _all_keys_used_once( return len(key_set) == len(group_set) == len(flat_keys) == len(flat_groups) +@torch.fx.wrap +def _regroup( + keyed_tensors: List["KeyedTensor"], + groups: List[List[str]], +) -> List[torch.Tensor]: + if len(keyed_tensors) == 0: + return [] + + # Fast path, one-to-one correspondence between keyed_tensors and groups + if _all_keys_used_once(keyed_tensors, groups) is True: + return _fbgemm_permute_pooled_embs(keyed_tensors, groups) + else: # Fallback to slow path otherwise + return _regroup_keyed_tensors(keyed_tensors, groups) + + @torch.fx.wrap def _fbgemm_permute_pooled_embs( keyed_tensors: List["KeyedTensor"], groups: List[List["str"]] @@ -2577,11 +2592,10 @@ def to_dict(self) -> Dict[str, torch.Tensor]: def regroup( keyed_tensors: List["KeyedTensor"], groups: List[List[str]] ) -> List[torch.Tensor]: - # Fast path, one-to-one correspondence between keyed_tensors and groups - if _all_keys_used_once(keyed_tensors, groups) is True: - return _fbgemm_permute_pooled_embs(keyed_tensors, groups) - else: # Fallback to slow path otherwise - return _regroup_keyed_tensors(keyed_tensors, groups) + return _regroup( + keyed_tensors=keyed_tensors, + groups=groups, + ) @staticmethod def regroup_as_dict( diff --git a/torchrec/sparse/tests/test_jagged_tensor.py b/torchrec/sparse/tests/test_jagged_tensor.py index 9efeb444c..764f6e2a8 100644 --- a/torchrec/sparse/tests/test_jagged_tensor.py +++ b/torchrec/sparse/tests/test_jagged_tensor.py @@ -2150,6 +2150,17 @@ def test_to_dict_dim_1(self) -> None: for key in keys: self.assertTrue(torch.equal(kt[key], d[key])) + def test_regroup_empty_list(self) -> None: + keyed_tensors = [] + groups = [] + + grouped_tensors = KeyedTensor.regroup( + keyed_tensors=keyed_tensors, + groups=groups, + ) + + self.assertEqual([], grouped_tensors) + def test_regroup_single_kt(self) -> None: tensor_list = [torch.randn(2, 3) for i in range(5)] key_dim = 1