diff --git a/energon/engine/gpt_pipeline_wrapper.py b/energon/engine/gpt_pipeline_wrapper.py index c1e8bb5..e78359b 100644 --- a/energon/engine/gpt_pipeline_wrapper.py +++ b/energon/engine/gpt_pipeline_wrapper.py @@ -24,18 +24,12 @@ def __init__(self, self.model = model self.dtype = dtype - # input - self.static_input = dict() - self.static_name = [] - self.comm_input = dict() - self.comm_name = None - # get the hidden_size input_ids = torch.randint(1, 10, (max_batch_size, 512), dtype=torch.int64).cuda() attention_mask = torch.randint(0, 1, (max_batch_size, 1, 512), dtype=torch.int64).cuda() hidden_states = None self.sample = dict(hidden_states=hidden_states, input_ids=input_ids, attention_mask=attention_mask) - self._init_input() + self.tensor_dim = 0 self.hidden_size = 0 self.max_batch_size = max_batch_size @@ -47,25 +41,13 @@ def __init__(self, self.lock = threading.Lock() self.key = CircleInt() - def _init_input(self): - - sig = inspect.signature(self.model.forward) - parameters = sig.parameters # dict - for name, _ in parameters.items(): - if self.sample[name] is not None: - self.static_input[name] = self.sample[name] - self.static_name.append(name) - else: - self.comm_input[name] = None - self.comm_name = name - def _init_tensor_meta(self): with torch.inference_mode(): recv_tensor_shape = None if gpc.is_first_rank(ParallelMode.PIPELINE): - output = self.model(**self.comm_input, **self.static_input) # ([32, 512, 1600]) + output = self.model(hidden_states = None, input_ids=self.sample['input_ids'], attention_mask = self.sample['attention_mask']) # ([32, 512, 1600]) send_tensor_meta(output) send_forward(output) self.tensor_dim = output.dim() @@ -80,8 +62,7 @@ def _init_tensor_meta(self): input_tensor = recv_forward(recv_tensor_shape, dtype=self.dtype) # only a tensor now self.tensor_dim = input_tensor.dim() self.hidden_size = input_tensor.size()[-1] - self.comm_input[self.comm_name] = input_tensor - output = self.model(**self.comm_input, **self.static_input) + output = self.model(hidden_states = None, input_ids=input_tensor, attention_mask = self.sample['attention_mask']) send_tensor_meta(output) send_forward(output) @@ -96,23 +77,24 @@ def fill_meta_tensor(self, inputs, pipe_meta): pipe_meta.get_meta_tensor()[3] = self.hidden_size pipe_meta.update_meta() - def run(self, key, inputs): + def run(self, key, inputs): pipe_meta = PipelineMeta(self.tensor_dim, self.max_batch_size) self.fill_meta_tensor(inputs, pipe_meta) self.pipe_msg_queue.enqueue(key, inputs, pipe_meta) self.lock.acquire() - sample, pipe_meta = self.pipe_msg_queue.top(self.key.val) + cur_key = self.key.val + sample, pipe_meta = self.pipe_msg_queue.top(cur_key) self.key.addOne() - for name in self.static_name: - self.static_input[name] = sample[name] - with torch.inference_mode(): if gpc.is_first_rank(ParallelMode.PIPELINE): - output = self.model(**self.comm_input, **self.static_input) + output = self.model(hidden_states = None, + input_ids = sample['input_ids'], + attention_mask = sample['attention_mask']) + send_forward(output) self.lock.release() return None @@ -121,17 +103,18 @@ def run(self, key, inputs): # print(f'get_tensor_shapes:{pipe_meta.get_tensor_shapes()}') input_tensor = recv_forward(pipe_meta.get_tensor_shapes(), dtype=self.dtype) - # print(f'input_tensor.shape:{input_tensor.shape}') - self.comm_input[self.comm_name] = input_tensor - output = self.model(**self.comm_input, **self.static_input) + output = self.model(hidden_states = input_tensor, + input_ids = sample['input_ids'], + attention_mask = sample['attention_mask']) self.lock.release() - return output + return output, cur_key else: input_tensor = recv_forward(pipe_meta.get_tensor_shapes(), dtype=self.dtype) - self.comm_input[self.comm_name] = input_tensor - output = self.model(**self.comm_input, **self.static_input) + output = self.model(hidden_states = input_tensor, + input_ids = sample['input_ids'], + attention_mask = sample['attention_mask']) send_forward(output) self.lock.release() return None diff --git a/energon/engine/pipeline_msg_dict.py b/energon/engine/pipeline_msg_dict.py index 2086baa..d0aa122 100644 --- a/energon/engine/pipeline_msg_dict.py +++ b/energon/engine/pipeline_msg_dict.py @@ -37,7 +37,7 @@ def top(self, key): while key not in self.pipeline_msg_dict: time.sleep(0.002) - pipe_msg = self.pipeline_msg_dict[key] + pipe_msg = self.pipeline_msg_dict.pop(key) return pipe_msg.sample, pipe_msg.pipe_meta diff --git a/energon/engine/rpc_worker.py b/energon/engine/rpc_worker.py index 3e0c118..2568524 100644 --- a/energon/engine/rpc_worker.py +++ b/energon/engine/rpc_worker.py @@ -25,7 +25,7 @@ def enqueue(self, key, output): def top(self, key): while key not in self.rd: time.sleep(0.001) - output = self.rd[key] + output = self.rd.pop(key) return output class RPCWorker: