Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactored code from different base based on leyan_branch #588

Open
wants to merge 25 commits into
base: master
Choose a base branch
from

Conversation

cesposo
Copy link

@cesposo cesposo commented Jan 9, 2025

Summary
This PR addresses two things, the extension of model_ext.py and train_sat.py from leyan_branch with my additions from the previous PR. Second it addresses some run-time issues in the flash attention path that was caused by a dtype and shape mismatch when passing the attention_mask to PyTorch’s scaled_dot_product_attention function. By default, flash attention expects a boolean or floating-point mask of the same dtype/shape as the query tensor, broadcastable to [batch_size, n_heads, seq_len, seq_len]. Our previous code passed a [batch_size, seq_len] mask of int64, leading to said runtime error.

We now proceed w/ the training runs.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants