diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index f4943cb38da44..5a10e72e5c165 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -396,6 +396,7 @@ def beam_search( beam_width: int, max_tokens: int, ignore_eos: bool = False, + temperature: float = 0.0, ) -> List[BeamSearchOutput]: """ Generate sequences using beam search. @@ -405,6 +406,7 @@ def beam_search( of token IDs. beam_width: The number of beams to keep at each step. max_tokens: The max number of tokens to generate for each prompt. + temperature: The temperature to use for generation. TODO: how does beam search work together with length penalty, frequency penalty, and stopping criteria, etc.? @@ -416,7 +418,7 @@ def beam_search( # at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa beam_search_params = SamplingParams(logprobs=2 * beam_width, max_tokens=1, - temperature=0.0) + temperature=temperature) instances: List[BeamSearchInstance] = [] for prompt in prompts: