From de63b9caccdb054af81128caae48f4aa6ac8785d Mon Sep 17 00:00:00 2001 From: GoGoJoestar Date: Wed, 25 Oct 2023 14:34:27 +0800 Subject: [PATCH 1/3] Add flash attention support for inference --- scripts/attn_and_long_ctx_patches.py | 1 + .../flash_attn_patch_for_inference.py | 92 +++++++++++++++++++ scripts/inference/gradio_demo.py | 13 ++- scripts/inference/inference_hf.py | 10 +- scripts/training/flash_attn_patch.py | 1 + 5 files changed, 113 insertions(+), 4 deletions(-) create mode 100644 scripts/inference/flash_attn_patch_for_inference.py diff --git a/scripts/attn_and_long_ctx_patches.py b/scripts/attn_and_long_ctx_patches.py index d971923..f9a8827 100644 --- a/scripts/attn_and_long_ctx_patches.py +++ b/scripts/attn_and_long_ctx_patches.py @@ -39,6 +39,7 @@ def xformers_forward( past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: bool = False, use_cache: bool = False, + padding_mask=None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() diff --git a/scripts/inference/flash_attn_patch_for_inference.py b/scripts/inference/flash_attn_patch_for_inference.py new file mode 100644 index 0000000..fa64a4b --- /dev/null +++ b/scripts/inference/flash_attn_patch_for_inference.py @@ -0,0 +1,92 @@ +# Below code is based on https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama_flash_attn_monkey_patch.py. +from typing import Optional, Tuple +import torch + +import transformers + +from einops import rearrange +try: + from flash_attn.flash_attn_interface import flash_attn_with_kvcache +except ImportError: + flash_attn_with_kvcache = None + print( + "FlashAttention-2 is not installed correctly. If you want to use flash attention to inference, flash-attention >= 2.2 is needed. " + "Please check the usage in https://github.com/Dao-AILab/flash-attention for more details." + ) + + +def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + padding_mask=None, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel + + attention_mask: [bsz, q_len] + """ + bsz, q_len, _ = hidden_states.size() + + query_states = ( + self.q_proj(hidden_states) + .view(bsz, q_len, self.num_heads, self.head_dim) + ) + key_states = ( + self.k_proj(hidden_states) + .view(bsz, q_len, self.num_heads, self.head_dim) + ) + value_states = ( + self.v_proj(hidden_states) + .view(bsz, q_len, self.num_heads, self.head_dim) + ) + + kv_seq_len = key_states.shape[1] + past_kv_len = 0 + if past_key_value is not None: + past_kv_len = past_key_value[0].shape[-2] + kv_seq_len += past_kv_len + + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + rotary_dim = cos.shape[-1] + cos, sin = cos.squeeze(0,1)[:,:rotary_dim//2].contiguous(), sin.squeeze(0,1)[:,:rotary_dim//2].contiguous() + + if past_key_value is not None: + key_cache = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1) + value_cache = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1) + else: + key_cache = key_states + value_cache = value_states + + assert not output_attentions, "output_attentions is not supported" + + q = query_states # [bsz, q_len, nh, hd] + k, v = key_states, value_states # [bsz, q_len, nh, hd] + + output = flash_attn_with_kvcache( + q, key_cache, value_cache, k, v, rotary_cos=cos, rotary_sin=sin, cache_seqlens=past_kv_len, softmax_scale=None, causal=True, rotary_interleaved=False + ) + output = rearrange(output, "b s h d -> b s (h d)", b=bsz) + + past_key_value = (key_cache[:,:kv_seq_len].transpose(1,2), value_cache[:,:kv_seq_len].transpose(1,2)) if use_cache else None + + output = self.o_proj(output) + + return output, None, past_key_value + + +# Disable the transformation of the attention mask in LlamaModel as the flash attention +# requires the attention mask to be the same as the key_padding_mask +def _prepare_decoder_attention_mask( + self, attention_mask, input_shape, inputs_embeds, past_key_values_length +): + return attention_mask + + +def replace_llama_attn_with_flash_attn(): + if flash_attn_with_kvcache != None: + transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = _prepare_decoder_attention_mask + transformers.models.llama.modeling_llama.LlamaAttention.forward = forward diff --git a/scripts/inference/gradio_demo.py b/scripts/inference/gradio_demo.py index fafddee..0b3e25e 100644 --- a/scripts/inference/gradio_demo.py +++ b/scripts/inference/gradio_demo.py @@ -110,6 +110,10 @@ "--draft_model_load_in_4bit", action='store_true', help="Load the draft model in the 4bit mode") +parser.add_argument( + '--flash_attn', + action='store_true', + help="Use flash attention to replace the LLaMA attention") args = parser.parse_args() @@ -132,9 +136,14 @@ import sys parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) sys.path.append(parent_dir) -from attn_and_long_ctx_patches import apply_attention_patch, apply_ntk_scaling_patch if not args.only_cpu: - apply_attention_patch(use_memory_efficient_attention=True) + if args.flash_attn: + from flash_attn_patch_for_inference import replace_llama_attn_with_flash_attn + replace_llama_attn_with_flash_attn() + else: + from attn_and_long_ctx_patches import apply_attention_patch + apply_attention_patch(use_memory_efficient_attention=True) +from attn_and_long_ctx_patches import apply_ntk_scaling_patch apply_ntk_scaling_patch(args.alpha) if args.speculative_sampling: if args.draft_base_model == None: diff --git a/scripts/inference/inference_hf.py b/scripts/inference/inference_hf.py index 1cb029c..07cf80a 100644 --- a/scripts/inference/inference_hf.py +++ b/scripts/inference/inference_hf.py @@ -33,6 +33,7 @@ parser.add_argument('--draft_lora_model', default=None, type=str, help="If None, perform inference on the draft base model") parser.add_argument('--draft_model_load_in_8bit', action='store_true', help="Load the draft model in the 8bit mode") parser.add_argument('--draft_model_load_in_4bit', action='store_true', help="Load the draft model in the 4bit mode") +parser.add_argument('--flash_attn', action='store_true', help="Use flash attention to replace the LLaMA attention") args = parser.parse_args() if args.guidance_scale > 1: @@ -71,9 +72,14 @@ import sys parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) sys.path.append(parent_dir) -from attn_and_long_ctx_patches import apply_attention_patch, apply_ntk_scaling_patch if not args.only_cpu: - apply_attention_patch(use_memory_efficient_attention=True) + if args.flash_attn: + from flash_attn_patch_for_inference import replace_llama_attn_with_flash_attn + replace_llama_attn_with_flash_attn() + else: + from attn_and_long_ctx_patches import apply_attention_patch + apply_attention_patch(use_memory_efficient_attention=True) +from attn_and_long_ctx_patches import apply_ntk_scaling_patch apply_ntk_scaling_patch(args.alpha) if args.speculative_sampling: if args.draft_base_model == None: diff --git a/scripts/training/flash_attn_patch.py b/scripts/training/flash_attn_patch.py index 68dc0ec..68fc7ec 100644 --- a/scripts/training/flash_attn_patch.py +++ b/scripts/training/flash_attn_patch.py @@ -22,6 +22,7 @@ def forward( past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: bool = False, use_cache: bool = False, + padding_mask=None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel From 9a155b00aac458692eee8197be8d01f65f28218a Mon Sep 17 00:00:00 2001 From: GoGoJoestar Date: Wed, 25 Oct 2023 17:53:52 +0800 Subject: [PATCH 2/3] update: output information when using flash-attention/xformers-attention --- scripts/attn_and_long_ctx_patches.py | 2 +- scripts/inference/flash_attn_patch_for_inference.py | 3 +++ scripts/training/flash_attn_patch.py | 1 + 3 files changed, 5 insertions(+), 1 deletion(-) diff --git a/scripts/attn_and_long_ctx_patches.py b/scripts/attn_and_long_ctx_patches.py index f9a8827..5baaebd 100644 --- a/scripts/attn_and_long_ctx_patches.py +++ b/scripts/attn_and_long_ctx_patches.py @@ -217,7 +217,7 @@ def apply_attention_patch( global USE_MEM_EFF_ATTENTION, STORE_KV_BEFORE_ROPE if use_memory_efficient_attention is True and xops is not None: USE_MEM_EFF_ATTENTION = use_memory_efficient_attention - print("USE_MEM_EFF_ATTENTION: ",USE_MEM_EFF_ATTENTION) + print("USE_XFORMERS_ATTENTION: ", USE_MEM_EFF_ATTENTION) STORE_KV_BEFORE_ROPE = store_kv_before_rope print("STORE_KV_BEFORE_ROPE:", STORE_KV_BEFORE_ROPE) transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward diff --git a/scripts/inference/flash_attn_patch_for_inference.py b/scripts/inference/flash_attn_patch_for_inference.py index fa64a4b..8b5f866 100644 --- a/scripts/inference/flash_attn_patch_for_inference.py +++ b/scripts/inference/flash_attn_patch_for_inference.py @@ -88,5 +88,8 @@ def _prepare_decoder_attention_mask( def replace_llama_attn_with_flash_attn(): if flash_attn_with_kvcache != None: + print("USE_FLASH_ATTENTION: ", True) transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = _prepare_decoder_attention_mask transformers.models.llama.modeling_llama.LlamaAttention.forward = forward + else: + print("USE_FLASH_ATTENTION: ", False) diff --git a/scripts/training/flash_attn_patch.py b/scripts/training/flash_attn_patch.py index 68fc7ec..890e989 100644 --- a/scripts/training/flash_attn_patch.py +++ b/scripts/training/flash_attn_patch.py @@ -111,6 +111,7 @@ def _prepare_decoder_attention_mask( def replace_llama_attn_with_flash_attn(): + print("USE_FLASH_ATTENTION: ", True) transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( _prepare_decoder_attention_mask ) From c89237fef7f7f90a76d9ef934ae1034ef24c9a88 Mon Sep 17 00:00:00 2001 From: Ziqing Yang Date: Thu, 26 Oct 2023 08:47:45 +0800 Subject: [PATCH 3/3] Remove trailing whitespace --- scripts/inference/flash_attn_patch_for_inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/inference/flash_attn_patch_for_inference.py b/scripts/inference/flash_attn_patch_for_inference.py index 8b5f866..f15e2d0 100644 --- a/scripts/inference/flash_attn_patch_for_inference.py +++ b/scripts/inference/flash_attn_patch_for_inference.py @@ -70,7 +70,7 @@ def forward( q, key_cache, value_cache, k, v, rotary_cos=cos, rotary_sin=sin, cache_seqlens=past_kv_len, softmax_scale=None, causal=True, rotary_interleaved=False ) output = rearrange(output, "b s h d -> b s (h d)", b=bsz) - + past_key_value = (key_cache[:,:kv_seq_len].transpose(1,2), value_cache[:,:kv_seq_len].transpose(1,2)) if use_cache else None output = self.o_proj(output)