Skip to content
This repository has been archived by the owner on Oct 16, 2023. It is now read-only.

refactor batch manager #62

Merged
merged 3 commits into from
May 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 73 additions & 55 deletions energon/server/batch_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,58 +20,6 @@
from concurrent.futures import ThreadPoolExecutor


def generate_cached_cost(engine, model_name: str, pp: int, tp: int,
max_seq_len: int = 1024, max_batch_size: int = 16, step: int = 1,
repeat_round: int = 3, tokenizer=None):
"""
Test the running time for different sequence length and batch size on the current machine.
:param engine: InferenceEngine from energon.engine
:type engine: InferenceEngine
:param max_seq_len: The max sequence length that is measured.
:param max_batch_size: The max batch size that is measured.
:param step: Run time is measured every other 'step' of sequence length
:param repeat_round: We inference current batch 'repeat_round' times and take average.
"""
logging.log(0, "fetching cached cost")
cached_name = "cached_cost_{}_pp{}_tp{}_{}_{}_{}_{}.npy".format(model_name, pp, tp, max_seq_len, max_batch_size, step, repeat_round)
if os.path.exists(cached_name):
logging.log(0, "loading cached cost from file")
cached_cost = np.load(cached_name).tolist()
else:
logging.log(0, "generating new cached cost")
cached_cost = [[0 for i in range(max_batch_size + 1)] for j in range(max_seq_len + 1)]
warm_up_str = "test test test"
if tokenizer:
warm_up_input = tokenizer(warm_up_str, return_tensors="pt")
else:
warm_up_input = warm_up_str
for tt in range(5):
output = engine.run(warm_up_input)
predictions = output.to_here()
input_text = ""
for tmp_len in trange(1, max_seq_len + 1, step):
input_text += "test "
for tmp_batch in range(1, max_batch_size + 1):
batched_text = [input_text for _ in range(tmp_batch)]
start_time = time.time()
for k in range(repeat_round):
if tokenizer:
input_token = tokenizer(batched_text, return_tensors="pt")
else:
input_token = batched_text
output = engine.run(input_token)
predictions = output.to_here()
if tokenizer:
tokenizer.decode(predictions)
time_cost = (time.time() - start_time) / repeat_round
cached_cost[tmp_len][tmp_batch] = time_cost
for k in range(1, step):
cached_cost[tmp_len + k][tmp_batch] = time_cost
np.save(cached_name, np.array(cached_cost))
logging.log(0, "cached cost loaded")
return cached_cost


class single_request:
def __init__(self, input_, time_stamp: float, input_str: str):
"""
Expand Down Expand Up @@ -106,8 +54,11 @@ class Batch_Manager(Manager):
in function cal_priority and then sent into the inference engine.
"""

def __init__(self, engine: InferenceEngine, cached_cost: list, init_mu: int = 512, init_theta: int = 180,
max_batch_size: int = 32, lr: float = 0.01, tokenizer=None, pad_token=None, rm_padding=False):
def __init__(self, engine: InferenceEngine, model_name: str, pp: int, tp: int,
max_sequence_length: int, init_mu: int = 512, init_theta: int = 180,
max_batch_size: int = 32, lr: float = 0.01, tokenizer=None, pad_token=None, rm_padding=False,
step: int = 1, repeat_round: int = 3
):
"""
:param engine: The InferenceEngine from energon.engine
:param cached_cost: The output of function generate_cached_cost
Expand All @@ -126,7 +77,9 @@ def __init__(self, engine: InferenceEngine, cached_cost: list, init_mu: int = 51
self.req_list = []
self.req_list_lock = rwlock.RWLockFair()
self.write_lock = self.req_list_lock.gen_wlock()
self.cached_cost = cached_cost
self.cached_cost = self.generate_cached_cost(engine, model_name, pp=pp, tp=tp, max_seq_len=max_sequence_length,
max_batch_size=max_batch_size, step=step, repeat_round=repeat_round,
tokenizer=tokenizer)
self.tokenizer = tokenizer
self.rm_padding = rm_padding
if self.tokenizer and pad_token:
Expand All @@ -137,6 +90,71 @@ def __init__(self, engine: InferenceEngine, cached_cost: list, init_mu: int = 51
self.main_thread = threading.Thread(target=self.processing_batch)
self.main_thread.start()

def generate_cached_cost(self, engine, model_name: str, pp: int, tp: int,
max_seq_len: int = 1024, max_batch_size: int = 16, step: int = 1,
repeat_round: int = 3, tokenizer=None):
"""
Test the running time for different sequence length and batch size on the current machine.
:param engine: InferenceEngine from energon.engine
:type engine: InferenceEngine
:param max_seq_len: The max sequence length that is measured.
:param max_batch_size: The max batch size that is measured.
:param step: Run time is measured every other 'step' of sequence length
:param repeat_round: We inference current batch 'repeat_round' times and take average.
"""
logging.log(0, "fetching cached cost")
cached_name = "cached_cost_{}_pp{}_tp{}_{}_{}_{}_{}.npy".format(model_name, pp, tp, max_seq_len, max_batch_size,
step, repeat_round)
if os.path.exists(cached_name):
logging.log(0, "loading cached cost from file")
cached_cost = np.load(cached_name).tolist()
else:
logging.log(0, "generating new cached cost")
cached_cost = [[0 for i in range(max_batch_size + 1)] for j in range(max_seq_len + 1)]
warm_up_str = "test test test"
if tokenizer:
warm_up_input = tokenizer(warm_up_str, return_tensors="pt")
else:
warm_up_input = warm_up_str
for tt in range(5):
output = engine.run(warm_up_input)
predictions = output.to_here()
input_text = ""
for tmp_len in trange(1, max_seq_len + 1, step):
input_text += "test "
for tmp_batch in range(1, max_batch_size + 1):
batched_text = [input_text for _ in range(tmp_batch)]
start_time = time.time()
for k in range(repeat_round):
if tokenizer:
input_token = tokenizer(batched_text, return_tensors="pt")
else:
input_token = batched_text
output = engine.run(input_token)
predictions = output.to_here()
if tokenizer:
tokenizer.decode(predictions)
time_cost = (time.time() - start_time) / repeat_round
cached_cost[tmp_len][tmp_batch] = time_cost
for k in range(1, step):
cached_cost[tmp_len + k][tmp_batch] = time_cost
np.save(cached_name, np.array(cached_cost))
logging.log(0, "cached cost loaded")
return cached_cost

def subscribe_result(self, time_stamp):
# red = redis.StrictRedis('localhost', 6379, charset="utf-8", decode_responses=True)
# sub = red.pubsub()
sub = self.publisher.pubsub()
sub.subscribe(str(time_stamp))
predictions = ''
for message in sub.listen():
if message is not None and isinstance(message, dict):
predictions = message.get('data')
if not isinstance(predictions, int):
break
return predictions

def insert_req(self, time_stamp: float, input_ids, input_str: str):
"""
Build a single_request class with the input string and then insert it into the queue.
Expand Down
15 changes: 14 additions & 1 deletion energon/server/naive_batch_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,19 @@ def insert_req(self, time_stamp: float, input_ids, input_str: str):
self.req_list.append(tmp_req)
self.write_lock.release()

