Skip to content
This repository has been archived by the owner on Oct 16, 2023. It is now read-only.

Commit

Permalink
Merge pull request #47 from hpcaitech/lzm_develop_2
Browse files Browse the repository at this point in the history
move tokenizer out of manager and move select_top_k in to models
  • Loading branch information
dujiangsu authored May 6, 2022
2 parents 8e65da0 + d9a09a6 commit a896d85
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 39 deletions.
56 changes: 22 additions & 34 deletions energon/server/batch_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from scipy import stats
import numpy as np
from energon.engine import InferenceEngine
from transformers import GPT2Tokenizer
import random
import redis
import os
Expand All @@ -19,7 +18,7 @@


def generate_cached_cost(engine, max_seq_len: int = 1024, max_batch_size: int = 16, step: int = 1,
repeat_round: int = 3):
repeat_round: int = 3, tokenizer=None):
"""
Test the running time for different sequence length and batch size on the current machine.
:param engine: InferenceEngine from energon.engine
Expand All @@ -29,20 +28,6 @@ def generate_cached_cost(engine, max_seq_len: int = 1024, max_batch_size: int =
:param step: Run time is measured every other 'step' of sequence length
:param repeat_round: We inference current batch 'repeat_round' times and take average.
"""

def select_top_k(temp_predictions, top_k: int = 10):
"""
Pick out a word from the top k of 50257 words according to the possibility given by temp_predictions
for each sequence in this batch.
:param temp_predictions: Transformer output tensor with size of (batch size, sequence length, vocab size)
which contains the possibilities for each word in this batch.
:type temp_predictions: torch.Tensor
:param top_k: How many top words to choose from.
"""
temp_predicted_index = random.choice(
temp_predictions[0, -1, :].sort(descending=True)[1][:top_k]).item()
return temp_predicted_index

logging.log(0, "fetching cached cost")
cached_name = "cached_cost_{}_{}_{}_{}.npy".format(max_seq_len, max_batch_size, step, repeat_round)
if os.path.exists(cached_name):
Expand All @@ -52,18 +37,20 @@ def select_top_k(temp_predictions, top_k: int = 10):
logging.log(0, "generating new cached cost")
cached_cost = [[0 for i in range(max_batch_size + 1)] for j in range(max_seq_len + 1)]
input_text = ""
tokenizer = GPT2Tokenizer.from_pretrained("./")
for tmp_len in trange(1, max_seq_len + 1, step):
input_text += "test "
for tmp_batch in range(1, max_batch_size + 1):
batched_text = [input_text for _ in range(tmp_batch)]
start_time = time.time()
for k in range(repeat_round):
input_token = tokenizer(batched_text, return_tensors="pt")
if tokenizer:
input_token = tokenizer(batched_text, return_tensors="pt")
else:
input_token = batched_text
output = engine.run(input_token)
predictions = output.to_here()
predicted_index = select_top_k(predictions, k=1)
tokenizer.decode(predicted_index)
if tokenizer:
tokenizer.decode(predictions)
time_cost = (time.time() - start_time) / repeat_round
cached_cost[tmp_len][tmp_batch] = time_cost
for k in range(1, step):
Expand Down Expand Up @@ -92,6 +79,7 @@ class Manager:
"""
Base class of batch manager.
"""

def __init__(self):
pass

