Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Priority-based scheduling in async engine #8850

Merged
Merged
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,11 +411,15 @@ async def add_request_async(
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> None:
"""Async version of :meth:`add_request`."""
if lora_request is not None and not self.lora_config:
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
"not enabled!")
if priority > 0 and not self.scheduler_config.policy == "priority":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if priority > 0 and not self.scheduler_config.policy == "priority":
if priority and not self.scheduler_config.policy == "priority":

Negative is also allowed

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@schoennenbeck I see that you copied this from LLMEngine ... perhaps you could fix it there too?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, will do.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can also merge from main to fix the test failures now.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done and done. I also added the error handling to add_request in AsyncLLMEngine since it was previously missing.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@DarkLight1337 The tests still seem to fail for reasons not related to the PR itself.

raise ValueError(f"Got priority {priority} but "
"Priority scheduling is not enabled.")
if arrival_time is None:
arrival_time = time.time()

Expand All @@ -435,6 +439,7 @@ async def add_request_async(
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
trace_headers=trace_headers,
priority=priority,
Copy link
Contributor

@sfc-gh-zhwang sfc-gh-zhwang Sep 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i don't think we are passing policy="priority" anywhere, shall we add it to here so that user can actually leverage this. Otherwise, how are folks supposed to use this feature? @njhill

not self.scheduler_config.policy == "priority":

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think currently the only way of actually using the feature is to manually change the policy after creating the engine.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, but shall we allow folks to specify through args? Like other param in the scheduler_config? I feel that's better for wider adoption for this feature.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I misread your comment. Yes I think this needs to go in the EngineArgs or some similar place. However, I think that should be a separate PR.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah it's in SchedulerConfig but still needs to be wired in the EngineArgs to be able to enable it externally. Agree that could be done in a separate PR

)

async def check_health_async(self) -> None:
Expand Down Expand Up @@ -782,7 +787,8 @@ async def add_request(
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]:
if not self.is_running:
if self.start_engine_loop:
Expand All @@ -802,7 +808,9 @@ async def add_request(
arrival_time=arrival_time or time.time(),
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request)
prompt_adapter_request=prompt_adapter_request,
priority=priority,
)

return stream.generator()

Expand All @@ -813,7 +821,8 @@ async def generate(
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> AsyncGenerator[RequestOutput, None]:
"""Generate outputs for a request.

Expand All @@ -831,6 +840,8 @@ async def generate(
trace_headers: OpenTelemetry trace headers.
prompt_adapter_request: Prompt Adapter request to use
for generation, if any.
priority: The priority of the request.
Only applicable with priority scheduling.

Yields:
The output `RequestOutput` objects from the LLMEngine
Expand Down Expand Up @@ -886,6 +897,7 @@ async def generate(
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
priority=priority,
):
yield LLMEngine.validate_output(output, RequestOutput)

Expand Down
Loading