Skip to content

Commit

Permalink
Fix _length_per_key_from_stride_per_key empty cat
Browse files Browse the repository at this point in the history
Summary: torch.cat fails on empty list, guarding this case.

Reviewed By: zainhuda

Differential Revision: D54305327

fbshipit-source-id: 82877e4f307631eed816a60b35e8b1ca52104b32
  • Loading branch information
Ivan Kobzarev authored and facebook-github-bot committed Feb 28, 2024
1 parent f1c716a commit 49fbc5f
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions torchrec/sparse/jagged_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -728,9 +728,13 @@ def _length_per_key_from_stride_per_key(
1, stride_per_key_offsets, lengths
).tolist()
else:
return torch.cat(
[torch.sum(chunk).view(1) for chunk in torch.split(lengths, stride_per_key)]
).tolist()
tensor_list: List[torch.Tensor] = [
torch.sum(chunk).view(1) for chunk in torch.split(lengths, stride_per_key)
]
if len(tensor_list) == 0:
return []

return torch.cat(tensor_list).tolist()


def _maybe_compute_length_per_key(
Expand Down

0 comments on commit 49fbc5f

Please sign in to comment.