From 16eff28d0c9531c59b653fe50821e06fbe62ee0d Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Sat, 21 Sep 2024 10:30:39 +0800 Subject: [PATCH] [Core] Factor out common code in `SequenceData` and `Sequence` (#8675) --- tests/samplers/test_sampler.py | 27 +++----- tests/spec_decode/utils.py | 12 +--- tests/test_logits_processor.py | 8 +-- tests/test_sequence.py | 7 +-- .../test_encoder_decoder_model_runner.py | 22 +++---- tests/worker/test_model_runner.py | 16 ++--- vllm/inputs/registry.py | 8 +-- vllm/sequence.py | 61 +++++++++++-------- 8 files changed, 64 insertions(+), 97 deletions(-) diff --git a/tests/samplers/test_sampler.py b/tests/samplers/test_sampler.py index 19a5ca5e27502..308b708feab71 100644 --- a/tests/samplers/test_sampler.py +++ b/tests/samplers/test_sampler.py @@ -1,6 +1,5 @@ import itertools import random -from array import array from typing import Dict, List, Optional, Tuple from unittest.mock import Mock, patch @@ -12,8 +11,7 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_random_seed -from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SamplingParams, - SequenceData, SequenceGroupMetadata) +from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata from vllm.utils import Counter, is_pin_memory_available @@ -59,9 +57,7 @@ def _do_sample( SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=True, - seq_data={ - 0: SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [1, 2, 3])) - }, + seq_data={0: SequenceData.from_seqs([1, 2, 3])}, sampling_params=sampling_params, block_tables={0: [1]}, )) @@ -205,9 +201,8 @@ def create_sampling_params(min_tokens, return sampling_params def create_sequence_data(num_input=3, num_generated=0): - seq_data = SequenceData( - array(VLLM_TOKEN_ID_ARRAY_TYPE, - random.choices(range(0, VOCAB_SIZE), k=num_input))) + seq_data = SequenceData.from_seqs( + random.choices(range(0, VOCAB_SIZE), k=num_input)) if num_generated > 0: seq_data.output_token_ids = random.choices(range(0, VOCAB_SIZE), k=num_generated) @@ -511,9 +506,7 @@ def test_sampler_mixed(seed: int, device: str): SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=True, - seq_data={ - 0: SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [1, 2, 3])) - }, + seq_data={0: SequenceData.from_seqs([1, 2, 3])}, sampling_params=sampling_params, block_tables={0: [1]}, )) @@ -613,9 +606,7 @@ def test_sampler_top_k_top_p(seed: int, device: str): SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=True, - seq_data={ - 0: SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [1, 2, 3])) - }, + seq_data={0: SequenceData.from_seqs([1, 2, 3])}, sampling_params=SamplingParams( temperature=1, top_k=top_k, @@ -699,11 +690,7 @@ def test_sampling_params(sampling_params: List[SamplingParams]): SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=True, - seq_data={ - 0: - SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, - [1, 2, 3])) - }, + seq_data={0: SequenceData.from_seqs([1, 2, 3])}, sampling_params=sampling_params[i], block_tables={0: [1]}, )) diff --git a/tests/spec_decode/utils.py b/tests/spec_decode/utils.py index 9075a433eb66e..f17e872881633 100644 --- a/tests/spec_decode/utils.py +++ b/tests/spec_decode/utils.py @@ -1,4 +1,3 @@ -from array import array from itertools import count from typing import Callable, Dict, List, Optional from typing import Sequence as GenericSequence @@ -11,8 +10,7 @@ from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.utils import set_random_seed from vllm.sampling_params import SamplingParams -from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, - CompletionSequenceGroupOutput, Logprob, +from vllm.sequence import (CompletionSequenceGroupOutput, Logprob, SequenceData, SequenceGroupMetadata, SequenceOutput) from vllm.utils import get_distributed_init_method, get_ip, get_open_port from vllm.worker.cache_engine import CacheEngine @@ -138,12 +136,8 @@ def create_seq_group_metadata_from_prompts( request_id=str(i), is_prompt=len(cont_token_ids) == 0, seq_data={ - i: - SequenceData( - array(VLLM_TOKEN_ID_ARRAY_TYPE, prompt_token_ids[:]), - _output_token_ids=array(VLLM_TOKEN_ID_ARRAY_TYPE, - cont_token_ids[:]), - ), + i: SequenceData.from_seqs(prompt_token_ids[:], + cont_token_ids[:]), }, sampling_params=SamplingParams(temperature=0.0, ), block_tables={i: block_allocations[i][:]}, diff --git a/tests/test_logits_processor.py b/tests/test_logits_processor.py index 1ce49a50688ae..39c1c38151fd0 100644 --- a/tests/test_logits_processor.py +++ b/tests/test_logits_processor.py @@ -1,5 +1,4 @@ import random -from array import array from typing import Tuple from unittest.mock import patch @@ -9,8 +8,7 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_random_seed -from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SamplingParams, - SequenceData, SequenceGroupMetadata) +from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata from vllm.utils import is_pin_memory_available @@ -71,9 +69,7 @@ def pick_ith(token_ids, logits): SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=True, - seq_data={ - 0: SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [1, 2, 3])) - }, + seq_data={0: SequenceData.from_seqs([1, 2, 3])}, sampling_params=SamplingParams(temperature=0, logits_processors=[pick_ith]), block_tables={0: [1]}, diff --git a/tests/test_sequence.py b/tests/test_sequence.py index 348ba7dd41d99..30e53a180ea31 100644 --- a/tests/test_sequence.py +++ b/tests/test_sequence.py @@ -1,10 +1,7 @@ -from array import array - import pytest from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, - CompletionSequenceGroupOutput, SequenceData, +from vllm.sequence import (CompletionSequenceGroupOutput, SequenceData, SequenceOutput) from .core.utils import create_dummy_prompt @@ -58,7 +55,7 @@ def test_sampler_output_eq(sample_outputs): def test_sequence_data_prefill(): - seq_data = SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [1, 2, 3, 4])) + seq_data = SequenceData.from_seqs([1, 2, 3, 4]) assert seq_data.get_num_uncomputed_tokens() == 4 assert seq_data.get_num_computed_tokens() == 0 # advance by 2 diff --git a/tests/worker/test_encoder_decoder_model_runner.py b/tests/worker/test_encoder_decoder_model_runner.py index 27cdf5f339ede..3dccc1b325d95 100644 --- a/tests/worker/test_encoder_decoder_model_runner.py +++ b/tests/worker/test_encoder_decoder_model_runner.py @@ -1,13 +1,11 @@ import itertools -from array import array from typing import List import pytest import torch from vllm.engine.arg_utils import EngineArgs -from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SamplingParams, - SequenceData, SequenceGroupMetadata) +from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata from vllm.utils import is_cpu, make_tensor_with_pad from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner from vllm.worker.model_runner import _get_graph_batch_size @@ -119,12 +117,10 @@ def test_prepare_prompt(batch_size): # make sure all tokens fit into one block seq_len = i % (model_runner.block_size - 1) + 1 seq_lens.append(seq_len) - seq_data = SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, - range(seq_len))) + seq_data = SequenceData.from_seqs(range(seq_len)) encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1 encoder_seq_lens.append(encoder_seq_len) - encoder_seq_data = SequenceData( - array(VLLM_TOKEN_ID_ARRAY_TYPE, range(encoder_seq_len))) + encoder_seq_data = SequenceData.from_seqs(range(encoder_seq_len)) seq_group_metadata = SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=True, @@ -317,11 +313,9 @@ def test_prepare_decode(batch_size, multiple_seqs_per_seq_group): for i in range(batch_size): # make sure all tokens fit into one block seq_len = i % (model_runner.block_size - 1) + 1 - seq_data = SequenceData( - array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(seq_len)))) + seq_data = SequenceData.from_seqs(range(seq_len)) encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1 - encoder_seq_data = SequenceData( - array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(encoder_seq_len)))) + encoder_seq_data = SequenceData.from_seqs(range(encoder_seq_len)) seq_group_metadata = SequenceGroupMetadata( request_id=f"test_{i}", @@ -523,11 +517,9 @@ def test_prepare_decode_cuda_graph(batch_size, multiple_seqs_per_seq_group): for i in range(batch_size): # make sure all tokens fit into one block seq_len = i % (model_runner.block_size - 1) + 1 - seq_data = SequenceData( - array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(seq_len)))) + seq_data = SequenceData.from_seqs(range(seq_len)) encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1 - encoder_seq_data = SequenceData( - array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(encoder_seq_len)))) + encoder_seq_data = SequenceData.from_seqs(range(encoder_seq_len)) seq_group_metadata = SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=False, diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index 42b2337f46914..fe97199bac62d 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -1,4 +1,3 @@ -from array import array from typing import List import pytest @@ -8,8 +7,7 @@ init_distributed_environment) from vllm.engine.arg_utils import EngineArgs from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SamplingParams, - SequenceData, SequenceGroupMetadata) +from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata from vllm.utils import get_open_port from vllm.worker.model_runner import ModelRunner, _get_graph_batch_size @@ -48,8 +46,7 @@ def test_prepare_prompt(batch_size): # make sure all tokens fit into one block seq_len = i % (model_runner.block_size - 1) + 1 seq_lens.append(seq_len) - seq_data = SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, - range(seq_len))) + seq_data = SequenceData.from_seqs(range(seq_len)) seq_group_metadata = SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=True, @@ -166,8 +163,7 @@ def test_prepare_decode_cuda_graph(batch_size): # make sure all tokens fit into one block context_len = i % (model_runner.block_size - 1) + 1 context_lens.append(context_len) - seq_data = SequenceData( - array(VLLM_TOKEN_ID_ARRAY_TYPE, range(context_len))) + seq_data = SequenceData.from_seqs(range(context_len)) seq_data.update_num_computed_tokens(context_len) # Append one token ID since prefill is finished. seq_data.append_token_id(1, 0) @@ -326,8 +322,7 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init): # make sure all tokens fit into one block seq_len = i % (model_runner.block_size - 1) + 1 seq_lens.append(seq_len) - seq_data = SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, - range(seq_len))) + seq_data = SequenceData.from_seqs(range(seq_len)) seq_group_metadata = SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=True, @@ -343,8 +338,7 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init): for i in range(prefill_batch_size, batch_size): # make sure all tokens fit into one block context_len = i % (model_runner.block_size - 1) + 1 - prompt_toks = array(VLLM_TOKEN_ID_ARRAY_TYPE, range(context_len)) - seq_data = SequenceData(prompt_toks) + seq_data = SequenceData.from_seqs(range(context_len)) seq_data.append_token_id(1, 0) seq_data.update_num_computed_tokens(context_len) seq_group_metadata = SequenceGroupMetadata( diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index ae6c6c05d9f72..a0f02ba29e219 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -1,5 +1,4 @@ import functools -from array import array from collections import UserDict from dataclasses import dataclass from typing import (TYPE_CHECKING, Any, Callable, Dict, Mapping, Optional, @@ -22,10 +21,6 @@ C = TypeVar("C", bound=PretrainedConfig, default=PretrainedConfig) -# NOTE: This has to match with sequence.py's VLLM_TOKEN_ID_ARRAY_TYPE. -# We cannot import it here because of circular dependencies. -VLLM_TOKEN_ID_ARRAY_TYPE = "l" - @dataclass(frozen=True) class InputContext: @@ -130,8 +125,7 @@ def _default_dummy_data_factory( # Avoid circular import from vllm.sequence import SequenceData - dummy_seq_data = SequenceData( - array(VLLM_TOKEN_ID_ARRAY_TYPE, [0]) * seq_len) + dummy_seq_data = SequenceData.from_counts({0: seq_len}) dummy_multi_modal_data = None return dummy_seq_data, dummy_multi_modal_data diff --git a/vllm/sequence.py b/vllm/sequence.py index 07ceccf123541..f849211c317ca 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -5,6 +5,7 @@ from array import array from collections import defaultdict from dataclasses import dataclass +from functools import cached_property, reduce from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional from typing import Sequence as GenericSequence from typing import Set, Tuple, Union, cast @@ -169,6 +170,35 @@ class SequenceData(msgspec.Struct, # It is used to compute mrope_position_ids. _mrope_position_delta: Optional[int] = None + @staticmethod + def from_counts(counts_by_token: Mapping[int, int]) -> "SequenceData": + if len(counts_by_token) == 0: + return SequenceData.from_seqs([]) + + arrs = [ + array(VLLM_TOKEN_ID_ARRAY_TYPE, [token_id]) * count + for token_id, count in counts_by_token.items() + ] + + return SequenceData(reduce(array.__add__, arrs)) + + @staticmethod + def from_seqs( + prompt_token_ids: GenericSequence[int], + output_token_ids: Optional[GenericSequence[int]] = None, + ) -> "SequenceData": + prompt_token_ids_arr = array(VLLM_TOKEN_ID_ARRAY_TYPE, + prompt_token_ids) + + if output_token_ids is None: + return SequenceData(prompt_token_ids_arr) + + output_token_ids_arr = array(VLLM_TOKEN_ID_ARRAY_TYPE, + output_token_ids) + + return SequenceData(prompt_token_ids_arr, + _output_token_ids=output_token_ids_arr) + def __post_init__(self) -> None: assert self._prompt_token_ids.typecode == "l" assert self._output_token_ids.typecode == "l" @@ -370,8 +400,6 @@ def __init__( self.lora_request = lora_request self.prompt_adapter_request = prompt_adapter_request self.from_decoder_prompt = from_decoder_prompt - self._prompt: Optional[str] = None - self._prompt_token_ids: Optional[List[int]] = None # For decoder-only models, a Sequence is constructed # from an LLMInputs instance (the `inputs` arg.) @@ -400,8 +428,7 @@ def __init__( f"invalid input {inputs}; did you forget the " "encoder input prompt fields?") - self.data = SequenceData( - array(VLLM_TOKEN_ID_ARRAY_TYPE, self.prompt_token_ids)) + self.data = SequenceData.from_seqs(self.prompt_token_ids) self.output_logprobs: SampleLogprobs = [] self.output_text = "" @@ -422,37 +449,23 @@ def __init__( def n_blocks(self) -> int: return (self.get_len() + self.block_size - 1) // self.block_size - @property + @cached_property def prompt(self) -> Optional[str]: - if self._prompt is not None: - # Reuse precomputed prompt string - return self._prompt - - # Select decoder or encoder input prompt str, - # as appropriate + # Select decoder or encoder input prompt str, as appropriate prompt_key: str = ("prompt" if self.from_decoder_prompt else "encoder_prompt") - # Cache prompt - self._prompt = cast(Optional[str], self.inputs.get(prompt_key)) - return self._prompt + return cast(Optional[str], self.inputs.get(prompt_key)) - @property + @cached_property def prompt_token_ids(self) -> List[int]: - if self._prompt_token_ids is not None: - # Reuse precomputed prompt token ids - return self._prompt_token_ids - - # Select decoder or encoder input prompt - # token ids, as appropriate + # Select decoder or encoder input prompt token ids, as appropriate prompt_token_ids_key: str = ("prompt_token_ids" if self.from_decoder_prompt else "encoder_prompt_token_ids") # Cache computed prompt token ids - self._prompt_token_ids = cast(List[int], - self.inputs.get(prompt_token_ids_key)) - return self._prompt_token_ids + return cast(List[int], self.inputs.get(prompt_token_ids_key)) @property def multi_modal_data(self) -> "MultiModalDataDict":