Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unify input/output types #295

Merged
merged 11 commits into from
Nov 21, 2023
32 changes: 16 additions & 16 deletions mii/backend/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import requests
from typing import Dict, Any, Callable, List, Union

from mii.batching.data_classes import Response
from mii.config import MIIConfig
from mii.constants import GRPC_MAX_MSG_SIZE
from mii.grpc_related.proto import modelresponse_pb2, modelresponse_pb2_grpc
Expand Down Expand Up @@ -37,18 +38,18 @@ def __init__(self, mii_config: MIIConfig, host: str = "localhost") -> None:
channel = create_channel(host, self.port)
self.stub = modelresponse_pb2_grpc.ModelResponseStub(channel)

def __call__(self, *args, **kwargs):
def __call__(self, *args, **kwargs) -> List[Response]:
return self.generate(*args, **kwargs)

async def _request_async_response(self, request_dict, **query_kwargs):
async def _request_async_response(self, prompts, **query_kwargs):
task_methods = TASK_METHODS_DICT[self.task]
proto_request = task_methods.pack_request_to_proto(request_dict, **query_kwargs)
proto_request = task_methods.pack_request_to_proto(prompts, **query_kwargs)
proto_response = await getattr(self.stub, task_methods.method)(proto_request)
return task_methods.unpack_response_from_proto(proto_response)

async def _request_async_response_stream(self, request_dict, **query_kwargs):
async def _request_async_response_stream(self, prompts, **query_kwargs):
task_methods = TASK_METHODS_DICT[self.task]
proto_request = task_methods.pack_request_to_proto(request_dict, **query_kwargs)
proto_request = task_methods.pack_request_to_proto(prompts, **query_kwargs)
assert hasattr(task_methods, "method_stream_out"), f"{self.task} does not support streaming response"
async for response in getattr(self.stub,
task_methods.method_stream_out)(proto_request):
Expand All @@ -59,30 +60,29 @@ def generate(self,
List[str]],
streaming_fn: Callable = None,
**query_kwargs: Dict[str,
Any]):
Any]) -> Union[None,
List[Response]]:
if isinstance(prompts, str):
prompts = [prompts]
if streaming_fn is not None:
if len(prompts) > 1:
raise RuntimeError(
"MII client streaming only supports a single prompt input.")
request_dict = {"query": prompts}
return self._generate_stream(streaming_fn, request_dict, **query_kwargs)
query_kwargs["stream"] = True
return self._generate_stream(streaming_fn, prompts, **query_kwargs)

request_dict = {"query": prompts}
return self.asyncio_loop.run_until_complete(
self._request_async_response(request_dict,
self._request_async_response(prompts,
**query_kwargs))

def _generate_stream(self,
callback,
request_dict: Dict[str,
str],
prompts: List[str],
**query_kwargs: Dict[str,
Any]):
Any]) -> None:
async def put_result():
response_stream = self._request_async_response_stream(
request_dict,
prompts,
**query_kwargs)

while True:
Expand All @@ -94,11 +94,11 @@ async def put_result():

self.asyncio_loop.run_until_complete(put_result())

async def terminate_async(self):
async def terminate_async(self) -> None:
await self.stub.Terminate(
modelresponse_pb2.google_dot_protobuf_dot_empty__pb2.Empty())

