From 723dbacef1146efa19de338329a3e2af8b5f6444 Mon Sep 17 00:00:00 2001 From: maruyama Date: Tue, 17 May 2022 14:23:23 +0800 Subject: [PATCH 1/3] refactor the structure of batch manager --- energon/server/batch_manager.py | 127 +++++++++++++++----------- energon/server/naive_batch_manager.py | 14 ++- examples/gpt/gpt_batch_server.py | 36 +++----- 3 files changed, 96 insertions(+), 81 deletions(-) diff --git a/energon/server/batch_manager.py b/energon/server/batch_manager.py index 30a2f4a..18c6c9e 100644 --- a/energon/server/batch_manager.py +++ b/energon/server/batch_manager.py @@ -20,58 +20,6 @@ from concurrent.futures import ThreadPoolExecutor -def generate_cached_cost(engine, model_name: str, pp: int, tp: int, - max_seq_len: int = 1024, max_batch_size: int = 16, step: int = 1, - repeat_round: int = 3, tokenizer=None): - """ - Test the running time for different sequence length and batch size on the current machine. - :param engine: InferenceEngine from energon.engine - :type engine: InferenceEngine - :param max_seq_len: The max sequence length that is measured. - :param max_batch_size: The max batch size that is measured. - :param step: Run time is measured every other 'step' of sequence length - :param repeat_round: We inference current batch 'repeat_round' times and take average. - """ - logging.log(0, "fetching cached cost") - cached_name = "cached_cost_{}_pp{}_tp{}_{}_{}_{}_{}.npy".format(model_name, pp, tp, max_seq_len, max_batch_size, step, repeat_round) - if os.path.exists(cached_name): - logging.log(0, "loading cached cost from file") - cached_cost = np.load(cached_name).tolist() - else: - logging.log(0, "generating new cached cost") - cached_cost = [[0 for i in range(max_batch_size + 1)] for j in range(max_seq_len + 1)] - warm_up_str = "test test test" - if tokenizer: - warm_up_input = tokenizer(warm_up_str, return_tensors="pt") - else: - warm_up_input = warm_up_str - for tt in range(5): - output = engine.run(warm_up_input) - predictions = output.to_here() - input_text = "" - for tmp_len in trange(1, max_seq_len + 1, step): - input_text += "test " - for tmp_batch in range(1, max_batch_size + 1): - batched_text = [input_text for _ in range(tmp_batch)] - start_time = time.time() - for k in range(repeat_round): - if tokenizer: - input_token = tokenizer(batched_text, return_tensors="pt") - else: - input_token = batched_text - output = engine.run(input_token) - predictions = output.to_here() - if tokenizer: - tokenizer.decode(predictions) - time_cost = (time.time() - start_time) / repeat_round - cached_cost[tmp_len][tmp_batch] = time_cost - for k in range(1, step): - cached_cost[tmp_len + k][tmp_batch] = time_cost - np.save(cached_name, np.array(cached_cost)) - logging.log(0, "cached cost loaded") - return cached_cost - - class single_request: def __init__(self, input_, time_stamp: float, input_str: str): """ @@ -106,8 +54,11 @@ class Batch_Manager(Manager): in function cal_priority and then sent into the inference engine. """ - def __init__(self, engine: InferenceEngine, cached_cost: list, init_mu: int = 512, init_theta: int = 180, - max_batch_size: int = 32, lr: float = 0.01, tokenizer=None, pad_token=None, rm_padding=False): + def __init__(self, engine: InferenceEngine, model_name: str, pp: int, tp: int, + max_sequence_length: int, init_mu: int = 512, init_theta: int = 180, + max_batch_size: int = 32, lr: float = 0.01, tokenizer=None, pad_token=None, rm_padding=False, + step: int = 1, repeat_round: int = 3 + ): """ :param engine: The InferenceEngine from energon.engine :param cached_cost: The output of function generate_cached_cost @@ -126,7 +77,9 @@ def __init__(self, engine: InferenceEngine, cached_cost: list, init_mu: int = 51 self.req_list = [] self.req_list_lock = rwlock.RWLockFair() self.write_lock = self.req_list_lock.gen_wlock() - self.cached_cost = cached_cost + self.cached_cost = self.generate_cached_cost(engine, model_name, pp=pp, tp=tp, max_seq_len=max_sequence_length, + max_batch_size=max_batch_size, step=step, repeat_round=repeat_round, + tokenizer=tokenizer) self.tokenizer = tokenizer self.rm_padding = rm_padding if self.tokenizer and pad_token: @@ -137,6 +90,70 @@ def __init__(self, engine: InferenceEngine, cached_cost: list, init_mu: int = 51 self.main_thread = threading.Thread(target=self.processing_batch) self.main_thread.start() + def generate_cached_cost(self, engine, model_name: str, pp: int, tp: int, + max_seq_len: int = 1024, max_batch_size: int = 16, step: int = 1, + repeat_round: int = 3, tokenizer=None): + """ + Test the running time for different sequence length and batch size on the current machine. + :param engine: InferenceEngine from energon.engine + :type engine: InferenceEngine + :param max_seq_len: The max sequence length that is measured. + :param max_batch_size: The max batch size that is measured. + :param step: Run time is measured every other 'step' of sequence length + :param repeat_round: We inference current batch 'repeat_round' times and take average. + """ + logging.log(0, "fetching cached cost") + cached_name = "cached_cost_{}_pp{}_tp{}_{}_{}_{}_{}.npy".format(model_name, pp, tp, max_seq_len, max_batch_size, + step, repeat_round) + if os.path.exists(cached_name): + logging.log(0, "loading cached cost from file") + cached_cost = np.load(cached_name).tolist() + else: + logging.log(0, "generating new cached cost") + cached_cost = [[0 for i in range(max_batch_size + 1)] for j in range(max_seq_len + 1)] + warm_up_str = "test test test" + if tokenizer: + warm_up_input = tokenizer(warm_up_str, return_tensors="pt") + else: + warm_up_input = warm_up_str + for tt in range(5): + output = engine.run(warm_up_input) + predictions = output.to_here() + input_text = "" + for tmp_len in trange(1, max_seq_len + 1, step): + input_text += "test " + for tmp_batch in range(1, max_batch_size + 1): + batched_text = [input_text for _ in range(tmp_batch)] + start_time = time.time() + for k in range(repeat_round): + if tokenizer: + input_token = tokenizer(batched_text, return_tensors="pt") + else: + input_token = batched_text + output = engine.run(input_token) + predictions = output.to_here() + if tokenizer: + tokenizer.decode(predictions) + time_cost = (time.time() - start_time) / repeat_round + cached_cost[tmp_len][tmp_batch] = time_cost + for k in range(1, step): + cached_cost[tmp_len + k][tmp_batch] = time_cost + np.save(cached_name, np.array(cached_cost)) + logging.log(0, "cached cost loaded") + return cached_cost + + def subscribe_result(self, time_stamp): + red = redis.StrictRedis('localhost', 6379, charset="utf-8", decode_responses=True) + sub = red.pubsub() + sub.subscribe(str(time_stamp)) + predictions = '' + for message in sub.listen(): + if message is not None and isinstance(message, dict): + predictions = message.get('data') + if not isinstance(predictions, int): + break + return predictions + def insert_req(self, time_stamp: float, input_ids, input_str: str): """ Build a single_request class with the input string and then insert it into the queue. diff --git a/energon/server/naive_batch_manager.py b/energon/server/naive_batch_manager.py index 2f91588..4dc3879 100644 --- a/energon/server/naive_batch_manager.py +++ b/energon/server/naive_batch_manager.py @@ -84,6 +84,18 @@ def insert_req(self, time_stamp: float, input_ids, input_str: str): self.req_list.append(tmp_req) self.write_lock.release() + def subscribe_result(self, time_stamp): + red = redis.StrictRedis('localhost', 6379, charset="utf-8", decode_responses=True) + sub = red.pubsub() + sub.subscribe(str(time_stamp)) + predictions = '' + for message in sub.listen(): + if message is not None and isinstance(message, dict): + predictions = message.get('data') + if not isinstance(predictions, int): + break + return predictions + def wrap_batch(self): """ Given a sorted sequence list, calculate the best way to wrap the batch with DP according to the @@ -120,7 +132,7 @@ def processing_batch(self): # self.publish_result(output, target_batch, start_time) # pub_thread = threading.Thread(target=self.publish_result, args=(output, target_batch, start_time)) # pub_thread.start() - time.sleep(0.05) + time.sleep(0.08) def publish_result(self, output, target_batch): """ diff --git a/examples/gpt/gpt_batch_server.py b/examples/gpt/gpt_batch_server.py index 5777c6e..1e11ea0 100644 --- a/examples/gpt/gpt_batch_server.py +++ b/examples/gpt/gpt_batch_server.py @@ -9,7 +9,7 @@ from fastapi import Response, Body import torch.distributed.rpc as rpc from energon.engine import InferenceEngine -from energon.server.batch_manager import Batch_Manager, generate_cached_cost, Manager +from energon.server.batch_manager import Batch_Manager, Manager from energon.server.naive_batch_manager import Naive_Batch_Manager app = FastAPI() @@ -21,19 +21,10 @@ def root(): @app.post("/model_with_padding_naive") def run_without_batch(input_str: str = Body(..., title="input_str", embed=True)): - red = redis.StrictRedis('localhost', 6379, charset="utf-8", decode_responses=True) - sub = red.pubsub() input_token = tokenizer(input_str, return_tensors="pt") time_stamp = time.time() naive_manager.insert_req(time_stamp, input_token, input_str) - sub.subscribe(str(time_stamp)) - predictions = input_str - for message in sub.listen(): - if message is not None and isinstance(message, dict): - predictions = message.get('data') - if not isinstance(predictions, int): - break - + predictions = batch_manager.subscribe_result(time_stamp) return {predictions} @app.post("/model_with_padding") @@ -43,22 +34,14 @@ def run( """Receive user request with post function. The input string is sent to the batch manager and then the result will be sent back with Redis pub-sub. The time stamp is used as the channel name that the current request process subscribes.""" - red = redis.StrictRedis('localhost', 6379, charset="utf-8", decode_responses=True) - sub = red.pubsub() input_token = tokenizer(input_str, return_tensors="pt") time_stamp = time.time() batch_manager.insert_req(time_stamp, input_token, input_str) - sub.subscribe(str(time_stamp)) - predictions = input_str - for message in sub.listen(): - if message is not None and isinstance(message, dict): - predictions = message.get('data') - if not isinstance(predictions, int): - break - + predictions = batch_manager.subscribe_result(time_stamp) return {predictions} + @app.get("/shutdown") async def shutdown(): engine.clear() @@ -103,12 +86,15 @@ def launch_engine(model_class, port=port, dtype=dtype) - global cached_cost - cached_cost = generate_cached_cost(engine, max_seq_len=1024, max_batch_size=4, step=8, repeat_round=2, tokenizer=tokenizer) + # global cached_cost + # cached_cost = generate_cached_cost(engine, max_seq_len=1024, max_batch_size=4, step=8, repeat_round=2, tokenizer=tokenizer) global batch_manager - batch_manager = Batch_Manager(engine, cached_cost, max_batch_size=4, tokenizer=tokenizer, - pad_token=GPT2Tokenizer.eos_token) + batch_manager = Batch_Manager(engine, model_name="gpt2_8B", pp=4, tp=2, + max_sequence_length=256, + max_batch_size=16, tokenizer=tokenizer, + pad_token=GPT2Tokenizer.eos_token, + step=8, repeat_round=2) global naive_manager naive_manager = Naive_Batch_Manager(engine, max_batch_size=4, tokenizer=tokenizer, From ccd76d6260edda474f1d3ece56ea9c179f71db1c Mon Sep 17 00:00:00 2001 From: maruyama Date: Tue, 17 May 2022 14:26:01 +0800 Subject: [PATCH 2/3] refactor the structure of batch manager --- examples/gpt/gpt_batch_server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/gpt/gpt_batch_server.py b/examples/gpt/gpt_batch_server.py index 1e11ea0..c02f1a2 100644 --- a/examples/gpt/gpt_batch_server.py +++ b/examples/gpt/gpt_batch_server.py @@ -24,7 +24,7 @@ def run_without_batch(input_str: str = Body(..., title="input_str", embed=True)) input_token = tokenizer(input_str, return_tensors="pt") time_stamp = time.time() naive_manager.insert_req(time_stamp, input_token, input_str) - predictions = batch_manager.subscribe_result(time_stamp) + predictions = naive_manager.subscribe_result(time_stamp) return {predictions} @app.post("/model_with_padding") From a7a0054d46e6cbe3c905874d1eec528ec0638483 Mon Sep 17 00:00:00 2001 From: maruyama Date: Tue, 17 May 2022 14:33:35 +0800 Subject: [PATCH 3/3] refactor the structure of batch manager --- energon/server/batch_manager.py | 5 +++-- energon/server/naive_batch_manager.py | 5 +++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/energon/server/batch_manager.py b/energon/server/batch_manager.py index 18c6c9e..6ca6cc5 100644 --- a/energon/server/batch_manager.py +++ b/energon/server/batch_manager.py @@ -143,8 +143,9 @@ def generate_cached_cost(self, engine, model_name: str, pp: int, tp: int, return cached_cost def subscribe_result(self, time_stamp): - red = redis.StrictRedis('localhost', 6379, charset="utf-8", decode_responses=True) - sub = red.pubsub() + # red = redis.StrictRedis('localhost', 6379, charset="utf-8", decode_responses=True) + # sub = red.pubsub() + sub = self.publisher.pubsub() sub.subscribe(str(time_stamp)) predictions = '' for message in sub.listen(): diff --git a/energon/server/naive_batch_manager.py b/energon/server/naive_batch_manager.py index 4dc3879..9ece836 100644 --- a/energon/server/naive_batch_manager.py +++ b/energon/server/naive_batch_manager.py @@ -85,8 +85,9 @@ def insert_req(self, time_stamp: float, input_ids, input_str: str): self.write_lock.release() def subscribe_result(self, time_stamp): - red = redis.StrictRedis('localhost', 6379, charset="utf-8", decode_responses=True) - sub = red.pubsub() + # red = redis.StrictRedis('localhost', 6379, charset="utf-8", decode_responses=True) + # sub = red.pubsub() + sub = self.publisher.pubsub() sub.subscribe(str(time_stamp)) predictions = '' for message in sub.listen():