Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add flash attention support for inference #367

Merged
merged 3 commits into from
Oct 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion scripts/attn_and_long_ctx_patches.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def xformers_forward(
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
padding_mask=None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()

Expand Down Expand Up @@ -216,7 +217,7 @@ def apply_attention_patch(
global USE_MEM_EFF_ATTENTION, STORE_KV_BEFORE_ROPE
if use_memory_efficient_attention is True and xops is not None:
USE_MEM_EFF_ATTENTION = use_memory_efficient_attention
print("USE_MEM_EFF_ATTENTION: ",USE_MEM_EFF_ATTENTION)
print("USE_XFORMERS_ATTENTION: ", USE_MEM_EFF_ATTENTION)
STORE_KV_BEFORE_ROPE = store_kv_before_rope
print("STORE_KV_BEFORE_ROPE:", STORE_KV_BEFORE_ROPE)
transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward
Expand Down
95 changes: 95 additions & 0 deletions scripts/inference/flash_attn_patch_for_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# Below code is based on https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama_flash_attn_monkey_patch.py.
from typing import Optional, Tuple
import torch

import transformers

from einops import rearrange
try:
from flash_attn.flash_attn_interface import flash_attn_with_kvcache
except ImportError:
flash_attn_with_kvcache = None
print(
"FlashAttention-2 is not installed correctly. If you want to use flash attention to inference, flash-attention >= 2.2 is needed. "
"Please check the usage in https://github.com/Dao-AILab/flash-attention for more details."
)


def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
padding_mask=None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel

attention_mask: [bsz, q_len]
"""
bsz, q_len, _ = hidden_states.size()

query_states = (
self.q_proj(hidden_states)
.view(bsz, q_len, self.num_heads, self.head_dim)
)
key_states = (
self.k_proj(hidden_states)
.view(bsz, q_len, self.num_heads, self.head_dim)
)
value_states = (
self.v_proj(hidden_states)
.view(bsz, q_len, self.num_heads, self.head_dim)
)

kv_seq_len = key_states.shape[1]
past_kv_len = 0
if past_key_value is not None:
past_kv_len = past_key_value[0].shape[-2]
kv_seq_len += past_kv_len

cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
rotary_dim = cos.shape[-1]
cos, sin = cos.squeeze(0,1)[:,:rotary_dim//2].contiguous(), sin.squeeze(0,1)[:,:rotary_dim//2].contiguous()

if past_key_value is not None:
key_cache = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1)
value_cache = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1)
else:
key_cache = key_states
value_cache = value_states

assert not output_attentions, "output_attentions is not supported"

q = query_states # [bsz, q_len, nh, hd]
k, v = key_states, value_states # [bsz, q_len, nh, hd]

output = flash_attn_with_kvcache(
q, key_cache, value_cache, k, v, rotary_cos=cos, rotary_sin=sin, cache_seqlens=past_kv_len, softmax_scale=None, causal=True, rotary_interleaved=False
)
output = rearrange(output, "b s h d -> b s (h d)", b=bsz)

past_key_value = (key_cache[:,:kv_seq_len].transpose(1,2), value_cache[:,:kv_seq_len].transpose(1,2)) if use_cache else None

output = self.o_proj(output)

return output, None, past_key_value


# Disable the transformation of the attention mask in LlamaModel as the flash attention
# requires the attention mask to be the same as the key_padding_mask
def _prepare_decoder_attention_mask(
self, attention_mask, input_shape, inputs_embeds, past_key_values_length
):
return attention_mask


def replace_llama_attn_with_flash_attn():
if flash_attn_with_kvcache != None:
print("USE_FLASH_ATTENTION: ", True)
transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = _prepare_decoder_attention_mask
transformers.models.llama.modeling_llama.LlamaAttention.forward = forward
else:
print("USE_FLASH_ATTENTION: ", False)
13 changes: 11 additions & 2 deletions scripts/inference/gradio_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,10 @@
"--draft_model_load_in_4bit",
action='store_true',
help="Load the draft model in the 4bit mode")
parser.add_argument(
'--flash_attn',
action='store_true',
help="Use flash attention to replace the LLaMA attention")

args = parser.parse_args()

Expand All @@ -132,9 +136,14 @@
import sys
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(parent_dir)
from attn_and_long_ctx_patches import apply_attention_patch, apply_ntk_scaling_patch
if not args.only_cpu:
apply_attention_patch(use_memory_efficient_attention=True)
if args.flash_attn:
from flash_attn_patch_for_inference import replace_llama_attn_with_flash_attn
replace_llama_attn_with_flash_attn()
else:
from attn_and_long_ctx_patches import apply_attention_patch
apply_attention_patch(use_memory_efficient_attention=True)
from attn_and_long_ctx_patches import apply_ntk_scaling_patch
apply_ntk_scaling_patch(args.alpha)
if args.speculative_sampling:
if args.draft_base_model == None:
Expand Down
10 changes: 8 additions & 2 deletions scripts/inference/inference_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
parser.add_argument('--draft_lora_model', default=None, type=str, help="If None, perform inference on the draft base model")
parser.add_argument('--draft_model_load_in_8bit', action='store_true', help="Load the draft model in the 8bit mode")
parser.add_argument('--draft_model_load_in_4bit', action='store_true', help="Load the draft model in the 4bit mode")
parser.add_argument('--flash_attn', action='store_true', help="Use flash attention to replace the LLaMA attention")
args = parser.parse_args()

if args.guidance_scale > 1:
Expand Down Expand Up @@ -71,9 +72,14 @@
import sys
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(parent_dir)
from attn_and_long_ctx_patches import apply_attention_patch, apply_ntk_scaling_patch
if not args.only_cpu:
apply_attention_patch(use_memory_efficient_attention=True)
if args.flash_attn:
from flash_attn_patch_for_inference import replace_llama_attn_with_flash_attn
replace_llama_attn_with_flash_attn()
else:
from attn_and_long_ctx_patches import apply_attention_patch
apply_attention_patch(use_memory_efficient_attention=True)
from attn_and_long_ctx_patches import apply_ntk_scaling_patch
apply_ntk_scaling_patch(args.alpha)
if args.speculative_sampling:
if args.draft_base_model == None:
Expand Down
2 changes: 2 additions & 0 deletions scripts/training/flash_attn_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def forward(
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
padding_mask=None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel

Expand Down Expand Up @@ -110,6 +111,7 @@ def _prepare_decoder_attention_mask(


def replace_llama_attn_with_flash_attn():
print("USE_FLASH_ATTENTION: ", True)
transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = (
_prepare_decoder_attention_mask
)
Expand Down