-
Notifications
You must be signed in to change notification settings - Fork 635
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add nd-tensor support for triton softmax #210
Conversation
LGTM, but could you just add one more shape here ? https://github.com/facebookresearch/xformers/blob/main/tests/test_triton_softmax.py#L24 This would cover this case, better safe than sorry :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM with a one liner addition, please test this case just to make sure we don't break it in the future (and forget about it)
@@ -28,6 +28,8 @@ | |||
(1, 2048, 2048), | |||
(1, 3136, 3136), | |||
(1, 4096, 4096), | |||
(2, 2, 384, 384), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks !
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM again, happy to land it anytime
What does this PR do?
This PR generalizes
xformers.triton.softmax
to support tensors of more than 3 dimensions.It does it by flattening the input so that it is 3d.
A similar patch was already added in the past so that it supports 2d tensors, so this is follows in a similar way.
I can add more tests if needed