Skip to content

Commit

Permalink
Revert _regroup in jagged_tensor (#2089)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2089

Fix S422574
backout D57500720 D58001114'

Post: https://fb.workplace.com/groups/gpuinference/permalink/2814805982001385/
Example failed job: f567662663

Reviewed By: xush6528

Differential Revision: D58310586

fbshipit-source-id: 1deacc6318298bf5c18e024560b86250b64a8709
  • Loading branch information
22quinn authored and facebook-github-bot committed Jun 8, 2024
1 parent 32cc3dd commit 733e42d
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 30 deletions.
24 changes: 5 additions & 19 deletions torchrec/sparse/jagged_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,21 +162,6 @@ def _all_keys_used_once(
return len(key_set) == len(group_set) == len(flat_keys) == len(flat_groups)


@torch.fx.wrap
def _regroup(
keyed_tensors: List["KeyedTensor"],
groups: List[List[str]],
) -> List[torch.Tensor]:
if len(keyed_tensors) == 0:
return []

# Fast path, one-to-one correspondence between keyed_tensors and groups
if _all_keys_used_once(keyed_tensors, groups) is True:
return _fbgemm_permute_pooled_embs(keyed_tensors, groups)
else: # Fallback to slow path otherwise
return _regroup_keyed_tensors(keyed_tensors, groups)


@torch.fx.wrap
def _fbgemm_permute_pooled_embs(
keyed_tensors: List["KeyedTensor"], groups: List[List["str"]]
Expand Down Expand Up @@ -2589,10 +2574,11 @@ def to_dict(self) -> Dict[str, torch.Tensor]:
def regroup(
keyed_tensors: List["KeyedTensor"], groups: List[List[str]]
) -> List[torch.Tensor]:
return _regroup(
keyed_tensors=keyed_tensors,
groups=groups,
)
# Fast path, one-to-one correspondence between keyed_tensors and groups
if _all_keys_used_once(keyed_tensors, groups) is True:
return _fbgemm_permute_pooled_embs(keyed_tensors, groups)
else: # Fallback to slow path otherwise
return _regroup_keyed_tensors(keyed_tensors, groups)

@staticmethod
def regroup_as_dict(
Expand Down
11 changes: 0 additions & 11 deletions torchrec/sparse/tests/test_jagged_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2150,17 +2150,6 @@ def test_to_dict_dim_1(self) -> None:
for key in keys:
self.assertTrue(torch.equal(kt[key], d[key]))

def test_regroup_empty_list(self) -> None:
keyed_tensors = []
groups = []

grouped_tensors = KeyedTensor.regroup(
keyed_tensors=keyed_tensors,
groups=groups,
)

self.assertEqual([], grouped_tensors)

def test_regroup_single_kt(self) -> None:
tensor_list = [torch.randn(2, 3) for i in range(5)]
key_dim = 1
Expand Down

0 comments on commit 733e42d

Please sign in to comment.