Skip to content

Commit

Permalink
Add option to pass inverse indices to new KJT after permute
Browse files Browse the repository at this point in the history
Summary: allow inverse indices to be included in new KJT after permute. For modifying vbe KJT outside of input dist.

Differential Revision: D53732141
  • Loading branch information
joshuadeng authored and facebook-github-bot committed Feb 13, 2024
1 parent a860519 commit 2dd4359
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions torchrec/sparse/jagged_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1645,7 +1645,10 @@ def split(self, segments: List[int]) -> List["KeyedJaggedTensor"]:
return split_list

def permute(
self, indices: List[int], indices_tensor: Optional[torch.Tensor] = None
self,
indices: List[int],
indices_tensor: Optional[torch.Tensor] = None,
include_inverse_indices: bool = False,
) -> "KeyedJaggedTensor":

if indices_tensor is None:
Expand Down Expand Up @@ -1715,7 +1718,9 @@ def permute(
offset_per_key=None,
index_per_key=None,
jt_dict=None,
inverse_indices=None,
inverse_indices=self.inverse_indices_or_none()
if include_inverse_indices
else None,
)
return kjt

Expand Down

0 comments on commit 2dd4359

Please sign in to comment.