Skip to content
This repository has been archived by the owner on Oct 16, 2023. It is now read-only.

Commit

Permalink
Merge pull request #43 from hpcaitech/feature/variable_len
Browse files Browse the repository at this point in the history
retuurn in order
  • Loading branch information
MaruyamaAya authored May 5, 2022
2 parents 110c30a + 26eae06 commit a1406ec
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 36 deletions.
51 changes: 17 additions & 34 deletions energon/engine/gpt_pipeline_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,12 @@ def __init__(self,
self.model = model
self.dtype = dtype

# input
self.static_input = dict()
self.static_name = []
self.comm_input = dict()
self.comm_name = None

# get the hidden_size
input_ids = torch.randint(1, 10, (max_batch_size, 512), dtype=torch.int64).cuda()
attention_mask = torch.randint(0, 1, (max_batch_size, 1, 512), dtype=torch.int64).cuda()
hidden_states = None
self.sample = dict(hidden_states=hidden_states, input_ids=input_ids, attention_mask=attention_mask)
self._init_input()

self.tensor_dim = 0
self.hidden_size = 0
self.max_batch_size = max_batch_size
Expand All @@ -47,25 +41,13 @@ def __init__(self,
self.lock = threading.Lock()
self.key = CircleInt()

def _init_input(self):

sig = inspect.signature(self.model.forward)
parameters = sig.parameters # dict
for name, _ in parameters.items():
if self.sample[name] is not None:
self.static_input[name] = self.sample[name]
self.static_name.append(name)
else:
self.comm_input[name] = None
self.comm_name = name


def _init_tensor_meta(self):

with torch.inference_mode():
recv_tensor_shape = None
if gpc.is_first_rank(ParallelMode.PIPELINE):
output = self.model(**self.comm_input, **self.static_input) # ([32, 512, 1600])
output = self.model(hidden_states = None, input_ids=self.sample['input_ids'], attention_mask = self.sample['attention_mask']) # ([32, 512, 1600])
send_tensor_meta(output)
send_forward(output)
self.tensor_dim = output.dim()
Expand All @@ -80,8 +62,7 @@ def _init_tensor_meta(self):
input_tensor = recv_forward(recv_tensor_shape, dtype=self.dtype) # only a tensor now
self.tensor_dim = input_tensor.dim()
self.hidden_size = input_tensor.size()[-1]
self.comm_input[self.comm_name] = input_tensor
output = self.model(**self.comm_input, **self.static_input)
output = self.model(hidden_states = None, input_ids=input_tensor, attention_mask = self.sample['attention_mask'])
send_tensor_meta(output)
send_forward(output)

Expand All @@ -96,23 +77,24 @@ def fill_meta_tensor(self, inputs, pipe_meta):
pipe_meta.get_meta_tensor()[3] = self.hidden_size
pipe_meta.update_meta()

def run(self, key, inputs):
def run(self, key, inputs):
pipe_meta = PipelineMeta(self.tensor_dim, self.max_batch_size)
self.fill_meta_tensor(inputs, pipe_meta)
self.pipe_msg_queue.enqueue(key, inputs, pipe_meta)

self.lock.acquire()
sample, pipe_meta = self.pipe_msg_queue.top(self.key.val)
cur_key = self.key.val
sample, pipe_meta = self.pipe_msg_queue.top(cur_key)
self.key.addOne()

for name in self.static_name:
self.static_input[name] = sample[name]

with torch.inference_mode():

if gpc.is_first_rank(ParallelMode.PIPELINE):

output = self.model(**self.comm_input, **self.static_input)
output = self.model(hidden_states = None,
input_ids = sample['input_ids'],
attention_mask = sample['attention_mask'])

send_forward(output)
self.lock.release()
return None
Expand All @@ -121,17 +103,18 @@ def run(self, key, inputs):

# print(f'get_tensor_shapes:{pipe_meta.get_tensor_shapes()}')
input_tensor = recv_forward(pipe_meta.get_tensor_shapes(), dtype=self.dtype)
# print(f'input_tensor.shape:{input_tensor.shape}')
self.comm_input[self.comm_name] = input_tensor
output = self.model(**self.comm_input, **self.static_input)
output = self.model(hidden_states = input_tensor,
input_ids = sample['input_ids'],
attention_mask = sample['attention_mask'])
self.lock.release()
return output
return output, cur_key

else:

input_tensor = recv_forward(pipe_meta.get_tensor_shapes(), dtype=self.dtype)
self.comm_input[self.comm_name] = input_tensor
output = self.model(**self.comm_input, **self.static_input)
output = self.model(hidden_states = input_tensor,
input_ids = sample['input_ids'],
attention_mask = sample['attention_mask'])
send_forward(output)
self.lock.release()
return None
Expand Down
2 changes: 1 addition & 1 deletion energon/engine/pipeline_msg_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def top(self, key):
while key not in self.pipeline_msg_dict:
time.sleep(0.002)

pipe_msg = self.pipeline_msg_dict[key]
pipe_msg = self.pipeline_msg_dict.pop(key)

return pipe_msg.sample, pipe_msg.pipe_meta

Expand Down
2 changes: 1 addition & 1 deletion energon/engine/rpc_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def enqueue(self, key, output):
def top(self, key):
while key not in self.rd:
time.sleep(0.001)
output = self.rd[key]
output = self.rd.pop(key)
return output

class RPCWorker:
Expand Down

0 comments on commit a1406ec

Please sign in to comment.