diff --git a/torchrec/sparse/jagged_tensor.py b/torchrec/sparse/jagged_tensor.py index 4d34fc664..cba8bb1fc 100644 --- a/torchrec/sparse/jagged_tensor.py +++ b/torchrec/sparse/jagged_tensor.py @@ -162,6 +162,21 @@ def _all_keys_used_once( return len(key_set) == len(group_set) == len(flat_keys) == len(flat_groups) +@torch.fx.wrap +def _regroup( + keyed_tensors: List["KeyedTensor"], + groups: List[List[str]], +) -> List[torch.Tensor]: + if len(keyed_tensors) == 0: + return [] + + # Fast path, one-to-one correspondence between keyed_tensors and groups + if _all_keys_used_once(keyed_tensors, groups) is True: + return _fbgemm_permute_pooled_embs(keyed_tensors, groups) + else: # Fallback to slow path otherwise + return _regroup_keyed_tensors(keyed_tensors, groups) + + @torch.fx.wrap def _fbgemm_permute_pooled_embs( keyed_tensors: List["KeyedTensor"], groups: List[List["str"]] @@ -2577,11 +2592,10 @@ def to_dict(self) -> Dict[str, torch.Tensor]: def regroup( keyed_tensors: List["KeyedTensor"], groups: List[List[str]] ) -> List[torch.Tensor]: - # Fast path, one-to-one correspondence between keyed_tensors and groups - if _all_keys_used_once(keyed_tensors, groups) is True: - return _fbgemm_permute_pooled_embs(keyed_tensors, groups) - else: # Fallback to slow path otherwise - return _regroup_keyed_tensors(keyed_tensors, groups) + return _regroup( + keyed_tensors=keyed_tensors, + groups=groups, + ) @staticmethod def regroup_as_dict( diff --git a/torchrec/sparse/tests/test_jagged_tensor.py b/torchrec/sparse/tests/test_jagged_tensor.py index 9efeb444c..764f6e2a8 100644 --- a/torchrec/sparse/tests/test_jagged_tensor.py +++ b/torchrec/sparse/tests/test_jagged_tensor.py @@ -2150,6 +2150,17 @@ def test_to_dict_dim_1(self) -> None: for key in keys: self.assertTrue(torch.equal(kt[key], d[key])) + def test_regroup_empty_list(self) -> None: + keyed_tensors = [] + groups = [] + + grouped_tensors = KeyedTensor.regroup( + keyed_tensors=keyed_tensors, + groups=groups, + ) + + self.assertEqual([], grouped_tensors) + def test_regroup_single_kt(self) -> None: tensor_list = [torch.randn(2, 3) for i in range(5)] key_dim = 1