Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove type hint Unpack[FlashAttentionKwargs] #36049

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open

Conversation

ydshieh
Copy link
Collaborator

@ydshieh ydshieh commented Feb 5, 2025

What does this PR do?

Before #35235, the type hint Unpack[FlashAttentionKwargs] is only used in the forward of xxxFlashAttention2 but not xxxAttention or xxxSdpaAttention.

In #35235, xxxFlashAttention2 and xxxSdpaAttention are removed, and xxxAttention becomes more general that handles different attention implementations within its forward via attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation].

But #35235 adds Unpack[FlashAttentionKwargs] as the type hint for **kwargs, which is not accurate and misleading. The passed kwargs could be the ones necessary for eager or sdpa attn implementations (although such usage is rare I believe).

And it seems to me this type hint was added because a copy-paste error.

This PR remove this type hint to avoid confusion.

TODO: apply the same changes to other places once the PR is approved.

@@ -262,7 +262,7 @@ def forward(
attention_mask: Optional[torch.Tensor],
past_key_value: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs],
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unpack[FlashAttentionKwargs] is added in #35235 to xxxAttention class which is a copy-paste mistake IMO

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure how useful this is, but IMO we should leave them as FlashAttentionKwargs, these are know kwargs, and eager, sdpa and flex don't need extra ones, flash needs. SO let's rather update places where typing in missing please

@Cyrilvallez
Copy link
Member

IMO a more general AttentionKwargs would maybe be better/make more sense, as we indeed handle several attention functions, but up to @ArthurZucker

@ydshieh
Copy link
Collaborator Author

ydshieh commented Feb 6, 2025

sdpa needs a kwarg in some special case (e.g. **is_causal**). For example, in class Kosmos2_5ImageToTextProjection(nn.Module):, it re-uses Kosmos2_5TextAttention to avoid again write a new attention layer class for that intermediate module (from image to text).

However, it needs to call with is_causal=False explicitly (for the sdpa imple), as it is a text token attends to all the image tokens

        hidden_states, attn_weights = self.x_attn(
            is_causal=False,
        )

That is because in sdpa_attention_forward we have

    if is_causal is None:
        is_causal = causal_mask is None and query.shape[2] > 1

And in the case Kosmos2_5ImageToTextProjection, if we don't specify is_causal , then it is default to None, and we also have query.shape[2] (target length) > 1 in this case, which cause is_causal being True but it shouldn't be.

Not sure if there is a better way though. @zucchini-nlp have you already some experience with the attention refactor in #35235 by applying to VLM (or encoder decoder like models) and seeing the similar issue?

@ydshieh
Copy link
Collaborator Author

ydshieh commented Feb 6, 2025

but yeah, not really useful in terms to the community, I just try to keep things more align to the fact 😄

@zucchini-nlp
Copy link
Member

@ydshieh yeah, that is pretty much how encoder decoder models are used. I don't think we have to ask users to pass is_causal, it is usually pre-set as class attribute when initializing the model, and cannot be changed easily. CMIIW

@ArthurZucker
Copy link
Collaborator

AttentionKwargs that is general is fine, but TLDR we need to type what we explicitily support being passed in here

@ydshieh
Copy link
Collaborator Author

ydshieh commented Feb 6, 2025

I am trying to do a small update. But it would be nice if you can suggest what we would define the set of keyword arguments 🙏

So the following are unique to each:

  • sdpa

    • is_causal
  • flash

    • sliding_window
    • softcap
  • flex

    • softcap
    • head_mask
    • (but it doesn't have dropout unlike in sdpa / flash / eager

Reference

def sdpa_attention_forward(
    module: torch.nn.Module,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    attention_mask: Optional[torch.Tensor],
    dropout: float = 0.0,
    scaling: Optional[float] = None,
    is_causal: Optional[bool] = None,
    **kwargs,
)
def flash_attention_forward(
    module: torch.nn.Module,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    attention_mask: Optional[torch.Tensor],
    dropout: float = 0.0,
    scaling: Optional[float] = None,
    sliding_window: Optional[int] = None,
    softcap: Optional[float] = None,
    **kwargs,
)
def flex_attention_forward(
    module: torch.nn.Module,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    attention_mask: Optional[torch.Tensor],
    scaling: Optional[float] = None,
    softcap: Optional[float] = None,
    head_mask: Optional[torch.Tensor] = None,
    **kwargs,
)

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!
One thing I wanted to take into account in this was for people to be able to overwrite the typing class easily (say TGI has different kwargs, how can they overwrite transformers.AttentionKwargs easily?
Let's add this to the doc!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants