Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support FP32 #141

Merged
merged 21 commits into from
Jun 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions cacheflow/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def _get_and_verify_dtype(
config_dtype = torch.float32

dtype = dtype.lower()
if dtype == "default":
if dtype == "auto":
if config_dtype == torch.float32:
# Following the common practice, we use float16 for float32 models.
torch_dtype = torch.float16
Expand All @@ -184,9 +184,8 @@ def _get_and_verify_dtype(
# Downcasting from float32 to float16 or bfloat16 is allowed.
pass
else:
# Casting between float16 and bfloat16 is not allowed.
raise ValueError(
f"Cannot use {torch_dtype} for {config_dtype} model.")
# Casting between float16 and bfloat16 is allowed with a warning.
logger.warn(f"Casting {config_dtype} to {torch_dtype}.")

# Check if the GPU supports the dtype.
if torch_dtype == torch.bfloat16:
Expand Down
9 changes: 5 additions & 4 deletions cacheflow/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,18 @@ class LLM:
tensor_parallel_size: The number of GPUs to use for distributed
execution with tensor parallelism.
dtype: The data type for the model weights and activations. Currently,
we support `float16` and `bfloat16`. If `default`, we use the
`torch_dtype` attribute of the model config. If the `torch_dtype`
is `float32`, we use `float16` instead.
we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
the `torch_dtype` attribute specified in the model config file.
However, if the `torch_dtype` in the config is `float32`, we will
use `float16` instead.
seed: The seed to initialize the random number generator for sampling.
"""

def __init__(
self,
model: str,
tensor_parallel_size: int = 1,
dtype: str = "default",
dtype: str = "auto",
seed: int = 0,
**kwargs,
) -> None:
Expand Down
8 changes: 3 additions & 5 deletions cacheflow/model_executor/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from cacheflow import pos_encoding_ops
from cacheflow.model_executor.input_metadata import InputMetadata

_SUPPORTED_HEAD_SIZES = [32, 64, 80, 96, 128, 160, 192, 256]
_SUPPORTED_HEAD_SIZES = [64, 80, 96, 128]


class GPTCacheFlowAttention(nn.Module):
Expand Down Expand Up @@ -49,10 +49,8 @@ def __init__(self, num_heads: int, head_size: int, scale: float) -> None:
self.attn_op = xops.fmha.cutlass.FwOp()

if self.head_size not in _SUPPORTED_HEAD_SIZES:
raise ValueError(f'head_size ({self.head_size}) is not supported by '
'the single_query_cached_kv_attention kernel. '
'Use one of the following head sizes: '
f'{_SUPPORTED_HEAD_SIZES}.')
raise ValueError(f"head_size ({self.head_size}) is not supported. "
f"Supported head sizes: {_SUPPORTED_HEAD_SIZES}.")

def multi_query_kv_attention(
self,
Expand Down
8 changes: 4 additions & 4 deletions cacheflow/server/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class ServerArgs:
download_dir: Optional[str] = None
use_np_weights: bool = False
use_dummy_weights: bool = False
dtype: str = "default"
dtype: str = "auto"
seed: int = 0
worker_use_ray: bool = False
pipeline_parallel_size: int = 1
Expand Down Expand Up @@ -49,9 +49,9 @@ def add_cli_args(
help='use dummy values for model weights')
# TODO(woosuk): Support FP32.
parser.add_argument('--dtype', type=str, default=ServerArgs.dtype,
choices=['default', 'half', 'bfloat16'],
choices=['auto', 'half', 'bfloat16', 'float'],
help='data type for model weights and activations. '
'The "default" option will use FP16 precision '
'The "auto" option will use FP16 precision '
'for FP32 and FP16 models, and BF16 precision '
'for BF16 models.')
# Parallel arguments
Expand All @@ -67,7 +67,7 @@ def add_cli_args(
# KV cache arguments
parser.add_argument('--block-size', type=int,
default=ServerArgs.block_size,
choices=[1, 2, 4, 8, 16, 32, 64, 128, 256],
choices=[8, 16, 32],
help='token block size')
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
parser.add_argument('--seed', type=int, default=ServerArgs.seed,
Expand Down
69 changes: 37 additions & 32 deletions csrc/attention/attention_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -370,9 +370,11 @@ void single_query_cached_kv_attention_launcher(
dim3 block(NUM_THREADS);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
switch (head_size) {
case 32:
LAUNCH_ATTENTION_KERNEL(T, 32, BLOCK_SIZE, NUM_THREADS);
break;
// NOTE(woosuk): To reduce the compilation time, we omitted head sizes
// 32, 160, 192, 256.
// case 32:
// LAUNCH_ATTENTION_KERNEL(T, 32, BLOCK_SIZE, NUM_THREADS);
// break;
case 64:
LAUNCH_ATTENTION_KERNEL(T, 64, BLOCK_SIZE, NUM_THREADS);
break;
Expand All @@ -385,15 +387,15 @@ void single_query_cached_kv_attention_launcher(
case 128:
LAUNCH_ATTENTION_KERNEL(T, 128, BLOCK_SIZE, NUM_THREADS);
break;
case 160:
LAUNCH_ATTENTION_KERNEL(T, 160, BLOCK_SIZE, NUM_THREADS);
break;
case 192:
LAUNCH_ATTENTION_KERNEL(T, 192, BLOCK_SIZE, NUM_THREADS);
break;
case 256:
LAUNCH_ATTENTION_KERNEL(T, 256, BLOCK_SIZE, NUM_THREADS);
break;
// case 160:
// LAUNCH_ATTENTION_KERNEL(T, 160, BLOCK_SIZE, NUM_THREADS);
// break;
// case 192:
// LAUNCH_ATTENTION_KERNEL(T, 192, BLOCK_SIZE, NUM_THREADS);
// break;
// case 256:
// LAUNCH_ATTENTION_KERNEL(T, 256, BLOCK_SIZE, NUM_THREADS);
// break;
default:
TORCH_CHECK(false, "Unsupported head size: ", head_size);
break;
Expand All @@ -411,17 +413,19 @@ void single_query_cached_kv_attention_launcher(
context_lens, \
max_context_len);

// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
// 1, 2, 4, 64, 128, 256.
#define CALL_KERNEL_LAUNCHER_BLOCK_SIZE(T) \
switch (block_size) { \
case 1: \
WoosukKwon marked this conversation as resolved.
Show resolved Hide resolved
CALL_KERNEL_LAUNCHER(T, 1); \
break; \
case 2: \
CALL_KERNEL_LAUNCHER(T, 2); \
break; \
case 4: \
CALL_KERNEL_LAUNCHER(T, 4); \
break; \
/* case 1: */ \
/* CALL_KERNEL_LAUNCHER(T, 1); */ \
/* break; */ \
/* case 2: */ \
/* CALL_KERNEL_LAUNCHER(T, 2); */ \
/* break; */ \
/* case 4: */ \
/* CALL_KERNEL_LAUNCHER(T, 4); */ \
/* break; */ \
case 8: \
CALL_KERNEL_LAUNCHER(T, 8); \
break; \
Expand All @@ -431,15 +435,15 @@ void single_query_cached_kv_attention_launcher(
case 32: \
CALL_KERNEL_LAUNCHER(T, 32); \
break; \
case 64: \
CALL_KERNEL_LAUNCHER(T, 64); \
break; \
case 128: \
CALL_KERNEL_LAUNCHER(T, 128); \
break; \
case 256: \
CALL_KERNEL_LAUNCHER(T, 256); \
break; \
/* case 64: */ \
/* CALL_KERNEL_LAUNCHER(T, 64); */ \
/* break; */ \
/* case 128: */ \
/* CALL_KERNEL_LAUNCHER(T, 128); */ \
/* break; */ \
/* case 256: */ \
/* CALL_KERNEL_LAUNCHER(T, 256); */ \
/* break; */ \
default: \
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
break; \
Expand All @@ -455,8 +459,9 @@ void single_query_cached_kv_attention(
torch::Tensor& context_lens, // [num_seqs]
int block_size,
int max_context_len) {
// TODO(woosuk): Support FP32.
if (query.dtype() == at::ScalarType::Half) {
if (query.dtype() == at::ScalarType::Float) {
CALL_KERNEL_LAUNCHER_BLOCK_SIZE(float);
} else if (query.dtype() == at::ScalarType::Half) {
CALL_KERNEL_LAUNCHER_BLOCK_SIZE(uint16_t);
} else if (query.dtype() == at::ScalarType::BFloat16) {
CALL_KERNEL_LAUNCHER_BLOCK_SIZE(__nv_bfloat16);
Expand Down
3 changes: 3 additions & 0 deletions docs/source/getting_started/installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,11 @@ CacheFlow can run on systems that meet the following requirements:

.. code-block:: console

$ # Pull the Docker image with CUDA 11.8.
$ docker run --gpus all -it --rm --shm-size=8g nvcr.io/nvidia/pytorch:22.12-py3

Inside the Docker container, please execute :code:`pip uninstall torch` before installing CacheFlow.

Install with pip
----------------

Expand Down
5 changes: 5 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,11 @@ def get_nvcc_cuda_version(cuda_dir: str) -> Version:
raise RuntimeError(
"CUDA 11.8 or higher is required for GPUs with compute capability 9.0.")

# Use NVCC threads to parallelize the build.
if nvcc_cuda_version >= Version("11.2"):
num_threads = min(os.cpu_count(), 8)
NVCC_FLAGS += ["--threads", str(num_threads)]

ext_modules = []

# Cache operations.
Expand Down
10 changes: 5 additions & 5 deletions tests/kernels/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,9 +270,9 @@ def run_multi_query_kv_attention(
def test_single_query_cached_kv_attention() -> None:
torch.random.manual_seed(TEST_SEED)
torch.cuda.manual_seed(TEST_SEED)
for dtype in [torch.half, torch.bfloat16]:
for block_size in [8, 16, 32, 64]:
for head_size in [32, 64, 80, 96, 128, 160, 192, 256]:
for dtype in [torch.half, torch.bfloat16, torch.float]:
for block_size in [8, 16, 32]:
for head_size in [64, 80, 96, 128]:
print(f'Testing single_query_cached_kv_attention with '
f'dtype={dtype}, block_size={block_size}, '
f'head_size={head_size}')
Expand All @@ -289,8 +289,8 @@ def test_single_query_cached_kv_attention() -> None:
def test_multi_query_kv_attention() -> None:
torch.random.manual_seed(TEST_SEED)
torch.cuda.manual_seed(TEST_SEED)
for dtype in [torch.half, torch.bfloat16]:
for head_size in [32, 64, 80, 96, 128, 160, 192, 256]:
for dtype in [torch.half, torch.bfloat16, torch.float]:
for head_size in [64, 80, 96, 128]:
print(f'Testing multi_query_kv_attention with dtype={dtype}, '
f'head_size={head_size}')
run_multi_query_kv_attention(
Expand Down