diff --git a/energon/server/batch_manager.py b/energon/server/batch_manager.py index 73b44a0..98e37a7 100644 --- a/energon/server/batch_manager.py +++ b/energon/server/batch_manager.py @@ -8,7 +8,6 @@ from scipy import stats import numpy as np from energon.engine import InferenceEngine -from transformers import GPT2Tokenizer import random import redis import os @@ -19,7 +18,7 @@ def generate_cached_cost(engine, max_seq_len: int = 1024, max_batch_size: int = 16, step: int = 1, - repeat_round: int = 3): + repeat_round: int = 3, tokenizer=None): """ Test the running time for different sequence length and batch size on the current machine. :param engine: InferenceEngine from energon.engine @@ -29,20 +28,6 @@ def generate_cached_cost(engine, max_seq_len: int = 1024, max_batch_size: int = :param step: Run time is measured every other 'step' of sequence length :param repeat_round: We inference current batch 'repeat_round' times and take average. """ - - def select_top_k(temp_predictions, top_k: int = 10): - """ - Pick out a word from the top k of 50257 words according to the possibility given by temp_predictions - for each sequence in this batch. - :param temp_predictions: Transformer output tensor with size of (batch size, sequence length, vocab size) - which contains the possibilities for each word in this batch. - :type temp_predictions: torch.Tensor - :param top_k: How many top words to choose from. - """ - temp_predicted_index = random.choice( - temp_predictions[0, -1, :].sort(descending=True)[1][:top_k]).item() - return temp_predicted_index - logging.log(0, "fetching cached cost") cached_name = "cached_cost_{}_{}_{}_{}.npy".format(max_seq_len, max_batch_size, step, repeat_round) if os.path.exists(cached_name): @@ -52,18 +37,20 @@ def select_top_k(temp_predictions, top_k: int = 10): logging.log(0, "generating new cached cost") cached_cost = [[0 for i in range(max_batch_size + 1)] for j in range(max_seq_len + 1)] input_text = "" - tokenizer = GPT2Tokenizer.from_pretrained("./") for tmp_len in trange(1, max_seq_len + 1, step): input_text += "test " for tmp_batch in range(1, max_batch_size + 1): batched_text = [input_text for _ in range(tmp_batch)] start_time = time.time() for k in range(repeat_round): - input_token = tokenizer(batched_text, return_tensors="pt") + if tokenizer: + input_token = tokenizer(batched_text, return_tensors="pt") + else: + input_token = batched_text output = engine.run(input_token) predictions = output.to_here() - predicted_index = select_top_k(predictions, k=1) - tokenizer.decode(predicted_index) + if tokenizer: + tokenizer.decode(predictions) time_cost = (time.time() - start_time) / repeat_round cached_cost[tmp_len][tmp_batch] = time_cost for k in range(1, step): @@ -92,6 +79,7 @@ class Manager: """ Base class of batch manager. """ + def __init__(self): pass @@ -105,8 +93,9 @@ class Batch_Manager(Manager): queue is wrapped into batches according to the sequence length and the priority calculated with the equation in function cal_priority and then sent into the inference engine. """ + def __init__(self, engine: InferenceEngine, cached_cost: list, init_mu: int = 512, init_theta: int = 180, - max_batch_size: int = 32, lr: float = 0.01): + max_batch_size: int = 32, lr: float = 0.01, tokenizer=None, pad_token=None): """ :param engine: The InferenceEngine from energon.engine :param cached_cost: The output of function generate_cached_cost @@ -126,8 +115,9 @@ def __init__(self, engine: InferenceEngine, cached_cost: list, init_mu: int = 51 self.req_list_lock = rwlock.RWLockFair() self.write_lock = self.req_list_lock.gen_wlock() self.cached_cost = cached_cost - self.tokenizer = GPT2Tokenizer.from_pretrained('/home/lcdjs/hf_gpt2') - self.tokenizer.pad_token = GPT2Tokenizer.eos_token + self.tokenizer = tokenizer + if self.tokenizer and pad_token: + self.tokenizer.pad_token = pad_token # GPT2Tokenizer.eos_token self.running_flag = True self.publisher = redis.StrictRedis('localhost', 6379, charset="utf-8", decode_responses=True) self.main_thread = threading.Thread(target=self.processing_batch) @@ -241,7 +231,10 @@ def processing_batch(self): pad_len = target_batch[-1].seq_len logging.log(0, "A batch with {} requests and length of {} packed".format(len(target_batch), pad_len)) input_text = [i.text for i in target_batch] - input_ids = self.tokenizer(input_text, padding="longest", return_tensors="pt") + if self.tokenizer: + input_ids = self.tokenizer(input_text, padding="longest", return_tensors="pt") + else: + input_ids = input_text # print("input_ids shape: {}".format(input_ids['input_ids'].shape)) # print("attention_mask shape: {}".format(input_ids['attention_mask'].shape)) output = self.engine.run(input_ids) @@ -255,18 +248,13 @@ def publish_result(self, output, target_batch): :param output: the rpc reference of the inference result. :param target_batch: the input batch """ - def select_top_k(batch_id, predictions, k=10): - predicted_index = random.choice( - predictions[batch_id, -1, :].sort(descending=True)[1][:k]).item() - return predicted_index - # print("output: {}".format(output)) predictions = output.to_here() - # print("predictions: {}".format(predictions), flush=True) for i in range(len(target_batch)): - # print(i, predictions.shape, target_batch) temp_st = target_batch[i].time_ - chosen_pred = select_top_k(i, predictions, k=5) - text_ = self.tokenizer.decode(chosen_pred) - print("text: {}".format(text_)) + chosen_pred = predictions[i] + if self.tokenizer: + text_ = self.tokenizer.decode(int(chosen_pred)) + else: + text_ = chosen_pred self.publisher.publish(str(temp_st), text_) diff --git a/examples/gpt/gpt.py b/examples/gpt/gpt.py index 69ae90c..c48c109 100644 --- a/examples/gpt/gpt.py +++ b/examples/gpt/gpt.py @@ -3,6 +3,7 @@ import os import torch +import random from torch import nn as nn, Tensor, dtype from energon.context import ParallelMode @@ -268,6 +269,19 @@ def __init__(self, word_embeding_weight=self.embed.word_embedding_weight, dtype=dtype) + def select_top_k(self, batch_id, temp_predictions, top_k: int = 10): + """ + Pick out a word from the top k of 50257 words according to the possibility given by temp_predictions + for each sequence in this batch. + :param temp_predictions: Transformer output tensor with size of (batch size, sequence length, vocab size) + which contains the possibilities for each word in this batch. + :type temp_predictions: torch.Tensor + :param top_k: How many top words to choose from. + """ + temp_predicted_index = random.choice( + temp_predictions[batch_id, -1, :].sort(descending=True)[1][:top_k]).item() + return temp_predicted_index + def forward(self, input_ids, attention_mask=None): x = self.embed(input_ids) @@ -282,8 +296,10 @@ def forward(self, input_ids, attention_mask=None): x, attention_mask = block(x, attention_mask) x = self.head(self.norm(x)) - - return x + res = [] + for i in range(x.shape[0]): + res.append(self.select_top_k(i, x)) + return res class PipelineGPT1D(nn.Module): @@ -348,6 +364,19 @@ def __init__(self, self.head = GPTLMHead1D(dim=dim, vocab_size=vocab_size, dtype=dtype) # word_embeeding_weight=self.embed.word_embedding_weight not in the same process + def select_top_k(self, batch_id, temp_predictions, top_k: int = 10): + """ + Pick out a word from the top k of 50257 words according to the possibility given by temp_predictions + for each sequence in this batch. + :param temp_predictions: Transformer output tensor with size of (batch size, sequence length, vocab size) + which contains the possibilities for each word in this batch. + :type temp_predictions: torch.Tensor + :param top_k: How many top words to choose from. + """ + temp_predicted_index = random.choice( + temp_predictions[batch_id, -1, :].sort(descending=True)[1][:top_k]).item() + return temp_predicted_index + def forward(self, hidden_states=None, input_ids=None, attention_mask=None): if self.first: hidden_states = self.embed(input_ids) @@ -371,7 +400,10 @@ def forward(self, hidden_states=None, input_ids=None, attention_mask=None): if self.last: hidden_states = self.head(self.norm(hidden_states)) - + res = [] + for i in range(hidden_states.shape[0]): + res.append(self.select_top_k(i, hidden_states)) + hidden_states = torch.Tensor(res) return hidden_states diff --git a/examples/gpt/gpt_batch_server.py b/examples/gpt/gpt_batch_server.py index 2fd5662..0965a25 100644 --- a/examples/gpt/gpt_batch_server.py +++ b/examples/gpt/gpt_batch_server.py @@ -85,10 +85,11 @@ def launch_engine(model_name, dtype=dtype) global cached_cost - cached_cost = generate_cached_cost(engine, max_seq_len=256, max_batch_size=4, step=4, repeat_round=2) + cached_cost = generate_cached_cost(engine, max_seq_len=256, max_batch_size=4, step=4, repeat_round=2, tokenizer=tokenizer) global batch_manager - batch_manager = Batch_Manager(engine, cached_cost, max_seq_len=256, max_batch_size=4) + batch_manager = Batch_Manager(engine, cached_cost, max_batch_size=4, tokenizer=tokenizer, + pad_token=GPT2Tokenizer.eos_token) print("batch manager initialized") global server