Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Non-blocking client API #158

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
37 changes: 24 additions & 13 deletions mii/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from mii.grpc_related.proto import modelresponse_pb2, modelresponse_pb2_grpc
from mii.constants import GRPC_MAX_MSG_SIZE
from mii.method_table import GRPC_METHOD_TABLE
from mii.event_loop import get_event_loop


def _get_deployment_info(deployment_name):
Expand Down Expand Up @@ -57,7 +58,7 @@ class MIIClient():
Client to send queries to a single endpoint.
"""
def __init__(self, task_name, host, port):
self.asyncio_loop = asyncio.get_event_loop()
self.asyncio_loop = get_event_loop()
channel = create_channel(host, port)
self.stub = modelresponse_pb2_grpc.ModelResponseStub(channel)
self.task = get_task(task_name)
Expand All @@ -74,17 +75,22 @@ async def _request_async_response(self, request_dict, **query_kwargs):
proto_response
) if "unpack_response_from_proto" in conversions else proto_response

def query(self, request_dict, **query_kwargs):
return self.asyncio_loop.run_until_complete(
def query_async(self, request_dict, **query_kwargs):
return asyncio.run_coroutine_threadsafe(
self._request_async_response(request_dict,
**query_kwargs))
**query_kwargs),
get_event_loop())

def query(self, request_dict, **query_kwargs):
return self.query_async(request_dict, **query_kwargs).result()

async def terminate_async(self):
await self.stub.Terminate(
modelresponse_pb2.google_dot_protobuf_dot_empty__pb2.Empty())

def terminate(self):
self.asyncio_loop.run_until_complete(self.terminate_async())
asyncio.run_coroutine_threadsafe(self.terminate_async(),
get_event_loop()).result()


class MIITensorParallelClient():
Expand All @@ -95,7 +101,7 @@ class MIITensorParallelClient():
def __init__(self, task_name, host, ports):
self.task = get_task(task_name)
self.clients = [MIIClient(task_name, host, port) for port in ports]
self.asyncio_loop = asyncio.get_event_loop()
self.asyncio_loop = get_event_loop()

# runs task in parallel and return the result from the first task
async def _query_in_tensor_parallel(self, request_string, query_kwargs):
Expand All @@ -107,7 +113,16 @@ async def _query_in_tensor_parallel(self, request_string, query_kwargs):
**query_kwargs)))

await responses[0]
return responses[0]
return responses[0].result()

def query_async(self, request_dict, **query_kwargs):
"""Asynchronously auery a local deployment.
See `query` for the arguments and the return value.
"""
return asyncio.run_coroutine_threadsafe(
self._query_in_tensor_parallel(request_dict,
query_kwargs),
self.asyncio_loop)

def query(self, request_dict, **query_kwargs):
"""Query a local deployment:
Expand All @@ -122,11 +137,7 @@ def query(self, request_dict, **query_kwargs):
Returns:
response: Response of the model
"""
response = self.asyncio_loop.run_until_complete(
self._query_in_tensor_parallel(request_dict,
query_kwargs))
ret = response.result()
return ret
return self.query_async(request_dict, **query_kwargs).result()

def terminate(self):
"""Terminates the deployment"""
Expand All @@ -136,5 +147,5 @@ def terminate(self):

def terminate_restful_gateway(deployment_name):
_, mii_configs = _get_deployment_info(deployment_name)
if mii_configs.restful_api_port > 0:
if mii_configs.enable_restful_api:
requests.get(f"http://localhost:{mii_configs.restful_api_port}/terminate")
14 changes: 14 additions & 0 deletions mii/event_loop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team
import asyncio
import threading

global event_loop
event_loop = asyncio.get_event_loop()
threading.Thread(target=event_loop.run_forever, daemon=True).start()


def get_event_loop():
return event_loop
16 changes: 6 additions & 10 deletions mii/grpc_related/modelresponse_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from mii.method_table import GRPC_METHOD_TABLE
from mii.client import create_channel
from mii.utils import get_task
from mii.event_loop import get_event_loop


class ServiceBase(modelresponse_pb2_grpc.ModelResponseServicer):
Expand All @@ -42,6 +43,7 @@ def __init__(self, inference_pipeline):
super().__init__()
self.inference_pipeline = inference_pipeline
self.method_name_to_task = {m["method"]: t for t, m in GRPC_METHOD_TABLE.items()}
self.lock = threading.Lock()

def _get_model_time(self, model, sum_times=False):
model_times = []
Expand Down Expand Up @@ -72,7 +74,8 @@ def _run_inference(self, method_name, request_proto):
args, kwargs = conversions["unpack_request_from_proto"](request_proto)

start = time.time()
response = self.inference_pipeline(*args, **kwargs)
with self.lock:
response = self.inference_pipeline(*args, **kwargs)
end = time.time()

model_time = self._get_model_time(self.inference_pipeline.model,
Expand Down Expand Up @@ -134,7 +137,7 @@ def __init__(self, host, ports):
stub = modelresponse_pb2_grpc.ModelResponseStub(channel)
self.stubs.append(stub)

self.asyncio_loop = asyncio.get_event_loop()
self.asyncio_loop = get_event_loop()

async def _invoke_async(self, method_name, proto_request):
responses = []
Expand All @@ -154,7 +157,7 @@ def invoke(self, method_name, proto_request):
class LoadBalancingInterceptor(grpc.ServerInterceptor):
def __init__(self, task_name, replica_configs):
super().__init__()
self.asyncio_loop = asyncio.get_event_loop()
self.asyncio_loop = get_event_loop()

self.stubs = [
ParallelStubInvoker(replica.hostname,
Expand All @@ -164,13 +167,6 @@ def __init__(self, task_name, replica_configs):
self.counter = AtomicCounter()
self.task = get_task(task_name)

# Start the asyncio loop in a separate thread
def run_asyncio_loop(loop):
asyncio.set_event_loop(loop)
loop.run_forever()

threading.Thread(target=run_asyncio_loop, args=(self.asyncio_loop, )).start()

def choose_stub(self, call_count):
return self.stubs[call_count % len(self.stubs)]

Expand Down