diff --git a/examples/multi_model/deploy.py b/examples/multi_model/deploy.py new file mode 100644 index 00000000..1e8b7aed --- /dev/null +++ b/examples/multi_model/deploy.py @@ -0,0 +1,51 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +import mii + +gpu_index_map1 = {'master': [0]} +gpu_index_map2 = {'master': [1]} +gpu_index_map3 = {'master': [0, 1]} + +deployments = [] + +mii_configs1 = {"tensor_parallel": 2, "dtype": "fp16"} +mii_configs2 = {"tensor_parallel": 1} + +name = "bigscience/bloom-560m" +deployments.append({ + 'task': 'text-generation', + 'model': name, + 'deployment_name': name + "_deployment", + 'GPU_index_map': gpu_index_map3, + 'tensor_parallel': 2, + 'dtype': "fp16" +}) + +# gpt2 +name = "microsoft/DialogRPT-human-vs-rand" +deployments.append({ + 'task': 'text-classification', + 'model': name, + 'deployment_name': name + "_deployment", + 'GPU_index_map': gpu_index_map2 +}) + +name = "microsoft/DialoGPT-large" +deployments.append({ + 'task': 'conversational', + 'model': name, + 'deployment_name': name + "_deployment", + 'GPU_index_map': gpu_index_map1, +}) + +name = "deepset/roberta-large-squad2" +deployments.append({ + 'task': "question-answering", + 'model': name, + 'deployment_name': name + "-qa-deployment", + 'GPU_index_map': gpu_index_map2 +}) + +mii.deploy(deployment_tag="multi_models", deployment_configs=deployments[:2]) diff --git a/examples/multi_model/query.py b/examples/multi_model/query.py new file mode 100644 index 00000000..f506830f --- /dev/null +++ b/examples/multi_model/query.py @@ -0,0 +1,50 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import mii + +results = [] +generator = mii.mii_query_handle("multi_models") +result = generator.query( + { + "query": ["DeepSpeed is", + "Seattle is"], + "deployment_name": "bigscience/bloom-560m_deployment" + }, + do_sample=True, + max_new_tokens=30, +) +results.append(result) +print(result) + +result = generator.query({ + 'query': + "DeepSpeed is the greatest", + "deployment_name": + "microsoft/DialogRPT-human-vs-rand_deployment" +}) +results.append(result) +print(result) + +result = generator.query({ + 'text': "DeepSpeed is the greatest", + 'conversation_id': 3, + 'past_user_inputs': [], + 'generated_responses': [], + "deployment_name": "microsoft/DialoGPT-large_deployment" +}) +results.append(result) +print(result) + +result = generator.query({ + 'question': + "What is the greatest?", + 'context': + "DeepSpeed is the greatest", + "deployment_name": + "deepset/roberta-large-squad2" + "-qa-deployment" +}) +results.append(result) +print(result) diff --git a/examples/multi_model/shutdown.py b/examples/multi_model/shutdown.py new file mode 100644 index 00000000..6b718a4d --- /dev/null +++ b/examples/multi_model/shutdown.py @@ -0,0 +1,7 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +import mii + +mii.terminate("multi_models") diff --git a/mii/__init__.py b/mii/__init__.py index ab409d4c..94bbdccf 100644 --- a/mii/__init__.py +++ b/mii/__init__.py @@ -7,10 +7,10 @@ from .client import MIIClient, mii_query_handle from .deployment import deploy from .terminate import terminate -from .constants import DeploymentType, Tasks +from .constants import DeploymentType, TaskType from .aml_related.utils import aml_output_path -from .config import MIIConfig, LoadBalancerConfig +from .config import MIIConfig, DeploymentConfig from .grpc_related.proto import modelresponse_pb2_grpc __version__ = "0.0.0" diff --git a/mii/client.py b/mii/client.py index 535b55c8..fbc166e0 100644 --- a/mii/client.py +++ b/mii/client.py @@ -6,20 +6,15 @@ import grpc import requests import mii -from mii.utils import get_task from mii.grpc_related.proto import modelresponse_pb2, modelresponse_pb2_grpc -from mii.constants import GRPC_MAX_MSG_SIZE, Tasks +from mii.constants import GRPC_MAX_MSG_SIZE, TaskType from mii.method_table import GRPC_METHOD_TABLE +from mii.config import MIIConfig -def _get_deployment_info(deployment_name): - configs = mii.utils.import_score_file(deployment_name).configs - task = configs[mii.constants.TASK_NAME_KEY] - mii_configs_dict = configs[mii.constants.MII_CONFIGS_KEY] - mii_configs = mii.config.MIIConfig(**mii_configs_dict) - - assert task is not None, "The task name should be set before calling init" - return task, mii_configs +def _get_mii_config(deployment_name): + mii_config = mii.utils.import_score_file(deployment_name).mii_config + return MIIConfig(**mii_config) def mii_query_handle(deployment_name): @@ -39,40 +34,64 @@ def mii_query_handle(deployment_name): inference_pipeline, task = mii.non_persistent_models[deployment_name] return MIINonPersistentClient(task, deployment_name) - task_name, mii_configs = _get_deployment_info(deployment_name) - return MIIClient(task_name, "localhost", mii_configs.port_number) + mii_config = _get_mii_config(deployment_name) + return MIIClient(mii_config, "localhost", mii_config.port_number) def create_channel(host, port): - return grpc.aio.insecure_channel(f'{host}:{port}', - options=[('grpc.max_send_message_length', - GRPC_MAX_MSG_SIZE), - ('grpc.max_receive_message_length', - GRPC_MAX_MSG_SIZE)]) - - -class MIIClient(): + return grpc.aio.insecure_channel( + f"{host}:{port}", + options=[ + ("grpc.max_send_message_length", + GRPC_MAX_MSG_SIZE), + ("grpc.max_receive_message_length", + GRPC_MAX_MSG_SIZE), + ], + ) + + +class MIIClient: """ Client to send queries to a single endpoint. """ - def __init__(self, task_name, host, port): + def __init__(self, mii_config, host, port): self.asyncio_loop = asyncio.get_event_loop() channel = create_channel(host, port) self.stub = modelresponse_pb2_grpc.ModelResponseStub(channel) - self.task = get_task(task_name) - - async def _request_async_response(self, request_dict, **query_kwargs): - if self.task not in GRPC_METHOD_TABLE: - raise ValueError(f"unknown task: {self.task}") - - task_methods = GRPC_METHOD_TABLE[self.task] + self.mii_config = mii_config + + def _get_deployment_task(self, deployment_name=None): + task = None + if deployment_name is None: #mii.terminate() or single model + if deployment_name is None: + assert len(self.deployments) == 1, "Must pass deployment_name to query when using multiple deployments" + deployment = self.mii_config.deployment_configs[0] + deployment_name = getattr(deployment, deployment_name) + task = getattr(deployment, task) + else: + if deployment_name in self.deployments: + deployment = self.mii_config.deployment_configs[deployment_name] + task = getattr(deployment, task) + else: + assert False, f"{deployment_name} not found in list of deployments" + return deployment_name, task + + async def _request_async_response(self, request_dict, task, **query_kwargs): + if task not in GRPC_METHOD_TABLE: + raise ValueError(f"unknown task: {task}") + + task_methods = GRPC_METHOD_TABLE[task] proto_request = task_methods.pack_request_to_proto(request_dict, **query_kwargs) - proto_response = await getattr(self.stub, task_methods.method)(proto_request) + proto_response = await getattr(self.mr_stub, task_methods.method)(proto_request) return task_methods.unpack_response_from_proto(proto_response) def query(self, request_dict, **query_kwargs): + deployment_name = request_dict.get(mii.constants.DEPLOYMENT_NAME_KEY) + deployment_name, task = self._get_deployment_task(deployment_name) + request_dict['deployment_name'] = deployment_name return self.asyncio_loop.run_until_complete( self._request_async_response(request_dict, + task, **query_kwargs)) async def terminate_async(self): @@ -87,7 +106,9 @@ async def create_session_async(self, session_id): modelresponse_pb2.SessionID(session_id=session_id)) def create_session(self, session_id): - assert self.task == Tasks.TEXT_GENERATION, f"Session creation only available for task '{Tasks.TEXT_GENERATION}'." + assert ( + self.task == TaskType.TEXT_GENERATION + ), f"Session creation only available for task '{TaskType.TEXT_GENERATION}'." return self.asyncio_loop.run_until_complete( self.create_session_async(session_id)) @@ -96,18 +117,20 @@ async def destroy_session_async(self, session_id): ) def destroy_session(self, session_id): - assert self.task == Tasks.TEXT_GENERATION, f"Session deletion only available for task '{Tasks.TEXT_GENERATION}'." + assert ( + self.task == TaskType.TEXT_GENERATION + ), f"Session deletion only available for task '{TaskType.TEXT_GENERATION}'." self.asyncio_loop.run_until_complete(self.destroy_session_async(session_id)) -class MIITensorParallelClient(): +class MIITensorParallelClient: """ Client to send queries to multiple endpoints in parallel. This is used to call multiple servers deployed for tensor parallelism. """ - def __init__(self, task_name, host, ports): - self.task = get_task(task_name) - self.clients = [MIIClient(task_name, host, port) for port in ports] + def __init__(self, task, host, ports): + self.task = task + self.clients = [MIIClient(task, host, port) for port in ports] self.asyncio_loop = asyncio.get_event_loop() # runs task in parallel and return the result from the first task @@ -155,30 +178,32 @@ def destroy_session(self, session_id): client.destroy_session(session_id) -class MIINonPersistentClient(): +class MIINonPersistentClient: def __init__(self, task, deployment_name): self.task = task self.deployment_name = deployment_name def query(self, request_dict, **query_kwargs): - assert self.deployment_name in mii.non_persistent_models, f"deployment: {self.deployment_name} not found" + assert ( + self.deployment_name in mii.non_persistent_models + ), f"deployment: {self.deployment_name} not found" task_methods = GRPC_METHOD_TABLE[self.task] inference_pipeline = mii.non_persistent_models[self.deployment_name][0] - if self.task == Tasks.QUESTION_ANSWERING: - if 'question' not in request_dict or 'context' not in request_dict: + if self.task == TaskType.QUESTION_ANSWERING: + if "question" not in request_dict or "context" not in request_dict: raise Exception( "Question Answering Task requires 'question' and 'context' keys") args = (request_dict["question"], request_dict["context"]) kwargs = query_kwargs - elif self.task == Tasks.CONVERSATIONAL: + elif self.task == TaskType.CONVERSATIONAL: conv = task_methods.create_conversation(request_dict, **query_kwargs) args = (conv, ) kwargs = {} else: - args = (request_dict['query'], ) + args = (request_dict["query"], ) kwargs = query_kwargs return task_methods.run_inference(inference_pipeline, args, query_kwargs) @@ -189,6 +214,6 @@ def terminate(self): def terminate_restful_gateway(deployment_name): - _, mii_configs = _get_deployment_info(deployment_name) - if mii_configs.enable_restful_api: - requests.get(f"http://localhost:{mii_configs.restful_api_port}/terminate") + mii_config = _get_mii_config(deployment_name) + if mii_config.enable_restful_api: + requests.get(f"http://localhost:{mii_config.restful_api_port}/terminate") diff --git a/mii/config.py b/mii/config.py index 2714cb40..1a581e3e 100644 --- a/mii/config.py +++ b/mii/config.py @@ -3,93 +3,148 @@ # DeepSpeed Team import torch -from typing import Union, List -from enum import Enum -from pydantic import BaseModel, validator, root_validator - -from deepspeed.launcher.runner import DLTS_HOSTFILE - - -class DtypeEnum(Enum): - # The torch dtype must always be the first value (so we return torch.dtype) - fp16 = torch.float16, "torch.float16", "fp16", "float16", "half" - bf16 = torch.bfloat16, "torch.bfloat16", "bf16", "bfloat16" - fp32 = torch.float32, "torch.float32", "fp32", "float32", "float" - int8 = torch.int8, "torch.int8", "int8" - - # Copied from https://stackoverflow.com/a/43210118 - # Allows us to use multiple values for each Enum index and returns first - # listed value when Enum is called - def __new__(cls, *values): - obj = object.__new__(cls) - # first value is canonical value - obj._value_ = values[0] - for other_value in values[1:]: - cls._value2member_map_[other_value] = obj - obj._all_values = values - return obj - - def __repr__(self): - return "<%s.%s: %s>" % ( - self.__class__.__name__, - self._name_, - ", ".join([repr(v) for v in self._all_values]), - ) +import os +import string +from typing import List, Optional, Dict, Any +from pydantic import validator, root_validator, BaseModel +import mii +from mii.constants import DeploymentType, TaskType, MII_MODEL_PATH_DEFAULT +from deepspeed.runtime.config_utils import DeepSpeedConfigModel +from deepspeed.inference.config import DtypeEnum +from deepspeed.launcher.runner import DLTS_HOSTFILE, fetch_hostfile -class MIIConfig(BaseModel): - tensor_parallel: int = 1 - port_number: int = 50050 - dtype: DtypeEnum = torch.float32 - meta_tensor: bool = False + +class ReplicaConfig(DeepSpeedConfigModel): + hostname: str = "" + tensor_parallel_ports: List[int] = [] + torch_dist_port: int = None + gpu_indices: List[int] = [] + + +class DeploymentConfig(DeepSpeedConfigModel): + # Deployment configs + deployment_name: str load_with_sys_mem: bool = False - enable_cuda_graph: bool = False - checkpoint_dict: Union[dict, None] = None - deploy_rank: Union[int, List[int]] = -1 + meta_tensor: bool = False + hf_auth_token: Optional[str] = None + deploy_rank: Optional[List[int]] = None torch_dist_port: int = 29500 - hf_auth_token: str = None - replace_with_kernel_inject: bool = True + replica_num: int = 1 + replica_configs: List[ReplicaConfig] = [] profile_model_time: bool = False skip_model_check: bool = False - max_tokens: int = 1024 - enable_restful_api: bool = False - restful_api_port: int = 51080 - replica_num: int = 1 - hostfile: str = DLTS_HOSTFILE trust_remote_code: bool = False - @validator("deploy_rank") - def deploy_valid(cls, field_value, values): - if "tensor_parallel" not in values: - raise ValueError( - "'tensor_parallel' must be defined in the pydantic model before 'deploy_rank'" - ) - - # if deploy rank is not given, default to align with TP value - if field_value == -1: - field_value = list(range(values["tensor_parallel"])) + # Model configs + model: str + task: TaskType + dtype: DtypeEnum = DtypeEnum.fp32 + model_path: str = "" + checkpoint_dict: Optional[Dict[str, Any]] = None + max_tokens: int = 1024 + GPU_index_map: dict = None - # ensure deploy rank type is always list for easier consumption later - if not isinstance(field_value, list): - field_value = [field_value] + # Performance configs + enable_deepspeed: bool = True + enable_zero: bool = False + ds_config: Dict[str, Any] = {} + tensor_parallel: int = 1 + enable_cuda_graph: bool = False + replace_with_kernel_inject: bool = True - # number of ranks provided must be equal to TP size, DP is handled outside MII currently - assert values["tensor_parallel"] == len(field_value), \ - f"{len(field_value)} rank(s) provided in 'deploy_rank' does not align with tensor_parallel size of {values['tensor_parallel']}" - return field_value + class Config: + json_encoders = {torch.dtype: lambda x: str(x)} - @validator('checkpoint_dict') - def checkpoint_dict_valid(cls, value): - if value is None: - return value - if value.get('base_dir', ''): + @validator("checkpoint_dict") + def checkpoint_dict_valid(cls, field_value, values): + if field_value is None: + return field_value + if field_value.get("base_dir", ""): raise ValueError( "please unset 'base_dir' it will be set w.r.t. the deployment 'model_path'" ) - for k in ['checkpoints', 'parallelization', 'version', 'type']: - if not value.get(k, ''): + for k in ["checkpoints", "parallelization", "version", "type"]: + if not field_value.get(k, ""): raise ValueError(f"Missing key={k} in checkpoint_dict") - return value + return field_value + + @validator("deploy_rank", pre=True) + def deploy_rank_to_list(cls, field_value, values): + if field_value and not isinstance(field_value, list): + field_value = [field_value] + return field_value + + @root_validator + def zero_or_meta(cls, values): + if values.get("enable_zero"): + assert not values.get( + "meta_tensor" + ), "ZeRO-Inference does not support meta tensors." + return values + + @root_validator + def bloom_model_valid(cls, values): + if "bigscience/bloom" in values.get("model"): + # TODO: SHould be albe to use DtypeEnum here + assert values.get("dtype") in [ + torch.int8, + torch.float16, + ], "Bloom models only support fp16/int8." + assert (not values.get( + "enable_cuda_graph" + )), "Bloom models do not support CUDA Graph." + return values + + @root_validator + def deploy_rank_valid(cls, values): + tensor_parallel = values.get("tensor_parallel") + deploy_rank = values.get("deploy_rank") + + # if deploy rank is not given, default to align with TP value + if deploy_rank is None: + deploy_rank = list(range(tensor_parallel)) + + # number of ranks provided must be equal to TP size, DP is handled outside MII currently + assert tensor_parallel == len( + deploy_rank + ), f"{len(deploy_rank)} rank(s) provided in 'deploy_rank' does not align with tensor_parallel size of {tensor_parallel}" + + values["deploy_rank"] = deploy_rank + return values + + @root_validator + def set_model_path(cls, values): + model_path = values.get("model_path") + if not model_path: + if values.get("deployment_type") == DeploymentType.AML: + model_path = "model" + else: + model_path = MII_MODEL_PATH_DEFAULT + aml_model_dir = os.environ.get("AZUREML_MODEL_DIR", None) + if aml_model_dir: + assert os.path.isabs( + aml_model_dir + ), "AZUREML_MODEL_DIR={aml_model_dir} must be an absolute path." + assert not os.path.isabs( + model_path + ), f"model_path={model_path} must be relative to append w/ AML path." + model_path = os.path.join(aml_model_dir, model_path) + + values["model_path"] = model_path + return values + + @root_validator + def validate_model_and_task(cls, values): + task = values.get("task") + model = values.get("model") + if not values.get("skip_model_check"): + mii.utils.check_if_task_and_model_is_valid(task, model) + if values.get("enable_deepspeed"): + mii.utils.check_if_task_and_model_is_supported(task, model) + # Skip any future checks + values["skip_model_check"] = True + return values @root_validator def meta_tensor_or_sys_mem(cls, values): @@ -99,29 +154,163 @@ def meta_tensor_or_sys_mem(cls, values): ) return values - class Config: - validate_all = True - validate_assignment = True - use_enum_values = True - extra = 'forbid' - json_encoders = {torch.dtype: lambda x: str(x)} + @root_validator + def zero_dtype_valid(cls, values): + if values.get("enable_zero"): + if values.get("ds_config").get("fp16", {}).get("enabled", False): + # TODO: We should be able to use DtypeEnum instead of torch.float + assert ( + values.get("dtype") == torch.float16 + ), "ZeRO FP16 enabled, `dtype` must be set to `torch.float16`" + else: + assert ( + values.get("dtype") == torch.float32 + ), "ZeRO FP16 disabled, `dtype` must be set to `torch.float32`" + return values + @root_validator + def deepspeed_or_zero(cls, values): + assert not ( + values.get("enable_deepspeed") and values.get("enable_zero") + ), "DeepSpeed and ZeRO cannot both be enabled, select only one" + return values -class ReplicaConfig(BaseModel): - hostname: str = "" - tensor_parallel_ports: List[int] = [] - torch_dist_port: int = None - gpu_indices: List[int] = [] - class Config: - validate_all = True - validate_assignment = True +""" + @root_validator + def index_map_valid(cls, values): + if values.get("GPU_index_map"): + for host in values.get("GPU_index_map"): + assert host in resource_pool, f"Host: {host} was not found" + assert resource_pool[host] >= tensor_parallel, f"Host {host} has {resource_pool[host]} slot(s), but {tensor_parallel} slot(s) are required" + return values +""" -class LoadBalancerConfig(BaseModel): - port: int = None - replica_configs: List[ReplicaConfig] = [] +class MIIConfig(DeepSpeedConfigModel): + deployment_configs: dict[str, DeploymentConfig] = None + deployment_type: DeploymentType = DeploymentType.LOCAL + deployment_tag: str = None + hf_auth_token: Optional[str] = None + port_number: int = 50050 + enable_restful_api: bool = False + restful_api_port: int = 51080 + hostfile: str = DLTS_HOSTFILE + version: int = 1 + port_map: dict = {} - class Config: - validate_all = True - validate_assignment = True + @root_validator(skip_on_failure=True) + def propagate_hf_auth(cls, values): + # This validator is for when we support multiple models in a deployment + hf_auth_token = values.get("hf_auth_token") + deployment_config_list = values.get("deployment_configs") + print(deployment_config_list) + for deployment_config in deployment_config_list.values(): + if not deployment_config.hf_auth_token: + deployment_config.hf_auth_token = hf_auth_token + return values + + @root_validator(skip_on_failure=True) + def AML_name_valid(cls, values): + if values.get("deployment_type") == DeploymentType.AML: + allowed_chars = set(string.ascii_lowercase + string.ascii_uppercaes + + string.digits + "-") + assert ( + set(values.get("deployment_configs").deployment_name) <= allowed_chars + ), "AML deployment names can only contain a-z, A-Z, 0-9, and '-'." + return values + + @root_validator() + def generate_replica_configs(cls, values): + port_map = values.get("port_map") + hostfile = values.get("hostfile") + port_number = values.get("port_number") + port_offset = 1 + for deployment_config in values.get("deployment_configs").values(): + + replica_configs = deployment_config.replica_configs + replica_num = deployment_config.replica_num + if replica_configs: + assert len(replica_configs) == replica_num + return values + + torch_dist_port = deployment_config.torch_dist_port + tensor_parallel = deployment_config.tensor_parallel + replica_num = deployment_config.replica_num + GPU_index_map = deployment_config.GPU_index_map + replica_pool, GPU_index_map = _allocate_processes(hostfile, tensor_parallel, replica_num, GPU_index_map) + deployment_config.GPU_index_map = GPU_index_map + replica_configs = [] + print(replica_pool) + for i, (hostname, gpu_indices) in enumerate(replica_pool): + # Reserver port for a LB proxy when replication is enabled + if hostname not in port_map: + port_map[hostname] = set() + base_port = port_number + i * tensor_parallel + port_offset + if base_port in port_map[hostname]: + base_port = max(port_map[hostname]) + 1 + tensor_parallel_ports = list( + range(base_port, + base_port + tensor_parallel)) + for i in range(base_port, base_port + tensor_parallel): + port_map[hostname].add(i) + replica_torch_dist_port = torch_dist_port + (100 * i) + replica_configs.append( + ReplicaConfig( + hostname=hostname, + tensor_parallel_ports=tensor_parallel_ports, + torch_dist_port=replica_torch_dist_port, + gpu_indices=gpu_indices, + )) + + deployment_config.replica_configs = replica_configs + return values + + +def _allocate_processes(hostfile_path, tensor_parallel, replica_num, GPU_index_map=None): + resource_pool = fetch_hostfile(hostfile_path) + assert ( + resource_pool is not None and len(resource_pool) > 0 + ), f"No hosts found in {hostfile_path}" + + replica_pool = [] + + if GPU_index_map is not None: + for host in GPU_index_map: + assert host in resource_pool, f"Host: {host} was not found" + assert resource_pool[host] >= tensor_parallel, f"Host {host} has {resource_pool[host]} slot(s), but {tensor_parallel} slot(s) are required" + for host in GPU_index_map: + replica_pool.append((host, GPU_index_map[host])) + return replica_pool, GPU_index_map + + allocated_num = 0 + for host, slots in resource_pool.items(): + available_on_host = slots + while available_on_host >= tensor_parallel: + if allocated_num >= replica_num: + break + if slots < tensor_parallel: + raise ValueError( + f"Host {host} has {slots} slot(s), but {tensor_parallel} slot(s) are required" + ) + + allocated_num_on_host = slots - available_on_host + replica_pool.append(( + host, + [ + i for i in range( + allocated_num_on_host, + allocated_num_on_host + tensor_parallel, + ) + ], + )) + allocated_num += 1 + + available_on_host -= tensor_parallel + + if allocated_num < replica_num: + raise ValueError( + f"Not sufficient GPUs for {replica_num} replica(s), only {allocated_num} replica(s) can be deployed" + ) + + return replica_pool, GPU_index_map diff --git a/mii/constants.py b/mii/constants.py index ba4cfa2f..6fc80854 100644 --- a/mii/constants.py +++ b/mii/constants.py @@ -2,104 +2,61 @@ # SPDX-License-Identifier: Apache-2.0 # DeepSpeed Team -import enum +from enum import Enum -#TODO naming.. -class DeploymentType(enum.Enum): - LOCAL = 1 - AML = 2 - NON_PERSISTENT = 3 +class DeploymentType(str, Enum): + LOCAL = "local" + AML = "aml" + NON_PERSISTENT = "non-persistent" -MII_CONFIGS_KEY = 'mii_configs' +class TaskType(str, Enum): + TEXT_GENERATION = "text-generation" + TEXT_CLASSIFICATION = "text-classification" + QUESTION_ANSWERING = "question-answering" + FILL_MASK = "fill-mask" + TOKEN_CLASSIFICATION = "token-classification" + CONVERSATIONAL = "conversational" + TEXT2IMG = "text-to-image" -class Tasks(enum.Enum): - TEXT_GENERATION = 1 - TEXT_CLASSIFICATION = 2 - QUESTION_ANSWERING = 3 - FILL_MASK = 4 - TOKEN_CLASSIFICATION = 5 - CONVERSATIONAL = 6 - TEXT2IMG = 7 +class ModelProvider(str, Enum): + HUGGING_FACE = "hugging-face" + ELEUTHER_AI = "eleuther-ai" + DIFFUSERS = "diffusers" -TEXT_GENERATION_NAME = 'text-generation' -TEXT_CLASSIFICATION_NAME = 'text-classification' -QUESTION_ANSWERING_NAME = 'question-answering' -FILL_MASK_NAME = 'fill-mask' -TOKEN_CLASSIFICATION_NAME = 'token-classification' -CONVERSATIONAL_NAME = 'conversational' -TEXT2IMG_NAME = "text-to-image" - - -class ModelProvider(enum.Enum): - HUGGING_FACE = 1 - ELEUTHER_AI = 2 - DIFFUSERS = 3 - - -MODEL_PROVIDER_NAME_HF = "hugging-face" -MODEL_PROVIDER_NAME_EA = "eleuther-ai" -MODEL_PROVIDER_NAME_DIFFUSERS = "diffusers" - -MODEL_PROVIDER_MAP = { - MODEL_PROVIDER_NAME_HF: ModelProvider.HUGGING_FACE, - MODEL_PROVIDER_NAME_EA: ModelProvider.ELEUTHER_AI, - MODEL_PROVIDER_NAME_DIFFUSERS: ModelProvider.DIFFUSERS -} - SUPPORTED_MODEL_TYPES = { - 'roberta': ModelProvider.HUGGING_FACE, - 'xlm-roberta': ModelProvider.HUGGING_FACE, - 'gpt2': ModelProvider.HUGGING_FACE, - 'bert': ModelProvider.HUGGING_FACE, - 'gpt_neo': ModelProvider.HUGGING_FACE, - 'gptj': ModelProvider.HUGGING_FACE, - 'opt': ModelProvider.HUGGING_FACE, - 'bloom': ModelProvider.HUGGING_FACE, - 'gpt-neox': ModelProvider.ELEUTHER_AI, - 'stable-diffusion': ModelProvider.DIFFUSERS, - 'llama': ModelProvider.HUGGING_FACE + "roberta": ModelProvider.HUGGING_FACE, + "xlm-roberta": ModelProvider.HUGGING_FACE, + "gpt2": ModelProvider.HUGGING_FACE, + "bert": ModelProvider.HUGGING_FACE, + "gpt_neo": ModelProvider.HUGGING_FACE, + "gptj": ModelProvider.HUGGING_FACE, + "opt": ModelProvider.HUGGING_FACE, + "bloom": ModelProvider.HUGGING_FACE, + "gpt-neox": ModelProvider.ELEUTHER_AI, + "stable-diffusion": ModelProvider.DIFFUSERS, + "llama": ModelProvider.HUGGING_FACE, } -SUPPORTED_TASKS = [ - TEXT_GENERATION_NAME, - TEXT_CLASSIFICATION_NAME, - QUESTION_ANSWERING_NAME, - FILL_MASK_NAME, - TOKEN_CLASSIFICATION_NAME, - CONVERSATIONAL_NAME, - TEXT2IMG_NAME -] - REQUIRED_KEYS_PER_TASK = { - TEXT_GENERATION_NAME: ["query"], - TEXT_CLASSIFICATION_NAME: ["query"], - QUESTION_ANSWERING_NAME: ["context", - "question"], - FILL_MASK_NAME: ["query"], - TOKEN_CLASSIFICATION_NAME: ["query"], - CONVERSATIONAL_NAME: - ['text', - 'conversation_id', - 'past_user_inputs', - 'generated_responses'], - TEXT2IMG_NAME: ["query"] + TaskType.TEXT_GENERATION: ["query"], + TaskType.TEXT_CLASSIFICATION: ["query"], + TaskType.QUESTION_ANSWERING: ["context", + "question"], + TaskType.FILL_MASK: ["query"], + TaskType.TOKEN_CLASSIFICATION: ["query"], + TaskType.CONVERSATIONAL: [ + "text", + "conversation_id", + "past_user_inputs", + "generated_responses", + ], + TaskType.TEXT2IMG: ["query"], } -MODEL_NAME_KEY = 'model_name' -TASK_NAME_KEY = 'task_name' -DEPLOYMENT_NAME_KEY = 'deployment_name' -MODEL_PATH_KEY = 'model_path' -LOAD_BALANCER_CONFIG_KEY = 'load_balancer_config' - -ENABLE_DEEPSPEED_KEY = 'ds_optimize' -ENABLE_DEEPSPEED_ZERO_KEY = 'ds_zero' -DEEPSPEED_CONFIG_KEY = 'ds_config' -CHECKPOINT_KEY = "checkpoint" - MII_CACHE_PATH = "MII_CACHE_PATH" MII_CACHE_PATH_DEFAULT = "/tmp/mii_cache" diff --git a/mii/deployment.py b/mii/deployment.py index 3cadd994..115f6c55 100644 --- a/mii/deployment.py +++ b/mii/deployment.py @@ -2,208 +2,134 @@ # SPDX-License-Identifier: Apache-2.0 # DeepSpeed Team -import torch -import string import os import mii - -from deepspeed.launcher.runner import fetch_hostfile - -from .constants import DeploymentType, MII_MODEL_PATH_DEFAULT, MODEL_PROVIDER_MAP -from .utils import logger, get_task_name, get_provider_name +from typing import List +from .utils import logger from .models.score import create_score_file from .models import load_models -from .config import ReplicaConfig, LoadBalancerConfig - - -def deploy(task, - model, - deployment_name, - deployment_type=DeploymentType.LOCAL, - model_path=None, - enable_deepspeed=True, - enable_zero=False, - ds_config=None, - mii_config={}, - version=1): - """Deploy a task using specified model. For usage examples see: - - mii/examples/local/text-generation-example.py - - - Arguments: - task: Name of the machine learning task to be deployed.Currently MII supports the following list of tasks - ``['text-generation', 'text-classification', 'question-answering', 'fill-mask', 'token-classification', 'conversational', 'text-to-image']`` - - model: Name of a supported model for the task. Models in MII are sourced from multiple open-source projects - such as Huggingface Transformer, FairSeq, EluetherAI etc. For the list of supported models for each task, please - see here [TODO]. - - deployment_name: Name of the deployment. Used as an identifier for posting queries for ``LOCAL`` deployment. - - deployment_type: One of the ``enum mii.DeploymentTypes: [LOCAL]``. - *``LOCAL`` uses a grpc server to create a local deployment, and query the model must be done by creating a query handle using - `mii.mii_query_handle` and posting queries using ``mii_request_handle.query`` API, - - model_path: Optional: In LOCAL deployments this is the local path where model checkpoints are available. In AML deployments this - is an optional relative path with AZURE_MODEL_DIR for the deployment. - - enable_deepspeed: Optional: Defaults to True. Use this flag to enable or disable DeepSpeed-Inference optimizations - - enable_zero: Optional: Defaults to False. Use this flag to enable or disable DeepSpeed-ZeRO inference - - ds_config: Optional: Defaults to None. Use this to specify the DeepSpeed configuration when enabling DeepSpeed-ZeRO inference - - force_register_model: Optional: Defaults to False. For AML deployments, set it to True if you want to re-register your model - with the same ``aml_model_tags`` using checkpoints from ``model_path``. - - mii_config: Optional: Dictionary specifying optimization and deployment configurations that should override defaults in ``mii.config.MIIConfig``. - mii_config is future looking to support extensions in optimization strategies supported by DeepSpeed Inference as we extend mii. - As of now, it can be used to set tensor-slicing degree using 'tensor_parallel' and port number for deployment using 'port_number'. +from .config import MIIConfig, DeploymentType, DeploymentConfig + + +def support_legacy_api( + task, + model, + deployment_type=DeploymentType.LOCAL, + model_path="", + enable_deepspeed=True, + enable_zero=False, + ds_config=None, + mii_config=None, + version=1, +): + if ds_config is None: + ds_config = {} + if mii_config is None: + mii_config = {} + + deployment_config = { + "task": task, + "model": model, + "model_path": model_path, + "enable_deepspeed": enable_deepspeed, + "enable_zero": enable_zero, + "ds_config": ds_config, + } + for key, val in mii_config.items(): + if not hasattr(MIIConfig, key): + deployment_config[key] = val + + mii_config = {k: v for k, v in mii_config.items() if hasattr(MIIConfig, k)} + mii_config["version"] = version + mii_config["deployment_type"] = deployment_type + + return deployment_config, mii_config + + +def deploy( + deployment_name: str = None, + deployment_config: dict = None, + mii_config: dict = None, + deployment_configs: list[dict] = None, + deployment_tag: str = None, + *args, + **kwargs, +): + if mii_config is None: + mii_config = {} + + if args or kwargs: + assert ( + not deployment_config + ), "We do not support mixture of legacy and new API options, use latest API." + assert deployment_name, "deployment_name required for singular deployment" + kwargs["mii_config"] = mii_config + deployment_config, mii_config = support_legacy_api(*args, **kwargs) + + deployment_config["deployment_name"] = deployment_name + mii_config["deployment_tag"] = deployment_name + mii_config["deployment_configs"] = { + deployment_name: DeploymentConfig(**deployment_config) + } + else: + assert all((deployment_tag, deployment_configs)), "To deploy multiple models you must use deployment_tag and deployment_configs" + deployment_dict = {} + for deployment_config in deployment_configs: + deployment_dict[deployment_config.get('deployment_name')] = DeploymentConfig( + **deployment_config) + #print(deployment_dict) + mii_config["deployment_configs"] = deployment_dict + mii_config["deployment_tag"] = deployment_tag + + print(mii_config.keys()) + mii_config = mii.config.MIIConfig(**mii_config) - version: Optional: Version to be set for AML deployment, useful if you want to deploy the same model with different settings. - Returns: - If deployment_type is `LOCAL`, returns just the name of the deployment that can be used to create a query handle using `mii.mii_query_handle(deployment_name)` + for deployment_config in mii_config.deployment_configs.values(): + if deployment_config.enable_deepspeed: + logger.info( + f"************* MII is using DeepSpeed Optimizations to accelerate your model *************" + ) + else: + logger.info( + f"************* DeepSpeed Optimizations not enabled. Please use enable_deepspeed to get better performance *************" + ) - """ + if mii_config.deployment_type != DeploymentType.NON_PERSISTENT: + create_score_file(mii_config) - # parse and validate mii config - mii_config = mii.config.MIIConfig(**mii_config) - if enable_zero: - if ds_config.get("fp16", {}).get("enabled", False): - assert (mii_config.dtype == torch.half), "MII Config Error: MII dtype and ZeRO dtype must match" - else: - assert (mii_config.dtype == torch.float), "MII Config Error: MII dtype and ZeRO dtype must match" - assert not (enable_deepspeed and enable_zero), "MII Config Error: DeepSpeed and ZeRO cannot both be enabled, select only one" - - # aml only allows certain characters for deployment names - if deployment_type == DeploymentType.AML: - allowed_chars = set(string.ascii_lowercase + string.ascii_uppercase + - string.digits + '-') - assert set(deployment_name) <= allowed_chars, "AML deployment names can only contain a-z, A-Z, 0-9, and '-'" - - task = mii.utils.get_task(task) - - if not mii_config.skip_model_check: - mii.utils.check_if_task_and_model_is_valid(task, model) - if enable_deepspeed: - mii.utils.check_if_task_and_model_is_supported(task, model) - - if enable_deepspeed: - logger.info( - f"************* MII is using DeepSpeed Optimizations to accelerate your model *************" - ) - else: - logger.info( - f"************* DeepSpeed Optimizations not enabled. Please use enable_deepspeed to get better performance *************" - ) - - # In local deployments use default path if no model path set - if model_path is None and deployment_type == DeploymentType.LOCAL: - model_path = MII_MODEL_PATH_DEFAULT - elif model_path is None and deployment_type == DeploymentType.AML: - model_path = "model" - - # add fields for replica deployment - replica_pool = _allocate_processes(mii_config.hostfile, - mii_config.tensor_parallel, - mii_config.replica_num) - replica_configs = [] - for i, (hostname, gpu_indices) in enumerate(replica_pool): - # Reserver port for a LB proxy when replication is enabled - port_offset = 1 - base_port = mii_config.port_number + i * mii_config.tensor_parallel + port_offset - tensor_parallel_ports = list( - range(base_port, - base_port + mii_config.tensor_parallel)) - torch_dist_port = mii_config.torch_dist_port + i - replica_configs.append( - ReplicaConfig(hostname=hostname, - tensor_parallel_ports=tensor_parallel_ports, - torch_dist_port=torch_dist_port, - gpu_indices=gpu_indices)) - lb_config = LoadBalancerConfig(port=mii_config.port_number, - replica_configs=replica_configs) - - if deployment_type != DeploymentType.NON_PERSISTENT: - create_score_file(deployment_name=deployment_name, - deployment_type=deployment_type, - task=task, - model_name=model, - ds_optimize=enable_deepspeed, - ds_zero=enable_zero, - ds_config=ds_config, - mii_config=mii_config, - model_path=model_path, - lb_config=lb_config) - - if deployment_type == DeploymentType.AML: - _deploy_aml(deployment_name=deployment_name, model_name=model, version=version) - elif deployment_type == DeploymentType.LOCAL: - return _deploy_local(deployment_name, model_path=model_path) - elif deployment_type == DeploymentType.NON_PERSISTENT: - assert int(os.getenv('WORLD_SIZE', '1')) == mii_config.tensor_parallel, "World Size does not equal number of tensors. When using non-persistent deployment type, please launch with `deepspeed --num_gpus `" - provider = MODEL_PROVIDER_MAP[get_provider_name(model, task)] - mii.non_persistent_models[deployment_name] = (load_models( - get_task_name(task), - model, - model_path, - enable_deepspeed, - enable_zero, - provider, - mii_config), - task) - else: - raise Exception(f"Unknown deployment type: {deployment_type}") + if mii_config.deployment_type == DeploymentType.AML: + _deploy_aml(mii_config) + elif mii_config.deployment_type == DeploymentType.LOCAL: + _deploy_local(mii_config) + elif mii_config.deployment_type == DeploymentType.NON_PERSISTENT: + _deploy_nonpersistent(mii_config) -def _deploy_local(deployment_name, model_path): - mii.utils.import_score_file(deployment_name).init() +def _deploy_local(mii_config): + mii.utils.import_score_file(mii_config.deployment_tag).init() -def _deploy_aml(deployment_name, model_name, version): +def _deploy_aml(mii_config): acr_name = mii.aml_related.utils.get_acr_name() - mii.aml_related.utils.generate_aml_scripts(acr_name=acr_name, - deployment_name=deployment_name, - model_name=model_name, - version=version) + mii.aml_related.utils.generate_aml_scripts( + acr_name=acr_name, + deployment_name=mii_config.deployment_config.deployment_name, + model_name=mii_config.deployment_config.model, + version=mii_config.version, + ) print( - f"AML deployment assets at {mii.aml_related.utils.aml_output_path(deployment_name)}" + f"AML deployment assets at {mii.aml_related.utils.aml_output_path(mii_config.deployment_config.deployment_name)}" ) print("Please run 'deploy.sh' to bring your deployment online") -def _allocate_processes(hostfile_path, tensor_parallel, num_replicas): - resource_pool = fetch_hostfile(hostfile_path) - assert resource_pool is not None and len( - resource_pool) > 0, f'No hosts found in {hostfile_path}' - - replica_pool = [] - allocated_num = 0 - for host, slots in resource_pool.items(): - available_on_host = slots - while available_on_host >= tensor_parallel: - if allocated_num >= num_replicas: - break - if slots < tensor_parallel: - raise ValueError( - f'Host {host} has {slots} slot(s), but {tensor_parallel} slot(s) are required' - ) - - allocated_num_on_host = slots - available_on_host - replica_pool.append( - (host, - [ - i for i in range(allocated_num_on_host, - allocated_num_on_host + tensor_parallel) - ])) - allocated_num += 1 - - available_on_host -= tensor_parallel - - if allocated_num < num_replicas: - raise ValueError( - f'No sufficient GPUs for {num_replicas} replica(s), only {allocated_num} replica(s) can be deployed' - ) - - return replica_pool +def _deploy_nonpersistent(mii_config): + assert ( + int(os.getenv("WORLD_SIZE", "1")) + == mii_config.deployment_config.tensor_parallel + ), "World Size does not equal number of tensors. When using non-persistent deployment type, please launch with `deepspeed --num_gpus `" + deployment_name = mii_config.deployment_config.deployment_name + mii.non_persistent_models[deployment_name] = ( + load_models(mii_config.deployment_config), + mii_config.deployment_config.task, + ) diff --git a/mii/grpc_related/modelresponse_server.py b/mii/grpc_related/modelresponse_server.py index 4a0a5d00..c5815b26 100644 --- a/mii/grpc_related/modelresponse_server.py +++ b/mii/grpc_related/modelresponse_server.py @@ -13,10 +13,18 @@ import threading import time -from mii.constants import GRPC_MAX_MSG_SIZE, CREATE_SESSION_METHOD, DESTROY_SESSION_METHOD, TERMINATE_METHOD, LB_MAX_WORKER_THREADS, SERVER_SHUTDOWN_TIMEOUT, Tasks +from mii.constants import ( + GRPC_MAX_MSG_SIZE, + CREATE_SESSION_METHOD, + DESTROY_SESSION_METHOD, + TERMINATE_METHOD, + LB_MAX_WORKER_THREADS, + SERVER_SHUTDOWN_TIMEOUT, + TaskType, +) from mii.method_table import GRPC_METHOD_TABLE from mii.client import create_channel -from mii.utils import get_task, unpack_proto_query_kwargs +from mii.utils import unpack_proto_query_kwargs class ServiceBase(modelresponse_pb2_grpc.ModelResponseServicer): @@ -62,12 +70,12 @@ def _get_model_time(self, model, sum_times=False): return model_time def CreateSession(self, request, context): - task_methods = GRPC_METHOD_TABLE[Tasks.TEXT_GENERATION] + task_methods = GRPC_METHOD_TABLE[TaskType.TEXT_GENERATION] task_methods.create_session(request.session_id) return google_dot_protobuf_dot_empty__pb2.Empty() def DestroySession(self, request, context): - task_methods = GRPC_METHOD_TABLE[Tasks.TEXT_GENERATION] + task_methods = GRPC_METHOD_TABLE[TaskType.TEXT_GENERATION] task_methods.destroy_session(request.session_id) return google_dot_protobuf_dot_empty__pb2.Empty() @@ -87,10 +95,10 @@ def _run_inference(self, method_name, request_proto): response = task_methods.run_inference(self.inference_pipeline, args, kwargs) end = time.time() - model_time = self._get_model_time(self.inference_pipeline.model, - sum_times=True) if hasattr( - self.inference_pipeline, - "model") else -1 + model_time = (self._get_model_time(self.inference_pipeline.model, + sum_times=True) if hasattr( + self.inference_pipeline, + "model") else -1) return task_methods.pack_response_to_proto(response, end - start, model_time) @@ -164,18 +172,26 @@ def invoke(self, method_name, proto_request): class LoadBalancingInterceptor(grpc.ServerInterceptor): - def __init__(self, task_name, replica_configs): + def __init__(self, mii_config): super().__init__() self.asyncio_loop = asyncio.get_event_loop() - self.stubs = [ - ParallelStubInvoker(replica.hostname, - replica.tensor_parallel_ports) - for replica in replica_configs - ] - self.counter = AtomicCounter() - self.task = get_task(task_name) - self.replica_sessions = {} + self.stubs = {} + self.counter = {} + self.replica_configs = replica_configs + self.tasks = {} + for deployment in mii_config.deployment_configs: + self.stubs[deployment.deployment_name] = [] + self.counter[deployment.deployment_name] = AtomicCounter() + self.tasks[deployment.deployment_name] = repl.task + + for deployment in mii_config.deployment_configs: + deployment_name = deployment.deployment_name + for repl in deployment.replica_configs: + self.stubs[deployment_name].append( + ParallelStubInvoker(repl.hostname, + repl.tensor_parallel_ports, + self.asyncio_loop)) # Start the asyncio loop in a separate thread def run_asyncio_loop(loop): @@ -193,38 +209,45 @@ def intercept_service(self, continuation, handler_call_details): def invoke_intercept_method(request_proto, context): method_name = _get_grpc_method_name(handler_call_details.method) - if method_name == TERMINATE_METHOD: - for stub in self.stubs: - stub.invoke(TERMINATE_METHOD, - google_dot_protobuf_dot_empty__pb2.Empty()) + for deployment_name in self.stubs: + for stub in self.stubs[deployment_name]: + stub.invoke(TERMINATE_METHOD, + google_dot_protobuf_dot_empty__pb2.Empty()) self.asyncio_loop.call_soon_threadsafe(self.asyncio_loop.stop) return next_handler.unary_unary(request_proto, context) - call_count = self.counter.get_and_increment() - replica_index = call_count % len(self.stubs) - if method_name == CREATE_SESSION_METHOD: if request_proto.session_id in self.sessions: raise ValueError( f"session {request_proto.session_id} already exists") self.replica_sessions[request_proto.session_id] = replica_index - self.stubs[replica_index].invoke(CREATE_SESSION_METHOD, request_proto) + self.stubs[deployment_name][replica_index].invoke( + CREATE_SESSION_METHOD, + request_proto) return google_dot_protobuf_dot_empty__pb2.Empty() if method_name == DESTROY_SESSION_METHOD: replica_index = self.replica_sessions.pop(request_proto.session_id) - self.stubs[replica_index].invoke(DESTROY_SESSION_METHOD, request_proto) + self.stubs[deployment_name][replica_index].invoke( + DESTROY_SESSION_METHOD, + request_proto) return google_dot_protobuf_dot_empty__pb2.Empty() - kwargs = unpack_proto_query_kwargs(request_proto.query_kwargs) - if "session_id" in kwargs: - session_id = kwargs["session_id"] + if "session_id" in request_proto.query_kwargs: + session_id = request_proto.query_kwargs["session_id"] if session_id not in self.replica_sessions: raise ValueError(f"session not found") replica_index = self.replica_sessions[session_id] - ret = self.stubs[replica_index].invoke(method_name, request_proto) + deployment_name = getattr(request_proto, 'deployment_name') + assert deployment_name in self.stubs, f"Deployment: {deployment_name} not found" + call_count = self.counter[deployment_name].get_and_increment() + replica_index = call_count % len(self.stubs[deployment_name]) + + ret = self.stubs[deployment_name][replica_index].invoke( + method_name, + request_proto) return ret return grpc.unary_unary_rpc_method_handler( @@ -235,14 +258,18 @@ def invoke_intercept_method(request_proto, context): def _do_serve(service_impl, port, interceptors=[]): stop_event = service_impl.get_stop_event() - server = grpc.server(futures.ThreadPoolExecutor(max_workers=LB_MAX_WORKER_THREADS), - interceptors=interceptors, - options=[('grpc.max_send_message_length', - GRPC_MAX_MSG_SIZE), - ('grpc.max_receive_message_length', - GRPC_MAX_MSG_SIZE)]) + server = grpc.server( + futures.ThreadPoolExecutor(max_workers=LB_MAX_WORKER_THREADS), + interceptors=interceptors, + options=[ + ("grpc.max_send_message_length", + GRPC_MAX_MSG_SIZE), + ("grpc.max_receive_message_length", + GRPC_MAX_MSG_SIZE), + ], + ) modelresponse_pb2_grpc.add_ModelResponseServicer_to_server(service_impl, server) - server.add_insecure_port(f'[::]:{port}') + server.add_insecure_port(f"[::]:{port}") print(f"About to start server") server.start() print(f"Started") @@ -254,14 +281,11 @@ def serve_inference(inference_pipeline, port): _do_serve(ModelResponse(inference_pipeline), port) -def serve_load_balancing(task_name, lb_config): - _do_serve(ServiceBase(), - lb_config.port, - [LoadBalancingInterceptor(task_name, - lb_config.replica_configs)]) +def serve_load_balancing(mii_config, lb_port): + _do_serve(ServiceBase(), lb_port, [LoadBalancingInterceptor(mii_config)]) -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig() print(sys.argv[1]) serve_inference(None, sys.argv[1]) diff --git a/mii/grpc_related/proto/modelresponse.proto b/mii/grpc_related/proto/modelresponse.proto index a0698899..74ca4913 100644 --- a/mii/grpc_related/proto/modelresponse.proto +++ b/mii/grpc_related/proto/modelresponse.proto @@ -52,11 +52,13 @@ message SessionID { message SingleStringRequest { string request = 1; map query_kwargs = 2; + optional string deployment_name = 3; } message MultiStringRequest { repeated string request = 1; map query_kwargs = 2; + optional string deployment_name = 3; } message SingleStringReply { @@ -75,6 +77,7 @@ message QARequest { string question = 1; string context = 2; map query_kwargs = 3; + optional string deployment_name = 4; } message ConversationRequest { @@ -83,6 +86,7 @@ message ConversationRequest { repeated string past_user_inputs = 3; repeated string generated_responses = 4; map query_kwargs = 5; + optional string deployment_name = 6; } message ConversationReply { diff --git a/mii/grpc_related/proto/modelresponse_pb2.py b/mii/grpc_related/proto/modelresponse_pb2.py index 76b1f994..0a219bf3 100644 --- a/mii/grpc_related/proto/modelresponse_pb2.py +++ b/mii/grpc_related/proto/modelresponse_pb2.py @@ -1,15 +1,10 @@ -# Copyright (c) Microsoft Corporation. -# SPDX-License-Identifier: Apache-2.0 - -# DeepSpeed Team - # Generated by the protocol buffer compiler. DO NOT EDIT! # source: modelresponse.proto """Generated protocol buffer code.""" -from google.protobuf.internal import builder as _builder from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor_pool as _descriptor_pool from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder # @@protoc_insertion_point(imports) _sym_db = _symbol_database.Default() @@ -17,11 +12,12 @@ from google.protobuf import empty_pb2 as google_dot_protobuf_dot_empty__pb2 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x13modelresponse.proto\x12\rmodelresponse\x1a\x1bgoogle/protobuf/empty.proto\"_\n\x05Value\x12\x10\n\x06svalue\x18\x01 \x01(\tH\x00\x12\x10\n\x06ivalue\x18\x02 \x01(\x03H\x00\x12\x10\n\x06\x66value\x18\x03 \x01(\x02H\x00\x12\x10\n\x06\x62value\x18\x04 \x01(\x08H\x00\x42\x0e\n\x0coneof_values\"\x1f\n\tSessionID\x12\x12\n\nsession_id\x18\x01 \x01(\t\"\xbb\x01\n\x13SingleStringRequest\x12\x0f\n\x07request\x18\x01 \x01(\t\x12I\n\x0cquery_kwargs\x18\x02 \x03(\x0b\x32\x33.modelresponse.SingleStringRequest.QueryKwargsEntry\x1aH\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.modelresponse.Value:\x02\x38\x01\"\xb9\x01\n\x12MultiStringRequest\x12\x0f\n\x07request\x18\x01 \x03(\t\x12H\n\x0cquery_kwargs\x18\x02 \x03(\x0b\x32\x32.modelresponse.MultiStringRequest.QueryKwargsEntry\x1aH\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.modelresponse.Value:\x02\x38\x01\"S\n\x11SingleStringReply\x12\x10\n\x08response\x18\x01 \x01(\t\x12\x12\n\ntime_taken\x18\x02 \x01(\x02\x12\x18\n\x10model_time_taken\x18\x03 \x01(\x02\"R\n\x10MultiStringReply\x12\x10\n\x08response\x18\x01 \x03(\t\x12\x12\n\ntime_taken\x18\x02 \x01(\x02\x12\x18\n\x10model_time_taken\x18\x03 \x01(\x02\"\xb9\x01\n\tQARequest\x12\x10\n\x08question\x18\x01 \x01(\t\x12\x0f\n\x07\x63ontext\x18\x02 \x01(\t\x12?\n\x0cquery_kwargs\x18\x03 \x03(\x0b\x32).modelresponse.QARequest.QueryKwargsEntry\x1aH\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.modelresponse.Value:\x02\x38\x01\"\xa1\x02\n\x13\x43onversationRequest\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\x1c\n\x0f\x63onversation_id\x18\x02 \x01(\x03H\x00\x88\x01\x01\x12\x18\n\x10past_user_inputs\x18\x03 \x03(\t\x12\x1b\n\x13generated_responses\x18\x04 \x03(\t\x12I\n\x0cquery_kwargs\x18\x05 \x03(\x0b\x32\x33.modelresponse.ConversationRequest.QueryKwargsEntry\x1aH\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.modelresponse.Value:\x02\x38\x01\x42\x12\n\x10_conversation_id\"\x91\x01\n\x11\x43onversationReply\x12\x17\n\x0f\x63onversation_id\x18\x01 \x01(\x03\x12\x18\n\x10past_user_inputs\x18\x02 \x03(\t\x12\x1b\n\x13generated_responses\x18\x03 \x03(\t\x12\x12\n\ntime_taken\x18\x04 \x01(\x02\x12\x18\n\x10model_time_taken\x18\x05 \x01(\x02\"}\n\nImageReply\x12\x0e\n\x06images\x18\x01 \x03(\x0c\x12\x1d\n\x15nsfw_content_detected\x18\x02 \x03(\x08\x12\x0c\n\x04mode\x18\x03 \x01(\t\x12\x0e\n\x06size_w\x18\x04 \x01(\x03\x12\x0e\n\x06size_h\x18\x05 \x01(\x03\x12\x12\n\ntime_taken\x18\x06 \x01(\x02\x32\xd4\x06\n\rModelResponse\x12=\n\tTerminate\x12\x16.google.protobuf.Empty\x1a\x16.google.protobuf.Empty\"\x00\x12\x43\n\rCreateSession\x12\x18.modelresponse.SessionID\x1a\x16.google.protobuf.Empty\"\x00\x12\x44\n\x0e\x44\x65stroySession\x12\x18.modelresponse.SessionID\x1a\x16.google.protobuf.Empty\"\x00\x12V\n\x0eGeneratorReply\x12!.modelresponse.MultiStringRequest\x1a\x1f.modelresponse.MultiStringReply\"\x00\x12]\n\x13\x43lassificationReply\x12\".modelresponse.SingleStringRequest\x1a .modelresponse.SingleStringReply\"\x00\x12V\n\x16QuestionAndAnswerReply\x12\x18.modelresponse.QARequest\x1a .modelresponse.SingleStringReply\"\x00\x12W\n\rFillMaskReply\x12\".modelresponse.SingleStringRequest\x1a .modelresponse.SingleStringReply\"\x00\x12\x62\n\x18TokenClassificationReply\x12\".modelresponse.SingleStringRequest\x1a .modelresponse.SingleStringReply\"\x00\x12]\n\x13\x43onversationalReply\x12\".modelresponse.ConversationRequest\x1a .modelresponse.ConversationReply\"\x00\x12N\n\x0cTxt2ImgReply\x12!.modelresponse.MultiStringRequest\x1a\x19.modelresponse.ImageReply\"\x00\x62\x06proto3' + b'\n\x13modelresponse.proto\x12\rmodelresponse\x1a\x1bgoogle/protobuf/empty.proto\"_\n\x05Value\x12\x10\n\x06svalue\x18\x01 \x01(\tH\x00\x12\x10\n\x06ivalue\x18\x02 \x01(\x03H\x00\x12\x10\n\x06\x66value\x18\x03 \x01(\x02H\x00\x12\x10\n\x06\x62value\x18\x04 \x01(\x08H\x00\x42\x0e\n\x0coneof_values\"\x1f\n\tSessionID\x12\x12\n\nsession_id\x18\x01 \x01(\t\"\xed\x01\n\x13SingleStringRequest\x12\x0f\n\x07request\x18\x01 \x01(\t\x12I\n\x0cquery_kwargs\x18\x02 \x03(\x0b\x32\x33.modelresponse.SingleStringRequest.QueryKwargsEntry\x12\x1c\n\x0f\x64\x65ployment_name\x18\x03 \x01(\tH\x00\x88\x01\x01\x1aH\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.modelresponse.Value:\x02\x38\x01\x42\x12\n\x10_deployment_name\"\xeb\x01\n\x12MultiStringRequest\x12\x0f\n\x07request\x18\x01 \x03(\t\x12H\n\x0cquery_kwargs\x18\x02 \x03(\x0b\x32\x32.modelresponse.MultiStringRequest.QueryKwargsEntry\x12\x1c\n\x0f\x64\x65ployment_name\x18\x03 \x01(\tH\x00\x88\x01\x01\x1aH\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.modelresponse.Value:\x02\x38\x01\x42\x12\n\x10_deployment_name\"S\n\x11SingleStringReply\x12\x10\n\x08response\x18\x01 \x01(\t\x12\x12\n\ntime_taken\x18\x02 \x01(\x02\x12\x18\n\x10model_time_taken\x18\x03 \x01(\x02\"R\n\x10MultiStringReply\x12\x10\n\x08response\x18\x01 \x03(\t\x12\x12\n\ntime_taken\x18\x02 \x01(\x02\x12\x18\n\x10model_time_taken\x18\x03 \x01(\x02\"\xeb\x01\n\tQARequest\x12\x10\n\x08question\x18\x01 \x01(\t\x12\x0f\n\x07\x63ontext\x18\x02 \x01(\t\x12?\n\x0cquery_kwargs\x18\x03 \x03(\x0b\x32).modelresponse.QARequest.QueryKwargsEntry\x12\x1c\n\x0f\x64\x65ployment_name\x18\x04 \x01(\tH\x00\x88\x01\x01\x1aH\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.modelresponse.Value:\x02\x38\x01\x42\x12\n\x10_deployment_name\"\xd3\x02\n\x13\x43onversationRequest\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\x1c\n\x0f\x63onversation_id\x18\x02 \x01(\x03H\x00\x88\x01\x01\x12\x18\n\x10past_user_inputs\x18\x03 \x03(\t\x12\x1b\n\x13generated_responses\x18\x04 \x03(\t\x12I\n\x0cquery_kwargs\x18\x05 \x03(\x0b\x32\x33.modelresponse.ConversationRequest.QueryKwargsEntry\x12\x1c\n\x0f\x64\x65ployment_name\x18\x06 \x01(\tH\x01\x88\x01\x01\x1aH\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.modelresponse.Value:\x02\x38\x01\x42\x12\n\x10_conversation_idB\x12\n\x10_deployment_name\"\x91\x01\n\x11\x43onversationReply\x12\x17\n\x0f\x63onversation_id\x18\x01 \x01(\x03\x12\x18\n\x10past_user_inputs\x18\x02 \x03(\t\x12\x1b\n\x13generated_responses\x18\x03 \x03(\t\x12\x12\n\ntime_taken\x18\x04 \x01(\x02\x12\x18\n\x10model_time_taken\x18\x05 \x01(\x02\"}\n\nImageReply\x12\x0e\n\x06images\x18\x01 \x03(\x0c\x12\x1d\n\x15nsfw_content_detected\x18\x02 \x03(\x08\x12\x0c\n\x04mode\x18\x03 \x01(\t\x12\x0e\n\x06size_w\x18\x04 \x01(\x03\x12\x0e\n\x06size_h\x18\x05 \x01(\x03\x12\x12\n\ntime_taken\x18\x06 \x01(\x02\x32\xd4\x06\n\rModelResponse\x12=\n\tTerminate\x12\x16.google.protobuf.Empty\x1a\x16.google.protobuf.Empty\"\x00\x12\x43\n\rCreateSession\x12\x18.modelresponse.SessionID\x1a\x16.google.protobuf.Empty\"\x00\x12\x44\n\x0e\x44\x65stroySession\x12\x18.modelresponse.SessionID\x1a\x16.google.protobuf.Empty\"\x00\x12V\n\x0eGeneratorReply\x12!.modelresponse.MultiStringRequest\x1a\x1f.modelresponse.MultiStringReply\"\x00\x12]\n\x13\x43lassificationReply\x12\".modelresponse.SingleStringRequest\x1a .modelresponse.SingleStringReply\"\x00\x12V\n\x16QuestionAndAnswerReply\x12\x18.modelresponse.QARequest\x1a .modelresponse.SingleStringReply\"\x00\x12W\n\rFillMaskReply\x12\".modelresponse.SingleStringRequest\x1a .modelresponse.SingleStringReply\"\x00\x12\x62\n\x18TokenClassificationReply\x12\".modelresponse.SingleStringRequest\x1a .modelresponse.SingleStringReply\"\x00\x12]\n\x13\x43onversationalReply\x12\".modelresponse.ConversationRequest\x1a .modelresponse.ConversationReply\"\x00\x12N\n\x0cTxt2ImgReply\x12!.modelresponse.MultiStringRequest\x1a\x19.modelresponse.ImageReply\"\x00\x62\x06proto3' ) -_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) -_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'modelresponse_pb2', globals()) +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'modelresponse_pb2', _globals) if _descriptor._USE_C_DESCRIPTORS == False: DESCRIPTOR._options = None @@ -33,34 +29,34 @@ _QAREQUEST_QUERYKWARGSENTRY._serialized_options = b'8\001' _CONVERSATIONREQUEST_QUERYKWARGSENTRY._options = None _CONVERSATIONREQUEST_QUERYKWARGSENTRY._serialized_options = b'8\001' - _VALUE._serialized_start = 67 - _VALUE._serialized_end = 162 - _SESSIONID._serialized_start = 164 - _SESSIONID._serialized_end = 195 - _SINGLESTRINGREQUEST._serialized_start = 198 - _SINGLESTRINGREQUEST._serialized_end = 385 - _SINGLESTRINGREQUEST_QUERYKWARGSENTRY._serialized_start = 313 - _SINGLESTRINGREQUEST_QUERYKWARGSENTRY._serialized_end = 385 - _MULTISTRINGREQUEST._serialized_start = 388 - _MULTISTRINGREQUEST._serialized_end = 573 - _MULTISTRINGREQUEST_QUERYKWARGSENTRY._serialized_start = 313 - _MULTISTRINGREQUEST_QUERYKWARGSENTRY._serialized_end = 385 - _SINGLESTRINGREPLY._serialized_start = 575 - _SINGLESTRINGREPLY._serialized_end = 658 - _MULTISTRINGREPLY._serialized_start = 660 - _MULTISTRINGREPLY._serialized_end = 742 - _QAREQUEST._serialized_start = 745 - _QAREQUEST._serialized_end = 930 - _QAREQUEST_QUERYKWARGSENTRY._serialized_start = 313 - _QAREQUEST_QUERYKWARGSENTRY._serialized_end = 385 - _CONVERSATIONREQUEST._serialized_start = 933 - _CONVERSATIONREQUEST._serialized_end = 1222 - _CONVERSATIONREQUEST_QUERYKWARGSENTRY._serialized_start = 313 - _CONVERSATIONREQUEST_QUERYKWARGSENTRY._serialized_end = 385 - _CONVERSATIONREPLY._serialized_start = 1225 - _CONVERSATIONREPLY._serialized_end = 1370 - _IMAGEREPLY._serialized_start = 1372 - _IMAGEREPLY._serialized_end = 1497 - _MODELRESPONSE._serialized_start = 1500 - _MODELRESPONSE._serialized_end = 2352 + _globals['_VALUE']._serialized_start = 67 + _globals['_VALUE']._serialized_end = 162 + _globals['_SESSIONID']._serialized_start = 164 + _globals['_SESSIONID']._serialized_end = 195 + _globals['_SINGLESTRINGREQUEST']._serialized_start = 198 + _globals['_SINGLESTRINGREQUEST']._serialized_end = 435 + _globals['_SINGLESTRINGREQUEST_QUERYKWARGSENTRY']._serialized_start = 343 + _globals['_SINGLESTRINGREQUEST_QUERYKWARGSENTRY']._serialized_end = 415 + _globals['_MULTISTRINGREQUEST']._serialized_start = 438 + _globals['_MULTISTRINGREQUEST']._serialized_end = 673 + _globals['_MULTISTRINGREQUEST_QUERYKWARGSENTRY']._serialized_start = 343 + _globals['_MULTISTRINGREQUEST_QUERYKWARGSENTRY']._serialized_end = 415 + _globals['_SINGLESTRINGREPLY']._serialized_start = 675 + _globals['_SINGLESTRINGREPLY']._serialized_end = 758 + _globals['_MULTISTRINGREPLY']._serialized_start = 760 + _globals['_MULTISTRINGREPLY']._serialized_end = 842 + _globals['_QAREQUEST']._serialized_start = 845 + _globals['_QAREQUEST']._serialized_end = 1080 + _globals['_QAREQUEST_QUERYKWARGSENTRY']._serialized_start = 343 + _globals['_QAREQUEST_QUERYKWARGSENTRY']._serialized_end = 415 + _globals['_CONVERSATIONREQUEST']._serialized_start = 1083 + _globals['_CONVERSATIONREQUEST']._serialized_end = 1422 + _globals['_CONVERSATIONREQUEST_QUERYKWARGSENTRY']._serialized_start = 343 + _globals['_CONVERSATIONREQUEST_QUERYKWARGSENTRY']._serialized_end = 415 + _globals['_CONVERSATIONREPLY']._serialized_start = 1425 + _globals['_CONVERSATIONREPLY']._serialized_end = 1570 + _globals['_IMAGEREPLY']._serialized_start = 1572 + _globals['_IMAGEREPLY']._serialized_end = 1697 + _globals['_MODELRESPONSE']._serialized_start = 1700 + _globals['_MODELRESPONSE']._serialized_end = 2552 # @@protoc_insertion_point(module_scope) diff --git a/mii/grpc_related/proto/modelresponse_pb2_grpc.py b/mii/grpc_related/proto/modelresponse_pb2_grpc.py index 95cfa825..f6515b0a 100644 --- a/mii/grpc_related/proto/modelresponse_pb2_grpc.py +++ b/mii/grpc_related/proto/modelresponse_pb2_grpc.py @@ -1,8 +1,3 @@ -# Copyright (c) Microsoft Corporation. -# SPDX-License-Identifier: Apache-2.0 - -# DeepSpeed Team - # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! """Client and server classes corresponding to protobuf-defined services.""" import grpc diff --git a/mii/grpc_related/restful_gateway.py b/mii/grpc_related/restful_gateway.py index e8cfa934..68b17c73 100644 --- a/mii/grpc_related/restful_gateway.py +++ b/mii/grpc_related/restful_gateway.py @@ -17,9 +17,9 @@ def shutdown(thread): thread.server.shutdown() -def createRestfulGatewayApp(deployment_name, task, mii_config, server_thread): +def createRestfulGatewayApp(deployment_config, lb_port, server_thread): # client must be thread-safe - client = mii.MIIClient(task, "localhost", mii_config.port_number) + client = mii.MIIClient(deployment_config.task, "localhost", lb_port) class RestfulGatewayService(Resource): def __init__(self): @@ -33,26 +33,25 @@ def post(self): app = Flask("RestfulGateway") - @app.route("/terminate", methods=['GET']) + @app.route("/terminate", methods=["GET"]) def terminate(): # Need to shutdown *after* completing the request threading.Thread(target=shutdown, args=(server_thread, )).start() return "Shutting down RESTful API gateway server" api = Api(app) - path = '/{}/{}'.format(RESTFUL_API_PATH, deployment_name) + path = "/{}/{}".format(RESTFUL_API_PATH, deployment_config.deployment_name) api.add_resource(RestfulGatewayService, path) return app class RestfulGatewayThread(threading.Thread): - def __init__(self, deployment_name, task, mii_config): + def __init__(self, deployment_config, lb_port, rest_port): threading.Thread.__init__(self) - self.mii_config = mii_config - app = createRestfulGatewayApp(deployment_name, task, mii_config, self) - self.server = make_server('127.0.0.1', mii_config.restful_api_port, app) + app = createRestfulGatewayApp(deployment_config, lb_port, self) + self.server = make_server("127.0.0.1", rest_port, app) self.ctx = app.app_context() self.ctx.push() diff --git a/mii/launch/multi_gpu_server.py b/mii/launch/multi_gpu_server.py index 27878725..34b46845 100644 --- a/mii/launch/multi_gpu_server.py +++ b/mii/launch/multi_gpu_server.py @@ -4,99 +4,106 @@ # DeepSpeed Team import os import argparse -import mii import base64 import json -from mii import MIIConfig, LoadBalancerConfig - +from mii.config import DeploymentConfig, MIIConfig from mii.models.load_models import load_models from mii.grpc_related.modelresponse_server import serve_inference, serve_load_balancing from mii.grpc_related.restful_gateway import RestfulGatewayThread -def decode_config_from_str(config_str): +def b64_encoded_config(config_str): # str -> bytes b64_bytes = config_str.encode() # decode b64 bytes -> json bytes config_bytes = base64.urlsafe_b64decode(b64_bytes) # convert json bytes -> str -> dict - return json.loads(config_bytes.decode()) + config_dict = json.loads(config_bytes.decode()) + # return mii.DeploymentConfig object + return DeploymentConfig(**config_dict) + + +def b64_encoded_config_MII(config_str): #TODO: Remove Duplicated Funciton + # str -> bytes + b64_bytes = config_str.encode() + # decode b64 bytes -> json bytes + config_bytes = base64.urlsafe_b64decode(b64_bytes) + # convert json bytes -> str -> dict + config_dict = json.loads(config_bytes.decode()) + # return mii.MIIConfig object + return MIIConfig(**config_dict) def main(): parser = argparse.ArgumentParser() - parser.add_argument("-n", "--deployment-name", type=str, help="deployment name") - parser.add_argument("-t", "--task-name", type=str, help="task name") - parser.add_argument("-m", "--model", type=str, help="model name") - parser.add_argument("-d", "--model-path", type=str, help="path to model") - parser.add_argument('-b', '--provider', type=str, help="model provider") - parser.add_argument("-o", - "--ds-optimize", - action='store_true', - help="Enable DeepSpeed") - parser.add_argument("-z", - "--ds-zero", - action='store_true', - help="Enable DeepSpeed ZeRO") - parser.add_argument("--ds-config", type=str, help="path to DeepSpeed ZeRO config") - parser.add_argument( - "-p", - "--port", + "--deployment-config", + type=b64_encoded_config, + help="base64 encoded deployment config", + ) + parser.add_argument( + "--mii-config", + type=b64_encoded_config_MII, + help="base64 encoded mii config", + ) + parser.add_argument( + "--server-port", type=int, - help="base server port, each rank will have unique port based on this value") - parser.add_argument("-c", "--config", type=str, help="base64 encoded mii config") + default=0, + help="Port to user for DeepSpeed inference server.", + ) parser.add_argument("--load-balancer", - type=str, - default=None, - help="base64 encoded load balancer config") - parser.add_argument("-r", - "--restful-gateway", - action='store_true', - help="launch restful api gateway") - + action="store_true", + help="Launch load balancer process.") + parser.add_argument( + "--load-balancer-port", + type=int, + default=0, + help="Port to use for load balancer.", + ) + parser.add_argument( + "--restful-gateway", + action="store_true", + help="Launches restful gateway process.", + ) + parser.add_argument( + "--restful-gateway-port", + type=int, + default=0, + help="Port to use for restful gateway.", + ) args = parser.parse_args() - - # de-serialize config object - config_dict = decode_config_from_str(args.config) - # convert dict -> mii config - mii_config = MIIConfig(**config_dict) + assert not ( + args.load_balancer and args.restful_gateway + ), "Select only load-balancer OR restful-gateway." if args.restful_gateway: - print(f"Starting RESTful API gateway on port: {mii_config.restful_api_port}") - gateway_thread = RestfulGatewayThread(args.deployment_name, - args.task_name, - mii_config) + assert args.restful_gateway_port, "--restful-gateway-port must be provided." + print(f"Starting RESTful API gateway on port: {args.restful_gateway_port}") + gateway_thread = RestfulGatewayThread( + args.deployment_config, + lb_port=args.load_balancer_port, + rest_port=args.restful_gateway_port, + ) stop_event = gateway_thread.get_stop_event() gateway_thread.start() stop_event.wait() - elif args.load_balancer is None: - provider = mii.constants.MODEL_PROVIDER_MAP.get(args.provider, None) - assert provider is not None, f"Unknown model provider: {args.provider}" + elif args.load_balancer: + assert args.load_balancer_port, "--load-balancer-port must be provided." + print(f"Starting load balancer on port: {args.load_balancer_port}") + serve_load_balancing(args.mii_config, args.load_balancer_port) - assert args.port is not None, "port is required for inference server" - local_rank = int(os.getenv('LOCAL_RANK', '0')) - port = args.port + local_rank + else: + assert args.server_port, "--server-port must be provided." + local_rank = int(os.getenv("LOCAL_RANK", "0")) + port = args.server_port + local_rank - inference_pipeline = load_models(task_name=args.task_name, - model_name=args.model, - model_path=args.model_path, - ds_optimize=args.ds_optimize, - ds_zero=args.ds_zero, - ds_config_path=args.ds_config, - provider=provider, - mii_config=mii_config) + inference_pipeline = load_models(args.deployment_config) print(f"Starting server on port: {port}") serve_inference(inference_pipeline, port) - else: - lb_config_dict = decode_config_from_str(args.load_balancer) - lb_config = LoadBalancerConfig(**lb_config_dict) - - print(f"Starting load balancer on port: {lb_config.port}") - serve_load_balancing(args.task_name, lb_config) if __name__ == "__main__": diff --git a/mii/method_table.py b/mii/method_table.py index c412f446..d1eb0b9d 100644 --- a/mii/method_table.py +++ b/mii/method_table.py @@ -4,7 +4,7 @@ # DeepSpeed Team from abc import ABC, abstractmethod from transformers import Conversation -from mii.constants import Tasks +from mii.constants import TaskType from mii.grpc_related.proto import modelresponse_pb2 from mii.utils import kwarg_dict_to_proto, unpack_proto_query_kwargs from mii.models.utils import ImageResponse @@ -12,7 +12,7 @@ def single_string_request_to_proto(self, request_dict, **query_kwargs): return modelresponse_pb2.SingleStringRequest( - request=request_dict['query'], + request=request_dict["query"], query_kwargs=kwarg_dict_to_proto(query_kwargs)) @@ -24,9 +24,10 @@ def single_string_response_to_proto(self, response, time_taken, model_time_taken def multi_string_request_to_proto(self, request_dict, **query_kwargs): return modelresponse_pb2.MultiStringRequest( - request=request_dict['query'] if isinstance(request_dict['query'], - list) else [request_dict['query']], - query_kwargs=kwarg_dict_to_proto(query_kwargs)) + request=request_dict["query"] if isinstance(request_dict["query"], + list) else [request_dict["query"]], + query_kwargs=kwarg_dict_to_proto(query_kwargs), + ) def proto_request_to_single_input(self, request): @@ -114,12 +115,14 @@ def postprocess_session(self, session_id, args, response): def pack_response_to_proto(self, response, time_taken, model_time_taken): text_responses = [] for response in response: - text = response[0]['generated_text'] + text = response[0]["generated_text"] text_responses.append(text) - return modelresponse_pb2.MultiStringReply(response=text_responses, - time_taken=time_taken, - model_time_taken=model_time_taken) + return modelresponse_pb2.MultiStringReply( + response=text_responses, + time_taken=time_taken, + model_time_taken=model_time_taken, + ) class TextClassificationMethods(TaskMethods): @@ -141,9 +144,10 @@ def method(self): def pack_request_to_proto(self, request_dict, **query_kwargs): return modelresponse_pb2.QARequest( - question=request_dict['question'], - context=request_dict['context'], - query_kwargs=kwarg_dict_to_proto(query_kwargs)) + question=request_dict["question"], + context=request_dict["context"], + query_kwargs=kwarg_dict_to_proto(query_kwargs), + ) def unpack_request_from_proto(self, request): kwargs = unpack_proto_query_kwargs(request.query_kwargs) @@ -180,24 +184,30 @@ def method(self): def create_conversation(self, request, **kwargs): if isinstance(request, dict): - assert 'text' in request and 'past_user_inputs' in request and 'generated_responses' in request, "Conversation requires 'text', 'past_user_inputs', and 'generated_responses' keys" - text = request['text'] - conversation_id = request[ - 'conversation_id'] if 'conversation_id' in request else None - past_user_inputs = request['past_user_inputs'] - generated_responses = request['generated_responses'] + assert ( + "text" in request + and "past_user_inputs" in request + and "generated_responses" in request + ), "Conversation requires 'text', 'past_user_inputs', and 'generated_responses' keys" + text = request["text"] + conversation_id = (request["conversation_id"] + if "conversation_id" in request else None) + past_user_inputs = request["past_user_inputs"] + generated_responses = request["generated_responses"] else: - text = getattr(request, 'text') - conversation_id = getattr(request, 'conversation_id') - past_user_inputs = getattr(request, 'past_user_inputs') - generated_responses = getattr(request, 'generated_responses') - - conv = Conversation(text=text, - conversation_id=conversation_id, - past_user_inputs=past_user_inputs, - generated_responses=generated_responses, - **kwargs) + text = getattr(request, "text") + conversation_id = getattr(request, "conversation_id") + past_user_inputs = getattr(request, "past_user_inputs") + generated_responses = getattr(request, "generated_responses") + + conv = Conversation( + text=text, + conversation_id=conversation_id, + past_user_inputs=past_user_inputs, + generated_responses=generated_responses, + **kwargs, + ) return conv def pack_response_to_proto(self, conv, time_taken, model_time_taken): @@ -206,7 +216,8 @@ def pack_response_to_proto(self, conv, time_taken, model_time_taken): past_user_inputs=conv.past_user_inputs, generated_responses=conv.generated_responses, time_taken=time_taken, - model_time_taken=model_time_taken) + model_time_taken=model_time_taken, + ) def unpack_request_from_proto(self, request): kwargs = unpack_proto_query_kwargs(request.query_kwargs) @@ -217,12 +228,13 @@ def unpack_request_from_proto(self, request): def pack_request_to_proto(self, request_dict, **query_kwargs): return modelresponse_pb2.ConversationRequest( - text=request_dict['text'], - conversation_id=request_dict['conversation_id'] - if 'conversation_id' in request_dict else None, - past_user_inputs=request_dict['past_user_inputs'], - generated_responses=request_dict['generated_responses'], - query_kwargs=kwarg_dict_to_proto(query_kwargs)) + text=request_dict["text"], + conversation_id=request_dict["conversation_id"] + if "conversation_id" in request_dict else None, + past_user_inputs=request_dict["past_user_inputs"], + generated_responses=request_dict["generated_responses"], + query_kwargs=kwarg_dict_to_proto(query_kwargs), + ) class Text2ImgMethods(TaskMethods): @@ -245,23 +257,25 @@ def pack_response_to_proto(self, response, time_taken, model_time_taken): img_mode = response.images[0].mode img_size_w, img_size_h = response.images[0].size - return modelresponse_pb2.ImageReply(images=images_bytes, - nsfw_content_detected=nsfw_content_detected, - mode=img_mode, - size_w=img_size_w, - size_h=img_size_h, - time_taken=time_taken) + return modelresponse_pb2.ImageReply( + images=images_bytes, + nsfw_content_detected=nsfw_content_detected, + mode=img_mode, + size_w=img_size_w, + size_h=img_size_h, + time_taken=time_taken, + ) def unpack_response_from_proto(self, response): return ImageResponse(response) GRPC_METHOD_TABLE = { - Tasks.TEXT_GENERATION: TextGenerationMethods(), - Tasks.TEXT_CLASSIFICATION: TextClassificationMethods(), - Tasks.QUESTION_ANSWERING: QuestionAnsweringMethods(), - Tasks.FILL_MASK: FillMaskMethods(), - Tasks.TOKEN_CLASSIFICATION: TokenClassificationMethods(), - Tasks.CONVERSATIONAL: ConversationalMethods(), - Tasks.TEXT2IMG: Text2ImgMethods(), + TaskType.TEXT_GENERATION: TextGenerationMethods(), + TaskType.TEXT_CLASSIFICATION: TextClassificationMethods(), + TaskType.QUESTION_ANSWERING: QuestionAnsweringMethods(), + TaskType.FILL_MASK: FillMaskMethods(), + TaskType.TOKEN_CLASSIFICATION: TokenClassificationMethods(), + TaskType.CONVERSATIONAL: ConversationalMethods(), + TaskType.TEXT2IMG: Text2ImgMethods(), } diff --git a/mii/models/load_models.py b/mii/models/load_models.py index a7d5e861..1e28d9aa 100644 --- a/mii/models/load_models.py +++ b/mii/models/load_models.py @@ -4,55 +4,49 @@ # DeepSpeed Team import os import mii -import json import torch import inspect import deepspeed from deepspeed.runtime.config import DeepSpeedConfig from deepspeed.runtime.zero.config import ZeroStageEnum +from mii.utils import get_provider -def load_models(task_name, - model_name, - model_path, - ds_optimize, - ds_zero, - provider, - mii_config, - ds_config_path=None): - global generator - local_rank = int(os.getenv('LOCAL_RANK', '0')) - world_size = int(os.getenv('WORLD_SIZE', '1')) +def load_models(deployment_config): + local_rank = int(os.getenv("LOCAL_RANK", "0")) + world_size = int(os.getenv("WORLD_SIZE", "1")) inf_config = { "tensor_parallel": { - "tp_size": world_size, + "tp_size": deployment_config.tensor_parallel, "mpu": None }, - "dtype": mii_config.dtype, + "dtype": deployment_config.dtype, "replace_method": "auto", - "enable_cuda_graph": mii_config.enable_cuda_graph, + "enable_cuda_graph": deployment_config.enable_cuda_graph, "checkpoint": None, "config": None, "training_mp_size": 1, - "replace_with_kernel_inject": mii_config.replace_with_kernel_inject, - "max_tokens": mii_config.max_tokens + "replace_with_kernel_inject": deployment_config.replace_with_kernel_inject, + "max_tokens": deployment_config.max_tokens, } + provider = get_provider(deployment_config.model, deployment_config.task) if provider == mii.constants.ModelProvider.HUGGING_FACE: from mii.models.providers.huggingface import hf_provider - if "bigscience/bloom" in model_name: - assert mii_config.dtype == torch.half or mii_config.dtype == torch.int8, "Bloom models only support fp16/int8" - assert mii_config.enable_cuda_graph == False, "Bloom models do no support Cuda Graphs" - inference_pipeline = hf_provider(model_path, model_name, task_name, mii_config) - if mii_config.meta_tensor: + + inference_pipeline = hf_provider(deployment_config) + if deployment_config.meta_tensor: inf_config["checkpoint"] = inference_pipeline.checkpoint_dict - if mii_config.dtype == torch.int8: + if deployment_config.dtype == torch.int8: # Support for older DeepSpeed versions - if "enable_qkv_quantization" in inspect.signature( - deepspeed.init_inference).parameters: + if ("enable_qkv_quantization" + in inspect.signature(deepspeed.init_inference).parameters): inf_config["enable_qkv_quantization"] = True elif provider == mii.constants.ModelProvider.ELEUTHER_AI: + assert False, "Eleuther AI support is currently disabled." + # TODO: Re-enable EleutherAI model support + """ from mii.models.providers.eleutherai import eleutherai_provider assert mii_config.dtype == torch.half, "gpt-neox only support fp16" assert mii_config.enable_cuda_graph == False, "Provider EleutherAI not supported with Cuda Graphs" @@ -64,44 +58,42 @@ def load_models(task_name, mii_config) inf_config["training_mp_size"] = 2 inf_config["config"] = inference_pipeline.neox_args + """ elif provider == mii.constants.ModelProvider.DIFFUSERS: from mii.models.providers.diffusers import diffusers_provider - inference_pipeline = diffusers_provider(model_path, - model_name, - task_name, - mii_config) - inf_config["replace_with_kernel_inject"] = False #not supported yet + + inference_pipeline = diffusers_provider(deployment_config) inf_config["enable_cuda_graph"] = True else: raise ValueError(f"Unknown model provider {provider}") - + """ print( f"> --------- MII Settings: ds_optimize={ds_optimize}, replace_with_kernel_inject={mii_config.replace_with_kernel_inject}, enable_cuda_graph={mii_config.enable_cuda_graph} " ) - if ds_optimize: + """ + if deployment_config.enable_deepspeed: engine = deepspeed.init_inference(getattr(inference_pipeline, "model", inference_pipeline), config=inf_config) - if mii_config.profile_model_time: + if deployment_config.profile_model_time: engine.profile_model_time() if hasattr(inference_pipeline, "model"): inference_pipeline.model = engine - elif ds_zero: - assert not mii_config.meta_tensor, "ZeRO-Inference does not support meta tensors" - ds_config = DeepSpeedConfig(ds_config_path) - #TODO: don't read ds-config from disk, we should pass this around as a dict instead - ds_config_dict = json.load(open(ds_config_path, 'r')) - assert ds_config.zero_optimization_stage == ZeroStageEnum.weights, "DeepSpeed ZeRO inference is only supported for ZeRO-3" + elif deployment_config.enable_zero: + ds_config = DeepSpeedConfig(deployment_config.ds_config) + assert ( + ds_config.zero_optimization_stage == ZeroStageEnum.weights + ), "DeepSpeed ZeRO inference is only supported for ZeRO-3" # initialise Deepspeed ZeRO and store only the engine object ds_engine = deepspeed.initialize(model=inference_pipeline.model, - config_params=ds_config_dict)[0] + config=deployment_config.ds_config)[0] ds_engine.module.eval() # inference inference_pipeline.model = ds_engine.module - if mii_config.load_with_sys_mem: + if deployment_config.load_with_sys_mem: inference_pipeline.device = torch.device(f"cuda:{local_rank}") return inference_pipeline diff --git a/mii/models/providers/diffusers.py b/mii/models/providers/diffusers.py index 82768db0..517559fe 100644 --- a/mii/models/providers/diffusers.py +++ b/mii/models/providers/diffusers.py @@ -6,18 +6,20 @@ import torch -def diffusers_provider(model_path, model_name, task_name, mii_config): +def diffusers_provider(deployment_config): from diffusers import DiffusionPipeline - local_rank = int(os.getenv('LOCAL_RANK', '0')) + + local_rank = int(os.getenv("LOCAL_RANK", "0")) kwargs = {} - if mii_config.dtype == torch.half: + if deployment_config.dtype == torch.half: kwargs["torch_dtype"] = torch.float16 kwargs["revision"] = "fp16" - pipeline = DiffusionPipeline.from_pretrained(model_name, - use_auth_token=mii_config.hf_auth_token, - **kwargs) + pipeline = DiffusionPipeline.from_pretrained( + deployment_config.model, + use_auth_token=deployment_config.hf_auth_token, + **kwargs) pipeline = pipeline.to(f"cuda:{local_rank}") pipeline.set_progress_bar_config(disable=True) return pipeline diff --git a/mii/models/providers/huggingface.py b/mii/models/providers/huggingface.py index 27f456aa..52b55730 100644 --- a/mii/models/providers/huggingface.py +++ b/mii/models/providers/huggingface.py @@ -18,9 +18,11 @@ try: from transformers.utils import cached_path, hf_bucket_url + USE_NEW_HF_CACHE = False except ImportError: from huggingface_hub import snapshot_download + USE_NEW_HF_CACHE = True @@ -54,7 +56,7 @@ def __call__(self, inputs, **kwargs): # construct output to align w. HF pipeline output_dicts = [] for output in outputs: - output_dicts.append([{'generated_text': output}]) + output_dicts.append([{"generated_text": output}]) return output_dicts @@ -63,7 +65,7 @@ def get_device(load_with_sys_mem=False): if load_with_sys_mem: device = torch.device("cpu") else: - local_rank = int(os.getenv('LOCAL_RANK', '0')) + local_rank = int(os.getenv("LOCAL_RANK", "0")) device = torch.device(f"cuda:{local_rank}") return device @@ -72,7 +74,7 @@ def _attempt_load(load_fn, model_name, cache_path, kwargs={}): try: value = load_fn(model_name, **kwargs) except OSError: - print(f'Attempted load but failed, retrying using cache_dir={cache_path}') + print(f"Attempted load but failed, retrying using cache_dir={cache_path}") value = load_fn(model_name, cache_dir=cache_path, **kwargs) return value @@ -117,25 +119,27 @@ def get_checkpoint_files(pretrained_model_name_or_path): pretrained_model_name_or_path, resolved_archive_file, cache_dir=cache_dir, - revision=revision + revision=revision, ) return resolved_archive_file -def create_checkpoint_dict(model_name, model_path, mii_config): +def create_checkpoint_dict(model_name, model_path, checkpoint_dict): if USE_NEW_HF_CACHE: - model_path = snapshot_download(model_name, - cache_dir=model_path, - allow_patterns=[ - "*.bin", - "*.json", - "*.pt", - ], - revision=None) - if mii_config.checkpoint_dict: - mii_config.checkpoint_dict['base_dir'] = model_path - return mii_config.checkpoint_dict + model_path = snapshot_download( + model_name, + cache_dir=model_path, + allow_patterns=[ + "*.bin", + "*.json", + "*.pt", + ], + revision=None, + ) + if checkpoint_dict: + checkpoint_dict["base_dir"] = model_path + return checkpoint_dict elif os.path.isfile(os.path.join(model_path, "ds_inference_config.json")): with open(os.path.join(model_path, "ds_inference_config.json")) as f: data = json.load(f) @@ -144,7 +148,7 @@ def create_checkpoint_dict(model_name, model_path, mii_config): else: if USE_NEW_HF_CACHE: checkpoint_files = [ - str(entry).split('/')[-1] + str(entry).split("/")[-1] for entry in Path(model_path).rglob("*.[bp][it][n]") if entry.is_file() ] else: @@ -153,28 +157,34 @@ def create_checkpoint_dict(model_name, model_path, mii_config): "type": "DS_MODEL", "checkpoints": checkpoint_files, "version": 1.0, - "base_dir": model_path + "base_dir": model_path, } return data -def load_with_meta_tensor(model_path, model_name, task_name, mii_config): - deepspeed.init_distributed('nccl') +def load_with_meta_tensor(deployment_config): + deepspeed.init_distributed("nccl") cache_path = mii_cache_path() - tokenizer = _attempt_load(AutoTokenizer.from_pretrained, - model_name, - cache_path, - kwargs={"padding_side": 'left'}) + tokenizer = _attempt_load( + AutoTokenizer.from_pretrained, + deployment_config.model, + cache_path, + kwargs={"padding_side": "left"}, + ) tokenizer.pad_token = tokenizer.eos_token - config = _attempt_load(AutoConfig.from_pretrained, model_name, cache_path) + config = _attempt_load(AutoConfig.from_pretrained, + deployment_config.model, + cache_path) - with OnDevice(dtype=torch.float16, device='meta', enabled=True): + with OnDevice(dtype=torch.float16, device="meta", enabled=True): model = AutoModelForCausalLM.from_config(config, torch_dtype=torch.bfloat16) model = model.eval() - checkpoint_dict = create_checkpoint_dict(model_name, model_path, mii_config) + checkpoint_dict = create_checkpoint_dict(deployment_config.model, + deployment_config.model_path, + deployment_config.checkpoint_dict) torch.distributed.barrier() inference_pipeline = MetaTensorPipeline(model=model, tokenizer=tokenizer, @@ -182,18 +192,18 @@ def load_with_meta_tensor(model_path, model_name, task_name, mii_config): return inference_pipeline -def hf_provider(model_path, model_name, task_name, mii_config): - if mii_config.meta_tensor: - return load_with_meta_tensor(model_path, model_name, task_name, mii_config) +def hf_provider(deployment_config): + if deployment_config.meta_tensor: + return load_with_meta_tensor(deployment_config) else: - device = get_device(load_with_sys_mem=mii_config.load_with_sys_mem) + device = get_device(load_with_sys_mem=deployment_config.load_with_sys_mem) inference_pipeline = pipeline( - task_name, - model=model_name, + deployment_config.task, + model=deployment_config.model, device=device, framework="pt", - use_auth_token=mii_config.hf_auth_token, - torch_dtype=mii_config.dtype, - trust_remote_code=mii_config.trust_remote_code, + use_auth_token=deployment_config.hf_auth_token, + torch_dtype=deployment_config.dtype, + trust_remote_code=deployment_config.trust_remote_code, ) return inference_pipeline diff --git a/mii/models/score/generate.py b/mii/models/score/generate.py index 1184d70e..0839a793 100644 --- a/mii/models/score/generate.py +++ b/mii/models/score/generate.py @@ -9,29 +9,7 @@ from mii.constants import DeploymentType -def create_score_file(deployment_name, - deployment_type, - task, - model_name, - ds_optimize, - ds_zero, - ds_config, - mii_config, - model_path, - lb_config): - config_dict = {} - config_dict[mii.constants.DEPLOYMENT_NAME_KEY] = deployment_name - config_dict[mii.constants.TASK_NAME_KEY] = mii.utils.get_task_name(task) - config_dict[mii.constants.MODEL_NAME_KEY] = model_name - config_dict[mii.constants.ENABLE_DEEPSPEED_KEY] = ds_optimize - config_dict[mii.constants.MII_CONFIGS_KEY] = mii_config.dict() - config_dict[mii.constants.ENABLE_DEEPSPEED_ZERO_KEY] = ds_zero - config_dict[mii.constants.DEEPSPEED_CONFIG_KEY] = ds_config - config_dict[mii.constants.MODEL_PATH_KEY] = model_path - - if lb_config is not None: - config_dict[mii.constants.LOAD_BALANCER_CONFIG_KEY] = lb_config - +def create_score_file(mii_config): if len(mii.__path__) > 1: logger.warning( f"Detected mii path as multiple sources: {mii.__path__}, might cause unknown behavior" @@ -43,15 +21,18 @@ def create_score_file(deployment_name, score_src = fd.read() # update score file w. global config dict + config_dict = mii_config.dict() source_with_config = f"{score_src}\n" - source_with_config += f"configs = {pprint.pformat(config_dict, indent=4)}" + source_with_config += f"mii_config = {pprint.pformat(config_dict, indent=4)}" - with open(generated_score_path(deployment_name, deployment_type), "w") as fd: + with open(generated_score_path(mii_config), "w") as fd: fd.write(source_with_config) fd.write("\n") -def generated_score_path(deployment_name, deployment_type): +def generated_score_path(mii_config): + deployment_type = mii_config.deployment_type + deployment_name = mii_config.deployment_tag if deployment_type == DeploymentType.LOCAL: score_path = os.path.join(mii.utils.mii_cache_path(), deployment_name) elif deployment_type == DeploymentType.AML: diff --git a/mii/models/score/score_template.py b/mii/models/score/score_template.py index 04e47fae..8000c34c 100644 --- a/mii/models/score/score_template.py +++ b/mii/models/score/score_template.py @@ -6,59 +6,44 @@ # flake8: noqa import os import json +import time import torch + import mii -from mii.config import LoadBalancerConfig, ReplicaConfig -import time +from mii.config import DeploymentConfig model = None def init(): - model_path = mii.utils.full_model_path(configs[mii.constants.MODEL_PATH_KEY]) - - deployment_name = configs[mii.constants.DEPLOYMENT_NAME_KEY] - model_name = configs[mii.constants.MODEL_NAME_KEY] - task_name = configs[mii.constants.TASK_NAME_KEY] - - assert model_name is not None, "The model name should be set before calling init" - assert task_name is not None, "The task name should be set before calling init" - - mii.MIIServer(deployment_name, - task_name, - model_name, - model_path, - ds_optimize=configs[mii.constants.ENABLE_DEEPSPEED_KEY], - ds_zero=configs[mii.constants.ENABLE_DEEPSPEED_ZERO_KEY], - ds_config=configs[mii.constants.DEEPSPEED_CONFIG_KEY], - mii_configs=configs[mii.constants.MII_CONFIGS_KEY], - lb_config=configs.get(mii.constants.LOAD_BALANCER_CONFIG_KEY, - None)) + global mii_config + mii_config = mii.MIIConfig(**mii_config) + mii.MIIServer(mii_config) global model model = None # In AML deployments both the GRPC client and server are used in the same process if mii.utils.is_aml(): - model = mii.MIIClient(task_name, - mii_configs=configs[mii.constants.MII_CONFIGS_KEY]) + model = mii.MIIClient(mii_config=mii_config) def run(request): - global model - assert model is not None, "grpc client has not been setup when this model was created" + global mii_config, model + assert ( + model is not None + ), "grpc client has not been setup when this model was created" request_dict = json.loads(request) - query_dict = mii.utils.extract_query_dict(configs[mii.constants.TASK_NAME_KEY], - request_dict) + query_dict = mii.utils.extract_query_dict(mii_config.task, request_dict) response = model.query(query_dict, **request_dict) time_taken = response.time_taken if not isinstance(response.response, str): response = [r for r in response.response] - return json.dumps({'responses': response, 'time': time_taken}) + return json.dumps({"responses": response, "time": time_taken}) ### Auto-generated config will be appended below at run-time diff --git a/mii/models/utils.py b/mii/models/utils.py index 8baa7333..d44b2871 100644 --- a/mii/models/utils.py +++ b/mii/models/utils.py @@ -7,23 +7,24 @@ def supported_models_from_huggingface(): - return ['gpt2', "deepset/roberta-large-squad2"] + return ["gpt2", "deepset/roberta-large-squad2"] -'''TODO make this more robust. If the pipeline has already been imported then -this might not work since the cache is set by the first import''' +"""TODO make this more robust. If the pipeline has already been imported then +this might not work since the cache is set by the first import""" def _download_hf_model_to_path(task, model_name, model_path): os.environ["TRANSFORMERS_CACHE"] = model_path from transformers import pipeline + inference_pipeline = pipeline(task, model=model_name) -'''generic method that will allow downloading all models that we support. +"""generic method that will allow downloading all models that we support. Currently only supports HF models, but will be extended to support model checkpoints -from other sources''' +from other sources""" def download_model_and_get_path(task, model_name): @@ -40,7 +41,7 @@ def download_model_and_get_path(task, model_name): return model_path -class ImageResponse(): +class ImageResponse: def __init__(self, response): self._response = response self.nsfw_content_detected = response.nsfw_content_detected @@ -50,6 +51,7 @@ def __init__(self, response): def images(self): if self._deserialized_images is None: from PIL import Image + images = [] for idx, img_bytes in enumerate(self._response.images): size = (self._response.size_w, self._response.size_h) diff --git a/mii/server.py b/mii/server.py index 0825e060..742f2488 100644 --- a/mii/server.py +++ b/mii/server.py @@ -3,18 +3,14 @@ # DeepSpeed Team import base64 -import json import os import subprocess import sys import tempfile import time -import torch -from pathlib import Path from collections import defaultdict -import mii -from mii.utils import get_num_gpus, logger, get_provider_name +from mii.utils import get_num_gpus, logger def config_to_b64_str(config): @@ -26,44 +22,29 @@ def config_to_b64_str(config): return b64_config_bytes.decode() -class MIIServer(): - '''Initialize the model, setup the server for the model under model_path''' - def __init__(self, - deployment_name, - task_name, - model_name, - model_path, - ds_optimize=True, - ds_zero=False, - ds_config=None, - mii_configs={}, - lb_config=None): +class MIIServer: + """Initialize the model, setup the server for the model under model_path""" + def __init__(self, mii_config): - mii_configs = mii.config.MIIConfig(**mii_configs) + #self.task = mii_config.deployment_config.task - self.task = mii.utils.get_task(task_name) - - self.num_gpus = get_num_gpus(mii_configs) - assert self.num_gpus > 0, "GPU count must be greater than 0" - - self.port_number = mii_configs.port_number + #self.num_gpus = get_num_gpus(mii_config) + #assert self.num_gpus > 0, "GPU count must be greater than 0" + for deployment_config in mii_config.deployment_configs.values(): + assert get_num_gpus(deployment_config) > 0, f"GPU count for {deployment.deployment_name} must be greater than 0" + """ if mii_configs.hostfile is None: hostfile = tempfile.NamedTemporaryFile(delete=False) num_gpu = torch.cuda.device_count() with open(hostfile, "w") as f: f.write(f"localhost slots={num_gpu}") - mii.configs.hostfile = hostfile + mii_config.hostfile = hostfile + """ - processes = self._initialize_service(deployment_name, - model_name, - model_path, - ds_optimize, - ds_zero, - ds_config, - mii_configs, - lb_config) - self._wait_until_server_is_live(processes, lb_config.replica_configs) + processes = self._initialize_service(mii_config) + self._wait_until_server_is_live(processes, + mii_config.deployment_config.replica_configs) def _wait_until_server_is_live(self, processes, deployment): for process, repl_config in zip(processes, deployment): @@ -84,6 +65,7 @@ def _wait_until_server_is_live(self, processes, deployment): def _is_socket_open(self, host, port): import socket + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) result = sock.connect_ex((host, port)) sock.close() @@ -102,242 +84,85 @@ def _is_server_process_alive(self, process): is_alive = False return is_alive - def _build_server_args(self, - deployment_name, - model_name, - model_path, - ds_optimize, - ds_zero, - ds_config, - mii_configs, - port): - # serialize mii config - b64_config_str = config_to_b64_str(mii_configs) - - server_args_str = f"--deployment-name {deployment_name} --task-name {mii.utils.get_task_name(self.task)} --model {model_name} --model-path {model_path} --port {port}" - server_args_str += " --ds-optimize" if ds_optimize else "" - - # XXX: fetch model provider based on model name in a more general way - provider = get_provider_name(model_name, self.task) - server_args_str += f" --provider {provider}" - - server_args_str += f" --config {b64_config_str}" - server_args_str += " --ds-zero" if ds_zero else "" - if ds_zero and ds_config is not None: - if isinstance(ds_config, dict): - - def create_config_from_dict(tmpdir, config_dict): - if not os.path.exists(tmpdir): - os.makedirs(tmpdir) - config_path = os.path.join(tmpdir, 'temp_config.json') - with open(config_path, 'w') as fd: - json.dump(config_dict, fd) - return config_path - - model_dir = Path(model_path).parent.resolve() - ds_config_path = create_config_from_dict(model_dir, ds_config) - elif isinstance(ds_config, str): - ds_config_path = ds_config - else: - raise ValueError( - f"Expected a string path to an existing deepspeed config, or a dictionary. Received: {ds_config}" - ) - server_args_str += f" --ds-config {ds_config_path}" - printable_config = f"task-name {mii.utils.get_task_name(self.task)} model {model_name} model-path {model_path} port {self.port_number} provider {provider}" - logger.info(f"MII using multi-gpu deepspeed launcher:\n" + - self.print_helper(printable_config)) - return server_args_str - - def print_helper(self, args): - # convert to list - args = args.split(" ") - # convert to dict - dct = {args[i]: args[i + 1] for i in range(0, len(args), 2)} - printable_string = "" - printable_string += " " + "-" * 60 + "\n" - for k, v in dct.items(): - dots = "." * (29 - len(k)) - printable_string += f" {k} {dots} {v} \n" - printable_string += " " + "-" * 60 - return printable_string - - def _launch_load_balancer(self, - deployment_name, - model_name, - model_path, - ds_optimize, - ds_zero, - ds_config, - mii_configs, - lb_config): - - # serialize mii config - b64_config_str = config_to_b64_str(lb_config) - - return self._launch_server_process( - deployment_name, - model_name, - model_path, - ds_optimize, - ds_zero, - ds_config, - mii_configs, - mii_configs.port_number, - "load balancer", - ex_server_args=[f"--load-balancer {b64_config_str}"]) - - def _launch_restful_gateway(self, - deployment_name, - model_name, - model_path, - ds_optimize, - ds_zero, - ds_config, - mii_configs, - port): - return self._launch_server_process(deployment_name, - model_name, - model_path, - ds_optimize, - ds_zero, - ds_config, - mii_configs, - port, - "restful api gateway", - ex_server_args=["--restful-gateway"]) - def _launch_server_process(self, - deployment_name, - model_name, - model_path, - ds_optimize, - ds_zero, - ds_config, - mii_configs, - port, + deployment_config, msg_server_type, - ds_launch_str=None, - ex_server_args=[]): + ds_launch_str="", + server_args=[]): launch_str = f"{sys.executable} -m mii.launch.multi_gpu_server" - server_args_str = self._build_server_args(deployment_name, - model_name, - model_path, - ds_optimize, - ds_zero, - ds_config, - mii_configs, - port) - server_args_str += f" " + \ - " ".join(ex_server_args) if ex_server_args else "" + print(deployment_config) + b64_config_str = config_to_b64_str(deployment_config) + server_args.append(f"--deployment-config {b64_config_str}") - if ds_launch_str is None: - cmd = f'{launch_str} {server_args_str}'.split(" ") - else: - cmd = f'{ds_launch_str} {launch_str} {server_args_str}'.split(" ") + server_args_str = " ".join(server_args) + cmd = f"{ds_launch_str} {launch_str} {server_args_str}".strip().split(" ") mii_env = os.environ.copy() - mii_env["TRANSFORMERS_CACHE"] = model_path + mii_env["TRANSFORMERS_CACHE"] = deployment_config.model_path logger.info(f"{msg_server_type} server launch: {cmd}") return subprocess.Popen(cmd, env=mii_env) - def _launch_deepspeed(self, - deployment_name, - model_name, - model_path, - ds_optimize, - ds_zero, - ds_config, - mii_configs, - hostfile, - host, - port, - master_port, - deploy_ranks): + def _generate_ds_launch_str(self, replica_config, hostfile): # use different hostfiles for replica instances # pass /dev/null when no replica is used worker_str = f"-H {hostfile} " # pin deepspeed launch to specific gpu id(s) - included_gpus = f"{host}:{','.join(map(str, deploy_ranks))}" + included_gpus = f"{replica_config.hostname}:{','.join(map(str, replica_config.gpu_indices))}" worker_str += f"-i {included_gpus} " # adjust torch dist port depending on rank, otherwise multi-replica deployments will conflict # assign different ports to replicas because they could be on the same host - worker_str += f"--master_port {master_port}" + worker_str += f"--master_port {replica_config.torch_dist_port}" ds_launch_str = f"deepspeed {worker_str} --no_local_rank --no_python" - return self._launch_server_process(deployment_name, - model_name, - model_path, - ds_optimize, - ds_zero, - ds_config, - mii_configs, - port, - "MII server", - ds_launch_str=ds_launch_str) - - def _initialize_service(self, - deployment_name, - model_name, - model_path, - ds_optimize, - ds_zero, - ds_config, - mii_configs, - lb_config): + return ds_launch_str + def _initialize_service(self, mii_config): processes = [] + server_args = [ + f"--load-balancer-port {mii_config.port_number}", + f"--restful-gateway-port {mii_config.restful_api_port}", + ] host_gpus = defaultdict(list) - for repl_config in lb_config.replica_configs: - host_gpus[repl_config.hostname].extend(repl_config.gpu_indices) + for deployment in mii_config.deployment_configs.values(): + for repl_config in deployment.replica_configs: + host_gpus[repl_config.hostname].extend(repl_config.gpu_indices) # Start replica instances - for i, repl_config in enumerate(lb_config.replica_configs): - hostfile = tempfile.NamedTemporaryFile(delete=False) - hostfile.write( - f'{repl_config.hostname} slots={max(host_gpus[repl_config.hostname])+1}\n' - .encode()) - processes.append( - self._launch_deepspeed( - deployment_name, - model_name, - model_path, - ds_optimize, - ds_zero, - ds_config, - mii_configs, - hostfile.name, - repl_config.hostname, - repl_config.tensor_parallel_ports[0], - mii_configs.torch_dist_port + (100 * i) + repl_config.gpu_indices[0], - repl_config.gpu_indices)) - + for deployment_config in mii_config.deployment_configs.values(): + for repl_config in deployment_config.replica_configs: + hostfile = tempfile.NamedTemporaryFile(delete=False) + hostfile.write( + f"{repl_config.hostname} slots={max(host_gpus[repl_config.hostname])+1}\n" + .encode()) + ds_launch_str = self._generate_ds_launch_str(repl_config, hostfile.name) + processes.append( + self._launch_server_process( + deployment_config, + "MII server", + ds_launch_str=ds_launch_str, + server_args=server_args + + [f"--server-port {repl_config.tensor_parallel_ports[0]}"], + )) # start load balancer here. # we don't use deepspeed launcher for the load balancer because it does not need a GPU. # The deepspeed launcher determines the number of processes to launch based on GPUs available on the host or CUDA_VISIBLE_DEVICES, # and it is expected to assign one GPU to one process. processes.append( - self._launch_load_balancer(deployment_name, - model_name, - model_path, - ds_optimize, - ds_zero, - ds_config, - mii_configs, - lb_config)) + self._launch_server_process( + mii_config, + "load balancer", + server_args=server_args + ["--load-balancer"], + )) - if mii_configs.enable_restful_api: - # start rest api server + if mii_config.enable_restful_api: processes.append( - self._launch_restful_gateway(deployment_name, - model_name, - model_path, - ds_optimize, - ds_zero, - ds_config, - mii_configs, - mii_configs.port_number)) + self._launch_server_process( + next(iter(mii_config.deployment_configs.values())), + "restful api gateway", + server_args=server_args + ["--restful-gateway"], + )) return processes diff --git a/mii/terminate.py b/mii/terminate.py index 167c5a5a..a0a6b99e 100644 --- a/mii/terminate.py +++ b/mii/terminate.py @@ -10,11 +10,11 @@ def terminate(deployment_name): mii.utils.logger.info(f"Terminating server for {deployment_name}") generator = mii.mii_query_handle(deployment_name) - if (deployment_name in mii.non_persistent_models): + if deployment_name in mii.non_persistent_models: generator.terminate() return try: - generator.query({'query': ''}) + generator.query({"query": ""}) except grpc.aio._call.AioRpcError as error: if error._code == grpc.StatusCode.UNAVAILABLE: mii.utils.logger.warn(f"Server for {deployment_name} not found") diff --git a/mii/utils.py b/mii/utils.py index 674c64a3..9a93b3c1 100644 --- a/mii/utils.py +++ b/mii/utils.py @@ -10,70 +10,15 @@ import mii from huggingface_hub import HfApi -from mii.constants import (CONVERSATIONAL_NAME, - FILL_MASK_NAME, - MII_CACHE_PATH, - MII_CACHE_PATH_DEFAULT, - TEXT_GENERATION_NAME, - TEXT_CLASSIFICATION_NAME, - QUESTION_ANSWERING_NAME, - TOKEN_CLASSIFICATION_NAME, - SUPPORTED_MODEL_TYPES, - ModelProvider, - REQUIRED_KEYS_PER_TASK, - TEXT2IMG_NAME) +from mii.constants import ( + MII_CACHE_PATH, + MII_CACHE_PATH_DEFAULT, + ModelProvider, + SUPPORTED_MODEL_TYPES, + REQUIRED_KEYS_PER_TASK, +) -from mii.constants import Tasks - - -def get_task_name(task): - if task == Tasks.QUESTION_ANSWERING: - return QUESTION_ANSWERING_NAME - - if task == Tasks.TEXT_GENERATION: - return TEXT_GENERATION_NAME - - if task == Tasks.TEXT_CLASSIFICATION: - return TEXT_CLASSIFICATION_NAME - - if task == Tasks.FILL_MASK: - return FILL_MASK_NAME - - if task == Tasks.TOKEN_CLASSIFICATION: - return TOKEN_CLASSIFICATION_NAME - - if task == Tasks.CONVERSATIONAL: - return CONVERSATIONAL_NAME - - if task == Tasks.TEXT2IMG: - return TEXT2IMG_NAME - - raise ValueError(f"Unknown Task {task}") - - -def get_task(task_name): - if task_name == QUESTION_ANSWERING_NAME: - return Tasks.QUESTION_ANSWERING - - if task_name == TEXT_GENERATION_NAME: - return Tasks.TEXT_GENERATION - - if task_name == TEXT_CLASSIFICATION_NAME: - return Tasks.TEXT_CLASSIFICATION - - if task_name == FILL_MASK_NAME: - return Tasks.FILL_MASK - - if task_name == TOKEN_CLASSIFICATION_NAME: - return Tasks.TOKEN_CLASSIFICATION - - if task_name == CONVERSATIONAL_NAME: - return Tasks.CONVERSATIONAL - - if task_name == TEXT2IMG_NAME: - return Tasks.TEXT2IMG - - assert False, f"Unknown Task {task_name}" +from mii.config import TaskType def _get_hf_models_by_type(model_type, task=None): @@ -81,7 +26,7 @@ def _get_hf_models_by_type(model_type, task=None): models = api.list_models(filter=model_type) models = ([m.modelId for m in models] if task is None else [m.modelId for m in models if m.pipeline_tag == task]) - if task == TEXT_GENERATION_NAME: + if task == TaskType.TEXT_GENERATION: # TODO: this is a temp solution to get around some HF models not having the correct tags models.append("microsoft/bloom-deepspeed-inference-fp16") models.append("microsoft/bloom-deepspeed-inference-int8") @@ -92,16 +37,15 @@ def _get_hf_models_by_type(model_type, task=None): # TODO read this from a file containing list of files supported for each task def _get_supported_models_name(task): supported_models = [] - task_name = get_task_name(task) for model_type, provider in SUPPORTED_MODEL_TYPES.items(): if provider == ModelProvider.HUGGING_FACE: - models = _get_hf_models_by_type(model_type, task_name) + models = _get_hf_models_by_type(model_type, task) elif provider == ModelProvider.ELEUTHER_AI: - if task_name == TEXT_GENERATION_NAME: + if task == TaskType.TEXT_GENERATION: models = [model_type] elif provider == ModelProvider.DIFFUSERS: - models = _get_hf_models_by_type(model_type, task_name) + models = _get_hf_models_by_type(model_type, task) supported_models.extend(models) if not supported_models: raise ValueError(f"Task {task} not supported") @@ -115,27 +59,8 @@ def check_if_task_and_model_is_supported(task, model_name): def check_if_task_and_model_is_valid(task, model_name): - task_name = get_task_name(task) - valid_task_models = _get_hf_models_by_type(None, task_name) - assert ( - model_name in valid_task_models - ), f"{task_name} only supports {valid_task_models}" - - -def full_model_path(model_path): - aml_model_dir = os.environ.get('AZUREML_MODEL_DIR', None) - if aml_model_dir: - # (potentially) append relative model_path w. aml path - assert os.path.isabs(aml_model_dir), f"AZUREML_MODEL_DIR={aml_model_dir} must be an absolute path" - if model_path: - assert not os.path.isabs(model_path), f"model_path={model_path} must be relative to append w. AML path" - return os.path.join(aml_model_dir, model_path) - else: - return aml_model_dir - elif model_path: - return model_path - else: - return mii.constants.MII_MODEL_PATH_DEFAULT + valid_task_models = _get_hf_models_by_type(None, task) + assert model_name in valid_task_models, f"{task} only supports {valid_task_models}" def is_aml(): @@ -198,21 +123,22 @@ def extract_query_dict(task, request_dict): return query_dict -def get_num_gpus(mii_configs): - num_gpus = mii_configs.tensor_parallel +def get_num_gpus(deployment_config): + num_gpus = deployment_config.tensor_parallel - assert torch.cuda.device_count( - ) >= num_gpus, f"Available GPU count: {torch.cuda.device_count()} does not meet the required gpu count: {num_gpus}" + assert ( + torch.cuda.device_count() >= num_gpus + ), f"Available GPU count: {torch.cuda.device_count()} does not meet the required gpu count: {num_gpus}" return num_gpus -def get_provider_name(model_name, task): +def get_provider(model_name, task): if model_name == "gpt-neox": - provider = mii.constants.MODEL_PROVIDER_NAME_EA - elif task == mii.Tasks.TEXT2IMG: - provider = mii.constants.MODEL_PROVIDER_NAME_DIFFUSERS + provider = ModelProvider.ELEUTHER_AI + elif task == TaskType.TEXT2IMG: + provider = ModelProvider.DIFFUSERS else: - provider = mii.constants.MODEL_PROVIDER_NAME_HF + provider = ModelProvider.HUGGING_FACE return provider diff --git a/tests/conftest.py b/tests/conftest.py index cb812069..9dac6e69 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,12 +9,7 @@ from types import SimpleNamespace -# Add pytest.skip here for configs that we do not want to test -def validate_config(config): - pass - - -@pytest.fixture(scope="function", params=['fp16']) +@pytest.fixture(scope="function", params=["fp16"]) def dtype(request): return request.param @@ -54,30 +49,6 @@ def restful_api_port(request): return request.param -@pytest.fixture(scope="function") -def mii_config( - tmpdir: str, - dtype: str, - tensor_parallel: int, - port_number: int, - meta_tensor: bool, - load_with_sys_mem: bool, - replica_num: int, - enable_restful_api: bool, - restful_api_port: int, -): - return { - 'dtype': dtype, - 'tensor_parallel': tensor_parallel, - 'port_number': port_number, - 'meta_tensor': meta_tensor, - 'load_with_sys_mem': load_with_sys_mem, - 'replica_num': replica_num, - 'enable_restful_api': enable_restful_api, - 'restful_api_port': restful_api_port, - } - - @pytest.fixture(scope="function", params=["text-generation"]) def task_name(request): return request.param @@ -88,6 +59,11 @@ def model_name(request): return request.param +@pytest.fixture(scope="function") +def deployment_name(model_name): + return model_name + "-deployment" + + @pytest.fixture(scope="function", params=[mii.DeploymentType.LOCAL]) def deployment_type(request): return request.param @@ -109,25 +85,48 @@ def ds_config(request): @pytest.fixture(scope="function") -def deployment_config(task_name: str, - model_name: str, - deployment_type: str, - mii_config: dict, - enable_deepspeed: bool, - enable_zero: bool, - ds_config: dict): - config = SimpleNamespace(task=task_name, - model=model_name, - deployment_type=deployment_type, - deployment_name=model_name + "-deployment", - model_path=os.getenv("TRANSFORMERS_CACHE", - None), - mii_config=mii_config, - enable_deepspeed=enable_deepspeed, - enable_zero=enable_zero, - ds_config=ds_config) - validate_config(config) - return config +def deployment_config( + task_name: str, + model_name: str, + dtype: str, + tensor_parallel: int, + meta_tensor: bool, + load_with_sys_mem: bool, + replica_num: int, + enable_deepspeed: bool, + enable_zero: bool, + ds_config: dict, +): + config = SimpleNamespace( + task=task_name, + model=model_name, + dtype=dtype, + tensor_parallel=tensor_parallel, + model_path=os.getenv("TRANSFORMERS_CACHE", + ""), + meta_tensor=meta_tensor, + replica_num=replica_num, + enable_deepspeed=enable_deepspeed, + enable_zero=enable_zero, + ds_config=ds_config, + ) + return config.__dict__ + + +@pytest.fixture(scope="function") +def mii_config( + deployment_type: str, + port_number: int, + enable_restful_api: bool, + restful_api_port: int, +): + config = SimpleNamespace( + deployment_type=deployment_type, + port_number=port_number, + enable_restful_api=enable_restful_api, + restful_api_port=restful_api_port, + ) + return config.__dict__ @pytest.fixture(scope="function", params=[None]) @@ -136,15 +135,23 @@ def expected_failure(request): @pytest.fixture(scope="function") -def deployment(deployment_config, expected_failure): +def deployment(deployment_name, mii_config, deployment_config, expected_failure): if expected_failure is not None: with pytest.raises(expected_failure) as excinfo: - mii.deploy(**deployment_config.__dict__) + mii.deploy( + deployment_name=deployment_name, + mii_config=mii_config, + deployment_config=deployment_config, + ) yield excinfo else: - mii.deploy(**deployment_config.__dict__) - yield deployment_config - mii.terminate(deployment_config.deployment_name) + mii.deploy( + deployment_name=deployment_name, + mii_config=mii_config, + deployment_config=deployment_config, + ) + yield deployment_name + mii.terminate(deployment_name) @pytest.fixture(scope="function", params=[{"query": "DeepSpeed is the greatest"}]) diff --git a/tests/test_config.py b/tests/test_config.py index 2d8de70b..73e68eeb 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -9,33 +9,21 @@ import mii -def test_base_config(): - config = {'port_number': 12345, 'tensor_parallel': 4} - mii_config = mii.config.MIIConfig(**config) - - assert mii_config.port_number == config['port_number'] - assert mii_config.tensor_parallel == config['tensor_parallel'] - - -@pytest.mark.parametrize("config", - [ - { - 'port_number': 'fail', - 'tensor_parallel': 'fail' - }, - { - 'port_number': 'fail', - 'tensor_parallel': 4 - }, - { - 'port_number': 12345, - 'tensor_parallel': 'fail' - }, - { - 'port_fail': 12345, - 'tensor_parallel': 4 - }, - ]) -def test_base_config_literalfail(config): +@pytest.mark.parametrize("port_number", [12345]) +@pytest.mark.parametrize("tensor_parallel", [4]) +def test_base_configs(deployment_name, mii_config, deployment_config): + deployment_config["deployment_name"] = deployment_name + mii_config["deployment_config"] = deployment_config + mii_config = mii.config.MIIConfig(**mii_config) + + assert mii_config.port_number == 12345 + assert mii_config.deployment_config.tensor_parallel == 4 + + +@pytest.mark.parametrize("port_number", ["fail"]) +@pytest.mark.parametrize("tensor_parallel", [3.5]) +def test_base_configs_literalfail(deployment_name, mii_config, deployment_config): with pytest.raises(pydantic.ValidationError): - mii_config = mii.config.MIIConfig(**config) + deployment_config["deployment_name"] = deployment_name + mii_config["deployment_config"] = deployment_config + mii_config = mii.config.MIIConfig(**mii_config) diff --git a/tests/test_deployment_options.py b/tests/test_deployment_options.py index a84318d7..5b1dcf22 100644 --- a/tests/test_deployment_options.py +++ b/tests/test_deployment_options.py @@ -6,6 +6,7 @@ import pytest import json import requests +import pydantic import mii @@ -13,18 +14,18 @@ @pytest.mark.parametrize("meta_tensor", [True]) @pytest.mark.parametrize("tensor_parallel", [2]) def test_meta_tensor(deployment, query): - generator = mii.mii_query_handle(deployment.deployment_name) + generator = mii.mii_query_handle(deployment) result = generator.query(query) assert result @pytest.mark.parametrize("enable_restful_api", [True]) def test_restful_api(deployment, query, restful_api_port): - generator = mii.mii_query_handle(deployment.deployment_name) + generator = mii.mii_query_handle(deployment) for _ in range(2): result = generator.query(query) - url = f'http://localhost:{restful_api_port}/mii/{deployment.deployment_name}' + url = f"http://localhost:{restful_api_port}/mii/{deployment}" params = {"request": query} json_params = json.dumps(params) result = requests.post(url, @@ -36,14 +37,14 @@ def test_restful_api(deployment, query, restful_api_port): @pytest.mark.parametrize("load_with_sys_mem", [True]) def test_load_to_sys_mem(deployment, query): - generator = mii.mii_query_handle(deployment.deployment_name) + generator = mii.mii_query_handle(deployment) result = generator.query(query) assert result @pytest.mark.parametrize("replica_num", [2]) def test_replicas(deployment, query, replica_num): - generator = mii.mii_query_handle(deployment.deployment_name) + generator = mii.mii_query_handle(deployment) # Replicas are given queries in round-robin, so test each model is responding for _ in range(replica_num): result = generator.query(query) @@ -53,69 +54,62 @@ def test_replicas(deployment, query, replica_num): @pytest.mark.deepspeed @pytest.mark.parametrize("enable_deepspeed", [False]) @pytest.mark.parametrize("enable_zero", [True]) -@pytest.mark.parametrize("ds_config", - [ - { - "fp16": { - "enabled": True - }, - "bf16": { - "enabled": False - }, - "zero_optimization": { - "stage": 3, - "offload_param": { - "device": "cpu", - }, - }, - "train_micro_batch_size_per_gpu": 1, - }, - ]) +@pytest.mark.parametrize( + "ds_config", + [ + { + "fp16": { + "enabled": True + }, + "bf16": { + "enabled": False + }, + "zero_optimization": { + "stage": 3, + "offload_param": { + "device": "cpu", + }, + }, + "train_micro_batch_size_per_gpu": 1, + }, + ], +) def test_zero_config(deployment, query): - generator = mii.mii_query_handle(deployment.deployment_name) + generator = mii.mii_query_handle(deployment) result = generator.query(query) assert result @pytest.mark.deepspeed -@pytest.mark.parametrize("expected_failure", [AssertionError]) -@pytest.mark.parametrize("enable_deepspeed, enable_zero, dtype", - [(True, - True, - 'fp32'), - (False, - True, - 'fp16')]) -@pytest.mark.parametrize("ds_config", - [ - { - "fp16": { - "enabled": False - }, - "bf16": { - "enabled": False - }, - "zero_optimization": { - "stage": 3, - "offload_param": { - "device": "cpu", - }, - }, - "train_micro_batch_size_per_gpu": 1, - }, - ]) +@pytest.mark.parametrize("expected_failure", [pydantic.ValidationError]) +@pytest.mark.parametrize( + "enable_deepspeed, enable_zero, dtype", + [(True, + True, + "fp32"), + (False, + True, + "fp16")], +) @pytest.mark.parametrize( - "task_name, model_name, query", + "ds_config", [ - ( - "text-generation", - "distilgpt2", - { - "query": "DeepSpeed is the greatest" + { + "fp16": { + "enabled": False + }, + "bf16": { + "enabled": False + }, + "zero_optimization": { + "stage": 3, + "offload_param": { + "device": "cpu", + }, }, - ), + "train_micro_batch_size_per_gpu": 1, + }, ], ) def test_zero_config_fail(deployment, query): - print(deployment) - assert "MII Config Error" in str(deployment.value) + assert "assertion_error" in str(deployment.value) diff --git a/tests/test_local_deployment.py b/tests/test_local_deployment.py index 7e0ab7b4..ec77613d 100644 --- a/tests/test_local_deployment.py +++ b/tests/test_local_deployment.py @@ -46,14 +46,16 @@ "bigscience/bloom-560m", { "query": ["DeepSpeed is the greatest", - 'Seattle is'] + "Seattle is"] + }, + ), + ( + "token-classification", + "Jean-Baptiste/roberta-large-ner-english", + { + "query": "My name is jean-baptiste and I live in montreal." }, ), - ("token-classification", - "Jean-Baptiste/roberta-large-ner-english", - { - "query": "My name is jean-baptiste and I live in montreal." - }), ( "text-classification", "roberta-large-mnli", @@ -64,7 +66,7 @@ ], ) def test_single_GPU(deployment, query): - generator = mii.mii_query_handle(deployment.deployment_name) + generator = mii.mii_query_handle(deployment) result = generator.query(query) assert result @@ -77,12 +79,12 @@ def test_single_GPU(deployment, query): "bigscience/bloom-560m", { "query": ["DeepSpeed is the greatest", - 'Seattle is'] + "Seattle is"] }, ), ], ) def test_multi_GPU(deployment, query): - generator = mii.mii_query_handle(deployment.deployment_name) + generator = mii.mii_query_handle(deployment) result = generator.query(query) assert result diff --git a/tests/test_non_persistent_deployment.py b/tests/test_non_persistent_deployment.py index 2f555b64..71234201 100644 --- a/tests/test_non_persistent_deployment.py +++ b/tests/test_non_persistent_deployment.py @@ -48,14 +48,16 @@ "bigscience/bloom-560m", { "query": ["DeepSpeed is the greatest", - 'Seattle is'] + "Seattle is"] + }, + ), + ( + "token-classification", + "Jean-Baptiste/roberta-large-ner-english", + { + "query": "My name is jean-baptiste and I live in montreal." }, ), - ("token-classification", - "Jean-Baptiste/roberta-large-ner-english", - { - "query": "My name is jean-baptiste and I live in montreal." - }), ( "text-classification", "roberta-large-mnli", @@ -66,6 +68,6 @@ ], ) def test_single_GPU(deployment, query): - generator = mii.mii_query_handle(deployment.deployment_name) + generator = mii.mii_query_handle(deployment) result = generator.query(query) assert result