-
Notifications
You must be signed in to change notification settings - Fork 27.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Remove type hint Unpack[FlashAttentionKwargs]
#36049
base: main
Are you sure you want to change the base?
Conversation
@@ -262,7 +262,7 @@ def forward( | |||
attention_mask: Optional[torch.Tensor], | |||
past_key_value: Optional[Cache] = None, | |||
cache_position: Optional[torch.LongTensor] = None, | |||
**kwargs: Unpack[FlashAttentionKwargs], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unpack[FlashAttentionKwargs]
is added in #35235 to xxxAttention
class which is a copy-paste mistake IMO
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not sure how useful this is, but IMO we should leave them as FlashAttentionKwargs, these are know kwargs, and eager, sdpa and flex don't need extra ones, flash needs. SO let's rather update places where typing in missing please
IMO a more general |
However, it needs to call with
That is because in
And in the case Not sure if there is a better way though. @zucchini-nlp have you already some experience with the attention refactor in #35235 by applying to VLM (or encoder decoder like models) and seeing the similar issue? |
but yeah, not really useful in terms to the community, I just try to keep things more align to the fact 😄 |
@ydshieh yeah, that is pretty much how encoder decoder models are used. I don't think we have to ask users to pass |
AttentionKwargs that is general is fine, but TLDR we need to type what we explicitily support being passed in here |
I am trying to do a small update. But it would be nice if you can suggest what we would define the set of keyword arguments 🙏 So the following are unique to each:
Reference
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
One thing I wanted to take into account in this was for people to be able to overwrite the typing class easily (say TGI has different kwargs, how can they overwrite transformers.AttentionKwargs easily?
Let's add this to the doc!
What does this PR do?
Before #35235, the type hint
Unpack[FlashAttentionKwargs]
is only used in theforward
ofxxxFlashAttention2
but notxxxAttention
orxxxSdpaAttention
.In #35235,
xxxFlashAttention2
andxxxSdpaAttention
are removed, andxxxAttention
becomes more general that handles different attention implementations within itsforward
viaattention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
.But #35235 adds
Unpack[FlashAttentionKwargs]
as the type hint for**kwargs
, which is not accurate and misleading. The passed kwargs could be the ones necessary foreager
orsdpa
attn implementations (although such usage is rare I believe).And it seems to me this type hint was added because a copy-paste error.
This PR remove this type hint to avoid confusion.
TODO: apply the same changes to other places once the PR is approved.