def subscribe_result(self, time_stamp):
# red = redis.StrictRedis('localhost', 6379, charset="utf-8", decode_responses=True)
# sub = red.pubsub()
sub = self.publisher.pubsub()
sub.subscribe(str(time_stamp))
predictions = ''
for message in sub.listen():
if message is not None and isinstance(message, dict):
predictions = message.get('data')
if not isinstance(predictions, int):
break
return predictions

def wrap_batch(self):
"""
Given a sorted sequence list, calculate the best way to wrap the batch with DP according to the
Expand Down Expand Up @@ -120,7 +133,7 @@ def processing_batch(self):
# self.publish_result(output, target_batch, start_time)
# pub_thread = threading.Thread(target=self.publish_result, args=(output, target_batch, start_time))
# pub_thread.start()
time.sleep(0.05)
time.sleep(0.08)

def publish_result(self, output, target_batch):
"""
Expand Down
36 changes: 11 additions & 25 deletions examples/gpt/gpt_batch_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from fastapi import Response, Body
import torch.distributed.rpc as rpc
from energon.engine import InferenceEngine
from energon.server.batch_manager import Batch_Manager, generate_cached_cost, Manager
from energon.server.batch_manager import Batch_Manager, Manager
from energon.server.naive_batch_manager import Naive_Batch_Manager

app = FastAPI()
Expand All @@ -21,19 +21,10 @@ def root():

@app.post("/model_with_padding_naive")
def run_without_batch(input_str: str = Body(..., title="input_str", embed=True)):
red = redis.StrictRedis('localhost', 6379, charset="utf-8", decode_responses=True)
sub = red.pubsub()
input_token = tokenizer(input_str, return_tensors="pt")
time_stamp = time.time()
naive_manager.insert_req(time_stamp, input_token, input_str)
sub.subscribe(str(time_stamp))
predictions = input_str
for message in sub.listen():
if message is not None and isinstance(message, dict):
predictions = message.get('data')
if not isinstance(predictions, int):
break

predictions = naive_manager.subscribe_result(time_stamp)
return {predictions}

@app.post("/model_with_padding")
Expand All @@ -43,22 +34,14 @@ def run(
"""Receive user request with post function. The input string is sent to the batch manager
and then the result will be sent back with Redis pub-sub. The time stamp is used as the
channel name that the current request process subscribes."""
red = redis.StrictRedis('localhost', 6379, charset="utf-8", decode_responses=True)
sub = red.pubsub()
input_token = tokenizer(input_str, return_tensors="pt")
time_stamp = time.time()
batch_manager.insert_req(time_stamp, input_token, input_str)
sub.subscribe(str(time_stamp))
predictions = input_str
for message in sub.listen():
if message is not None and isinstance(message, dict):
predictions = message.get('data')
if not isinstance(predictions, int):
break

predictions = batch_manager.subscribe_result(time_stamp)
return {predictions}



@app.get("/shutdown")
async def shutdown():
engine.clear()
Expand Down Expand Up @@ -103,12 +86,15 @@ def launch_engine(model_class,
port=port,
dtype=dtype)

global cached_cost
cached_cost = generate_cached_cost(engine, max_seq_len=1024, max_batch_size=4, step=8, repeat_round=2, tokenizer=tokenizer)
# global cached_cost
# cached_cost = generate_cached_cost(engine, max_seq_len=1024, max_batch_size=4, step=8, repeat_round=2, tokenizer=tokenizer)

global batch_manager
batch_manager = Batch_Manager(engine, cached_cost, max_batch_size=4, tokenizer=tokenizer,
pad_token=GPT2Tokenizer.eos_token)
batch_manager = Batch_Manager(engine, model_name="gpt2_8B", pp=4, tp=2,
max_sequence_length=256,
max_batch_size=16, tokenizer=tokenizer,
pad_token=GPT2Tokenizer.eos_token,
step=8, repeat_round=2)

global naive_manager
naive_manager = Naive_Batch_Manager(engine, max_batch_size=4, tokenizer=tokenizer,
Expand Down