Expand All @@ -105,8 +93,9 @@ class Batch_Manager(Manager):
queue is wrapped into batches according to the sequence length and the priority calculated with the equation
in function cal_priority and then sent into the inference engine.
"""

def __init__(self, engine: InferenceEngine, cached_cost: list, init_mu: int = 512, init_theta: int = 180,
max_batch_size: int = 32, lr: float = 0.01):
max_batch_size: int = 32, lr: float = 0.01, tokenizer=None, pad_token=None):
"""
:param engine: The InferenceEngine from energon.engine
:param cached_cost: The output of function generate_cached_cost
Expand All @@ -126,8 +115,9 @@ def __init__(self, engine: InferenceEngine, cached_cost: list, init_mu: int = 51
self.req_list_lock = rwlock.RWLockFair()
self.write_lock = self.req_list_lock.gen_wlock()
self.cached_cost = cached_cost
self.tokenizer = GPT2Tokenizer.from_pretrained('/home/lcdjs/hf_gpt2')
self.tokenizer.pad_token = GPT2Tokenizer.eos_token
self.tokenizer = tokenizer
if self.tokenizer and pad_token:
self.tokenizer.pad_token = pad_token # GPT2Tokenizer.eos_token
self.running_flag = True
self.publisher = redis.StrictRedis('localhost', 6379, charset="utf-8", decode_responses=True)
self.main_thread = threading.Thread(target=self.processing_batch)
Expand Down Expand Up @@ -241,7 +231,10 @@ def processing_batch(self):
pad_len = target_batch[-1].seq_len
logging.log(0, "A batch with {} requests and length of {} packed".format(len(target_batch), pad_len))
input_text = [i.text for i in target_batch]
input_ids = self.tokenizer(input_text, padding="longest", return_tensors="pt")
if self.tokenizer:
input_ids = self.tokenizer(input_text, padding="longest", return_tensors="pt")
else:
input_ids = input_text
# print("input_ids shape: {}".format(input_ids['input_ids'].shape))
# print("attention_mask shape: {}".format(input_ids['attention_mask'].shape))
output = self.engine.run(input_ids)
Expand All @@ -255,18 +248,13 @@ def publish_result(self, output, target_batch):
:param output: the rpc reference of the inference result.
:param target_batch: the input batch
"""
def select_top_k(batch_id, predictions, k=10):
predicted_index = random.choice(
predictions[batch_id, -1, :].sort(descending=True)[1][:k]).item()
return predicted_index

# print("output: {}".format(output))
predictions = output.to_here()
# print("predictions: {}".format(predictions), flush=True)
for i in range(len(target_batch)):
# print(i, predictions.shape, target_batch)
temp_st = target_batch[i].time_
chosen_pred = select_top_k(i, predictions, k=5)
text_ = self.tokenizer.decode(chosen_pred)
print("text: {}".format(text_))
chosen_pred = predictions[i]
if self.tokenizer:
text_ = self.tokenizer.decode(int(chosen_pred))
else:
text_ = chosen_pred
self.publisher.publish(str(temp_st), text_)
38 changes: 35 additions & 3 deletions examples/gpt/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os

import torch
import random
from torch import nn as nn, Tensor, dtype

from energon.context import ParallelMode
Expand Down Expand Up @@ -268,6 +269,19 @@ def __init__(self,
word_embeding_weight=self.embed.word_embedding_weight,
dtype=dtype)

def select_top_k(self, batch_id, temp_predictions, top_k: int = 10):
"""
Pick out a word from the top k of 50257 words according to the possibility given by temp_predictions
for each sequence in this batch.
:param temp_predictions: Transformer output tensor with size of (batch size, sequence length, vocab size)
which contains the possibilities for each word in this batch.
:type temp_predictions: torch.Tensor
:param top_k: How many top words to choose from.
"""
temp_predicted_index = random.choice(
temp_predictions[batch_id, -1, :].sort(descending=True)[1][:top_k]).item()
return temp_predicted_index

def forward(self, input_ids, attention_mask=None):
x = self.embed(input_ids)

Expand All @@ -282,8 +296,10 @@ def forward(self, input_ids, attention_mask=None):
x, attention_mask = block(x, attention_mask)

x = self.head(self.norm(x))

return x
res = []
for i in range(x.shape[0]):
res.append(self.select_top_k(i, x))
return res


class PipelineGPT1D(nn.Module):
Expand Down Expand Up @@ -348,6 +364,19 @@ def __init__(self,
self.head = GPTLMHead1D(dim=dim, vocab_size=vocab_size,
dtype=dtype) # word_embeeding_weight=self.embed.word_embedding_weight not in the same process

def select_top_k(self, batch_id, temp_predictions, top_k: int = 10):
"""
Pick out a word from the top k of 50257 words according to the possibility given by temp_predictions
for each sequence in this batch.
:param temp_predictions: Transformer output tensor with size of (batch size, sequence length, vocab size)
which contains the possibilities for each word in this batch.
:type temp_predictions: torch.Tensor
:param top_k: How many top words to choose from.
"""
temp_predicted_index = random.choice(
temp_predictions[batch_id, -1, :].sort(descending=True)[1][:top_k]).item()
return temp_predicted_index

def forward(self, hidden_states=None, input_ids=None, attention_mask=None):
if self.first:
hidden_states = self.embed(input_ids)
Expand All @@ -371,7 +400,10 @@ def forward(self, hidden_states=None, input_ids=None, attention_mask=None):

if self.last:
hidden_states = self.head(self.norm(hidden_states))

res = []
for i in range(hidden_states.shape[0]):
res.append(self.select_top_k(i, hidden_states))
hidden_states = torch.Tensor(res)
return hidden_states


Expand Down
5 changes: 3 additions & 2 deletions examples/gpt/gpt_batch_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,11 @@ def launch_engine(model_name,
dtype=dtype)

global cached_cost
cached_cost = generate_cached_cost(engine, max_seq_len=256, max_batch_size=4, step=4, repeat_round=2)
cached_cost = generate_cached_cost(engine, max_seq_len=256, max_batch_size=4, step=4, repeat_round=2, tokenizer=tokenizer)

global batch_manager
batch_manager = Batch_Manager(engine, cached_cost, max_seq_len=256, max_batch_size=4)
batch_manager = Batch_Manager(engine, cached_cost, max_batch_size=4, tokenizer=tokenizer,
pad_token=GPT2Tokenizer.eos_token)
print("batch manager initialized")

global server
Expand Down

0 comments on commit a896d85

Please sign in to comment.