diff --git a/README.md b/README.md index 27e6513b..83eed709 100644 --- a/README.md +++ b/README.md @@ -178,24 +178,6 @@ mii.deploy(... mii_config=mii_configs) ``` -**Non-persistent Deployment** - -You can enable a non-persistent deployment which allows you to make queries without standing up a server. The non-persistent deployment acts as a simplified interface to DeepSpeed-inference for use cases that do not require creating a persistent model server process. Changing the `deployment_type` to `NON_PERSISTENT` in `mii.deploy(...)` will activate this option. - -```python -... -mii.deploy(deployment_name = DEPLOYMENT_NAME, - deployment_type=mii.constants.DeploymentType.NON_PERSISTENT - ... - ) - -generator = mii.mii_query_handle(DEPLOYMENT_NAME) -result = generator.query({"query": ["DeepSpeed is", "Seattle is"]}, do_sample=True, max_new_tokens=30}) - -``` - -You can find a complete example [here]("https://github.com/microsoft/DeepSpeed-MII/tree/main/examples/non_persistent") - Any HTTP client can be used to call the APIs. An example of using curl is: ```bash # Assume deployment_name and restful_api_port are set to bloom560m_deployment and 28080 respectively: @@ -219,6 +201,24 @@ response = requests.post(url, data=json_params, headers={ print(response.json()) ``` +**Non-persistent Deployment** + +You can enable a non-persistent deployment which allows you to make queries without standing up a server. The non-persistent deployment acts as a simplified interface to DeepSpeed-inference for use cases that do not require creating a persistent model server process. Changing the `deployment_type` to `NON_PERSISTENT` in `mii.deploy(...)` will activate this option. + +```python +... +mii.deploy(deployment_name = DEPLOYMENT_NAME, + deployment_type=mii.constants.DeploymentType.NON_PERSISTENT + ... + ) + +generator = mii.mii_query_handle(DEPLOYMENT_NAME) +result = generator.query({"query": ["DeepSpeed is", "Seattle is"]}, do_sample=True, max_new_tokens=30}) + +``` + +You can find a complete example [here]("https://github.com/microsoft/DeepSpeed-MII/tree/main/examples/non_persistent") + ## Deploying with MII-Azure MII supports deployment on Azure via AML Inference. To enable this, MII generates AML deployment assets for a given model that can be deployed using the Azure-CLI, as shown in the code below. Furthermore, deploying on Azure, allows MII to leverage DeepSpeed-Azure as its optimization backend, which offers better latency and cost reduction than DeepSpeed-Public. diff --git a/examples/multi_model/add_delete_models.py b/examples/multi_model/add_delete_models.py new file mode 100644 index 00000000..2a85b0f3 --- /dev/null +++ b/examples/multi_model/add_delete_models.py @@ -0,0 +1,32 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import mii + +deployments = [] +results = [] +name = 'bigscience/bloom-560m' +mii_configs1 = {"tensor_parallel": 1, "dtype": "fp16"} +deployments.append( + mii.DeploymentConfig(task='text-generation', + model=name, + deployment_name=name + "_deployment5", + mii_configs=mii.config.MIIConfig(**mii_configs1) + )) + +generator = mii.mii_query_handle("multi_models") +generator.add_models(deployments=deployments) + +result = generator.query( + { + "query": ["DeepSpeed is", + "Seattle is"], + "deployment_name": "bigscience/bloom-560m_deployment5" + }, + do_sample=True, + max_new_tokens=30, +) +print(result) +generator.delete_model("bigscience/bloom-560m_deployment5") diff --git a/examples/multi_model/deploy.py b/examples/multi_model/deploy.py new file mode 100644 index 00000000..525b2da3 --- /dev/null +++ b/examples/multi_model/deploy.py @@ -0,0 +1,49 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +import mii + +gpu_index_map1 = {'master': [0]} +gpu_index_map2 = {'master': [1]} +gpu_index_map3 = {'master': [0, 1]} + +deployments = [] + +mii_configs1 = {"tensor_parallel": 2, "dtype": "fp16"} +mii_configs2 = {"tensor_parallel": 1} + +name = "bigscience/bloom-560m" +deployments.append( + mii.DeploymentConfig(task='text-generation', + model=name, + deployment_name=name + "_deployment", + GPU_index_map=gpu_index_map3, + tensor_parallel=2, + dtype="fp16")) + +# gpt2 +name = "microsoft/DialogRPT-human-vs-rand" +deployments.append( + mii.DeploymentConfig(task='text-classification', + model=name, + deployment_name=name + "_deployment", + GPU_index_map=gpu_index_map2)) + +name = "microsoft/DialoGPT-large" +deployments.append( + mii.DeploymentConfig( + task='conversational', + model=name, + deployment_name=name + "_deployment", + GPU_index_map=gpu_index_map1, + )) + +name = "deepset/roberta-large-squad2" +deployments.append( + mii.DeploymentConfig(task="question-answering", + model=name, + deployment_name=name + "-qa-deployment", + GPU_index_map=gpu_index_map2)) + +mii.deploy(deployment_tag="multi_models", deployments=deployments) diff --git a/examples/multi_model/query.py b/examples/multi_model/query.py new file mode 100644 index 00000000..f506830f --- /dev/null +++ b/examples/multi_model/query.py @@ -0,0 +1,50 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import mii + +results = [] +generator = mii.mii_query_handle("multi_models") +result = generator.query( + { + "query": ["DeepSpeed is", + "Seattle is"], + "deployment_name": "bigscience/bloom-560m_deployment" + }, + do_sample=True, + max_new_tokens=30, +) +results.append(result) +print(result) + +result = generator.query({ + 'query': + "DeepSpeed is the greatest", + "deployment_name": + "microsoft/DialogRPT-human-vs-rand_deployment" +}) +results.append(result) +print(result) + +result = generator.query({ + 'text': "DeepSpeed is the greatest", + 'conversation_id': 3, + 'past_user_inputs': [], + 'generated_responses': [], + "deployment_name": "microsoft/DialoGPT-large_deployment" +}) +results.append(result) +print(result) + +result = generator.query({ + 'question': + "What is the greatest?", + 'context': + "DeepSpeed is the greatest", + "deployment_name": + "deepset/roberta-large-squad2" + "-qa-deployment" +}) +results.append(result) +print(result) diff --git a/examples/multi_model/shutdown.py b/examples/multi_model/shutdown.py new file mode 100644 index 00000000..6b718a4d --- /dev/null +++ b/examples/multi_model/shutdown.py @@ -0,0 +1,7 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +import mii + +mii.terminate("multi_models") diff --git a/mii/__init__.py b/mii/__init__.py index ab409d4c..66748a56 100644 --- a/mii/__init__.py +++ b/mii/__init__.py @@ -10,7 +10,7 @@ from .constants import DeploymentType, Tasks from .aml_related.utils import aml_output_path -from .config import MIIConfig, LoadBalancerConfig +from .config import MIIConfig, LoadBalancerConfig, DeploymentConfig from .grpc_related.proto import modelresponse_pb2_grpc __version__ = "0.0.0" diff --git a/mii/client.py b/mii/client.py index 535b55c8..d478f7f2 100644 --- a/mii/client.py +++ b/mii/client.py @@ -8,21 +8,27 @@ import mii from mii.utils import get_task from mii.grpc_related.proto import modelresponse_pb2, modelresponse_pb2_grpc -from mii.constants import GRPC_MAX_MSG_SIZE, Tasks +from mii.constants import GRPC_MAX_MSG_SIZE, Tasks, DeploymentType from mii.method_table import GRPC_METHOD_TABLE +from mii.deployment import allocate_processes +from mii.config import DeploymentConfig, MIIConfig -def _get_deployment_info(deployment_name): - configs = mii.utils.import_score_file(deployment_name).configs - task = configs[mii.constants.TASK_NAME_KEY] - mii_configs_dict = configs[mii.constants.MII_CONFIGS_KEY] - mii_configs = mii.config.MIIConfig(**mii_configs_dict) +def _get_deployment_configs(deployment_tag): + deployments = {} + configs = mii.utils.import_score_file(deployment_tag).configs + for deployment in configs.get(mii.constants.DEPLOYMENTS_KEY).values(): + deployment_name = deployment[mii.constants.DEPLOYMENT_NAME_KEY] + deployments[deployment_name] = DeploymentConfig(**deployment) + lb_config = configs.get(mii.constants.LOAD_BALANCER_CONFIG_KEY) + model_path = configs.get(mii.constants.MODEL_PATH_KEY) + port_map = configs.get(mii.constants.PORT_MAP_KEY) + deployment_type = configs.get(mii.constants.DEPLOYMENT_TYPE_KEY) + mii_configs = MIIConfig(**configs.get(mii.constants.MII_CONFIGS_KEY)) + return deployments, lb_config, model_path, port_map, deployment_type, mii_configs - assert task is not None, "The task name should be set before calling init" - return task, mii_configs - -def mii_query_handle(deployment_name): +def mii_query_handle(deployment_tag): """Get a query handle for a local deployment: mii/examples/local/gpt2-query-example.py @@ -35,12 +41,31 @@ def mii_query_handle(deployment_name): query_handle: A query handle with a single method `.query(request_dictionary)` using which queries can be sent to the model. """ - if deployment_name in mii.non_persistent_models: - inference_pipeline, task = mii.non_persistent_models[deployment_name] - return MIINonPersistentClient(task, deployment_name) + if deployment_tag in mii.non_persistent_models: + inference_pipeline, task = mii.non_persistent_models[deployment_tag] + return MIINonPersistentClient(task, deployment_tag) - task_name, mii_configs = _get_deployment_info(deployment_name) - return MIIClient(task_name, "localhost", mii_configs.port_number) + deployments, lb_config, model_path, port_map, deployment_type, mii_configs = _get_deployment_configs(deployment_tag) + """mii_configs = None + if len(deployments) > 0: + mii_configs = getattr(next(iter(deployments.values())), + mii.constants.MII_CONFIGS_KEY) + """ + port_number = None if mii_configs == None else mii_configs.port_number + """if port_number: + for deployment in deployments.values(): + assert getattr(deployment, mii.constants.MII_CONFIGS_KEY).port_number == port_number, f"All port numbers is each deployments mii_configs must match" + """ + + return LBClient(deployments, + "localhost", + port_number, + lb_config, + model_path, + port_map, + deployment_tag, + deployment_type, + mii_configs) def create_channel(host, port): @@ -55,51 +80,191 @@ class MIIClient(): """ Client to send queries to a single endpoint. """ - def __init__(self, task_name, host, port): + def __init__(self, deployments, host, port): self.asyncio_loop = asyncio.get_event_loop() - channel = create_channel(host, port) - self.stub = modelresponse_pb2_grpc.ModelResponseStub(channel) - self.task = get_task(task_name) - - async def _request_async_response(self, request_dict, **query_kwargs): - if self.task not in GRPC_METHOD_TABLE: - raise ValueError(f"unknown task: {self.task}") - - task_methods = GRPC_METHOD_TABLE[self.task] + self.mr_stub = None + self.channel = None + self.host = host + if port is not None: + self.channel = create_channel(host, port) + self.mr_stub = modelresponse_pb2_grpc.ModelResponseStub(self.channel) + self.deployments = deployments + + def _get_deployment_task(self, deployment_name=None): + task = None + if deployment_name is None or deployment_name == mii.constants.MII_TERMINATE_DEP_KEY: #mii.terminate() or single model + if deployment_name is None: + assert len(self.deployments) == 1, "Must pass deployment_name to query when using multiple deployments" + deployment = next(iter(self.deployments.values())) + deployment_name = getattr(deployment, mii.constants.DEPLOYMENT_NAME_KEY) + task = getattr(deployment, mii.constants.TASK_NAME_KEY) + else: + if deployment_name in self.deployments: + deployment = self.deployments[deployment_name] + task = getattr(deployment, mii.constants.TASK_NAME_KEY) + else: + assert False, f"{deployment_name} not found in list of deployments" + return deployment_name, task + + async def _request_async_response(self, request_dict, task, **query_kwargs): + if task not in GRPC_METHOD_TABLE: + raise ValueError(f"unknown task: {task}") + + task_methods = GRPC_METHOD_TABLE[task] proto_request = task_methods.pack_request_to_proto(request_dict, **query_kwargs) - proto_response = await getattr(self.stub, task_methods.method)(proto_request) + proto_response = await getattr(self.mr_stub, task_methods.method)(proto_request) return task_methods.unpack_response_from_proto(proto_response) def query(self, request_dict, **query_kwargs): + deployment_name = request_dict.get(mii.constants.DEPLOYMENT_NAME_KEY) + deployment_name, task = self._get_deployment_task(deployment_name) + request_dict['deployment_name'] = deployment_name return self.asyncio_loop.run_until_complete( self._request_async_response(request_dict, + task, **query_kwargs)) async def terminate_async(self): - await self.stub.Terminate( + await self.lb_stub.Terminate( modelresponse_pb2.google_dot_protobuf_dot_empty__pb2.Empty()) def terminate(self): self.asyncio_loop.run_until_complete(self.terminate_async()) async def create_session_async(self, session_id): - return await self.stub.CreateSession( + return await self.mr_stub.CreateSession( modelresponse_pb2.SessionID(session_id=session_id)) - def create_session(self, session_id): - assert self.task == Tasks.TEXT_GENERATION, f"Session creation only available for task '{Tasks.TEXT_GENERATION}'." + def create_session(self, session_id, deployment_name=None): + if len(self.deployments > 1): + assert deployment_name is not None, "Deployment name must be passed in to create session when there are multiple models" + deployment_name, task = self._get_deployment_task(deployment_name) + assert task == Tasks.TEXT_GENERATION, f"Session creation only available for task '{Tasks.TEXT_GENERATION}'." return self.asyncio_loop.run_until_complete( self.create_session_async(session_id)) async def destroy_session_async(self, session_id): - await self.stub.DestroySession(modelresponse_pb2.SessionID(session_id=session_id) - ) + await self.mr_stub.DestroySession( + modelresponse_pb2.SessionID(session_id=session_id)) - def destroy_session(self, session_id): - assert self.task == Tasks.TEXT_GENERATION, f"Session deletion only available for task '{Tasks.TEXT_GENERATION}'." + def destroy_session(self, session_id, deployment_name=None): + if len(self.deployments > 1): + assert deployment_name is not None, "Deployment name must be passed in to destroy session when there are multiple models" + deployment_name, task = self._get_deployment_task(deployment_name) + assert task == Tasks.TEXT_GENERATION, f"Session deletion only available for task '{Tasks.TEXT_GENERATION}'." self.asyncio_loop.run_until_complete(self.destroy_session_async(session_id)) +class LBClient(MIIClient): + def __init__(self, + deployments, + host, + port, + lb_config=None, + model_path=None, + port_map=None, + deployment_tag=None, + deployment_type=DeploymentType.LOCAL, + mii_configs={}): + super().__init__(deployments, host, port) + self.lb_stub = None + if port is not None: + channel = create_channel(host, port) if not self.channel else self.channel + self.lb_stub = modelresponse_pb2_grpc.DeploymentManagementStub(channel) + self.lb_config = lb_config + self.model_path = model_path + self.port_map = port_map if port_map is not None else {} + self.deployment_tag = deployment_tag + self.deployment_type = deployment_type + self.mii_configs = mii_configs + + async def add_models_async(self, proto_request): + await getattr(self.lb_stub, "AddDeployment")(proto_request) + + def add_models(self, deployments=[], model_path=None, version=1): + assert self.deployment_type != DeploymentType.AML, "Cannot currently add models to AML deployment" + """_, deployments = validate_deployment(task=task, + model=model, + deployment_name=deployment_name, + enable_deepspeed=enable_deepspeed, + enable_zero=enable_zero, + ds_config=ds_config, + mii_config=mii_config, + deployment_tag=self.deployment_tag, + deployments=deployments, + deployment_type=deployment_type, + model_path=model_path, + version=version) + """ + if not deployments: #Empty deployment + return None + + deps = { + getattr(deployment, + mii.constants.DEPLOYMENT_NAME_KEY): deployment + for deployment in deployments + } + lb_config, self.port_map = allocate_processes(deps, self.port_map, self.mii_configs) + lb_enabled = True if len(self.deployments) else False + if self.lb_config is not None: + self.lb_config.replica_configs.extend(lb_config.replica_configs) + else: + self.lb_config = lb_config + for deployment in deployments: + self.deployments[getattr(deployment, + mii.constants.DEPLOYMENT_NAME_KEY)] = deployment + if self.model_path is None and self.deployment_type == DeploymentType.LOCAL: + self.model_path = mii.constants.MII_MODEL_PATH_DEFAULT + """create_score_file(deployment_tag=self.deployment_tag, + deployment_type=deployment_type, + deployments=deps, + model_path=self.model_path, + port_map=self.port_map, + lb_config=lb_config, + deployed=lb_enabled) + + if deployment_type == DeploymentType.LOCAL: + mii.utils.import_score_file(self.deployment_tag).init() + """ + if not self.mii_configs: + self.mii_configs = mii.configs.MIIConfigs(**{}) + mii.MIIServer(self.deployment_tag, + deps.values(), + self.model_path, + lb_config=lb_config, + lb_enabled=lb_enabled, + mii_configs=self.mii_configs) + + if self.lb_stub is None: + self.port_number = self.mii_configs.port_number + self.channel = create_channel(self.host, self.port_number) + self.lb_stub = modelresponse_pb2_grpc.DeploymentManagementStub(self.channel) + if not self.mr_stub: + self.mr_stub = modelresponse_pb2_grpc.ModelResponseStub(self.channel) + for replica in lb_config.replica_configs: + request_proto = modelresponse_pb2.AddDeployRequest( + task=replica.task, + deployment_name=replica.deployment_name, + hostname=replica.hostname, + tensor_parallel_ports=replica.tensor_parallel_ports, + torch_dist_port=replica.torch_dist_port, + gpu_indices=replica.gpu_indices) + + self.asyncio_loop.run_until_complete(self.add_models_async(request_proto)) + + async def delete_model_async(self, proto_request): + await getattr(self.lb_stub, "DeleteDeployment")(proto_request) + + def delete_model(self, deployment_name): + if deployment_name in self.deployments: + request_proto = modelresponse_pb2.DeleteDeployRequest( + deployment_name=deployment_name) + self.asyncio_loop.run_until_complete(self.delete_model_async(request_proto)) + del self.deployments[deployment_name] + return None + assert False, f"Deployment: {deployment_name} not found" + + class MIITensorParallelClient(): """ Client to send queries to multiple endpoints in parallel. @@ -157,7 +322,7 @@ def destroy_session(self, session_id): class MIINonPersistentClient(): def __init__(self, task, deployment_name): - self.task = task + self.task = get_task(task) self.deployment_name = deployment_name def query(self, request_dict, **query_kwargs): @@ -188,7 +353,9 @@ def terminate(self): del mii.non_persistent_models[self.deployment_name] -def terminate_restful_gateway(deployment_name): - _, mii_configs = _get_deployment_info(deployment_name) - if mii_configs.enable_restful_api: - requests.get(f"http://localhost:{mii_configs.restful_api_port}/terminate") +def terminate_restful_gateway(deployment_tag): + deployments, _, _, _, _, mii_configs = _get_deployment_configs(deployment_tag) + for deployment in deployments.values(): + #mii_configs = getattr(deployment, mii.constants.MII_CONFIGS_KEY) + if deployment.enable_restful_api: + requests.get(f"http://localhost:{mii_configs.restful_api_port}/terminate") diff --git a/mii/config.py b/mii/config.py index 6a8bac16..695054c3 100644 --- a/mii/config.py +++ b/mii/config.py @@ -5,9 +5,10 @@ import torch from typing import Union, List from enum import Enum -from pydantic import BaseModel, validator, root_validator - +from pydantic import BaseModel, validator, root_validator, Field from deepspeed.launcher.runner import DLTS_HOSTFILE +from mii.utils import get_task +from mii.constants import DEPLOYMENT_NAME_KEY, TASK_NAME_KEY, MODEL_NAME_KEY, ENABLE_DEEPSPEED_KEY, ENABLE_DEEPSPEED_ZERO_KEY, GPU_INDEX_KEY, DEEPSPEED_CONFIG_KEY, VERSION_KEY class DtypeEnum(Enum): @@ -56,7 +57,8 @@ class MIIConfig(BaseModel): restful_api_port: int = 51080 replica_num: int = 1 hostfile: str = DLTS_HOSTFILE - + trust_remote_code: bool = False + """ @validator("deploy_rank") def deploy_valid(cls, field_value, values): if "tensor_parallel" not in values: @@ -75,8 +77,9 @@ def deploy_valid(cls, field_value, values): # number of ranks provided must be equal to TP size, DP is handled outside MII currently assert values["tensor_parallel"] == len(field_value), \ f"{len(field_value)} rank(s) provided in 'deploy_rank' does not align with tensor_parallel size of {values['tensor_parallel']}" - return field_value - + return field_value + """ + """ @validator('checkpoint_dict') def checkpoint_dict_valid(cls, value): if value is None: @@ -89,7 +92,8 @@ def checkpoint_dict_valid(cls, value): if not value.get(k, ''): raise ValueError(f"Missing key={k} in checkpoint_dict") return value - + """ + """ @root_validator def meta_tensor_or_sys_mem(cls, values): if values.get("meta_tensor") and values.get("load_with_sys_mem"): @@ -97,7 +101,7 @@ def meta_tensor_or_sys_mem(cls, values): "`meta_tensor` and `load_with_sys_mem` cannot be active at the same time." ) return values - + """ class Config: validate_all = True validate_assignment = True @@ -107,6 +111,8 @@ class Config: class ReplicaConfig(BaseModel): + task: str = "" + deployment_name: str = "" hostname: str = "" tensor_parallel_ports: List[int] = [] torch_dist_port: int = None @@ -123,4 +129,86 @@ class LoadBalancerConfig(BaseModel): class Config: validate_all = True + + validate_assignment = True + + +class DeploymentConfig(BaseModel): + deployment_name: str = Field(alias=DEPLOYMENT_NAME_KEY) + task: str = Field(alias=TASK_NAME_KEY) + model: str = Field(alias=MODEL_NAME_KEY) + ds_optimize: bool = Field(default=True, alias=ENABLE_DEEPSPEED_KEY) + ds_zero: bool = Field(default=False, alias=ENABLE_DEEPSPEED_ZERO_KEY) + GPU_index_map: dict = Field(default=None, alias=GPU_INDEX_KEY) + #mii_configs: MIIConfig = Field(default={}, alias=MII_CONFIGS_KEY) + ds_config: dict = Field(default=None, alias=DEEPSPEED_CONFIG_KEY) + version: int = Field(default=1, alias=VERSION_KEY) + tensor_parallel: int = 1 + dtype: DtypeEnum = torch.float32 + meta_tensor: bool = False + load_with_sys_mem: bool = False + replace_with_kernel_inject: bool = True + profile_model_time: bool = False + skip_model_check: bool = False + max_tokens: int = 1024 + enable_restful_api: bool = False + replica_num: int = 1 + hostfile: str = DLTS_HOSTFILE + deploy_rank: Union[int, List[int]] = -1 + enable_cuda_graph: bool = False + checkpoint_dict: Union[dict, None] = None + hf_auth_token: str = None + trust_remote_code: bool = False + + @validator('checkpoint_dict') + def checkpoint_dict_valid(cls, value): + if value is None: + return value + if value.get('base_dir', ''): + raise ValueError( + "please unset 'base_dir' it will be set w.r.t. the deployment 'model_path'" + ) + for k in ['checkpoints', 'parallelization', 'version', 'type']: + if not value.get(k, ''): + raise ValueError(f"Missing key={k} in checkpoint_dict") + return value + + @validator("deploy_rank") + def deploy_valid(cls, field_value, values): + if "tensor_parallel" not in values: + raise ValueError( + "'tensor_parallel' must be defined in the pydantic model before 'deploy_rank'" + ) + + # if deploy rank is not given, default to align with TP value + if field_value == -1: + field_value = list(range(values["tensor_parallel"])) + + # ensure deploy rank type is always list for easier consumption later + if not isinstance(field_value, list): + field_value = [field_value] + + # number of ranks provided must be equal to TP size, DP is handled outside MII currently + assert values["tensor_parallel"] == len(field_value), \ + f"{len(field_value)} rank(s) provided in 'deploy_rank' does not align with tensor_parallel size of {values['tensor_parallel']}" + return field_value + + @root_validator + def meta_tensor_or_sys_mem(cls, values): + if values.get("meta_tensor") and values.get("load_with_sys_mem"): + raise ValueError( + "`meta_tensor` and `load_with_sys_mem` cannot be active at the same time." + ) + return values + + @validator("task") + def convert_task_str(cls, field_value, values): + return get_task(field_value) + + class Config: + allow_population_by_field_name = True + validate_all = True validate_assignment = True + use_enum_values = True + extra = 'forbid' + json_encoders = {torch.dtype: lambda x: str(x)} diff --git a/mii/constants.py b/mii/constants.py index ba4cfa2f..3d674efe 100644 --- a/mii/constants.py +++ b/mii/constants.py @@ -88,17 +88,23 @@ class ModelProvider(enum.Enum): 'generated_responses'], TEXT2IMG_NAME: ["query"] } - -MODEL_NAME_KEY = 'model_name' -TASK_NAME_KEY = 'task_name' +GPU_INDEX_KEY = "GPU_index_map" +DEPLOYMENTS_KEY = 'deployments' +PORT_MAP_KEY = 'port_map' +MODEL_NAME_KEY = 'model' +TASK_NAME_KEY = 'task' DEPLOYMENT_NAME_KEY = 'deployment_name' MODEL_PATH_KEY = 'model_path' LOAD_BALANCER_CONFIG_KEY = 'load_balancer_config' - +DEPLOYMENT_TAG_KEY = 'deployment_tag' ENABLE_DEEPSPEED_KEY = 'ds_optimize' ENABLE_DEEPSPEED_ZERO_KEY = 'ds_zero' DEEPSPEED_CONFIG_KEY = 'ds_config' CHECKPOINT_KEY = "checkpoint" +DEPLOYED_KEY = "deployed" +VERSION_KEY = "version" +MII_TERMINATE_DEP_KEY = "__MII_TERMINATE_CALL__" +DEPLOYMENT_TYPE_KEY = "deployment_type" MII_CACHE_PATH = "MII_CACHE_PATH" MII_CACHE_PATH_DEFAULT = "/tmp/mii_cache" @@ -118,7 +124,8 @@ class ModelProvider(enum.Enum): TERMINATE_METHOD = "Terminate" CREATE_SESSION_METHOD = "CreateSession" DESTROY_SESSION_METHOD = "DestroySession" - +ADD_DEPLOYMENT_METHOD = "AddDeployment" +DELETE_DEPLOYMENT_METHOD = "DeleteDeployment" LB_MAX_WORKER_THREADS = 32 SERVER_SHUTDOWN_TIMEOUT = 10 diff --git a/mii/deployment.py b/mii/deployment.py index 3cadd994..1976ca4a 100644 --- a/mii/deployment.py +++ b/mii/deployment.py @@ -13,18 +13,20 @@ from .utils import logger, get_task_name, get_provider_name from .models.score import create_score_file from .models import load_models -from .config import ReplicaConfig, LoadBalancerConfig +from .config import ReplicaConfig, LoadBalancerConfig, DeploymentConfig -def deploy(task, - model, - deployment_name, - deployment_type=DeploymentType.LOCAL, - model_path=None, +def deploy(task=None, + model=None, + deployment_name=None, enable_deepspeed=True, enable_zero=False, ds_config=None, mii_config={}, + deployment_tag=None, + deployments=[], + deployment_type=DeploymentType.LOCAL, + model_path=None, version=1): """Deploy a task using specified model. For usage examples see: @@ -66,119 +68,235 @@ def deploy(task, If deployment_type is `LOCAL`, returns just the name of the deployment that can be used to create a query handle using `mii.mii_query_handle(deployment_name)` """ + if not mii_config: + mii_config = mii.config.MIIConfig(**{}) + + if model_path is None and deployment_type == DeploymentType.LOCAL: + model_path = MII_MODEL_PATH_DEFAULT + elif model_path is None and deployment_type == DeploymentType.AML: + model_path = "model" + + deployment_tag, deployments = validate_deployment(task=task, + model=model, + deployment_name=deployment_name, + enable_deepspeed=enable_deepspeed, + enable_zero=enable_zero, + ds_config=ds_config, + mii_config=mii_config, + deployment_tag=deployment_tag, + deployments=deployments, + deployment_type=deployment_type, + model_path=model_path, + version=version) + + if not deployments: #Empty deployment + create_score_file(deployment_tag=deployment_tag, + deployment_type=deployment_type, + deployments=None, + model_path=model_path, + port_map=None, + lb_config=None) + print(f"Starting empty deployment, deployment_tag -> {deployment_tag}") + return None # parse and validate mii config - mii_config = mii.config.MIIConfig(**mii_config) - if enable_zero: - if ds_config.get("fp16", {}).get("enabled", False): - assert (mii_config.dtype == torch.half), "MII Config Error: MII dtype and ZeRO dtype must match" - else: - assert (mii_config.dtype == torch.float), "MII Config Error: MII dtype and ZeRO dtype must match" - assert not (enable_deepspeed and enable_zero), "MII Config Error: DeepSpeed and ZeRO cannot both be enabled, select only one" + for deployment in deployments: + #mii_config = getattr(deployment, mii.constants.MII_CONFIGS_KEY) + if getattr(deployment, mii.constants.ENABLE_DEEPSPEED_ZERO_KEY): + if getattr(deployment, + mii.constants.DEEPSPEED_CONFIG_KEY).get("fp16", + {}).get("enabled", + False): + assert (deployment.dtype == torch.half), "MII Config Error: MII dtype and ZeRO dtype must match" + else: + assert (deployment.dtype == torch.float), "MII Config Error: MII dtype and ZeRO dtype must match" + assert not (enable_deepspeed and enable_zero), "MII Config Error: DeepSpeed and ZeRO cannot both be enabled, select only one" # aml only allows certain characters for deployment names if deployment_type == DeploymentType.AML: + assert len(deployments == 1), "mii does not currently support empty/multi-model deployment on AML" allowed_chars = set(string.ascii_lowercase + string.ascii_uppercase + string.digits + '-') assert set(deployment_name) <= allowed_chars, "AML deployment names can only contain a-z, A-Z, 0-9, and '-'" - task = mii.utils.get_task(task) + if not mii_config.skip_model_check: + mii.utils.check_if_task_and_model_is_valid( + getattr(deployment, + mii.constants.TASK_NAME_KEY), + getattr(deployment, + mii.constants.MODEL_NAME_KEY)) + if enable_deepspeed: + mii.utils.check_if_task_and_model_is_supported( + deployment.task, + deployment.model) - if not mii_config.skip_model_check: - mii.utils.check_if_task_and_model_is_valid(task, model) if enable_deepspeed: - mii.utils.check_if_task_and_model_is_supported(task, model) - - if enable_deepspeed: - logger.info( - f"************* MII is using DeepSpeed Optimizations to accelerate your model *************" - ) - else: - logger.info( - f"************* DeepSpeed Optimizations not enabled. Please use enable_deepspeed to get better performance *************" - ) + logger.info( + f"************* MII is using DeepSpeed Optimizations to accelerate your model: {deployment.model} *************" + ) + else: + logger.info( + f"************* DeepSpeed Optimizations not enabled. Please use enable_deepspeed to get better performance for: {deployment.model} *************" + ) + deps = {deployment.deployment_name: deployment for deployment in deployments} # In local deployments use default path if no model path set - if model_path is None and deployment_type == DeploymentType.LOCAL: - model_path = MII_MODEL_PATH_DEFAULT - elif model_path is None and deployment_type == DeploymentType.AML: - model_path = "model" # add fields for replica deployment - replica_pool = _allocate_processes(mii_config.hostfile, - mii_config.tensor_parallel, - mii_config.replica_num) - replica_configs = [] - for i, (hostname, gpu_indices) in enumerate(replica_pool): - # Reserver port for a LB proxy when replication is enabled - port_offset = 1 - base_port = mii_config.port_number + i * mii_config.tensor_parallel + port_offset - tensor_parallel_ports = list( - range(base_port, - base_port + mii_config.tensor_parallel)) - torch_dist_port = mii_config.torch_dist_port + i - replica_configs.append( - ReplicaConfig(hostname=hostname, - tensor_parallel_ports=tensor_parallel_ports, - torch_dist_port=torch_dist_port, - gpu_indices=gpu_indices)) - lb_config = LoadBalancerConfig(port=mii_config.port_number, - replica_configs=replica_configs) + port_map = {} + lb_config, port_map = allocate_processes(deps, port_map, mii_config) if deployment_type != DeploymentType.NON_PERSISTENT: - create_score_file(deployment_name=deployment_name, + create_score_file(deployment_tag=deployment_tag, deployment_type=deployment_type, - task=task, - model_name=model, - ds_optimize=enable_deepspeed, - ds_zero=enable_zero, - ds_config=ds_config, - mii_config=mii_config, + deployments=deps, model_path=model_path, - lb_config=lb_config) + port_map=port_map, + lb_config=lb_config, + mii_configs=mii_config) if deployment_type == DeploymentType.AML: - _deploy_aml(deployment_name=deployment_name, model_name=model, version=version) + _deploy_aml(deployment_tag=deployment_tag, model_name=model, version=version) elif deployment_type == DeploymentType.LOCAL: - return _deploy_local(deployment_name, model_path=model_path) + return _deploy_local(deployment_tag, model_path=model_path) elif deployment_type == DeploymentType.NON_PERSISTENT: assert int(os.getenv('WORLD_SIZE', '1')) == mii_config.tensor_parallel, "World Size does not equal number of tensors. When using non-persistent deployment type, please launch with `deepspeed --num_gpus `" provider = MODEL_PROVIDER_MAP[get_provider_name(model, task)] mii.non_persistent_models[deployment_name] = (load_models( - get_task_name(task), + task, model, model_path, enable_deepspeed, enable_zero, provider, - mii_config), + deployment), task) else: raise Exception(f"Unknown deployment type: {deployment_type}") -def _deploy_local(deployment_name, model_path): - mii.utils.import_score_file(deployment_name).init() - - -def _deploy_aml(deployment_name, model_name, version): +def allocate_processes(deployments, port_map, mii_config): + replica_configs = [] + port_offset = 1 + for deployment in deployments.values(): + #mii_config = getattr(deployment, mii.constants.MII_CONFIGS_KEY) + replica_pool = _allocate_processes( + deployment.hostfile, + deployment.tensor_parallel, + deployment.replica_num, + getattr(deployment, + mii.constants.GPU_INDEX_KEY)) + + for i, (hostname, gpu_indices) in enumerate(replica_pool): + # Reserver port for a LB proxy when replication is enabled + if hostname not in port_map: + port_map[hostname] = set() + base_port = mii_config.port_number + i * deployment.tensor_parallel + port_offset + if base_port in port_map[hostname]: + base_port = max(port_map[hostname]) + 1 + tensor_parallel_ports = list( + range(base_port, + base_port + deployment.tensor_parallel)) + for i in range(base_port, base_port + deployment.tensor_parallel): + port_map[hostname].add(i) + torch_dist_port = mii_config.torch_dist_port + i + replica_configs.append( + ReplicaConfig( + task=get_task_name(getattr(deployment, + mii.constants.TASK_NAME_KEY)), + deployment_name=(getattr(deployment, + mii.constants.DEPLOYMENT_NAME_KEY)), + hostname=hostname, + tensor_parallel_ports=tensor_parallel_ports, + torch_dist_port=torch_dist_port, + gpu_indices=gpu_indices)) + lb_config = LoadBalancerConfig(port=mii_config.port_number, + replica_configs=replica_configs) + return lb_config, port_map + + +def validate_deployment(task=None, + model=None, + deployment_name=None, + enable_deepspeed=True, + enable_zero=False, + ds_config=None, + mii_config={}, + deployment_tag=None, + deployments=[], + deployment_type=DeploymentType.LOCAL, + model_path=None, + version=1): + + if deployments and any((model, task, deployment_name)): + assert False, "Do not input deployments and model/task/deployment_name at the same time" + + elif deployments: + assert deployment_tag, "deployment_tag must be set to for multiple models" + return deployment_tag, deployments + + elif not any((model, task, deployment_name)): + assert deployment_tag, "deployment_tag must be set for an empty deployment" + create_score_file(deployment_tag=deployment_tag, + deployment_type=deployment_type, + deployments=None, + model_path=model_path, + mii_configs={}, + port_map=None, + lb_config=None) + return deployment_tag, None + + assert all((model, task, deployment_name)), "model, task, and deployment_name must be set for a single model" + deployments = [ + DeploymentConfig(DEPLOYMENT_NAME_KEY=deployment_name, + TASK_NAME_KEY=task, + MODEL_NAME_KEY=model, + ENABLE_DEEPSPEED_KEY=enable_deepspeed, + ENABLE_DEEPSPEED_ZERO_KEY=enable_zero, + GPU_INDEX_KEY=None, + MII_CONFIGS_KEY=mii.config.MIIConfig(**mii_config), + DEEPSPEED_CONFIG_KEY=ds_config, + VERSION_KEY=version) + ] + if deployment_tag is None: + deployment_tag = deployment_name + return deployment_tag, deployments + + +def _deploy_local(deployment_tag, model_path): + mii.utils.import_score_file(deployment_tag).init() + + +def _deploy_aml(deployment_tag, model_name, version): acr_name = mii.aml_related.utils.get_acr_name() mii.aml_related.utils.generate_aml_scripts(acr_name=acr_name, - deployment_name=deployment_name, + deployment_name=deployment_tag, model_name=model_name, version=version) print( - f"AML deployment assets at {mii.aml_related.utils.aml_output_path(deployment_name)}" + f"AML deployment assets at {mii.aml_related.utils.aml_output_path(deployment_tag)}" ) print("Please run 'deploy.sh' to bring your deployment online") -def _allocate_processes(hostfile_path, tensor_parallel, num_replicas): +def _allocate_processes(hostfile_path, + tensor_parallel, + num_replicas, + gpu_index_map=None): resource_pool = fetch_hostfile(hostfile_path) assert resource_pool is not None and len( resource_pool) > 0, f'No hosts found in {hostfile_path}' replica_pool = [] + + if gpu_index_map is not None: + for host in gpu_index_map: + assert host in resource_pool, f"Host: {host} was not found" + assert resource_pool[host] >= tensor_parallel, f"Host {host} has {resource_pool[host]} slot(s), but {tensor_parallel} slot(s) are required" + for host in gpu_index_map: + replica_pool.append((host, gpu_index_map[host])) + return replica_pool + allocated_num = 0 for host, slots in resource_pool.items(): available_on_host = slots diff --git a/mii/grpc_related/modelresponse_server.py b/mii/grpc_related/modelresponse_server.py index 4a0a5d00..4aa485dc 100644 --- a/mii/grpc_related/modelresponse_server.py +++ b/mii/grpc_related/modelresponse_server.py @@ -13,10 +13,9 @@ import threading import time -from mii.constants import GRPC_MAX_MSG_SIZE, CREATE_SESSION_METHOD, DESTROY_SESSION_METHOD, TERMINATE_METHOD, LB_MAX_WORKER_THREADS, SERVER_SHUTDOWN_TIMEOUT, Tasks +from mii.constants import GRPC_MAX_MSG_SIZE, ADD_DEPLOYMENT_METHOD, DELETE_DEPLOYMENT_METHOD, CREATE_SESSION_METHOD, DESTROY_SESSION_METHOD, TERMINATE_METHOD, LB_MAX_WORKER_THREADS, SERVER_SHUTDOWN_TIMEOUT, Tasks from mii.method_table import GRPC_METHOD_TABLE from mii.client import create_channel -from mii.utils import get_task, unpack_proto_query_kwargs class ServiceBase(modelresponse_pb2_grpc.ModelResponseServicer): @@ -34,6 +33,18 @@ def get_stop_event(self): return self._stop_event +class DeploymentManagement(ServiceBase, + modelresponse_pb2_grpc.DeploymentManagementServicer): + def __init__(self): + ServiceBase.__init__(self) + + def AddDeployment(self, request, context): + return google_dot_protobuf_dot_empty__pb2.Empty() + + def DeleteDeployment(self, request, context): + return google_dot_protobuf_dot_empty__pb2.Empty() + + class ModelResponse(ServiceBase): """ Implementation class of an MII inference server @@ -138,15 +149,16 @@ class ParallelStubInvoker: This class aims to call gRPC methods without conversions between proto and python object. TensorParallelClient can be used for invocation with the conversions. """ - def __init__(self, host, ports): + def __init__(self, host, ports, asyncio_loop): # Assumption: target services are all on the same host self.stubs = [] for port in ports: + asyncio.set_event_loop(asyncio_loop) channel = create_channel(host, port) stub = modelresponse_pb2_grpc.ModelResponseStub(channel) self.stubs.append(stub) - self.asyncio_loop = asyncio.get_event_loop() + self.asyncio_loop = asyncio_loop async def _invoke_async(self, method_name, proto_request): responses = [] @@ -164,18 +176,24 @@ def invoke(self, method_name, proto_request): class LoadBalancingInterceptor(grpc.ServerInterceptor): - def __init__(self, task_name, replica_configs): + def __init__(self, replica_configs): super().__init__() self.asyncio_loop = asyncio.get_event_loop() - self.stubs = [ - ParallelStubInvoker(replica.hostname, - replica.tensor_parallel_ports) - for replica in replica_configs - ] - self.counter = AtomicCounter() - self.task = get_task(task_name) - self.replica_sessions = {} + self.stubs = {} + self.counter = {} + self.replica_configs = replica_configs + self.tasks = {} + for repl in replica_configs: + self.stubs[repl.deployment_name] = [] + self.counter[repl.deployment_name] = AtomicCounter() + self.tasks[repl.deployment_name] = repl.task + + for repl in replica_configs: + self.stubs[repl.deployment_name].append( + ParallelStubInvoker(repl.hostname, + repl.tensor_parallel_ports, + self.asyncio_loop)) # Start the asyncio loop in a separate thread def run_asyncio_loop(loop): @@ -193,38 +211,77 @@ def intercept_service(self, continuation, handler_call_details): def invoke_intercept_method(request_proto, context): method_name = _get_grpc_method_name(handler_call_details.method) + if method_name == ADD_DEPLOYMENT_METHOD: + deployment_name = str(getattr(request_proto, "deployment_name")) + if deployment_name not in self.stubs: + task = str(getattr(request_proto, "task")) + hostname = str(getattr(request_proto, "hostname")) + tensor_parallel_ports = list( + getattr(request_proto, + "tensor_parallel_ports")) + torch_dist_port = int(getattr(request_proto, "torch_dist_port")) + gpu_indices = list(getattr(request_proto, "gpu_indices")) + self.stubs[deployment_name] = [] + self.counter[deployment_name] = AtomicCounter() + self.tasks[deployment_name] = task + self.stubs[deployment_name].append( + ParallelStubInvoker(hostname, + tensor_parallel_ports, + self.asyncio_loop)) + else: + print(f"deployment: {deployment_name} already exists") + return google_dot_protobuf_dot_empty__pb2.Empty() if method_name == TERMINATE_METHOD: - for stub in self.stubs: - stub.invoke(TERMINATE_METHOD, - google_dot_protobuf_dot_empty__pb2.Empty()) + for deployment_name in self.stubs: + for stub in self.stubs[deployment_name]: + stub.invoke(TERMINATE_METHOD, + google_dot_protobuf_dot_empty__pb2.Empty()) self.asyncio_loop.call_soon_threadsafe(self.asyncio_loop.stop) return next_handler.unary_unary(request_proto, context) - call_count = self.counter.get_and_increment() - replica_index = call_count % len(self.stubs) + if method_name == DELETE_DEPLOYMENT_METHOD: + deployment_name = str(getattr(request_proto, "deployment_name")) + assert deployment_name in self.stubs, f"Deployment: {deployment_name} not found" + for stub in self.stubs[deployment_name]: + stub.invoke(TERMINATE_METHOD, + google_dot_protobuf_dot_empty__pb2.Empty()) + del self.stubs[deployment_name] + del self.counter[deployment_name] + del self.tasks[deployment_name] + return google_dot_protobuf_dot_empty__pb2.Empty() + + deployment_name = getattr(request_proto, 'deployment_name') + assert deployment_name in self.stubs, f"Deployment: {deployment_name} not found" + call_count = self.counter[deployment_name].get_and_increment() + replica_index = call_count % len(self.stubs[deployment_name]) if method_name == CREATE_SESSION_METHOD: if request_proto.session_id in self.sessions: raise ValueError( f"session {request_proto.session_id} already exists") self.replica_sessions[request_proto.session_id] = replica_index - self.stubs[replica_index].invoke(CREATE_SESSION_METHOD, request_proto) + self.stubs[deployment_name][replica_index].invoke( + CREATE_SESSION_METHOD, + request_proto) return google_dot_protobuf_dot_empty__pb2.Empty() if method_name == DESTROY_SESSION_METHOD: replica_index = self.replica_sessions.pop(request_proto.session_id) - self.stubs[replica_index].invoke(DESTROY_SESSION_METHOD, request_proto) + self.stubs[deployment_name][replica_index].invoke( + DESTROY_SESSION_METHOD, + request_proto) return google_dot_protobuf_dot_empty__pb2.Empty() - kwargs = unpack_proto_query_kwargs(request_proto.query_kwargs) - if "session_id" in kwargs: - session_id = kwargs["session_id"] + if "session_id" in request_proto.query_kwargs: + session_id = request_proto.query_kwargs["session_id"] if session_id not in self.replica_sessions: raise ValueError(f"session not found") replica_index = self.replica_sessions[session_id] - ret = self.stubs[replica_index].invoke(method_name, request_proto) + ret = self.stubs[deployment_name][replica_index].invoke( + method_name, + request_proto) return ret return grpc.unary_unary_rpc_method_handler( @@ -233,7 +290,7 @@ def invoke_intercept_method(request_proto, context): response_serializer=next_handler.response_serializer) -def _do_serve(service_impl, port, interceptors=[]): +def _do_serve(service_impl, port, interceptors=[], is_lb=False): stop_event = service_impl.get_stop_event() server = grpc.server(futures.ThreadPoolExecutor(max_workers=LB_MAX_WORKER_THREADS), interceptors=interceptors, @@ -241,7 +298,15 @@ def _do_serve(service_impl, port, interceptors=[]): GRPC_MAX_MSG_SIZE), ('grpc.max_receive_message_length', GRPC_MAX_MSG_SIZE)]) - modelresponse_pb2_grpc.add_ModelResponseServicer_to_server(service_impl, server) + if is_lb: + modelresponse_pb2_grpc.add_DeploymentManagementServicer_to_server( + service_impl, + server) + modelresponse_pb2_grpc.add_ModelResponseServicer_to_server( + ModelResponse(None), + server) + else: + modelresponse_pb2_grpc.add_ModelResponseServicer_to_server(service_impl, server) server.add_insecure_port(f'[::]:{port}') print(f"About to start server") server.start() @@ -254,11 +319,11 @@ def serve_inference(inference_pipeline, port): _do_serve(ModelResponse(inference_pipeline), port) -def serve_load_balancing(task_name, lb_config): - _do_serve(ServiceBase(), +def serve_load_balancing(lb_config): + _do_serve(DeploymentManagement(), lb_config.port, - [LoadBalancingInterceptor(task_name, - lb_config.replica_configs)]) + [LoadBalancingInterceptor(lb_config.replica_configs)], + True) if __name__ == '__main__': diff --git a/mii/grpc_related/proto/modelresponse.proto b/mii/grpc_related/proto/modelresponse.proto index a0698899..146e1f30 100644 --- a/mii/grpc_related/proto/modelresponse.proto +++ b/mii/grpc_related/proto/modelresponse.proto @@ -36,6 +36,12 @@ service ModelResponse { rpc Txt2ImgReply(MultiStringRequest) returns (ImageReply) {} } +service DeploymentManagement { + rpc AddDeployment(AddDeployRequest) returns (google.protobuf.Empty) {} + rpc DeleteDeployment(DeleteDeployRequest) returns (google.protobuf.Empty) {} + rpc Terminate (google.protobuf.Empty) returns (google.protobuf.Empty) {} +} + message Value { oneof oneof_values { string svalue = 1; @@ -52,29 +58,34 @@ message SessionID { message SingleStringRequest { string request = 1; map query_kwargs = 2; + optional string deployment_name = 3; } message MultiStringRequest { repeated string request = 1; map query_kwargs = 2; + optional string deployment_name = 3; } message SingleStringReply { string response = 1; float time_taken = 2; float model_time_taken = 3; + optional string deployment_name = 4; } message MultiStringReply { repeated string response = 1; float time_taken = 2; float model_time_taken = 3; + optional string deployment_name = 4; } message QARequest { string question = 1; string context = 2; map query_kwargs = 3; + optional string deployment_name = 4; } message ConversationRequest { @@ -83,6 +94,7 @@ message ConversationRequest { repeated string past_user_inputs = 3; repeated string generated_responses = 4; map query_kwargs = 5; + optional string deployment_name = 6; } message ConversationReply { @@ -91,6 +103,7 @@ message ConversationReply { repeated string generated_responses = 3; float time_taken = 4; float model_time_taken = 5; + optional string deployment_name = 6; } message ImageReply { @@ -100,4 +113,19 @@ message ImageReply { int64 size_w = 4; int64 size_h = 5; float time_taken = 6; + optional string deployment_name = 7; +} + +message AddDeployRequest { + string task = 1; + string deployment_name = 2; + string hostname = 3; + repeated int64 tensor_parallel_ports = 4; + int64 torch_dist_port = 5; + repeated int64 gpu_indices = 6; + +} + +message DeleteDeployRequest { + string deployment_name = 1; } diff --git a/mii/grpc_related/proto/modelresponse_pb2.py b/mii/grpc_related/proto/modelresponse_pb2.py index 76b1f994..72c33ed8 100644 --- a/mii/grpc_related/proto/modelresponse_pb2.py +++ b/mii/grpc_related/proto/modelresponse_pb2.py @@ -2,14 +2,14 @@ # SPDX-License-Identifier: Apache-2.0 # DeepSpeed Team - +# -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! # source: modelresponse.proto """Generated protocol buffer code.""" -from google.protobuf.internal import builder as _builder from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor_pool as _descriptor_pool from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder # @@protoc_insertion_point(imports) _sym_db = _symbol_database.Default() @@ -17,11 +17,12 @@ from google.protobuf import empty_pb2 as google_dot_protobuf_dot_empty__pb2 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x13modelresponse.proto\x12\rmodelresponse\x1a\x1bgoogle/protobuf/empty.proto\"_\n\x05Value\x12\x10\n\x06svalue\x18\x01 \x01(\tH\x00\x12\x10\n\x06ivalue\x18\x02 \x01(\x03H\x00\x12\x10\n\x06\x66value\x18\x03 \x01(\x02H\x00\x12\x10\n\x06\x62value\x18\x04 \x01(\x08H\x00\x42\x0e\n\x0coneof_values\"\x1f\n\tSessionID\x12\x12\n\nsession_id\x18\x01 \x01(\t\"\xbb\x01\n\x13SingleStringRequest\x12\x0f\n\x07request\x18\x01 \x01(\t\x12I\n\x0cquery_kwargs\x18\x02 \x03(\x0b\x32\x33.modelresponse.SingleStringRequest.QueryKwargsEntry\x1aH\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.modelresponse.Value:\x02\x38\x01\"\xb9\x01\n\x12MultiStringRequest\x12\x0f\n\x07request\x18\x01 \x03(\t\x12H\n\x0cquery_kwargs\x18\x02 \x03(\x0b\x32\x32.modelresponse.MultiStringRequest.QueryKwargsEntry\x1aH\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.modelresponse.Value:\x02\x38\x01\"S\n\x11SingleStringReply\x12\x10\n\x08response\x18\x01 \x01(\t\x12\x12\n\ntime_taken\x18\x02 \x01(\x02\x12\x18\n\x10model_time_taken\x18\x03 \x01(\x02\"R\n\x10MultiStringReply\x12\x10\n\x08response\x18\x01 \x03(\t\x12\x12\n\ntime_taken\x18\x02 \x01(\x02\x12\x18\n\x10model_time_taken\x18\x03 \x01(\x02\"\xb9\x01\n\tQARequest\x12\x10\n\x08question\x18\x01 \x01(\t\x12\x0f\n\x07\x63ontext\x18\x02 \x01(\t\x12?\n\x0cquery_kwargs\x18\x03 \x03(\x0b\x32).modelresponse.QARequest.QueryKwargsEntry\x1aH\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.modelresponse.Value:\x02\x38\x01\"\xa1\x02\n\x13\x43onversationRequest\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\x1c\n\x0f\x63onversation_id\x18\x02 \x01(\x03H\x00\x88\x01\x01\x12\x18\n\x10past_user_inputs\x18\x03 \x03(\t\x12\x1b\n\x13generated_responses\x18\x04 \x03(\t\x12I\n\x0cquery_kwargs\x18\x05 \x03(\x0b\x32\x33.modelresponse.ConversationRequest.QueryKwargsEntry\x1aH\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.modelresponse.Value:\x02\x38\x01\x42\x12\n\x10_conversation_id\"\x91\x01\n\x11\x43onversationReply\x12\x17\n\x0f\x63onversation_id\x18\x01 \x01(\x03\x12\x18\n\x10past_user_inputs\x18\x02 \x03(\t\x12\x1b\n\x13generated_responses\x18\x03 \x03(\t\x12\x12\n\ntime_taken\x18\x04 \x01(\x02\x12\x18\n\x10model_time_taken\x18\x05 \x01(\x02\"}\n\nImageReply\x12\x0e\n\x06images\x18\x01 \x03(\x0c\x12\x1d\n\x15nsfw_content_detected\x18\x02 \x03(\x08\x12\x0c\n\x04mode\x18\x03 \x01(\t\x12\x0e\n\x06size_w\x18\x04 \x01(\x03\x12\x0e\n\x06size_h\x18\x05 \x01(\x03\x12\x12\n\ntime_taken\x18\x06 \x01(\x02\x32\xd4\x06\n\rModelResponse\x12=\n\tTerminate\x12\x16.google.protobuf.Empty\x1a\x16.google.protobuf.Empty\"\x00\x12\x43\n\rCreateSession\x12\x18.modelresponse.SessionID\x1a\x16.google.protobuf.Empty\"\x00\x12\x44\n\x0e\x44\x65stroySession\x12\x18.modelresponse.SessionID\x1a\x16.google.protobuf.Empty\"\x00\x12V\n\x0eGeneratorReply\x12!.modelresponse.MultiStringRequest\x1a\x1f.modelresponse.MultiStringReply\"\x00\x12]\n\x13\x43lassificationReply\x12\".modelresponse.SingleStringRequest\x1a .modelresponse.SingleStringReply\"\x00\x12V\n\x16QuestionAndAnswerReply\x12\x18.modelresponse.QARequest\x1a .modelresponse.SingleStringReply\"\x00\x12W\n\rFillMaskReply\x12\".modelresponse.SingleStringRequest\x1a .modelresponse.SingleStringReply\"\x00\x12\x62\n\x18TokenClassificationReply\x12\".modelresponse.SingleStringRequest\x1a .modelresponse.SingleStringReply\"\x00\x12]\n\x13\x43onversationalReply\x12\".modelresponse.ConversationRequest\x1a .modelresponse.ConversationReply\"\x00\x12N\n\x0cTxt2ImgReply\x12!.modelresponse.MultiStringRequest\x1a\x19.modelresponse.ImageReply\"\x00\x62\x06proto3' + b'\n\x13modelresponse.proto\x12\rmodelresponse\x1a\x1bgoogle/protobuf/empty.proto\"_\n\x05Value\x12\x10\n\x06svalue\x18\x01 \x01(\tH\x00\x12\x10\n\x06ivalue\x18\x02 \x01(\x03H\x00\x12\x10\n\x06\x66value\x18\x03 \x01(\x02H\x00\x12\x10\n\x06\x62value\x18\x04 \x01(\x08H\x00\x42\x0e\n\x0coneof_values\"\x1f\n\tSessionID\x12\x12\n\nsession_id\x18\x01 \x01(\t\"\xed\x01\n\x13SingleStringRequest\x12\x0f\n\x07request\x18\x01 \x01(\t\x12I\n\x0cquery_kwargs\x18\x02 \x03(\x0b\x32\x33.modelresponse.SingleStringRequest.QueryKwargsEntry\x12\x1c\n\x0f\x64\x65ployment_name\x18\x03 \x01(\tH\x00\x88\x01\x01\x1aH\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.modelresponse.Value:\x02\x38\x01\x42\x12\n\x10_deployment_name\"\xeb\x01\n\x12MultiStringRequest\x12\x0f\n\x07request\x18\x01 \x03(\t\x12H\n\x0cquery_kwargs\x18\x02 \x03(\x0b\x32\x32.modelresponse.MultiStringRequest.QueryKwargsEntry\x12\x1c\n\x0f\x64\x65ployment_name\x18\x03 \x01(\tH\x00\x88\x01\x01\x1aH\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.modelresponse.Value:\x02\x38\x01\x42\x12\n\x10_deployment_name\"\x85\x01\n\x11SingleStringReply\x12\x10\n\x08response\x18\x01 \x01(\t\x12\x12\n\ntime_taken\x18\x02 \x01(\x02\x12\x18\n\x10model_time_taken\x18\x03 \x01(\x02\x12\x1c\n\x0f\x64\x65ployment_name\x18\x04 \x01(\tH\x00\x88\x01\x01\x42\x12\n\x10_deployment_name\"\x84\x01\n\x10MultiStringReply\x12\x10\n\x08response\x18\x01 \x03(\t\x12\x12\n\ntime_taken\x18\x02 \x01(\x02\x12\x18\n\x10model_time_taken\x18\x03 \x01(\x02\x12\x1c\n\x0f\x64\x65ployment_name\x18\x04 \x01(\tH\x00\x88\x01\x01\x42\x12\n\x10_deployment_name\"\xeb\x01\n\tQARequest\x12\x10\n\x08question\x18\x01 \x01(\t\x12\x0f\n\x07\x63ontext\x18\x02 \x01(\t\x12?\n\x0cquery_kwargs\x18\x03 \x03(\x0b\x32).modelresponse.QARequest.QueryKwargsEntry\x12\x1c\n\x0f\x64\x65ployment_name\x18\x04 \x01(\tH\x00\x88\x01\x01\x1aH\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.modelresponse.Value:\x02\x38\x01\x42\x12\n\x10_deployment_name\"\xd3\x02\n\x13\x43onversationRequest\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\x1c\n\x0f\x63onversation_id\x18\x02 \x01(\x03H\x00\x88\x01\x01\x12\x18\n\x10past_user_inputs\x18\x03 \x03(\t\x12\x1b\n\x13generated_responses\x18\x04 \x03(\t\x12I\n\x0cquery_kwargs\x18\x05 \x03(\x0b\x32\x33.modelresponse.ConversationRequest.QueryKwargsEntry\x12\x1c\n\x0f\x64\x65ployment_name\x18\x06 \x01(\tH\x01\x88\x01\x01\x1aH\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.modelresponse.Value:\x02\x38\x01\x42\x12\n\x10_conversation_idB\x12\n\x10_deployment_name\"\xc3\x01\n\x11\x43onversationReply\x12\x17\n\x0f\x63onversation_id\x18\x01 \x01(\x03\x12\x18\n\x10past_user_inputs\x18\x02 \x03(\t\x12\x1b\n\x13generated_responses\x18\x03 \x03(\t\x12\x12\n\ntime_taken\x18\x04 \x01(\x02\x12\x18\n\x10model_time_taken\x18\x05 \x01(\x02\x12\x1c\n\x0f\x64\x65ployment_name\x18\x06 \x01(\tH\x00\x88\x01\x01\x42\x12\n\x10_deployment_name\"\xaf\x01\n\nImageReply\x12\x0e\n\x06images\x18\x01 \x03(\x0c\x12\x1d\n\x15nsfw_content_detected\x18\x02 \x03(\x08\x12\x0c\n\x04mode\x18\x03 \x01(\t\x12\x0e\n\x06size_w\x18\x04 \x01(\x03\x12\x0e\n\x06size_h\x18\x05 \x01(\x03\x12\x12\n\ntime_taken\x18\x06 \x01(\x02\x12\x1c\n\x0f\x64\x65ployment_name\x18\x07 \x01(\tH\x00\x88\x01\x01\x42\x12\n\x10_deployment_name\"\x98\x01\n\x10\x41\x64\x64\x44\x65ployRequest\x12\x0c\n\x04task\x18\x01 \x01(\t\x12\x17\n\x0f\x64\x65ployment_name\x18\x02 \x01(\t\x12\x10\n\x08hostname\x18\x03 \x01(\t\x12\x1d\n\x15tensor_parallel_ports\x18\x04 \x03(\x03\x12\x17\n\x0ftorch_dist_port\x18\x05 \x01(\x03\x12\x13\n\x0bgpu_indices\x18\x06 \x03(\x03\".\n\x13\x44\x65leteDeployRequest\x12\x17\n\x0f\x64\x65ployment_name\x18\x01 \x01(\t2\xd4\x06\n\rModelResponse\x12=\n\tTerminate\x12\x16.google.protobuf.Empty\x1a\x16.google.protobuf.Empty\"\x00\x12\x43\n\rCreateSession\x12\x18.modelresponse.SessionID\x1a\x16.google.protobuf.Empty\"\x00\x12\x44\n\x0e\x44\x65stroySession\x12\x18.modelresponse.SessionID\x1a\x16.google.protobuf.Empty\"\x00\x12V\n\x0eGeneratorReply\x12!.modelresponse.MultiStringRequest\x1a\x1f.modelresponse.MultiStringReply\"\x00\x12]\n\x13\x43lassificationReply\x12\".modelresponse.SingleStringRequest\x1a .modelresponse.SingleStringReply\"\x00\x12V\n\x16QuestionAndAnswerReply\x12\x18.modelresponse.QARequest\x1a .modelresponse.SingleStringReply\"\x00\x12W\n\rFillMaskReply\x12\".modelresponse.SingleStringRequest\x1a .modelresponse.SingleStringReply\"\x00\x12\x62\n\x18TokenClassificationReply\x12\".modelresponse.SingleStringRequest\x1a .modelresponse.SingleStringReply\"\x00\x12]\n\x13\x43onversationalReply\x12\".modelresponse.ConversationRequest\x1a .modelresponse.ConversationReply\"\x00\x12N\n\x0cTxt2ImgReply\x12!.modelresponse.MultiStringRequest\x1a\x19.modelresponse.ImageReply\"\x00\x32\xf3\x01\n\x14\x44\x65ploymentManagement\x12J\n\rAddDeployment\x12\x1f.modelresponse.AddDeployRequest\x1a\x16.google.protobuf.Empty\"\x00\x12P\n\x10\x44\x65leteDeployment\x12\".modelresponse.DeleteDeployRequest\x1a\x16.google.protobuf.Empty\"\x00\x12=\n\tTerminate\x12\x16.google.protobuf.Empty\x1a\x16.google.protobuf.Empty\"\x00\x62\x06proto3' ) -_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) -_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'modelresponse_pb2', globals()) +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'modelresponse_pb2', _globals) if _descriptor._USE_C_DESCRIPTORS == False: DESCRIPTOR._options = None @@ -33,34 +34,40 @@ _QAREQUEST_QUERYKWARGSENTRY._serialized_options = b'8\001' _CONVERSATIONREQUEST_QUERYKWARGSENTRY._options = None _CONVERSATIONREQUEST_QUERYKWARGSENTRY._serialized_options = b'8\001' - _VALUE._serialized_start = 67 - _VALUE._serialized_end = 162 - _SESSIONID._serialized_start = 164 - _SESSIONID._serialized_end = 195 - _SINGLESTRINGREQUEST._serialized_start = 198 - _SINGLESTRINGREQUEST._serialized_end = 385 - _SINGLESTRINGREQUEST_QUERYKWARGSENTRY._serialized_start = 313 - _SINGLESTRINGREQUEST_QUERYKWARGSENTRY._serialized_end = 385 - _MULTISTRINGREQUEST._serialized_start = 388 - _MULTISTRINGREQUEST._serialized_end = 573 - _MULTISTRINGREQUEST_QUERYKWARGSENTRY._serialized_start = 313 - _MULTISTRINGREQUEST_QUERYKWARGSENTRY._serialized_end = 385 - _SINGLESTRINGREPLY._serialized_start = 575 - _SINGLESTRINGREPLY._serialized_end = 658 - _MULTISTRINGREPLY._serialized_start = 660 - _MULTISTRINGREPLY._serialized_end = 742 - _QAREQUEST._serialized_start = 745 - _QAREQUEST._serialized_end = 930 - _QAREQUEST_QUERYKWARGSENTRY._serialized_start = 313 - _QAREQUEST_QUERYKWARGSENTRY._serialized_end = 385 - _CONVERSATIONREQUEST._serialized_start = 933 - _CONVERSATIONREQUEST._serialized_end = 1222 - _CONVERSATIONREQUEST_QUERYKWARGSENTRY._serialized_start = 313 - _CONVERSATIONREQUEST_QUERYKWARGSENTRY._serialized_end = 385 - _CONVERSATIONREPLY._serialized_start = 1225 - _CONVERSATIONREPLY._serialized_end = 1370 - _IMAGEREPLY._serialized_start = 1372 - _IMAGEREPLY._serialized_end = 1497 - _MODELRESPONSE._serialized_start = 1500 - _MODELRESPONSE._serialized_end = 2352 + _globals['_VALUE']._serialized_start = 67 + _globals['_VALUE']._serialized_end = 162 + _globals['_SESSIONID']._serialized_start = 164 + _globals['_SESSIONID']._serialized_end = 195 + _globals['_SINGLESTRINGREQUEST']._serialized_start = 198 + _globals['_SINGLESTRINGREQUEST']._serialized_end = 435 + _globals['_SINGLESTRINGREQUEST_QUERYKWARGSENTRY']._serialized_start = 343 + _globals['_SINGLESTRINGREQUEST_QUERYKWARGSENTRY']._serialized_end = 415 + _globals['_MULTISTRINGREQUEST']._serialized_start = 438 + _globals['_MULTISTRINGREQUEST']._serialized_end = 673 + _globals['_MULTISTRINGREQUEST_QUERYKWARGSENTRY']._serialized_start = 343 + _globals['_MULTISTRINGREQUEST_QUERYKWARGSENTRY']._serialized_end = 415 + _globals['_SINGLESTRINGREPLY']._serialized_start = 676 + _globals['_SINGLESTRINGREPLY']._serialized_end = 809 + _globals['_MULTISTRINGREPLY']._serialized_start = 812 + _globals['_MULTISTRINGREPLY']._serialized_end = 944 + _globals['_QAREQUEST']._serialized_start = 947 + _globals['_QAREQUEST']._serialized_end = 1182 + _globals['_QAREQUEST_QUERYKWARGSENTRY']._serialized_start = 343 + _globals['_QAREQUEST_QUERYKWARGSENTRY']._serialized_end = 415 + _globals['_CONVERSATIONREQUEST']._serialized_start = 1185 + _globals['_CONVERSATIONREQUEST']._serialized_end = 1524 + _globals['_CONVERSATIONREQUEST_QUERYKWARGSENTRY']._serialized_start = 343 + _globals['_CONVERSATIONREQUEST_QUERYKWARGSENTRY']._serialized_end = 415 + _globals['_CONVERSATIONREPLY']._serialized_start = 1527 + _globals['_CONVERSATIONREPLY']._serialized_end = 1722 + _globals['_IMAGEREPLY']._serialized_start = 1725 + _globals['_IMAGEREPLY']._serialized_end = 1900 + _globals['_ADDDEPLOYREQUEST']._serialized_start = 1903 + _globals['_ADDDEPLOYREQUEST']._serialized_end = 2055 + _globals['_DELETEDEPLOYREQUEST']._serialized_start = 2057 + _globals['_DELETEDEPLOYREQUEST']._serialized_end = 2103 + _globals['_MODELRESPONSE']._serialized_start = 2106 + _globals['_MODELRESPONSE']._serialized_end = 2958 + _globals['_DEPLOYMENTMANAGEMENT']._serialized_start = 2961 + _globals['_DEPLOYMENTMANAGEMENT']._serialized_end = 3204 # @@protoc_insertion_point(module_scope) diff --git a/mii/grpc_related/proto/modelresponse_pb2_grpc.py b/mii/grpc_related/proto/modelresponse_pb2_grpc.py index 95cfa825..5334f127 100644 --- a/mii/grpc_related/proto/modelresponse_pb2_grpc.py +++ b/mii/grpc_related/proto/modelresponse_pb2_grpc.py @@ -2,7 +2,6 @@ # SPDX-License-Identifier: Apache-2.0 # DeepSpeed Team - # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! """Client and server classes corresponding to protobuf-defined services.""" import grpc @@ -468,3 +467,162 @@ def Txt2ImgReply(request, wait_for_ready, timeout, metadata) + + +class DeploymentManagementStub(object): + """Missing associated documentation comment in .proto file.""" + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.AddDeployment = channel.unary_unary( + '/modelresponse.DeploymentManagement/AddDeployment', + request_serializer=modelresponse__pb2.AddDeployRequest.SerializeToString, + response_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString, + ) + self.DeleteDeployment = channel.unary_unary( + '/modelresponse.DeploymentManagement/DeleteDeployment', + request_serializer=modelresponse__pb2.DeleteDeployRequest.SerializeToString, + response_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString, + ) + self.Terminate = channel.unary_unary( + '/modelresponse.DeploymentManagement/Terminate', + request_serializer=google_dot_protobuf_dot_empty__pb2.Empty. + SerializeToString, + response_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString, + ) + + +class DeploymentManagementServicer(object): + """Missing associated documentation comment in .proto file.""" + def AddDeployment(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def DeleteDeployment(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def Terminate(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + +def add_DeploymentManagementServicer_to_server(servicer, server): + rpc_method_handlers = { + 'AddDeployment': + grpc.unary_unary_rpc_method_handler( + servicer.AddDeployment, + request_deserializer=modelresponse__pb2.AddDeployRequest.FromString, + response_serializer=google_dot_protobuf_dot_empty__pb2.Empty. + SerializeToString, + ), + 'DeleteDeployment': + grpc.unary_unary_rpc_method_handler( + servicer.DeleteDeployment, + request_deserializer=modelresponse__pb2.DeleteDeployRequest.FromString, + response_serializer=google_dot_protobuf_dot_empty__pb2.Empty. + SerializeToString, + ), + 'Terminate': + grpc.unary_unary_rpc_method_handler( + servicer.Terminate, + request_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString, + response_serializer=google_dot_protobuf_dot_empty__pb2.Empty. + SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + 'modelresponse.DeploymentManagement', + rpc_method_handlers) + server.add_generic_rpc_handlers((generic_handler, )) + + +# This class is part of an EXPERIMENTAL API. +class DeploymentManagement(object): + """Missing associated documentation comment in .proto file.""" + @staticmethod + def AddDeployment(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/modelresponse.DeploymentManagement/AddDeployment', + modelresponse__pb2.AddDeployRequest.SerializeToString, + google_dot_protobuf_dot_empty__pb2.Empty.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata) + + @staticmethod + def DeleteDeployment(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/modelresponse.DeploymentManagement/DeleteDeployment', + modelresponse__pb2.DeleteDeployRequest.SerializeToString, + google_dot_protobuf_dot_empty__pb2.Empty.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata) + + @staticmethod + def Terminate(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/modelresponse.DeploymentManagement/Terminate', + google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString, + google_dot_protobuf_dot_empty__pb2.Empty.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata) diff --git a/mii/grpc_related/restful_gateway.py b/mii/grpc_related/restful_gateway.py index e8cfa934..f4302d45 100644 --- a/mii/grpc_related/restful_gateway.py +++ b/mii/grpc_related/restful_gateway.py @@ -19,7 +19,7 @@ def shutdown(thread): def createRestfulGatewayApp(deployment_name, task, mii_config, server_thread): # client must be thread-safe - client = mii.MIIClient(task, "localhost", mii_config.port_number) + client = mii.mii_query_handle(deployment_name) class RestfulGatewayService(Resource): def __init__(self): diff --git a/mii/launch/load_balance_server.py b/mii/launch/load_balance_server.py new file mode 100644 index 00000000..01de3822 --- /dev/null +++ b/mii/launch/load_balance_server.py @@ -0,0 +1,31 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +import argparse + +from mii import LoadBalancerConfig + +from mii.grpc_related.modelresponse_server import serve_load_balancing +from .utils import decode_config_from_str + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--load-balancer", + type=str, + default=None, + help="base64 encoded load balancer config") + + args = parser.parse_args() + assert args.load_balancer is not None, "lb_config required to use load balancer" + lb_config_dict = decode_config_from_str(args.load_balancer) + lb_config = LoadBalancerConfig(**lb_config_dict) + + print(f"Starting load balancer on port: {lb_config.port}") + serve_load_balancing(lb_config) + + +if __name__ == "__main__": + # python -m mii.launch.load_balance_server + main() diff --git a/mii/launch/multi_gpu_server.py b/mii/launch/multi_gpu_server.py index 27878725..194cc4a9 100644 --- a/mii/launch/multi_gpu_server.py +++ b/mii/launch/multi_gpu_server.py @@ -5,23 +5,13 @@ import os import argparse import mii -import base64 -import json - -from mii import MIIConfig, LoadBalancerConfig +from mii import MIIConfig, LoadBalancerConfig, DeploymentConfig +from mii.utils import get_task_name from mii.models.load_models import load_models from mii.grpc_related.modelresponse_server import serve_inference, serve_load_balancing from mii.grpc_related.restful_gateway import RestfulGatewayThread - - -def decode_config_from_str(config_str): - # str -> bytes - b64_bytes = config_str.encode() - # decode b64 bytes -> json bytes - config_bytes = base64.urlsafe_b64decode(b64_bytes) - # convert json bytes -> str -> dict - return json.loads(config_bytes.decode()) +from .utils import decode_config_from_str def main(): @@ -55,6 +45,7 @@ def main(): "--restful-gateway", action='store_true', help="launch restful api gateway") + parser.add_argument("-f", "--deployment", type=str, help="base64 encoded deployment") args = parser.parse_args() @@ -63,6 +54,9 @@ def main(): # convert dict -> mii config mii_config = MIIConfig(**config_dict) + deployment_dict = decode_config_from_str(args.deployment) + deployment_dict['task'] = get_task_name(mii.constants.Tasks(deployment_dict['task'])) + deployment = DeploymentConfig(**deployment_dict) if args.restful_gateway: print(f"Starting RESTful API gateway on port: {mii_config.restful_api_port}") gateway_thread = RestfulGatewayThread(args.deployment_name, @@ -87,7 +81,7 @@ def main(): ds_zero=args.ds_zero, ds_config_path=args.ds_config, provider=provider, - mii_config=mii_config) + mii_config=deployment) print(f"Starting server on port: {port}") serve_inference(inference_pipeline, port) diff --git a/mii/launch/utils.py b/mii/launch/utils.py new file mode 100644 index 00000000..9e039409 --- /dev/null +++ b/mii/launch/utils.py @@ -0,0 +1,15 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +import base64 +import json + + +def decode_config_from_str(config_str): + # str -> bytes + b64_bytes = config_str.encode() + # decode b64 bytes -> json bytes + config_bytes = base64.urlsafe_b64decode(b64_bytes) + # convert json bytes -> str -> dict + return json.loads(config_bytes.decode()) diff --git a/mii/method_table.py b/mii/method_table.py index c412f446..f7f87d28 100644 --- a/mii/method_table.py +++ b/mii/method_table.py @@ -13,7 +13,8 @@ def single_string_request_to_proto(self, request_dict, **query_kwargs): return modelresponse_pb2.SingleStringRequest( request=request_dict['query'], - query_kwargs=kwarg_dict_to_proto(query_kwargs)) + query_kwargs=kwarg_dict_to_proto(query_kwargs), + deployment_name=request_dict.get('deployment_name')) def single_string_response_to_proto(self, response, time_taken, model_time_taken): @@ -26,7 +27,8 @@ def multi_string_request_to_proto(self, request_dict, **query_kwargs): return modelresponse_pb2.MultiStringRequest( request=request_dict['query'] if isinstance(request_dict['query'], list) else [request_dict['query']], - query_kwargs=kwarg_dict_to_proto(query_kwargs)) + query_kwargs=kwarg_dict_to_proto(query_kwargs), + deployment_name=request_dict.get('deployment_name')) def proto_request_to_single_input(self, request): @@ -143,7 +145,8 @@ def pack_request_to_proto(self, request_dict, **query_kwargs): return modelresponse_pb2.QARequest( question=request_dict['question'], context=request_dict['context'], - query_kwargs=kwarg_dict_to_proto(query_kwargs)) + query_kwargs=kwarg_dict_to_proto(query_kwargs), + deployment_name=request_dict.get('deployment_name')) def unpack_request_from_proto(self, request): kwargs = unpack_proto_query_kwargs(request.query_kwargs) @@ -222,7 +225,8 @@ def pack_request_to_proto(self, request_dict, **query_kwargs): if 'conversation_id' in request_dict else None, past_user_inputs=request_dict['past_user_inputs'], generated_responses=request_dict['generated_responses'], - query_kwargs=kwarg_dict_to_proto(query_kwargs)) + query_kwargs=kwarg_dict_to_proto(query_kwargs), + deployment_name=request_dict.get('deployment_name')) class Text2ImgMethods(TaskMethods): diff --git a/mii/models/providers/huggingface.py b/mii/models/providers/huggingface.py index c04a6829..27f456aa 100644 --- a/mii/models/providers/huggingface.py +++ b/mii/models/providers/huggingface.py @@ -194,5 +194,6 @@ def hf_provider(model_path, model_name, task_name, mii_config): framework="pt", use_auth_token=mii_config.hf_auth_token, torch_dtype=mii_config.dtype, + trust_remote_code=mii_config.trust_remote_code, ) return inference_pipeline diff --git a/mii/models/score/generate.py b/mii/models/score/generate.py index 1184d70e..6d608fc8 100644 --- a/mii/models/score/generate.py +++ b/mii/models/score/generate.py @@ -9,25 +9,88 @@ from mii.constants import DeploymentType -def create_score_file(deployment_name, +def create_score_file(deployment_tag, deployment_type, - task, - model_name, - ds_optimize, - ds_zero, - ds_config, - mii_config, + deployments, model_path, - lb_config): + port_map, + lb_config, + mii_configs={}, + deployed=False): + config_dict = {} - config_dict[mii.constants.DEPLOYMENT_NAME_KEY] = deployment_name - config_dict[mii.constants.TASK_NAME_KEY] = mii.utils.get_task_name(task) - config_dict[mii.constants.MODEL_NAME_KEY] = model_name - config_dict[mii.constants.ENABLE_DEEPSPEED_KEY] = ds_optimize - config_dict[mii.constants.MII_CONFIGS_KEY] = mii_config.dict() - config_dict[mii.constants.ENABLE_DEEPSPEED_ZERO_KEY] = ds_zero - config_dict[mii.constants.DEEPSPEED_CONFIG_KEY] = ds_config + config_dict[ + mii.constants.MII_CONFIGS_KEY] = mii_configs.dict() if mii_configs else {} + config_dict[mii.constants.DEPLOYMENT_TYPE_KEY] = deployment_type.value config_dict[mii.constants.MODEL_PATH_KEY] = model_path + config_dict[mii.constants.DEPLOYMENT_TAG_KEY] = deployment_tag + config_dict[mii.constants.DEPLOYED_KEY] = deployed + config_dict[mii.constants.DEPLOYMENTS_KEY] = {} + if port_map is not None: + config_dict[mii.constants.PORT_MAP_KEY] = port_map + + if deployments is not None: + for deployment in deployments.values(): + deployment_config = { + mii.constants.DEPLOYMENT_NAME_KEY: + getattr(deployment, + mii.constants.DEPLOYMENT_NAME_KEY), + mii.constants.TASK_NAME_KEY: + mii.utils.get_task_name(getattr(deployment, + mii.constants.TASK_NAME_KEY)), + mii.constants.MODEL_NAME_KEY: + getattr(deployment, + mii.constants.MODEL_NAME_KEY), + mii.constants.ENABLE_DEEPSPEED_KEY: + getattr(deployment, + mii.constants.ENABLE_DEEPSPEED_KEY), + #mii.constants.MII_CONFIGS_KEY: + #getattr(deployment, + # mii.constants.MII_CONFIGS_KEY).dict(), + mii.constants.ENABLE_DEEPSPEED_ZERO_KEY: + getattr(deployment, + mii.constants.ENABLE_DEEPSPEED_ZERO_KEY), + mii.constants.DEEPSPEED_CONFIG_KEY: + getattr(deployment, + mii.constants.DEEPSPEED_CONFIG_KEY), + mii.constants.GPU_INDEX_KEY: + getattr(deployment, + mii.constants.GPU_INDEX_KEY), + 'tensor_parallel': + deployment.tensor_parallel, + 'dtype': + deployment.dtype, + 'meta_tensor': + deployment.meta_tensor, + 'load_with_sys_mem': + deployment.load_with_sys_mem, + 'replace_with_kernel_inject': + deployment.replace_with_kernel_inject, + 'profile_model_time': + deployment.profile_model_time, + 'skip_model_check': + deployment.skip_model_check, + 'max_tokens': + deployment.max_tokens, + 'enable_restful_api': + deployment.enable_restful_api, + 'replica_num': + deployment.replica_num, + 'hostfile': + deployment.hostfile, + 'deploy_rank': + deployment.deploy_rank, + 'enable_cuda_graph': + deployment.enable_cuda_graph, + 'checkpoint_dict': + deployment.checkpoint_dict, + 'hf_auth_token': + deployment.hf_auth_token, + 'trust_remote_code': + deployment.trust_remote_code + } + config_dict[mii.constants.DEPLOYMENTS_KEY][ + deployment.deployment_name] = deployment_config if lb_config is not None: config_dict[mii.constants.LOAD_BALANCER_CONFIG_KEY] = lb_config @@ -46,16 +109,16 @@ def create_score_file(deployment_name, source_with_config = f"{score_src}\n" source_with_config += f"configs = {pprint.pformat(config_dict, indent=4)}" - with open(generated_score_path(deployment_name, deployment_type), "w") as fd: + with open(generated_score_path(deployment_tag, deployment_type), "w") as fd: fd.write(source_with_config) fd.write("\n") -def generated_score_path(deployment_name, deployment_type): +def generated_score_path(deployment_tag, deployment_type): if deployment_type == DeploymentType.LOCAL: - score_path = os.path.join(mii.utils.mii_cache_path(), deployment_name) + score_path = os.path.join(mii.utils.mii_cache_path(), deployment_tag) elif deployment_type == DeploymentType.AML: - score_path = os.path.join(mii.aml_related.utils.aml_output_path(deployment_name), + score_path = os.path.join(mii.aml_related.utils.aml_output_path(deployment_tag), "code") if not os.path.isdir(score_path): os.makedirs(score_path) diff --git a/mii/models/score/score_template.py b/mii/models/score/score_template.py index 04e47fae..df4d94d0 100644 --- a/mii/models/score/score_template.py +++ b/mii/models/score/score_template.py @@ -8,7 +8,7 @@ import json import torch import mii -from mii.config import LoadBalancerConfig, ReplicaConfig +from mii.config import LoadBalancerConfig, ReplicaConfig, MIIConfig import time model = None @@ -16,24 +16,20 @@ def init(): model_path = mii.utils.full_model_path(configs[mii.constants.MODEL_PATH_KEY]) - - deployment_name = configs[mii.constants.DEPLOYMENT_NAME_KEY] - model_name = configs[mii.constants.MODEL_NAME_KEY] - task_name = configs[mii.constants.TASK_NAME_KEY] - - assert model_name is not None, "The model name should be set before calling init" - assert task_name is not None, "The task name should be set before calling init" - - mii.MIIServer(deployment_name, - task_name, - model_name, + mii_configs = configs[mii.constants.MII_CONFIGS_KEY] + deployment_tag = configs[mii.constants.DEPLOYMENT_TAG_KEY] + deployments = [] + lb_enabled = configs[mii.constants.DEPLOYED_KEY] + for deployment in configs[mii.constants.DEPLOYMENTS_KEY].values(): + deployments.append(mii.DeploymentConfig(**deployment)) + mii_configs = MIIConfig(**mii_configs) + mii.MIIServer(deployment_tag, + deployments, model_path, - ds_optimize=configs[mii.constants.ENABLE_DEEPSPEED_KEY], - ds_zero=configs[mii.constants.ENABLE_DEEPSPEED_ZERO_KEY], - ds_config=configs[mii.constants.DEEPSPEED_CONFIG_KEY], - mii_configs=configs[mii.constants.MII_CONFIGS_KEY], lb_config=configs.get(mii.constants.LOAD_BALANCER_CONFIG_KEY, - None)) + None), + lb_enabled=lb_enabled, + mii_configs=mii_configs) global model model = None diff --git a/mii/server.py b/mii/server.py index 0825e060..1aeac364 100644 --- a/mii/server.py +++ b/mii/server.py @@ -29,41 +29,31 @@ def config_to_b64_str(config): class MIIServer(): '''Initialize the model, setup the server for the model under model_path''' def __init__(self, - deployment_name, - task_name, - model_name, + deployment_tag, + deployments, model_path, - ds_optimize=True, - ds_zero=False, - ds_config=None, - mii_configs={}, - lb_config=None): - - mii_configs = mii.config.MIIConfig(**mii_configs) - - self.task = mii.utils.get_task(task_name) - - self.num_gpus = get_num_gpus(mii_configs) - assert self.num_gpus > 0, "GPU count must be greater than 0" - - self.port_number = mii_configs.port_number - - if mii_configs.hostfile is None: - hostfile = tempfile.NamedTemporaryFile(delete=False) - num_gpu = torch.cuda.device_count() - with open(hostfile, "w") as f: - f.write(f"localhost slots={num_gpu}") - mii.configs.hostfile = hostfile - - processes = self._initialize_service(deployment_name, - model_name, - model_path, - ds_optimize, - ds_zero, - ds_config, - mii_configs, - lb_config) - self._wait_until_server_is_live(processes, lb_config.replica_configs) + lb_config=None, + lb_enabled=False, + mii_configs={}): + if len(deployments) > 0: + self.lb_enabled = lb_enabled + self.deployments = deployments + for deployment in deployments: + #mii_configs = getattr(deployment, mii.constants.MII_CONFIGS_KEY) + assert get_num_gpus(deployment) > 0, f"GPU count for {deployment.deployment_name} must be greater than 0" + if deployment.hostfile is None: + hostfile = tempfile.NamedTemporaryFile(delete=False) + num_gpu = torch.cuda.device_count() + with open(hostfile, "w") as f: + f.write(f"localhost slots={num_gpu}") + deployment.hostfile = hostfile + deps = {dep.deployment_name: dep for dep in deployments} + processes = self._initialize_service(deployment_tag, + deps, + model_path, + lb_config, + mii_configs) + self._wait_until_server_is_live(processes, lb_config.replica_configs) def _wait_until_server_is_live(self, processes, deployment): for process, repl_config in zip(processes, deployment): @@ -110,18 +100,25 @@ def _build_server_args(self, ds_zero, ds_config, mii_configs, - port): + port, + deployment): # serialize mii config b64_config_str = config_to_b64_str(mii_configs) - - server_args_str = f"--deployment-name {deployment_name} --task-name {mii.utils.get_task_name(self.task)} --model {model_name} --model-path {model_path} --port {port}" + b64_deployment = config_to_b64_str(deployment) + task = "" + for deployment in self.deployments: + if deployment_name == getattr(deployment, mii.constants.DEPLOYMENT_NAME_KEY): + task = getattr(deployment, mii.constants.TASK_NAME_KEY) + break + server_args_str = f"--deployment-name {deployment_name} --task-name {mii.utils.get_task_name(task)} --model {model_name} --model-path {model_path} --port {port}" server_args_str += " --ds-optimize" if ds_optimize else "" # XXX: fetch model provider based on model name in a more general way - provider = get_provider_name(model_name, self.task) + provider = get_provider_name(model_name, task) server_args_str += f" --provider {provider}" server_args_str += f" --config {b64_config_str}" + server_args_str += f" -f {b64_deployment}" server_args_str += " --ds-zero" if ds_zero else "" if ds_zero and ds_config is not None: if isinstance(ds_config, dict): @@ -143,7 +140,7 @@ def create_config_from_dict(tmpdir, config_dict): f"Expected a string path to an existing deepspeed config, or a dictionary. Received: {ds_config}" ) server_args_str += f" --ds-config {ds_config_path}" - printable_config = f"task-name {mii.utils.get_task_name(self.task)} model {model_name} model-path {model_path} port {self.port_number} provider {provider}" + printable_config = f"task-name {task} model {model_name} model-path {model_path} port {port} provider {provider}" logger.info(f"MII using multi-gpu deepspeed launcher:\n" + self.print_helper(printable_config)) return server_args_str @@ -161,30 +158,16 @@ def print_helper(self, args): printable_string += " " + "-" * 60 return printable_string - def _launch_load_balancer(self, - deployment_name, - model_name, - model_path, - ds_optimize, - ds_zero, - ds_config, - mii_configs, - lb_config): + def _launch_load_balancer(self, model_path, lb_config): # serialize mii config b64_config_str = config_to_b64_str(lb_config) - - return self._launch_server_process( - deployment_name, - model_name, - model_path, - ds_optimize, - ds_zero, - ds_config, - mii_configs, - mii_configs.port_number, - "load balancer", - ex_server_args=[f"--load-balancer {b64_config_str}"]) + launch_str = f"{sys.executable} -m mii.launch.load_balance_server --load-balancer {b64_config_str}" + cmd = launch_str.split(" ") + mii_env = os.environ.copy() + mii_env["TRANSFORMERS_CACHE"] = model_path + logger.info(f"load balancer server launch: {cmd}") + return subprocess.Popen(cmd, env=mii_env) def _launch_restful_gateway(self, deployment_name, @@ -194,17 +177,21 @@ def _launch_restful_gateway(self, ds_zero, ds_config, mii_configs, - port): - return self._launch_server_process(deployment_name, - model_name, - model_path, - ds_optimize, - ds_zero, - ds_config, - mii_configs, - port, - "restful api gateway", - ex_server_args=["--restful-gateway"]) + port, + deployment): + return self._launch_server_process( + deployment_name, + model_name, + model_path, + ds_optimize, + ds_zero, + ds_config, + mii_configs, + port, + "restful api gateway", + deployment, + ex_server_args=["--restful-gateway"], + ) def _launch_server_process(self, deployment_name, @@ -216,6 +203,7 @@ def _launch_server_process(self, mii_configs, port, msg_server_type, + deployment, ds_launch_str=None, ex_server_args=[]): launch_str = f"{sys.executable} -m mii.launch.multi_gpu_server" @@ -226,7 +214,8 @@ def _launch_server_process(self, ds_zero, ds_config, mii_configs, - port) + port, + deployment) server_args_str += f" " + \ " ".join(ex_server_args) if ex_server_args else "" @@ -252,7 +241,8 @@ def _launch_deepspeed(self, host, port, master_port, - deploy_ranks): + deploy_ranks, + deployment): # use different hostfiles for replica instances # pass /dev/null when no replica is used worker_str = f"-H {hostfile} " @@ -275,69 +265,81 @@ def _launch_deepspeed(self, mii_configs, port, "MII server", + deployment, ds_launch_str=ds_launch_str) def _initialize_service(self, - deployment_name, - model_name, + deployment_tag, + deployments, model_path, - ds_optimize, - ds_zero, - ds_config, - mii_configs, - lb_config): + lb_config, + mii_configs): processes = [] - host_gpus = defaultdict(list) for repl_config in lb_config.replica_configs: host_gpus[repl_config.hostname].extend(repl_config.gpu_indices) # Start replica instances for i, repl_config in enumerate(lb_config.replica_configs): + name = repl_config.deployment_name + deployment = None if name not in deployments else deployments[name] + """for dep in deployments: + if getattr(dep, mii.constants.DEPLOYMENT_NAME_KEY) == name: + deployment = dep + """ + if deployment is None: + continue hostfile = tempfile.NamedTemporaryFile(delete=False) hostfile.write( f'{repl_config.hostname} slots={max(host_gpus[repl_config.hostname])+1}\n' .encode()) processes.append( self._launch_deepspeed( - deployment_name, - model_name, + name, + getattr(deployment, + mii.constants.MODEL_NAME_KEY), model_path, - ds_optimize, - ds_zero, - ds_config, + getattr(deployment, + mii.constants.ENABLE_DEEPSPEED_KEY), + getattr(deployment, + mii.constants.ENABLE_DEEPSPEED_ZERO_KEY), + getattr(deployment, + mii.constants.DEEPSPEED_CONFIG_KEY), mii_configs, hostfile.name, repl_config.hostname, repl_config.tensor_parallel_ports[0], mii_configs.torch_dist_port + (100 * i) + repl_config.gpu_indices[0], - repl_config.gpu_indices)) + repl_config.gpu_indices, + deployment)) # start load balancer here. # we don't use deepspeed launcher for the load balancer because it does not need a GPU. # The deepspeed launcher determines the number of processes to launch based on GPUs available on the host or CUDA_VISIBLE_DEVICES, # and it is expected to assign one GPU to one process. - processes.append( - self._launch_load_balancer(deployment_name, - model_name, - model_path, - ds_optimize, - ds_zero, - ds_config, - mii_configs, - lb_config)) - - if mii_configs.enable_restful_api: - # start rest api server - processes.append( - self._launch_restful_gateway(deployment_name, - model_name, - model_path, - ds_optimize, - ds_zero, - ds_config, - mii_configs, - mii_configs.port_number)) + if not self.lb_enabled: + processes.append(self._launch_load_balancer(model_path, lb_config)) + + for deployment in self.deployments: + if deployment.enable_restful_api: + # start rest api server + processes.append( + self._launch_restful_gateway( + getattr(deployment, + mii.constants.DEPLOYMENT_NAME_KEY), + getattr(deployment, + mii.constants.MODEL_NAME_KEY), + model_path, + getattr(deployment, + mii.constants.ENABLE_DEEPSPEED_KEY), + getattr(deployment, + mii.constants.ENABLE_DEEPSPEED_ZERO_KEY), + getattr(deployment, + mii.constants.DEEPSPEED_CONFIG_KEY), + mii_configs, + mii_configs.port_number), + deployment) + break return processes diff --git a/mii/terminate.py b/mii/terminate.py index 167c5a5a..5585832b 100644 --- a/mii/terminate.py +++ b/mii/terminate.py @@ -7,21 +7,21 @@ import mii -def terminate(deployment_name): - mii.utils.logger.info(f"Terminating server for {deployment_name}") - generator = mii.mii_query_handle(deployment_name) - if (deployment_name in mii.non_persistent_models): +def terminate(deployment_tag): + mii.utils.logger.info(f"Terminating server for {deployment_tag}") + generator = mii.mii_query_handle(deployment_tag) + if (deployment_tag in mii.non_persistent_models): generator.terminate() return try: - generator.query({'query': ''}) + generator.query({'query': ''}, mii.constants.MII_TERMINATE_DEP_KEY) except grpc.aio._call.AioRpcError as error: if error._code == grpc.StatusCode.UNAVAILABLE: - mii.utils.logger.warn(f"Server for {deployment_name} not found") + mii.utils.logger.warn(f"Server for {deployment_tag} not found") else: pass except (KeyError, TypeError) as error: pass generator.terminate() - mii.client.terminate_restful_gateway(deployment_name) + mii.client.terminate_restful_gateway(deployment_tag) diff --git a/tests/conftest.py b/tests/conftest.py index cb812069..29be37be 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -108,6 +108,55 @@ def ds_config(request): return request.param +@pytest.fixture(scope="function", params=["Multi_Model_Tag"]) +def deployment_tag(request): + return request.param + + +@pytest.fixture(scope="function", params=[[]]) +def deployments(request): + ret = {} + gpu_index_map1 = {'master': [0]} + gpu_index_map2 = {'master': [1]} + gpu_index_map3 = {'master': [0, 1]} + + deployments = [] + + mii_configs1 = {"tensor_parallel": 2, "dtype": "fp16"} + mii_configs2 = {"tensor_parallel": 1} + + name = "bigscience/bloom-560m" + deployments.append( + mii.DeploymentConfig(task='text-generation', + model=name, + deployment_name=name + "_deployment", + GPU_index_map=gpu_index_map3, + mii_configs=mii.config.MIIConfig(**mii_configs1))) + + name = "microsoft/DialogRPT-human-vs-rand" + deployments.append( + mii.DeploymentConfig(task='text-classification', + model=name, + deployment_name=name + "_deployment", + GPU_index_map=gpu_index_map2)) + + name = "microsoft/DialoGPT-large" + deployments.append( + mii.DeploymentConfig(task='conversational', + model=name, + deployment_name=name + "_deployment", + GPU_index_map=gpu_index_map1, + mii_configs=mii.config.MIIConfig(**mii_configs2))) + + name = "deepset/roberta-large-squad2" + deployments.append( + mii.DeploymentConfig(task="question-answering", + model=name, + deployment_name=name + "-qa-deployment", + GPU_index_map=gpu_index_map2)) + return deployments + + @pytest.fixture(scope="function") def deployment_config(task_name: str, model_name: str, @@ -130,6 +179,19 @@ def deployment_config(task_name: str, return config +@pytest.fixture(scope="function") +def multi_deployment_config(deployments: list, + deployment_tag: str, + deployment_type: str): + config = SimpleNamespace(deployments=deployments, + deployment_type=deployment_type, + deployment_tag=deployment_tag, + model_path=os.getenv("TRANSFORMERS_CACHE", + None)) + validate_config(config) + return config + + @pytest.fixture(scope="function", params=[None]) def expected_failure(request): return request.param @@ -147,6 +209,43 @@ def deployment(deployment_config, expected_failure): mii.terminate(deployment_config.deployment_name) +@pytest.fixture(scope="function") +def multi_deployment(deployment_tag, multi_deployment_config): + mii.deploy(**multi_deployment_config.__dict__) + yield multi_deployment_config + mii.terminate(deployment_tag) + + @pytest.fixture(scope="function", params=[{"query": "DeepSpeed is the greatest"}]) def query(request): return request.param + + +@pytest.fixture(scope="function") +def multi_query(request): + queries = [] + queries.append({ + "query": ["DeepSpeed is", + "Seattle is"], + "deployment_name": "bigscience/bloom-560m_deployment" + }) + + queries.append({ + 'query': "DeepSpeed is the greatest", + "deployment_name": "microsoft/DialogRPT-human-vs-rand_deployment" + }) + + queries.append({ + 'text': "DeepSpeed is the greatest", + 'conversation_id': 3, + 'past_user_inputs': [], + 'generated_responses': [], + "deployment_name": "microsoft/DialoGPT-large_deployment" + }) + + queries.append({ + 'question': "What is the greatest?", + 'context': "DeepSpeed is the greatest", + "deployment_name": "deepset/roberta-large-squad2" + "-qa-deployment" + }) + return queries diff --git a/tests/test_multi_deployment.py b/tests/test_multi_deployment.py new file mode 100644 index 00000000..9caa9828 --- /dev/null +++ b/tests/test_multi_deployment.py @@ -0,0 +1,35 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +import pytest +import mii + + +def test_multi_deploy(deployment_tag, multi_deployment, multi_query): + generator = mii.mii_query_handle(deployment_tag) + for query in multi_query: + result = generator.query(query) + assert result + + +@pytest.mark.parametrize( + "task_name, model_name, query", + [ + ( + "text-generation", + "bigscience/bloom-560m", + { + "query": ["DeepSpeed is the greatest", + 'Seattle is'] + }, + ), + ], +) +def test_partial_deploy(deployment_tag, multi_deployment, deployment_config, query): + generator = mii.mii_query_handle(deployment_tag) + generator.add_models(**deployment_config.__dict__) + query["deployment_name"] = deployment_config.deployment_name + result = generator.query(query) + generator.delete_model(deployment_config.deployment_name) + assert result