Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
DarkLight1337 committed Oct 18, 2024
1 parent 72f1e72 commit beae068
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 12 deletions.
20 changes: 10 additions & 10 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,33 +261,33 @@ def _resolve_task(
architectures = getattr(hf_config, "architectures", [])

task_support: Dict[Task, bool] = {
# NOTE: They are listed from highest to lowest priority, in case
# the model supports multiple of them
# NOTE: Listed from highest to lowest priority,
# in case the model supports multiple of them
"generate": ModelRegistry.is_text_generation_model(architectures),
"embedding": ModelRegistry.is_embedding_model(architectures),
}
supported_tasks: Set[Task] = {
task
for task, is_supported in task_support.items() if is_supported
}
supported_tasks_lst: List[Task] = [
task for task, is_supported in task_support.items() if is_supported
]
supported_tasks = set(supported_tasks_lst)

if task_option == "auto":
task = next(iter(supported_tasks))
selected_task = next(iter(supported_tasks_lst))

if len(supported_tasks) > 1:
logger.info(
"This model supports multiple tasks: %s. "
"Defaulting to '%s'.", supported_tasks, task)
"Defaulting to '%s'.", supported_tasks, selected_task)
else:
if task_option not in supported_tasks:
msg = (
f"This model does not support the '{task_option}' task. "
f"Supported tasks: {supported_tasks}")
raise ValueError(msg)

task = task_option
selected_task = task_option

return supported_tasks, task
return supported_tasks, selected_task

def _parse_quant_hf_config(self):
quant_cfg = getattr(self.hf_config, "quantization_config", None)
Expand Down
4 changes: 2 additions & 2 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ def generate(

supported_tasks = self.llm_engine.model_config.supported_tasks
if "generate" in supported_tasks:
messages += (
messages.append(
"Your model supports the 'generate' task, but is "
f"currently initialized for the '{task}' task. Please "
"initialize the model using `--task generate`.")
Expand Down Expand Up @@ -724,7 +724,7 @@ def encode(

supported_tasks = self.llm_engine.model_config.supported_tasks
if "embedding" in supported_tasks:
messages += (
messages.append(
"Your model supports the 'embedding' task, but is "
f"currently initialized for the '{task}' task. Please "
"initialize the model using `--task embedding`.")
Expand Down

0 comments on commit beae068

Please sign in to comment.