diff --git a/torchrec/sparse/jagged_tensor.py b/torchrec/sparse/jagged_tensor.py index ecff769da..d378aaef4 100644 --- a/torchrec/sparse/jagged_tensor.py +++ b/torchrec/sparse/jagged_tensor.py @@ -728,9 +728,13 @@ def _length_per_key_from_stride_per_key( 1, stride_per_key_offsets, lengths ).tolist() else: - return torch.cat( - [torch.sum(chunk).view(1) for chunk in torch.split(lengths, stride_per_key)] - ).tolist() + tensor_list: List[torch.Tensor] = [ + torch.sum(chunk).view(1) for chunk in torch.split(lengths, stride_per_key) + ] + if len(tensor_list) == 0: + return [] + + return torch.cat(tensor_list).tolist() def _maybe_compute_length_per_key(