Skip to content

Commit

Permalink
Replace FlashAttention with xformers (#70)
Browse files Browse the repository at this point in the history
  • Loading branch information
WoosukKwon authored May 5, 2023
1 parent 189ae23 commit c9d5b6d
Show file tree
Hide file tree
Showing 13 changed files with 87 additions and 131 deletions.
6 changes: 1 addition & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,7 @@
## Installation

```bash
pip install psutil numpy ray torch
pip install git+https://github.com/huggingface/transformers # Required for LLaMA.
pip install sentencepiece # Required for LlamaTokenizer.
pip install ninja # To parallelize the compilation of flash-attn.
pip install flash-attn # This may take up to 10 mins.
pip install ninja psutil numpy sentencepiece ray torch transformers xformers
pip install -e .
```

Expand Down
2 changes: 1 addition & 1 deletion cacheflow/master/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def add_server_arguments(parser: argparse.ArgumentParser):
parser.add_argument('--use-np-cache', action='store_true',
help='save a numpy copy of model weights for faster loading')
parser.add_argument('--use-dummy-weights', action='store_true', help='use dummy values for model weights')
# NOTE(woosuk): FlashAttention does not support float32.
# TODO(woosuk): Support FP32 for debugging.
parser.add_argument('--dtype', type=str, default='default', choices=['default', 'half', 'bfloat16'],
help=('data type for model weights and activations. '
'The "default" option will use FP16 precision '
Expand Down
54 changes: 16 additions & 38 deletions cacheflow/models/attention.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from typing import Optional

from flash_attn.flash_attn_interface import _flash_attn_forward
import torch
import torch.nn as nn
from xformers import ops as xops

from cacheflow import attention_ops
from cacheflow import cache_ops
Expand All @@ -15,39 +15,29 @@ class GPTCacheFlowAttention(nn.Module):
def __init__(self, scale: float) -> None:
super().__init__()
self.scale = float(scale)
self.attn_op = xops.fmha.cutlass.FwOp()

def multi_query_kv_attention(
self,
output: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
query: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
key: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
value: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
cumulative_prompt_lens: torch.Tensor, # [num_prompts + 1]
max_prompt_len: int,
attn_bias: xops.AttentionBias,
) -> None:
if query.dtype == torch.float:
raise ValueError('The float data type is not supported by '
'FlashAttention. Use the half data type instead.')
head_size = query.shape[-1]
if head_size > 128:
raise ValueError('FlashAttention does not support head_size > 128.')

# Directly call FlashAttention's internal function to avoid allocating
# a new tensor for the output.
_flash_attn_forward(
query,
key,
value,
output,
cumulative_prompt_lens,
cumulative_prompt_lens,
max_prompt_len,
max_prompt_len,
dropout_p=0.0,
softmax_scale=self.scale,
causal=True,
return_softmax=False,
# TODO(woosuk): The unsqueeze op may incur some CPU overhead. Optimize.
out = xops.memory_efficient_attention_forward(
query.unsqueeze(0),
key.unsqueeze(0),
value.unsqueeze(0),
attn_bias=attn_bias,
p=0.0,
scale=self.scale,
op=self.attn_op,
)
# TODO(woosuk): Unnecessary copy. Optimize.
output.copy_(out.squeeze(0))
return output

def single_query_cached_kv_attention(
self,
Expand Down Expand Up @@ -109,8 +99,7 @@ def forward(
query[:num_prompt_tokens],
key[:num_prompt_tokens],
value[:num_prompt_tokens],
input_metadata.cumulative_prompt_lens,
input_metadata.max_prompt_len,
input_metadata.attn_bias,
)

# Wait until the cache op is done.
Expand Down Expand Up @@ -143,13 +132,6 @@ def forward(
return output.view(-1, num_heads * head_size)


class OPTCacheFlowAttention(GPTCacheFlowAttention):
"""OPT uses the same attention mechanism as GPT."""

def __init__(self, scale: float) -> None:
super().__init__(scale)


class GPTNeoXCacheFlowAttention(GPTCacheFlowAttention):
"""Attention with GPT-NeoX style rotary embedding."""

Expand Down Expand Up @@ -207,7 +189,3 @@ def forward(
input_metadata,
cache_event,
)


class LlamaCacheFlowAttention(GPTNeoXCacheFlowAttention):
"""LLaMA uses the GPT-NeoX style rotary embedding."""
21 changes: 9 additions & 12 deletions cacheflow/models/input_metadata.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import List, Dict, Tuple

import torch
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask

from cacheflow.sampling_params import SamplingParams

Expand All @@ -12,7 +13,6 @@ def __init__(
seq_groups: List[Tuple[List[int], SamplingParams]],
seq_logprobs: Dict[int, float], # Seq id -> cumulative logprobs.
prompt_lens: List[int],
cumulative_prompt_lens: torch.Tensor,
slot_mapping: torch.Tensor,
context_lens: torch.Tensor,
max_context_len: int,
Expand All @@ -21,15 +21,14 @@ def __init__(
self.seq_groups = seq_groups
self.seq_logprobs = seq_logprobs
self.prompt_lens = prompt_lens
self.cumulative_prompt_lens = cumulative_prompt_lens
self.slot_mapping = slot_mapping
self.context_lens = context_lens
self.max_context_len = max_context_len
self.block_tables = block_tables

self.attn_bias = BlockDiagonalCausalMask.from_seqlens(prompt_lens)
self.num_prompts = len(prompt_lens)
self.num_prompt_tokens = sum(prompt_lens)
self.max_prompt_len = max(prompt_lens) if prompt_lens else 0
self.num_generation_tokens = context_lens.shape[0]
self.num_valid_tokens = slot_mapping.shape[0]
if block_tables.numel() > 0:
Expand All @@ -41,15 +40,13 @@ def __init__(

def __repr__(self) -> str:
return (f'InputMetadata('
f'num_prompts={self.num_prompts}, '
f'num_prompt_tokens={self.num_prompt_tokens}, '
f'max_prompt_len={self.max_prompt_len}, '
f'num_generation_tokens={self.num_generation_tokens}, '
f'num_valid_tokens={self.num_valid_tokens}, '
f'max_num_blocks_per_seq={self.max_num_blocks_per_seq}, '
f'max_context_len={self.max_context_len}), '
f'num_prompt_tokens={self.num_prompt_tokens}, '
f'num_prompts={self.num_prompts}, '
f'prompt_lens={self.prompt_lens}, '
f'cumulative_prompt_lens={self.cumulative_prompt_lens}, '
f'slot_mapping={self.slot_mapping}, '
f'num_generation_tokens={self.num_generation_tokens}, '
f'context_lens={self.context_lens}, '
f'block_tables={self.block_tables})')
f'max_context_len={self.max_context_len}), '
f'max_num_blocks_per_seq={self.max_num_blocks_per_seq}, '
f'block_tables={self.block_tables}), '
f'slot_mapping={self.slot_mapping}')
4 changes: 2 additions & 2 deletions cacheflow/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from cacheflow.models import InputMetadata
from cacheflow.models.activation import SiluAndMul
from cacheflow.models.attention import LlamaCacheFlowAttention
from cacheflow.models.attention import GPTNeoXCacheFlowAttention
from cacheflow.models.layernorm import RMSNorm
from cacheflow.models.sample import Sampler
from cacheflow.models.utils import (hf_model_weights_iterator,
Expand Down Expand Up @@ -79,7 +79,7 @@ def __init__(
input_is_parallel=True,
perform_initialization=False,
)
self.attn = LlamaCacheFlowAttention(self.scaling, self.head_dim)
self.attn = GPTNeoXCacheFlowAttention(self.scaling, self.head_dim)

def forward(
self,
Expand Down
12 changes: 6 additions & 6 deletions cacheflow/models/memory_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,8 +202,8 @@ def get_max_act_size(
# estimating
# 1) the maximum activation tensor size during inference
# 2) the residual tensor size during inference
# Here, we assume that FlashAttention is used and
# thus the attention maps are never materialized in GPU DRAM.
# Here, we assume that we use memory-efficient attention which
# does not materialize the attention maps in GPU DRAM.
residual = max_num_batched_tokens * self.hidden_size
qkv = 3 * (max_num_batched_tokens * self.hidden_size) // self.tensor_parallel_size
ffn = max_num_batched_tokens * self.ffn_size // self.tensor_parallel_size
Expand Down Expand Up @@ -277,8 +277,8 @@ def get_max_act_size(
# estimating
# 1) the maximum activation tensor size during inference
# 2) the residual tensor size during inference
# Here, we assume that FlashAttention is used and
# thus the attention maps are never materialized in GPU DRAM.
# Here, we assume that we use memory-efficient attention which
# does not materialize the attention maps in GPU DRAM.
residual = max_num_batched_tokens * self.hidden_size
qkv = 3 * (max_num_batched_tokens * self.hidden_size) // self.tensor_parallel_size
ffn = 2 * (max_num_batched_tokens * self.ffn_size) // self.tensor_parallel_size
Expand Down Expand Up @@ -353,8 +353,8 @@ def get_max_act_size(
# estimating
# 1) the maximum activation tensor size during inference
# 2) the residual tensor size during inference
# Here, we assume that FlashAttention is used and
# thus the attention maps are never materialized in GPU DRAM.
# Here, we assume that we use memory-efficient attention which
# does not materialize the attention maps in GPU DRAM.
residual = max_num_batched_tokens * self.hidden_size
qkv = 3 * (max_num_batched_tokens * self.hidden_size) // self.tensor_parallel_size
ffn = 2 * (max_num_batched_tokens * self.ffn_size) // self.tensor_parallel_size
Expand Down
4 changes: 2 additions & 2 deletions cacheflow/models/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from transformers import OPTConfig

from cacheflow.models import InputMetadata
from cacheflow.models.attention import OPTCacheFlowAttention
from cacheflow.models.attention import GPTCacheFlowAttention
from cacheflow.models.sample import Sampler
from cacheflow.models.utils import (hf_model_weights_iterator,
load_tensor_parallel_weights)
Expand Down Expand Up @@ -55,7 +55,7 @@ def __init__(
self.out_proj = RowParallelLinear(embed_dim, embed_dim, bias=bias,
input_is_parallel=True,
perform_initialization=False)
self.attn = OPTCacheFlowAttention(scale=self.scaling)
self.attn = GPTCacheFlowAttention(scale=self.scaling)

def forward(
self,
Expand Down
8 changes: 0 additions & 8 deletions cacheflow/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,11 +136,6 @@ def prepare_inputs(
slot = block_number * self.block_size + block_offset
slot_mapping.append(slot)

cumulative_prompt_lens: List[int] = [0]
for prompt_len in prompt_lens:
cumulative_prompt_lens.append(
cumulative_prompt_lens[-1] + prompt_len)

# Add generation tokens.
max_context_len = 0
max_num_blocks_per_seq = 0
Expand Down Expand Up @@ -196,14 +191,11 @@ def prepare_inputs(
for block_table in generation_block_tables]
block_tables_tensor = torch.tensor(
padded_block_tables, dtype=torch.int, device='cuda')
cumulative_prompt_lens_tensor = torch.tensor(
cumulative_prompt_lens, dtype=torch.int, device='cuda')

input_metadata = InputMetadata(
seq_groups=seq_groups,
seq_logprobs=seq_logprobs,
prompt_lens=prompt_lens,
cumulative_prompt_lens=cumulative_prompt_lens_tensor,
slot_mapping=slot_mapping_tensor,
context_lens=context_lens_tensor,
max_context_len=max_context_len,
Expand Down
2 changes: 1 addition & 1 deletion tests/kernels/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def test_silu_and_mul(


if __name__ == '__main__':
for dtype in [torch.half, torch.float]:
for dtype in [torch.half, torch.bfloat16, torch.float]:
for num_tokens in [7, 83, 2048]:
for d in [512, 4096, 13824]:
print(f'Testing dtype={dtype}, num_tokens={num_tokens}, d={d}')
Expand Down
Loading

0 comments on commit c9d5b6d

Please sign in to comment.