Skip to content

Commit

Permalink
fix error on gpu
Browse files Browse the repository at this point in the history
  • Loading branch information
mshr-h committed Nov 29, 2023
1 parent d46a6fc commit 9e18215
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -5526,6 +5526,8 @@ def test_fn(attn_mask=None, is_causal=False):

# Test with explicit attn_mask
attn_mask = torch.ones((L, S), dtype=torch.bool).tril(diagonal=0)
if torch.cuda.is_available():
attn_mask = attn_mask.cuda()
verify_model(test_fn(attn_mask=attn_mask), [query_4d, key_4d, value_4d])
verify_model(test_fn(attn_mask=attn_mask), [query_4d, key_4d, value_3d])
verify_model(test_fn(attn_mask=attn_mask), [query_4d, key_3d, value_4d])
Expand Down

0 comments on commit 9e18215

Please sign in to comment.