Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add generation server scripts using HF accelerate and DS-inference #328

Merged
merged 33 commits into from
Sep 1, 2022
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
remove MaxTokensError
Mayank Mishra authored and Mayank Mishra committed Aug 16, 2022
commit 46ade324888c2280406d6c7e8cd25777fc04aed8
23 changes: 15 additions & 8 deletions inference/benchmark.py
Original file line number Diff line number Diff line change
@@ -12,7 +12,15 @@
from ds_inference import DSInferenceModel
from ds_zero import DSZeROModel
from hf_accelerate import HFAccelerateModel
from utils import Execute, Model, get_argument_parser, get_dummy_batch, print_rank_n, GenerateRequest
from utils import (
Execute,
GenerateRequest,
Model,
get_argument_parser,
get_dummy_batch,
parse_generate_kwargs,
print_rank_n
)


def run_and_log_time(execs: Union[List[Execute], Execute]) -> Union[List[Any], float]:
@@ -67,17 +75,16 @@ def benchmark_end_to_end(args: argparse.Namespace,
Execute(model_class, {"args": args})
)

print_rank_n(
f"*** Starting to generate {args.generate_kwargs['max_new_tokens']} tokens with bs={args.batch_size}")

input_sentences = get_dummy_batch(args.batch_size)
request = parse_generate_kwargs(
get_dummy_batch(args.batch_size),
args.generate_kwargs
)

print_rank_n(f"Generate args {args.generate_kwargs}")
print_rank_n(f"generate_kwargs = {request}")
print_rank_n(f"batch_size = {args.batch_size}")

# warmup is a must if measuring speed as it's when all the optimizations are performed
# e.g. on 8x80 a100 the first pass of 100 tokens takes 23sec, and the next one is 4secs
request = GenerateRequest(input_sentences, args.generate_kwargs)

response, generation_time = run_and_log_time(
Execute(model.generate, {"request": request}))

9 changes: 2 additions & 7 deletions inference/cli.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import argparse
import json

import deepspeed

import constants
import utils
from ds_inference import DSInferenceGRPCServer
@@ -48,7 +46,6 @@ def main() -> None:
f"Unknown deployment framework {args.deployment_framework}")

generate_kwargs = args.generate_kwargs
request = parse_generate_kwargs(generate_kwargs)

while (True):
# currently only 1 process is running so its
@@ -61,11 +58,9 @@ def main() -> None:
model.shutdown()

if (input("change generate_kwargs? [y/n] ") == "y"):
generate_kwargs = input("Generate kwargs: ")
generate_kwargs = json.loads(generate_kwargs)
request = parse_generate_kwargs(generate_kwargs)
generate_kwargs = json.loads(input("generate_kwargs: "))

request.text = input_text
request = parse_generate_kwargs(input_text, generate_kwargs)
response = model.generate(request)

print_rank_n("Output text:", response.text)
29 changes: 2 additions & 27 deletions inference/ds_inference/grpc_server.py
Original file line number Diff line number Diff line change
@@ -6,7 +6,7 @@
from transformers import AutoTokenizer

import mii
from utils import GenerateRequest, GenerateResponse, Model, get_str_dtype
from utils import GenerateRequest, GenerateResponse, Model, get_filter_dict, get_str_dtype


class DSInferenceGRPCServer(Model):
@@ -51,32 +51,7 @@ def generate(self, request: GenerateRequest) -> GenerateResponse:

output_text = self.model.query(
{"query": text},
min_length=request.min_length,
do_sample=request.do_sample,
early_stopping=request.early_stopping,
num_beams=request.num_beams,
temperature=request.temperature,
top_k=request.top_k,
top_p=request.top_p,
typical_p=request.typical_p,
repitition_penalty=request.repitition_penalty,
bos_token_id=request.bos_token_id,
pad_token_id=request.pad_token_id,
eos_token_id=request.eos_token_id,
length_penalty=request.length_penalty,
no_repeat_ngram_size=request.no_repeat_ngram_size,
encoder_no_repeat_ngram_size=request.encoder_no_repeat_ngram_size,
num_return_sequences=request.num_return_sequences,
max_time=request.max_time,
max_new_tokens=request.max_new_tokens,
decoder_start_token_id=request.decoder_start_token_id,
num_beam_groups=request.num_beam_groups,
diversity_penalty=request.diversity_penalty,
forced_bos_token_id=request.forced_bos_token_id,
forced_eos_token_id=request.forced_eos_token_id,
exponential_decay_length_penalty=request.exponential_decay_length_penalty,
bad_words_ids=request.bad_words_ids,
force_words_ids=request.force_words_ids
**get_filter_dict(request)
).response

output_text = [_ for _ in output_text]
9 changes: 3 additions & 6 deletions inference/server.py
Original file line number Diff line number Diff line change
@@ -7,7 +7,7 @@
from ds_inference import DSInferenceGRPCServer
from fastapi import FastAPI, HTTPException
from hf_accelerate import HFAccelerateModel
from utils import GenerateRequest, MaxTokensError, get_argument_parser
from utils import GenerateRequest, get_argument_parser, get_num_tokens_to_generate
from uvicorn import run


@@ -69,11 +69,8 @@ def generate(request: GenerateRequest) -> dict:
try:
start_time = time.time()

if (request.max_new_tokens > args.allowed_max_new_tokens):
raise MaxTokensError(
request.max_new_tokens,
args.allowed_max_new_tokens
)
request.max_new_tokens = get_num_tokens_to_generate(
request.max_new_tokens, args.allowed_max_new_tokens)

response = model.generate(request)
response.query_id = query_id
5 changes: 2 additions & 3 deletions inference/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
from .model import Model
from .requests import GenerateRequest, GenerateResponse
from .requests import GenerateRequest, GenerateResponse, get_filter_dict, parse_generate_kwargs
from .utils import (
Execute,
MaxTokensError,
get_args,
get_argument_parser,
get_dummy_batch,
get_num_tokens_to_generate,
get_str_dtype,
parse_generate_kwargs,
print_rank_n,
run_rank_n
)
127 changes: 69 additions & 58 deletions inference/utils/requests.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,8 @@
from typing import List, Union
from typing import Any, List, Union

from pydantic import BaseModel


def parse_bool(value: str) -> bool:
if (value.lower() == "true"):
return True
elif (value.lower() == "false"):
return False
else:
raise ValueError("{} is not a valid boolean value".format(value))


def parse_field(kwargs: dict, field: str, dtype: int, default_value: Any = None) -> Any:
if (field in kwargs):
if (dtype == bool):
return parse_bool(kwargs[field])
else:
return dtype(kwargs[field])
else:
return default_value


class GenerateRequest(BaseModel):
text: Union[List[str], str]
min_length: int = None
@@ -52,47 +33,77 @@ class GenerateRequest(BaseModel):
force_words_ids: Union[List[int], List[List[int]]] = None
remove_input_from_output: bool = False

def __init__(self, text: Union[List[str], str], kwargs: dict) -> None:
self.text = text
self.min_length = parse_field(kwargs, "min_length", int)
self.do_sample = parse_field(kwargs, "do_sample", bool)
self.early_stopping = parse_field(kwargs, "early_stopping", bool)
self.num_beams = parse_field(kwargs, "num_beams", int)
self.temperature = parse_field(kwargs, "temperature", float)
self.top_k = parse_field(kwargs, "top_k", int)
self.top_p = parse_field(kwargs, "top_p", float)
self.typical_p = parse_field(kwargs, "typical_p", float)
self.repitition_penalty = parse_field(
kwargs, "repitition_penalty", float)
self.bos_token_id = parse_field(kwargs, "bos_token_id", int)
self.pad_token_id = parse_field(kwargs, "pad_token_id", int)
self.eos_token_id = parse_field(kwargs, "eos_token_id", int)
self.length_penalty = parse_field(kwargs, "length_penalty", float)
self.no_repeat_ngram_size = parse_field(
kwargs, "no_repeat_ngram_size", int)
self.encoder_no_repeat_ngram_size = parse_field(
kwargs, "encoder_no_repeat_ngram_size", int)
self.num_return_sequences = parse_field(
kwargs, "num_return_sequences", int)
self.max_time = parse_field(kwargs, "max_time", float)
self.max_new_tokens = parse_field(kwargs, "max_new_tokens", int)
self.decoder_start_token_id = parse_field(
kwargs, "decoder_start_token_id", int)
self.num_beam_group = parse_field(kwargs, "num_beam_group", int)
self.diversity_penalty = parse_field(
kwargs, "diversity_penalty", float)
self.forced_bos_token_id = parse_field(
kwargs, "forced_bos_token_id", int)
self.forced_eos_token_id = parse_field(
kwargs, "forced_eos_token_id", int)
self.exponential_decay_length_penalty = parse_field(
kwargs, "exponential_decay_length_penalty", float),
self.remove_input_from_output = parse_field(
kwargs, "remove_input_from_output", bool, False)