def terminate_server(self):
def terminate_server(self) -> None:
self.asyncio_loop.run_until_complete(self.terminate_async())
if self.mii_config.enable_restful_api:
requests.get(
Expand Down
236 changes: 236 additions & 0 deletions mii/batching/data_classes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team
from dataclasses import dataclass, field, asdict
from typing import Any, Dict, List, Iterator, Union
from typing_extensions import Self

import torch

from mii.constants import GenerationFinishReason


@dataclass
class Response:
generated_text: str
prompt_length: int
generated_length: int
finish_reason: GenerationFinishReason

@staticmethod
def from_msg_dict(msg: Dict[str, Union[str, int]]) -> Self:
return Response(**msg)

def to_msg_dict(self) -> Dict[str, Union[str, int]]:
return asdict(self)

def __repr__(self) -> str:
mrwyattii marked this conversation as resolved.
Show resolved Hide resolved
return self.generated_text

def __str__(self) -> str:
return self.generated_text


@dataclass
class RequestMsg:
uid: int
input_tokens: Union[torch.Tensor, List[int]]

@property
def is_flush_request(self):
return self.input_tokens is None

@staticmethod
def from_msg_dict(msg: Dict[str, Any]) -> Self:
input_tokens = msg["input_tokens"]
if input_tokens is not None:
input_tokens = torch.tensor(msg["input_tokens"],
dtype=torch.int32,
device=torch.device("cpu"))
return RequestMsg(uid=msg["uid"], input_tokens=input_tokens)


@dataclass
class Request:
tid: int
uid: int
input_tokens: torch.Tensor
prompt_tokens: torch.Tensor
seq_length: int
max_length: int
max_new_tokens: int
min_new_tokens: int
last_in_prompt: bool
post_processing: List[object]
stream: bool = False
ignore_eos: bool = False
return_full_text: bool = False

_next_token: Union[None, torch.Tensor] = None
_is_done: bool = False
_generated_tokens: List[torch.Tensor] = field(default_factory=list)
_finish_reason: GenerationFinishReason = GenerationFinishReason.NONE

@property
def prompt_length(self) -> int:
return len(self.prompt_tokens)

@property
def next_token(self) -> Union[None, torch.Tensor]:
return self._next_token

@next_token.setter
def next_token(self, next_token: Union[None, torch.Tensor]) -> None:
self._next_token = next_token

@property
def is_done(self) -> bool:
if self.ignore_eos:
return False
if self.seq_length < self.min_new_tokens:
return False
return self._is_done

@is_done.setter
def is_done(self, is_done: bool) -> None:
self._is_done = is_done

@property
def generated_tokens(self) -> List[torch.Tensor]:
return self._generated_tokens

@property
def finish_reason(self) -> GenerationFinishReason:
return self._finish_reason

@property
def is_flush_request(self):
return self.input_tokens is None

@property
def num_generated_tokens(self) -> int:
# We return zero while we are processing decomposed prompts
return self.seq_length - self.prompt_length + 1 if self.seq_length >= self.prompt_length else 0

@property
def stop_generation(self) -> bool:
# Returns whether to stop generation for request
if self.is_done:
self._finish_reason = GenerationFinishReason.STOP
return True
if (self.seq_length >= self.max_length) or (self.num_generated_tokens >=
self.max_new_tokens):
self._finish_reason = GenerationFinishReason.LENGTH
return True
return False

def to_msg_dict(self) -> Dict[str, Any]:
# Returns a minimal version of the request of purposes of broadcasting to all ranks
input_tokens = self.input_tokens
if input_tokens is not None:
input_tokens = self.input_tokens.tolist()
return {"uid": self.uid, "input_tokens": input_tokens}

def accumulate_generated_token(self) -> None:
# Append the latest token to the list of generated tokens
if not self.is_done:
self._generated_tokens.append(self.next_token)

def clear_generated_token(self) -> None:
self._generated_tokens.clear()

def set_next_as_input(self) -> None:
# Places the next token into the input token for next round of generation
if self.next_token is not None:
self.input_tokens = self.next_token.unsqueeze(0)
self.last_in_prompt = True
self.next_token = None
self.is_done = False


class RequestBatch:
def __init__(self, requests: List[Request] = []) -> None:
self.requests = requests

def __len__(self) -> int:
return len(self.requests)

def __contains__(self, r: Request) -> bool:
return r in self.requests

def __nonzero__(self) -> bool:
if len(self.requests) != 0:
return True
return False

def __iter__(self) -> Iterator[Request]:
return iter(self.requests)

def __repr__(self) -> str:
return f"RequestBatch({self.requests})"

@property
def requests_to_run(self) -> Self:
return RequestBatch([r for r in self.requests if not r.is_flush_request])

@property
def requests_to_flush(self) -> Self:
return RequestBatch([r for r in self.requests if r.is_flush_request])

@property
def last_in_prompt(self) -> Self:
return RequestBatch([r for r in self.requests if r.last_in_prompt])

@property
def completed(self) -> Self:
return RequestBatch([r for r in self.requests if r.stop_generation])

@property
def uids(self) -> List[int]:
return [r.uid for r in self.requests]

@property
def lengths(self) -> List[int]:
return [len(r.input_tokens) for r in self.requests]

@property
def tokens(self) -> List[torch.Tensor]:
return [r.input_tokens for r in self.requests]

@property
def next_tokens(self) -> List[torch.Tensor]:
return [r.next_token for r in self.requests]

@property
def done_tokens(self) -> List[torch.Tensor]:
return [r.is_done for r in self.requests]

@next_tokens.setter
def next_tokens(self, next_tokens: List[torch.Tensor]) -> None:
assert len(next_tokens) == len(self.requests)
for idx, r in enumerate(self.requests):
r.next_token = next_tokens[idx]

@done_tokens.setter
def done_tokens(self, done_tokens: List[torch.Tensor]) -> None:
assert len(done_tokens) == len(self.requests)
for idx, r in enumerate(self.requests):
r.is_done = done_tokens[idx]

def to_msg_dicts(self) -> List[Dict[str, Any]]:
return [r.to_msg_dict() for r in self.requests]

@staticmethod
def from_msg_dicts(msg_dicts: List[Dict[str, Any]]) -> Self:
return RequestBatch([RequestMsg.from_msg_dict(msg) for msg in msg_dicts])

def prune(self, uids: List[int]) -> None:
self.requests = [r for r in self.requests if r.uid not in uids]

def append(self, r: Request) -> None:
self.requests.append(r)

def update_seq_length(self) -> None:
for r in self.requests:
r.seq_length += r.input_tokens.size(0)
Loading