Skip to content
This repository has been archived by the owner on Oct 16, 2023. It is now read-only.

retuurn in order #43

Merged
merged 1 commit into from
May 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 17 additions & 34 deletions energon/engine/gpt_pipeline_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,12 @@ def __init__(self,
self.model = model
self.dtype = dtype

# input
self.static_input = dict()
self.static_name = []
self.comm_input = dict()
self.comm_name = None

# get the hidden_size
input_ids = torch.randint(1, 10, (max_batch_size, 512), dtype=torch.int64).cuda()
attention_mask = torch.randint(0, 1, (max_batch_size, 1, 512), dtype=torch.int64).cuda()
hidden_states = None
self.sample = dict(hidden_states=hidden_states, input_ids=input_ids, attention_mask=attention_mask)
self._init_input()

self.tensor_dim = 0
self.hidden_size = 0
self.max_batch_size = max_batch_size
Expand All @@ -47,25 +41,13 @@ def __init__(self,
self.lock = threading.Lock()
self.key = CircleInt()

def _init_input(self):

sig = inspect.signature(self.model.forward)
parameters = sig.parameters # dict
for name, _ in parameters.items():
if self.sample[name] is not None:
self.static_input[name] = self.sample[name]
self.static_name.append(name)
else:
self.comm_input[name] = None
self.comm_name = name


def _init_tensor_meta(self):

with torch.inference_mode():
recv_tensor_shape = None
if gpc.is_first_rank(ParallelMode.PIPELINE):
output = self.model(**self.comm_input, **self.static_input) # ([32, 512, 1600])
output = self.model(hidden_states = None, input_ids=self.sample['input_ids'], attention_mask = self.sample['attention_mask']) # ([32, 512, 1600])
send_tensor_meta(output)
send_forward(output)
self.tensor_dim = output.dim()
Expand All @@ -80,8 +62,7 @@ def _init_tensor_meta(self):
input_tensor = recv_forward(recv_tensor_shape, dtype=self.dtype) # only a tensor now
self.tensor_dim = input_tensor.dim()
self.hidden_size = input_tensor.size()[-1]
self.comm_input[self.comm_name] = input_tensor
output = self.model(**self.comm_input, **self.static_input)
output = self.model(hidden_states = None, input_ids=input_tensor, attention_mask = self.sample['attention_mask'])
send_tensor_meta(output)
send_forward(output)

Expand All @@ -96,23 +77,24 @@ def fill_meta_tensor(self, inputs, pipe_meta):
pipe_meta.get_meta_tensor()[3] = self.hidden_size
pipe_meta.update_meta()

def run(self, key, inputs):
def run(self, key, inputs):
pipe_meta = PipelineMeta(self.tensor_dim, self.max_batch_size)
self.fill_meta_tensor(inputs, pipe_meta)
self.pipe_msg_queue.enqueue(key, inputs, pipe_meta)

self.lock.acquire()
sample, pipe_meta = self.pipe_msg_queue.top(self.key.val)
cur_key = self.key.val
sample, pipe_meta = self.pipe_msg_queue.top(cur_key)
self.key.addOne()

for name in self.static_name:
self.static_input[name] = sample[name]

with torch.inference_mode():

if gpc.is_first_rank(ParallelMode.PIPELINE):

output = self.model(**self.comm_input, **self.static_input)
output = self.model(hidden_states = None,
input_ids = sample['input_ids'],
attention_mask = sample['attention_mask'])

send_forward(output)
self.lock.release()
return None
Expand All @@ -121,17 +103,18 @@ def run(self, key, inputs):

# print(f'get_tensor_shapes:{pipe_meta.get_tensor_shapes()}')
input_tensor = recv_forward(pipe_meta.get_tensor_shapes(), dtype=self.dtype)
# print(f'input_tensor.shape:{input_tensor.shape}')
self.comm_input[self.comm_name] = input_tensor
output = self.model(**self.comm_input, **self.static_input)
output = self.model(hidden_states = input_tensor,
input_ids = sample['input_ids'],
attention_mask = sample['attention_mask'])
self.lock.release()
return output
return output, cur_key

else:

input_tensor = recv_forward(pipe_meta.get_tensor_shapes(), dtype=self.dtype)
self.comm_input[self.comm_name] = input_tensor
output = self.model(**self.comm_input, **self.static_input)
output = self.model(hidden_states = input_tensor,
input_ids = sample['input_ids'],
attention_mask = sample['attention_mask'])
send_forward(output)
self.lock.release()
return None
Expand Down
2 changes: 1 addition & 1 deletion energon/engine/pipeline_msg_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def top(self, key):
while key not in self.pipeline_msg_dict:
time.sleep(0.002)

pipe_msg = self.pipeline_msg_dict[key]
pipe_msg = self.pipeline_msg_dict.pop(key)

return pipe_msg.sample, pipe_msg.pipe_meta

Expand Down
2 changes: 1 addition & 1 deletion energon/engine/rpc_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def enqueue(self, key, output):
def top(self, key):
while key not in self.rd:
time.sleep(0.001)
output = self.rd[key]
output = self.rd.pop(key)
return output

class RPCWorker:
Expand Down