class GenerateResponse(BaseModel):
text: Union[List[str], str] = None
num_generated_tokens: Union[List[int], int] = None
query_id: int = None
total_time_taken: float = None


def parse_bool(value: str) -> bool:
if (value.lower() == "true"):
return True
elif (value.lower() == "false"):
return False
else:
raise ValueError("{} is not a valid boolean value".format(value))


def parse_field(kwargs: dict,
field: str,
dtype: int,
default_value: Any = None) -> Any:
if (field in kwargs):
if (type(kwargs[field]) == dtype):
return kwargs[field]
elif (dtype == bool):
return parse_bool(kwargs[field])
else:
return dtype(kwargs[field])
else:
return default_value


def parse_generate_kwargs(text: Union[List[str], str], kwargs: dict) -> GenerateRequest:
return GenerateRequest(
text=text,
min_length=parse_field(kwargs, "min_length", int),
do_sample=parse_field(kwargs, "do_sample", bool),
early_stopping=parse_field(kwargs, "early_stopping", bool),
num_beams=parse_field(kwargs, "num_beams", int),
temperature=parse_field(kwargs, "temperature", float),
top_k=parse_field(kwargs, "top_k", int),
top_p=parse_field(kwargs, "top_p", float),
typical_p=parse_field(kwargs, "typical_p", float),
repitition_penalty=parse_field(kwargs, "repitition_penalty", float),
bos_token_id=parse_field(kwargs, "bos_token_id", int),
pad_token_id=parse_field(kwargs, "pad_token_id", int),
eos_token_id=parse_field(kwargs, "eos_token_id", int),
length_penalty=parse_field(kwargs, "length_penalty", float),
no_repeat_ngram_size=parse_field(kwargs, "no_repeat_ngram_size", int),
encoder_no_repeat_ngram_size=parse_field(
kwargs, "encoder_no_repeat_ngram_size", int),
num_return_sequences=parse_field(kwargs, "num_return_sequences", int),
max_time=parse_field(kwargs, "max_time", float),
max_new_tokens=parse_field(kwargs, "max_new_tokens", int),
decoder_start_token_id=parse_field(
kwargs, "decoder_start_token_id", int),
num_beam_group=parse_field(kwargs, "num_beam_group", int),
diversity_penalty=parse_field(kwargs, "diversity_penalty", float),
forced_bos_token_id=parse_field(kwargs, "forced_bos_token_id", int),
forced_eos_token_id=parse_field(kwargs, "forced_eos_token_id", int),
exponential_decay_length_penalty=parse_field(
kwargs, "exponential_decay_length_penalty", float),
remove_input_from_output=parse_field(
kwargs, "remove_input_from_output", bool, False)
)


def get_filter_dict(d: BaseModel) -> dict:
d = dict(d)
q = {}
for i in d:
if (d[i] != None):
q[i] = d[i]
return q
14 changes: 8 additions & 6 deletions inference/utils/utils.py
Original file line number Diff line number Diff line change
@@ -21,12 +21,6 @@
]


class MaxTokensError(Exception):
def __init__(self, max_new_tokens: int, allowed_max_new_tokens: int) -> None:
super().__init__("max_new_tokens = {} > {} is not supported.".format(
max_new_tokens, allowed_max_new_tokens))


class Execute:
def __init__(self, func: callable, kwargs: dict) -> None:
self.func = func
@@ -113,3 +107,11 @@ def get_dummy_batch(batch_size: int, input_sentences: List[str] = None) -> List[
input_sentences = input_sentences[:batch_size]

return input_sentences


def get_num_tokens_to_generate(max_new_tokens: int,
allowed_max_new_tokens: int) -> int:
if (max_new_tokens == None):
return allowed_max_new_tokens
else:
return min(max_new_tokens, allowed_max_new_tokens)