Skip to content

Commit

Permalink
fix mask bias dtype in sdpa (#2407)
Browse files Browse the repository at this point in the history
* fix mask bias dtype in sdpa

* Update decoder.py
  • Loading branch information
Mddct authored Mar 12, 2024
1 parent 983e86f commit 696d161
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions wenet/transformer/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,8 @@ def forward(
# tgt_mask: (B, L, L)
tgt_mask = tgt_mask & m
if self.use_sdpa:
tgt_mask = mask_to_bias(tgt_mask, tgt.dtype)
memory_mask = mask_to_bias(memory_mask, memory_mask.dtype)
tgt_mask = mask_to_bias(tgt_mask, memory.dtype)
memory_mask = mask_to_bias(memory_mask, memory.dtype)

x, _ = self.embed(tgt)
if self.gradient_checkpointing and self.training:
Expand Down

0 comments on commit 696d161

Please sign in to comment.