Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace FlashAttention with xformers #70

Merged
merged 12 commits into from
May 5, 2023
6 changes: 1 addition & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,7 @@
## Installation

```bash
pip install psutil numpy ray torch
pip install git+https://github.com/huggingface/transformers # Required for LLaMA.
pip install sentencepiece # Required for LlamaTokenizer.
pip install ninja # To parallelize the compilation of flash-attn.
pip install flash-attn # This may take up to 10 mins.
pip install ninja psutil numpy sentencepiece ray torch transformers xformers
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO (in the next PR): specify the exact dependencies in setup.py.

pip install -e .
```

Expand Down
2 changes: 1 addition & 1 deletion cacheflow/master/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def add_server_arguments(parser: argparse.ArgumentParser):
parser.add_argument('--use-np-cache', action='store_true',
help='save a numpy copy of model weights for faster loading')
parser.add_argument('--use-dummy-weights', action='store_true', help='use dummy values for model weights')
# NOTE(woosuk): FlashAttention does not support float32.
# TODO(woosuk): Support FP32 for debugging.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does xformers support FP32?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it does. It is our attention kernel that does not support FP32. More precisely, our attention kernel currently does not support some block sizes when FP32 is used. I will fix this in the future.

parser.add_argument('--dtype', type=str, default='default', choices=['default', 'half', 'bfloat16'],
help=('data type for model weights and activations. '
'The "default" option will use FP16 precision '
Expand Down
54 changes: 16 additions & 38 deletions cacheflow/models/attention.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from typing import Optional

from flash_attn.flash_attn_interface import _flash_attn_forward
import torch
import torch.nn as nn
from xformers import ops as xops

from cacheflow import attention_ops
from cacheflow import cache_ops
Expand All @@ -15,39 +15,29 @@ class GPTCacheFlowAttention(nn.Module):
def __init__(self, scale: float) -> None:
super().__init__()
self.scale = float(scale)
self.attn_op = xops.fmha.cutlass.FwOp()

def multi_query_kv_attention(
self,
output: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
query: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
key: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
value: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
cumulative_prompt_lens: torch.Tensor, # [num_prompts + 1]
max_prompt_len: int,
attn_bias: xops.AttentionBias,
) -> None:
if query.dtype == torch.float:
raise ValueError('The float data type is not supported by '
'FlashAttention. Use the half data type instead.')
head_size = query.shape[-1]
if head_size > 128:
raise ValueError('FlashAttention does not support head_size > 128.')

# Directly call FlashAttention's internal function to avoid allocating
# a new tensor for the output.
_flash_attn_forward(
query,
key,
value,
output,
cumulative_prompt_lens,
cumulative_prompt_lens,
max_prompt_len,
max_prompt_len,
dropout_p=0.0,
softmax_scale=self.scale,
causal=True,
return_softmax=False,
# TODO(woosuk): The unsqueeze op may incur some CPU overhead. Optimize.
out = xops.memory_efficient_attention_forward(
query.unsqueeze(0),
key.unsqueeze(0),
value.unsqueeze(0),
attn_bias=attn_bias,
p=0.0,
scale=self.scale,
op=self.attn_op,
)
# TODO(woosuk): Unnecessary copy. Optimize.
output.copy_(out.squeeze(0))
return output

def single_query_cached_kv_attention(
self,
Expand Down Expand Up @@ -109,8 +99,7 @@ def forward(
query[:num_prompt_tokens],
key[:num_prompt_tokens],
value[:num_prompt_tokens],
input_metadata.cumulative_prompt_lens,
input_metadata.max_prompt_len,
input_metadata.attn_bias,
)

# Wait until the cache op is done.
Expand Down Expand Up @@ -143,13 +132,6 @@ def forward(
return output.view(-1, num_heads * head_size)


class OPTCacheFlowAttention(GPTCacheFlowAttention):
"""OPT uses the same attention mechanism as GPT."""

def __init__(self, scale: float) -> None:
super().__init__(scale)


class GPTNeoXCacheFlowAttention(GPTCacheFlowAttention):
"""Attention with GPT-NeoX style rotary embedding."""

Expand Down Expand Up @@ -207,7 +189,3 @@ def forward(
input_metadata,
cache_event,
)


class LlamaCacheFlowAttention(GPTNeoXCacheFlowAttention):
"""LLaMA uses the GPT-NeoX style rotary embedding."""
21 changes: 9 additions & 12 deletions cacheflow/models/input_metadata.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import List, Dict, Tuple

import torch
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask

from cacheflow.sampling_params import SamplingParams

Expand All @@ -12,7 +13,6 @@ def __init__(
seq_groups: List[Tuple[List[int], SamplingParams]],
seq_logprobs: Dict[int, float], # Seq id -> cumulative logprobs.
prompt_lens: List[int],
cumulative_prompt_lens: torch.Tensor,
slot_mapping: torch.Tensor,
context_lens: torch.Tensor,
max_context_len: int,
Expand All @@ -21,15 +21,14 @@ def __init__(
self.seq_groups = seq_groups
self.seq_logprobs = seq_logprobs
self.prompt_lens = prompt_lens
self.cumulative_prompt_lens = cumulative_prompt_lens
self.slot_mapping = slot_mapping
self.context_lens = context_lens
self.max_context_len = max_context_len
self.block_tables = block_tables

self.attn_bias = BlockDiagonalCausalMask.from_seqlens(prompt_lens)
self.num_prompts = len(prompt_lens)
self.num_prompt_tokens = sum(prompt_lens)
self.max_prompt_len = max(prompt_lens) if prompt_lens else 0
self.num_generation_tokens = context_lens.shape[0]
self.num_valid_tokens = slot_mapping.shape[0]
if block_tables.numel() > 0:
Expand All @@ -41,15 +40,13 @@ def __init__(

def __repr__(self) -> str:
return (f'InputMetadata('
f'num_prompts={self.num_prompts}, '
f'num_prompt_tokens={self.num_prompt_tokens}, '
f'max_prompt_len={self.max_prompt_len}, '
f'num_generation_tokens={self.num_generation_tokens}, '
f'num_valid_tokens={self.num_valid_tokens}, '
f'max_num_blocks_per_seq={self.max_num_blocks_per_seq}, '
f'max_context_len={self.max_context_len}), '
f'num_prompt_tokens={self.num_prompt_tokens}, '
f'num_prompts={self.num_prompts}, '
f'prompt_lens={self.prompt_lens}, '
f'cumulative_prompt_lens={self.cumulative_prompt_lens}, '
f'slot_mapping={self.slot_mapping}, '
f'num_generation_tokens={self.num_generation_tokens}, '
f'context_lens={self.context_lens}, '
f'block_tables={self.block_tables})')
f'max_context_len={self.max_context_len}), '
f'max_num_blocks_per_seq={self.max_num_blocks_per_seq}, '
f'block_tables={self.block_tables}), '
f'slot_mapping={self.slot_mapping}')
4 changes: 2 additions & 2 deletions cacheflow/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from cacheflow.models import InputMetadata
from cacheflow.models.activation import SiluAndMul
from cacheflow.models.attention import LlamaCacheFlowAttention
from cacheflow.models.attention import GPTNeoXCacheFlowAttention
from cacheflow.models.layernorm import RMSNorm
from cacheflow.models.sample import Sampler
from cacheflow.models.utils import (hf_model_weights_iterator,
Expand Down Expand Up @@ -79,7 +79,7 @@ def __init__(
input_is_parallel=True,
perform_initialization=False,
)
self.attn = LlamaCacheFlowAttention(self.scaling, self.head_dim)
self.attn = GPTNeoXCacheFlowAttention(self.scaling, self.head_dim)

def forward(
self,
Expand Down
12 changes: 6 additions & 6 deletions cacheflow/models/memory_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,8 +202,8 @@ def get_max_act_size(
# estimating
# 1) the maximum activation tensor size during inference
# 2) the residual tensor size during inference
# Here, we assume that FlashAttention is used and
# thus the attention maps are never materialized in GPU DRAM.
# Here, we assume that we use memory-efficient attention which
# does not materialize the attention maps in GPU DRAM.
residual = max_num_batched_tokens * self.hidden_size
qkv = 3 * (max_num_batched_tokens * self.hidden_size) // self.tensor_parallel_size
ffn = max_num_batched_tokens * self.ffn_size // self.tensor_parallel_size
Expand Down Expand Up @@ -277,8 +277,8 @@ def get_max_act_size(
# estimating
# 1) the maximum activation tensor size during inference
# 2) the residual tensor size during inference
# Here, we assume that FlashAttention is used and
# thus the attention maps are never materialized in GPU DRAM.
# Here, we assume that we use memory-efficient attention which
# does not materialize the attention maps in GPU DRAM.
residual = max_num_batched_tokens * self.hidden_size
qkv = 3 * (max_num_batched_tokens * self.hidden_size) // self.tensor_parallel_size
ffn = 2 * (max_num_batched_tokens * self.ffn_size) // self.tensor_parallel_size
Expand Down Expand Up @@ -353,8 +353,8 @@ def get_max_act_size(
# estimating
# 1) the maximum activation tensor size during inference
# 2) the residual tensor size during inference
# Here, we assume that FlashAttention is used and
# thus the attention maps are never materialized in GPU DRAM.
# Here, we assume that we use memory-efficient attention which
# does not materialize the attention maps in GPU DRAM.
residual = max_num_batched_tokens * self.hidden_size
qkv = 3 * (max_num_batched_tokens * self.hidden_size) // self.tensor_parallel_size
ffn = 2 * (max_num_batched_tokens * self.ffn_size) // self.tensor_parallel_size
Expand Down
4 changes: 2 additions & 2 deletions cacheflow/models/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from transformers import OPTConfig

from cacheflow.models import InputMetadata
from cacheflow.models.attention import OPTCacheFlowAttention
from cacheflow.models.attention import GPTCacheFlowAttention
from cacheflow.models.sample import Sampler
from cacheflow.models.utils import (hf_model_weights_iterator,
load_tensor_parallel_weights)
Expand Down Expand Up @@ -55,7 +55,7 @@ def __init__(
self.out_proj = RowParallelLinear(embed_dim, embed_dim, bias=bias,
input_is_parallel=True,
perform_initialization=False)
self.attn = OPTCacheFlowAttention(scale=self.scaling)
self.attn = GPTCacheFlowAttention(scale=self.scaling)

def forward(
self,
Expand Down
8 changes: 0 additions & 8 deletions cacheflow/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,11 +136,6 @@ def prepare_inputs(
slot = block_number * self.block_size + block_offset
slot_mapping.append(slot)

cumulative_prompt_lens: List[int] = [0]
for prompt_len in prompt_lens:
cumulative_prompt_lens.append(
cumulative_prompt_lens[-1] + prompt_len)

# Add generation tokens.
max_context_len = 0
max_num_blocks_per_seq = 0
Expand Down Expand Up @@ -196,14 +191,11 @@ def prepare_inputs(
for block_table in generation_block_tables]
block_tables_tensor = torch.tensor(
padded_block_tables, dtype=torch.int, device='cuda')
cumulative_prompt_lens_tensor = torch.tensor(
cumulative_prompt_lens, dtype=torch.int, device='cuda')

input_metadata = InputMetadata(
seq_groups=seq_groups,
seq_logprobs=seq_logprobs,
prompt_lens=prompt_lens,
cumulative_prompt_lens=cumulative_prompt_lens_tensor,
slot_mapping=slot_mapping_tensor,
context_lens=context_lens_tensor,
max_context_len=max_context_len,
Expand Down
2 changes: 1 addition & 1 deletion tests/kernels/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def test_silu_and_mul(


if __name__ == '__main__':
for dtype in [torch.half, torch.float]:
for dtype in [torch.half, torch.bfloat16, torch.float]:
for num_tokens in [7, 83, 2048]:
for d in [512, 4096, 13824]:
print(f'Testing dtype={dtype}, num_tokens={num_tokens}, d={d}')
Expand Down
Loading