Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update the attention_mask reformat for MHA #802

Closed
wants to merge 1 commit into from
Closed

Conversation

apsonawane
Copy link
Contributor

For phi3 models the attention_mask reformat was incorrect. Updated the pattern in this below PR.
This PR also helps with the issue open here: #552

@@ -1689,61 +1694,73 @@ def make_attention_mask_reformatting_for_mha(self):
# Make nodes for the attention mask subgraphs that reformat the
# 2D attention mask (B, S) to 4D causal attention mask (B, N, S, T)
Copy link

@tianleiwu tianleiwu Aug 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not under stand this. I think MHA supports 2D mask with shape (B, T). Shall we use that directly instead of converting to 4D in onnx graph? (May need MHA supports causal mask in cuda ep).

It is better that we use 1D mask (total kv lengths, assuming right padding) to be consistent with GQA.

@natke natke self-requested a review August 21, 2024 00:08
Copy link
Contributor

@natke natke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this change get validated by tests?

@baijumeswani
Copy link
Collaborator

apsonawane is this pull-request still relevant? Could you please address the comments and update the PR when possible?

@apsonawane
Copy link
Contributor Author

Yes, I will be updating the PR

@apsonawane
Copy link
Contributor Author

Phi-3.5 onnx models has been released here. The issue is not seen in these models would recommend using Phi-3.5 instead of Phi-3.

Closing this PR as it is no longer required.

@apsonawane apsonawane closed this Sep 16, 2024
@myadav2
Copy link

myadav2 commented Sep 19, 2024

Looks like this has been solved with the latest ONNX release, but fine-tuning these ONNX models by converting them to torch is really tricky. Has the fix been made to the non-ONNX models as well? Any workaround for that?

@yufenglee
Copy link
Member

Looks like this has been solved with the latest ONNX release, but fine-tuning these ONNX models by converting them to torch is really tricky. Has the fix been made to the non-ONNX models as well? Any workaround for that?

@myadav2, the overall process for fine-tuning is that: fine-tune with Pytorch -> use ModelBuilder or PyTorch ONNX export to convert the model to ONNX -> serve with ORT GenAI API

@myadav2
Copy link

myadav2 commented Sep 19, 2024

Looks like this has been solved with the latest ONNX release, but fine-tuning these ONNX models by converting them to torch is really tricky. Has the fix been made to the non-ONNX models as well? Any workaround for that?

@myadav2, the overall process for fine-tuning is that: fine-tune with Pytorch -> use ModelBuilder or PyTorch ONNX export to convert the model to ONNX -> serve with ORT GenAI API

Yes, but the issue of gibberish output after fine-tuning long context text is present for the base Pytorch models. I don't think we have a fix for that as far as I know

@yufenglee
Copy link
Member

Looks like this has been solved with the latest ONNX release, but fine-tuning these ONNX models by converting them to torch is really tricky. Has the fix been made to the non-ONNX models as well? Any workaround for that?

@myadav2, the overall process for fine-tuning is that: fine-tune with Pytorch -> use ModelBuilder or PyTorch ONNX export to convert the model to ONNX -> serve with ORT GenAI API

Yes, but the issue of gibberish output after fine-tuning long context text is present for the base Pytorch models. I don't think we have a fix for that as far as I know

@myadav2, that's annoying. Could you please describe your issue in details and share some examples? I can bring the issues to the Phi3 model training team and see if they can help.

@kunal-vaishnavi
Copy link
Contributor

Has the fix been made to the non-ONNX models as well? Any workaround for that?

The Phi-3.5 PyTorch models should also have this fix. If you still observe this behavior, you can open an issue in this repo so this can be tracked.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants