Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] *Prompt* logprobs support in Multi-step #8199

Merged
merged 66 commits into from
Sep 18, 2024
Merged
Show file tree
Hide file tree
Changes from 60 commits
Commits
Show all changes
66 commits
Select commit Hold shift + click to select a range
f97e0ae
added example
afeldman-nm Aug 21, 2024
f969241
wip:
afeldman-nm Aug 21, 2024
642d31b
first working attempt at logprobs
afeldman-nm Aug 21, 2024
a0ca262
merge; format
afeldman-nm Aug 21, 2024
ed97288
passing test; dataclass
afeldman-nm Aug 21, 2024
861e1b9
refactoring
afeldman-nm Aug 21, 2024
8bc0765
Merge branch 'main' into logprobs_merge
afeldman-nm Aug 21, 2024
a34d1ac
refactoring
afeldman-nm Aug 21, 2024
4cda5c0
Merge branch 'logprobs' into logprobs_merge
afeldman-nm Aug 21, 2024
ac8a39a
Merge branch 'main' into logprobs_merge
afeldman-nm Aug 21, 2024
1284327
removing example
afeldman-nm Aug 21, 2024
a6c1207
removed example from build pipeline
afeldman-nm Aug 21, 2024
fe42995
fixed one docstring; embedded NUM_LOGPROBS
afeldman-nm Aug 21, 2024
9fb5bbe
test refactor
afeldman-nm Aug 21, 2024
046a8b1
incremental refactors
afeldman-nm Aug 21, 2024
fa86efd
remove unnecessary conftest change
afeldman-nm Aug 21, 2024
1c0ffb6
Update vllm/model_executor/layers/sampler.py
afeldman-nm Aug 21, 2024
3babadb
refactor
afeldman-nm Aug 21, 2024
f502029
Merge branch 'afeldman-nm/logprobs' of https://github.com/neuralmagic…
afeldman-nm Aug 21, 2024
1875b37
test_multi_step comment
afeldman-nm Aug 21, 2024
3760a95
utils function docstrings
afeldman-nm Aug 21, 2024
d43308c
docstring refactors
afeldman-nm Aug 21, 2024
54db498
merge
afeldman-nm Aug 21, 2024
dfbbaf0
passing tests & formatted
afeldman-nm Aug 21, 2024
5eebfca
Merge branch 'main' into logprobs_merge
afeldman-nm Aug 21, 2024
5e23d9a
Merge branch 'main' into logprobs_merge
afeldman-nm Aug 22, 2024
717efa3
merge; format
afeldman-nm Aug 22, 2024
e0d59ce
removed incorrect SamplerOutput imports
afeldman-nm Aug 22, 2024
102fd92
formatting
afeldman-nm Aug 22, 2024
948f4ef
Update tests/multi_step/test_correctness.py
afeldman-nm Aug 22, 2024
6e6711f
fixed comment
afeldman-nm Aug 22, 2024
f61163e
merge; format
afeldman-nm Aug 23, 2024
1cc93dd
rename
afeldman-nm Aug 23, 2024
4995204
Merge branch 'logprobs' into logprobs_merge
afeldman-nm Aug 23, 2024
da5826b
test modification
afeldman-nm Aug 26, 2024
d4fb430
merge; format
afeldman-nm Aug 26, 2024
b6752e0
merge
afeldman-nm Aug 27, 2024
1e42656
formatting
afeldman-nm Aug 27, 2024
cd0fdf9
disabled logprobs pythonization when logprobs are disabled
afeldman-nm Aug 27, 2024
3fecbc4
wip
afeldman-nm Aug 27, 2024
67bd035
skip logprobs processing entirely when logprobs are not enabled; form…
afeldman-nm Aug 27, 2024
419659d
multi-step output processing; formatting
afeldman-nm Aug 27, 2024
55eaab9
wip
afeldman-nm Aug 27, 2024
bae1fb9
small fixes
afeldman-nm Aug 27, 2024
4c0c9f8
Added prompt logprobs tests
afeldman-nm Aug 27, 2024
c249e51
Merge branch 'main' into logprobs_merge
afeldman-nm Aug 27, 2024
8865bbd
wip
afeldman-nm Aug 27, 2024
e05670b
increased max wait time
afeldman-nm Aug 27, 2024
42af633
formatting
afeldman-nm Aug 27, 2024
81fedc1
upstream merge; upstream formatting issues
afeldman-nm Aug 28, 2024
c6f703d
upstream merge; passing logprobs tests
afeldman-nm Sep 5, 2024
55d7276
Merge branch 'main' into logprobs_merge
afeldman-nm Sep 11, 2024
d8a3f8c
upstream merge; format
afeldman-nm Sep 13, 2024
ad7f261
seems to be passing tests
afeldman-nm Sep 13, 2024
ac4b36f
comments
afeldman-nm Sep 13, 2024
dcad218
refactoring
afeldman-nm Sep 13, 2024
0f373ab
comment
afeldman-nm Sep 13, 2024
9bff9b6
updated prompt logprobs test comment
afeldman-nm Sep 13, 2024
1a28003
updated check_logprobs_close() comment
afeldman-nm Sep 13, 2024
c9d9537
small fix
afeldman-nm Sep 13, 2024
5d64bf3
Update tests/models/utils.py
afeldman-nm Sep 17, 2024
c152446
Update vllm/worker/multi_step_model_runner.py
afeldman-nm Sep 17, 2024
8197219
Update vllm/worker/multi_step_model_runner.py
afeldman-nm Sep 17, 2024
908709c
Update vllm/worker/multi_step_model_runner.py
afeldman-nm Sep 17, 2024
cce5394
addressing feedback
abf149 Sep 17, 2024
7f147d6
Merge branch 'main' into logprobs
abf149 Sep 17, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 52 additions & 32 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from transformers import (AutoModelForCausalLM, AutoTokenizer, BatchEncoding,
BatchFeature)

from tests.models.utils import (TokensTextLogprobs,
TokensTextLogprobsPromptLogprobs)
from vllm import LLM, SamplingParams
from vllm.assets.image import ImageAsset
from vllm.assets.video import VideoAsset
Expand All @@ -32,7 +34,6 @@
to_enc_dec_tuple_list, zip_enc_dec_prompts)
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from vllm.sequence import SampleLogprobs
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, cuda_device_count_stateless,
identity, is_cpu)

Expand Down Expand Up @@ -477,7 +478,7 @@ def generate_greedy_logprobs_limit(
audios: Optional[PromptAudioInput] = None,
videos: Optional[List[np.ndarray]] = None,
**kwargs: Any,
) -> List[Tuple[List[int], str, List[Dict[int, float]]]]:
) -> List[TokensTextLogprobs]:
all_logprobs: List[List[Dict[int, float]]] = []
all_output_ids: List[List[int]] = []
all_output_strs: List[str] = []
Expand Down Expand Up @@ -533,7 +534,7 @@ def generate_encoder_decoder_greedy_logprobs_limit(
max_tokens: int,
num_logprobs: int,
**kwargs: Any,
) -> List[Tuple[List[int], str, List[Dict[int, float]]]]:
) -> List[TokensTextLogprobs]:
'''
Greedy logprobs generation for vLLM encoder/decoder models
'''
Expand Down Expand Up @@ -661,14 +662,16 @@ def generate(
@staticmethod
def _final_steps_generate_w_logprobs(
req_outputs: List[RequestOutput],
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
outputs: List[Tuple[List[int], str, Optional[SampleLogprobs]]] = []
) -> List[TokensTextLogprobsPromptLogprobs]:
outputs: List[TokensTextLogprobsPromptLogprobs] = []
for req_output in req_outputs:
assert len(req_output.outputs) > 0
for sample in req_output.outputs:
output_str = sample.text
output_ids = list(sample.token_ids)
output_logprobs = sample.logprobs
outputs.append((output_ids, output_str, output_logprobs))
outputs.append((output_ids, output_str, output_logprobs,
req_output.prompt_logprobs))
return outputs

def generate_w_logprobs(
Expand All @@ -678,7 +681,8 @@ def generate_w_logprobs(
images: Optional[PromptImageInput] = None,
audios: Optional[PromptAudioInput] = None,
videos: Optional[PromptVideoInput] = None,
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
) -> Union[List[TokensTextLogprobs],
List[TokensTextLogprobsPromptLogprobs]]:
assert sampling_params.logprobs is not None

if images is not None:
Expand All @@ -703,21 +707,33 @@ def generate_w_logprobs(

req_outputs = self.model.generate(inputs,
sampling_params=sampling_params)
return self._final_steps_generate_w_logprobs(req_outputs)

toks_str_logsprobs_prompt_logprobs = (
self._final_steps_generate_w_logprobs(req_outputs))
# Omit prompt logprobs if not required by sampling params
return ([x[0:-1] for x in toks_str_logsprobs_prompt_logprobs]
if sampling_params.prompt_logprobs is None else
toks_str_logsprobs_prompt_logprobs)

def generate_encoder_decoder_w_logprobs(
self,
encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str, str]],
sampling_params: SamplingParams,
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
) -> Union[List[TokensTextLogprobs],
List[TokensTextLogprobsPromptLogprobs]]:
'''
Logprobs generation for vLLM encoder/decoder models
'''

assert sampling_params.logprobs is not None
req_outputs = self.model.generate(encoder_decoder_prompts,
sampling_params=sampling_params)
return self._final_steps_generate_w_logprobs(req_outputs)
toks_str_logsprobs_prompt_logprobs = (
self._final_steps_generate_w_logprobs(req_outputs))
# Omit prompt logprobs if not required by sampling params
return ([x[0:-1] for x in toks_str_logsprobs_prompt_logprobs]
if sampling_params.prompt_logprobs is None else
toks_str_logsprobs_prompt_logprobs)

def generate_greedy(
self,
Expand All @@ -735,44 +751,48 @@ def generate_greedy_logprobs(
prompts: List[str],
max_tokens: int,
num_logprobs: int,
num_prompt_logprobs: Optional[int] = None,
images: Optional[PromptImageInput] = None,
audios: Optional[PromptAudioInput] = None,
videos: Optional[PromptVideoInput] = None,
stop_token_ids: Optional[List[int]] = None,
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
greedy_logprobs_params = SamplingParams(temperature=0.0,
max_tokens=max_tokens,
logprobs=num_logprobs,
stop_token_ids=stop_token_ids)
outputs = self.generate_w_logprobs(prompts,
greedy_logprobs_params,
images=images,
audios=audios,
videos=videos)

return [(output_ids, output_str, output_logprobs)
for output_ids, output_str, output_logprobs in outputs]
) -> Union[List[TokensTextLogprobs],
List[TokensTextLogprobsPromptLogprobs]]:
greedy_logprobs_params = SamplingParams(
temperature=0.0,
max_tokens=max_tokens,
logprobs=num_logprobs,
prompt_logprobs=(num_prompt_logprobs),
stop_token_ids=stop_token_ids)

return self.generate_w_logprobs(prompts,
greedy_logprobs_params,
images=images,
audios=audios,
videos=videos)

def generate_encoder_decoder_greedy_logprobs(
self,
encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str, str]],
max_tokens: int,
num_logprobs: int,
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
greedy_logprobs_params = SamplingParams(temperature=0.0,
use_beam_search=False,
max_tokens=max_tokens,
logprobs=num_logprobs)
num_prompt_logprobs: Optional[int] = None,
) -> Union[List[TokensTextLogprobs],
List[TokensTextLogprobsPromptLogprobs]]:
greedy_logprobs_params = SamplingParams(
temperature=0.0,
use_beam_search=False,
max_tokens=max_tokens,
logprobs=num_logprobs,
prompt_logprobs=(num_prompt_logprobs),
)
'''
Greedy logprobs generation for vLLM encoder/decoder models
'''

outputs = self.generate_encoder_decoder_w_logprobs(
return self.generate_encoder_decoder_w_logprobs(
encoder_decoder_prompts, greedy_logprobs_params)

return [(output_ids, output_str, output_logprobs)
for output_ids, output_str, output_logprobs in outputs]

def generate_beam_search(
self,
prompts: List[str],
Expand Down
107 changes: 101 additions & 6 deletions tests/models/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import warnings
from typing import Dict, List, Optional, Sequence, Tuple, Union

from vllm.sequence import Logprob, SampleLogprobs
from vllm.sequence import Logprob, PromptLogprobs, SampleLogprobs

TokensText = Tuple[List[int], str]

Expand Down Expand Up @@ -34,20 +34,47 @@ def check_outputs_equal(
assert output_ids_0 == output_ids_1, fail_msg


# Representation of generated sequence as a tuple of
# * Token ID list
# * String
# * List of top sample logprobs for each sampled token
#
# Assumes prompt logprobs were not requested.
TokensTextLogprobs = Tuple[List[int], str, Optional[Union[List[Dict[int,
float]],
SampleLogprobs]]]

# Allow for tokens to be represented as str's rather than IDs
# Allow for tokens to be represented as str's rather than IDs;
# tuple of
# * Token string representations list
# * String
# * Optional list of top sample logprobs for each sampled token
#
# Assumes prompt logprobs were not requested.
TextTextLogprobs = Tuple[List[str], str, Optional[Union[List[Dict[str, float]],
List[Dict[str,
Logprob]]]]]

# Representation of generated sequence as a tuple of
# * Token ID list
# * String
# * Optional list of top sample logprobs for each sampled token
# * Optional list of top prompt logprobs for each prompt token
#
# Allows prompt logprobs to be requested.
TokensTextLogprobsPromptLogprobs = Tuple[
List[int], str, Optional[Union[List[Dict[int, float]], SampleLogprobs]],
Optional[Union[List[Optional[Dict[int, float]]], PromptLogprobs]]]


def check_logprobs_close(
*,
outputs_0_lst: Sequence[Union[TokensTextLogprobs, TextTextLogprobs]],
outputs_1_lst: Sequence[Union[TokensTextLogprobs, TextTextLogprobs]],
outputs_0_lst: Sequence[Union[TokensTextLogprobs,
TokensTextLogprobsPromptLogprobs,
TextTextLogprobs]],
outputs_1_lst: Sequence[Union[TokensTextLogprobs,
TokensTextLogprobsPromptLogprobs,
TextTextLogprobs]],
name_0: str,
name_1: str,
num_outputs_0_skip_tokens: int = 0,
Expand All @@ -57,6 +84,18 @@ def check_logprobs_close(
"""Compare the logprobs of two sequences generated by different models,
which should be similar but not necessarily equal.

How sample logprobs are compared:
* `always_check_logprobs == True`: set of highest-logprob token ids
must match between seq0 and seq1 at all sampled token offsets
* `always_check_logprobs == False`: highest-logprob token ids are
only compared at sampled token offsets for which generated token
ids don't match

Prompt logprobs must be provided either for both input sequences, or
for neither. If prompt logprobs are provided, then highest-logprob
prompt token ids must match between seq0 and seq1 at all prompt token
offsets.

Args:
outputs_0_lst: First sequence to compare
outputs_0_lst: Second sequence to compare
Expand All @@ -78,8 +117,64 @@ def check_logprobs_close(
for prompt_idx, (outputs_0,
outputs_1) in enumerate(zip(outputs_0_lst,
outputs_1_lst)):
output_ids_0, output_str_0, logprobs_0 = outputs_0
output_ids_1, output_str_1, logprobs_1 = outputs_1
if len(outputs_0) == 3:
assert len(outputs_1) == 3
# Break out tokens, text & sample logprobs
# (prompt logprobs were not provided)
output_ids_0, output_str_0, logprobs_0 = outputs_0
output_ids_1, output_str_1, logprobs_1 = outputs_1
elif len(outputs_0) == 4:
assert len(outputs_1) == 4
comaniac marked this conversation as resolved.
Show resolved Hide resolved
# Break out tokens, text, sample logprobs & prompt logprobs
(
output_ids_0,
output_str_0,
logprobs_0,
prompt_logprobs_0,
) = outputs_0
(
output_ids_1,
output_str_1,
logprobs_1,
prompt_logprobs_1,
) = outputs_1

# Test prompt logprobs closeness
if (prompt_logprobs_0 is not None
and prompt_logprobs_1 is not None):
# Both sequences' prompt logprobs lists are not `None``
# (although individual list elements may be `None`);
# for each token's logprobs:
for idx, (logprobs_elem_0, logprobs_elem_1) in enumerate(
zip(prompt_logprobs_0, prompt_logprobs_1)):
fail_msg = (
f"Prompt logprobs test:"
f"\n{name_0}:\tPrompt index {idx}\t{logprobs_elem_0}"
f"\n{name_1}:\tPrompt index {idx}\t{logprobs_elem_1}")

if logprobs_elem_0 is None:
# If the seq 0 token's logprobs are `None`,
# the seq 1 token's logprobs must be `None`
assert logprobs_elem_1 is None, fail_msg
else:
# If the seq 0 token's logprobs are not `None`,
# the seq 1 token's logprobs must not be `None`
assert logprobs_elem_1 is not None, fail_msg
# Logprobs check: top-k token choices must be the same
assert (set(logprobs_elem_0.keys()) == set(
logprobs_elem_1.keys())), fail_msg
else:
# Both sequence logprobs lists must be `None`
fail_msg = (f"Prompt logprobs test:"
f"\n{name_0}:\tlogprobs\t{prompt_logprobs_0}"
f"\n{name_1}:\tlogprobs\t{prompt_logprobs_1}")

assert (prompt_logprobs_0 is None
and prompt_logprobs_1 is None), fail_msg
else:
raise ValueError(f"Outputs tuple must have 3 or 4 elements but "
f"{len(outputs_0)} elements were provided: "
f"{outputs_0}")

if logprobs_0 is None:
logprobs_0 = [None] * len(output_ids_0)
Expand Down
Loading
Loading