Skip to content

Commit

Permalink
Avoid unnecessary multi-modal input data copy when len(batch) == 1 (#…
Browse files Browse the repository at this point in the history
…12722)

Signed-off-by: imkero <[email protected]>
  • Loading branch information
imkero authored Feb 4, 2025
1 parent 6469038 commit 62467a8
Showing 1 changed file with 16 additions and 0 deletions.
16 changes: 16 additions & 0 deletions vllm/multimodal/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,11 @@ def build_elems(self, batch: NestedTensors) -> list[MultiModalFieldElem]:

def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
if len(batch) == 1:
# An optimization when `batch` contains only one tensor:
# - produce exactly same result as `torch.stack(batch)`
# - will achieve zero-copy if the tensor is contiguous
return batch[0].unsqueeze(0).contiguous()
first_shape = batch[0].shape
if all(elem.shape == first_shape for elem in batch):
return torch.stack(batch)
Expand All @@ -235,6 +240,11 @@ def build_elems(

def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
if len(batch) == 1:
# An optimization when `batch` contains only one tensor:
# - produce exactly same result as `torch.concat(batch)`
# - will achieve zero-copy if the tensor is contiguous
return batch[0].contiguous()
first_shape = batch[0].shape
if all(elem.shape[1:] == first_shape[1:] for elem in batch):
return torch.concat(batch)
Expand Down Expand Up @@ -407,6 +417,12 @@ def _try_stack(nested_tensors: NestedTensors) -> NestedTensors:
return stacked

tensors_ = cast(list[torch.Tensor], stacked)
if len(tensors_) == 1:
# An optimization when `tensors_` contains only one tensor:
# - produce exactly same result as `torch.stack(tensors_)`
# - will achieve zero-copy if the tensor is contiguous
return tensors_[0].unsqueeze(0).contiguous()

if any(t.shape != tensors_[0].shape for t in tensors_):
# The tensors have incompatible shapes and can't be stacked.
return tensors_
Expand Down

0 comments on commit 62467a8

Please sign in to comment.