Skip to content

Commit

Permalink
1.fix dynamic_forward of mtp. 2.fix llama-eagle multi-device (#9947)
Browse files Browse the repository at this point in the history
  • Loading branch information
freeliuzc authored Feb 27, 2025
1 parent 02bf7c8 commit 85e1238
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 4 deletions.
3 changes: 0 additions & 3 deletions llm/predict/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1127,9 +1127,6 @@ def predict(self, input_texts: list[str], return_tokens=False):
self._infer(self.model_inputs)
logger.info(f"running spend {time.time() - s_time}")

if self.proposer is not None:
self.proposer.postprocess(base_model_inputs=self.model_inputs)

if self.tensor_parallel_rank == 0:
outputs = []
output_tokens = []
Expand Down
4 changes: 3 additions & 1 deletion paddlenlp/experimental/transformers/llama/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1468,7 +1468,7 @@ def __init__(self, config: LlamaConfig):

if config.tensor_parallel_degree > 1:
self.fc = ColumnParallelLinear(
self.hidden_size * 2, self.hidden_size, has_bias=True, gather_output=False, fuse_matmul_bias=True
self.hidden_size * 2, self.hidden_size, has_bias=True, gather_output=True, fuse_matmul_bias=True
)
else:
self.fc = nn.Linear(self.hidden_size * 2, self.hidden_size, bias_attr=True)
Expand Down Expand Up @@ -1832,6 +1832,8 @@ def get_tensor_parallel_split_mappings(num_layers):

base_actions = {
"lm_head.weight": partial(fn, is_column=True),
"fc.weight": partial(fn, is_column=True),
"fc.bias": partial(fn, is_column=True),
# Row Linear
"embed_tokens.weight": partial(fn, is_column=False),
"layers.0.self_attn.o_proj.weight": partial(fn, is_column=False),
Expand Down

0 comments on commit 85e1238

Please sign in to comment.