From 49fbc5f3a5b3c8ad68bf2f7b5f9bab8f6628b9d8 Mon Sep 17 00:00:00 2001 From: Ivan Kobzarev Date: Wed, 28 Feb 2024 09:40:30 -0800 Subject: [PATCH] Fix _length_per_key_from_stride_per_key empty cat Summary: torch.cat fails on empty list, guarding this case. Reviewed By: zainhuda Differential Revision: D54305327 fbshipit-source-id: 82877e4f307631eed816a60b35e8b1ca52104b32 --- torchrec/sparse/jagged_tensor.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/torchrec/sparse/jagged_tensor.py b/torchrec/sparse/jagged_tensor.py index ecff769da..d378aaef4 100644 --- a/torchrec/sparse/jagged_tensor.py +++ b/torchrec/sparse/jagged_tensor.py @@ -728,9 +728,13 @@ def _length_per_key_from_stride_per_key( 1, stride_per_key_offsets, lengths ).tolist() else: - return torch.cat( - [torch.sum(chunk).view(1) for chunk in torch.split(lengths, stride_per_key)] - ).tolist() + tensor_list: List[torch.Tensor] = [ + torch.sum(chunk).view(1) for chunk in torch.split(lengths, stride_per_key) + ] + if len(tensor_list) == 0: + return [] + + return torch.cat(tensor_list).tolist() def _maybe_compute_length_per_key(