Skip to content

Commit

Permalink
fix #5966
Browse files Browse the repository at this point in the history
  • Loading branch information
hiyouga committed Nov 8, 2024
1 parent 707ff5a commit 8f3a322
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions src/llamafactory/chat/vllm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,6 @@ async def _generate(
prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired_messages, system, tools)
prompt_length = len(prompt_ids)

use_beam_search: bool = self.generating_args["num_beams"] > 1
temperature: Optional[float] = input_kwargs.pop("temperature", None)
top_p: Optional[float] = input_kwargs.pop("top_p", None)
top_k: Optional[float] = input_kwargs.pop("top_k", None)
Expand All @@ -126,6 +125,9 @@ async def _generate(
max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None)
stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None)

if length_penalty is not None:
logger.warning_rank0("Length penalty is not supported by the vllm engine yet.")

if "max_new_tokens" in self.generating_args:
max_tokens = self.generating_args["max_new_tokens"]
elif "max_length" in self.generating_args:
Expand All @@ -149,8 +151,6 @@ async def _generate(
temperature=temperature if temperature is not None else self.generating_args["temperature"],
top_p=(top_p if top_p is not None else self.generating_args["top_p"]) or 1.0, # top_p must > 0
top_k=top_k if top_k is not None else self.generating_args["top_k"],
use_beam_search=use_beam_search,
length_penalty=length_penalty if length_penalty is not None else self.generating_args["length_penalty"],
stop=stop,
stop_token_ids=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
max_tokens=max_tokens,
Expand Down

0 comments on commit 8f3a322

Please sign in to comment.