Skip to content

Commit

Permalink
fix llama and baichuan typo (#1883)
Browse files Browse the repository at this point in the history
  • Loading branch information
lvyufeng authored Dec 23, 2024
1 parent f007978 commit 68a6247
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 5 deletions.
3 changes: 2 additions & 1 deletion .github/pylint.conf
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,8 @@ disable=raw-checker-failed,
fixme,
use-a-generator,
nested-min-max,
method-hidden
method-hidden,
unsubscriptable-object

# Enable the message, report, category or checker with the given id(s). You can
# either give multiple identifier separated by comma (,) or put this option
Expand Down
6 changes: 3 additions & 3 deletions mindnlp/transformers/models/baichuan/modeling_baichuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -1550,11 +1550,11 @@ def forward(
if attention_mask is not None:
if len(attention_mask.shape) == 2:
expanded_mask = attention_mask.to(alibi_mask.dtype)
expanded_mask = ops.tril(ops.gt(expanded_mask[:, :, None] * expanded_mask[:, None, :], 0)
) * ops.eq(expanded_mask[:, :, None] - expanded_mask[:, None, :], 0)
expanded_mask = ops.tril((ops.gt(expanded_mask[:, :, None] * expanded_mask[:, None, :], 0)
) * ops.eq(expanded_mask[:, :, None] - expanded_mask[:, None, :], 0).int()).bool()
else:
expanded_mask = attention_mask
bsz = inputs_embeds.size(0)
bsz = inputs_embeds.shape[0]
src_len, tgt_len = alibi_mask.shape[-2:]
expanded_mask = expanded_mask.unsqueeze(1).broadcast_to((bsz, 1, src_len, tgt_len)).to(alibi_mask.dtype)
inverted_mask = 1.0 - expanded_mask
Expand Down
2 changes: 1 addition & 1 deletion mindnlp/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ def forward(self, x):
)
up_proj = ops.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)

intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
intermediate_states = ops.split((self.act_fn(gate_proj) * up_proj), slice, dim=2)
down_proj = [
F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp)
]
Expand Down

0 comments on commit 68a6247

Please sign in to comment.