-
Notifications
You must be signed in to change notification settings - Fork 150
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update the attention_mask reformat for MHA #802
Conversation
@@ -1689,61 +1694,73 @@ def make_attention_mask_reformatting_for_mha(self): | |||
# Make nodes for the attention mask subgraphs that reformat the | |||
# 2D attention mask (B, S) to 4D causal attention mask (B, N, S, T) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I do not under stand this. I think MHA supports 2D mask with shape (B, T). Shall we use that directly instead of converting to 4D in onnx graph? (May need MHA supports causal mask in cuda ep).
It is better that we use 1D mask (total kv lengths, assuming right padding) to be consistent with GQA.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this change get validated by tests?
apsonawane is this pull-request still relevant? Could you please address the comments and update the PR when possible? |
Yes, I will be updating the PR |
Phi-3.5 onnx models has been released here. The issue is not seen in these models would recommend using Phi-3.5 instead of Phi-3. Closing this PR as it is no longer required. |
Looks like this has been solved with the latest ONNX release, but fine-tuning these ONNX models by converting them to torch is really tricky. Has the fix been made to the non-ONNX models as well? Any workaround for that? |
@myadav2, the overall process for fine-tuning is that: fine-tune with Pytorch -> use ModelBuilder or PyTorch ONNX export to convert the model to ONNX -> serve with ORT GenAI API |
Yes, but the issue of gibberish output after fine-tuning long context text is present for the base Pytorch models. I don't think we have a fix for that as far as I know |
@myadav2, that's annoying. Could you please describe your issue in details and share some examples? I can bring the issues to the Phi3 model training team and see if they can help. |
The Phi-3.5 PyTorch models should also have this fix. If you still observe this behavior, you can open an issue in this repo so this can be tracked. |
For phi3 models the attention_mask reformat was incorrect. Updated the pattern in this below PR.
This PR also helps with the issue open here: #552