Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend, Core] Adding stop and stop_token_ids for beam search. #9264

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 18 additions & 12 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1045,12 +1045,13 @@ async def beam_search(
request_id: str,
params: BeamSearchParams,
) -> AsyncGenerator[RequestOutput, None]:

beam_width = params.beam_width
max_tokens = params.max_tokens
ignore_eos = params.ignore_eos
temperature = params.temperature
length_penalty = params.length_penalty
stop = params.stop
stop_token_ids = params.stop_token_ids

tokenizer = await self.get_tokenizer()
tokenizedPrompt = prompt if isinstance(
Expand All @@ -1060,9 +1061,14 @@ async def beam_search(
sort_beams_key = create_sort_beams_key_function(
tokenizer.eos_token_id, length_penalty)

beam_search_params = SamplingParams(logprobs=2 * beam_width,
max_tokens=1,
temperature=temperature)
beam_search_params = SamplingParams(
logprobs=2 * beam_width,
max_tokens=1,
temperature=temperature,
stop=stop,
stop_token_ids=stop_token_ids,
ignore_eos=ignore_eos
)
all_beams = [BeamSearchSequence(tokens=tokenizedPrompt, cum_logprob=0)]
completed = []

Expand All @@ -1073,22 +1079,18 @@ async def beam_search(
]

tasks = []

request_id = f"beam_search-{random_uuid()}"
request_id_base = f"beam_search-{random_uuid()}"
for i, individual_prompt in enumerate(prompts_batch):
request_id_item = f"{request_id}-{i}"
request_id_item = f"{request_id_base}-{i}"
task = asyncio.create_task(
collect_from_async_generator(
self.generate(individual_prompt, beam_search_params,
request_id_item)))
tasks.append(task)

output = await asyncio.gather(*tasks)

output = [x[0] for x in output]

logger.info(output)

new_beams = []
for i, current_beam in enumerate(all_beams):
result = output[i]
Expand All @@ -1101,15 +1103,17 @@ async def beam_search(
cum_logprob=current_beam.cum_logprob +
logprob_obj.logprob)

if token_id == tokenizer.eos_token_id and \
not ignore_eos:
if result.outputs[0].finish_reason == "stop":
completed.append(new_beam)
else:
new_beams.append(new_beam)

sorted_beams = sorted(new_beams, key=sort_beams_key, reverse=True)
all_beams = sorted_beams[:beam_width]

if not all_beams:
break

completed.extend(all_beams)
sorted_completed = sorted(completed, key=sort_beams_key, reverse=True)
best_beams = sorted_completed[:beam_width]
Expand All @@ -1132,6 +1136,8 @@ async def beam_search(
finished=True,
prompt_token_ids=tokenizedPrompt,
prompt_logprobs=None)

logger.info(beam_search_output)

yield LLMEngine.validate_output(beam_search_output, RequestOutput)

Expand Down
28 changes: 16 additions & 12 deletions vllm/engine/multiprocessing/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,12 +450,13 @@ async def beam_search(
request_id: str,
params: BeamSearchParams,
) -> AsyncGenerator[RequestOutput, None]:

beam_width = params.beam_width
max_tokens = params.max_tokens
ignore_eos = params.ignore_eos
temperature = params.temperature
length_penalty = params.length_penalty
stop = params.stop
stop_token_ids = params.stop_token_ids

tokenizer = await self.get_tokenizer(lora_request=None)
tokenizedPrompt = prompt if isinstance(
Expand All @@ -465,9 +466,14 @@ async def beam_search(
sort_beams_key = create_sort_beams_key_function(
tokenizer.eos_token_id, length_penalty)

beam_search_params = SamplingParams(logprobs=2 * beam_width,
max_tokens=1,
temperature=temperature)
beam_search_params = SamplingParams(
logprobs=2 * beam_width,
max_tokens=1,
temperature=temperature,
stop=stop,
stop_token_ids=stop_token_ids,
ignore_eos=ignore_eos
)
all_beams = [BeamSearchSequence(tokens=tokenizedPrompt, cum_logprob=0)]
completed = []

Expand All @@ -478,22 +484,18 @@ async def beam_search(
]

tasks = []

request_id = f"beam_search-{random_uuid()}"
request_id_base = f"beam_search-{random_uuid()}"
for i, individual_prompt in enumerate(prompts_batch):
request_id_item = f"{request_id}-{i}"
request_id_item = f"{request_id_base}-{i}"
task = asyncio.create_task(
collect_from_async_generator(
self.generate(individual_prompt, beam_search_params,
request_id_item)))
tasks.append(task)

output = await asyncio.gather(*tasks)

output = [x[0] for x in output]

logger.info(output)

new_beams = []
for i, current_beam in enumerate(all_beams):
result = output[i]
Expand All @@ -506,15 +508,17 @@ async def beam_search(
cum_logprob=current_beam.cum_logprob +
logprob_obj.logprob)

if token_id == tokenizer.eos_token_id and \
not ignore_eos:
if result.outputs[0].finish_reason == "stop":
completed.append(new_beam)
else:
new_beams.append(new_beam)

sorted_beams = sorted(new_beams, key=sort_beams_key, reverse=True)
all_beams = sorted_beams[:beam_width]

if not all_beams:
break
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was this a bug?

Better to move it a couple of lines up and use if not new_beams?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll be honest I got confused by the similar check in llm.py and decided to do something similar here. Not sure if it is needed, and we can do as you say, or we can leave it outside altogether.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK thanks, I guess we should figure out which is correct based on the prior/expected behavior...


completed.extend(all_beams)
sorted_completed = sorted(completed, key=sort_beams_key, reverse=True)
best_beams = sorted_completed[:beam_width]
Expand Down
21 changes: 14 additions & 7 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,8 +373,10 @@ def beam_search(
beam_width = params.beam_width
max_tokens = params.max_tokens
temperature = params.temperature
ignore_eos = params.ignore_eos
length_penalty = params.length_penalty
stop = params.stop
stop_token_ids = params.stop_token_ids
ignore_eos = params.ignore_eos

def sort_beams_key(x: BeamSearchSequence) -> float:
return get_beam_search_score(x.tokens, x.cum_logprob,
Expand All @@ -385,9 +387,14 @@ def sort_beams_key(x: BeamSearchSequence) -> float:
# generate 2 * beam_width candidates at each step
# following the huggingface transformers implementation
# at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa
beam_search_params = SamplingParams(logprobs=2 * beam_width,
max_tokens=1,
temperature=temperature)
beam_search_params = SamplingParams(
logprobs=2 * beam_width,
max_tokens=1,
temperature=temperature,
stop=stop,
stop_token_ids=stop_token_ids,
ignore_eos=ignore_eos
)
instances: List[BeamSearchInstance] = []

for prompt in prompts:
Expand Down Expand Up @@ -426,7 +433,7 @@ def sort_beams_key(x: BeamSearchSequence) -> float:
result = output[i]

if result.outputs[0].logprobs is not None:
# if `result.outputs[0].logprobs` is None, it means
# if `result.outputs[0].logprobs`is None, it means
# the sequence is completed because of the max-model-len
# or abortion. we don't need to add it to the new beams.
logprobs = result.outputs[0].logprobs[0]
Expand All @@ -436,11 +443,11 @@ def sort_beams_key(x: BeamSearchSequence) -> float:
cum_logprob=current_beam.cum_logprob +
logprob_obj.logprob)

if token_id == tokenizer.eos_token_id and \
not ignore_eos:
if result.outputs[0].finish_reason == "stop":
instance.completed.append(new_beam)
else:
instance_new_beams.append(new_beam)

sorted_beams = sorted(instance_new_beams,
key=sort_beams_key,
reverse=True)
Expand Down
2 changes: 2 additions & 0 deletions vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,6 +594,8 @@ def to_beam_search_params(self,
ignore_eos=self.ignore_eos,
temperature=temperature,
length_penalty=self.length_penalty,
stop=self.stop,
stop_token_ids=self.stop_token_ids,
)

def to_sampling_params(self, default_max_tokens: int) -> SamplingParams:
Expand Down
2 changes: 2 additions & 0 deletions vllm/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,3 +488,5 @@ class BeamSearchParams(
ignore_eos: bool = False
temperature: float = 0.0
length_penalty: float = 1.0
stop: Optional[Union[str, List[str]]] = None
stop_token_ids: Optional[List[int]] = None
Loading