From c653294e996c47fd5a0b139b700fa17cbd3d9e8b Mon Sep 17 00:00:00 2001 From: woWoosuk Kwon Date: Wed, 3 May 2023 05:51:17 +0000 Subject: [PATCH 01/10] Fix GPTAttention --- cacheflow/models/attention.py | 45 +++++++++++++---------------------- 1 file changed, 17 insertions(+), 28 deletions(-) diff --git a/cacheflow/models/attention.py b/cacheflow/models/attention.py index 4c085644b5fe8..3293e2f7a3d69 100644 --- a/cacheflow/models/attention.py +++ b/cacheflow/models/attention.py @@ -1,8 +1,8 @@ -from typing import Optional +from typing import List, Optional -from flash_attn.flash_attn_interface import _flash_attn_forward import torch import torch.nn as nn +from xformers import ops as xops from cacheflow import attention_ops from cacheflow import cache_ops @@ -15,6 +15,7 @@ class GPTCacheFlowAttention(nn.Module): def __init__(self, scale: float) -> None: super().__init__() self.scale = float(scale) + self.attn_op = xops.fmha.cutlass.FwOp() def multi_query_kv_attention( self, @@ -22,32 +23,21 @@ def multi_query_kv_attention( query: torch.Tensor, # [num_prompt_tokens, num_heads, head_size] key: torch.Tensor, # [num_prompt_tokens, num_heads, head_size] value: torch.Tensor, # [num_prompt_tokens, num_heads, head_size] - cumulative_prompt_lens: torch.Tensor, # [num_prompts + 1] - max_prompt_len: int, + prompt_lens: List[int], ) -> None: - if query.dtype == torch.float: - raise ValueError('The float data type is not supported by ' - 'FlashAttention. Use the half data type instead.') - head_size = query.shape[-1] - if head_size > 128: - raise ValueError('FlashAttention does not support head_size > 128.') - - # Directly call FlashAttention's internal function to avoid allocating - # a new tensor for the output. - _flash_attn_forward( - query, - key, - value, - output, - cumulative_prompt_lens, - cumulative_prompt_lens, - max_prompt_len, - max_prompt_len, - dropout_p=0.0, - softmax_scale=self.scale, - causal=True, - return_softmax=False, + # FIXME + attn_bias = xops.fmha.attn_bias.BlockDiagonalCausalMask.from_seqlens(prompt_lens) + out = xops.memory_efficient_attention_forward( + query.unsqueeze(0), + key.unsqueeze(0), + value.unsqueeze(0), + attn_bias=attn_bias, + p=0.0, + scale=self.scale, + op=self.attn_op, ) + output.copy_(out.squeeze(0)) + return output def single_query_cached_kv_attention( self, @@ -109,8 +99,7 @@ def forward( query[:num_prompt_tokens], key[:num_prompt_tokens], value[:num_prompt_tokens], - input_metadata.cumulative_prompt_lens, - input_metadata.max_prompt_len, + input_metadata.prompt_lens, ) # Wait until the cache op is done. From 00c6129ed92bf6de3029d7b103fcde2a7c7b2c1f Mon Sep 17 00:00:00 2001 From: woWoosuk Kwon Date: Thu, 4 May 2023 08:23:19 +0000 Subject: [PATCH 02/10] Remove OPT and LLaMA attention --- cacheflow/models/attention.py | 11 ----------- cacheflow/models/llama.py | 4 ++-- cacheflow/models/opt.py | 4 ++-- 3 files changed, 4 insertions(+), 15 deletions(-) diff --git a/cacheflow/models/attention.py b/cacheflow/models/attention.py index 3293e2f7a3d69..fee656063fd8e 100644 --- a/cacheflow/models/attention.py +++ b/cacheflow/models/attention.py @@ -132,13 +132,6 @@ def forward( return output.view(-1, num_heads * head_size) -class OPTCacheFlowAttention(GPTCacheFlowAttention): - """OPT uses the same attention mechanism as GPT.""" - - def __init__(self, scale: float) -> None: - super().__init__(scale) - - class GPTNeoXCacheFlowAttention(GPTCacheFlowAttention): """Attention with GPT-NeoX style rotary embedding.""" @@ -196,7 +189,3 @@ def forward( input_metadata, cache_event, ) - - -class LlamaCacheFlowAttention(GPTNeoXCacheFlowAttention): - """LLaMA uses the GPT-NeoX style rotary embedding.""" diff --git a/cacheflow/models/llama.py b/cacheflow/models/llama.py index 1eda7f23d077d..0587858205ad8 100644 --- a/cacheflow/models/llama.py +++ b/cacheflow/models/llama.py @@ -7,7 +7,7 @@ from cacheflow.models import InputMetadata from cacheflow.models.activation import SiluAndMul -from cacheflow.models.attention import LlamaCacheFlowAttention +from cacheflow.models.attention import GPTNeoXCacheFlowAttention from cacheflow.models.layernorm import RMSNorm from cacheflow.models.sample import Sampler from cacheflow.models.utils import (hf_model_weights_iterator, @@ -79,7 +79,7 @@ def __init__( input_is_parallel=True, perform_initialization=False, ) - self.attn = LlamaCacheFlowAttention(self.scaling, self.head_dim) + self.attn = GPTNeoXCacheFlowAttention(self.scaling, self.head_dim) def forward( self, diff --git a/cacheflow/models/opt.py b/cacheflow/models/opt.py index 15f0f688d1af6..920ee415ce81c 100644 --- a/cacheflow/models/opt.py +++ b/cacheflow/models/opt.py @@ -6,7 +6,7 @@ from transformers import OPTConfig from cacheflow.models import InputMetadata -from cacheflow.models.attention import OPTCacheFlowAttention +from cacheflow.models.attention import GPTCacheFlowAttention from cacheflow.models.sample import Sampler from cacheflow.models.utils import (hf_model_weights_iterator, load_tensor_parallel_weights) @@ -55,7 +55,7 @@ def __init__( self.out_proj = RowParallelLinear(embed_dim, embed_dim, bias=bias, input_is_parallel=True, perform_initialization=False) - self.attn = OPTCacheFlowAttention(scale=self.scaling) + self.attn = GPTCacheFlowAttention(scale=self.scaling) def forward( self, From a8afbca35161a20a7edc66aa853dafecef1b25ff Mon Sep 17 00:00:00 2001 From: woWoosuk Kwon Date: Thu, 4 May 2023 08:42:13 +0000 Subject: [PATCH 03/10] Replace flash attention with xformers --- cacheflow/models/attention.py | 10 +++++----- cacheflow/models/input_metadata.py | 21 +++++++++------------ cacheflow/worker/worker.py | 8 -------- 3 files changed, 14 insertions(+), 25 deletions(-) diff --git a/cacheflow/models/attention.py b/cacheflow/models/attention.py index fee656063fd8e..179fbd0b32c8a 100644 --- a/cacheflow/models/attention.py +++ b/cacheflow/models/attention.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import Optional import torch import torch.nn as nn @@ -23,10 +23,9 @@ def multi_query_kv_attention( query: torch.Tensor, # [num_prompt_tokens, num_heads, head_size] key: torch.Tensor, # [num_prompt_tokens, num_heads, head_size] value: torch.Tensor, # [num_prompt_tokens, num_heads, head_size] - prompt_lens: List[int], + attn_bias: xops.AttentionBias, ) -> None: - # FIXME - attn_bias = xops.fmha.attn_bias.BlockDiagonalCausalMask.from_seqlens(prompt_lens) + # TODO(woosuk): The unsqueeze op may incur some CPU overhead. Optimize. out = xops.memory_efficient_attention_forward( query.unsqueeze(0), key.unsqueeze(0), @@ -36,6 +35,7 @@ def multi_query_kv_attention( scale=self.scale, op=self.attn_op, ) + # TODO(woosuk): Unnecessary copy. Optimize. output.copy_(out.squeeze(0)) return output @@ -99,7 +99,7 @@ def forward( query[:num_prompt_tokens], key[:num_prompt_tokens], value[:num_prompt_tokens], - input_metadata.prompt_lens, + input_metadata.attn_bias, ) # Wait until the cache op is done. diff --git a/cacheflow/models/input_metadata.py b/cacheflow/models/input_metadata.py index c61bfff20a66b..943524c9cca70 100644 --- a/cacheflow/models/input_metadata.py +++ b/cacheflow/models/input_metadata.py @@ -1,6 +1,7 @@ from typing import List, Dict, Tuple import torch +from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask from cacheflow.sampling_params import SamplingParams @@ -12,7 +13,6 @@ def __init__( seq_groups: List[Tuple[List[int], SamplingParams]], seq_logprobs: Dict[int, float], # Seq id -> cumulative logprobs. prompt_lens: List[int], - cumulative_prompt_lens: torch.Tensor, slot_mapping: torch.Tensor, context_lens: torch.Tensor, max_context_len: int, @@ -21,15 +21,14 @@ def __init__( self.seq_groups = seq_groups self.seq_logprobs = seq_logprobs self.prompt_lens = prompt_lens - self.cumulative_prompt_lens = cumulative_prompt_lens self.slot_mapping = slot_mapping self.context_lens = context_lens self.max_context_len = max_context_len self.block_tables = block_tables + self.attn_bias = BlockDiagonalCausalMask.from_seqlens(prompt_lens) self.num_prompts = len(prompt_lens) self.num_prompt_tokens = sum(prompt_lens) - self.max_prompt_len = max(prompt_lens) if prompt_lens else 0 self.num_generation_tokens = context_lens.shape[0] self.num_valid_tokens = slot_mapping.shape[0] if block_tables.numel() > 0: @@ -41,15 +40,13 @@ def __init__( def __repr__(self) -> str: return (f'InputMetadata(' - f'num_prompts={self.num_prompts}, ' - f'num_prompt_tokens={self.num_prompt_tokens}, ' - f'max_prompt_len={self.max_prompt_len}, ' - f'num_generation_tokens={self.num_generation_tokens}, ' f'num_valid_tokens={self.num_valid_tokens}, ' - f'max_num_blocks_per_seq={self.max_num_blocks_per_seq}, ' - f'max_context_len={self.max_context_len}), ' + f'num_prompt_tokens={self.num_prompt_tokens}, ' + f'num_prompts={self.num_prompts}, ' f'prompt_lens={self.prompt_lens}, ' - f'cumulative_prompt_lens={self.cumulative_prompt_lens}, ' - f'slot_mapping={self.slot_mapping}, ' + f'num_generation_tokens={self.num_generation_tokens}, ' f'context_lens={self.context_lens}, ' - f'block_tables={self.block_tables})') + f'max_context_len={self.max_context_len}), ' + f'max_num_blocks_per_seq={self.max_num_blocks_per_seq}, ' + f'block_tables={self.block_tables}), ' + f'slot_mapping={self.slot_mapping}') diff --git a/cacheflow/worker/worker.py b/cacheflow/worker/worker.py index 9b76d04ec8b17..59001b9d8fdcb 100644 --- a/cacheflow/worker/worker.py +++ b/cacheflow/worker/worker.py @@ -136,11 +136,6 @@ def prepare_inputs( slot = block_number * self.block_size + block_offset slot_mapping.append(slot) - cumulative_prompt_lens: List[int] = [0] - for prompt_len in prompt_lens: - cumulative_prompt_lens.append( - cumulative_prompt_lens[-1] + prompt_len) - # Add generation tokens. max_context_len = 0 max_num_blocks_per_seq = 0 @@ -196,14 +191,11 @@ def prepare_inputs( for block_table in generation_block_tables] block_tables_tensor = torch.tensor( padded_block_tables, dtype=torch.int, device='cuda') - cumulative_prompt_lens_tensor = torch.tensor( - cumulative_prompt_lens, dtype=torch.int, device='cuda') input_metadata = InputMetadata( seq_groups=seq_groups, seq_logprobs=seq_logprobs, prompt_lens=prompt_lens, - cumulative_prompt_lens=cumulative_prompt_lens_tensor, slot_mapping=slot_mapping_tensor, context_lens=context_lens_tensor, max_context_len=max_context_len, From 4a4717960902f2db8664b3e744157dacb9c507a0 Mon Sep 17 00:00:00 2001 From: woWoosuk Kwon Date: Thu, 4 May 2023 09:02:41 +0000 Subject: [PATCH 04/10] Remove flash attention in comments --- cacheflow/master/server.py | 2 +- cacheflow/models/memory_analyzer.py | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/cacheflow/master/server.py b/cacheflow/master/server.py index e2a9956faf989..0481cc538de0c 100644 --- a/cacheflow/master/server.py +++ b/cacheflow/master/server.py @@ -213,7 +213,7 @@ def add_server_arguments(parser: argparse.ArgumentParser): parser.add_argument('--use-np-cache', action='store_true', help='save a numpy copy of model weights for faster loading') parser.add_argument('--use-dummy-weights', action='store_true', help='use dummy values for model weights') - # NOTE(woosuk): FlashAttention does not support float32. + # TODO(woosuk): Support FP32 for debugging. parser.add_argument('--dtype', type=str, default='half', choices=['half', 'bfloat16'], help='data type') # Parallel arguments parser.add_argument('--use-ray', action='store_true', help='use Ray for distributed serving, will be automatically set when using more than 1 GPU') diff --git a/cacheflow/models/memory_analyzer.py b/cacheflow/models/memory_analyzer.py index 738c6d11d023e..b6c9ddabbceb6 100644 --- a/cacheflow/models/memory_analyzer.py +++ b/cacheflow/models/memory_analyzer.py @@ -132,8 +132,8 @@ def get_max_act_size( # estimating # 1) the maximum activation tensor size during inference # 2) the residual tensor size during inference - # Here, we assume that FlashAttention is used and - # thus the attention maps are never materialized in GPU DRAM. + # Here, we assume that we use memory-efficient attention which + # does not materialize the attention maps in GPU DRAM. residual = max_num_batched_tokens * self.hidden_size qkv = 3 * (max_num_batched_tokens * self.hidden_size) // self.tensor_parallel_size ffn = max_num_batched_tokens * self.ffn_size // self.tensor_parallel_size @@ -207,8 +207,8 @@ def get_max_act_size( # estimating # 1) the maximum activation tensor size during inference # 2) the residual tensor size during inference - # Here, we assume that FlashAttention is used and - # thus the attention maps are never materialized in GPU DRAM. + # Here, we assume that we use memory-efficient attention which + # does not materialize the attention maps in GPU DRAM. residual = max_num_batched_tokens * self.hidden_size qkv = 3 * (max_num_batched_tokens * self.hidden_size) // self.tensor_parallel_size ffn = 2 * (max_num_batched_tokens * self.ffn_size) // self.tensor_parallel_size @@ -283,8 +283,8 @@ def get_max_act_size( # estimating # 1) the maximum activation tensor size during inference # 2) the residual tensor size during inference - # Here, we assume that FlashAttention is used and - # thus the attention maps are never materialized in GPU DRAM. + # Here, we assume that we use memory-efficient attention which + # does not materialize the attention maps in GPU DRAM. residual = max_num_batched_tokens * self.hidden_size qkv = 3 * (max_num_batched_tokens * self.hidden_size) // self.tensor_parallel_size ffn = 2 * (max_num_batched_tokens * self.ffn_size) // self.tensor_parallel_size From 76ad4dc19afa68c2feb772d39f3fa62b53d7f112 Mon Sep 17 00:00:00 2001 From: woWoosuk Kwon Date: Thu, 4 May 2023 09:03:02 +0000 Subject: [PATCH 05/10] Fix attention kernel tests --- tests/kernels/attention.py | 51 ++++++++++++++++---------------------- 1 file changed, 21 insertions(+), 30 deletions(-) diff --git a/tests/kernels/attention.py b/tests/kernels/attention.py index 4567315d2e7a8..b9ff74d770edd 100644 --- a/tests/kernels/attention.py +++ b/tests/kernels/attention.py @@ -1,8 +1,9 @@ import random from typing import List, Optional -from flash_attn.flash_attn_interface import _flash_attn_forward import torch +from xformers import ops as xops +from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask from cacheflow import attention_ops @@ -228,39 +229,31 @@ def test_multi_query_kv_attention( dtype: torch.dtype, ) -> None: seq_lens = random.sample(range(1, MAX_SEQ_LEN), num_seqs) - max_seq_len = max(seq_lens) num_tokens = sum(seq_lens) - cu_seq_lens = [0] - for seq_len in seq_lens: - cu_seq_lens.append(cu_seq_lens[-1] + seq_len) - cu_seq_lens = torch.tensor(cu_seq_lens, dtype=torch.int, device='cuda') - scale = float(1.0 / (head_size ** 0.5)) qkv = torch.randn( num_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda') # Adjust the range of the values to reduce precision errors. qkv = qkv / (head_size ** 0.5) - query, key, value = qkv.unbind(dim=1) - output = torch.empty( - num_tokens, num_heads, head_size, dtype=dtype, device='cuda') - _flash_attn_forward( - query, - key, - value, - output, - cu_seq_lens, - cu_seq_lens, - max_seq_len, - max_seq_len, - dropout_p=0.0, - softmax_scale=scale, - causal=True, - return_softmax=False, + + attn_op = xops.fmha.cutlass.FwOp() + attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_lens) + output = xops.memory_efficient_attention_forward( + query.unsqueeze(0), + key.unsqueeze(0), + value.unsqueeze(0), + attn_bias=attn_bias, + p=0.0, + scale=scale, + op=attn_op, ) + output = output.squeeze(0) - cu_seq_lens = cu_seq_lens.cpu().tolist() + cu_seq_lens = [0] + for seq_len in seq_lens: + cu_seq_lens.append(cu_seq_lens[-1] + seq_len) ref_output = ref_multi_query_kv_attention( cu_seq_lens, query, @@ -277,9 +270,9 @@ def test_attention(seed: int) -> None: # the test fails due to the precision issue. Re-run the test if it fails. torch.random.manual_seed(seed) torch.cuda.manual_seed(seed) - for dtype in [torch.half, torch.float]: - for block_size in [8, 16, 32]: - for head_size in [32, 64, 80, 96, 128, 160, 192, 256]: + for dtype in [torch.half, torch.bfloat16]: + for block_size in [8, 16, 32, 64]: + for head_size in [64, 80, 96, 128, 256]: print(f'Testing single_query_cached_kv_attention with ' f'dtype={dtype}, block_size={block_size}, ' f'head_size={head_size}') @@ -292,9 +285,7 @@ def test_attention(seed: int) -> None: dtype=dtype, ) - # NOTE(woosuk): FlashAttention does not support FP32. - for dtype in [torch.half]: - # NOTE(woosuk): FlashAttention does not support head_size > 128. + for dtype in [torch.half, torch.bfloat16]: for head_size in [64, 80, 96, 128]: print(f'Testing multi_query_kv_attention with dtype={dtype}, ' f'head_size={head_size}') From 0330f833c137c36876c49f60352d4b8a585dfc91 Mon Sep 17 00:00:00 2001 From: woWoosuk Kwon Date: Thu, 4 May 2023 09:04:07 +0000 Subject: [PATCH 06/10] Minor --- tests/kernels/attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/kernels/attention.py b/tests/kernels/attention.py index b9ff74d770edd..d004746859b7c 100644 --- a/tests/kernels/attention.py +++ b/tests/kernels/attention.py @@ -286,7 +286,7 @@ def test_attention(seed: int) -> None: ) for dtype in [torch.half, torch.bfloat16]: - for head_size in [64, 80, 96, 128]: + for head_size in [64, 80, 96, 128, 256]: print(f'Testing multi_query_kv_attention with dtype={dtype}, ' f'head_size={head_size}') test_multi_query_kv_attention( From e6bfe3ac6457a14d1b451ea3dc6e16d26e319a89 Mon Sep 17 00:00:00 2001 From: woWoosuk Kwon Date: Thu, 4 May 2023 09:17:46 +0000 Subject: [PATCH 07/10] Make tests robust to precision errors --- tests/kernels/attention.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/tests/kernels/attention.py b/tests/kernels/attention.py index d004746859b7c..a5b6b6f207341 100644 --- a/tests/kernels/attention.py +++ b/tests/kernels/attention.py @@ -82,8 +82,10 @@ def ref_multi_query_kv_attention( end_idx = cu_seq_lens[i + 1] seq_len = end_idx - start_idx - # Create attention mask - attn_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1) * -1e5 + # Create attention mask. + attn_mask = torch.triu( + torch.ones(seq_len, seq_len, dtype=dtype), diagonal=1) + attn_mask = attn_mask * torch.finfo(dtype).min attn_mask = attn_mask.to(dtype=dtype, device='cuda') ref_output = ref_masked_attention( @@ -161,21 +163,20 @@ def test_single_query_cached_kv_attention( num_blocks: int, dtype: torch.dtype, ) -> None: - qkv = torch.randn( + qkv = torch.empty( num_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda') + qkv.uniform_(-1e-3, 1e-3) query, _, _ = qkv.unbind(dim=1) + x = 16 // torch.tensor([], dtype=dtype).element_size() key_block_shape = (num_heads, head_size // x, block_size, x) - key_cache = torch.randn( + key_cache = torch.empty( size=(num_blocks, *key_block_shape), dtype=dtype, device='cuda') + key_cache.uniform_(-1e-3, 1e-3) value_block_shape = (num_heads, head_size, block_size) - value_cache = torch.randn( + value_cache = torch.empty( size=(num_blocks, *value_block_shape), dtype=dtype, device='cuda') - - # Adjust the range of the values to reduce precision errors. - query = query / (head_size ** 0.5) - key_cache = key_cache / (head_size ** 0.5) - value_cache = value_cache / (head_size ** 0.5) + value_cache.uniform_(-1e-3, 1e-3) context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_tokens)] max_context_len = max(context_lens) @@ -232,10 +233,9 @@ def test_multi_query_kv_attention( num_tokens = sum(seq_lens) scale = float(1.0 / (head_size ** 0.5)) - qkv = torch.randn( + qkv = torch.empty( num_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda') - # Adjust the range of the values to reduce precision errors. - qkv = qkv / (head_size ** 0.5) + qkv.uniform_(-1e-3, 1e-3) query, key, value = qkv.unbind(dim=1) attn_op = xops.fmha.cutlass.FwOp() From f0c462bbaac7a02950bac8d8f50a6a0393bd04e0 Mon Sep 17 00:00:00 2001 From: woWoosuk Kwon Date: Thu, 4 May 2023 10:16:03 +0000 Subject: [PATCH 08/10] Add bfloat16 in kernel unit tests --- tests/kernels/activation.py | 2 +- tests/kernels/cache.py | 19 ++++++++++--------- tests/kernels/layernorm.py | 5 +++-- tests/kernels/pos_encoding.py | 2 +- 4 files changed, 15 insertions(+), 13 deletions(-) diff --git a/tests/kernels/activation.py b/tests/kernels/activation.py index 3d9a9a644f6d0..b35bea61d04d1 100644 --- a/tests/kernels/activation.py +++ b/tests/kernels/activation.py @@ -23,7 +23,7 @@ def test_silu_and_mul( if __name__ == '__main__': - for dtype in [torch.half, torch.float]: + for dtype in [torch.half, torch.bfloat16, torch.float]: for num_tokens in [7, 83, 2048]: for d in [512, 4096, 13824]: print(f'Testing dtype={dtype}, num_tokens={num_tokens}, d={d}') diff --git a/tests/kernels/cache.py b/tests/kernels/cache.py index f444ac16a49dc..b750ca97e985a 100644 --- a/tests/kernels/cache.py +++ b/tests/kernels/cache.py @@ -142,15 +142,16 @@ def test_gather_cached_kv( @torch.inference_mode() def test_cache() -> None: - test_copy_blocks( - num_mappings=23, num_layers=7, num_heads=17, head_size=16, - block_size=8, num_blocks=1024, dtype=torch.half) - test_reshape_and_cache( - num_tokens=3, num_heads=2, head_size=16, block_size=8, num_blocks=2, - dtype=torch.half) - test_gather_cached_kv( - num_tokens=3, num_heads=2, head_size=16, block_size=8, num_blocks=2, - dtype=torch.half) + for dtype in [torch.half, torch.bfloat16, torch.float]: + test_copy_blocks( + num_mappings=23, num_layers=7, num_heads=17, head_size=16, + block_size=8, num_blocks=1024, dtype=dtype) + test_reshape_and_cache( + num_tokens=3, num_heads=2, head_size=16, block_size=8, num_blocks=2, + dtype=dtype) + test_gather_cached_kv( + num_tokens=3, num_heads=2, head_size=16, block_size=8, num_blocks=2, + dtype=dtype) if __name__ == '__main__': diff --git a/tests/kernels/layernorm.py b/tests/kernels/layernorm.py index 0e0072d879c21..a61fa9b67aa75 100644 --- a/tests/kernels/layernorm.py +++ b/tests/kernels/layernorm.py @@ -8,7 +8,8 @@ class RefRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): super().__init__() - weight = torch.randn(hidden_size) / (hidden_size ** 0.5) + weight = torch.empty(hidden_size) + weight.uniform_(-1e-3, 1e-3) self.weight = nn.Parameter(weight) self.variance_epsilon = eps @@ -41,7 +42,7 @@ def test_rms_norm( if __name__ == '__main__': - for dtype in [torch.half, torch.float]: + for dtype in [torch.half, torch.bfloat16, torch.float]: for num_tokens in [7, 128, 2048]: for hidden_size in [13, 64, 1024, 5120]: print(f'Testing RMS kernel with dtype={dtype}, num_tokens=' diff --git a/tests/kernels/pos_encoding.py b/tests/kernels/pos_encoding.py index 11fd6695919ec..16b3992a30f1d 100644 --- a/tests/kernels/pos_encoding.py +++ b/tests/kernels/pos_encoding.py @@ -129,7 +129,7 @@ def test_rotary_embedding_neox( if __name__ == '__main__': - for dtype in [torch.half, torch.float]: + for dtype in [torch.half, torch.bfloat16, torch.float]: for head_size in [32, 64, 80, 96, 128, 160, 192, 256]: print(f'Running tests for head_size={head_size} and dtype={dtype}') test_rotary_embedding_neox( From 02ad58ba3bcd18586b0e41afd9d343f4dad46700 Mon Sep 17 00:00:00 2001 From: woWoosuk Kwon Date: Thu, 4 May 2023 10:25:44 +0000 Subject: [PATCH 09/10] Update README --- README.md | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/README.md b/README.md index 8df3fecec4b56..0543b9def659c 100644 --- a/README.md +++ b/README.md @@ -3,11 +3,7 @@ ## Installation ```bash -pip install psutil numpy ray torch -pip install git+https://github.com/huggingface/transformers # Required for LLaMA. -pip install sentencepiece # Required for LlamaTokenizer. -pip install ninja # To parallelize the compilation of flash-attn. -pip install flash-attn # This may take up to 10 mins. +pip install ninja psutil numpy sentencepiece ray torch transformers xformers pip install -e . ``` From bd7657b39748b8e011d872f60d28f4e2124aceec Mon Sep 17 00:00:00 2001 From: woWoosuk Kwon Date: Thu, 4 May 2023 10:25:59 +0000 Subject: [PATCH 10/10] Head size --- tests/kernels/attention.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/kernels/attention.py b/tests/kernels/attention.py index a5b6b6f207341..ae46fd6bcc06e 100644 --- a/tests/kernels/attention.py +++ b/tests/kernels/attention.py @@ -272,7 +272,7 @@ def test_attention(seed: int) -> None: torch.cuda.manual_seed(seed) for dtype in [torch.half, torch.bfloat16]: for block_size in [8, 16, 32, 64]: - for head_size in [64, 80, 96, 128, 256]: + for head_size in [32, 64, 80, 96, 128, 160, 192, 256]: print(f'Testing single_query_cached_kv_attention with ' f'dtype={dtype}, block_size={block_size}, ' f'head_size={head_size}') @@ -286,11 +286,11 @@ def test_attention(seed: int) -> None: ) for dtype in [torch.half, torch.bfloat16]: - for head_size in [64, 80, 96, 128, 256]: + for head_size in [32, 64, 80, 96, 128, 160, 192, 256]: print(f'Testing multi_query_kv_attention with dtype={dtype}, ' f'head_size={head_size}') test_multi_query_kv_attention( - num_seqs=11, + num_seqs=5, num_heads=3, head_size=head_size, dtype=dtype,