diff --git a/scripts/attn_and_long_ctx_patches.py b/scripts/attn_and_long_ctx_patches.py index d971923..5baaebd 100644 --- a/scripts/attn_and_long_ctx_patches.py +++ b/scripts/attn_and_long_ctx_patches.py @@ -39,6 +39,7 @@ def xformers_forward( past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: bool = False, use_cache: bool = False, + padding_mask=None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() @@ -216,7 +217,7 @@ def apply_attention_patch( global USE_MEM_EFF_ATTENTION, STORE_KV_BEFORE_ROPE if use_memory_efficient_attention is True and xops is not None: USE_MEM_EFF_ATTENTION = use_memory_efficient_attention - print("USE_MEM_EFF_ATTENTION: ",USE_MEM_EFF_ATTENTION) + print("USE_XFORMERS_ATTENTION: ", USE_MEM_EFF_ATTENTION) STORE_KV_BEFORE_ROPE = store_kv_before_rope print("STORE_KV_BEFORE_ROPE:", STORE_KV_BEFORE_ROPE) transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward diff --git a/scripts/inference/flash_attn_patch_for_inference.py b/scripts/inference/flash_attn_patch_for_inference.py new file mode 100644 index 0000000..f15e2d0 --- /dev/null +++ b/scripts/inference/flash_attn_patch_for_inference.py @@ -0,0 +1,95 @@ +# Below code is based on https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama_flash_attn_monkey_patch.py. +from typing import Optional, Tuple +import torch + +import transformers + +from einops import rearrange +try: + from flash_attn.flash_attn_interface import flash_attn_with_kvcache +except ImportError: + flash_attn_with_kvcache = None + print( + "FlashAttention-2 is not installed correctly. If you want to use flash attention to inference, flash-attention >= 2.2 is needed. " + "Please check the usage in https://github.com/Dao-AILab/flash-attention for more details." + ) + + +def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + padding_mask=None, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel + + attention_mask: [bsz, q_len] + """ + bsz, q_len, _ = hidden_states.size() + + query_states = ( + self.q_proj(hidden_states) + .view(bsz, q_len, self.num_heads, self.head_dim) + ) + key_states = ( + self.k_proj(hidden_states) + .view(bsz, q_len, self.num_heads, self.head_dim) + ) + value_states = ( + self.v_proj(hidden_states) + .view(bsz, q_len, self.num_heads, self.head_dim) + ) + + kv_seq_len = key_states.shape[1] + past_kv_len = 0 + if past_key_value is not None: + past_kv_len = past_key_value[0].shape[-2] + kv_seq_len += past_kv_len + + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + rotary_dim = cos.shape[-1] + cos, sin = cos.squeeze(0,1)[:,:rotary_dim//2].contiguous(), sin.squeeze(0,1)[:,:rotary_dim//2].contiguous() + + if past_key_value is not None: + key_cache = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1) + value_cache = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1) + else: + key_cache = key_states + value_cache = value_states + + assert not output_attentions, "output_attentions is not supported" + + q = query_states # [bsz, q_len, nh, hd] + k, v = key_states, value_states # [bsz, q_len, nh, hd] + + output = flash_attn_with_kvcache( + q, key_cache, value_cache, k, v, rotary_cos=cos, rotary_sin=sin, cache_seqlens=past_kv_len, softmax_scale=None, causal=True, rotary_interleaved=False + ) + output = rearrange(output, "b s h d -> b s (h d)", b=bsz) + + past_key_value = (key_cache[:,:kv_seq_len].transpose(1,2), value_cache[:,:kv_seq_len].transpose(1,2)) if use_cache else None + + output = self.o_proj(output) + + return output, None, past_key_value + + +# Disable the transformation of the attention mask in LlamaModel as the flash attention +# requires the attention mask to be the same as the key_padding_mask +def _prepare_decoder_attention_mask( + self, attention_mask, input_shape, inputs_embeds, past_key_values_length +): + return attention_mask + + +def replace_llama_attn_with_flash_attn(): + if flash_attn_with_kvcache != None: + print("USE_FLASH_ATTENTION: ", True) + transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = _prepare_decoder_attention_mask + transformers.models.llama.modeling_llama.LlamaAttention.forward = forward + else: + print("USE_FLASH_ATTENTION: ", False) diff --git a/scripts/inference/gradio_demo.py b/scripts/inference/gradio_demo.py index fafddee..0b3e25e 100644 --- a/scripts/inference/gradio_demo.py +++ b/scripts/inference/gradio_demo.py @@ -110,6 +110,10 @@ "--draft_model_load_in_4bit", action='store_true', help="Load the draft model in the 4bit mode") +parser.add_argument( + '--flash_attn', + action='store_true', + help="Use flash attention to replace the LLaMA attention") args = parser.parse_args() @@ -132,9 +136,14 @@ import sys parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) sys.path.append(parent_dir) -from attn_and_long_ctx_patches import apply_attention_patch, apply_ntk_scaling_patch if not args.only_cpu: - apply_attention_patch(use_memory_efficient_attention=True) + if args.flash_attn: + from flash_attn_patch_for_inference import replace_llama_attn_with_flash_attn + replace_llama_attn_with_flash_attn() + else: + from attn_and_long_ctx_patches import apply_attention_patch + apply_attention_patch(use_memory_efficient_attention=True) +from attn_and_long_ctx_patches import apply_ntk_scaling_patch apply_ntk_scaling_patch(args.alpha) if args.speculative_sampling: if args.draft_base_model == None: diff --git a/scripts/inference/inference_hf.py b/scripts/inference/inference_hf.py index 1cb029c..07cf80a 100644 --- a/scripts/inference/inference_hf.py +++ b/scripts/inference/inference_hf.py @@ -33,6 +33,7 @@ parser.add_argument('--draft_lora_model', default=None, type=str, help="If None, perform inference on the draft base model") parser.add_argument('--draft_model_load_in_8bit', action='store_true', help="Load the draft model in the 8bit mode") parser.add_argument('--draft_model_load_in_4bit', action='store_true', help="Load the draft model in the 4bit mode") +parser.add_argument('--flash_attn', action='store_true', help="Use flash attention to replace the LLaMA attention") args = parser.parse_args() if args.guidance_scale > 1: @@ -71,9 +72,14 @@ import sys parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) sys.path.append(parent_dir) -from attn_and_long_ctx_patches import apply_attention_patch, apply_ntk_scaling_patch if not args.only_cpu: - apply_attention_patch(use_memory_efficient_attention=True) + if args.flash_attn: + from flash_attn_patch_for_inference import replace_llama_attn_with_flash_attn + replace_llama_attn_with_flash_attn() + else: + from attn_and_long_ctx_patches import apply_attention_patch + apply_attention_patch(use_memory_efficient_attention=True) +from attn_and_long_ctx_patches import apply_ntk_scaling_patch apply_ntk_scaling_patch(args.alpha) if args.speculative_sampling: if args.draft_base_model == None: diff --git a/scripts/training/flash_attn_patch.py b/scripts/training/flash_attn_patch.py index 68dc0ec..890e989 100644 --- a/scripts/training/flash_attn_patch.py +++ b/scripts/training/flash_attn_patch.py @@ -22,6 +22,7 @@ def forward( past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: bool = False, use_cache: bool = False, + padding_mask=None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel @@ -110,6 +111,7 @@ def _prepare_decoder_attention_mask( def replace_llama_attn_with_flash_attn(): + print("USE_FLASH_ATTENTION: ", True) transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( _prepare_decoder_attention_mask )