From c43bdad7ebc5aad6bc4f034a59f13f544e0d6917 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Thu, 2 Mar 2023 19:47:36 +0000 Subject: [PATCH 01/13] add RESTful api gateway --- mii/client.py | 7 ++++ mii/config.py | 1 + mii/constants.py | 2 + mii/grpc_related/restful_gateway.py | 60 +++++++++++++++++++++++++++++ mii/launch/multi_gpu_server.py | 25 ++++++++---- mii/server.py | 11 ++++++ mii/terminate.py | 1 + requirements/requirements.txt | 2 + 8 files changed, 102 insertions(+), 7 deletions(-) create mode 100644 mii/grpc_related/restful_gateway.py diff --git a/mii/client.py b/mii/client.py index 5923741b..d6f1bbb8 100644 --- a/mii/client.py +++ b/mii/client.py @@ -3,6 +3,7 @@ ''' import asyncio import grpc +import requests import mii from mii.utils import get_task from mii.grpc_related.proto import modelresponse_pb2, modelresponse_pb2_grpc @@ -130,3 +131,9 @@ def terminate(self): """Terminates the deployment""" for client in self.clients: client.terminate() + + +def terminate_restful_gateway(deployment_name): + _, mii_configs = _get_deployment_info(deployment_name) + if mii_configs.restful_api_port > 0: + requests.get(f"http://localhost:{mii_configs.restful_api_port}/terminate") diff --git a/mii/config.py b/mii/config.py index ae0f5505..56d482ad 100644 --- a/mii/config.py +++ b/mii/config.py @@ -49,6 +49,7 @@ class MIIConfig(BaseModel): max_tokens: int = 1024 enable_load_balancing: bool = False replica_num: int = 1 + restful_api_port: int = 0 hostfile: str = DLTS_HOSTFILE @validator("deploy_rank") diff --git a/mii/constants.py b/mii/constants.py index ffca47ad..06d76cf4 100644 --- a/mii/constants.py +++ b/mii/constants.py @@ -116,3 +116,5 @@ class ModelProvider(enum.Enum): LB_MAX_WORKER_THREADS = 32 SERVER_SHUTDOWN_TIMEOUT = 10 + +RESTFUL_API_PATH = "mii" diff --git a/mii/grpc_related/restful_gateway.py b/mii/grpc_related/restful_gateway.py new file mode 100644 index 00000000..2804f444 --- /dev/null +++ b/mii/grpc_related/restful_gateway.py @@ -0,0 +1,60 @@ +import time +import threading +import mii +from flask import Flask, request +from flask_restful import Resource, Api +from werkzeug.serving import make_server +from mii.constants import SERVER_SHUTDOWN_TIMEOUT, RESTFUL_API_PATH +from google.protobuf.json_format import MessageToJson + + +def shutdown(thread): + time.sleep(SERVER_SHUTDOWN_TIMEOUT) + thread.server.shutdown() + + +def createRestfulGatewayApp(task, mii_config, server_thread): + # client must be thread-safe + client = mii.MIIClient(task, "localhost", mii_config.port_number) + + class RestfulGatewayService(Resource): + def __init__(self): + super().__init__() + + def post(self): + data = request.get_json() + result = client.query(data["request"], **data["kwargs"]) + return MessageToJson(result) + + app = Flask("RestfulGateway") + + @app.route("/terminate", methods=['GET']) + def terminate(): + # Need to shutdown *after* completing the request + threading.Thread(target=shutdown, args=(server_thread, )).start() + return "Shutting down RESTful API gateway server" + + api = Api(app) + path = '/{}/{}'.format(RESTFUL_API_PATH, task) + api.add_resource(RestfulGatewayService, path) + + return app + + +class RestfulGatewayThread(threading.Thread): + def __init__(self, task, mii_config): + threading.Thread.__init__(self) + self.mii_config = mii_config + + app = createRestfulGatewayApp(task, mii_config, self) + self.server = make_server('127.0.0.1', mii_config.restful_api_port, app) + self.ctx = app.app_context() + self.ctx.push() + + self._stop_event = threading.Event() + + def run(self): + self.server.serve_forever() + + def get_stop_event(self): + return self._stop_event diff --git a/mii/launch/multi_gpu_server.py b/mii/launch/multi_gpu_server.py index 0a8dde65..52e5f68d 100644 --- a/mii/launch/multi_gpu_server.py +++ b/mii/launch/multi_gpu_server.py @@ -11,6 +11,7 @@ from mii.models.load_models import load_models from mii.grpc_related.modelresponse_server import serve_inference, serve_load_balancing +from mii.grpc_related.restful_gateway import RestfulGatewayThread def decode_config_from_str(config_str): @@ -48,19 +49,29 @@ def main(): type=str, default=None, help="base64 encoded load balancer config") + parser.add_argument("-r", + "--restful-gateway", + action='store_true', + help="launch restful api gateway") args = parser.parse_args() - # if args.load_balancer is not None: - if args.load_balancer is None: + # de-serialize config object + config_dict = decode_config_from_str(args.config) + # convert dict -> mii config + mii_config = MIIConfig(**config_dict) + + if args.restful_gateway: + print(f"Starting RESTful API gateway on port: {mii_config.restful_api_port}") + gateway_thread = RestfulGatewayThread(args.task_name, mii_config) + stop_event = gateway_thread.get_stop_event() + gateway_thread.start() + stop_event.wait() + + elif args.load_balancer is None: provider = mii.constants.MODEL_PROVIDER_MAP.get(args.provider, None) assert provider is not None, f"Unknown model provider: {args.provider}" - # de-serialize config object - config_dict = decode_config_from_str(args.config) - # convert dict -> mii config - mii_config = MIIConfig(**config_dict) - assert args.port is not None, "port is required for inference server" local_rank = int(os.getenv('LOCAL_RANK', '0')) port = args.port + local_rank diff --git a/mii/server.py b/mii/server.py index c7fa1d7d..bc17ffe2 100644 --- a/mii/server.py +++ b/mii/server.py @@ -318,6 +318,17 @@ def _initialize_service(self, mii_configs, lb_config)) + if mii_configs.restful_api_port > 0: + # start rest api server + processes.append( + self._launch_restful_gateway(model_name, + model_path, + ds_optimize, + ds_zero, + ds_config, + mii_configs, + mii_configs.port_number)) + return processes else: if self._is_socket_open("localhost", self.port_number): diff --git a/mii/terminate.py b/mii/terminate.py index d0f53732..61784f58 100644 --- a/mii/terminate.py +++ b/mii/terminate.py @@ -17,3 +17,4 @@ def terminate(deployment_name): pass generator.terminate() + mii.client.terminate_restful_gateway(deployment_name) diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 504ea485..243d5aee 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -1,7 +1,9 @@ asyncio deepspeed>=0.7.6 +Flask-RESTful grpcio grpcio-tools pydantic torch transformers +Werkzeug From 4de010e2db410e577d544a25f190494ae4c814ab Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Thu, 2 Mar 2023 21:30:01 +0000 Subject: [PATCH 02/13] add handling empty kwargs --- mii/grpc_related/restful_gateway.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mii/grpc_related/restful_gateway.py b/mii/grpc_related/restful_gateway.py index 2804f444..048b1814 100644 --- a/mii/grpc_related/restful_gateway.py +++ b/mii/grpc_related/restful_gateway.py @@ -23,7 +23,8 @@ def __init__(self): def post(self): data = request.get_json() - result = client.query(data["request"], **data["kwargs"]) + kwargs = data["kwargs"] if "kwargs" in data else {} + result = client.query(data["request"], **kwargs) return MessageToJson(result) app = Flask("RestfulGateway") From fd8387235eacb0d632e96d3f7d71f4308f716008 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Thu, 2 Mar 2023 22:41:17 +0000 Subject: [PATCH 03/13] set event to stop gateway server --- mii/constants.py | 1 + mii/grpc_related/restful_gateway.py | 5 +++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/mii/constants.py b/mii/constants.py index 06d76cf4..fd65b091 100644 --- a/mii/constants.py +++ b/mii/constants.py @@ -117,4 +117,5 @@ class ModelProvider(enum.Enum): SERVER_SHUTDOWN_TIMEOUT = 10 +RESTFUL_GATEWAY_SHUTDOWN_TIMEOUT = 1 RESTFUL_API_PATH = "mii" diff --git a/mii/grpc_related/restful_gateway.py b/mii/grpc_related/restful_gateway.py index 048b1814..725b1cc5 100644 --- a/mii/grpc_related/restful_gateway.py +++ b/mii/grpc_related/restful_gateway.py @@ -4,12 +4,12 @@ from flask import Flask, request from flask_restful import Resource, Api from werkzeug.serving import make_server -from mii.constants import SERVER_SHUTDOWN_TIMEOUT, RESTFUL_API_PATH +from mii.constants import RESTFUL_GATEWAY_SHUTDOWN_TIMEOUT, RESTFUL_API_PATH from google.protobuf.json_format import MessageToJson def shutdown(thread): - time.sleep(SERVER_SHUTDOWN_TIMEOUT) + time.sleep(RESTFUL_GATEWAY_SHUTDOWN_TIMEOUT) thread.server.shutdown() @@ -56,6 +56,7 @@ def __init__(self, task, mii_config): def run(self): self.server.serve_forever() + self._stop_event.set() def get_stop_event(self): return self._stop_event From 7bba5a7e5e67f98df3af3b699576e928c559579b Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Thu, 2 Mar 2023 22:41:36 +0000 Subject: [PATCH 04/13] add test for RESTful api --- tests/test_local_deployment.py | 38 ++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/tests/test_local_deployment.py b/tests/test_local_deployment.py index a0dfe339..cd517a15 100644 --- a/tests/test_local_deployment.py +++ b/tests/test_local_deployment.py @@ -2,6 +2,8 @@ import os import torch from types import SimpleNamespace +import json +import requests import mii @@ -42,6 +44,11 @@ def enable_load_balancing(request): return request.param +@pytest.fixture(scope="function", params=[0]) +def restful_api_port(request): + return request.param + + @pytest.fixture(scope="function", params=[True]) def enable_deepspeed(request): return request.param @@ -68,6 +75,7 @@ def mii_configs( port_number: int, load_with_sys_mem: bool, enable_load_balancing: bool, + restful_api_port: int, ): # Create a hostfile for DeepSpeed launcher when load_balancing is enabled @@ -85,6 +93,7 @@ def mii_configs( 'enable_load_balancing': enable_load_balancing, 'replica_num': num_gpu * enable_load_balancing, 'hostfile': hostfile, + 'restful_api_port': restful_api_port, } @@ -215,6 +224,35 @@ def test_load_balancing(local_deployment, query): assert result +@pytest.mark.local +@pytest.mark.parametrize("enable_load_balancing", [True]) +@pytest.mark.parametrize("restful_api_port", [28080]) +@pytest.mark.parametrize( + "task_name, model_name, query", + [ + ( + "text-generation", + "bigscience/bloom-560m", + { + "query": ["DeepSpeed is the greatest"] + }, + ), + ], +) +def test_restful_api(local_deployment, task_name, query, restful_api_port): + generator = mii.mii_query_handle(local_deployment.deployment_name) + for _ in range(2): + result = generator.query(query) + + url = f'http://localhost:{restful_api_port}/mii/{task_name}' + params = {"request": query} + json_params = json.dumps(params) + result = requests.post(url, + data=json_params, + headers={"Content-Type": "application/json"}) + assert result + + @pytest.mark.local @pytest.mark.parametrize("load_with_sys_mem", [True]) @pytest.mark.parametrize( From cab6394f53a0b1d567a2e315209851fefe57a9e4 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Tue, 7 Mar 2023 00:15:29 +0000 Subject: [PATCH 05/13] fix port number for replica_num == 1 --- mii/deployment.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mii/deployment.py b/mii/deployment.py index cd8a5147..e80324ff 100644 --- a/mii/deployment.py +++ b/mii/deployment.py @@ -110,7 +110,7 @@ def deploy(task, replica_configs = [] for i, (hostname, gpu_indices) in enumerate(replica_pool): # Reserver port for a LB proxy when replication is enabled - port_offset = 1 if mii_config.replica_num > 1 else 0 + port_offset = 1 if mii_config.enable_load_balancing else 0 base_port = mii_config.port_number + i * mii_config.tensor_parallel + port_offset tensor_parallel_ports = list( range(base_port, From a9ef1f9401ab01a1424962da4b24a07fc1edee5e Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Tue, 7 Mar 2023 01:31:23 +0000 Subject: [PATCH 06/13] change MIIConfig item to enable RESTful API --- mii/config.py | 9 ++++++++- tests/test_local_deployment.py | 10 +++++++++- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/mii/config.py b/mii/config.py index 56d482ad..552c6bf0 100644 --- a/mii/config.py +++ b/mii/config.py @@ -47,9 +47,10 @@ class MIIConfig(BaseModel): profile_model_time: bool = False skip_model_check: bool = False max_tokens: int = 1024 + enable_restful_api: bool = False + restful_api_port: int = 51080 enable_load_balancing: bool = False replica_num: int = 1 - restful_api_port: int = 0 hostfile: str = DLTS_HOSTFILE @validator("deploy_rank") @@ -85,6 +86,12 @@ def checkpoint_dict_valid(cls, value): raise ValueError(f"Missing key={k} in checkpoint_dict") return value + @validator('enable_load_balancing') + def enable_load_balancing_valid(cls, field_value, values): + if values["enable_restful_api"]: + field_value = True + return field_value + class Config: validate_all = True validate_assignment = True diff --git a/tests/test_local_deployment.py b/tests/test_local_deployment.py index cd517a15..608055a4 100644 --- a/tests/test_local_deployment.py +++ b/tests/test_local_deployment.py @@ -44,6 +44,11 @@ def enable_load_balancing(request): return request.param +@pytest.fixture(scope="function", params=[False]) +def enable_restful_api(request): + return request.param + + @pytest.fixture(scope="function", params=[0]) def restful_api_port(request): return request.param @@ -75,12 +80,14 @@ def mii_configs( port_number: int, load_with_sys_mem: bool, enable_load_balancing: bool, + enable_restful_api: bool, restful_api_port: int, ): # Create a hostfile for DeepSpeed launcher when load_balancing is enabled hostfile = os.path.join(tmpdir, "hostfile") num_gpu = torch.cuda.device_count() + enable_load_balancing = enable_load_balancing or enable_restful_api if enable_load_balancing: with open(hostfile, "w") as f: f.write(f"localhost slots={num_gpu}") @@ -93,6 +100,7 @@ def mii_configs( 'enable_load_balancing': enable_load_balancing, 'replica_num': num_gpu * enable_load_balancing, 'hostfile': hostfile, + 'enable_restful_api': enable_restful_api, 'restful_api_port': restful_api_port, } @@ -225,7 +233,7 @@ def test_load_balancing(local_deployment, query): @pytest.mark.local -@pytest.mark.parametrize("enable_load_balancing", [True]) +@pytest.mark.parametrize("enable_restful_api", [True]) @pytest.mark.parametrize("restful_api_port", [28080]) @pytest.mark.parametrize( "task_name, model_name, query", From d8f1d4a0723786584f50cbdde564a091fff0b540 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Tue, 7 Mar 2023 02:05:58 +0000 Subject: [PATCH 07/13] set deployment name to RESTful API path --- mii/constants.py | 1 + mii/grpc_related/restful_gateway.py | 8 ++++---- mii/launch/multi_gpu_server.py | 5 ++++- mii/models/score/generate.py | 1 + mii/models/score/score_template.py | 4 +++- mii/server.py | 32 +++++++++++++++++++++-------- tests/test_local_deployment.py | 4 ++-- 7 files changed, 39 insertions(+), 16 deletions(-) diff --git a/mii/constants.py b/mii/constants.py index fd65b091..f0ec5927 100644 --- a/mii/constants.py +++ b/mii/constants.py @@ -88,6 +88,7 @@ class ModelProvider(enum.Enum): MODEL_NAME_KEY = 'model_name' TASK_NAME_KEY = 'task_name' +DEPLOYMENT_NAME_KEY = 'deployment_name' MODEL_PATH_KEY = 'model_path' LOAD_BALANCER_CONFIG_KEY = 'load_balancer_config' diff --git a/mii/grpc_related/restful_gateway.py b/mii/grpc_related/restful_gateway.py index 725b1cc5..61967a83 100644 --- a/mii/grpc_related/restful_gateway.py +++ b/mii/grpc_related/restful_gateway.py @@ -13,7 +13,7 @@ def shutdown(thread): thread.server.shutdown() -def createRestfulGatewayApp(task, mii_config, server_thread): +def createRestfulGatewayApp(deployment_name, task, mii_config, server_thread): # client must be thread-safe client = mii.MIIClient(task, "localhost", mii_config.port_number) @@ -36,18 +36,18 @@ def terminate(): return "Shutting down RESTful API gateway server" api = Api(app) - path = '/{}/{}'.format(RESTFUL_API_PATH, task) + path = '/{}/{}'.format(RESTFUL_API_PATH, deployment_name) api.add_resource(RestfulGatewayService, path) return app class RestfulGatewayThread(threading.Thread): - def __init__(self, task, mii_config): + def __init__(self, deployment_name, task, mii_config): threading.Thread.__init__(self) self.mii_config = mii_config - app = createRestfulGatewayApp(task, mii_config, self) + app = createRestfulGatewayApp(deployment_name, task, mii_config, self) self.server = make_server('127.0.0.1', mii_config.restful_api_port, app) self.ctx = app.app_context() self.ctx.push() diff --git a/mii/launch/multi_gpu_server.py b/mii/launch/multi_gpu_server.py index 52e5f68d..1ad12a00 100644 --- a/mii/launch/multi_gpu_server.py +++ b/mii/launch/multi_gpu_server.py @@ -25,6 +25,7 @@ def decode_config_from_str(config_str): def main(): parser = argparse.ArgumentParser() + parser.add_argument("-n", "--deployment-name", type=str, help="deployment name") parser.add_argument("-t", "--task-name", type=str, help="task name") parser.add_argument("-m", "--model", type=str, help="model name") parser.add_argument("-d", "--model-path", type=str, help="path to model") @@ -63,7 +64,9 @@ def main(): if args.restful_gateway: print(f"Starting RESTful API gateway on port: {mii_config.restful_api_port}") - gateway_thread = RestfulGatewayThread(args.task_name, mii_config) + gateway_thread = RestfulGatewayThread(args.deployment_name, + args.task_name, + mii_config) stop_event = gateway_thread.get_stop_event() gateway_thread.start() stop_event.wait() diff --git a/mii/models/score/generate.py b/mii/models/score/generate.py index 81c4fdb7..d18777d0 100644 --- a/mii/models/score/generate.py +++ b/mii/models/score/generate.py @@ -19,6 +19,7 @@ def create_score_file(deployment_name, model_path, lb_config): config_dict = {} + config_dict[mii.constants.DEPLOYMENT_NAME_KEY] = deployment_name config_dict[mii.constants.TASK_NAME_KEY] = mii.utils.get_task_name(task) config_dict[mii.constants.MODEL_NAME_KEY] = model_name config_dict[mii.constants.ENABLE_DEEPSPEED_KEY] = ds_optimize diff --git a/mii/models/score/score_template.py b/mii/models/score/score_template.py index 9da45281..9019fd32 100644 --- a/mii/models/score/score_template.py +++ b/mii/models/score/score_template.py @@ -15,13 +15,15 @@ def init(): model_path = mii.utils.full_model_path(configs[mii.constants.MODEL_PATH_KEY]) + deployment_name = configs[mii.constants.DEPLOYMENT_NAME_KEY] model_name = configs[mii.constants.MODEL_NAME_KEY] task_name = configs[mii.constants.TASK_NAME_KEY] assert model_name is not None, "The model name should be set before calling init" assert task_name is not None, "The task name should be set before calling init" - mii.MIIServer(task_name, + mii.MIIServer(deployment_name, + task_name, model_name, model_path, ds_optimize=configs[mii.constants.ENABLE_DEEPSPEED_KEY], diff --git a/mii/server.py b/mii/server.py index bc17ffe2..ca138a64 100644 --- a/mii/server.py +++ b/mii/server.py @@ -27,6 +27,7 @@ def config_to_b64_str(config): class MIIServer(): '''Initialize the model, setup the server for the model under model_path''' def __init__(self, + deployment_name, task_name, model_name, model_path, @@ -49,7 +50,8 @@ def __init__(self, raise ValueError( "hostfile must be provided if enable_load_balancing == True") - processes = self._initialize_service(model_name, + processes = self._initialize_service(deployment_name, + model_name, model_path, ds_optimize, ds_zero, @@ -101,6 +103,7 @@ def _is_server_process_alive(self, process): return is_alive def _build_server_args(self, + deployment_name, model_name, model_path, ds_optimize, @@ -111,7 +114,7 @@ def _build_server_args(self, # serialize mii config b64_config_str = config_to_b64_str(mii_configs) - server_args_str = f"--task-name {mii.utils.get_task_name(self.task)} --model {model_name} --model-path {model_path} --port {port}" + server_args_str = f"--deployment-name {deployment_name} --task-name {mii.utils.get_task_name(self.task)} --model {model_name} --model-path {model_path} --port {port}" server_args_str += " --ds-optimize" if ds_optimize else "" # XXX: fetch model provider based on model name in a more general way @@ -166,6 +169,7 @@ def print_helper(self, args): return printable_string def _launch_load_balancer(self, + deployment_name, model_name, model_path, ds_optimize, @@ -178,6 +182,7 @@ def _launch_load_balancer(self, b64_config_str = config_to_b64_str(lb_config) return self._launch_server_process( + deployment_name, model_name, model_path, ds_optimize, @@ -189,6 +194,7 @@ def _launch_load_balancer(self, ex_server_args=[f"--load-balancer {b64_config_str}"]) def _launch_restful_gateway(self, + deployment_name, model_name, model_path, ds_optimize, @@ -196,7 +202,8 @@ def _launch_restful_gateway(self, ds_config, mii_configs, port): - return self._launch_server_process(model_name, + return self._launch_server_process(deployment_name, + model_name, model_path, ds_optimize, ds_zero, @@ -207,6 +214,7 @@ def _launch_restful_gateway(self, ex_server_args=["--restful-gateway"]) def _launch_server_process(self, + deployment_name, model_name, model_path, ds_optimize, @@ -218,7 +226,8 @@ def _launch_server_process(self, ds_launch_str=None, ex_server_args=[]): launch_str = f"{sys.executable} -m mii.launch.multi_gpu_server" - server_args_str = self._build_server_args(model_name, + server_args_str = self._build_server_args(deployment_name, + model_name, model_path, ds_optimize, ds_zero, @@ -239,6 +248,7 @@ def _launch_server_process(self, return subprocess.Popen(cmd, env=mii_env) def _launch_deepspeed(self, + deployment_name, model_name, model_path, ds_optimize, @@ -263,7 +273,8 @@ def _launch_deepspeed(self, ds_launch_str = f"deepspeed {worker_str} --no_local_rank --no_python" - return self._launch_server_process(model_name, + return self._launch_server_process(deployment_name, + model_name, model_path, ds_optimize, ds_zero, @@ -274,6 +285,7 @@ def _launch_deepspeed(self, ds_launch_str=ds_launch_str) def _initialize_service(self, + deployment_name, model_name, model_path, ds_optimize, @@ -292,6 +304,7 @@ def _initialize_service(self, f'{repl_config.hostname} slots={mii_configs.replica_num}\n'.encode()) processes.append( self._launch_deepspeed( + deployment_name, model_name, model_path, ds_optimize, @@ -310,7 +323,8 @@ def _initialize_service(self, # The deepspeed launcher determines the number of processes to launch based on GPUs available on the host or CUDA_VISIBLE_DEVICES, # and it is expected to assign one GPU to one process. processes.append( - self._launch_load_balancer(model_name, + self._launch_load_balancer(deployment_name, + model_name, model_path, ds_optimize, ds_zero, @@ -321,7 +335,8 @@ def _initialize_service(self, if mii_configs.restful_api_port > 0: # start rest api server processes.append( - self._launch_restful_gateway(model_name, + self._launch_restful_gateway(deployment_name, + model_name, model_path, ds_optimize, ds_zero, @@ -337,7 +352,8 @@ def _initialize_service(self, ) processes.append( - self._launch_deepspeed(model_name, + self._launch_deepspeed(deployment_name, + model_name, model_path, ds_optimize, ds_zero, diff --git a/tests/test_local_deployment.py b/tests/test_local_deployment.py index 608055a4..c0969ac5 100644 --- a/tests/test_local_deployment.py +++ b/tests/test_local_deployment.py @@ -247,12 +247,12 @@ def test_load_balancing(local_deployment, query): ), ], ) -def test_restful_api(local_deployment, task_name, query, restful_api_port): +def test_restful_api(local_deployment, query, restful_api_port): generator = mii.mii_query_handle(local_deployment.deployment_name) for _ in range(2): result = generator.query(query) - url = f'http://localhost:{restful_api_port}/mii/{task_name}' + url = f'http://localhost:{restful_api_port}/mii/{local_deployment.deployment_name}' params = {"request": query} json_params = json.dumps(params) result = requests.post(url, From 4b7dcbefd04cf808f64e9863bc5abe4c98bef91a Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Tue, 7 Mar 2023 20:14:09 +0000 Subject: [PATCH 08/13] use root_validator to set enable_load_balancing --- mii/config.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/mii/config.py b/mii/config.py index 552c6bf0..1d3226df 100644 --- a/mii/config.py +++ b/mii/config.py @@ -1,10 +1,12 @@ import torch from typing import Union, List from enum import Enum -from pydantic import BaseModel, validator +from pydantic import BaseModel, validator, root_validator from deepspeed.launcher.runner import DLTS_HOSTFILE +from .utils import logger + class DtypeEnum(Enum): # The torch dtype must always be the first value (so we return torch.dtype) @@ -86,11 +88,12 @@ def checkpoint_dict_valid(cls, value): raise ValueError(f"Missing key={k} in checkpoint_dict") return value - @validator('enable_load_balancing') - def enable_load_balancing_valid(cls, field_value, values): - if values["enable_restful_api"]: - field_value = True - return field_value + @root_validator + def auto_enable_load_balancing(cls, values): + if values["enable_restful_api"] and not values["enable_load_balancing"]: + logger.warn("Restful API is enabled, enabling Load Balancing") + values["enable_load_balancing"] = True + return values class Config: validate_all = True From c33dbef50de419fe48ccb42f0d7c94fc0753b65f Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Wed, 8 Mar 2023 19:32:55 +0000 Subject: [PATCH 09/13] add usage of load balancer and restful api --- README.md | 68 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 68 insertions(+) diff --git a/README.md b/README.md index bd6b968b..a7e9b35e 100644 --- a/README.md +++ b/README.md @@ -133,6 +133,74 @@ import mii mii.terminate("bloom560m_deployment") ``` +**Load balancing over multiple replicas** + +You can launch a load balancer and multiple replica of MII servers. +When `enable_load_balancing` is set to `True`, `mii.deploy()` launches the load balancer server and `replica_num` number of replicas. +Note that each replica consists of `tensor_parallel` server processes that are deployed on the same server. + +```python +mii_configs = { +... + "tensor_parallel": tensor_parallel, + "enable_load_balancing": True, + "replica_num": replica_num, + "hostfile": hostfile +} +mii.deploy(... + mii_config=mii_configs, + ...) +``` + +The client sends requests to the load balancer, which forwards them to the replicas, instead of sending requests to individual MII servers. +Currently, the load balancer implements a simple round-robin algorithm. +The load balancer acts as a simple proxy when `replica_num` is set to `1`. + +`hostfile` is the path to hostfile used by DeepSpeed's launcher. +When hostfile is not specified, DeepSpeed-MII uses the default path `/job/hostfile`, which is defined for DeepSpeed. +See the [DeepSpeed's document](https://www.deepspeed.ai/getting-started/#resource-configuration-multi-node) for the details. + +**RESTful API support** + +MII can enable users to call the inference service through RESTful APIs. +By setting `enable_restful_api` to `True`, `mii.deploy()` launches a gateway that accepts RESTful API. +The gateway can receive requests at `http://[HOST]:[PORT_FOR_RESTFUL_API]/mii/[DEPLOYMENT_NAME]`. + +```python +mii_configs = { +... + "enable_restful_api": True, + "restful_api_port": PORT_FOR_RESTFUL_API, +... +} +mii.deploy(... + deployment_name=DEPLOYMENT_NAME, + mii_config=mii_configs) +``` + +Any HTTP client can be used to call the APIs. An example of using curl is: +```bash +# Assume deployment_name and restful_api_port are set to bloom560m_deployment and 28080 respectively: +$ curl --header "Content-Type: application/json" --request POST -d '{"request": {"query": ["Seattle is", "Bellevue is", "Redmond is"]}, "kwargs": {"do_sample": false, "max_new_tokens": 100}}' http://localhost:28080/mii/bloom560m_deployment +``` + +The code below is an example using Python. + +```python +import requests +import json + +# text_generation +url = 'http://localhost:28080/mii/bloom560m_deployment' +params = {"request": {"query": ["Seattle is", "Bellevue is", "Redmond is"]}, + "kwargs": {"do_sample": False, "max_new_tokens": 100}} + +json_params = json.dumps(params) +response = requests.post(url, data=json_params, headers={ + "Content-Type": "application/json"}) +print(response.json()) +``` + ## Deploying with MII-Azure MII supports deployment on Azure via AML Inference. To enable this, MII generates AML deployment assets for a given model that can be deployed using the Azure-CLI, as shown in the code below. Furthermore, deploying on Azure, allows MII to leverage DeepSpeed-Azure as its optimization backend, which offers better latency and cost reduction than DeepSpeed-Public. From ed265f2a6a8d8febaad1b52578cc07570ff8d6b2 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Fri, 10 Mar 2023 21:34:39 +0000 Subject: [PATCH 10/13] check contents of RESTful api response --- tests/test_local_deployment.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_local_deployment.py b/tests/test_local_deployment.py index c0969ac5..0b890fa8 100644 --- a/tests/test_local_deployment.py +++ b/tests/test_local_deployment.py @@ -258,7 +258,8 @@ def test_restful_api(local_deployment, query, restful_api_port): result = requests.post(url, data=json_params, headers={"Content-Type": "application/json"}) - assert result + assert result.status_code == 200 + assert "response" in result.json() @pytest.mark.local From 9cbb69f3d186e9b940d0e634fd8a2e0e49d8baa7 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Mon, 20 Mar 2023 21:57:17 +0000 Subject: [PATCH 11/13] fix condition to launch restful gateway --- mii/server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mii/server.py b/mii/server.py index ca138a64..c4e8e2a1 100644 --- a/mii/server.py +++ b/mii/server.py @@ -332,7 +332,7 @@ def _initialize_service(self, mii_configs, lb_config)) - if mii_configs.restful_api_port > 0: + if mii_configs.enable_restful_api: # start rest api server processes.append( self._launch_restful_gateway(deployment_name, From 9f2c6082d072f266fa85003f5880cf96a1081be1 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Mon, 20 Mar 2023 22:39:13 +0000 Subject: [PATCH 12/13] run thread for event loop --- mii/client.py | 37 +++++++++++++++--------- mii/event_loop.py | 10 +++++++ mii/grpc_related/modelresponse_server.py | 16 ++++------ 3 files changed, 40 insertions(+), 23 deletions(-) create mode 100644 mii/event_loop.py diff --git a/mii/client.py b/mii/client.py index d6f1bbb8..f357a6b2 100644 --- a/mii/client.py +++ b/mii/client.py @@ -9,6 +9,7 @@ from mii.grpc_related.proto import modelresponse_pb2, modelresponse_pb2_grpc from mii.constants import GRPC_MAX_MSG_SIZE from mii.method_table import GRPC_METHOD_TABLE +from mii.event_loop import get_event_loop def _get_deployment_info(deployment_name): @@ -56,7 +57,7 @@ class MIIClient(): Client to send queries to a single endpoint. """ def __init__(self, task_name, host, port): - self.asyncio_loop = asyncio.get_event_loop() + self.asyncio_loop = get_event_loop() channel = create_channel(host, port) self.stub = modelresponse_pb2_grpc.ModelResponseStub(channel) self.task = get_task(task_name) @@ -73,17 +74,22 @@ async def _request_async_response(self, request_dict, **query_kwargs): proto_response ) if "unpack_response_from_proto" in conversions else proto_response - def query(self, request_dict, **query_kwargs): - return self.asyncio_loop.run_until_complete( + def query_async(self, request_dict, **query_kwargs): + return asyncio.run_coroutine_threadsafe( self._request_async_response(request_dict, - **query_kwargs)) + **query_kwargs), + get_event_loop()) + + def query(self, request_dict, **query_kwargs): + return self.query_async(request_dict, **query_kwargs).result() async def terminate_async(self): await self.stub.Terminate( modelresponse_pb2.google_dot_protobuf_dot_empty__pb2.Empty()) def terminate(self): - self.asyncio_loop.run_until_complete(self.terminate_async()) + asyncio.run_coroutine_threadsafe(self.terminate_async(), + get_event_loop()).result() class MIITensorParallelClient(): @@ -94,7 +100,7 @@ class MIITensorParallelClient(): def __init__(self, task_name, host, ports): self.task = get_task(task_name) self.clients = [MIIClient(task_name, host, port) for port in ports] - self.asyncio_loop = asyncio.get_event_loop() + self.asyncio_loop = get_event_loop() # runs task in parallel and return the result from the first task async def _query_in_tensor_parallel(self, request_string, query_kwargs): @@ -106,7 +112,16 @@ async def _query_in_tensor_parallel(self, request_string, query_kwargs): **query_kwargs))) await responses[0] - return responses[0] + return responses[0].result() + + def query_async(self, request_dict, **query_kwargs): + """Asynchronously auery a local deployment. + See `query` for the arguments and the return value. + """ + return asyncio.run_coroutine_threadsafe( + self._query_in_tensor_parallel(request_dict, + query_kwargs), + self.asyncio_loop) def query(self, request_dict, **query_kwargs): """Query a local deployment: @@ -121,11 +136,7 @@ def query(self, request_dict, **query_kwargs): Returns: response: Response of the model """ - response = self.asyncio_loop.run_until_complete( - self._query_in_tensor_parallel(request_dict, - query_kwargs)) - ret = response.result() - return ret + return self.query_async(request_dict, **query_kwargs).result() def terminate(self): """Terminates the deployment""" @@ -135,5 +146,5 @@ def terminate(self): def terminate_restful_gateway(deployment_name): _, mii_configs = _get_deployment_info(deployment_name) - if mii_configs.restful_api_port > 0: + if mii_configs.enable_restful_api: requests.get(f"http://localhost:{mii_configs.restful_api_port}/terminate") diff --git a/mii/event_loop.py b/mii/event_loop.py new file mode 100644 index 00000000..0315d00d --- /dev/null +++ b/mii/event_loop.py @@ -0,0 +1,10 @@ +import asyncio +import threading + +global event_loop +event_loop = asyncio.get_event_loop() +threading.Thread(target=event_loop.run_forever, daemon=True).start() + + +def get_event_loop(): + return event_loop diff --git a/mii/grpc_related/modelresponse_server.py b/mii/grpc_related/modelresponse_server.py index d99f0274..2b9c3316 100644 --- a/mii/grpc_related/modelresponse_server.py +++ b/mii/grpc_related/modelresponse_server.py @@ -16,6 +16,7 @@ from mii.method_table import GRPC_METHOD_TABLE from mii.client import create_channel from mii.utils import get_task +from mii.event_loop import get_event_loop class ServiceBase(modelresponse_pb2_grpc.ModelResponseServicer): @@ -41,6 +42,7 @@ def __init__(self, inference_pipeline): super().__init__() self.inference_pipeline = inference_pipeline self.method_name_to_task = {m["method"]: t for t, m in GRPC_METHOD_TABLE.items()} + self.lock = threading.Lock() def _get_model_time(self, model, sum_times=False): model_times = [] @@ -71,7 +73,8 @@ def _run_inference(self, method_name, request_proto): args, kwargs = conversions["unpack_request_from_proto"](request_proto) start = time.time() - response = self.inference_pipeline(*args, **kwargs) + with self.lock: + response = self.inference_pipeline(*args, **kwargs) end = time.time() model_time = self._get_model_time(self.inference_pipeline.model, @@ -133,7 +136,7 @@ def __init__(self, host, ports): stub = modelresponse_pb2_grpc.ModelResponseStub(channel) self.stubs.append(stub) - self.asyncio_loop = asyncio.get_event_loop() + self.asyncio_loop = get_event_loop() async def _invoke_async(self, method_name, proto_request): responses = [] @@ -153,7 +156,7 @@ def invoke(self, method_name, proto_request): class LoadBalancingInterceptor(grpc.ServerInterceptor): def __init__(self, task_name, replica_configs): super().__init__() - self.asyncio_loop = asyncio.get_event_loop() + self.asyncio_loop = get_event_loop() self.stubs = [ ParallelStubInvoker(replica.hostname, @@ -163,13 +166,6 @@ def __init__(self, task_name, replica_configs): self.counter = AtomicCounter() self.task = get_task(task_name) - # Start the asyncio loop in a separate thread - def run_asyncio_loop(loop): - asyncio.set_event_loop(loop) - loop.run_forever() - - threading.Thread(target=run_asyncio_loop, args=(self.asyncio_loop, )).start() - def choose_stub(self, call_count): return self.stubs[call_count % len(self.stubs)] From b0ae84cf2cb61399a88e0282ab1822d9d6f7f60b Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Wed, 19 Apr 2023 01:29:49 +0500 Subject: [PATCH 13/13] add license --- mii/event_loop.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/mii/event_loop.py b/mii/event_loop.py index 0315d00d..4040c86f 100644 --- a/mii/event_loop.py +++ b/mii/event_loop.py @@ -1,3 +1,7 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team import asyncio import threading