-
-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Core][Bugfix][Perf] Refactor Server to Avoid AsyncLLMEngine
#8092
Changes from 1 commit
a7a6e43
ce7d159
569cd43
d99ce6f
8d6b2e9
14f3637
3b8311b
5e2eb74
aa62f2e
863081b
965b97a
8fd72f6
ddeb7c6
6539e10
4b111e4
a5ffd2c
1395872
938cf85
72d1d42
fcdcfc9
659169e
9886f3d
5b2f057
ae4564c
f9ccecc
89b730b
f3dc82b
ac97a9e
becd7ab
b7f49ed
58ae3b0
d0f9641
aa64042
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -124,17 +124,7 @@ def _init_engine(self, *args, | |
elif self.worker_use_ray: | ||
engine_class = ray.remote(num_cpus=0)(self._engine_class).remote | ||
else: | ||
# FIXME(woosuk): This is a bit hacky. Be careful when changing the | ||
# order of the arguments. | ||
cache_config = kwargs["cache_config"] | ||
parallel_config = kwargs["parallel_config"] | ||
if (parallel_config.tensor_parallel_size == 1 | ||
and parallel_config.pipeline_parallel_size == 1): | ||
num_gpus = cache_config.gpu_memory_utilization | ||
else: | ||
num_gpus = 1 | ||
engine_class = ray.remote(num_gpus=num_gpus)( | ||
self._engine_class).remote | ||
raise NotImplementedError("Not supported yet!") | ||
return engine_class(*args, **kwargs) | ||
|
||
def run_background_loop(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Inner loop? How should we make this propogate exceptions and do things like have the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it even worth pulling over the IIUC that's all there to prevent misconfiguration errors from putting the engine into a state where it only ever responds with an exception, but it looks like it's done pretty bluntly where any exception at all from the model executor will kill the loop. If we want to keep that behavior, we could simply raise from here and exit (after notifying the clients of the exception), and let the frontend die as well. |
||
|
@@ -218,7 +208,7 @@ def stream_outputs(self, request_outputs: List[RequestOutput]): | |
self.output_socket.send_multipart((pickle.dumps(request_outputs), ), | ||
copy=False) | ||
|
||
def awk_check_health(self): | ||
def ack_check_health(self): | ||
self.health_socket.send_multipart( | ||
(pickle.dumps(VLLM_RPC_SUCCESS_STR), ), copy=False) | ||
|
||
|
@@ -255,8 +245,7 @@ def _handle_utility_request(self, request: RPCUtilityRequest): | |
self.engine.do_log_stats() | ||
elif request == RPCUtilityRequest.CHECK_HEALTH: | ||
self.engine.check_health() | ||
# Special check_health channel for awk check health. | ||
self.awk_check_health() | ||
self.ack_check_health() | ||
|
||
|
||
def run_mp_engine(engine_args: AsyncEngineArgs, usage_context: UsageContext, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is
worker_use_ray
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it is the same as
--distributed_executor_backend=ray