Skip to content

Commit

Permalink
fix 2
Browse files Browse the repository at this point in the history
  • Loading branch information
ydshieh committed Feb 6, 2025
1 parent 6ba13f5 commit 421bf86
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 7 deletions.
16 changes: 15 additions & 1 deletion src/transformers/integrations/sdpa_attention.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Tuple
from typing import Optional, Tuple, TypedDict

import torch

Expand Down Expand Up @@ -62,3 +62,17 @@ def sdpa_attention_forward(
attn_output = attn_output.transpose(1, 2).contiguous()

return attn_output, None


class SdpaAttentionKwargs(TypedDict, total=False):
"""
Keyword arguments for sdpa Attention.
Attributes:
is_causal (`bool`, *optional*)
The value for the argument `is_causal` that is passed to `torch.nn.functional.scaled_dot_product_attention`.
An error is thrown if both attention_mask and is_causal are set. If `None`, it is inferred in
`sdpa_attention_forward`.
"""

is_causal: Optional[bool]
8 changes: 6 additions & 2 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from dataclasses import dataclass
from functools import partial, wraps
from threading import Thread
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TypeVar, Union
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TypedDict, TypeVar, Union
from zipfile import is_zipfile

import torch
Expand All @@ -48,8 +48,9 @@
from .integrations import PeftAdapterMixin, deepspeed_config, is_deepspeed_zero3_enabled
from .integrations.flash_attention import flash_attention_forward
from .integrations.flex_attention import flex_attention_forward
from .integrations.sdpa_attention import sdpa_attention_forward
from .integrations.sdpa_attention import sdpa_attention_forward, SdpaAttentionKwargs
from .loss.loss_utils import LOSS_MAPPING
from .modeling_flash_attention_utils import FlashAttentionKwargs
from .pytorch_utils import ( # noqa: F401
Conv1D,
apply_chunking_to_forward,
Expand Down Expand Up @@ -5702,3 +5703,6 @@ def get_disk_only_shard_files(device_map, sharded_metadata, start_prefix):
"sdpa": sdpa_attention_forward,
}
)


AttentionKwargs = Union[FlashAttentionKwargs, SdpaAttentionKwargs]
7 changes: 3 additions & 4 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from ...cache_utils import Cache, DynamicCache, StaticCache
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
Expand All @@ -36,7 +35,7 @@
TokenClassifierOutput,
)
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, AttentionKwargs, PreTrainedModel
from ...processing_utils import Unpack
from ...pytorch_utils import ALL_LAYERNORM_LAYERS
from ...utils import (
Expand Down Expand Up @@ -262,7 +261,7 @@ def forward(
attention_mask: Optional[torch.Tensor],
past_key_value: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
**kwargs: Unpack[AttentionKwargs],
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
Expand Down Expand Up @@ -528,7 +527,7 @@ def forward(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
**flash_attn_kwargs: Unpack[AttentionKwargs],
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
Expand Down

0 comments on commit 421bf86

Please sign in to comment.