From 4ee2064df0319eb1a61e32507f47f412adc3479c Mon Sep 17 00:00:00 2001 From: Michael Wyatt Date: Mon, 13 Nov 2023 14:55:34 -0800 Subject: [PATCH 1/9] move request/response dataclasses to new file --- mii/batching/data_classes.py | 267 ++++++++++++++++++++++++++++++++ mii/batching/ragged_batching.py | 262 +------------------------------ 2 files changed, 270 insertions(+), 259 deletions(-) create mode 100644 mii/batching/data_classes.py diff --git a/mii/batching/data_classes.py b/mii/batching/data_classes.py new file mode 100644 index 00000000..c57e04d1 --- /dev/null +++ b/mii/batching/data_classes.py @@ -0,0 +1,267 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +from dataclasses import dataclass, field +from typing import Dict, List, Iterator, Union +from typing_extensions import Self + +import torch + +from mii.constants import GenerationFinishReason + + +@dataclass +class Response: + generated_text: str + prompt_length: int + generated_length: int + finish_reason: GenerationFinishReason + + @staticmethod + def from_msg(msg: Dict[str, Union[str, int]]) -> Self: + return Response( + generated_text=msg["generated_text"], + prompt_length=msg["prompt_length"], + generated_length=msg["generated_length"], + finish_reason=GenerationFinishReason(msg["finish_reason"]), + ) + + def get_msg(self) -> Dict[str, Union[str, int]]: + return { + "generated_text": self.generated_text, + "prompt_length": self.prompt_length, + "generated_length": self.generated_length, + "finish_reason": self.finish_reason.value + } + + def __repr__(self) -> str: + return self.generated_text + + def __str__(self) -> str: + return self.generated_text + + +class ResponseBatch: + def __init__(self, responses: List[Response]) -> None: + self.responses = responses + + def __iter__(self) -> Iterator[Response]: + return iter(self.responses) + + def __repr__(self) -> str: + return "\n\n".join(str(r) for r in self.responses) + + @property + def generated_texts(self) -> List[str]: + return [r.generated_text for r in self.responses] + + @property + def prompt_lengths(self) -> List[int]: + return [r.prompt_length for r in self.responses] + + @property + def generated_lengths(self) -> List[int]: + return [r.generated_length for r in self.responses] + + @property + def finish_reasons(self) -> List[GenerationFinishReason]: + return [r.finish_reason for r in self.responses] + + def append(self, response: Response) -> None: + self.responses.append(response) + + +@dataclass +class RaggedRequestMsg: + uid: int + input_tokens: Union[torch.Tensor, List[int]] + + @property + def is_flush_request(self): + return self.input_tokens is None + + @staticmethod + def from_msg(msg: Dict[str, int]) -> Self: + return RaggedRequestMsg( + uid=msg["uid"], + input_tokens=None + if msg["input_tokens"] is None else torch.tensor(msg["input_tokens"], + dtype=torch.int32, + device=torch.device("cpu")), + ) + + +@dataclass +class RaggedRequest: + tid: int + uid: int + input_tokens: torch.Tensor + prompt_tokens: torch.Tensor + seq_length: int + max_length: int + max_new_tokens: int + min_new_tokens: int + last_in_prompt: bool + post_processing: List[object] + stream: bool = False + ignore_eos: bool = False + return_full_text: bool = False + + _next_token: Union[None, torch.Tensor] = None + _is_done: bool = False + _generated_tokens: List[torch.Tensor] = field(default_factory=list) + _finish_reason: GenerationFinishReason = GenerationFinishReason.NONE + + @property + def prompt_length(self) -> int: + return len(self.prompt_tokens) + + @property + def next_token(self) -> Union[None, torch.Tensor]: + return self._next_token + + @next_token.setter + def next_token(self, next_token: Union[None, torch.Tensor]) -> None: + self._next_token = next_token + + @property + def is_done(self) -> bool: + if self.ignore_eos: + return False + if self.seq_length < self.min_new_tokens: + return False + return self._is_done + + @is_done.setter + def is_done(self, is_done: bool) -> None: + self._is_done = is_done + + @property + def generated_tokens(self) -> List[torch.Tensor]: + return self._generated_tokens + + @property + def finish_reason(self) -> GenerationFinishReason: + return self._finish_reason + + @property + def is_flush_request(self): + return self.input_tokens is None + + @property + def num_generated_tokens(self) -> int: + # We return zero while we are processing decomposed prompts + return self.seq_length - self.prompt_length + 1 if self.seq_length >= self.prompt_length else 0 + + @property + def stop_generation(self) -> bool: + if self.is_done: + self._finish_reason = GenerationFinishReason.STOP + return True + if (self.seq_length >= self.max_length) or (self.num_generated_tokens >= + self.max_new_tokens): + self._finish_reason = GenerationFinishReason.LENGTH + return True + return False + + def get_msg(self) -> RaggedRequestMsg: + return RaggedRequestMsg( + uid=self.uid, + input_tokens=None + if self.input_tokens is None else self.input_tokens.tolist(), + ) + + def accumulate_generated_token(self) -> None: + if not self.is_done: + self._generated_tokens.append(self.next_token) + + def clear_generated_token(self) -> None: + self._generated_tokens.clear() + + def set_next_as_input(self) -> None: + if self.next_token is not None: + self.input_tokens = self.next_token.unsqueeze(0) + self.last_in_prompt = True + self.next_token = None + self.is_done = False + + +class RaggedRequestBatch: + def __init__(self, requests: List[RaggedRequest]) -> None: + self.requests = requests + + def __len__(self) -> int: + return len(self.requests) + + def __contains__(self, r: RaggedRequest) -> bool: + return r in self.requests + + def __nonzero__(self) -> bool: + if len(self.requests) != 0: + return True + return False + + def __iter__(self) -> Iterator[RaggedRequest]: + return iter(self.requests) + + def __repr__(self) -> str: + return f"RaggedRequestBatch({self.requests})" + + @property + def requests_to_run(self) -> Self: + return RaggedRequestBatch([r for r in self.requests if not r.is_flush_request]) + + @property + def requests_to_flush(self) -> Self: + return RaggedRequestBatch([r for r in self.requests if r.is_flush_request]) + + @property + def last_in_prompt(self) -> Self: + return RaggedRequestBatch([r for r in self.requests if r.last_in_prompt]) + + @property + def completed(self) -> Self: + return RaggedRequestBatch([r for r in self.requests if r.stop_generation]) + + @property + def uids(self) -> List[int]: + return [r.uid for r in self.requests] + + @property + def lengths(self) -> List[int]: + return [len(r.input_tokens) for r in self.requests] + + @property + def tokens(self) -> List[torch.Tensor]: + return [r.input_tokens for r in self.requests] + + @property + def next_tokens(self) -> List[torch.Tensor]: + return [r.next_token for r in self.requests] + + @property + def done_tokens(self) -> List[torch.Tensor]: + return [r.is_done for r in self.requests] + + @next_tokens.setter + def next_tokens(self, next_tokens: List[torch.Tensor]) -> None: + assert len(next_tokens) == len(self.requests) + for idx, r in enumerate(self.requests): + r.next_token = next_tokens[idx] + + @done_tokens.setter + def done_tokens(self, done_tokens: List[torch.Tensor]) -> None: + assert len(done_tokens) == len(self.requests) + for idx, r in enumerate(self.requests): + r.is_done = done_tokens[idx] + + def prune(self, uids: List[int]) -> None: + self.requests = [r for r in self.requests if r.uid not in uids] + + def append(self, r: RaggedRequest) -> None: + self.requests.append(r) + + def update_seq_length(self) -> None: + for r in self.requests: + r.seq_length += r.input_tokens.size(0) diff --git a/mii/batching/ragged_batching.py b/mii/batching/ragged_batching.py index 217fa606..b9331ecf 100644 --- a/mii/batching/ragged_batching.py +++ b/mii/batching/ragged_batching.py @@ -9,10 +9,9 @@ import threading import time from collections import deque, defaultdict -from dataclasses import dataclass, asdict, field +from dataclasses import asdict from functools import cached_property -from typing import Dict, Tuple, List, Any, Iterator, Union, DefaultDict -from typing_extensions import Self +from typing import Dict, Tuple, List, Any, Union, DefaultDict import torch import ujson @@ -42,6 +41,7 @@ TEMP_NAME, SAMPLER_NAME, STOP_NAME) +from mii.batching.data_classes import Response, RaggedRequest, ResponseBatch, RaggedRequestBatch, RaggedRequestMsg from mii.batching.generation.logit_processors import TopPLogitProcessor, TopKLogitProcessor, TemperatureLogitProcessor from mii.batching.generation.samplers import LogitsSampler, GreedySampler from mii.batching.generation.stop_criterion import EosGenerationStopCriterion, TokenStopCriterion @@ -55,262 +55,6 @@ from mii.logging import logger -@dataclass -class Response: - generated_text: str - prompt_length: int - generated_length: int - finish_reason: GenerationFinishReason - - @staticmethod - def from_msg(msg: Dict[str, Union[str, int]]) -> Self: - return Response( - generated_text=msg["generated_text"], - prompt_length=msg["prompt_length"], - generated_length=msg["generated_length"], - finish_reason=GenerationFinishReason(msg["finish_reason"]), - ) - - def get_msg(self) -> Dict[str, Union[str, int]]: - return { - "generated_text": self.generated_text, - "prompt_length": self.prompt_length, - "generated_length": self.generated_length, - "finish_reason": self.finish_reason.value - } - - def __repr__(self) -> str: - return self.generated_text - - def __str__(self) -> str: - return self.generated_text - - -class ResponseBatch: - def __init__(self, responses: List[Response]) -> None: - self.responses = responses - - def __iter__(self) -> Iterator[Response]: - return iter(self.responses) - - def __repr__(self) -> str: - return "\n\n".join(str(r) for r in self.responses) - - @property - def generated_texts(self) -> List[str]: - return [r.generated_text for r in self.responses] - - @property - def prompt_lengths(self) -> List[int]: - return [r.prompt_length for r in self.responses] - - @property - def generated_lengths(self) -> List[int]: - return [r.generated_length for r in self.responses] - - @property - def finish_reasons(self) -> List[GenerationFinishReason]: - return [r.finish_reason for r in self.responses] - - def append(self, response: Response) -> None: - self.responses.append(response) - - -@dataclass -class RaggedRequestMsg: - uid: int - input_tokens: Union[torch.Tensor, List[int]] - - @property - def is_flush_request(self): - return self.input_tokens is None - - @staticmethod - def from_msg(msg: Dict[str, int]) -> Self: - return RaggedRequestMsg( - uid=msg["uid"], - input_tokens=None - if msg["input_tokens"] is None else torch.tensor(msg["input_tokens"], - dtype=torch.int32, - device=torch.device("cpu")), - ) - - -@dataclass -class RaggedRequest: - tid: int - uid: int - input_tokens: torch.Tensor - prompt_tokens: torch.Tensor - seq_length: int - max_length: int - max_new_tokens: int - min_new_tokens: int - last_in_prompt: bool - post_processing: List[object] - stream: bool = False - ignore_eos: bool = False - return_full_text: bool = False - - _next_token: Union[None, torch.Tensor] = None - _is_done: bool = False - _generated_tokens: List[torch.Tensor] = field(default_factory=list) - _finish_reason: GenerationFinishReason = GenerationFinishReason.NONE - - @property - def prompt_length(self) -> int: - return len(self.prompt_tokens) - - @property - def next_token(self) -> Union[None, torch.Tensor]: - return self._next_token - - @next_token.setter - def next_token(self, next_token: Union[None, torch.Tensor]) -> None: - self._next_token = next_token - - @property - def is_done(self) -> bool: - if self.ignore_eos: - return False - if self.seq_length < self.min_new_tokens: - return False - return self._is_done - - @is_done.setter - def is_done(self, is_done: bool) -> None: - self._is_done = is_done - - @property - def generated_tokens(self) -> List[torch.Tensor]: - return self._generated_tokens - - @property - def finish_reason(self) -> GenerationFinishReason: - return self._finish_reason - - @property - def is_flush_request(self): - return self.input_tokens is None - - @property - def num_generated_tokens(self) -> int: - # We return zero while we are processing decomposed prompts - return self.seq_length - self.prompt_length + 1 if self.seq_length >= self.prompt_length else 0 - - @property - def stop_generation(self) -> bool: - if self.is_done: - self._finish_reason = GenerationFinishReason.STOP - return True - if (self.seq_length >= self.max_length) or (self.num_generated_tokens >= - self.max_new_tokens): - self._finish_reason = GenerationFinishReason.LENGTH - return True - return False - - def get_msg(self) -> RaggedRequestMsg: - return RaggedRequestMsg( - uid=self.uid, - input_tokens=None - if self.input_tokens is None else self.input_tokens.tolist(), - ) - - def accumulate_generated_token(self) -> None: - if not self.is_done: - self._generated_tokens.append(self.next_token) - - def clear_generated_token(self) -> None: - self._generated_tokens.clear() - - def set_next_as_input(self) -> None: - if self.next_token is not None: - self.input_tokens = self.next_token.unsqueeze(0) - self.last_in_prompt = True - self.next_token = None - self.is_done = False - - -class RaggedRequestBatch: - def __init__(self, requests: List[RaggedRequest]) -> None: - self.requests = requests - - def __len__(self) -> int: - return len(self.requests) - - def __contains__(self, r: RaggedRequest) -> bool: - return r in self.requests - - def __nonzero__(self) -> bool: - if len(self.requests) != 0: - return True - return False - - def __iter__(self) -> Iterator[RaggedRequest]: - return iter(self.requests) - - def __repr__(self) -> str: - return f"RaggedRequestBatch({self.requests})" - - @property - def requests_to_run(self) -> Self: - return RaggedRequestBatch([r for r in self.requests if not r.is_flush_request]) - - @property - def requests_to_flush(self) -> Self: - return RaggedRequestBatch([r for r in self.requests if r.is_flush_request]) - - @property - def last_in_prompt(self) -> Self: - return RaggedRequestBatch([r for r in self.requests if r.last_in_prompt]) - - @property - def completed(self) -> Self: - return RaggedRequestBatch([r for r in self.requests if r.stop_generation]) - - @property - def uids(self) -> List[int]: - return [r.uid for r in self.requests] - - @property - def lengths(self) -> List[int]: - return [len(r.input_tokens) for r in self.requests] - - @property - def tokens(self) -> List[torch.Tensor]: - return [r.input_tokens for r in self.requests] - - @property - def next_tokens(self) -> List[torch.Tensor]: - return [r.next_token for r in self.requests] - - @property - def done_tokens(self) -> List[torch.Tensor]: - return [r.is_done for r in self.requests] - - @next_tokens.setter - def next_tokens(self, next_tokens: List[torch.Tensor]) -> None: - assert len(next_tokens) == len(self.requests) - for idx, r in enumerate(self.requests): - r.next_token = next_tokens[idx] - - @done_tokens.setter - def done_tokens(self, done_tokens: List[torch.Tensor]) -> None: - assert len(done_tokens) == len(self.requests) - for idx, r in enumerate(self.requests): - r.is_done = done_tokens[idx] - - def prune(self, uids: List[int]) -> None: - self.requests = [r for r in self.requests if r.uid not in uids] - - def append(self, r: RaggedRequest) -> None: - self.requests.append(r) - - def update_seq_length(self) -> None: - for r in self.requests: - r.seq_length += r.input_tokens.size(0) - - class RaggedBatchBase: def __init__(self, inference_engine, tokenizer, model_config): self.inference_engine = inference_engine From b03947eaa497747ff5025bc4b41d4d0513c44339 Mon Sep 17 00:00:00 2001 From: Michael Wyatt Date: Mon, 13 Nov 2023 15:55:39 -0800 Subject: [PATCH 2/9] rename RaggedRequest -> Request, add some additional class methods --- mii/backend/client.py | 5 +- mii/batching/data_classes.py | 125 +++++++++++++++++--------------- mii/batching/ragged_batching.py | 32 ++++---- 3 files changed, 85 insertions(+), 77 deletions(-) diff --git a/mii/backend/client.py b/mii/backend/client.py index b2cd8118..44da5855 100644 --- a/mii/backend/client.py +++ b/mii/backend/client.py @@ -7,6 +7,7 @@ import requests from typing import Dict, Any, Callable, List, Union +from mii.batching.data_classes import ResponseBatch from mii.config import MIIConfig from mii.constants import GRPC_MAX_MSG_SIZE from mii.grpc_related.proto import modelresponse_pb2, modelresponse_pb2_grpc @@ -37,7 +38,7 @@ def __init__(self, mii_config: MIIConfig, host: str = "localhost") -> None: channel = create_channel(host, self.port) self.stub = modelresponse_pb2_grpc.ModelResponseStub(channel) - def __call__(self, *args, **kwargs): + def __call__(self, *args, **kwargs) -> ResponseBatch: return self.generate(*args, **kwargs) async def _request_async_response(self, request_dict, **query_kwargs): @@ -59,7 +60,7 @@ def generate(self, List[str]], streaming_fn: Callable = None, **query_kwargs: Dict[str, - Any]): + Any]) -> ResponseBatch: if isinstance(prompts, str): prompts = [prompts] if streaming_fn is not None: diff --git a/mii/batching/data_classes.py b/mii/batching/data_classes.py index c57e04d1..9bad1e1c 100644 --- a/mii/batching/data_classes.py +++ b/mii/batching/data_classes.py @@ -3,7 +3,7 @@ # DeepSpeed Team from dataclasses import dataclass, field -from typing import Dict, List, Iterator, Union +from typing import Any, Dict, List, Iterator, Union from typing_extensions import Self import torch @@ -27,7 +27,7 @@ def from_msg(msg: Dict[str, Union[str, int]]) -> Self: finish_reason=GenerationFinishReason(msg["finish_reason"]), ) - def get_msg(self) -> Dict[str, Union[str, int]]: + def to_msg(self) -> Dict[str, Union[str, int]]: return { "generated_text": self.generated_text, "prompt_length": self.prompt_length, @@ -42,38 +42,8 @@ def __str__(self) -> str: return self.generated_text -class ResponseBatch: - def __init__(self, responses: List[Response]) -> None: - self.responses = responses - - def __iter__(self) -> Iterator[Response]: - return iter(self.responses) - - def __repr__(self) -> str: - return "\n\n".join(str(r) for r in self.responses) - - @property - def generated_texts(self) -> List[str]: - return [r.generated_text for r in self.responses] - - @property - def prompt_lengths(self) -> List[int]: - return [r.prompt_length for r in self.responses] - - @property - def generated_lengths(self) -> List[int]: - return [r.generated_length for r in self.responses] - - @property - def finish_reasons(self) -> List[GenerationFinishReason]: - return [r.finish_reason for r in self.responses] - - def append(self, response: Response) -> None: - self.responses.append(response) - - @dataclass -class RaggedRequestMsg: +class RequestMsg: uid: int input_tokens: Union[torch.Tensor, List[int]] @@ -82,18 +52,17 @@ def is_flush_request(self): return self.input_tokens is None @staticmethod - def from_msg(msg: Dict[str, int]) -> Self: - return RaggedRequestMsg( - uid=msg["uid"], - input_tokens=None - if msg["input_tokens"] is None else torch.tensor(msg["input_tokens"], - dtype=torch.int32, - device=torch.device("cpu")), - ) + def from_msg_dict(msg: Dict[str, Any]) -> Self: + input_tokens = msg["input_tokens"] + if input_tokens is not None: + input_tokens = torch.tensor(msg["input_tokens"], + dtype=torch.int32, + device=torch.device("cpu")) + return RequestMsg(uid=msg["uid"], input_tokens=input_tokens) @dataclass -class RaggedRequest: +class Request: tid: int uid: int input_tokens: torch.Tensor @@ -156,6 +125,7 @@ def num_generated_tokens(self) -> int: @property def stop_generation(self) -> bool: + # Returns whether to stop generation for request if self.is_done: self._finish_reason = GenerationFinishReason.STOP return True @@ -165,14 +135,15 @@ def stop_generation(self) -> bool: return True return False - def get_msg(self) -> RaggedRequestMsg: - return RaggedRequestMsg( - uid=self.uid, - input_tokens=None - if self.input_tokens is None else self.input_tokens.tolist(), - ) + def to_msg_dict(self) -> Dict[str, Any]: + # Returns a minimal version of the request of purposes of broadcasting to all ranks + input_tokens = self.input_tokens + if input_tokens is not None: + input_tokens = self.input_tokens.tolist() + return {"uid": self.uid, "input_tokens": input_tokens} def accumulate_generated_token(self) -> None: + # Append the latest token to the list of generated tokens if not self.is_done: self._generated_tokens.append(self.next_token) @@ -180,6 +151,7 @@ def clear_generated_token(self) -> None: self._generated_tokens.clear() def set_next_as_input(self) -> None: + # Places the next token into the input token for next round of generation if self.next_token is not None: self.input_tokens = self.next_token.unsqueeze(0) self.last_in_prompt = True @@ -187,14 +159,44 @@ def set_next_as_input(self) -> None: self.is_done = False -class RaggedRequestBatch: - def __init__(self, requests: List[RaggedRequest]) -> None: +class ResponseBatch: + def __init__(self, responses: List[Response] = []) -> None: + self.responses = responses + + def __iter__(self) -> Iterator[Response]: + return iter(self.responses) + + def __repr__(self) -> str: + return "\n\n".join(str(r) for r in self.responses) + + @property + def generated_texts(self) -> List[str]: + return [r.generated_text for r in self.responses] + + @property + def prompt_lengths(self) -> List[int]: + return [r.prompt_length for r in self.responses] + + @property + def generated_lengths(self) -> List[int]: + return [r.generated_length for r in self.responses] + + @property + def finish_reasons(self) -> List[GenerationFinishReason]: + return [r.finish_reason for r in self.responses] + + def append(self, response: Response) -> None: + self.responses.append(response) + + +class RequestBatch: + def __init__(self, requests: List[Request] = []) -> None: self.requests = requests def __len__(self) -> int: return len(self.requests) - def __contains__(self, r: RaggedRequest) -> bool: + def __contains__(self, r: Request) -> bool: return r in self.requests def __nonzero__(self) -> bool: @@ -202,27 +204,27 @@ def __nonzero__(self) -> bool: return True return False - def __iter__(self) -> Iterator[RaggedRequest]: + def __iter__(self) -> Iterator[Request]: return iter(self.requests) def __repr__(self) -> str: - return f"RaggedRequestBatch({self.requests})" + return f"RequestBatch({self.requests})" @property def requests_to_run(self) -> Self: - return RaggedRequestBatch([r for r in self.requests if not r.is_flush_request]) + return RequestBatch([r for r in self.requests if not r.is_flush_request]) @property def requests_to_flush(self) -> Self: - return RaggedRequestBatch([r for r in self.requests if r.is_flush_request]) + return RequestBatch([r for r in self.requests if r.is_flush_request]) @property def last_in_prompt(self) -> Self: - return RaggedRequestBatch([r for r in self.requests if r.last_in_prompt]) + return RequestBatch([r for r in self.requests if r.last_in_prompt]) @property def completed(self) -> Self: - return RaggedRequestBatch([r for r in self.requests if r.stop_generation]) + return RequestBatch([r for r in self.requests if r.stop_generation]) @property def uids(self) -> List[int]: @@ -256,10 +258,17 @@ def done_tokens(self, done_tokens: List[torch.Tensor]) -> None: for idx, r in enumerate(self.requests): r.is_done = done_tokens[idx] + def to_msg_dicts(self) -> List[Dict[str, Any]]: + return [r.to_msg_dict() for r in self.requests] + + @staticmethod + def from_msg_dicts(msg_dicts: List[Dict[str, Any]]) -> Self: + return RequestBatch([RequestMsg.from_msg_dict(msg) for msg in msg_dicts]) + def prune(self, uids: List[int]) -> None: self.requests = [r for r in self.requests if r.uid not in uids] - def append(self, r: RaggedRequest) -> None: + def append(self, r: Request) -> None: self.requests.append(r) def update_seq_length(self) -> None: diff --git a/mii/batching/ragged_batching.py b/mii/batching/ragged_batching.py index b9331ecf..c693dc1f 100644 --- a/mii/batching/ragged_batching.py +++ b/mii/batching/ragged_batching.py @@ -9,7 +9,6 @@ import threading import time from collections import deque, defaultdict -from dataclasses import asdict from functools import cached_property from typing import Dict, Tuple, List, Any, Union, DefaultDict @@ -41,7 +40,7 @@ TEMP_NAME, SAMPLER_NAME, STOP_NAME) -from mii.batching.data_classes import Response, RaggedRequest, ResponseBatch, RaggedRequestBatch, RaggedRequestMsg +from mii.batching.data_classes import Response, Request, ResponseBatch, RequestBatch from mii.batching.generation.logit_processors import TopPLogitProcessor, TopKLogitProcessor, TemperatureLogitProcessor from mii.batching.generation.samplers import LogitsSampler, GreedySampler from mii.batching.generation.stop_criterion import EosGenerationStopCriterion, TokenStopCriterion @@ -71,7 +70,7 @@ def __init__(self, inference_engine, tokenizer, model_config): self.request_queue: queue.Queue = queue.Queue() self.result_queues: Dict[int, queue.Queue] = {} - self.scheduled_requests: RaggedRequestBatch = RaggedRequestBatch([]) + self.scheduled_requests: RequestBatch = RequestBatch() self.buffer = deque() self.scheduled_length = 0 self.scheduled_seq_num = 0 @@ -171,27 +170,26 @@ def _print_profiled_times(self) -> None: self._num_generated_tokens = 0 @sync_debug - def _bcast_requests(self, force=False) -> RaggedRequestBatch: + def _bcast_requests(self, force=False) -> RequestBatch: if self.is_rank_0: if not self.scheduled_requests and not force: return self.scheduled_requests # Rank 0 gets batch of requests and broadcasts to other ranks - data_dicts = [asdict(r.get_msg()) for r in self.scheduled_requests] + data_dicts = self.scheduled_requests.to_msg_dicts() json_data = ujson.dumps(data_dicts) self.socket.send_string(json_data) else: try: json_data = self.socket.recv_string() data_dicts = ujson.loads(json_data) - self.scheduled_requests = RaggedRequestBatch( - [RaggedRequestMsg.from_msg(msg) for msg in data_dicts]) + self.scheduled_requests = RequestBatch.from_msg_dicts(data_dicts) except zmq.Again: - self.scheduled_requests = RaggedRequestBatch([]) + self.scheduled_requests = RequestBatch() return self.scheduled_requests def _reset_scheduler_bookkeeping(self) -> None: - self.scheduled_requests = RaggedRequestBatch([]) + self.scheduled_requests = RequestBatch() self.scheduled_length = 0 self.scheduled_seq_num = 0 self.scheduled_req_blocks = 0 @@ -200,8 +198,8 @@ def _reset_scheduler_bookkeeping(self) -> None: def _process_logits( self, next_token_logits: torch.Tensor, - running_requests: RaggedRequestBatch) -> Tuple[torch.Tensor, - torch.Tensor]: + running_requests: RequestBatch) -> Tuple[torch.Tensor, + torch.Tensor]: next_token_logits = next_token_logits[:, :self.vocab_size] next_token_logits = self.logit_processor(next_token_logits, running_requests, @@ -216,7 +214,7 @@ def _process_logits( return next_tokens, done_tokens @sync_debug - def _generate_output(self, r: RaggedRequest) -> bool: + def _generate_output(self, r: Request) -> bool: outputs = [] if r.stream: outputs.append(( @@ -245,7 +243,7 @@ def _generate_output(self, r: RaggedRequest) -> bool: for output in outputs: self.result_queues[r.tid].put_nowait(output) - def _do_schedule_requests(self, requests: List[RaggedRequest]) -> None: + def _do_schedule_requests(self, requests: List[Request]) -> None: free_blocks = self.inference_engine._state_manager.free_blocks conf_manager = self.inference_engine._config.state_manager @@ -322,7 +320,7 @@ def schedule_requests(self) -> None: print( "Deadlock detected. Resetting KV cache and recomputing requests. Consider limiting number of concurrent requests or decreasing max lengths of prompts/generations." ) - self.scheduled_requests = RaggedRequestBatch([]) + self.scheduled_requests = RequestBatch() self.reset_request_status() else: scheduled_requests_ids = set(id(r) for r in self.scheduled_requests) @@ -331,7 +329,7 @@ def schedule_requests(self) -> None: def _queue_flush_request(self, uid: int) -> None: self.request_queue.put_nowait( - RaggedRequest( + Request( tid=None, uid=uid, input_tokens=None, @@ -366,7 +364,7 @@ def make_request(self, tid: int, uid: int, input_tokens: torch.Tensor, - kwargs: Dict) -> RaggedRequest: + kwargs: Dict) -> Request: prompt_length = len(input_tokens) max_length = kwargs.pop(MAX_LENGTH_KWARG, self.max_length) assert max_length > prompt_length, f"prompt length must be less than {MAX_LENGTH_KWARG}" @@ -426,7 +424,7 @@ def make_request(self, assert kwargs == {}, f"Unknown keyword arguments {kwargs}" - return RaggedRequest( + return Request( tid=tid, uid=uid, input_tokens=input_tokens, From 512e9aa9465f0b653a22fe277837c65457a6f024 Mon Sep 17 00:00:00 2001 From: Michael Wyatt Date: Tue, 14 Nov 2023 09:26:52 -0800 Subject: [PATCH 3/9] update dataclass methods --- mii/batching/data_classes.py | 5 +- mii/batching/ragged_batching.py | 2 +- mii/grpc_related/proto/modelresponse_pb2.py | 108 ++-- .../proto/modelresponse_pb2_grpc.py | 609 +++++++----------- mii/grpc_related/task_methods.py | 1 + 5 files changed, 303 insertions(+), 422 deletions(-) diff --git a/mii/batching/data_classes.py b/mii/batching/data_classes.py index 9bad1e1c..a6f524c3 100644 --- a/mii/batching/data_classes.py +++ b/mii/batching/data_classes.py @@ -36,10 +36,10 @@ def to_msg(self) -> Dict[str, Union[str, int]]: } def __repr__(self) -> str: - return self.generated_text + return str(self.to_msg()) def __str__(self) -> str: - return self.generated_text + return self.to_msg() @dataclass @@ -167,6 +167,7 @@ def __iter__(self) -> Iterator[Response]: return iter(self.responses) def __repr__(self) -> str: + return self.responses return "\n\n".join(str(r) for r in self.responses) @property diff --git a/mii/batching/ragged_batching.py b/mii/batching/ragged_batching.py index c693dc1f..9823631f 100644 --- a/mii/batching/ragged_batching.py +++ b/mii/batching/ragged_batching.py @@ -521,7 +521,7 @@ def _get_response(self) -> Tuple[int, Response]: def _bcast_responses(self, responses: ResponseBatch) -> ResponseBatch: if self.is_rank_0: - data_dicts = [r.get_msg() for r in responses] + data_dicts = [r.to_msg() for r in responses] json_data = ujson.dumps(data_dicts) self.socket.send_string(json_data) else: diff --git a/mii/grpc_related/proto/modelresponse_pb2.py b/mii/grpc_related/proto/modelresponse_pb2.py index 6b5294f7..cad660f3 100644 --- a/mii/grpc_related/proto/modelresponse_pb2.py +++ b/mii/grpc_related/proto/modelresponse_pb2.py @@ -1,7 +1,3 @@ -# Copyright (c) Microsoft Corporation. -# SPDX-License-Identifier: Apache-2.0 - -# DeepSpeed Team # -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! # source: modelresponse.proto @@ -14,63 +10,63 @@ _sym_db = _symbol_database.Default() + from google.protobuf import empty_pb2 as google_dot_protobuf_dot_empty__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x13modelresponse.proto\x12\rmodelresponse\x1a\x1bgoogle/protobuf/empty.proto\"\x88\x01\n\nDictionary\x12\x35\n\x06values\x18\x01 \x03(\x0b\x32%.modelresponse.Dictionary.ValuesEntry\x1a\x43\n\x0bValuesEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.modelresponse.Value:\x02\x38\x01\"\x8c\x01\n\x05Value\x12\x10\n\x06svalue\x18\x01 \x01(\tH\x00\x12\x10\n\x06ivalue\x18\x02 \x01(\x03H\x00\x12\x10\n\x06\x66value\x18\x03 \x01(\x02H\x00\x12\x10\n\x06\x62value\x18\x04 \x01(\x08H\x00\x12+\n\x06mvalue\x18\x05 \x01(\x0b\x32\x19.modelresponse.DictionaryH\x00\x42\x0e\n\x0coneof_values\"\x1f\n\tSessionID\x12\x12\n\nsession_id\x18\x01 \x01(\t\"\xbb\x01\n\x13SingleStringRequest\x12\x0f\n\x07request\x18\x01 \x01(\t\x12I\n\x0cquery_kwargs\x18\x02 \x03(\x0b\x32\x33.modelresponse.SingleStringRequest.QueryKwargsEntry\x1aH\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.modelresponse.Value:\x02\x38\x01\"\xb9\x01\n\x12MultiStringRequest\x12\x0f\n\x07request\x18\x01 \x03(\t\x12H\n\x0cquery_kwargs\x18\x02 \x03(\x0b\x32\x32.modelresponse.MultiStringRequest.QueryKwargsEntry\x1aH\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.modelresponse.Value:\x02\x38\x01\"S\n\x11SingleStringReply\x12\x10\n\x08response\x18\x01 \x01(\t\x12\x12\n\ntime_taken\x18\x02 \x01(\x02\x12\x18\n\x10model_time_taken\x18\x03 \x01(\x02\"R\n\x10MultiStringReply\x12\x10\n\x08response\x18\x01 \x03(\t\x12\x12\n\ntime_taken\x18\x02 \x01(\x02\x12\x18\n\x10model_time_taken\x18\x03 \x01(\x02\"[\n\x11GenerationDetails\x12\x15\n\rfinish_reason\x18\x01 \x01(\t\x12\x15\n\rprompt_tokens\x18\x02 \x01(\x03\x12\x18\n\x10generated_tokens\x18\x03 \x01(\x03\"\x95\x01\n\x0fGenerationReply\x12\x10\n\x08response\x18\x01 \x03(\t\x12\x0f\n\x07indices\x18\x02 \x03(\x03\x12\x31\n\x07\x64\x65tails\x18\x03 \x03(\x0b\x32 .modelresponse.GenerationDetails\x12\x12\n\ntime_taken\x18\x04 \x01(\x02\x12\x18\n\x10model_time_taken\x18\x05 \x01(\x02\"\xb9\x01\n\tQARequest\x12\x10\n\x08question\x18\x01 \x01(\t\x12\x0f\n\x07\x63ontext\x18\x02 \x01(\t\x12?\n\x0cquery_kwargs\x18\x03 \x03(\x0b\x32).modelresponse.QARequest.QueryKwargsEntry\x1aH\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.modelresponse.Value:\x02\x38\x01\"\x88\x02\n\x13\x43onversationRequest\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\x17\n\x0f\x63onversation_id\x18\x02 \x01(\t\x12\x18\n\x10past_user_inputs\x18\x03 \x03(\t\x12\x1b\n\x13generated_responses\x18\x04 \x03(\t\x12I\n\x0cquery_kwargs\x18\x05 \x03(\x0b\x32\x33.modelresponse.ConversationRequest.QueryKwargsEntry\x1aH\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.modelresponse.Value:\x02\x38\x01\"\x91\x01\n\x11\x43onversationReply\x12\x17\n\x0f\x63onversation_id\x18\x01 \x01(\t\x12\x18\n\x10past_user_inputs\x18\x02 \x03(\t\x12\x1b\n\x13generated_responses\x18\x03 \x03(\t\x12\x12\n\ntime_taken\x18\x04 \x01(\x02\x12\x18\n\x10model_time_taken\x18\x05 \x01(\x02\"}\n\nImageReply\x12\x0e\n\x06images\x18\x01 \x03(\x0c\x12\x1d\n\x15nsfw_content_detected\x18\x02 \x03(\x08\x12\x0c\n\x04mode\x18\x03 \x01(\t\x12\x0e\n\x06size_w\x18\x04 \x01(\x03\x12\x0e\n\x06size_h\x18\x05 \x01(\x03\x12\x12\n\ntime_taken\x18\x06 \x01(\x02\x32\xb3\x07\n\rModelResponse\x12=\n\tTerminate\x12\x16.google.protobuf.Empty\x1a\x16.google.protobuf.Empty\"\x00\x12\x43\n\rCreateSession\x12\x18.modelresponse.SessionID\x1a\x16.google.protobuf.Empty\"\x00\x12\x44\n\x0e\x44\x65stroySession\x12\x18.modelresponse.SessionID\x1a\x16.google.protobuf.Empty\"\x00\x12V\n\x0eGeneratorReply\x12!.modelresponse.MultiStringRequest\x1a\x1f.modelresponse.MultiStringReply\"\x00\x12]\n\x13\x43lassificationReply\x12\".modelresponse.SingleStringRequest\x1a .modelresponse.SingleStringReply\"\x00\x12V\n\x16QuestionAndAnswerReply\x12\x18.modelresponse.QARequest\x1a .modelresponse.SingleStringReply\"\x00\x12W\n\rFillMaskReply\x12\".modelresponse.SingleStringRequest\x1a .modelresponse.SingleStringReply\"\x00\x12\x62\n\x18TokenClassificationReply\x12\".modelresponse.SingleStringRequest\x1a .modelresponse.SingleStringReply\"\x00\x12]\n\x13\x43onversationalReply\x12\".modelresponse.ConversationRequest\x1a .modelresponse.ConversationReply\"\x00\x12N\n\x0cTxt2ImgReply\x12!.modelresponse.MultiStringRequest\x1a\x19.modelresponse.ImageReply\"\x00\x12]\n\x14GeneratorReplyStream\x12!.modelresponse.MultiStringRequest\x1a\x1e.modelresponse.GenerationReply\"\x00\x30\x01\x62\x06proto3' -) + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x13modelresponse.proto\x12\rmodelresponse\x1a\x1bgoogle/protobuf/empty.proto\"\x88\x01\n\nDictionary\x12\x35\n\x06values\x18\x01 \x03(\x0b\x32%.modelresponse.Dictionary.ValuesEntry\x1a\x43\n\x0bValuesEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.modelresponse.Value:\x02\x38\x01\"\x8c\x01\n\x05Value\x12\x10\n\x06svalue\x18\x01 \x01(\tH\x00\x12\x10\n\x06ivalue\x18\x02 \x01(\x03H\x00\x12\x10\n\x06\x66value\x18\x03 \x01(\x02H\x00\x12\x10\n\x06\x62value\x18\x04 \x01(\x08H\x00\x12+\n\x06mvalue\x18\x05 \x01(\x0b\x32\x19.modelresponse.DictionaryH\x00\x42\x0e\n\x0coneof_values\"\x1f\n\tSessionID\x12\x12\n\nsession_id\x18\x01 \x01(\t\"\xbb\x01\n\x13SingleStringRequest\x12\x0f\n\x07request\x18\x01 \x01(\t\x12I\n\x0cquery_kwargs\x18\x02 \x03(\x0b\x32\x33.modelresponse.SingleStringRequest.QueryKwargsEntry\x1aH\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.modelresponse.Value:\x02\x38\x01\"\xb9\x01\n\x12MultiStringRequest\x12\x0f\n\x07request\x18\x01 \x03(\t\x12H\n\x0cquery_kwargs\x18\x02 \x03(\x0b\x32\x32.modelresponse.MultiStringRequest.QueryKwargsEntry\x1aH\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.modelresponse.Value:\x02\x38\x01\"S\n\x11SingleStringReply\x12\x10\n\x08response\x18\x01 \x01(\t\x12\x12\n\ntime_taken\x18\x02 \x01(\x02\x12\x18\n\x10model_time_taken\x18\x03 \x01(\x02\"R\n\x10MultiStringReply\x12\x10\n\x08response\x18\x01 \x03(\t\x12\x12\n\ntime_taken\x18\x02 \x01(\x02\x12\x18\n\x10model_time_taken\x18\x03 \x01(\x02\"[\n\x11GenerationDetails\x12\x15\n\rfinish_reason\x18\x01 \x01(\t\x12\x15\n\rprompt_tokens\x18\x02 \x01(\x03\x12\x18\n\x10generated_tokens\x18\x03 \x01(\x03\"\x95\x01\n\x0fGenerationReply\x12\x10\n\x08response\x18\x01 \x03(\t\x12\x0f\n\x07indices\x18\x02 \x03(\x03\x12\x31\n\x07\x64\x65tails\x18\x03 \x03(\x0b\x32 .modelresponse.GenerationDetails\x12\x12\n\ntime_taken\x18\x04 \x01(\x02\x12\x18\n\x10model_time_taken\x18\x05 \x01(\x02\"\xb9\x01\n\tQARequest\x12\x10\n\x08question\x18\x01 \x01(\t\x12\x0f\n\x07\x63ontext\x18\x02 \x01(\t\x12?\n\x0cquery_kwargs\x18\x03 \x03(\x0b\x32).modelresponse.QARequest.QueryKwargsEntry\x1aH\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.modelresponse.Value:\x02\x38\x01\"\x88\x02\n\x13\x43onversationRequest\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\x17\n\x0f\x63onversation_id\x18\x02 \x01(\t\x12\x18\n\x10past_user_inputs\x18\x03 \x03(\t\x12\x1b\n\x13generated_responses\x18\x04 \x03(\t\x12I\n\x0cquery_kwargs\x18\x05 \x03(\x0b\x32\x33.modelresponse.ConversationRequest.QueryKwargsEntry\x1aH\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.modelresponse.Value:\x02\x38\x01\"\x91\x01\n\x11\x43onversationReply\x12\x17\n\x0f\x63onversation_id\x18\x01 \x01(\t\x12\x18\n\x10past_user_inputs\x18\x02 \x03(\t\x12\x1b\n\x13generated_responses\x18\x03 \x03(\t\x12\x12\n\ntime_taken\x18\x04 \x01(\x02\x12\x18\n\x10model_time_taken\x18\x05 \x01(\x02\"}\n\nImageReply\x12\x0e\n\x06images\x18\x01 \x03(\x0c\x12\x1d\n\x15nsfw_content_detected\x18\x02 \x03(\x08\x12\x0c\n\x04mode\x18\x03 \x01(\t\x12\x0e\n\x06size_w\x18\x04 \x01(\x03\x12\x0e\n\x06size_h\x18\x05 \x01(\x03\x12\x12\n\ntime_taken\x18\x06 \x01(\x02\x32\xb3\x07\n\rModelResponse\x12=\n\tTerminate\x12\x16.google.protobuf.Empty\x1a\x16.google.protobuf.Empty\"\x00\x12\x43\n\rCreateSession\x12\x18.modelresponse.SessionID\x1a\x16.google.protobuf.Empty\"\x00\x12\x44\n\x0e\x44\x65stroySession\x12\x18.modelresponse.SessionID\x1a\x16.google.protobuf.Empty\"\x00\x12V\n\x0eGeneratorReply\x12!.modelresponse.MultiStringRequest\x1a\x1f.modelresponse.MultiStringReply\"\x00\x12]\n\x13\x43lassificationReply\x12\".modelresponse.SingleStringRequest\x1a .modelresponse.SingleStringReply\"\x00\x12V\n\x16QuestionAndAnswerReply\x12\x18.modelresponse.QARequest\x1a .modelresponse.SingleStringReply\"\x00\x12W\n\rFillMaskReply\x12\".modelresponse.SingleStringRequest\x1a .modelresponse.SingleStringReply\"\x00\x12\x62\n\x18TokenClassificationReply\x12\".modelresponse.SingleStringRequest\x1a .modelresponse.SingleStringReply\"\x00\x12]\n\x13\x43onversationalReply\x12\".modelresponse.ConversationRequest\x1a .modelresponse.ConversationReply\"\x00\x12N\n\x0cTxt2ImgReply\x12!.modelresponse.MultiStringRequest\x1a\x19.modelresponse.ImageReply\"\x00\x12]\n\x14GeneratorReplyStream\x12!.modelresponse.MultiStringRequest\x1a\x1e.modelresponse.GenerationReply\"\x00\x30\x01\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'modelresponse_pb2', _globals) if _descriptor._USE_C_DESCRIPTORS == False: - DESCRIPTOR._options = None - _DICTIONARY_VALUESENTRY._options = None - _DICTIONARY_VALUESENTRY._serialized_options = b'8\001' - _SINGLESTRINGREQUEST_QUERYKWARGSENTRY._options = None - _SINGLESTRINGREQUEST_QUERYKWARGSENTRY._serialized_options = b'8\001' - _MULTISTRINGREQUEST_QUERYKWARGSENTRY._options = None - _MULTISTRINGREQUEST_QUERYKWARGSENTRY._serialized_options = b'8\001' - _QAREQUEST_QUERYKWARGSENTRY._options = None - _QAREQUEST_QUERYKWARGSENTRY._serialized_options = b'8\001' - _CONVERSATIONREQUEST_QUERYKWARGSENTRY._options = None - _CONVERSATIONREQUEST_QUERYKWARGSENTRY._serialized_options = b'8\001' - _globals['_DICTIONARY']._serialized_start = 68 - _globals['_DICTIONARY']._serialized_end = 204 - _globals['_DICTIONARY_VALUESENTRY']._serialized_start = 137 - _globals['_DICTIONARY_VALUESENTRY']._serialized_end = 204 - _globals['_VALUE']._serialized_start = 207 - _globals['_VALUE']._serialized_end = 347 - _globals['_SESSIONID']._serialized_start = 349 - _globals['_SESSIONID']._serialized_end = 380 - _globals['_SINGLESTRINGREQUEST']._serialized_start = 383 - _globals['_SINGLESTRINGREQUEST']._serialized_end = 570 - _globals['_SINGLESTRINGREQUEST_QUERYKWARGSENTRY']._serialized_start = 498 - _globals['_SINGLESTRINGREQUEST_QUERYKWARGSENTRY']._serialized_end = 570 - _globals['_MULTISTRINGREQUEST']._serialized_start = 573 - _globals['_MULTISTRINGREQUEST']._serialized_end = 758 - _globals['_MULTISTRINGREQUEST_QUERYKWARGSENTRY']._serialized_start = 498 - _globals['_MULTISTRINGREQUEST_QUERYKWARGSENTRY']._serialized_end = 570 - _globals['_SINGLESTRINGREPLY']._serialized_start = 760 - _globals['_SINGLESTRINGREPLY']._serialized_end = 843 - _globals['_MULTISTRINGREPLY']._serialized_start = 845 - _globals['_MULTISTRINGREPLY']._serialized_end = 927 - _globals['_GENERATIONDETAILS']._serialized_start = 929 - _globals['_GENERATIONDETAILS']._serialized_end = 1020 - _globals['_GENERATIONREPLY']._serialized_start = 1023 - _globals['_GENERATIONREPLY']._serialized_end = 1172 - _globals['_QAREQUEST']._serialized_start = 1175 - _globals['_QAREQUEST']._serialized_end = 1360 - _globals['_QAREQUEST_QUERYKWARGSENTRY']._serialized_start = 498 - _globals['_QAREQUEST_QUERYKWARGSENTRY']._serialized_end = 570 - _globals['_CONVERSATIONREQUEST']._serialized_start = 1363 - _globals['_CONVERSATIONREQUEST']._serialized_end = 1627 - _globals['_CONVERSATIONREQUEST_QUERYKWARGSENTRY']._serialized_start = 498 - _globals['_CONVERSATIONREQUEST_QUERYKWARGSENTRY']._serialized_end = 570 - _globals['_CONVERSATIONREPLY']._serialized_start = 1630 - _globals['_CONVERSATIONREPLY']._serialized_end = 1775 - _globals['_IMAGEREPLY']._serialized_start = 1777 - _globals['_IMAGEREPLY']._serialized_end = 1902 - _globals['_MODELRESPONSE']._serialized_start = 1905 - _globals['_MODELRESPONSE']._serialized_end = 2852 + DESCRIPTOR._options = None + _DICTIONARY_VALUESENTRY._options = None + _DICTIONARY_VALUESENTRY._serialized_options = b'8\001' + _SINGLESTRINGREQUEST_QUERYKWARGSENTRY._options = None + _SINGLESTRINGREQUEST_QUERYKWARGSENTRY._serialized_options = b'8\001' + _MULTISTRINGREQUEST_QUERYKWARGSENTRY._options = None + _MULTISTRINGREQUEST_QUERYKWARGSENTRY._serialized_options = b'8\001' + _QAREQUEST_QUERYKWARGSENTRY._options = None + _QAREQUEST_QUERYKWARGSENTRY._serialized_options = b'8\001' + _CONVERSATIONREQUEST_QUERYKWARGSENTRY._options = None + _CONVERSATIONREQUEST_QUERYKWARGSENTRY._serialized_options = b'8\001' + _globals['_DICTIONARY']._serialized_start=68 + _globals['_DICTIONARY']._serialized_end=204 + _globals['_DICTIONARY_VALUESENTRY']._serialized_start=137 + _globals['_DICTIONARY_VALUESENTRY']._serialized_end=204 + _globals['_VALUE']._serialized_start=207 + _globals['_VALUE']._serialized_end=347 + _globals['_SESSIONID']._serialized_start=349 + _globals['_SESSIONID']._serialized_end=380 + _globals['_SINGLESTRINGREQUEST']._serialized_start=383 + _globals['_SINGLESTRINGREQUEST']._serialized_end=570 + _globals['_SINGLESTRINGREQUEST_QUERYKWARGSENTRY']._serialized_start=498 + _globals['_SINGLESTRINGREQUEST_QUERYKWARGSENTRY']._serialized_end=570 + _globals['_MULTISTRINGREQUEST']._serialized_start=573 + _globals['_MULTISTRINGREQUEST']._serialized_end=758 + _globals['_MULTISTRINGREQUEST_QUERYKWARGSENTRY']._serialized_start=498 + _globals['_MULTISTRINGREQUEST_QUERYKWARGSENTRY']._serialized_end=570 + _globals['_SINGLESTRINGREPLY']._serialized_start=760 + _globals['_SINGLESTRINGREPLY']._serialized_end=843 + _globals['_MULTISTRINGREPLY']._serialized_start=845 + _globals['_MULTISTRINGREPLY']._serialized_end=927 + _globals['_GENERATIONDETAILS']._serialized_start=929 + _globals['_GENERATIONDETAILS']._serialized_end=1020 + _globals['_GENERATIONREPLY']._serialized_start=1023 + _globals['_GENERATIONREPLY']._serialized_end=1172 + _globals['_QAREQUEST']._serialized_start=1175 + _globals['_QAREQUEST']._serialized_end=1360 + _globals['_QAREQUEST_QUERYKWARGSENTRY']._serialized_start=498 + _globals['_QAREQUEST_QUERYKWARGSENTRY']._serialized_end=570 + _globals['_CONVERSATIONREQUEST']._serialized_start=1363 + _globals['_CONVERSATIONREQUEST']._serialized_end=1627 + _globals['_CONVERSATIONREQUEST_QUERYKWARGSENTRY']._serialized_start=498 + _globals['_CONVERSATIONREQUEST_QUERYKWARGSENTRY']._serialized_end=570 + _globals['_CONVERSATIONREPLY']._serialized_start=1630 + _globals['_CONVERSATIONREPLY']._serialized_end=1775 + _globals['_IMAGEREPLY']._serialized_start=1777 + _globals['_IMAGEREPLY']._serialized_end=1902 + _globals['_MODELRESPONSE']._serialized_start=1905 + _globals['_MODELRESPONSE']._serialized_end=2852 # @@protoc_insertion_point(module_scope) diff --git a/mii/grpc_related/proto/modelresponse_pb2_grpc.py b/mii/grpc_related/proto/modelresponse_pb2_grpc.py index 4f16a368..48df7c9d 100644 --- a/mii/grpc_related/proto/modelresponse_pb2_grpc.py +++ b/mii/grpc_related/proto/modelresponse_pb2_grpc.py @@ -1,7 +1,3 @@ -# Copyright (c) Microsoft Corporation. -# SPDX-License-Identifier: Apache-2.0 - -# DeepSpeed Team # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! """Client and server classes corresponding to protobuf-defined services.""" import grpc @@ -12,6 +8,7 @@ class ModelResponseStub(object): """Missing associated documentation comment in .proto file.""" + def __init__(self, channel): """Constructor. @@ -19,126 +16,124 @@ def __init__(self, channel): channel: A grpc.Channel. """ self.Terminate = channel.unary_unary( - '/modelresponse.ModelResponse/Terminate', - request_serializer=google_dot_protobuf_dot_empty__pb2.Empty. - SerializeToString, - response_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString, - ) + '/modelresponse.ModelResponse/Terminate', + request_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString, + response_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString, + ) self.CreateSession = channel.unary_unary( - '/modelresponse.ModelResponse/CreateSession', - request_serializer=modelresponse__pb2.SessionID.SerializeToString, - response_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString, - ) + '/modelresponse.ModelResponse/CreateSession', + request_serializer=modelresponse__pb2.SessionID.SerializeToString, + response_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString, + ) self.DestroySession = channel.unary_unary( - '/modelresponse.ModelResponse/DestroySession', - request_serializer=modelresponse__pb2.SessionID.SerializeToString, - response_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString, - ) + '/modelresponse.ModelResponse/DestroySession', + request_serializer=modelresponse__pb2.SessionID.SerializeToString, + response_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString, + ) self.GeneratorReply = channel.unary_unary( - '/modelresponse.ModelResponse/GeneratorReply', - request_serializer=modelresponse__pb2.MultiStringRequest.SerializeToString, - response_deserializer=modelresponse__pb2.MultiStringReply.FromString, - ) + '/modelresponse.ModelResponse/GeneratorReply', + request_serializer=modelresponse__pb2.MultiStringRequest.SerializeToString, + response_deserializer=modelresponse__pb2.MultiStringReply.FromString, + ) self.ClassificationReply = channel.unary_unary( - '/modelresponse.ModelResponse/ClassificationReply', - request_serializer=modelresponse__pb2.SingleStringRequest.SerializeToString, - response_deserializer=modelresponse__pb2.SingleStringReply.FromString, - ) + '/modelresponse.ModelResponse/ClassificationReply', + request_serializer=modelresponse__pb2.SingleStringRequest.SerializeToString, + response_deserializer=modelresponse__pb2.SingleStringReply.FromString, + ) self.QuestionAndAnswerReply = channel.unary_unary( - '/modelresponse.ModelResponse/QuestionAndAnswerReply', - request_serializer=modelresponse__pb2.QARequest.SerializeToString, - response_deserializer=modelresponse__pb2.SingleStringReply.FromString, - ) + '/modelresponse.ModelResponse/QuestionAndAnswerReply', + request_serializer=modelresponse__pb2.QARequest.SerializeToString, + response_deserializer=modelresponse__pb2.SingleStringReply.FromString, + ) self.FillMaskReply = channel.unary_unary( - '/modelresponse.ModelResponse/FillMaskReply', - request_serializer=modelresponse__pb2.SingleStringRequest.SerializeToString, - response_deserializer=modelresponse__pb2.SingleStringReply.FromString, - ) + '/modelresponse.ModelResponse/FillMaskReply', + request_serializer=modelresponse__pb2.SingleStringRequest.SerializeToString, + response_deserializer=modelresponse__pb2.SingleStringReply.FromString, + ) self.TokenClassificationReply = channel.unary_unary( - '/modelresponse.ModelResponse/TokenClassificationReply', - request_serializer=modelresponse__pb2.SingleStringRequest.SerializeToString, - response_deserializer=modelresponse__pb2.SingleStringReply.FromString, - ) + '/modelresponse.ModelResponse/TokenClassificationReply', + request_serializer=modelresponse__pb2.SingleStringRequest.SerializeToString, + response_deserializer=modelresponse__pb2.SingleStringReply.FromString, + ) self.ConversationalReply = channel.unary_unary( - '/modelresponse.ModelResponse/ConversationalReply', - request_serializer=modelresponse__pb2.ConversationRequest.SerializeToString, - response_deserializer=modelresponse__pb2.ConversationReply.FromString, - ) + '/modelresponse.ModelResponse/ConversationalReply', + request_serializer=modelresponse__pb2.ConversationRequest.SerializeToString, + response_deserializer=modelresponse__pb2.ConversationReply.FromString, + ) self.Txt2ImgReply = channel.unary_unary( - '/modelresponse.ModelResponse/Txt2ImgReply', - request_serializer=modelresponse__pb2.MultiStringRequest.SerializeToString, - response_deserializer=modelresponse__pb2.ImageReply.FromString, - ) + '/modelresponse.ModelResponse/Txt2ImgReply', + request_serializer=modelresponse__pb2.MultiStringRequest.SerializeToString, + response_deserializer=modelresponse__pb2.ImageReply.FromString, + ) self.GeneratorReplyStream = channel.unary_stream( - '/modelresponse.ModelResponse/GeneratorReplyStream', - request_serializer=modelresponse__pb2.MultiStringRequest.SerializeToString, - response_deserializer=modelresponse__pb2.GenerationReply.FromString, - ) + '/modelresponse.ModelResponse/GeneratorReplyStream', + request_serializer=modelresponse__pb2.MultiStringRequest.SerializeToString, + response_deserializer=modelresponse__pb2.GenerationReply.FromString, + ) class ModelResponseServicer(object): """Missing associated documentation comment in .proto file.""" - ERROR_MSG = 'Method not implemented!' def Terminate(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details(self.ERROR_MSG) - raise NotImplementedError(self.ERROR_MSG) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') def CreateSession(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details(self.ERROR_MSG) - raise NotImplementedError(self.ERROR_MSG) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') def DestroySession(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details(self.ERROR_MSG) - raise NotImplementedError(self.ERROR_MSG) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') def GeneratorReply(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details(self.ERROR_MSG) - raise NotImplementedError(self.ERROR_MSG) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') def ClassificationReply(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details(self.ERROR_MSG) - raise NotImplementedError(self.ERROR_MSG) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') def QuestionAndAnswerReply(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details(self.ERROR_MSG) - raise NotImplementedError(self.ERROR_MSG) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') def FillMaskReply(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details(self.ERROR_MSG) - raise NotImplementedError(self.ERROR_MSG) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') def TokenClassificationReply(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details(self.ERROR_MSG) - raise NotImplementedError(self.ERROR_MSG) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') def ConversationalReply(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details(self.ERROR_MSG) - raise NotImplementedError(self.ERROR_MSG) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') def Txt2ImgReply(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details(self.ERROR_MSG) - raise NotImplementedError(self.ERROR_MSG) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') def GeneratorReplyStream(self, request, context): """Missing associated documentation comment in .proto file.""" @@ -149,366 +144,254 @@ def GeneratorReplyStream(self, request, context): def add_ModelResponseServicer_to_server(servicer, server): rpc_method_handlers = { - 'Terminate': - grpc.unary_unary_rpc_method_handler( - servicer.Terminate, - request_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString, - response_serializer=google_dot_protobuf_dot_empty__pb2.Empty. - SerializeToString, - ), - 'CreateSession': - grpc.unary_unary_rpc_method_handler( - servicer.CreateSession, - request_deserializer=modelresponse__pb2.SessionID.FromString, - response_serializer=google_dot_protobuf_dot_empty__pb2.Empty. - SerializeToString, - ), - 'DestroySession': - grpc.unary_unary_rpc_method_handler( - servicer.DestroySession, - request_deserializer=modelresponse__pb2.SessionID.FromString, - response_serializer=google_dot_protobuf_dot_empty__pb2.Empty. - SerializeToString, - ), - 'GeneratorReply': - grpc.unary_unary_rpc_method_handler( - servicer.GeneratorReply, - request_deserializer=modelresponse__pb2.MultiStringRequest.FromString, - response_serializer=modelresponse__pb2.MultiStringReply.SerializeToString, - ), - 'ClassificationReply': - grpc.unary_unary_rpc_method_handler( - servicer.ClassificationReply, - request_deserializer=modelresponse__pb2.SingleStringRequest.FromString, - response_serializer=modelresponse__pb2.SingleStringReply.SerializeToString, - ), - 'QuestionAndAnswerReply': - grpc.unary_unary_rpc_method_handler( - servicer.QuestionAndAnswerReply, - request_deserializer=modelresponse__pb2.QARequest.FromString, - response_serializer=modelresponse__pb2.SingleStringReply.SerializeToString, - ), - 'FillMaskReply': - grpc.unary_unary_rpc_method_handler( - servicer.FillMaskReply, - request_deserializer=modelresponse__pb2.SingleStringRequest.FromString, - response_serializer=modelresponse__pb2.SingleStringReply.SerializeToString, - ), - 'TokenClassificationReply': - grpc.unary_unary_rpc_method_handler( - servicer.TokenClassificationReply, - request_deserializer=modelresponse__pb2.SingleStringRequest.FromString, - response_serializer=modelresponse__pb2.SingleStringReply.SerializeToString, - ), - 'ConversationalReply': - grpc.unary_unary_rpc_method_handler( - servicer.ConversationalReply, - request_deserializer=modelresponse__pb2.ConversationRequest.FromString, - response_serializer=modelresponse__pb2.ConversationReply.SerializeToString, - ), - 'Txt2ImgReply': - grpc.unary_unary_rpc_method_handler( - servicer.Txt2ImgReply, - request_deserializer=modelresponse__pb2.MultiStringRequest.FromString, - response_serializer=modelresponse__pb2.ImageReply.SerializeToString, - ), - 'GeneratorReplyStream': - grpc.unary_stream_rpc_method_handler( - servicer.GeneratorReplyStream, - request_deserializer=modelresponse__pb2.MultiStringRequest.FromString, - response_serializer=modelresponse__pb2.GenerationReply.SerializeToString, - ), + 'Terminate': grpc.unary_unary_rpc_method_handler( + servicer.Terminate, + request_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString, + response_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString, + ), + 'CreateSession': grpc.unary_unary_rpc_method_handler( + servicer.CreateSession, + request_deserializer=modelresponse__pb2.SessionID.FromString, + response_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString, + ), + 'DestroySession': grpc.unary_unary_rpc_method_handler( + servicer.DestroySession, + request_deserializer=modelresponse__pb2.SessionID.FromString, + response_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString, + ), + 'GeneratorReply': grpc.unary_unary_rpc_method_handler( + servicer.GeneratorReply, + request_deserializer=modelresponse__pb2.MultiStringRequest.FromString, + response_serializer=modelresponse__pb2.MultiStringReply.SerializeToString, + ), + 'ClassificationReply': grpc.unary_unary_rpc_method_handler( + servicer.ClassificationReply, + request_deserializer=modelresponse__pb2.SingleStringRequest.FromString, + response_serializer=modelresponse__pb2.SingleStringReply.SerializeToString, + ), + 'QuestionAndAnswerReply': grpc.unary_unary_rpc_method_handler( + servicer.QuestionAndAnswerReply, + request_deserializer=modelresponse__pb2.QARequest.FromString, + response_serializer=modelresponse__pb2.SingleStringReply.SerializeToString, + ), + 'FillMaskReply': grpc.unary_unary_rpc_method_handler( + servicer.FillMaskReply, + request_deserializer=modelresponse__pb2.SingleStringRequest.FromString, + response_serializer=modelresponse__pb2.SingleStringReply.SerializeToString, + ), + 'TokenClassificationReply': grpc.unary_unary_rpc_method_handler( + servicer.TokenClassificationReply, + request_deserializer=modelresponse__pb2.SingleStringRequest.FromString, + response_serializer=modelresponse__pb2.SingleStringReply.SerializeToString, + ), + 'ConversationalReply': grpc.unary_unary_rpc_method_handler( + servicer.ConversationalReply, + request_deserializer=modelresponse__pb2.ConversationRequest.FromString, + response_serializer=modelresponse__pb2.ConversationReply.SerializeToString, + ), + 'Txt2ImgReply': grpc.unary_unary_rpc_method_handler( + servicer.Txt2ImgReply, + request_deserializer=modelresponse__pb2.MultiStringRequest.FromString, + response_serializer=modelresponse__pb2.ImageReply.SerializeToString, + ), + 'GeneratorReplyStream': grpc.unary_stream_rpc_method_handler( + servicer.GeneratorReplyStream, + request_deserializer=modelresponse__pb2.MultiStringRequest.FromString, + response_serializer=modelresponse__pb2.GenerationReply.SerializeToString, + ), } - generic_handler = grpc.method_handlers_generic_handler('modelresponse.ModelResponse', - rpc_method_handlers) - server.add_generic_rpc_handlers((generic_handler, )) + generic_handler = grpc.method_handlers_generic_handler( + 'modelresponse.ModelResponse', rpc_method_handlers) + server.add_generic_rpc_handlers((generic_handler,)) -# This class is part of an EXPERIMENTAL API. + # This class is part of an EXPERIMENTAL API. class ModelResponse(object): """Missing associated documentation comment in .proto file.""" + @staticmethod def Terminate(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary( - request, target, - '/modelresponse.ModelResponse/Terminate', + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/modelresponse.ModelResponse/Terminate', google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString, google_dot_protobuf_dot_empty__pb2.Empty.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata) + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @staticmethod def CreateSession(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary( - request, target, - '/modelresponse.ModelResponse/CreateSession', + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/modelresponse.ModelResponse/CreateSession', modelresponse__pb2.SessionID.SerializeToString, google_dot_protobuf_dot_empty__pb2.Empty.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata) + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @staticmethod def DestroySession(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary( - request, target, - '/modelresponse.ModelResponse/DestroySession', + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/modelresponse.ModelResponse/DestroySession', modelresponse__pb2.SessionID.SerializeToString, google_dot_protobuf_dot_empty__pb2.Empty.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata) + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @staticmethod def GeneratorReply(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary( - request, target, - '/modelresponse.ModelResponse/GeneratorReply', + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/modelresponse.ModelResponse/GeneratorReply', modelresponse__pb2.MultiStringRequest.SerializeToString, modelresponse__pb2.MultiStringReply.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata) + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @staticmethod def ClassificationReply(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary( - request, target, - '/modelresponse.ModelResponse/ClassificationReply', + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/modelresponse.ModelResponse/ClassificationReply', modelresponse__pb2.SingleStringRequest.SerializeToString, modelresponse__pb2.SingleStringReply.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata) + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @staticmethod def QuestionAndAnswerReply(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary( - request, target, - '/modelresponse.ModelResponse/QuestionAndAnswerReply', + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/modelresponse.ModelResponse/QuestionAndAnswerReply', modelresponse__pb2.QARequest.SerializeToString, modelresponse__pb2.SingleStringReply.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata) + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @staticmethod def FillMaskReply(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary( - request, target, - '/modelresponse.ModelResponse/FillMaskReply', + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/modelresponse.ModelResponse/FillMaskReply', modelresponse__pb2.SingleStringRequest.SerializeToString, modelresponse__pb2.SingleStringReply.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata) + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @staticmethod def TokenClassificationReply(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary( - request, target, - '/modelresponse.ModelResponse/TokenClassificationReply', + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/modelresponse.ModelResponse/TokenClassificationReply', modelresponse__pb2.SingleStringRequest.SerializeToString, modelresponse__pb2.SingleStringReply.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata) + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @staticmethod def ConversationalReply(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary( - request, target, - '/modelresponse.ModelResponse/ConversationalReply', + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/modelresponse.ModelResponse/ConversationalReply', modelresponse__pb2.ConversationRequest.SerializeToString, modelresponse__pb2.ConversationReply.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata) + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @staticmethod def Txt2ImgReply(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary( - request, target, - '/modelresponse.ModelResponse/Txt2ImgReply', + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/modelresponse.ModelResponse/Txt2ImgReply', modelresponse__pb2.MultiStringRequest.SerializeToString, modelresponse__pb2.ImageReply.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata) + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @staticmethod def GeneratorReplyStream(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_stream( - request, target, - '/modelresponse.ModelResponse/GeneratorReplyStream', + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_stream(request, target, '/modelresponse.ModelResponse/GeneratorReplyStream', modelresponse__pb2.MultiStringRequest.SerializeToString, modelresponse__pb2.GenerationReply.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata) + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) diff --git a/mii/grpc_related/task_methods.py b/mii/grpc_related/task_methods.py index 7d37805a..1c1f9169 100644 --- a/mii/grpc_related/task_methods.py +++ b/mii/grpc_related/task_methods.py @@ -52,6 +52,7 @@ def pack_response_to_proto(self, response, time_taken, model_time_taken): return response, time_taken, model_time_taken def unpack_response_from_proto(self, response): + print("RESPONSE", response) return response From d6e22e11fd32993b8a435b9762b5ce0d12cb45db Mon Sep 17 00:00:00 2001 From: Michael Wyatt Date: Tue, 14 Nov 2023 14:22:49 -0800 Subject: [PATCH 4/9] refactor task methods --- mii/backend/client.py | 20 +- mii/batching/data_classes.py | 5 +- mii/batching/ragged_batching.py | 5 - mii/grpc_related/modelresponse_server.py | 59 +- mii/grpc_related/proto/modelresponse.proto | 2 +- mii/grpc_related/proto/modelresponse_pb2.py | 105 ++-- .../proto/modelresponse_pb2_grpc.py | 567 +++++++++++------- mii/grpc_related/task_methods.py | 76 ++- 8 files changed, 476 insertions(+), 363 deletions(-) diff --git a/mii/backend/client.py b/mii/backend/client.py index 44da5855..8541aac5 100644 --- a/mii/backend/client.py +++ b/mii/backend/client.py @@ -41,15 +41,15 @@ def __init__(self, mii_config: MIIConfig, host: str = "localhost") -> None: def __call__(self, *args, **kwargs) -> ResponseBatch: return self.generate(*args, **kwargs) - async def _request_async_response(self, request_dict, **query_kwargs): + async def _request_async_response(self, prompts, **query_kwargs): task_methods = TASK_METHODS_DICT[self.task] - proto_request = task_methods.pack_request_to_proto(request_dict, **query_kwargs) + proto_request = task_methods.pack_request_to_proto(prompts, **query_kwargs) proto_response = await getattr(self.stub, task_methods.method)(proto_request) return task_methods.unpack_response_from_proto(proto_response) - async def _request_async_response_stream(self, request_dict, **query_kwargs): + async def _request_async_response_stream(self, prompts, **query_kwargs): task_methods = TASK_METHODS_DICT[self.task] - proto_request = task_methods.pack_request_to_proto(request_dict, **query_kwargs) + proto_request = task_methods.pack_request_to_proto(prompts, **query_kwargs) assert hasattr(task_methods, "method_stream_out"), f"{self.task} does not support streaming response" async for response in getattr(self.stub, task_methods.method_stream_out)(proto_request): @@ -67,23 +67,21 @@ def generate(self, if len(prompts) > 1: raise RuntimeError( "MII client streaming only supports a single prompt input.") - request_dict = {"query": prompts} - return self._generate_stream(streaming_fn, request_dict, **query_kwargs) + query_kwargs["stream"] = True + return self._generate_stream(streaming_fn, prompts, **query_kwargs) - request_dict = {"query": prompts} return self.asyncio_loop.run_until_complete( - self._request_async_response(request_dict, + self._request_async_response(prompts, **query_kwargs)) def _generate_stream(self, callback, - request_dict: Dict[str, - str], + prompts: List[str], **query_kwargs: Dict[str, Any]): async def put_result(): response_stream = self._request_async_response_stream( - request_dict, + prompts, **query_kwargs) while True: diff --git a/mii/batching/data_classes.py b/mii/batching/data_classes.py index a6f524c3..53b0078e 100644 --- a/mii/batching/data_classes.py +++ b/mii/batching/data_classes.py @@ -32,7 +32,7 @@ def to_msg(self) -> Dict[str, Union[str, int]]: "generated_text": self.generated_text, "prompt_length": self.prompt_length, "generated_length": self.generated_length, - "finish_reason": self.finish_reason.value + "finish_reason": self.finish_reason } def __repr__(self) -> str: @@ -166,6 +166,9 @@ def __init__(self, responses: List[Response] = []) -> None: def __iter__(self) -> Iterator[Response]: return iter(self.responses) + def __str__(self) -> str: + return str(self.responses) + def __repr__(self) -> str: return self.responses return "\n\n".join(str(r) for r in self.responses) diff --git a/mii/batching/ragged_batching.py b/mii/batching/ragged_batching.py index 9823631f..61a4323b 100644 --- a/mii/batching/ragged_batching.py +++ b/mii/batching/ragged_batching.py @@ -588,11 +588,6 @@ def put_request(self, prompt: str, kwargs: Dict) -> int: return uid - def is_response_ready(self, uid: int) -> bool: - if not self.is_rank_0: - return True - return not self.result_queues[uid].empty() - def get_response(self) -> Tuple[int, Response]: # TODO: We should avoid any request/response work with non-rank 0, but # this requires some refactoring how we do the put and request in diff --git a/mii/grpc_related/modelresponse_server.py b/mii/grpc_related/modelresponse_server.py index 113a3ae2..8a8a68d7 100644 --- a/mii/grpc_related/modelresponse_server.py +++ b/mii/grpc_related/modelresponse_server.py @@ -3,29 +3,26 @@ # DeepSpeed Team import asyncio +import queue +import sys +import threading from concurrent import futures -import logging +from typing import Dict import grpc - from google.protobuf import empty_pb2 as google_dot_protobuf_dot_empty__pb2 -from .proto import modelresponse_pb2_grpc -import sys -import threading -import time -import queue +from mii.backend.client import create_channel from mii.constants import ( + GenerationFinishReason, GRPC_MAX_MSG_SIZE, TERMINATE_METHOD, LB_MAX_WORKER_THREADS, SERVER_SHUTDOWN_TIMEOUT, STREAM_RESPONSE_QUEUE_TIMEOUT, ) -from mii.grpc_related.task_methods import TASK_METHODS_DICT -from mii.backend.client import create_channel - -from mii.constants import GenerationFinishReason +from mii.grpc_related.proto import modelresponse_pb2_grpc +from mii.grpc_related.task_methods import TASK_METHODS_DICT, TaskMethods class ServiceBase(modelresponse_pb2_grpc.ModelResponseServicer): @@ -53,7 +50,7 @@ def __init__(self, async_pipeline=None): self.method_name_to_task = {m.method: t for t, m in TASK_METHODS_DICT.items()} self.lock = threading.Lock() - def _run_inference(self, method_name, request_proto): + def _get_task_methods(self, method_name: str) -> Dict[str, TaskMethods]: if method_name not in self.method_name_to_task: raise ValueError(f"unknown method: {method_name}") @@ -62,12 +59,14 @@ def _run_inference(self, method_name, request_proto): raise ValueError(f"unknown task: {task}") task_methods = TASK_METHODS_DICT[task] - prompts, kwargs = task_methods.unpack_request_from_proto(request_proto) + return task_methods + + def GeneratorReply(self, request, context): + task_methods = self._get_task_methods("GeneratorReply") + + prompts, kwargs = task_methods.unpack_request_from_proto(request) + uids_running, uids_complete_order, responses = [], [], [] - start = time.time() - uids_running = [] - uids_complete_order = [] - responses = [] # Put requests for all prompts into the pipeline for p in prompts: request_kwargs = kwargs.copy() @@ -85,7 +84,6 @@ def _run_inference(self, method_name, request_proto): self.inference_pipeline.flush_uid(uid) uids_complete_order.append(uids_running.index(uid)) uids_running.remove(uid) - end = time.time() # Sort responses in the order of prompts responses = [ @@ -95,31 +93,19 @@ def _run_inference(self, method_name, request_proto): key=lambda pair: pair[0]) ] - return task_methods.pack_response_to_proto(responses, end - start, -1) - - def GeneratorReply(self, request, context): - return self._run_inference("GeneratorReply", request) - - def _run_inference_stream(self, method_name, request_proto) -> int: - task = self.method_name_to_task[method_name] - task_methods = TASK_METHODS_DICT[task] - prompts, kwargs = task_methods.unpack_request_from_proto(request_proto) - - kwargs["stream"] = True - # NOTE: Streaming handle only single prompt inputs - return self.inference_pipeline.put_request(prompts[0], kwargs) + return task_methods.pack_response_to_proto(responses) def GeneratorReplyStream(self, request, context): - method_name = "GeneratorReply" - task = self.method_name_to_task[method_name] - task_methods = TASK_METHODS_DICT[task] + task_methods = self._get_task_methods("GeneratorReply") + + prompts, kwargs = task_methods.unpack_request_from_proto(request) + uid = self.inference_pipeline.put_request(prompts[0], kwargs) - uid = self._run_inference_stream(method_name, request) while True: response_uid, r = self.inference_pipeline.get_response() assert uid == response_uid, "uid mismatch" done = r.finish_reason != GenerationFinishReason.NONE - response = task_methods.pack_response_to_proto([r], 0.0, 0.0) + response = task_methods.pack_response_to_proto([r]) yield response if done: break @@ -302,5 +288,6 @@ def serve_load_balancing(model_config, lb_port): if __name__ == "__main__": + import logging logging.basicConfig() serve_inference(None, sys.argv[1]) diff --git a/mii/grpc_related/proto/modelresponse.proto b/mii/grpc_related/proto/modelresponse.proto index c2d0899f..5ad1f194 100644 --- a/mii/grpc_related/proto/modelresponse.proto +++ b/mii/grpc_related/proto/modelresponse.proto @@ -27,7 +27,7 @@ service ModelResponse { rpc Terminate (google.protobuf.Empty) returns (google.protobuf.Empty) {} rpc CreateSession (SessionID) returns (google.protobuf.Empty) {} rpc DestroySession (SessionID) returns (google.protobuf.Empty) {} - rpc GeneratorReply (MultiStringRequest) returns (MultiStringReply) {} + rpc GeneratorReply (MultiStringRequest) returns (GenerationReply) {} rpc ClassificationReply (SingleStringRequest) returns (SingleStringReply) {} rpc QuestionAndAnswerReply(QARequest) returns (SingleStringReply) {} rpc FillMaskReply(SingleStringRequest) returns (SingleStringReply) {} diff --git a/mii/grpc_related/proto/modelresponse_pb2.py b/mii/grpc_related/proto/modelresponse_pb2.py index cad660f3..88505cdd 100644 --- a/mii/grpc_related/proto/modelresponse_pb2.py +++ b/mii/grpc_related/proto/modelresponse_pb2.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! # source: modelresponse.proto """Generated protocol buffer code.""" @@ -10,63 +9,63 @@ _sym_db = _symbol_database.Default() - from google.protobuf import empty_pb2 as google_dot_protobuf_dot_empty__pb2 - -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x13modelresponse.proto\x12\rmodelresponse\x1a\x1bgoogle/protobuf/empty.proto\"\x88\x01\n\nDictionary\x12\x35\n\x06values\x18\x01 \x03(\x0b\x32%.modelresponse.Dictionary.ValuesEntry\x1a\x43\n\x0bValuesEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.modelresponse.Value:\x02\x38\x01\"\x8c\x01\n\x05Value\x12\x10\n\x06svalue\x18\x01 \x01(\tH\x00\x12\x10\n\x06ivalue\x18\x02 \x01(\x03H\x00\x12\x10\n\x06\x66value\x18\x03 \x01(\x02H\x00\x12\x10\n\x06\x62value\x18\x04 \x01(\x08H\x00\x12+\n\x06mvalue\x18\x05 \x01(\x0b\x32\x19.modelresponse.DictionaryH\x00\x42\x0e\n\x0coneof_values\"\x1f\n\tSessionID\x12\x12\n\nsession_id\x18\x01 \x01(\t\"\xbb\x01\n\x13SingleStringRequest\x12\x0f\n\x07request\x18\x01 \x01(\t\x12I\n\x0cquery_kwargs\x18\x02 \x03(\x0b\x32\x33.modelresponse.SingleStringRequest.QueryKwargsEntry\x1aH\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.modelresponse.Value:\x02\x38\x01\"\xb9\x01\n\x12MultiStringRequest\x12\x0f\n\x07request\x18\x01 \x03(\t\x12H\n\x0cquery_kwargs\x18\x02 \x03(\x0b\x32\x32.modelresponse.MultiStringRequest.QueryKwargsEntry\x1aH\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.modelresponse.Value:\x02\x38\x01\"S\n\x11SingleStringReply\x12\x10\n\x08response\x18\x01 \x01(\t\x12\x12\n\ntime_taken\x18\x02 \x01(\x02\x12\x18\n\x10model_time_taken\x18\x03 \x01(\x02\"R\n\x10MultiStringReply\x12\x10\n\x08response\x18\x01 \x03(\t\x12\x12\n\ntime_taken\x18\x02 \x01(\x02\x12\x18\n\x10model_time_taken\x18\x03 \x01(\x02\"[\n\x11GenerationDetails\x12\x15\n\rfinish_reason\x18\x01 \x01(\t\x12\x15\n\rprompt_tokens\x18\x02 \x01(\x03\x12\x18\n\x10generated_tokens\x18\x03 \x01(\x03\"\x95\x01\n\x0fGenerationReply\x12\x10\n\x08response\x18\x01 \x03(\t\x12\x0f\n\x07indices\x18\x02 \x03(\x03\x12\x31\n\x07\x64\x65tails\x18\x03 \x03(\x0b\x32 .modelresponse.GenerationDetails\x12\x12\n\ntime_taken\x18\x04 \x01(\x02\x12\x18\n\x10model_time_taken\x18\x05 \x01(\x02\"\xb9\x01\n\tQARequest\x12\x10\n\x08question\x18\x01 \x01(\t\x12\x0f\n\x07\x63ontext\x18\x02 \x01(\t\x12?\n\x0cquery_kwargs\x18\x03 \x03(\x0b\x32).modelresponse.QARequest.QueryKwargsEntry\x1aH\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.modelresponse.Value:\x02\x38\x01\"\x88\x02\n\x13\x43onversationRequest\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\x17\n\x0f\x63onversation_id\x18\x02 \x01(\t\x12\x18\n\x10past_user_inputs\x18\x03 \x03(\t\x12\x1b\n\x13generated_responses\x18\x04 \x03(\t\x12I\n\x0cquery_kwargs\x18\x05 \x03(\x0b\x32\x33.modelresponse.ConversationRequest.QueryKwargsEntry\x1aH\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.modelresponse.Value:\x02\x38\x01\"\x91\x01\n\x11\x43onversationReply\x12\x17\n\x0f\x63onversation_id\x18\x01 \x01(\t\x12\x18\n\x10past_user_inputs\x18\x02 \x03(\t\x12\x1b\n\x13generated_responses\x18\x03 \x03(\t\x12\x12\n\ntime_taken\x18\x04 \x01(\x02\x12\x18\n\x10model_time_taken\x18\x05 \x01(\x02\"}\n\nImageReply\x12\x0e\n\x06images\x18\x01 \x03(\x0c\x12\x1d\n\x15nsfw_content_detected\x18\x02 \x03(\x08\x12\x0c\n\x04mode\x18\x03 \x01(\t\x12\x0e\n\x06size_w\x18\x04 \x01(\x03\x12\x0e\n\x06size_h\x18\x05 \x01(\x03\x12\x12\n\ntime_taken\x18\x06 \x01(\x02\x32\xb3\x07\n\rModelResponse\x12=\n\tTerminate\x12\x16.google.protobuf.Empty\x1a\x16.google.protobuf.Empty\"\x00\x12\x43\n\rCreateSession\x12\x18.modelresponse.SessionID\x1a\x16.google.protobuf.Empty\"\x00\x12\x44\n\x0e\x44\x65stroySession\x12\x18.modelresponse.SessionID\x1a\x16.google.protobuf.Empty\"\x00\x12V\n\x0eGeneratorReply\x12!.modelresponse.MultiStringRequest\x1a\x1f.modelresponse.MultiStringReply\"\x00\x12]\n\x13\x43lassificationReply\x12\".modelresponse.SingleStringRequest\x1a .modelresponse.SingleStringReply\"\x00\x12V\n\x16QuestionAndAnswerReply\x12\x18.modelresponse.QARequest\x1a .modelresponse.SingleStringReply\"\x00\x12W\n\rFillMaskReply\x12\".modelresponse.SingleStringRequest\x1a .modelresponse.SingleStringReply\"\x00\x12\x62\n\x18TokenClassificationReply\x12\".modelresponse.SingleStringRequest\x1a .modelresponse.SingleStringReply\"\x00\x12]\n\x13\x43onversationalReply\x12\".modelresponse.ConversationRequest\x1a .modelresponse.ConversationReply\"\x00\x12N\n\x0cTxt2ImgReply\x12!.modelresponse.MultiStringRequest\x1a\x19.modelresponse.ImageReply\"\x00\x12]\n\x14GeneratorReplyStream\x12!.modelresponse.MultiStringRequest\x1a\x1e.modelresponse.GenerationReply\"\x00\x30\x01\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( + b'\n\x13modelresponse.proto\x12\rmodelresponse\x1a\x1bgoogle/protobuf/empty.proto\"\x88\x01\n\nDictionary\x12\x35\n\x06values\x18\x01 \x03(\x0b\x32%.modelresponse.Dictionary.ValuesEntry\x1a\x43\n\x0bValuesEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.modelresponse.Value:\x02\x38\x01\"\x8c\x01\n\x05Value\x12\x10\n\x06svalue\x18\x01 \x01(\tH\x00\x12\x10\n\x06ivalue\x18\x02 \x01(\x03H\x00\x12\x10\n\x06\x66value\x18\x03 \x01(\x02H\x00\x12\x10\n\x06\x62value\x18\x04 \x01(\x08H\x00\x12+\n\x06mvalue\x18\x05 \x01(\x0b\x32\x19.modelresponse.DictionaryH\x00\x42\x0e\n\x0coneof_values\"\x1f\n\tSessionID\x12\x12\n\nsession_id\x18\x01 \x01(\t\"\xbb\x01\n\x13SingleStringRequest\x12\x0f\n\x07request\x18\x01 \x01(\t\x12I\n\x0cquery_kwargs\x18\x02 \x03(\x0b\x32\x33.modelresponse.SingleStringRequest.QueryKwargsEntry\x1aH\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.modelresponse.Value:\x02\x38\x01\"\xb9\x01\n\x12MultiStringRequest\x12\x0f\n\x07request\x18\x01 \x03(\t\x12H\n\x0cquery_kwargs\x18\x02 \x03(\x0b\x32\x32.modelresponse.MultiStringRequest.QueryKwargsEntry\x1aH\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.modelresponse.Value:\x02\x38\x01\"S\n\x11SingleStringReply\x12\x10\n\x08response\x18\x01 \x01(\t\x12\x12\n\ntime_taken\x18\x02 \x01(\x02\x12\x18\n\x10model_time_taken\x18\x03 \x01(\x02\"R\n\x10MultiStringReply\x12\x10\n\x08response\x18\x01 \x03(\t\x12\x12\n\ntime_taken\x18\x02 \x01(\x02\x12\x18\n\x10model_time_taken\x18\x03 \x01(\x02\"[\n\x11GenerationDetails\x12\x15\n\rfinish_reason\x18\x01 \x01(\t\x12\x15\n\rprompt_tokens\x18\x02 \x01(\x03\x12\x18\n\x10generated_tokens\x18\x03 \x01(\x03\"\x95\x01\n\x0fGenerationReply\x12\x10\n\x08response\x18\x01 \x03(\t\x12\x0f\n\x07indices\x18\x02 \x03(\x03\x12\x31\n\x07\x64\x65tails\x18\x03 \x03(\x0b\x32 .modelresponse.GenerationDetails\x12\x12\n\ntime_taken\x18\x04 \x01(\x02\x12\x18\n\x10model_time_taken\x18\x05 \x01(\x02\"\xb9\x01\n\tQARequest\x12\x10\n\x08question\x18\x01 \x01(\t\x12\x0f\n\x07\x63ontext\x18\x02 \x01(\t\x12?\n\x0cquery_kwargs\x18\x03 \x03(\x0b\x32).modelresponse.QARequest.QueryKwargsEntry\x1aH\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.modelresponse.Value:\x02\x38\x01\"\x88\x02\n\x13\x43onversationRequest\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\x17\n\x0f\x63onversation_id\x18\x02 \x01(\t\x12\x18\n\x10past_user_inputs\x18\x03 \x03(\t\x12\x1b\n\x13generated_responses\x18\x04 \x03(\t\x12I\n\x0cquery_kwargs\x18\x05 \x03(\x0b\x32\x33.modelresponse.ConversationRequest.QueryKwargsEntry\x1aH\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.modelresponse.Value:\x02\x38\x01\"\x91\x01\n\x11\x43onversationReply\x12\x17\n\x0f\x63onversation_id\x18\x01 \x01(\t\x12\x18\n\x10past_user_inputs\x18\x02 \x03(\t\x12\x1b\n\x13generated_responses\x18\x03 \x03(\t\x12\x12\n\ntime_taken\x18\x04 \x01(\x02\x12\x18\n\x10model_time_taken\x18\x05 \x01(\x02\"}\n\nImageReply\x12\x0e\n\x06images\x18\x01 \x03(\x0c\x12\x1d\n\x15nsfw_content_detected\x18\x02 \x03(\x08\x12\x0c\n\x04mode\x18\x03 \x01(\t\x12\x0e\n\x06size_w\x18\x04 \x01(\x03\x12\x0e\n\x06size_h\x18\x05 \x01(\x03\x12\x12\n\ntime_taken\x18\x06 \x01(\x02\x32\xb2\x07\n\rModelResponse\x12=\n\tTerminate\x12\x16.google.protobuf.Empty\x1a\x16.google.protobuf.Empty\"\x00\x12\x43\n\rCreateSession\x12\x18.modelresponse.SessionID\x1a\x16.google.protobuf.Empty\"\x00\x12\x44\n\x0e\x44\x65stroySession\x12\x18.modelresponse.SessionID\x1a\x16.google.protobuf.Empty\"\x00\x12U\n\x0eGeneratorReply\x12!.modelresponse.MultiStringRequest\x1a\x1e.modelresponse.GenerationReply\"\x00\x12]\n\x13\x43lassificationReply\x12\".modelresponse.SingleStringRequest\x1a .modelresponse.SingleStringReply\"\x00\x12V\n\x16QuestionAndAnswerReply\x12\x18.modelresponse.QARequest\x1a .modelresponse.SingleStringReply\"\x00\x12W\n\rFillMaskReply\x12\".modelresponse.SingleStringRequest\x1a .modelresponse.SingleStringReply\"\x00\x12\x62\n\x18TokenClassificationReply\x12\".modelresponse.SingleStringRequest\x1a .modelresponse.SingleStringReply\"\x00\x12]\n\x13\x43onversationalReply\x12\".modelresponse.ConversationRequest\x1a .modelresponse.ConversationReply\"\x00\x12N\n\x0cTxt2ImgReply\x12!.modelresponse.MultiStringRequest\x1a\x19.modelresponse.ImageReply\"\x00\x12]\n\x14GeneratorReplyStream\x12!.modelresponse.MultiStringRequest\x1a\x1e.modelresponse.GenerationReply\"\x00\x30\x01\x62\x06proto3' +) _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'modelresponse_pb2', _globals) if _descriptor._USE_C_DESCRIPTORS == False: - DESCRIPTOR._options = None - _DICTIONARY_VALUESENTRY._options = None - _DICTIONARY_VALUESENTRY._serialized_options = b'8\001' - _SINGLESTRINGREQUEST_QUERYKWARGSENTRY._options = None - _SINGLESTRINGREQUEST_QUERYKWARGSENTRY._serialized_options = b'8\001' - _MULTISTRINGREQUEST_QUERYKWARGSENTRY._options = None - _MULTISTRINGREQUEST_QUERYKWARGSENTRY._serialized_options = b'8\001' - _QAREQUEST_QUERYKWARGSENTRY._options = None - _QAREQUEST_QUERYKWARGSENTRY._serialized_options = b'8\001' - _CONVERSATIONREQUEST_QUERYKWARGSENTRY._options = None - _CONVERSATIONREQUEST_QUERYKWARGSENTRY._serialized_options = b'8\001' - _globals['_DICTIONARY']._serialized_start=68 - _globals['_DICTIONARY']._serialized_end=204 - _globals['_DICTIONARY_VALUESENTRY']._serialized_start=137 - _globals['_DICTIONARY_VALUESENTRY']._serialized_end=204 - _globals['_VALUE']._serialized_start=207 - _globals['_VALUE']._serialized_end=347 - _globals['_SESSIONID']._serialized_start=349 - _globals['_SESSIONID']._serialized_end=380 - _globals['_SINGLESTRINGREQUEST']._serialized_start=383 - _globals['_SINGLESTRINGREQUEST']._serialized_end=570 - _globals['_SINGLESTRINGREQUEST_QUERYKWARGSENTRY']._serialized_start=498 - _globals['_SINGLESTRINGREQUEST_QUERYKWARGSENTRY']._serialized_end=570 - _globals['_MULTISTRINGREQUEST']._serialized_start=573 - _globals['_MULTISTRINGREQUEST']._serialized_end=758 - _globals['_MULTISTRINGREQUEST_QUERYKWARGSENTRY']._serialized_start=498 - _globals['_MULTISTRINGREQUEST_QUERYKWARGSENTRY']._serialized_end=570 - _globals['_SINGLESTRINGREPLY']._serialized_start=760 - _globals['_SINGLESTRINGREPLY']._serialized_end=843 - _globals['_MULTISTRINGREPLY']._serialized_start=845 - _globals['_MULTISTRINGREPLY']._serialized_end=927 - _globals['_GENERATIONDETAILS']._serialized_start=929 - _globals['_GENERATIONDETAILS']._serialized_end=1020 - _globals['_GENERATIONREPLY']._serialized_start=1023 - _globals['_GENERATIONREPLY']._serialized_end=1172 - _globals['_QAREQUEST']._serialized_start=1175 - _globals['_QAREQUEST']._serialized_end=1360 - _globals['_QAREQUEST_QUERYKWARGSENTRY']._serialized_start=498 - _globals['_QAREQUEST_QUERYKWARGSENTRY']._serialized_end=570 - _globals['_CONVERSATIONREQUEST']._serialized_start=1363 - _globals['_CONVERSATIONREQUEST']._serialized_end=1627 - _globals['_CONVERSATIONREQUEST_QUERYKWARGSENTRY']._serialized_start=498 - _globals['_CONVERSATIONREQUEST_QUERYKWARGSENTRY']._serialized_end=570 - _globals['_CONVERSATIONREPLY']._serialized_start=1630 - _globals['_CONVERSATIONREPLY']._serialized_end=1775 - _globals['_IMAGEREPLY']._serialized_start=1777 - _globals['_IMAGEREPLY']._serialized_end=1902 - _globals['_MODELRESPONSE']._serialized_start=1905 - _globals['_MODELRESPONSE']._serialized_end=2852 + DESCRIPTOR._options = None + _DICTIONARY_VALUESENTRY._options = None + _DICTIONARY_VALUESENTRY._serialized_options = b'8\001' + _SINGLESTRINGREQUEST_QUERYKWARGSENTRY._options = None + _SINGLESTRINGREQUEST_QUERYKWARGSENTRY._serialized_options = b'8\001' + _MULTISTRINGREQUEST_QUERYKWARGSENTRY._options = None + _MULTISTRINGREQUEST_QUERYKWARGSENTRY._serialized_options = b'8\001' + _QAREQUEST_QUERYKWARGSENTRY._options = None + _QAREQUEST_QUERYKWARGSENTRY._serialized_options = b'8\001' + _CONVERSATIONREQUEST_QUERYKWARGSENTRY._options = None + _CONVERSATIONREQUEST_QUERYKWARGSENTRY._serialized_options = b'8\001' + _globals['_DICTIONARY']._serialized_start = 68 + _globals['_DICTIONARY']._serialized_end = 204 + _globals['_DICTIONARY_VALUESENTRY']._serialized_start = 137 + _globals['_DICTIONARY_VALUESENTRY']._serialized_end = 204 + _globals['_VALUE']._serialized_start = 207 + _globals['_VALUE']._serialized_end = 347 + _globals['_SESSIONID']._serialized_start = 349 + _globals['_SESSIONID']._serialized_end = 380 + _globals['_SINGLESTRINGREQUEST']._serialized_start = 383 + _globals['_SINGLESTRINGREQUEST']._serialized_end = 570 + _globals['_SINGLESTRINGREQUEST_QUERYKWARGSENTRY']._serialized_start = 498 + _globals['_SINGLESTRINGREQUEST_QUERYKWARGSENTRY']._serialized_end = 570 + _globals['_MULTISTRINGREQUEST']._serialized_start = 573 + _globals['_MULTISTRINGREQUEST']._serialized_end = 758 + _globals['_MULTISTRINGREQUEST_QUERYKWARGSENTRY']._serialized_start = 498 + _globals['_MULTISTRINGREQUEST_QUERYKWARGSENTRY']._serialized_end = 570 + _globals['_SINGLESTRINGREPLY']._serialized_start = 760 + _globals['_SINGLESTRINGREPLY']._serialized_end = 843 + _globals['_MULTISTRINGREPLY']._serialized_start = 845 + _globals['_MULTISTRINGREPLY']._serialized_end = 927 + _globals['_GENERATIONDETAILS']._serialized_start = 929 + _globals['_GENERATIONDETAILS']._serialized_end = 1020 + _globals['_GENERATIONREPLY']._serialized_start = 1023 + _globals['_GENERATIONREPLY']._serialized_end = 1172 + _globals['_QAREQUEST']._serialized_start = 1175 + _globals['_QAREQUEST']._serialized_end = 1360 + _globals['_QAREQUEST_QUERYKWARGSENTRY']._serialized_start = 498 + _globals['_QAREQUEST_QUERYKWARGSENTRY']._serialized_end = 570 + _globals['_CONVERSATIONREQUEST']._serialized_start = 1363 + _globals['_CONVERSATIONREQUEST']._serialized_end = 1627 + _globals['_CONVERSATIONREQUEST_QUERYKWARGSENTRY']._serialized_start = 498 + _globals['_CONVERSATIONREQUEST_QUERYKWARGSENTRY']._serialized_end = 570 + _globals['_CONVERSATIONREPLY']._serialized_start = 1630 + _globals['_CONVERSATIONREPLY']._serialized_end = 1775 + _globals['_IMAGEREPLY']._serialized_start = 1777 + _globals['_IMAGEREPLY']._serialized_end = 1902 + _globals['_MODELRESPONSE']._serialized_start = 1905 + _globals['_MODELRESPONSE']._serialized_end = 2851 # @@protoc_insertion_point(module_scope) diff --git a/mii/grpc_related/proto/modelresponse_pb2_grpc.py b/mii/grpc_related/proto/modelresponse_pb2_grpc.py index 48df7c9d..e94ec498 100644 --- a/mii/grpc_related/proto/modelresponse_pb2_grpc.py +++ b/mii/grpc_related/proto/modelresponse_pb2_grpc.py @@ -8,7 +8,6 @@ class ModelResponseStub(object): """Missing associated documentation comment in .proto file.""" - def __init__(self, channel): """Constructor. @@ -16,65 +15,65 @@ def __init__(self, channel): channel: A grpc.Channel. """ self.Terminate = channel.unary_unary( - '/modelresponse.ModelResponse/Terminate', - request_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString, - response_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString, - ) + '/modelresponse.ModelResponse/Terminate', + request_serializer=google_dot_protobuf_dot_empty__pb2.Empty. + SerializeToString, + response_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString, + ) self.CreateSession = channel.unary_unary( - '/modelresponse.ModelResponse/CreateSession', - request_serializer=modelresponse__pb2.SessionID.SerializeToString, - response_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString, - ) + '/modelresponse.ModelResponse/CreateSession', + request_serializer=modelresponse__pb2.SessionID.SerializeToString, + response_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString, + ) self.DestroySession = channel.unary_unary( - '/modelresponse.ModelResponse/DestroySession', - request_serializer=modelresponse__pb2.SessionID.SerializeToString, - response_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString, - ) + '/modelresponse.ModelResponse/DestroySession', + request_serializer=modelresponse__pb2.SessionID.SerializeToString, + response_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString, + ) self.GeneratorReply = channel.unary_unary( - '/modelresponse.ModelResponse/GeneratorReply', - request_serializer=modelresponse__pb2.MultiStringRequest.SerializeToString, - response_deserializer=modelresponse__pb2.MultiStringReply.FromString, - ) + '/modelresponse.ModelResponse/GeneratorReply', + request_serializer=modelresponse__pb2.MultiStringRequest.SerializeToString, + response_deserializer=modelresponse__pb2.GenerationReply.FromString, + ) self.ClassificationReply = channel.unary_unary( - '/modelresponse.ModelResponse/ClassificationReply', - request_serializer=modelresponse__pb2.SingleStringRequest.SerializeToString, - response_deserializer=modelresponse__pb2.SingleStringReply.FromString, - ) + '/modelresponse.ModelResponse/ClassificationReply', + request_serializer=modelresponse__pb2.SingleStringRequest.SerializeToString, + response_deserializer=modelresponse__pb2.SingleStringReply.FromString, + ) self.QuestionAndAnswerReply = channel.unary_unary( - '/modelresponse.ModelResponse/QuestionAndAnswerReply', - request_serializer=modelresponse__pb2.QARequest.SerializeToString, - response_deserializer=modelresponse__pb2.SingleStringReply.FromString, - ) + '/modelresponse.ModelResponse/QuestionAndAnswerReply', + request_serializer=modelresponse__pb2.QARequest.SerializeToString, + response_deserializer=modelresponse__pb2.SingleStringReply.FromString, + ) self.FillMaskReply = channel.unary_unary( - '/modelresponse.ModelResponse/FillMaskReply', - request_serializer=modelresponse__pb2.SingleStringRequest.SerializeToString, - response_deserializer=modelresponse__pb2.SingleStringReply.FromString, - ) + '/modelresponse.ModelResponse/FillMaskReply', + request_serializer=modelresponse__pb2.SingleStringRequest.SerializeToString, + response_deserializer=modelresponse__pb2.SingleStringReply.FromString, + ) self.TokenClassificationReply = channel.unary_unary( - '/modelresponse.ModelResponse/TokenClassificationReply', - request_serializer=modelresponse__pb2.SingleStringRequest.SerializeToString, - response_deserializer=modelresponse__pb2.SingleStringReply.FromString, - ) + '/modelresponse.ModelResponse/TokenClassificationReply', + request_serializer=modelresponse__pb2.SingleStringRequest.SerializeToString, + response_deserializer=modelresponse__pb2.SingleStringReply.FromString, + ) self.ConversationalReply = channel.unary_unary( - '/modelresponse.ModelResponse/ConversationalReply', - request_serializer=modelresponse__pb2.ConversationRequest.SerializeToString, - response_deserializer=modelresponse__pb2.ConversationReply.FromString, - ) + '/modelresponse.ModelResponse/ConversationalReply', + request_serializer=modelresponse__pb2.ConversationRequest.SerializeToString, + response_deserializer=modelresponse__pb2.ConversationReply.FromString, + ) self.Txt2ImgReply = channel.unary_unary( - '/modelresponse.ModelResponse/Txt2ImgReply', - request_serializer=modelresponse__pb2.MultiStringRequest.SerializeToString, - response_deserializer=modelresponse__pb2.ImageReply.FromString, - ) + '/modelresponse.ModelResponse/Txt2ImgReply', + request_serializer=modelresponse__pb2.MultiStringRequest.SerializeToString, + response_deserializer=modelresponse__pb2.ImageReply.FromString, + ) self.GeneratorReplyStream = channel.unary_stream( - '/modelresponse.ModelResponse/GeneratorReplyStream', - request_serializer=modelresponse__pb2.MultiStringRequest.SerializeToString, - response_deserializer=modelresponse__pb2.GenerationReply.FromString, - ) + '/modelresponse.ModelResponse/GeneratorReplyStream', + request_serializer=modelresponse__pb2.MultiStringRequest.SerializeToString, + response_deserializer=modelresponse__pb2.GenerationReply.FromString, + ) class ModelResponseServicer(object): """Missing associated documentation comment in .proto file.""" - def Terminate(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) @@ -144,254 +143,366 @@ def GeneratorReplyStream(self, request, context): def add_ModelResponseServicer_to_server(servicer, server): rpc_method_handlers = { - 'Terminate': grpc.unary_unary_rpc_method_handler( - servicer.Terminate, - request_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString, - response_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString, - ), - 'CreateSession': grpc.unary_unary_rpc_method_handler( - servicer.CreateSession, - request_deserializer=modelresponse__pb2.SessionID.FromString, - response_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString, - ), - 'DestroySession': grpc.unary_unary_rpc_method_handler( - servicer.DestroySession, - request_deserializer=modelresponse__pb2.SessionID.FromString, - response_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString, - ), - 'GeneratorReply': grpc.unary_unary_rpc_method_handler( - servicer.GeneratorReply, - request_deserializer=modelresponse__pb2.MultiStringRequest.FromString, - response_serializer=modelresponse__pb2.MultiStringReply.SerializeToString, - ), - 'ClassificationReply': grpc.unary_unary_rpc_method_handler( - servicer.ClassificationReply, - request_deserializer=modelresponse__pb2.SingleStringRequest.FromString, - response_serializer=modelresponse__pb2.SingleStringReply.SerializeToString, - ), - 'QuestionAndAnswerReply': grpc.unary_unary_rpc_method_handler( - servicer.QuestionAndAnswerReply, - request_deserializer=modelresponse__pb2.QARequest.FromString, - response_serializer=modelresponse__pb2.SingleStringReply.SerializeToString, - ), - 'FillMaskReply': grpc.unary_unary_rpc_method_handler( - servicer.FillMaskReply, - request_deserializer=modelresponse__pb2.SingleStringRequest.FromString, - response_serializer=modelresponse__pb2.SingleStringReply.SerializeToString, - ), - 'TokenClassificationReply': grpc.unary_unary_rpc_method_handler( - servicer.TokenClassificationReply, - request_deserializer=modelresponse__pb2.SingleStringRequest.FromString, - response_serializer=modelresponse__pb2.SingleStringReply.SerializeToString, - ), - 'ConversationalReply': grpc.unary_unary_rpc_method_handler( - servicer.ConversationalReply, - request_deserializer=modelresponse__pb2.ConversationRequest.FromString, - response_serializer=modelresponse__pb2.ConversationReply.SerializeToString, - ), - 'Txt2ImgReply': grpc.unary_unary_rpc_method_handler( - servicer.Txt2ImgReply, - request_deserializer=modelresponse__pb2.MultiStringRequest.FromString, - response_serializer=modelresponse__pb2.ImageReply.SerializeToString, - ), - 'GeneratorReplyStream': grpc.unary_stream_rpc_method_handler( - servicer.GeneratorReplyStream, - request_deserializer=modelresponse__pb2.MultiStringRequest.FromString, - response_serializer=modelresponse__pb2.GenerationReply.SerializeToString, - ), + 'Terminate': + grpc.unary_unary_rpc_method_handler( + servicer.Terminate, + request_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString, + response_serializer=google_dot_protobuf_dot_empty__pb2.Empty. + SerializeToString, + ), + 'CreateSession': + grpc.unary_unary_rpc_method_handler( + servicer.CreateSession, + request_deserializer=modelresponse__pb2.SessionID.FromString, + response_serializer=google_dot_protobuf_dot_empty__pb2.Empty. + SerializeToString, + ), + 'DestroySession': + grpc.unary_unary_rpc_method_handler( + servicer.DestroySession, + request_deserializer=modelresponse__pb2.SessionID.FromString, + response_serializer=google_dot_protobuf_dot_empty__pb2.Empty. + SerializeToString, + ), + 'GeneratorReply': + grpc.unary_unary_rpc_method_handler( + servicer.GeneratorReply, + request_deserializer=modelresponse__pb2.MultiStringRequest.FromString, + response_serializer=modelresponse__pb2.GenerationReply.SerializeToString, + ), + 'ClassificationReply': + grpc.unary_unary_rpc_method_handler( + servicer.ClassificationReply, + request_deserializer=modelresponse__pb2.SingleStringRequest.FromString, + response_serializer=modelresponse__pb2.SingleStringReply.SerializeToString, + ), + 'QuestionAndAnswerReply': + grpc.unary_unary_rpc_method_handler( + servicer.QuestionAndAnswerReply, + request_deserializer=modelresponse__pb2.QARequest.FromString, + response_serializer=modelresponse__pb2.SingleStringReply.SerializeToString, + ), + 'FillMaskReply': + grpc.unary_unary_rpc_method_handler( + servicer.FillMaskReply, + request_deserializer=modelresponse__pb2.SingleStringRequest.FromString, + response_serializer=modelresponse__pb2.SingleStringReply.SerializeToString, + ), + 'TokenClassificationReply': + grpc.unary_unary_rpc_method_handler( + servicer.TokenClassificationReply, + request_deserializer=modelresponse__pb2.SingleStringRequest.FromString, + response_serializer=modelresponse__pb2.SingleStringReply.SerializeToString, + ), + 'ConversationalReply': + grpc.unary_unary_rpc_method_handler( + servicer.ConversationalReply, + request_deserializer=modelresponse__pb2.ConversationRequest.FromString, + response_serializer=modelresponse__pb2.ConversationReply.SerializeToString, + ), + 'Txt2ImgReply': + grpc.unary_unary_rpc_method_handler( + servicer.Txt2ImgReply, + request_deserializer=modelresponse__pb2.MultiStringRequest.FromString, + response_serializer=modelresponse__pb2.ImageReply.SerializeToString, + ), + 'GeneratorReplyStream': + grpc.unary_stream_rpc_method_handler( + servicer.GeneratorReplyStream, + request_deserializer=modelresponse__pb2.MultiStringRequest.FromString, + response_serializer=modelresponse__pb2.GenerationReply.SerializeToString, + ), } - generic_handler = grpc.method_handlers_generic_handler( - 'modelresponse.ModelResponse', rpc_method_handlers) - server.add_generic_rpc_handlers((generic_handler,)) + generic_handler = grpc.method_handlers_generic_handler('modelresponse.ModelResponse', + rpc_method_handlers) + server.add_generic_rpc_handlers((generic_handler, )) - # This class is part of an EXPERIMENTAL API. +# This class is part of an EXPERIMENTAL API. class ModelResponse(object): """Missing associated documentation comment in .proto file.""" - @staticmethod def Terminate(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary(request, target, '/modelresponse.ModelResponse/Terminate', + '/modelresponse.ModelResponse/Terminate', google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString, google_dot_protobuf_dot_empty__pb2.Empty.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata) @staticmethod def CreateSession(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary(request, target, '/modelresponse.ModelResponse/CreateSession', + '/modelresponse.ModelResponse/CreateSession', modelresponse__pb2.SessionID.SerializeToString, google_dot_protobuf_dot_empty__pb2.Empty.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata) @staticmethod def DestroySession(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary(request, target, '/modelresponse.ModelResponse/DestroySession', + '/modelresponse.ModelResponse/DestroySession', modelresponse__pb2.SessionID.SerializeToString, google_dot_protobuf_dot_empty__pb2.Empty.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata) @staticmethod def GeneratorReply(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary(request, target, '/modelresponse.ModelResponse/GeneratorReply', + '/modelresponse.ModelResponse/GeneratorReply', modelresponse__pb2.MultiStringRequest.SerializeToString, - modelresponse__pb2.MultiStringReply.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + modelresponse__pb2.GenerationReply.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata) @staticmethod def ClassificationReply(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary(request, target, '/modelresponse.ModelResponse/ClassificationReply', + '/modelresponse.ModelResponse/ClassificationReply', modelresponse__pb2.SingleStringRequest.SerializeToString, modelresponse__pb2.SingleStringReply.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata) @staticmethod def QuestionAndAnswerReply(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary(request, target, '/modelresponse.ModelResponse/QuestionAndAnswerReply', + '/modelresponse.ModelResponse/QuestionAndAnswerReply', modelresponse__pb2.QARequest.SerializeToString, modelresponse__pb2.SingleStringReply.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata) @staticmethod def FillMaskReply(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary(request, target, '/modelresponse.ModelResponse/FillMaskReply', + '/modelresponse.ModelResponse/FillMaskReply', modelresponse__pb2.SingleStringRequest.SerializeToString, modelresponse__pb2.SingleStringReply.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata) @staticmethod def TokenClassificationReply(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary(request, target, '/modelresponse.ModelResponse/TokenClassificationReply', + '/modelresponse.ModelResponse/TokenClassificationReply', modelresponse__pb2.SingleStringRequest.SerializeToString, modelresponse__pb2.SingleStringReply.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata) @staticmethod def ConversationalReply(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary(request, target, '/modelresponse.ModelResponse/ConversationalReply', + '/modelresponse.ModelResponse/ConversationalReply', modelresponse__pb2.ConversationRequest.SerializeToString, modelresponse__pb2.ConversationReply.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata) @staticmethod def Txt2ImgReply(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary(request, target, '/modelresponse.ModelResponse/Txt2ImgReply', + '/modelresponse.ModelResponse/Txt2ImgReply', modelresponse__pb2.MultiStringRequest.SerializeToString, modelresponse__pb2.ImageReply.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata) @staticmethod def GeneratorReplyStream(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_stream( + request, target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_stream(request, target, '/modelresponse.ModelResponse/GeneratorReplyStream', + '/modelresponse.ModelResponse/GeneratorReplyStream', modelresponse__pb2.MultiStringRequest.SerializeToString, modelresponse__pb2.GenerationReply.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata) diff --git a/mii/grpc_related/task_methods.py b/mii/grpc_related/task_methods.py index 1c1f9169..37dc64a2 100644 --- a/mii/grpc_related/task_methods.py +++ b/mii/grpc_related/task_methods.py @@ -4,7 +4,11 @@ # DeepSpeed Team from abc import ABC, abstractmethod +from typing import Any, Dict, List, Tuple +from google.protobuf.message import Message + +from mii.batching.data_classes import Response, ResponseBatch from mii.constants import TaskType from mii.grpc_related.proto import modelresponse_pb2 from mii.utils import kwarg_dict_to_proto, unpack_proto_query_kwargs @@ -22,38 +26,27 @@ def single_string_response_to_proto(self, response, time_taken, model_time_taken model_time_taken=model_time_taken) -def multi_string_request_to_proto(self, request_dict, **query_kwargs): - return modelresponse_pb2.MultiStringRequest( - request=request_dict["query"] if isinstance(request_dict["query"], - list) else [request_dict["query"]], - query_kwargs=kwarg_dict_to_proto(query_kwargs), - ) - - -def proto_request_to_list(self, request): - prompts = [r for r in request.request] - kwargs = unpack_proto_query_kwargs(request.query_kwargs) - return prompts, kwargs - - class TaskMethods(ABC): @property @abstractmethod def method(self): ... - def pack_request_to_proto(self, request_dict, **query_kwargs): - return request_dict, query_kwargs + @abstractmethod + def pack_request_to_proto(self, request, **query_kwargs): + ... - def unpack_request_from_proto(self, request): - return request + @abstractmethod + def unpack_request_from_proto(self, proto_request): + ... - def pack_response_to_proto(self, response, time_taken, model_time_taken): - return response, time_taken, model_time_taken + @abstractmethod + def pack_response_to_proto(self, response): + ... - def unpack_response_from_proto(self, response): - print("RESPONSE", response) - return response + @abstractmethod + def unpack_response_from_proto(self, proto_response): + ... class TextGenerationMethods(TaskMethods): @@ -65,10 +58,25 @@ def method(self): def method_stream_out(self): return "GeneratorReplyStream" - pack_request_to_proto = multi_string_request_to_proto - unpack_request_from_proto = proto_request_to_list + def pack_request_to_proto(self, + prompts: List[str], + **query_kwargs: Dict[str, + Any]) -> Message: + proto_request = modelresponse_pb2.MultiStringRequest( + request=prompts, + query_kwargs=kwarg_dict_to_proto(query_kwargs), + ) + return proto_request + + def unpack_request_from_proto(self, + proto_request: Message) -> Tuple[List[str], + Dict[str, + Any]]: + prompts = [r for r in proto_request.request] + kwargs = unpack_proto_query_kwargs(proto_request.query_kwargs) + return prompts, kwargs - def pack_response_to_proto(self, responses, time_taken, model_time_taken): + def pack_response_to_proto(self, responses: ResponseBatch) -> Message: text_responses = [] details = [] @@ -87,10 +95,22 @@ def pack_response_to_proto(self, responses, time_taken, model_time_taken): response=text_responses, indices=[0], details=details, - time_taken=time_taken, - model_time_taken=model_time_taken, + time_taken=-1, + model_time_taken=-1, ) + def unpack_response_from_proto(self, response: Message) -> ResponseBatch: + response_batch = ResponseBatch() + for i, r in enumerate(response.response): + response_batch.append( + Response( + generated_text=r, + prompt_length=response.details[i].prompt_tokens, + generated_length=response.details[i].generated_tokens, + finish_reason=response.details[i].finish_reason, + )) + return response_batch + TASK_METHODS_DICT = { TaskType.TEXT_GENERATION: TextGenerationMethods(), From 882d471387ee895d082c6154bd48702f1535c33e Mon Sep 17 00:00:00 2001 From: Michael Wyatt Date: Tue, 14 Nov 2023 16:07:52 -0800 Subject: [PATCH 5/9] remove unused proto code, light refactor of remaining proto messages --- mii/backend/client.py | 13 +- mii/batching/data_classes.py | 60 +-- mii/batching/ragged_batching.py | 16 +- mii/grpc_related/modelresponse_server.py | 2 +- mii/grpc_related/proto/modelresponse.proto | 80 +--- mii/grpc_related/proto/modelresponse_pb2.py | 56 +-- .../proto/modelresponse_pb2_grpc.py | 364 +----------------- mii/grpc_related/restful_gateway.py | 6 +- mii/grpc_related/task_methods.py | 52 ++- tests/test_deployment.py | 38 +- 10 files changed, 110 insertions(+), 577 deletions(-) diff --git a/mii/backend/client.py b/mii/backend/client.py index 8541aac5..796324b3 100644 --- a/mii/backend/client.py +++ b/mii/backend/client.py @@ -7,7 +7,7 @@ import requests from typing import Dict, Any, Callable, List, Union -from mii.batching.data_classes import ResponseBatch +from mii.batching.data_classes import Response from mii.config import MIIConfig from mii.constants import GRPC_MAX_MSG_SIZE from mii.grpc_related.proto import modelresponse_pb2, modelresponse_pb2_grpc @@ -38,7 +38,7 @@ def __init__(self, mii_config: MIIConfig, host: str = "localhost") -> None: channel = create_channel(host, self.port) self.stub = modelresponse_pb2_grpc.ModelResponseStub(channel) - def __call__(self, *args, **kwargs) -> ResponseBatch: + def __call__(self, *args, **kwargs) -> List[Response]: return self.generate(*args, **kwargs) async def _request_async_response(self, prompts, **query_kwargs): @@ -60,7 +60,8 @@ def generate(self, List[str]], streaming_fn: Callable = None, **query_kwargs: Dict[str, - Any]) -> ResponseBatch: + Any]) -> Union[None, + List[Response]]: if isinstance(prompts, str): prompts = [prompts] if streaming_fn is not None: @@ -78,7 +79,7 @@ def _generate_stream(self, callback, prompts: List[str], **query_kwargs: Dict[str, - Any]): + Any]) -> None: async def put_result(): response_stream = self._request_async_response_stream( prompts, @@ -93,11 +94,11 @@ async def put_result(): self.asyncio_loop.run_until_complete(put_result()) - async def terminate_async(self): + async def terminate_async(self) -> None: await self.stub.Terminate( modelresponse_pb2.google_dot_protobuf_dot_empty__pb2.Empty()) - def terminate_server(self): + def terminate_server(self) -> None: self.asyncio_loop.run_until_complete(self.terminate_async()) if self.mii_config.enable_restful_api: requests.get( diff --git a/mii/batching/data_classes.py b/mii/batching/data_classes.py index 53b0078e..c73ea60b 100644 --- a/mii/batching/data_classes.py +++ b/mii/batching/data_classes.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 # DeepSpeed Team -from dataclasses import dataclass, field +from dataclasses import dataclass, field, asdict from typing import Any, Dict, List, Iterator, Union from typing_extensions import Self @@ -19,27 +19,17 @@ class Response: finish_reason: GenerationFinishReason @staticmethod - def from_msg(msg: Dict[str, Union[str, int]]) -> Self: - return Response( - generated_text=msg["generated_text"], - prompt_length=msg["prompt_length"], - generated_length=msg["generated_length"], - finish_reason=GenerationFinishReason(msg["finish_reason"]), - ) - - def to_msg(self) -> Dict[str, Union[str, int]]: - return { - "generated_text": self.generated_text, - "prompt_length": self.prompt_length, - "generated_length": self.generated_length, - "finish_reason": self.finish_reason - } + def from_msg_dict(msg: Dict[str, Union[str, int]]) -> Self: + return Response(**msg) + + def to_msg_dict(self) -> Dict[str, Union[str, int]]: + return asdict(self) def __repr__(self) -> str: - return str(self.to_msg()) + return self.generated_text def __str__(self) -> str: - return self.to_msg() + return self.generated_text @dataclass @@ -159,40 +149,6 @@ def set_next_as_input(self) -> None: self.is_done = False -class ResponseBatch: - def __init__(self, responses: List[Response] = []) -> None: - self.responses = responses - - def __iter__(self) -> Iterator[Response]: - return iter(self.responses) - - def __str__(self) -> str: - return str(self.responses) - - def __repr__(self) -> str: - return self.responses - return "\n\n".join(str(r) for r in self.responses) - - @property - def generated_texts(self) -> List[str]: - return [r.generated_text for r in self.responses] - - @property - def prompt_lengths(self) -> List[int]: - return [r.prompt_length for r in self.responses] - - @property - def generated_lengths(self) -> List[int]: - return [r.generated_length for r in self.responses] - - @property - def finish_reasons(self) -> List[GenerationFinishReason]: - return [r.finish_reason for r in self.responses] - - def append(self, response: Response) -> None: - self.responses.append(response) - - class RequestBatch: def __init__(self, requests: List[Request] = []) -> None: self.requests = requests diff --git a/mii/batching/ragged_batching.py b/mii/batching/ragged_batching.py index 61a4323b..3a9ad964 100644 --- a/mii/batching/ragged_batching.py +++ b/mii/batching/ragged_batching.py @@ -40,7 +40,7 @@ TEMP_NAME, SAMPLER_NAME, STOP_NAME) -from mii.batching.data_classes import Response, Request, ResponseBatch, RequestBatch +from mii.batching.data_classes import Response, Request, RequestBatch from mii.batching.generation.logit_processors import TopPLogitProcessor, TopKLogitProcessor, TemperatureLogitProcessor from mii.batching.generation.samplers import LogitsSampler, GreedySampler from mii.batching.generation.stop_criterion import EosGenerationStopCriterion, TokenStopCriterion @@ -463,10 +463,10 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.tid = threading.get_ident() - def __call__(self, inputs: Union[str, List[str]], **kwargs) -> ResponseBatch: + def __call__(self, inputs: Union[str, List[str]], **kwargs) -> List[Response]: if isinstance(inputs, str): inputs = [inputs] - outputs: ResponseBatch = ResponseBatch([]) + outputs: List[Response] = [] uids_running: List[int] = list(range(len(inputs))) uids_complete_order: List[int] = [] @@ -494,12 +494,12 @@ def __call__(self, inputs: Union[str, List[str]], **kwargs) -> ResponseBatch: while self.scheduled_requests: self.generate() - outputs = ResponseBatch([ + outputs = [ r for idx, r in sorted(zip(uids_complete_order, outputs), key=lambda pair: pair[0]) - ]) + ] if self.model_config.all_rank_output: outputs = self._bcast_responses(outputs) @@ -519,15 +519,15 @@ def _get_response(self) -> Tuple[int, Response]: response = self.make_response(generated_tokens, result[2], result[3], result[4]) return uid, response - def _bcast_responses(self, responses: ResponseBatch) -> ResponseBatch: + def _bcast_responses(self, responses: List[Response]) -> List[Response]: if self.is_rank_0: - data_dicts = [r.to_msg() for r in responses] + data_dicts = [r.to_msg_dict() for r in responses] json_data = ujson.dumps(data_dicts) self.socket.send_string(json_data) else: json_data = self.socket.recv_string() data_dicts = ujson.loads(json_data) - responses = ResponseBatch([Response.from_msg(msg) for msg in data_dicts]) + responses = [Response.from_msg_dict(msg) for msg in data_dicts] return responses diff --git a/mii/grpc_related/modelresponse_server.py b/mii/grpc_related/modelresponse_server.py index 8a8a68d7..5092b817 100644 --- a/mii/grpc_related/modelresponse_server.py +++ b/mii/grpc_related/modelresponse_server.py @@ -244,7 +244,7 @@ def invoke_intercept_method_stream(request_proto, context): response_proto = result_queue.get( timeout=STREAM_RESPONSE_QUEUE_TIMEOUT) yield response_proto - if response_proto.details[0].finish_reason != str( + if response_proto.response[0].finish_reason != str( GenerationFinishReason.NONE): break except queue.Empty: diff --git a/mii/grpc_related/proto/modelresponse.proto b/mii/grpc_related/proto/modelresponse.proto index 5ad1f194..9ea04a9c 100644 --- a/mii/grpc_related/proto/modelresponse.proto +++ b/mii/grpc_related/proto/modelresponse.proto @@ -25,17 +25,8 @@ package modelresponse; service ModelResponse { rpc Terminate (google.protobuf.Empty) returns (google.protobuf.Empty) {} - rpc CreateSession (SessionID) returns (google.protobuf.Empty) {} - rpc DestroySession (SessionID) returns (google.protobuf.Empty) {} - rpc GeneratorReply (MultiStringRequest) returns (GenerationReply) {} - rpc ClassificationReply (SingleStringRequest) returns (SingleStringReply) {} - rpc QuestionAndAnswerReply(QARequest) returns (SingleStringReply) {} - rpc FillMaskReply(SingleStringRequest) returns (SingleStringReply) {} - rpc TokenClassificationReply(SingleStringRequest) returns (SingleStringReply) {} - rpc ConversationalReply(ConversationRequest) returns (ConversationReply) {} - rpc Txt2ImgReply(MultiStringRequest) returns (ImageReply) {} - - rpc GeneratorReplyStream (MultiStringRequest) returns (stream GenerationReply) {} + rpc GeneratorReply (MultiStringRequest) returns (MultiGenerationReply) {} + rpc GeneratorReplyStream (MultiStringRequest) returns (stream MultiGenerationReply) {} } message Dictionary { @@ -52,10 +43,6 @@ message Value { } } -message SessionID { - string session_id = 1; -} - message SingleStringRequest { string request = 1; map query_kwargs = 2; @@ -66,62 +53,15 @@ message MultiStringRequest { map query_kwargs = 2; } -message SingleStringReply { +message SingleGenerationReply { string response = 1; - float time_taken = 2; - float model_time_taken = 3; -} - -message MultiStringReply { - repeated string response = 1; - float time_taken = 2; - float model_time_taken = 3; -} - -message GenerationDetails { - string finish_reason = 1; - int64 prompt_tokens = 2; - int64 generated_tokens = 3; -} - -message GenerationReply { - repeated string response = 1; - // A request may contain multiple prompts and they produce different number of tokens. - // When streaming output is enabled, a response may contain generated tokens only for some prompts. - // `indices` represents the indices of prompts for which `response` and `details` are provided. - repeated int64 indices = 2; - repeated GenerationDetails details = 3; - float time_taken = 4; - float model_time_taken = 5; -} - -message QARequest { - string question = 1; - string context = 2; - map query_kwargs = 3; -} - -message ConversationRequest { - string text = 1; - string conversation_id = 2; - repeated string past_user_inputs = 3; - repeated string generated_responses = 4; - map query_kwargs = 5; -} - -message ConversationReply { - string conversation_id = 1; - repeated string past_user_inputs = 2; - repeated string generated_responses = 3; - float time_taken = 4; - float model_time_taken = 5; + string finish_reason = 2; + int64 prompt_tokens = 3; + int64 generated_tokens = 4; + float time_taken = 5; + float model_time_taken = 6; } -message ImageReply { - repeated bytes images = 1; - repeated bool nsfw_content_detected = 2; - string mode = 3; - int64 size_w = 4; - int64 size_h = 5; - float time_taken = 6; +message MultiGenerationReply { + repeated SingleGenerationReply response = 1; } diff --git a/mii/grpc_related/proto/modelresponse_pb2.py b/mii/grpc_related/proto/modelresponse_pb2.py index 88505cdd..c152e207 100644 --- a/mii/grpc_related/proto/modelresponse_pb2.py +++ b/mii/grpc_related/proto/modelresponse_pb2.py @@ -1,3 +1,7 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team # Generated by the protocol buffer compiler. DO NOT EDIT! # source: modelresponse.proto """Generated protocol buffer code.""" @@ -12,7 +16,7 @@ from google.protobuf import empty_pb2 as google_dot_protobuf_dot_empty__pb2 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x13modelresponse.proto\x12\rmodelresponse\x1a\x1bgoogle/protobuf/empty.proto\"\x88\x01\n\nDictionary\x12\x35\n\x06values\x18\x01 \x03(\x0b\x32%.modelresponse.Dictionary.ValuesEntry\x1a\x43\n\x0bValuesEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.modelresponse.Value:\x02\x38\x01\"\x8c\x01\n\x05Value\x12\x10\n\x06svalue\x18\x01 \x01(\tH\x00\x12\x10\n\x06ivalue\x18\x02 \x01(\x03H\x00\x12\x10\n\x06\x66value\x18\x03 \x01(\x02H\x00\x12\x10\n\x06\x62value\x18\x04 \x01(\x08H\x00\x12+\n\x06mvalue\x18\x05 \x01(\x0b\x32\x19.modelresponse.DictionaryH\x00\x42\x0e\n\x0coneof_values\"\x1f\n\tSessionID\x12\x12\n\nsession_id\x18\x01 \x01(\t\"\xbb\x01\n\x13SingleStringRequest\x12\x0f\n\x07request\x18\x01 \x01(\t\x12I\n\x0cquery_kwargs\x18\x02 \x03(\x0b\x32\x33.modelresponse.SingleStringRequest.QueryKwargsEntry\x1aH\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.modelresponse.Value:\x02\x38\x01\"\xb9\x01\n\x12MultiStringRequest\x12\x0f\n\x07request\x18\x01 \x03(\t\x12H\n\x0cquery_kwargs\x18\x02 \x03(\x0b\x32\x32.modelresponse.MultiStringRequest.QueryKwargsEntry\x1aH\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.modelresponse.Value:\x02\x38\x01\"S\n\x11SingleStringReply\x12\x10\n\x08response\x18\x01 \x01(\t\x12\x12\n\ntime_taken\x18\x02 \x01(\x02\x12\x18\n\x10model_time_taken\x18\x03 \x01(\x02\"R\n\x10MultiStringReply\x12\x10\n\x08response\x18\x01 \x03(\t\x12\x12\n\ntime_taken\x18\x02 \x01(\x02\x12\x18\n\x10model_time_taken\x18\x03 \x01(\x02\"[\n\x11GenerationDetails\x12\x15\n\rfinish_reason\x18\x01 \x01(\t\x12\x15\n\rprompt_tokens\x18\x02 \x01(\x03\x12\x18\n\x10generated_tokens\x18\x03 \x01(\x03\"\x95\x01\n\x0fGenerationReply\x12\x10\n\x08response\x18\x01 \x03(\t\x12\x0f\n\x07indices\x18\x02 \x03(\x03\x12\x31\n\x07\x64\x65tails\x18\x03 \x03(\x0b\x32 .modelresponse.GenerationDetails\x12\x12\n\ntime_taken\x18\x04 \x01(\x02\x12\x18\n\x10model_time_taken\x18\x05 \x01(\x02\"\xb9\x01\n\tQARequest\x12\x10\n\x08question\x18\x01 \x01(\t\x12\x0f\n\x07\x63ontext\x18\x02 \x01(\t\x12?\n\x0cquery_kwargs\x18\x03 \x03(\x0b\x32).modelresponse.QARequest.QueryKwargsEntry\x1aH\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.modelresponse.Value:\x02\x38\x01\"\x88\x02\n\x13\x43onversationRequest\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\x17\n\x0f\x63onversation_id\x18\x02 \x01(\t\x12\x18\n\x10past_user_inputs\x18\x03 \x03(\t\x12\x1b\n\x13generated_responses\x18\x04 \x03(\t\x12I\n\x0cquery_kwargs\x18\x05 \x03(\x0b\x32\x33.modelresponse.ConversationRequest.QueryKwargsEntry\x1aH\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.modelresponse.Value:\x02\x38\x01\"\x91\x01\n\x11\x43onversationReply\x12\x17\n\x0f\x63onversation_id\x18\x01 \x01(\t\x12\x18\n\x10past_user_inputs\x18\x02 \x03(\t\x12\x1b\n\x13generated_responses\x18\x03 \x03(\t\x12\x12\n\ntime_taken\x18\x04 \x01(\x02\x12\x18\n\x10model_time_taken\x18\x05 \x01(\x02\"}\n\nImageReply\x12\x0e\n\x06images\x18\x01 \x03(\x0c\x12\x1d\n\x15nsfw_content_detected\x18\x02 \x03(\x08\x12\x0c\n\x04mode\x18\x03 \x01(\t\x12\x0e\n\x06size_w\x18\x04 \x01(\x03\x12\x0e\n\x06size_h\x18\x05 \x01(\x03\x12\x12\n\ntime_taken\x18\x06 \x01(\x02\x32\xb2\x07\n\rModelResponse\x12=\n\tTerminate\x12\x16.google.protobuf.Empty\x1a\x16.google.protobuf.Empty\"\x00\x12\x43\n\rCreateSession\x12\x18.modelresponse.SessionID\x1a\x16.google.protobuf.Empty\"\x00\x12\x44\n\x0e\x44\x65stroySession\x12\x18.modelresponse.SessionID\x1a\x16.google.protobuf.Empty\"\x00\x12U\n\x0eGeneratorReply\x12!.modelresponse.MultiStringRequest\x1a\x1e.modelresponse.GenerationReply\"\x00\x12]\n\x13\x43lassificationReply\x12\".modelresponse.SingleStringRequest\x1a .modelresponse.SingleStringReply\"\x00\x12V\n\x16QuestionAndAnswerReply\x12\x18.modelresponse.QARequest\x1a .modelresponse.SingleStringReply\"\x00\x12W\n\rFillMaskReply\x12\".modelresponse.SingleStringRequest\x1a .modelresponse.SingleStringReply\"\x00\x12\x62\n\x18TokenClassificationReply\x12\".modelresponse.SingleStringRequest\x1a .modelresponse.SingleStringReply\"\x00\x12]\n\x13\x43onversationalReply\x12\".modelresponse.ConversationRequest\x1a .modelresponse.ConversationReply\"\x00\x12N\n\x0cTxt2ImgReply\x12!.modelresponse.MultiStringRequest\x1a\x19.modelresponse.ImageReply\"\x00\x12]\n\x14GeneratorReplyStream\x12!.modelresponse.MultiStringRequest\x1a\x1e.modelresponse.GenerationReply\"\x00\x30\x01\x62\x06proto3' + b'\n\x13modelresponse.proto\x12\rmodelresponse\x1a\x1bgoogle/protobuf/empty.proto\"\x88\x01\n\nDictionary\x12\x35\n\x06values\x18\x01 \x03(\x0b\x32%.modelresponse.Dictionary.ValuesEntry\x1a\x43\n\x0bValuesEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.modelresponse.Value:\x02\x38\x01\"\x8c\x01\n\x05Value\x12\x10\n\x06svalue\x18\x01 \x01(\tH\x00\x12\x10\n\x06ivalue\x18\x02 \x01(\x03H\x00\x12\x10\n\x06\x66value\x18\x03 \x01(\x02H\x00\x12\x10\n\x06\x62value\x18\x04 \x01(\x08H\x00\x12+\n\x06mvalue\x18\x05 \x01(\x0b\x32\x19.modelresponse.DictionaryH\x00\x42\x0e\n\x0coneof_values\"\xbb\x01\n\x13SingleStringRequest\x12\x0f\n\x07request\x18\x01 \x01(\t\x12I\n\x0cquery_kwargs\x18\x02 \x03(\x0b\x32\x33.modelresponse.SingleStringRequest.QueryKwargsEntry\x1aH\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.modelresponse.Value:\x02\x38\x01\"\xb9\x01\n\x12MultiStringRequest\x12\x0f\n\x07request\x18\x01 \x03(\t\x12H\n\x0cquery_kwargs\x18\x02 \x03(\x0b\x32\x32.modelresponse.MultiStringRequest.QueryKwargsEntry\x1aH\n\x10QueryKwargsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12#\n\x05value\x18\x02 \x01(\x0b\x32\x14.modelresponse.Value:\x02\x38\x01\"\x9f\x01\n\x15SingleGenerationReply\x12\x10\n\x08response\x18\x01 \x01(\t\x12\x15\n\rfinish_reason\x18\x02 \x01(\t\x12\x15\n\rprompt_tokens\x18\x03 \x01(\x03\x12\x18\n\x10generated_tokens\x18\x04 \x01(\x03\x12\x12\n\ntime_taken\x18\x05 \x01(\x02\x12\x18\n\x10model_time_taken\x18\x06 \x01(\x02\"N\n\x14MultiGenerationReply\x12\x36\n\x08response\x18\x01 \x03(\x0b\x32$.modelresponse.SingleGenerationReply2\x8e\x02\n\rModelResponse\x12=\n\tTerminate\x12\x16.google.protobuf.Empty\x1a\x16.google.protobuf.Empty\"\x00\x12Z\n\x0eGeneratorReply\x12!.modelresponse.MultiStringRequest\x1a#.modelresponse.MultiGenerationReply\"\x00\x12\x62\n\x14GeneratorReplyStream\x12!.modelresponse.MultiStringRequest\x1a#.modelresponse.MultiGenerationReply\"\x00\x30\x01\x62\x06proto3' ) _globals = globals() @@ -26,46 +30,24 @@ _SINGLESTRINGREQUEST_QUERYKWARGSENTRY._serialized_options = b'8\001' _MULTISTRINGREQUEST_QUERYKWARGSENTRY._options = None _MULTISTRINGREQUEST_QUERYKWARGSENTRY._serialized_options = b'8\001' - _QAREQUEST_QUERYKWARGSENTRY._options = None - _QAREQUEST_QUERYKWARGSENTRY._serialized_options = b'8\001' - _CONVERSATIONREQUEST_QUERYKWARGSENTRY._options = None - _CONVERSATIONREQUEST_QUERYKWARGSENTRY._serialized_options = b'8\001' _globals['_DICTIONARY']._serialized_start = 68 _globals['_DICTIONARY']._serialized_end = 204 _globals['_DICTIONARY_VALUESENTRY']._serialized_start = 137 _globals['_DICTIONARY_VALUESENTRY']._serialized_end = 204 _globals['_VALUE']._serialized_start = 207 _globals['_VALUE']._serialized_end = 347 - _globals['_SESSIONID']._serialized_start = 349 - _globals['_SESSIONID']._serialized_end = 380 - _globals['_SINGLESTRINGREQUEST']._serialized_start = 383 - _globals['_SINGLESTRINGREQUEST']._serialized_end = 570 - _globals['_SINGLESTRINGREQUEST_QUERYKWARGSENTRY']._serialized_start = 498 - _globals['_SINGLESTRINGREQUEST_QUERYKWARGSENTRY']._serialized_end = 570 - _globals['_MULTISTRINGREQUEST']._serialized_start = 573 - _globals['_MULTISTRINGREQUEST']._serialized_end = 758 - _globals['_MULTISTRINGREQUEST_QUERYKWARGSENTRY']._serialized_start = 498 - _globals['_MULTISTRINGREQUEST_QUERYKWARGSENTRY']._serialized_end = 570 - _globals['_SINGLESTRINGREPLY']._serialized_start = 760 - _globals['_SINGLESTRINGREPLY']._serialized_end = 843 - _globals['_MULTISTRINGREPLY']._serialized_start = 845 - _globals['_MULTISTRINGREPLY']._serialized_end = 927 - _globals['_GENERATIONDETAILS']._serialized_start = 929 - _globals['_GENERATIONDETAILS']._serialized_end = 1020 - _globals['_GENERATIONREPLY']._serialized_start = 1023 - _globals['_GENERATIONREPLY']._serialized_end = 1172 - _globals['_QAREQUEST']._serialized_start = 1175 - _globals['_QAREQUEST']._serialized_end = 1360 - _globals['_QAREQUEST_QUERYKWARGSENTRY']._serialized_start = 498 - _globals['_QAREQUEST_QUERYKWARGSENTRY']._serialized_end = 570 - _globals['_CONVERSATIONREQUEST']._serialized_start = 1363 - _globals['_CONVERSATIONREQUEST']._serialized_end = 1627 - _globals['_CONVERSATIONREQUEST_QUERYKWARGSENTRY']._serialized_start = 498 - _globals['_CONVERSATIONREQUEST_QUERYKWARGSENTRY']._serialized_end = 570 - _globals['_CONVERSATIONREPLY']._serialized_start = 1630 - _globals['_CONVERSATIONREPLY']._serialized_end = 1775 - _globals['_IMAGEREPLY']._serialized_start = 1777 - _globals['_IMAGEREPLY']._serialized_end = 1902 - _globals['_MODELRESPONSE']._serialized_start = 1905 - _globals['_MODELRESPONSE']._serialized_end = 2851 + _globals['_SINGLESTRINGREQUEST']._serialized_start = 350 + _globals['_SINGLESTRINGREQUEST']._serialized_end = 537 + _globals['_SINGLESTRINGREQUEST_QUERYKWARGSENTRY']._serialized_start = 465 + _globals['_SINGLESTRINGREQUEST_QUERYKWARGSENTRY']._serialized_end = 537 + _globals['_MULTISTRINGREQUEST']._serialized_start = 540 + _globals['_MULTISTRINGREQUEST']._serialized_end = 725 + _globals['_MULTISTRINGREQUEST_QUERYKWARGSENTRY']._serialized_start = 465 + _globals['_MULTISTRINGREQUEST_QUERYKWARGSENTRY']._serialized_end = 537 + _globals['_SINGLEGENERATIONREPLY']._serialized_start = 728 + _globals['_SINGLEGENERATIONREPLY']._serialized_end = 887 + _globals['_MULTIGENERATIONREPLY']._serialized_start = 889 + _globals['_MULTIGENERATIONREPLY']._serialized_end = 967 + _globals['_MODELRESPONSE']._serialized_start = 970 + _globals['_MODELRESPONSE']._serialized_end = 1240 # @@protoc_insertion_point(module_scope) diff --git a/mii/grpc_related/proto/modelresponse_pb2_grpc.py b/mii/grpc_related/proto/modelresponse_pb2_grpc.py index e94ec498..8da300b6 100644 --- a/mii/grpc_related/proto/modelresponse_pb2_grpc.py +++ b/mii/grpc_related/proto/modelresponse_pb2_grpc.py @@ -1,3 +1,7 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! """Client and server classes corresponding to protobuf-defined services.""" import grpc @@ -20,55 +24,15 @@ def __init__(self, channel): SerializeToString, response_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString, ) - self.CreateSession = channel.unary_unary( - '/modelresponse.ModelResponse/CreateSession', - request_serializer=modelresponse__pb2.SessionID.SerializeToString, - response_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString, - ) - self.DestroySession = channel.unary_unary( - '/modelresponse.ModelResponse/DestroySession', - request_serializer=modelresponse__pb2.SessionID.SerializeToString, - response_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString, - ) self.GeneratorReply = channel.unary_unary( '/modelresponse.ModelResponse/GeneratorReply', request_serializer=modelresponse__pb2.MultiStringRequest.SerializeToString, - response_deserializer=modelresponse__pb2.GenerationReply.FromString, - ) - self.ClassificationReply = channel.unary_unary( - '/modelresponse.ModelResponse/ClassificationReply', - request_serializer=modelresponse__pb2.SingleStringRequest.SerializeToString, - response_deserializer=modelresponse__pb2.SingleStringReply.FromString, - ) - self.QuestionAndAnswerReply = channel.unary_unary( - '/modelresponse.ModelResponse/QuestionAndAnswerReply', - request_serializer=modelresponse__pb2.QARequest.SerializeToString, - response_deserializer=modelresponse__pb2.SingleStringReply.FromString, - ) - self.FillMaskReply = channel.unary_unary( - '/modelresponse.ModelResponse/FillMaskReply', - request_serializer=modelresponse__pb2.SingleStringRequest.SerializeToString, - response_deserializer=modelresponse__pb2.SingleStringReply.FromString, - ) - self.TokenClassificationReply = channel.unary_unary( - '/modelresponse.ModelResponse/TokenClassificationReply', - request_serializer=modelresponse__pb2.SingleStringRequest.SerializeToString, - response_deserializer=modelresponse__pb2.SingleStringReply.FromString, - ) - self.ConversationalReply = channel.unary_unary( - '/modelresponse.ModelResponse/ConversationalReply', - request_serializer=modelresponse__pb2.ConversationRequest.SerializeToString, - response_deserializer=modelresponse__pb2.ConversationReply.FromString, - ) - self.Txt2ImgReply = channel.unary_unary( - '/modelresponse.ModelResponse/Txt2ImgReply', - request_serializer=modelresponse__pb2.MultiStringRequest.SerializeToString, - response_deserializer=modelresponse__pb2.ImageReply.FromString, + response_deserializer=modelresponse__pb2.MultiGenerationReply.FromString, ) self.GeneratorReplyStream = channel.unary_stream( '/modelresponse.ModelResponse/GeneratorReplyStream', request_serializer=modelresponse__pb2.MultiStringRequest.SerializeToString, - response_deserializer=modelresponse__pb2.GenerationReply.FromString, + response_deserializer=modelresponse__pb2.MultiGenerationReply.FromString, ) @@ -80,60 +44,12 @@ def Terminate(self, request, context): context.set_details('Method not implemented!') raise NotImplementedError('Method not implemented!') - def CreateSession(self, request, context): - """Missing associated documentation comment in .proto file.""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - def DestroySession(self, request, context): - """Missing associated documentation comment in .proto file.""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - def GeneratorReply(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) context.set_details('Method not implemented!') raise NotImplementedError('Method not implemented!') - def ClassificationReply(self, request, context): - """Missing associated documentation comment in .proto file.""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - def QuestionAndAnswerReply(self, request, context): - """Missing associated documentation comment in .proto file.""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - def FillMaskReply(self, request, context): - """Missing associated documentation comment in .proto file.""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - def TokenClassificationReply(self, request, context): - """Missing associated documentation comment in .proto file.""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - def ConversationalReply(self, request, context): - """Missing associated documentation comment in .proto file.""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - def Txt2ImgReply(self, request, context): - """Missing associated documentation comment in .proto file.""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - def GeneratorReplyStream(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) @@ -150,67 +66,19 @@ def add_ModelResponseServicer_to_server(servicer, server): response_serializer=google_dot_protobuf_dot_empty__pb2.Empty. SerializeToString, ), - 'CreateSession': - grpc.unary_unary_rpc_method_handler( - servicer.CreateSession, - request_deserializer=modelresponse__pb2.SessionID.FromString, - response_serializer=google_dot_protobuf_dot_empty__pb2.Empty. - SerializeToString, - ), - 'DestroySession': - grpc.unary_unary_rpc_method_handler( - servicer.DestroySession, - request_deserializer=modelresponse__pb2.SessionID.FromString, - response_serializer=google_dot_protobuf_dot_empty__pb2.Empty. - SerializeToString, - ), 'GeneratorReply': grpc.unary_unary_rpc_method_handler( servicer.GeneratorReply, request_deserializer=modelresponse__pb2.MultiStringRequest.FromString, - response_serializer=modelresponse__pb2.GenerationReply.SerializeToString, - ), - 'ClassificationReply': - grpc.unary_unary_rpc_method_handler( - servicer.ClassificationReply, - request_deserializer=modelresponse__pb2.SingleStringRequest.FromString, - response_serializer=modelresponse__pb2.SingleStringReply.SerializeToString, - ), - 'QuestionAndAnswerReply': - grpc.unary_unary_rpc_method_handler( - servicer.QuestionAndAnswerReply, - request_deserializer=modelresponse__pb2.QARequest.FromString, - response_serializer=modelresponse__pb2.SingleStringReply.SerializeToString, - ), - 'FillMaskReply': - grpc.unary_unary_rpc_method_handler( - servicer.FillMaskReply, - request_deserializer=modelresponse__pb2.SingleStringRequest.FromString, - response_serializer=modelresponse__pb2.SingleStringReply.SerializeToString, - ), - 'TokenClassificationReply': - grpc.unary_unary_rpc_method_handler( - servicer.TokenClassificationReply, - request_deserializer=modelresponse__pb2.SingleStringRequest.FromString, - response_serializer=modelresponse__pb2.SingleStringReply.SerializeToString, - ), - 'ConversationalReply': - grpc.unary_unary_rpc_method_handler( - servicer.ConversationalReply, - request_deserializer=modelresponse__pb2.ConversationRequest.FromString, - response_serializer=modelresponse__pb2.ConversationReply.SerializeToString, - ), - 'Txt2ImgReply': - grpc.unary_unary_rpc_method_handler( - servicer.Txt2ImgReply, - request_deserializer=modelresponse__pb2.MultiStringRequest.FromString, - response_serializer=modelresponse__pb2.ImageReply.SerializeToString, + response_serializer=modelresponse__pb2.MultiGenerationReply. + SerializeToString, ), 'GeneratorReplyStream': grpc.unary_stream_rpc_method_handler( servicer.GeneratorReplyStream, request_deserializer=modelresponse__pb2.MultiStringRequest.FromString, - response_serializer=modelresponse__pb2.GenerationReply.SerializeToString, + response_serializer=modelresponse__pb2.MultiGenerationReply. + SerializeToString, ), } generic_handler = grpc.method_handlers_generic_handler('modelresponse.ModelResponse', @@ -247,58 +115,6 @@ def Terminate(request, timeout, metadata) - @staticmethod - def CreateSession(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary( - request, - target, - '/modelresponse.ModelResponse/CreateSession', - modelresponse__pb2.SessionID.SerializeToString, - google_dot_protobuf_dot_empty__pb2.Empty.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata) - - @staticmethod - def DestroySession(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary( - request, - target, - '/modelresponse.ModelResponse/DestroySession', - modelresponse__pb2.SessionID.SerializeToString, - google_dot_protobuf_dot_empty__pb2.Empty.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata) - @staticmethod def GeneratorReply(request, target, @@ -315,163 +131,7 @@ def GeneratorReply(request, target, '/modelresponse.ModelResponse/GeneratorReply', modelresponse__pb2.MultiStringRequest.SerializeToString, - modelresponse__pb2.GenerationReply.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata) - - @staticmethod - def ClassificationReply(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary( - request, - target, - '/modelresponse.ModelResponse/ClassificationReply', - modelresponse__pb2.SingleStringRequest.SerializeToString, - modelresponse__pb2.SingleStringReply.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata) - - @staticmethod - def QuestionAndAnswerReply(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary( - request, - target, - '/modelresponse.ModelResponse/QuestionAndAnswerReply', - modelresponse__pb2.QARequest.SerializeToString, - modelresponse__pb2.SingleStringReply.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata) - - @staticmethod - def FillMaskReply(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary( - request, - target, - '/modelresponse.ModelResponse/FillMaskReply', - modelresponse__pb2.SingleStringRequest.SerializeToString, - modelresponse__pb2.SingleStringReply.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata) - - @staticmethod - def TokenClassificationReply(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary( - request, - target, - '/modelresponse.ModelResponse/TokenClassificationReply', - modelresponse__pb2.SingleStringRequest.SerializeToString, - modelresponse__pb2.SingleStringReply.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata) - - @staticmethod - def ConversationalReply(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary( - request, - target, - '/modelresponse.ModelResponse/ConversationalReply', - modelresponse__pb2.ConversationRequest.SerializeToString, - modelresponse__pb2.ConversationReply.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata) - - @staticmethod - def Txt2ImgReply(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary( - request, - target, - '/modelresponse.ModelResponse/Txt2ImgReply', - modelresponse__pb2.MultiStringRequest.SerializeToString, - modelresponse__pb2.ImageReply.FromString, + modelresponse__pb2.MultiGenerationReply.FromString, options, channel_credentials, insecure, @@ -497,7 +157,7 @@ def GeneratorReplyStream(request, target, '/modelresponse.ModelResponse/GeneratorReplyStream', modelresponse__pb2.MultiStringRequest.SerializeToString, - modelresponse__pb2.GenerationReply.FromString, + modelresponse__pb2.MultiGenerationReply.FromString, options, channel_credentials, insecure, diff --git a/mii/grpc_related/restful_gateway.py b/mii/grpc_related/restful_gateway.py index dc2dab71..5c2bc48a 100644 --- a/mii/grpc_related/restful_gateway.py +++ b/mii/grpc_related/restful_gateway.py @@ -2,11 +2,12 @@ # SPDX-License-Identifier: Apache-2.0 # DeepSpeed Team +import json import threading import time + from flask import Flask, request from flask_restful import Resource, Api -from google.protobuf.json_format import MessageToJson from werkzeug.serving import make_server import mii @@ -29,7 +30,8 @@ def __init__(self): def post(self): data = request.get_json() result = client.generate(**data) - return MessageToJson(result) + result_json = json.dumps([r.to_msg_dict() for r in result]) + return result_json app = Flask("RestfulGateway") diff --git a/mii/grpc_related/task_methods.py b/mii/grpc_related/task_methods.py index 37dc64a2..5a4cf230 100644 --- a/mii/grpc_related/task_methods.py +++ b/mii/grpc_related/task_methods.py @@ -8,7 +8,7 @@ from google.protobuf.message import Message -from mii.batching.data_classes import Response, ResponseBatch +from mii.batching.data_classes import Response from mii.constants import TaskType from mii.grpc_related.proto import modelresponse_pb2 from mii.utils import kwarg_dict_to_proto, unpack_proto_query_kwargs @@ -76,38 +76,30 @@ def unpack_request_from_proto(self, kwargs = unpack_proto_query_kwargs(proto_request.query_kwargs) return prompts, kwargs - def pack_response_to_proto(self, responses: ResponseBatch) -> Message: - text_responses = [] - details = [] - - # Response a nested list of dicts - # [Sample, 1, Dict] - for response in responses: - text = response.generated_text - text_responses.append(text) - details.append( - modelresponse_pb2.GenerationDetails( - finish_reason=str(response.finish_reason), - prompt_tokens=response.prompt_length, - generated_tokens=response.generated_length)) - - return modelresponse_pb2.GenerationReply( - response=text_responses, - indices=[0], - details=details, - time_taken=-1, - model_time_taken=-1, - ) + def pack_response_to_proto(self, responses: List[Response]) -> Message: + proto_responses = [] + for r in responses: + proto_responses.append( + modelresponse_pb2.SingleGenerationReply( + response=r.generated_text, + finish_reason=str(r.finish_reason), + prompt_tokens=r.prompt_length, + generated_tokens=r.generated_length, + time_taken=-1, + model_time_taken=-1, + )) + + return modelresponse_pb2.MultiGenerationReply(response=proto_responses, ) - def unpack_response_from_proto(self, response: Message) -> ResponseBatch: - response_batch = ResponseBatch() - for i, r in enumerate(response.response): + def unpack_response_from_proto(self, response: Message) -> List[Response]: + response_batch = [] + for r in response.response: response_batch.append( Response( - generated_text=r, - prompt_length=response.details[i].prompt_tokens, - generated_length=response.details[i].generated_tokens, - finish_reason=response.details[i].finish_reason, + generated_text=r.response, + prompt_length=r.prompt_tokens, + generated_length=r.generated_tokens, + finish_reason=r.finish_reason, )) return response_batch diff --git a/tests/test_deployment.py b/tests/test_deployment.py index 125b34c7..6e897d82 100644 --- a/tests/test_deployment.py +++ b/tests/test_deployment.py @@ -14,30 +14,30 @@ def test_single_gpu(deployment, query): - output = deployment(query) - assert output, "output is empty" + outputs = deployment(query) + assert outputs[0], "output is empty" def test_streaming(deployment, query): - output = [] + outputs = [] def callback(response): - output.append(response.response) + outputs.append(response[0].generated_text) deployment(query, streaming_fn=callback) - assert output, "output is empty" + assert outputs, "output is empty" def test_multi_prompt(deployment, query): - output = deployment([query] * 4) - for r in output.response: + outputs = deployment([query] * 4) + for r in outputs: assert r, "output is empty" @pytest.mark.parametrize("tensor_parallel", [2]) def test_multi_gpu(deployment, query): - output = deployment(query) - assert output, "output is empty" + outputs = deployment(query) + assert outputs[0], "output is empty" @pytest.mark.parametrize("replica_num", [2]) @@ -45,9 +45,9 @@ def test_multi_replica(deployment, query): deployment_name = deployment.mii_config.deployment_name start = time.time() - output = mii.client(deployment_name)(query, max_length=128, ignore_eos=True) + outputs = mii.client(deployment_name)(query, max_length=128, ignore_eos=True) end = time.time() - assert output, "output is empty" + assert outputs[0], "output is empty" single_query_time = end - start procs = [] @@ -77,7 +77,7 @@ def test_multi_replica(deployment, query): def test_query_kwargs(deployment, query): # test ignore_eos - output = deployment( + outputs = deployment( query, max_length=128, min_new_tokens=16, @@ -86,14 +86,14 @@ def test_query_kwargs(deployment, query): top_k=50, temperature=0.9, ) - assert output, "output is empty" + assert outputs[0], "output is empty" def test_do_sample(deployment, query): output_0 = deployment(query, do_sample=False, max_length=128) output_1 = deployment(query, do_sample=False, max_length=128) assert ( - output_0.response == output_1.response + output_0[0] == output_1[0] ), "do_sample=False should always return the same output" @@ -105,15 +105,15 @@ def test_stop_token(deployment, query): def test_return_full_text(deployment, query): - output = deployment(query, max_length=128, return_full_text=True) - assert output.response[0].startswith(query), "output should start with the prompt" + outputs = deployment(query, max_length=128, return_full_text=True) + assert outputs[0].generated_text.startswith(query), "output should start with the prompt" @pytest.mark.parametrize("enable_restful_api", [True]) def test_restful_api(deployment, query, deployment_name, restful_api_port): # Verify deployment is running - output = deployment(query, max_length=128) - assert output, "output is empty" + outputs = deployment(query, max_length=128) + assert outputs[0], "output is empty" # Verify REST API url = f"http://localhost:{restful_api_port}/mii/{deployment_name}" @@ -123,4 +123,4 @@ def test_restful_api(deployment, query, deployment_name, restful_api_port): data=json_params, headers={"Content-Type": "application/json"}) assert result.status_code == 200 - assert "response" in result.json() + assert "generated_text" in result.json() From 3a035032661df7374cedba0c988ac44f9127ba36 Mon Sep 17 00:00:00 2001 From: Michael Wyatt Date: Thu, 16 Nov 2023 16:59:37 -0800 Subject: [PATCH 6/9] fix bug in default value --- mii/batching/data_classes.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/mii/batching/data_classes.py b/mii/batching/data_classes.py index c73ea60b..4bc46f73 100644 --- a/mii/batching/data_classes.py +++ b/mii/batching/data_classes.py @@ -150,7 +150,9 @@ def set_next_as_input(self) -> None: class RequestBatch: - def __init__(self, requests: List[Request] = []) -> None: + def __init__(self, requests: List[Request] = None) -> None: + if requests is None: + requests = [] self.requests = requests def __len__(self) -> int: From 85b9d02d3dd0cb9c21ed1c247de910c6e617b515 Mon Sep 17 00:00:00 2001 From: Michael Wyatt Date: Fri, 17 Nov 2023 14:20:12 -0800 Subject: [PATCH 7/9] update docs --- README.md | 20 +++++++++++++++++--- mii/grpc_related/task_methods.py | 2 +- 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index b1e09bea..0d76590f 100644 --- a/README.md +++ b/README.md @@ -116,10 +116,17 @@ A non-persistent pipeline is a great way to try DeepSpeed-MII. Non-persistent pi ```python import mii pipe = mii.pipeline("mistralai/Mistral-7B-v0.1") -response = pipe("DeepSpeed is", max_new_tokens=128) +response = pipe(["DeepSpeed is", "Seattle is"], max_new_tokens=128) print(response) ``` +The returned `response` is a list of `Response` objects. We can access several details about the generation (e.g., `response[0].prompt_length`): + +- `generated_text: str` Text generated by the model. +- `prompt_length: int` Number of tokens in the original prompt. +- `generated_length: int` Number of tokens generated. +- `finish_reason: str` Reason for stopping generation. `stop` indicates the EOS token was generated and `length` indicates the generation reached `max_new_tokens` or `max_length`. + ### Tensor parallelism Taking advantage of multi-GPU systems for greater performance is easy with MII. When run with the `deepspeed` launcher, tensor parallelism is automatically controlled by the `--num_gpus` flag: @@ -158,10 +165,17 @@ A persistent deployment is ideal for use with long-running and production applic ```python import mii client = mii.serve("mistralai/Mistral-7B-v0.1") -response = client.generate("Deepspeed is", max_new_tokens=128) -print(response.response) +response = client.generate(["Deepspeed is", "Seattle is"], max_new_tokens=128) +print(response) ``` +The returned `response` is a list of `Response` objects. We can access several details about the generation (e.g., `response[0].prompt_length`): + +- `generated_text: str` Text generated by the model. +- `prompt_length: int` Number of tokens in the original prompt. +- `generated_length: int` Number of tokens generated. +- `finish_reason: str` Reason for stopping generation. `stop` indicates the EOS token was generated and `length` indicates the generation reached `max_new_tokens` or `max_length`. + If we want to generate text from other processes, we can do that too: ```python diff --git a/mii/grpc_related/task_methods.py b/mii/grpc_related/task_methods.py index 5a4cf230..77c4a3fc 100644 --- a/mii/grpc_related/task_methods.py +++ b/mii/grpc_related/task_methods.py @@ -82,7 +82,7 @@ def pack_response_to_proto(self, responses: List[Response]) -> Message: proto_responses.append( modelresponse_pb2.SingleGenerationReply( response=r.generated_text, - finish_reason=str(r.finish_reason), + finish_reason=str(r.finish_reason.value), prompt_tokens=r.prompt_length, generated_tokens=r.generated_length, time_taken=-1, From ed296c22efb230b62f8d93b3b1312b67cc0b97a3 Mon Sep 17 00:00:00 2001 From: Michael Wyatt Date: Fri, 17 Nov 2023 14:45:32 -0800 Subject: [PATCH 8/9] update to install latest DS --- .github/workflows/formatting.yml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.github/workflows/formatting.yml b/.github/workflows/formatting.yml index 6aa26187..82cd1b9f 100644 --- a/.github/workflows/formatting.yml +++ b/.github/workflows/formatting.yml @@ -27,6 +27,10 @@ jobs: which python python --version + - name: Install DeepSpeed + run: | + pip install git+https://github.com/microsoft/DeepSpeed.git + - name: Install MII run: | pip install .[dev] From 45a4a112f9965514cc711ddf422cf7f689d44c44 Mon Sep 17 00:00:00 2001 From: Michael Wyatt Date: Tue, 21 Nov 2023 10:51:39 -0800 Subject: [PATCH 9/9] fix for streaming bug --- mii/grpc_related/modelresponse_server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mii/grpc_related/modelresponse_server.py b/mii/grpc_related/modelresponse_server.py index 5092b817..69d37890 100644 --- a/mii/grpc_related/modelresponse_server.py +++ b/mii/grpc_related/modelresponse_server.py @@ -245,7 +245,7 @@ def invoke_intercept_method_stream(request_proto, context): timeout=STREAM_RESPONSE_QUEUE_TIMEOUT) yield response_proto if response_proto.response[0].finish_reason != str( - GenerationFinishReason.NONE): + GenerationFinishReason.NONE.value): break except queue.Empty: print(