diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index e33d5fb2dc247..c66742d9a202a 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -23,7 +23,9 @@ def main(args: argparse.Namespace): tensor_parallel_size=args.tensor_parallel_size, trust_remote_code=args.trust_remote_code, dtype=args.dtype, + device=args.device, enforce_eager=args.enforce_eager, + swap_space=args.swap_space, ) sampling_params = SamplingParams( @@ -127,5 +129,15 @@ def run_to_completion(profile_dir: Optional[str] = None): 'path to save the pytorch profiler output. Can be visualized ' 'with ui.perfetto.dev or Tensorboard.' )) + parser.add_argument( + "--device", + type=str, + default="cuda", + choices=["cuda", "cpu"], + help='device type for vLLM execution, supporting CUDA and CPU.') + parser.add_argument("--swap-space", + type=int, + default=4, + help="memory space available for CPU (GB).") args = parser.parse_args() main(args) diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index 3aac479c01bd2..223704342799d 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -69,8 +69,10 @@ def run_vllm( use_beam_search: bool, trust_remote_code: bool, dtype: str, + device: str, max_model_len: Optional[int], enforce_eager: bool, + swap_space: int, ) -> float: from vllm import LLM, SamplingParams llm = LLM( @@ -83,6 +85,8 @@ def run_vllm( dtype=dtype, max_model_len=max_model_len, enforce_eager=enforce_eager, + device=device, + swap_space=swap_space, ) # Add the requests to the engine. @@ -109,21 +113,16 @@ def run_vllm( return end - start -def run_hf( - requests: List[Tuple[str, int, int]], - model: str, - tokenizer: PreTrainedTokenizerBase, - n: int, - use_beam_search: bool, - max_batch_size: int, - trust_remote_code: bool, -) -> float: +def run_hf(requests: List[Tuple[str, int, int]], model: str, + tokenizer: PreTrainedTokenizerBase, n: int, use_beam_search: bool, + max_batch_size: int, trust_remote_code: bool) -> float: assert not use_beam_search llm = AutoModelForCausalLM.from_pretrained( model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code) if llm.config.model_type == "llama": # To enable padding in the HF backend. tokenizer.pad_token = tokenizer.eos_token + llm = llm.cuda() pbar = tqdm(total=len(requests)) @@ -206,7 +205,8 @@ def main(args: argparse.Namespace): args.quantization, args.tensor_parallel_size, args.seed, args.n, args.use_beam_search, args.trust_remote_code, args.dtype, - args.max_model_len, args.enforce_eager) + args.device, args.max_model_len, args.enforce_eager, + args.swap_space) elif args.backend == "hf": assert args.tensor_parallel_size == 1 elapsed_time = run_hf(requests, args.model, tokenizer, args.n, @@ -284,6 +284,16 @@ def main(args: argparse.Namespace): parser.add_argument("--enforce-eager", action="store_true", help="enforce eager execution") + parser.add_argument( + "--device", + type=str, + default="cuda", + choices=["cuda", "cpu"], + help='device type for vLLM execution, supporting CUDA and CPU.') + parser.add_argument("--swap-space", + type=int, + default=4, + help="memory space available for CPU (GB).") args = parser.parse_args() if args.tokenizer is None: args.tokenizer = args.model diff --git a/cpu.Dockerfile b/cpu.Dockerfile new file mode 100644 index 0000000000000..15f9d5be6408a --- /dev/null +++ b/cpu.Dockerfile @@ -0,0 +1,77 @@ +FROM python:3.10 AS dev + +RUN apt-get update -y \ + && apt-get install -y python3-pip + +WORKDIR /workspace + +# install build and runtime dependencies +COPY requirements-cpu.txt requirements-cpu.txt +RUN --mount=type=cache,target=/root/.cache/pip \ + pip install -r requirements-cpu.txt + +# install development dependencies +COPY requirements-dev.txt requirements-dev.txt +RUN --mount=type=cache,target=/root/.cache/pip \ + pip install -r requirements-dev.txt + +# image to build pytorch extensions +FROM dev AS build + +# install build dependencies +COPY requirements-build-cpu.txt requirements-build-cpu.txt +RUN --mount=type=cache,target=/root/.cache/pip \ + pip install -r requirements-build-cpu.txt + +# copy input files +COPY csrc csrc +COPY setup.py setup.py +COPY requirements-cpu.txt requirements-cpu.txt +COPY pyproject.toml pyproject.toml +COPY vllm/__init__.py vllm/__init__.py + +# max jobs used by Ninja to build extensions +ENV MAX_JOBS=$max_jobs +RUN python3 setup.py build_ext --inplace + +# image to run unit testing suite +FROM dev AS test + +# copy pytorch extensions separately to avoid having to rebuild +# when python code changes +COPY --from=build /workspace/vllm/*.so /workspace/vllm/ +COPY tests tests +COPY vllm vllm + +ENTRYPOINT ["python3", "-m", "pytest", "tests"] + +# use CUDA base as CUDA runtime dependencies are already installed via pip +FROM python:3.10 AS dev + +# libnccl required for ray +RUN apt-get update -y \ + && apt-get install -y python3-pip + +WORKDIR /workspace +COPY requirements-cpu.txt requirements-cpu.txt +RUN --mount=type=cache,target=/root/.cache/pip \ + pip install -r requirements-cpu.txt + +FROM vllm-base AS vllm +COPY --from=build /workspace/vllm/*.so /workspace/vllm/ +COPY vllm vllm + +EXPOSE 8000 +ENTRYPOINT ["python3", "-m", "vllm.entrypoints.api_server"] + +# openai api server alternative +FROM vllm-base AS vllm-openai +# install additional dependencies for openai api server +RUN --mount=type=cache,target=/root/.cache/pip \ + pip install accelerate fschat + +COPY --from=build /workspace/vllm/*.so /workspace/vllm/ +COPY vllm vllm + +ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"] + diff --git a/csrc/cache.h b/csrc/cache.h index b26faad2ca814..7802f33689426 100644 --- a/csrc/cache.h +++ b/csrc/cache.h @@ -27,4 +27,4 @@ void gather_cached_kv( torch::Tensor& value, torch::Tensor& key_cache, torch::Tensor& value_cache, - torch::Tensor& slot_mapping); + torch::Tensor& slot_mapping); \ No newline at end of file diff --git a/csrc/cpu/activation_impl.cpp b/csrc/cpu/activation_impl.cpp new file mode 100644 index 0000000000000..3704b3090c294 --- /dev/null +++ b/csrc/cpu/activation_impl.cpp @@ -0,0 +1,54 @@ +#include "cpu_types.hpp" + +namespace { +template +void silu_and_mul_cpu_impl(int num_tokens, int d, scalar_t *__restrict__ input, + scalar_t *__restrict__ output) { + using scalar_vec_t = vec_op::vec_t; + constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num(); + + TORCH_CHECK(d % VEC_ELEM_NUM == 0); + + const vec_op::FP32Vec8 zeros(0.0); + const vec_op::FP32Vec8 ones(1.0); + +#pragma omp parallel for + for (int i = 0; i < num_tokens; ++i) { + for (int j = 0; j < d; j += VEC_ELEM_NUM) { + const int start = i * 2 * d; + const scalar_vec_t x(input + start + j); + const scalar_vec_t y(input + start + d + j); + + const vec_op::FP32Vec8 f32_x(x.reg); + const vec_op::FP32Vec8 f32_y(y.reg); + + const vec_op::FP32Vec8 f32_ans = + f32_y * (f32_x / (ones + (zeros - f32_x).exp())); + + const scalar_vec_t ans(f32_ans.reg); + ans.save(output + i * d + j); + } + } +} +}; // namespace + +void silu_and_mul_cpu(torch::Tensor &out, torch::Tensor &input) { + int num_tokens = input.numel() / input.size(-1); + int d = input.size(-1) / 2; + + VLLM_DISPATCH_FLOATING_TYPES( + input.scalar_type(), "silu_and_mul_cpu_impl", [&] { + CPU_KERNEL_GUARD_IN(silu_and_mul_cpu_impl) + silu_and_mul_cpu_impl(num_tokens, d, input.data_ptr(), + out.data_ptr()); + CPU_KERNEL_GUARD_OUT(silu_and_mul_cpu_impl) + }); +} + +void gelu_new_cpu(torch::Tensor &out, torch::Tensor &input) { + TORCH_CHECK(false, "gelu_new is unsupported on CPU.") +} + +void gelu_fast_cpu(torch::Tensor &out, torch::Tensor &input) { + TORCH_CHECK(false, "gelu_fast is unsupported on CPU.") +} diff --git a/csrc/cpu/attention_impl.cpp b/csrc/cpu/attention_impl.cpp new file mode 100644 index 0000000000000..8ac96f0012db3 --- /dev/null +++ b/csrc/cpu/attention_impl.cpp @@ -0,0 +1,403 @@ +#include "cpu_types.hpp" + +namespace { + +template +struct paged_attention_v1_impl { + static void + call(scalar_t *__restrict__ out, // [num_seqs, num_heads, head_size] + const scalar_t *__restrict__ q, // [num_seqs, num_heads, head_size] + const scalar_t *__restrict__ k_cache, // [num_blocks, num_kv_heads, + // head_size/x, block_size, x] + const scalar_t *__restrict__ v_cache, // [num_blocks, num_kv_heads, + // head_size, block_size] + const int num_kv_heads, const float scale, + const int + *__restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int *__restrict__ context_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const float *__restrict__ alibi_slopes, // [num_heads] + const int q_stride, const int kv_block_stride, const int kv_head_stride, + const int num_seqs, const int num_heads) { + TORCH_CHECK(HEAD_SIZE % 16 == 0); + TORCH_CHECK(alibi_slopes == nullptr, "Unsupport alibi_slopes for CPU"); + constexpr int x = 16 / sizeof(scalar_t); + const int num_queries_per_kv = num_heads / num_kv_heads; + + int max_context_len = max_num_blocks_per_seq * BLOCK_SIZE; + int max_context_len_padded = (max_context_len + 15) & 0xFFFFFFF0; + TORCH_CHECK((max_context_len_padded * sizeof(float)) % 64 == 0); + + size_t logits_bytes = num_heads * max_context_len_padded * sizeof(float); + float *logits = (float *)std::aligned_alloc( + 64, logits_bytes); // Cacheline alignment for each context token. + // [head_num, max_context_len_padded] + + std::memset(out, 0, num_seqs * num_heads * HEAD_SIZE * sizeof(scalar_t)); + + for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) { + int context_len = context_lens[seq_idx]; + const int *seq_block_table = + block_tables + max_num_blocks_per_seq * seq_idx; + const int block_num = (context_len + BLOCK_SIZE - 1) / BLOCK_SIZE; + std::memset(logits, 0, logits_bytes); + + // Compute attention logits +#pragma omp parallel for collapse(2) + for (int block_idx = 0; block_idx < block_num; ++block_idx) { + for (int head_idx = 0; head_idx < num_heads; ++head_idx) { + const int64_t kv_head_idx = head_idx / num_queries_per_kv; + const int64_t physical_block_idx = seq_block_table[block_idx]; + const scalar_t *__restrict__ q_vec_ptr = + q + seq_idx * q_stride + head_idx * HEAD_SIZE; + const scalar_t *__restrict__ k_block_cache_ptr = + k_cache + physical_block_idx * kv_block_stride + + kv_head_idx * kv_head_stride; + float *__restrict__ head_block_logits = + logits + head_idx * max_context_len_padded + + block_idx * BLOCK_SIZE; + + for (int q_offset = 0; q_offset < HEAD_SIZE; + q_offset += x, q_vec_ptr += x) { + for (int token_idx = 0; token_idx < BLOCK_SIZE; + ++token_idx, k_block_cache_ptr += x) { + for (int i = 0; i < x; ++i) { + head_block_logits[token_idx] += + q_vec_ptr[i] * k_block_cache_ptr[i] * scale; + } + } + } + } + } + + // Compute softmax +#pragma omp parallel for + for (int head_idx = 0; head_idx < num_heads; ++head_idx) { + float *head_logit_ptr = logits + head_idx * max_context_len_padded; + float max_logit = head_logit_ptr[0]; + for (int i = 1; i < context_len; ++i) { + max_logit = + max_logit >= head_logit_ptr[i] ? max_logit : head_logit_ptr[i]; + } + + float sum = 0; + for (int i = 0; i < context_len; ++i) { + head_logit_ptr[i] = std::exp(head_logit_ptr[i] - max_logit); + sum += head_logit_ptr[i]; + } + + for (int i = 0; i < context_len; ++i) { + head_logit_ptr[i] /= sum; + } + + int remaining_seq_upper = block_num * BLOCK_SIZE; + for (int i = context_len; i < remaining_seq_upper; ++i) { + head_logit_ptr[i] = 0; + } + } + + // Compute value + constexpr int head_partition_num = HEAD_SIZE / 16; +#pragma omp parallel for collapse(2) + for (int head_idx = 0; head_idx < num_heads; ++head_idx) { + for (int head_part_idx = 0; head_part_idx < head_partition_num; + ++head_part_idx) { + for (int block_idx = 0; block_idx < block_num; ++block_idx) { + const int64_t kv_head_idx = head_idx / num_queries_per_kv; + const int64_t physical_block_idx = seq_block_table[block_idx]; + const float *__restrict__ prob_vec_ptr = + logits + head_idx * max_context_len_padded + + block_idx * BLOCK_SIZE; + const scalar_t *__restrict__ v_block_cache_ptr = + v_cache + physical_block_idx * kv_block_stride + + kv_head_idx * kv_head_stride + BLOCK_SIZE * head_part_idx * 16; + scalar_t *__restrict__ out_ptr = + out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE + + head_part_idx * 16; + + for (int i = 0; i < 16; ++i, v_block_cache_ptr += BLOCK_SIZE) { + for (int j = 0; j < BLOCK_SIZE; ++j) { + out_ptr[i] += prob_vec_ptr[j] * v_block_cache_ptr[j]; + } + } + } + } + } + } + std::free(logits); + } +}; + +template +struct paged_attention_v1_impl { + using scalar_t = c10::BFloat16; + + static void + call(scalar_t *__restrict__ out, // [num_seqs, num_heads, head_size] + const scalar_t *__restrict__ q, // [num_seqs, num_heads, head_size] + const scalar_t *__restrict__ k_cache, // [num_blocks, num_kv_heads, + // head_size/x, block_size, x] + const scalar_t *__restrict__ v_cache, // [num_blocks, num_kv_heads, + // head_size, block_size] + const int num_kv_heads, const float scale, + const int + *__restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int *__restrict__ context_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const float *__restrict__ alibi_slopes, // [num_heads] + const int q_stride, const int kv_block_stride, const int kv_head_stride, + const int num_seqs, const int num_heads) { + TORCH_CHECK(alibi_slopes == nullptr, "Unsupport alibi_slopes for CPU"); + constexpr int x = 16 / sizeof(scalar_t); + const int num_queries_per_kv = num_heads / num_kv_heads; + + using scalar_vec_t = vec_op::vec_t; + constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num(); + + static_assert(x == VEC_ELEM_NUM); + static_assert(BLOCK_SIZE == 16); + static_assert(BLOCK_SIZE % VEC_ELEM_NUM == 0); + + int max_context_len = max_num_blocks_per_seq * BLOCK_SIZE; + int max_context_len_padded = (max_context_len + 15) & 0xFFFFFFF0; + TORCH_CHECK((max_context_len_padded * sizeof(float)) % 64 == 0); + + const int parallel_work_item_num = omp_get_max_threads(); + + size_t logits_bytes = + parallel_work_item_num * max_context_len_padded * sizeof(float); + float *logits = (float *)std::aligned_alloc( + 64, logits_bytes); // Cacheline alignment for each context token. + // [parallel_work_item_num, max_context_len_padded] + +#pragma omp parallel for schedule(dynamic) collapse(2) + for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) { + for (int head_idx = 0; head_idx < num_heads; ++head_idx) { + int context_len = context_lens[seq_idx]; + const int *seq_block_table = + block_tables + max_num_blocks_per_seq * seq_idx; + const int block_num = (context_len + BLOCK_SIZE - 1) / BLOCK_SIZE; + const int64_t kv_head_idx = head_idx / num_queries_per_kv; + const scalar_t *__restrict__ q_vec_ptr = + q + seq_idx * q_stride + head_idx * HEAD_SIZE; + float *__restrict__ thread_block_logits = + logits + omp_get_thread_num() * max_context_len_padded; + + // Compute logits + for (int block_idx = 0; block_idx < block_num; ++block_idx) { + const int64_t physical_block_idx = seq_block_table[block_idx]; + const scalar_t *__restrict__ k_block_cache_ptr = + k_cache + physical_block_idx * kv_block_stride + + kv_head_idx * kv_head_stride; + float *__restrict__ head_block_logits = + thread_block_logits + block_idx * BLOCK_SIZE; + + static_assert(vec_op::BF16Vec32::get_elem_num() % x == 0); + constexpr int TOKEN_PER_GROUP = vec_op::BF16Vec32::get_elem_num() / x; + static_assert(BLOCK_SIZE % TOKEN_PER_GROUP == 0); + constexpr int TOKEN_GROUPS = BLOCK_SIZE / TOKEN_PER_GROUP; + + // vec_op::FP32Vec8 accums[BLOCK_SIZE]; + vec_op::FP32Vec16 group_accums[TOKEN_GROUPS]; + + for (int q_offset = 0; q_offset < HEAD_SIZE; + q_offset += x, k_block_cache_ptr += x * BLOCK_SIZE) { + scalar_vec_t q_vec(q_vec_ptr + q_offset); + vec_op::BF16Vec32 q_group_vec(q_vec); + + vec_op::unroll_loop( + [k_block_cache_ptr, &q_group_vec, + &group_accums](int token_group_idx) { + vec_op::BF16Vec32 k_group_vec(k_block_cache_ptr + + token_group_idx * x * + TOKEN_PER_GROUP); + + group_accums[token_group_idx] = vec_op::fma( + q_group_vec, k_group_vec, group_accums[token_group_idx]); + }); + } + + vec_op::unroll_loop([&group_accums, + head_block_logits, + scale](int token_group_idx) { + vec_op::unroll_loop([&group_accums, + head_block_logits, scale, + token_group_idx]( + int token_idx) { + float dot_v = + group_accums[token_group_idx] + .template reduce_sub_sum< + vec_op::FP32Vec16::get_elem_num() / TOKEN_PER_GROUP>( + token_idx); + head_block_logits[token_group_idx * TOKEN_PER_GROUP + token_idx] = + dot_v * scale; + }); + }); + } + + // Compute softmax + float max_logit = thread_block_logits[0]; + for (int i = 1; i < context_len; ++i) { + max_logit = max_logit >= thread_block_logits[i] + ? max_logit + : thread_block_logits[i]; + } + + float sum = 0; + for (int i = 0; i < context_len; ++i) { + thread_block_logits[i] = std::exp(thread_block_logits[i] - max_logit); + sum += thread_block_logits[i]; + } + + for (int i = 0; i < context_len; ++i) { + thread_block_logits[i] /= sum; + } + + int remaining_seq_upper = block_num * BLOCK_SIZE; + for (int i = context_len; i < remaining_seq_upper; ++i) { + thread_block_logits[i] = 0; + } + + // Compute value + constexpr int head_elem_num_per_partition = 16; + constexpr int head_partition_num = + HEAD_SIZE / head_elem_num_per_partition; + for (int head_part_idx = 0; head_part_idx < head_partition_num; + ++head_part_idx) { + vec_op::FP32Vec16 accums[head_elem_num_per_partition]; + scalar_t *__restrict__ out_ptr = + out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE + + head_part_idx * head_elem_num_per_partition; + for (int block_idx = 0; block_idx < block_num; ++block_idx) { + const int64_t physical_block_idx = seq_block_table[block_idx]; + const float *__restrict__ prob_vec_ptr = + thread_block_logits + block_idx * BLOCK_SIZE; + const scalar_t *__restrict__ v_block_cache_ptr = + v_cache + physical_block_idx * kv_block_stride + + kv_head_idx * kv_head_stride + + BLOCK_SIZE * head_part_idx * head_elem_num_per_partition; + + vec_op::FP32Vec16 prob_vec(prob_vec_ptr); + + vec_op::unroll_loop( + [&](int head_elem_idx) { + vec_op::BF16Vec16 v_vec(v_block_cache_ptr + + BLOCK_SIZE * head_elem_idx); + vec_op::FP32Vec16 fp32_v_vec(v_vec.reg); + accums[head_elem_idx] = + accums[head_elem_idx] + prob_vec * fp32_v_vec; + }); + } + + vec_op::unroll_loop( + [&](int head_elem_idx) { + float value = accums[head_elem_idx].reduce_sum(); + vec_op::storeFP32ToT(value, out_ptr + head_elem_idx); + }); + } + } + } + std::free(logits); + } +}; + +#define LAUNCH_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE) \ + paged_attention_v1_impl::call( \ + out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \ + block_tables_ptr, context_lens_ptr, max_num_blocks_per_seq, \ + alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, num_seqs, \ + num_heads); + +template +void paged_attention_v1_impl_launcher( + torch::Tensor &out, torch::Tensor &query, torch::Tensor &key_cache, + torch::Tensor &value_cache, int num_kv_heads, float scale, + torch::Tensor &block_tables, torch::Tensor &context_lens, + int max_context_len, const c10::optional &alibi_slopes) { + int num_seqs = query.size(0); + int num_heads = query.size(1); + int head_size = query.size(2); + int max_num_blocks_per_seq = block_tables.size(1); + int q_stride = query.stride(0); + int kv_block_stride = key_cache.stride(0); + int kv_head_stride = key_cache.stride(1); + + // NOTE: alibi_slopes is optional. + const float *alibi_slopes_ptr = + alibi_slopes + ? reinterpret_cast(alibi_slopes.value().data_ptr()) + : nullptr; + + T *out_ptr = reinterpret_cast(out.data_ptr()); + T *query_ptr = reinterpret_cast(query.data_ptr()); + T *key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); + T *value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); + int *block_tables_ptr = block_tables.data_ptr(); + int *context_lens_ptr = context_lens.data_ptr(); + + switch (head_size) { + case 64: + LAUNCH_ATTENTION_KERNEL(T, 64, BLOCK_SIZE); + break; + case 80: + LAUNCH_ATTENTION_KERNEL(T, 80, BLOCK_SIZE); + break; + case 96: + LAUNCH_ATTENTION_KERNEL(T, 96, BLOCK_SIZE); + break; + case 112: + LAUNCH_ATTENTION_KERNEL(T, 112, BLOCK_SIZE); + break; + case 128: + LAUNCH_ATTENTION_KERNEL(T, 128, BLOCK_SIZE); + break; + case 256: + LAUNCH_ATTENTION_KERNEL(T, 256, BLOCK_SIZE); + break; + default: + TORCH_CHECK(false, "Unsupported head size: ", head_size); + break; + } +} + +#define CALL_KERNEL_LAUNCHER(T, BLOCK_SIZE) \ + paged_attention_v1_impl_launcher( \ + out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \ + context_lens, max_context_len, alibi_slopes); + +#define CALL_KERNEL_LAUNCHER_BLOCK_SIZE(T) \ + switch (block_size) { \ + case 16: \ + CALL_KERNEL_LAUNCHER(T, 16); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported block size: ", block_size); \ + break; \ + } +} // namespace + +void paged_attention_v1_cpu(torch::Tensor &out, torch::Tensor &query, + torch::Tensor &key_cache, + torch::Tensor &value_cache, int num_kv_heads, + float scale, torch::Tensor &block_tables, + torch::Tensor &context_lens, int block_size, + int max_context_len, + const c10::optional &alibi_slopes) { + VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v1_impl", + [&] { + CPU_KERNEL_GUARD_IN(paged_attention_v1_impl) + CALL_KERNEL_LAUNCHER_BLOCK_SIZE(scalar_t); + CPU_KERNEL_GUARD_OUT(paged_attention_v1_impl) + }); +} + +void paged_attention_v2_cpu(torch::Tensor &out, torch::Tensor &exp_sums, + torch::Tensor &max_logits, torch::Tensor &tmp_out, + torch::Tensor &query, torch::Tensor &key_cache, + torch::Tensor &value_cache, int num_kv_heads, + float scale, torch::Tensor &block_tables, + torch::Tensor &context_lens, int block_size, + int max_context_len, + const c10::optional &alibi_slopes) { + TORCH_CHECK(false, "paged_attention_v2 is unsupported on CPU.") +} \ No newline at end of file diff --git a/csrc/cpu/cache_impl.cpp b/csrc/cpu/cache_impl.cpp new file mode 100644 index 0000000000000..6bed99b402fa0 --- /dev/null +++ b/csrc/cpu/cache_impl.cpp @@ -0,0 +1,145 @@ +#include +#include + +#include "cpu_types.hpp" + +namespace { +template +void copy_blocks_cpu_impl( + std::vector &key_caches, + std::vector &value_caches, + const std::vector> mapping_pairs, + const int element_num_per_block, const int layer_num) { + const size_t pair_num = mapping_pairs.size(); + const size_t block_bytes = sizeof(scalar_t) * element_num_per_block; +#pragma omp parallel for collapse(2) + for (int layer = 0; layer < layer_num; ++layer) { + for (size_t pair = 0; pair < pair_num; ++pair) { + int64_t source_offset = element_num_per_block * mapping_pairs[pair].first; + int64_t target_offset = + element_num_per_block * mapping_pairs[pair].second; + scalar_t *key_cache_ptr = key_caches[layer].data_ptr(); + scalar_t *source_ptr = key_cache_ptr + source_offset; + scalar_t *target_ptr = key_cache_ptr + target_offset; + std::memcpy(target_ptr, source_ptr, block_bytes); + + scalar_t *value_cache_ptr = value_caches[layer].data_ptr(); + source_ptr = value_cache_ptr + source_offset; + target_ptr = value_cache_ptr + target_offset; + std::memcpy(target_ptr, source_ptr, block_bytes); + } + } +} + +template +void reshape_and_cache_cpu_impl( + const scalar_t *__restrict__ key, const scalar_t *__restrict__ value, + scalar_t *__restrict__ key_cache, scalar_t *__restrict__ value_cache, + const int64_t *__restrict__ slot_mapping, const int num_tokens, + const int key_stride, const int value_stride, const int num_heads, + const int head_size, const int block_size, const int x) { + const int block_elem_num = num_heads * head_size * block_size; + +#pragma omp parallel for collapse(2) + for (int token_idx = 0; token_idx < num_tokens; ++token_idx) { + for (int head_idx = 0; head_idx < num_heads; ++head_idx) { + const int64_t slot_idx = slot_mapping[token_idx]; + if (slot_idx >= 0) { + int src_key_head_idx = token_idx * key_stride + head_idx * head_size; + int src_value_head_idx = + token_idx * value_stride + head_idx * head_size; + const scalar_t *src_key_head_ptr = key + src_key_head_idx; + const scalar_t *src_value_head_ptr = value + src_value_head_idx; + const int64_t block_index = slot_idx / block_size; + const int64_t block_offset = slot_idx % block_size; + scalar_t *target_key_head_ptr = key_cache + + block_elem_num * block_index + + head_idx * block_size * head_size; + scalar_t *target_value_head_ptr = value_cache + + block_elem_num * block_index + + head_idx * block_size * head_size; + + for (int src_key_idx = 0; src_key_idx < head_size; src_key_idx += x) { + const int64_t target_offset = + src_key_idx * block_size + block_offset * x; + for (int i = 0; i < x; ++i) { + target_key_head_ptr[target_offset + i] = + src_key_head_ptr[src_key_idx + i]; + } + } + + for (int src_value_idx = 0; src_value_idx < head_size; + ++src_value_idx) { + const int64_t target_offset = + src_value_idx * block_size + block_offset; + target_value_head_ptr[target_offset] = + src_value_head_ptr[src_value_idx]; + } + } + } + } +} +}; // namespace + +void copy_blocks_cpu( + std::vector &key_caches, + std::vector &value_caches, + const std::map> &block_mapping) { + int num_layers = key_caches.size(); + TORCH_CHECK(num_layers == value_caches.size()); + if (num_layers == 0) { + return; + } + + std::vector> mapping_pairs; + mapping_pairs.reserve(block_mapping.size()); + for (const auto &pair : block_mapping) { + for (const auto &dst : pair.second) { + mapping_pairs.emplace_back(pair.first, dst); + } + } + + const int element_num_per_block = key_caches[0][0].numel(); + VLLM_DISPATCH_FLOATING_TYPES( + key_caches[0].scalar_type(), "copy_blocks_cpu_impl", [&] { + CPU_KERNEL_GUARD_IN(copy_blocks_cpu_impl) + copy_blocks_cpu_impl(key_caches, value_caches, mapping_pairs, + element_num_per_block, num_layers); + CPU_KERNEL_GUARD_OUT(copy_blocks_cpu_impl) + }); +} + +void reshape_and_cache_cpu(torch::Tensor &key, torch::Tensor &value, + torch::Tensor &key_cache, torch::Tensor &value_cache, + torch::Tensor &slot_mapping) { + int num_tokens = key.size(0); + int num_heads = key.size(1); + int head_size = key.size(2); + int block_size = key_cache.size(3); + int x = key_cache.size(4); + + int key_stride = key.stride(0); + int value_stride = value.stride(0); + + VLLM_DISPATCH_FLOATING_TYPES( + key.scalar_type(), "reshape_and_cache_cpu_impl", [&] { + CPU_KERNEL_GUARD_IN(reshape_and_cache_cpu_impl) + reshape_and_cache_cpu_impl( + key.data_ptr(), value.data_ptr(), + key_cache.data_ptr(), value_cache.data_ptr(), + slot_mapping.data_ptr(), num_tokens, key_stride, + value_stride, num_heads, head_size, block_size, x); + CPU_KERNEL_GUARD_OUT(reshape_and_cache_cpu_impl) + }); +} + +void swap_blocks_cpu(torch::Tensor &src, torch::Tensor &dst, + const std::map &block_mapping) { + TORCH_CHECK(false, "swap_blocks is unsupported on CPU.") +} + +void gather_cached_kv_cpu(torch::Tensor &key, torch::Tensor &value, + torch::Tensor &key_cache, torch::Tensor &value_cache, + torch::Tensor &slot_mapping) { + TORCH_CHECK(false, "gather_cached_kv is unsupported on CPU.") +} \ No newline at end of file diff --git a/csrc/cpu/cpu_ops.h b/csrc/cpu/cpu_ops.h new file mode 100644 index 0000000000000..6b42ac549d633 --- /dev/null +++ b/csrc/cpu/cpu_ops.h @@ -0,0 +1,73 @@ +#include + +void rotary_embedding_cpu(torch::Tensor &positions, torch::Tensor &query, + torch::Tensor &key, int head_size, + torch::Tensor &cos_sin_cache, bool is_neox); + +void silu_and_mul_cpu(torch::Tensor &out, torch::Tensor &input); + +void gelu_new_cpu(torch::Tensor &out, torch::Tensor &input); + +void gelu_fast_cpu(torch::Tensor &out, torch::Tensor &input); + +void paged_attention_v1_cpu( + torch::Tensor &out, torch::Tensor &query, torch::Tensor &key_cache, + torch::Tensor &value_cache, int num_kv_heads, float scale, + torch::Tensor &block_tables, torch::Tensor &context_lens, int block_size, + int max_context_len, const c10::optional &alibi_slopes); + +void paged_attention_v2_cpu( + torch::Tensor &out, torch::Tensor &exp_sums, torch::Tensor &max_logits, + torch::Tensor &tmp_out, torch::Tensor &query, torch::Tensor &key_cache, + torch::Tensor &value_cache, int num_kv_heads, float scale, + torch::Tensor &block_tables, torch::Tensor &context_lens, int block_size, + int max_context_len, const c10::optional &alibi_slopes); + +void copy_blocks_cpu( + std::vector &key_caches, + std::vector &value_caches, + const std::map> &block_mapping); + +void reshape_and_cache_cpu(torch::Tensor &key, torch::Tensor &value, + torch::Tensor &key_cache, torch::Tensor &value_cache, + torch::Tensor &slot_mapping); + +void swap_blocks_cpu(torch::Tensor &src, torch::Tensor &dst, + const std::map &block_mapping); + +void gather_cached_kv_cpu(torch::Tensor &key, torch::Tensor &value, + torch::Tensor &key_cache, torch::Tensor &value_cache, + torch::Tensor &slot_mapping); + +void rms_norm_cpu(torch::Tensor &out, torch::Tensor &input, + torch::Tensor &weight, float epsilon); + +void fused_add_rms_norm_cpu(torch::Tensor &input, torch::Tensor &residual, + torch::Tensor &weight, float epsilon); + +inline torch::Tensor awq_gemm_cpu(torch::Tensor _in_feats, torch::Tensor _kernel, + torch::Tensor _scaling_factors, torch::Tensor _zeros, + int split_k_iters) { + TORCH_CHECK(false, "Quantization is not supported on CPU."); +} + +inline void squeezellm_gemm_cpu(torch::Tensor vec, torch::Tensor mat, + torch::Tensor mul, torch::Tensor lookup_table) { + TORCH_CHECK(false, "Quantization is not supported on CPU."); +} + +inline torch::Tensor gptq_gemm_cpu( + torch::Tensor a, + torch::Tensor b_q_weight, + torch::Tensor b_gptq_qzeros, + torch::Tensor b_gptq_scales, + torch::Tensor b_g_idx, + bool use_exllama) { + TORCH_CHECK(false, "Quantization is not supported on CPU."); +} + +inline void gptq_shuffle_cpu( + torch::Tensor q_weight, + torch::Tensor q_perm) { + TORCH_CHECK(false, "Quantization is not supported on CPU."); +} \ No newline at end of file diff --git a/csrc/cpu/cpu_types.hpp b/csrc/cpu/cpu_types.hpp new file mode 100644 index 0000000000000..5c0207cd979c8 --- /dev/null +++ b/csrc/cpu/cpu_types.hpp @@ -0,0 +1,273 @@ + +#ifndef CPU_TYPES_HPP +#define CPU_TYPES_HPP + +#include +#include + +namespace vec_op { + +// FIXME: FP16 is not fully supported in Torch-CPU +#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) + +#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) + +#ifndef CPU_OP_GUARD +#define CPU_KERNEL_GUARD_IN(NAME) +#define CPU_KERNEL_GUARD_OUT(NAME) +#else +#define CPU_KERNEL_GUARD_IN(NAME) \ + std::cout << #NAME << " invoked." << std::endl; +#define CPU_KERNEL_GUARD_OUT(NAME) std::cout << #NAME << " exit." << std::endl; +#endif + +namespace { +template +constexpr void unroll_loop_item(std::integer_sequence, F &&f) { + (f(std::integral_constant{}), ...); +} +}; // namespace + +template >> +constexpr void unroll_loop(F &&f) { + unroll_loop_item(std::make_integer_sequence{}, std::forward(f)); +} + +template struct Vec { + constexpr static int get_elem_num() { return T::VEC_ELEM_NUM; } +}; + +#ifdef __AVX512FP16__ +struct FP16Vec8 : public Vec { + constexpr static int VEC_ELEM_NUM = 8; + + __m128h reg; + + explicit FP16Vec8(_Float16 v) : reg(_mm_set1_ph(v)) {} + + explicit FP16Vec8(const void *ptr) : reg(_mm_loadu_ph(ptr)) {} + + explicit FP16Vec8(__m128h data) : reg(data) {} + + explicit FP16Vec8(__m256 v) + : reg(_mm_castsi128_ph(_mm256_cvtps_ph(v, _MM_FROUND_TO_NEAREST_INT))){}; + + FP16Vec8 operator*(const FP16Vec8 &b) const { + return FP16Vec8(_mm_mul_ph(reg, b.reg)); + } + + FP16Vec8 operator+(const FP16Vec8 &b) const { + return FP16Vec8(_mm_add_ph(reg, b.reg)); + } + + FP16Vec8 operator-(const FP16Vec8 &b) const { + return FP16Vec8(_mm_sub_ph(reg, b.reg)); + } + + FP16Vec8 operator/(const FP16Vec8 &b) const { + return FP16Vec8(_mm_div_ph(reg, b.reg)); + } + + void save(void *ptr) const { _mm_storeu_ph(ptr, reg); } +}; +#endif + +struct BF16Vec8 : public Vec { + constexpr static int VEC_ELEM_NUM = 8; + + __m128bh reg; + + explicit BF16Vec8(const void *ptr) + : reg(*reinterpret_cast(ptr)) {} + + explicit BF16Vec8(__m128bh data) : reg(data) {} + + explicit BF16Vec8(__m256 v) : reg(_mm256_cvtneps_pbh(v)){}; + + void save(void *ptr) const { *reinterpret_cast<__m128bh *>(ptr) = reg; } +}; + +struct BF16Vec16 : public Vec { + constexpr static int VEC_ELEM_NUM = 16; + + __m256bh reg; + + explicit BF16Vec16(const void *ptr) + : reg(*reinterpret_cast(ptr)) {} + + explicit BF16Vec16(__m256bh data) : reg(data) {} + + explicit BF16Vec16(__m512 v) : reg(_mm512_cvtneps_pbh(v)){}; + + void save(void *ptr) const { *reinterpret_cast<__m256bh *>(ptr) = reg; } +}; + +struct BF16Vec32 : public Vec { + constexpr static int VEC_ELEM_NUM = 32; + + __m512bh reg; + + explicit BF16Vec32(const void *ptr) + : reg(*reinterpret_cast(ptr)) {} + + explicit BF16Vec32(__m512bh data) : reg(data) {} + + explicit BF16Vec32(BF16Vec8 &vec8_data) + : reg((__m512bh)_mm512_inserti32x4( + _mm512_inserti32x4(_mm512_inserti32x4(_mm512_castsi128_si512( + (__m128i)vec8_data.reg), + (__m128i)vec8_data.reg, 1), + (__m128i)vec8_data.reg, 2), + (__m128i)vec8_data.reg, 3)) {} + + void save(void *ptr) const { *reinterpret_cast<__m512bh *>(ptr) = reg; } +}; + +struct FP32Vec8 : public Vec { + constexpr static int VEC_ELEM_NUM = 8; + union AliasReg { + __m256 reg; + float values[VEC_ELEM_NUM]; + }; + + __m256 reg; + + explicit FP32Vec8(float v) : reg(_mm256_set1_ps(v)) {} + + explicit FP32Vec8() : reg(_mm256_set1_ps(0.0)) {} + + explicit FP32Vec8(const float *ptr) : reg(_mm256_loadu_ps(ptr)) {} + + explicit FP32Vec8(__m256 data) : reg(data) {} + +#ifdef __AVX512FP16__ + explicit FP32Vec8(__m128h v) : reg(_mm256_cvtph_ps(_mm_castph_si128(v))) {} +#endif + + explicit FP32Vec8(__m128bh v) : reg(_mm256_cvtpbh_ps(v)) {} + + float reduce_sum() const { + AliasReg ar; + ar.reg = reg; + float ans = 0; + unroll_loop([&ans, &ar](int i) { ans += ar.values[i]; }); + + return ans; + } + + FP32Vec8 exp() const { + AliasReg ar; + ar.reg = reg; + return FP32Vec8(_mm256_set_ps(expf(ar.values[7]), expf(ar.values[6]), + expf(ar.values[5]), expf(ar.values[4]), + expf(ar.values[3]), expf(ar.values[2]), + expf(ar.values[1]), expf(ar.values[0]))); + } + + FP32Vec8 operator*(const FP32Vec8 &b) const { + return FP32Vec8(_mm256_mul_ps(reg, b.reg)); + } + + FP32Vec8 operator+(const FP32Vec8 &b) const { + return FP32Vec8(_mm256_add_ps(reg, b.reg)); + } + + FP32Vec8 operator-(const FP32Vec8 &b) const { + return FP32Vec8(_mm256_sub_ps(reg, b.reg)); + } + + FP32Vec8 operator/(const FP32Vec8 &b) const { + return FP32Vec8(_mm256_div_ps(reg, b.reg)); + } + + void save(float *ptr) const { _mm256_storeu_ps(ptr, reg); } +}; + +struct FP32Vec16 : public Vec { + constexpr static int VEC_ELEM_NUM = 16; + union AliasReg { + __m512 reg; + float values[VEC_ELEM_NUM]; + }; + + __m512 reg; + + explicit FP32Vec16(float v) : reg(_mm512_set1_ps(v)) {} + + explicit FP32Vec16() : reg(_mm512_set1_ps(0.0)) {} + + explicit FP32Vec16(const float *ptr) : reg(_mm512_loadu_ps(ptr)) {} + + explicit FP32Vec16(__m512 data) : reg(data) {} + + explicit FP32Vec16(__m256bh v) : reg(_mm512_cvtpbh_ps(v)) {} + + FP32Vec16 operator+(const FP32Vec16 &b) const { + return FP32Vec16(_mm512_add_ps(reg, b.reg)); + } + + FP32Vec16 operator*(const FP32Vec16 &b) const { + return FP32Vec16(_mm512_mul_ps(reg, b.reg)); + } + + float reduce_sum() const { + AliasReg ar; + ar.reg = reg; + float ans = 0; + unroll_loop([&ans, &ar](int i) { ans += ar.values[i]; }); + + return ans; + } + + template float reduce_sub_sum(int idx) { + static_assert(VEC_ELEM_NUM % group_size == 0); + + AliasReg ar; + ar.reg = reg; + float ans = 0; + const int start = idx * group_size; + unroll_loop( + [&ans, &start, ar](int i) { ans += ar.values[start + i]; }); + + return ans; + } + + void save(float *ptr) const { _mm512_storeu_ps(ptr, reg); } +}; + +template struct VecType { using vec_type = void; }; + +template using vec_t = typename VecType::vec_type; + +template <> struct VecType { using vec_type = FP32Vec8; }; + +#ifdef __AVX512FP16__ +template <> struct VecType { using vec_type = FP16Vec8; }; +#endif + +template <> struct VecType { using vec_type = BF16Vec8; }; + +template void storeFP32ToT(float v, T *ptr) { *ptr = v; } + +#ifdef __AVX512FP16__ +template <> inline void storeFP32ToT(float v, c10::Half *ptr) { + *reinterpret_cast<_Float16 *>(ptr) = v; +} +#endif + +inline FP32Vec16 fma(BF16Vec32 &a, BF16Vec32 &b, FP32Vec16 &c) { + return FP32Vec16(_mm512_dpbf16_ps(c.reg, a.reg, b.reg)); +} + +template <> +inline void storeFP32ToT(float v, c10::BFloat16 *ptr) { + *reinterpret_cast<__bfloat16 *>(ptr) = _mm_cvtness_sbh(v); +} + +}; // namespace vec_op + +#endif \ No newline at end of file diff --git a/csrc/cpu/layernorm_impl.cpp b/csrc/cpu/layernorm_impl.cpp new file mode 100644 index 0000000000000..fb1373746b290 --- /dev/null +++ b/csrc/cpu/layernorm_impl.cpp @@ -0,0 +1,117 @@ +#include "cpu_types.hpp" + +namespace { +template +void rms_norm_cpu_impl(scalar_t *__restrict__ out, + const scalar_t *__restrict__ input, + const scalar_t *__restrict__ weight, const float epsilon, + const int num_tokens, const int hidden_size) { + using scalar_vec_t = vec_op::vec_t; + constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num(); + TORCH_CHECK(hidden_size % VEC_ELEM_NUM == 0); + +#pragma omp parallel for + for (int i = 0; i < num_tokens; ++i) { + vec_op::FP32Vec8 variance(0.0); + auto input_p = input + i * hidden_size; + auto output_p = out + i * hidden_size; + for (int j = 0; j < hidden_size; j += VEC_ELEM_NUM) { + scalar_vec_t x(input_p + j); + vec_op::FP32Vec8 fp32_x(x.reg); + variance = variance + fp32_x * fp32_x; + } + + float s_variance = + 1.0f / sqrtf(variance.reduce_sum() / (float)hidden_size + epsilon); + vec_op::FP32Vec8 fp32_s_variance(s_variance); + + for (int j = 0; j < hidden_size; j += VEC_ELEM_NUM) { + scalar_vec_t x(input_p + j); + scalar_vec_t w(weight + j); + + vec_op::FP32Vec8 fp32_x(x.reg); + vec_op::FP32Vec8 fp32_w(w.reg); + + vec_op::FP32Vec8 fp32_out = fp32_x * fp32_s_variance * fp32_w; + + scalar_vec_t out(fp32_out.reg); + out.save(output_p + j); + } + } +} + +template +void fused_add_rms_norm_cpu_impl(scalar_t *__restrict__ input, + scalar_t *__restrict__ residual, + const scalar_t *__restrict__ weight, + const float epsilon, const int num_tokens, + const int hidden_size) { + using scalar_vec_t = vec_op::vec_t; + constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num(); + TORCH_CHECK(hidden_size % VEC_ELEM_NUM == 0); + +#pragma omp parallel for + for (int i = 0; i < num_tokens; ++i) { + vec_op::FP32Vec8 variance(0.0); + auto input_p = input + i * hidden_size; + auto residual_p = residual + i * hidden_size; + for (int j = 0; j < hidden_size; j += VEC_ELEM_NUM) { + scalar_vec_t x(input_p + j); + scalar_vec_t res(residual_p + j); + vec_op::FP32Vec8 fp32_x(x.reg); + vec_op::FP32Vec8 fp32_res(res.reg); + + fp32_x = fp32_x + fp32_res; + variance = variance + fp32_x * fp32_x; + scalar_vec_t out(fp32_x.reg); + out.save(residual_p + j); + } + + float s_variance = + 1.0f / sqrtf(variance.reduce_sum() / (float)hidden_size + epsilon); + vec_op::FP32Vec8 fp32_s_variance(s_variance); + + for (int j = 0; j < hidden_size; j += VEC_ELEM_NUM) { + scalar_vec_t w(weight + j); + scalar_vec_t res(residual_p + j); + + vec_op::FP32Vec8 fp32_w(w.reg); + vec_op::FP32Vec8 fp32_res(res.reg); + + vec_op::FP32Vec8 fp32_out = fp32_res * fp32_s_variance * fp32_w; + + scalar_vec_t out(fp32_out.reg); + out.save(input_p + j); + } + } +} +} // namespace + +void rms_norm_cpu(torch::Tensor &out, torch::Tensor &input, + torch::Tensor &weight, float epsilon) { + int hidden_size = input.size(-1); + int num_tokens = input.numel() / hidden_size; + + VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_cpu_impl", [&] { + CPU_KERNEL_GUARD_IN(rms_norm_cpu_impl) + rms_norm_cpu_impl(out.data_ptr(), input.data_ptr(), + weight.data_ptr(), epsilon, num_tokens, + hidden_size); + CPU_KERNEL_GUARD_OUT(rms_norm_cpu_impl) + }); +} + +void fused_add_rms_norm_cpu(torch::Tensor &input, torch::Tensor &residual, + torch::Tensor &weight, float epsilon) { + int hidden_size = input.size(-1); + int num_tokens = input.numel() / hidden_size; + + VLLM_DISPATCH_FLOATING_TYPES( + input.scalar_type(), "fused_add_rms_norm_cpu_impl", [&] { + CPU_KERNEL_GUARD_IN(fused_add_rms_norm_cpu_impl) + fused_add_rms_norm_cpu_impl( + input.data_ptr(), residual.data_ptr(), + weight.data_ptr(), epsilon, num_tokens, hidden_size); + CPU_KERNEL_GUARD_OUT(fused_add_rms_norm_cpu_impl) + }); +} \ No newline at end of file diff --git a/csrc/cpu/pos_encoding_impl.cpp b/csrc/cpu/pos_encoding_impl.cpp new file mode 100644 index 0000000000000..7661c8488cbde --- /dev/null +++ b/csrc/cpu/pos_encoding_impl.cpp @@ -0,0 +1,117 @@ + +#include "cpu_types.hpp" + +namespace { +template +void rotary_embedding_impl( + const int64_t + *__restrict__ positions, // [batch_size, seq_len] or [num_tokens] + scalar_t + *__restrict__ query, /// [batch_size, seq_len, num_heads, head_size] or + /// [num_tokens, num_heads, head_size] + scalar_t + *__restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or + // [num_tokens, num_kv_heads, head_size] + const scalar_t + *__restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2] + const int rot_dim, const int64_t query_stride, const int64_t key_stride, + const int num_heads, const int num_kv_heads, const int head_size, + const int num_tokens) { + using scalar_vec_t = vec_op::vec_t; + constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num(); + constexpr int ELEM_SIZE = sizeof(scalar_t); + + const int embed_dim = rot_dim / 2; + TORCH_CHECK(embed_dim % VEC_ELEM_NUM == 0); + +#pragma omp parallel for + for (int token_idx = 0; token_idx < num_tokens; ++token_idx) { + int64_t pos = positions[token_idx]; + const scalar_t *cache_ptr = cos_sin_cache + pos * rot_dim; + + for (int i = 0; i < num_heads; ++i) { + const int head_idx = i; + const int64_t token_head = token_idx * query_stride + head_idx * head_size; + for (int j = 0; j < embed_dim; j += VEC_ELEM_NUM) { + const int rot_offset = j; + const int x_index = rot_offset; + const int y_index = embed_dim + rot_offset; + + const int64_t out_x = token_head + x_index; + const int64_t out_y = token_head + y_index; + + const scalar_vec_t cos(cache_ptr + x_index); + const scalar_vec_t sin(cache_ptr + y_index); + + const scalar_vec_t q_x(query + out_x); + const scalar_vec_t q_y(query + out_y); + + vec_op::FP32Vec8 fp32_cos(cos.reg); + vec_op::FP32Vec8 fp32_sin(sin.reg); + + vec_op::FP32Vec8 fp32_q_x(q_x.reg); + vec_op::FP32Vec8 fp32_q_y(q_y.reg); + + auto out1 = fp32_q_x * fp32_cos - fp32_q_y * fp32_sin; + scalar_vec_t(out1.reg).save(query + out_x); + + auto out2 = fp32_q_y * fp32_cos + fp32_q_x * fp32_sin; + scalar_vec_t(out2.reg).save(query + out_y); + } + } + + for (int i = 0; i < num_kv_heads; ++i) { + const int head_idx = i; + const int64_t token_head = token_idx * key_stride + head_idx * head_size; + for (int j = 0; j < embed_dim; j += VEC_ELEM_NUM) { + const int rot_offset = j; + const int x_index = rot_offset; + const int y_index = embed_dim + rot_offset; + + const int64_t out_x = token_head + x_index; + const int64_t out_y = token_head + y_index; + + const scalar_vec_t cos(cache_ptr + x_index); + const scalar_vec_t sin(cache_ptr + y_index); + + const scalar_vec_t k_x(key + out_x); + const scalar_vec_t k_y(key + out_y); + + vec_op::FP32Vec8 fp32_cos(cos.reg); + vec_op::FP32Vec8 fp32_sin(sin.reg); + + vec_op::FP32Vec8 fp32_k_x(k_x.reg); + vec_op::FP32Vec8 fp32_k_y(k_y.reg); + + auto out1 = fp32_k_x * fp32_cos - fp32_k_y * fp32_sin; + scalar_vec_t(out1.reg).save(key + out_x); + auto out2 = fp32_k_y * fp32_cos + fp32_k_x * fp32_sin; + scalar_vec_t(out2.reg).save(key + out_y); + } + } + } +} +}; // namespace + +void rotary_embedding_cpu(torch::Tensor &positions, torch::Tensor &query, + torch::Tensor &key, int head_size, + torch::Tensor &cos_sin_cache, bool is_neox) { + TORCH_CHECK(is_neox); + int num_tokens = query.numel() / query.size(-1); + int rot_dim = cos_sin_cache.size(1); + int num_heads = query.size(-1) / head_size; + int num_kv_heads = key.size(-1) / head_size; + int64_t key_stride = key.stride(-2); + int64_t query_stride = query.stride(-2); + + VLLM_DISPATCH_FLOATING_TYPES( + query.scalar_type(), "rotary_embedding_impl", [&] { + CPU_KERNEL_GUARD_IN(rotary_embedding_impl) + rotary_embedding_impl( + positions.data_ptr(), query.data_ptr(), + key.data_ptr(), cos_sin_cache.data_ptr(), + rot_dim, query_stride, key_stride, num_heads, num_kv_heads, + head_size, num_tokens); + CPU_KERNEL_GUARD_OUT(rotary_embedding_impl) + }); +} \ No newline at end of file diff --git a/csrc/dispatch_utils.h b/csrc/dispatch_utils.h index 0ae9cd6415982..c15cda7c39950 100644 --- a/csrc/dispatch_utils.h +++ b/csrc/dispatch_utils.h @@ -6,11 +6,40 @@ #include -#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ - AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ - AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ +#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) -#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ - AT_DISPATCH_SWITCH( \ - TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) +#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) + +#ifdef VLLM_BUILD_CPU_ONLY +#define VLLM_DISPATCH_TO_CUDA_CASE(BASENAME, ...) +#else +#define VLLM_DISPATCH_TO_CUDA_CASE(BASENAME, ...) \ + case c10::DeviceType::CUDA: { \ + return BASENAME(__VA_ARGS__); \ + } +#endif + +#ifdef VLLM_BUILD_CPU_OPS +#define VLLM_DISPATCH_TO_CPU_CASE(BASENAME, ...) \ + case c10::DeviceType::CPU: { \ + return BASENAME##_cpu(__VA_ARGS__); \ + } +#else +#define VLLM_DISPATCH_TO_CPU_CASE(BASENAME, ...) +#endif + +#define VLLM_DISPATCH_DEVICES(DEVICE, BASENAME, ...) \ + { \ + auto device = DEVICE.type(); \ + switch (device) { \ + VLLM_DISPATCH_TO_CUDA_CASE(BASENAME, __VA_ARGS__) \ + VLLM_DISPATCH_TO_CPU_CASE(BASENAME, __VA_ARGS__) \ + default: \ + AT_ERROR('"', #BASENAME, "\" not implemented for '", \ + c10::DeviceTypeName(device), "'"); \ + } \ + } diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp index 95f557686f337..4d5279b2e5a1d 100644 --- a/csrc/pybind.cpp +++ b/csrc/pybind.cpp @@ -1,8 +1,98 @@ #include "cache.h" #include "cuda_utils.h" #include "ops.h" +#include "cpu/cpu_ops.h" +#include "dispatch_utils.h" #include +void rotary_embedding_dispatch(torch::Tensor &positions, torch::Tensor &query, + torch::Tensor &key, int head_size, + torch::Tensor &cos_sin_cache, bool is_neox) { + VLLM_DISPATCH_DEVICES(key.device(), rotary_embedding, positions, query, key, head_size, cos_sin_cache, is_neox); +} + +void paged_attention_v1_dispatch( + torch::Tensor &out, torch::Tensor &query, torch::Tensor &key_cache, + torch::Tensor &value_cache, int num_kv_heads, float scale, + torch::Tensor &block_tables, torch::Tensor &context_lens, int block_size, + int max_context_len, const c10::optional &alibi_slopes) { + VLLM_DISPATCH_DEVICES(out.device(), paged_attention_v1, out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, context_lens, block_size, max_context_len, alibi_slopes); +} + +void paged_attention_v2_dispatch(torch::Tensor &out, torch::Tensor &exp_sums, + torch::Tensor &max_logits, torch::Tensor &tmp_out, torch::Tensor &query, + torch::Tensor &key_cache, torch::Tensor &value_cache, int num_kv_heads, + float scale, torch::Tensor &block_tables, torch::Tensor &context_lens, int block_size, + int max_context_len, const c10::optional &alibi_slopes) { + VLLM_DISPATCH_DEVICES(out.device(), paged_attention_v2, out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, context_lens, block_size,max_context_len, alibi_slopes); +} + +void silu_and_mul_dispatch(torch::Tensor &out, torch::Tensor &input) { + VLLM_DISPATCH_DEVICES(out.device(), silu_and_mul, out, input); +} + +void gelu_new_dispatch(torch::Tensor &out, torch::Tensor &input) { + VLLM_DISPATCH_DEVICES(out.device(), gelu_new, out, input); +} + +void gelu_fast_dispatch(torch::Tensor &out, torch::Tensor &input) { + VLLM_DISPATCH_DEVICES(out.device(), gelu_fast, out, input); +} + +void rms_norm_dispatch(torch::Tensor &out, torch::Tensor &input, torch::Tensor &weight, float epsilon) { + VLLM_DISPATCH_DEVICES(out.device(), rms_norm, out, input, weight, epsilon); +} + +void fused_add_rms_norm_dispatch(torch::Tensor& input, torch::Tensor& residual, torch::Tensor& weight, float epsilon) { + VLLM_DISPATCH_DEVICES(input.device(), fused_add_rms_norm, input, residual, weight, epsilon); +} + +torch::Tensor awq_gemm_dispatch(torch::Tensor _in_feats, torch::Tensor _kernel, torch::Tensor _scaling_factors, torch::Tensor _zeros, int split_k_iters) { + VLLM_DISPATCH_DEVICES(_in_feats.device(), awq_gemm, _in_feats, _kernel, _scaling_factors, _zeros, split_k_iters); +} + +void squeezellm_gemm_dispatch(torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, torch::Tensor lookup_table) { + VLLM_DISPATCH_DEVICES(vec.device(), squeezellm_gemm, vec, mat, mul, lookup_table); +} + +void swap_blocks_dispatch(torch::Tensor& src, torch::Tensor& dst, const std::map& block_mapping) { + VLLM_DISPATCH_DEVICES(src.device(), swap_blocks, src, dst, block_mapping); +} + +void copy_blocks_dispatch(std::vector& key_caches, std::vector& value_caches, const std::map>& block_mapping) { + VLLM_DISPATCH_DEVICES(key_caches[0].device(), copy_blocks, key_caches, value_caches, block_mapping); +} + +void reshape_and_cache_dispatch(torch::Tensor& key, torch::Tensor& value, torch::Tensor& key_cache, torch::Tensor& value_cache, torch::Tensor& slot_mapping) { + VLLM_DISPATCH_DEVICES(key.device(), reshape_and_cache, key, value, key_cache, value_cache, slot_mapping); +} + +void gather_cached_kv_dispatch(torch::Tensor& key, torch::Tensor& value, torch::Tensor& key_cache, torch::Tensor& value_cache, torch::Tensor& slot_mapping) { + VLLM_DISPATCH_DEVICES(key.device(), gather_cached_kv, key, value, key_cache, value_cache, slot_mapping); +} + +torch::Tensor gptq_gemm_dispatch( + torch::Tensor a, + torch::Tensor b_q_weight, + torch::Tensor b_gptq_qzeros, + torch::Tensor b_gptq_scales, + torch::Tensor b_g_idx, + bool use_exllama) { + VLLM_DISPATCH_DEVICES(a.device(), gptq_gemm, a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, use_exllama); +} + +void gptq_shuffle_dispatch( + torch::Tensor q_weight, + torch::Tensor q_perm) { + VLLM_DISPATCH_DEVICES(q_weight.device(), gptq_shuffle, q_weight, q_perm); +} + +#ifdef VLLM_BUILD_CPU_ONLY +int get_device_attribute( + int attribute, + int device_id) { return 94387; } +#endif + PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { // vLLM custom ops pybind11::module ops = m.def_submodule("ops", "vLLM custom operators"); @@ -10,69 +100,69 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { // Attention ops ops.def( "paged_attention_v1", - &paged_attention_v1, + &paged_attention_v1_dispatch, "Compute the attention between an input query and the cached keys/values using PagedAttention."); ops.def( "paged_attention_v2", - &paged_attention_v2, + &paged_attention_v2_dispatch, "PagedAttention V2."); // Activation ops ops.def( "silu_and_mul", - &silu_and_mul, + &silu_and_mul_dispatch, "Activation function used in SwiGLU."); ops.def( "gelu_new", - &gelu_new, + &gelu_new_dispatch, "GELU implementation used in GPT-2."); ops.def( "gelu_fast", - &gelu_fast, + &gelu_fast_dispatch, "Approximate GELU implementation."); // Layernorm ops.def( "rms_norm", - &rms_norm, + &rms_norm_dispatch, "Apply Root Mean Square (RMS) Normalization to the input tensor."); ops.def( "fused_add_rms_norm", - &fused_add_rms_norm, + &fused_add_rms_norm_dispatch, "In-place fused Add and RMS Normalization"); // Rotary embedding ops.def( "rotary_embedding", - &rotary_embedding, + &rotary_embedding_dispatch, "Apply GPT-NeoX or GPT-J style rotary embedding to query and key"); #ifndef USE_ROCM // Quantization ops - ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ"); + ops.def("awq_gemm", &awq_gemm_dispatch, "Quantized GEMM for AWQ"); #endif - ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ"); - ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ"); - ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM"); + ops.def("gptq_gemm", &gptq_gemm_dispatch, "Quantized GEMM for GPTQ"); + ops.def("gptq_shuffle", &gptq_shuffle_dispatch, "Post processing for GPTQ"); + ops.def("squeezellm_gemm", &squeezellm_gemm_dispatch, "Quantized GEMM for SqueezeLLM"); // Cache ops pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops"); cache_ops.def( "swap_blocks", - &swap_blocks, + &swap_blocks_dispatch, "Swap in (out) the cache blocks from src to dst"); cache_ops.def( "copy_blocks", - ©_blocks, + ©_blocks_dispatch, "Copy the cache blocks from src to dst"); cache_ops.def( "reshape_and_cache", - &reshape_and_cache, + &reshape_and_cache_dispatch, "Reshape the key and value tensors and cache them"); cache_ops.def( "gather_cached_kv", - &gather_cached_kv, + &gather_cached_kv_dispatch, "Gather key and value from the cache into contiguous QKV tensors"); // Cuda utils diff --git a/examples/openai_chatcompletion_client.py b/examples/openai_chatcompletion_client.py index 0b8e4b86ef5e1..bbada3891bd19 100644 --- a/examples/openai_chatcompletion_client.py +++ b/examples/openai_chatcompletion_client.py @@ -32,6 +32,5 @@ model=model, ) - print("Chat completion results:") print(chat_completion) diff --git a/examples/openai_completion_client.py b/examples/openai_completion_client.py index 7a80c4ac49ab2..58519f978d340 100644 --- a/examples/openai_completion_client.py +++ b/examples/openai_completion_client.py @@ -21,8 +21,7 @@ echo=False, n=2, stream=stream, - logprobs=3 -) + logprobs=3) print("Completion results:") if stream: diff --git a/Dockerfile b/gpu.Dockerfile similarity index 89% rename from Dockerfile rename to gpu.Dockerfile index 44b1dd17d7e02..a480aac16f144 100644 --- a/Dockerfile +++ b/gpu.Dockerfile @@ -10,9 +10,9 @@ RUN apt-get update -y \ WORKDIR /workspace # install build and runtime dependencies -COPY requirements.txt requirements.txt +COPY requirements-gpu.txt requirements-gpu.txt RUN --mount=type=cache,target=/root/.cache/pip \ - pip install -r requirements.txt + pip install -r requirements-gpu.txt # install development dependencies COPY requirements-dev.txt requirements-dev.txt @@ -25,14 +25,14 @@ RUN --mount=type=cache,target=/root/.cache/pip \ FROM dev AS build # install build dependencies -COPY requirements-build.txt requirements-build.txt +COPY requirements-build-gpu.txt requirements-build-gpu.txt RUN --mount=type=cache,target=/root/.cache/pip \ - pip install -r requirements-build.txt + pip install -r requirements-build-gpu.txt # copy input files COPY csrc csrc COPY setup.py setup.py -COPY requirements.txt requirements.txt +COPY requirements-gpu.txt requirements-gpu.txt COPY pyproject.toml pyproject.toml COPY vllm/__init__.py vllm/__init__.py @@ -75,9 +75,9 @@ RUN apt-get update -y \ && apt-get install -y python3-pip WORKDIR /workspace -COPY requirements.txt requirements.txt +COPY requirements-gpu.txt requirements-gpu.txt RUN --mount=type=cache,target=/root/.cache/pip \ - pip install -r requirements.txt + pip install -r requirements-gpu.txt #################### RUNTIME BASE IMAGE #################### diff --git a/pyproject.toml b/pyproject.toml index b197256f6ff55..54bec2e47013b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ requires = [ "ninja", "packaging", "setuptools >= 49.4.0", - "torch == 2.1.2", + "torch == 2.1.2+cpu", "wheel", ] build-backend = "setuptools.build_meta" diff --git a/requirements-build-cpu.txt b/requirements-build-cpu.txt new file mode 100644 index 0000000000000..04312d11f61b9 --- /dev/null +++ b/requirements-build-cpu.txt @@ -0,0 +1,6 @@ +# Should be mirrored in pyproject.toml +ninja +packaging +setuptools>=49.4.0 +torch==2.1.2+cpu +wheel \ No newline at end of file diff --git a/requirements-build.txt b/requirements-build-gpu.txt similarity index 100% rename from requirements-build.txt rename to requirements-build-gpu.txt diff --git a/requirements-cpu.txt b/requirements-cpu.txt new file mode 100644 index 0000000000000..6f3ff1cce61fe --- /dev/null +++ b/requirements-cpu.txt @@ -0,0 +1,15 @@ +ninja # For faster builds. +psutil +ray >= 2.5.1 +pandas # Required for Ray data. +pyarrow # Required for Ray data. +pybind11 +sentencepiece # Required for LLaMA tokenizer. +numpy +einops # Required for phi-1_5 +torch == 2.1.2+cpu +transformers >= 4.34.0 # Required for Mistral. +fastapi +uvicorn[standard] +pydantic == 1.10.13 # Required for OpenAI server. +aioprometheus[starlette] diff --git a/requirements.txt b/requirements-gpu.txt similarity index 100% rename from requirements.txt rename to requirements-gpu.txt diff --git a/setup.py b/setup.py index fe8cd6d75ed76..aff6e86b235ec 100644 --- a/setup.py +++ b/setup.py @@ -8,7 +8,13 @@ from packaging.version import parse, Version import setuptools import torch -from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME, ROCM_HOME + +BUILD_CPU_ONLY = os.getenv('VLLM_BUILD_CPU_ONLY', "1") == "1" + +if not BUILD_CPU_ONLY: + from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME, ROCM_HOME +else: + from torch.utils.cpp_extension import BuildExtension, CppExtension ROOT_DIR = os.path.dirname(__file__) @@ -21,11 +27,11 @@ def _is_hip() -> bool: - return torch.version.hip is not None + return torch.version.hip is not None and not BUILD_CPU_ONLY def _is_cuda() -> bool: - return torch.version.cuda is not None + return torch.version.cuda is not None and not BUILD_CPU_ONLY # Compiler flags. @@ -86,7 +92,6 @@ def get_hipcc_rocm_version(): print("Could not find HIP version in the output") return None - def get_nvcc_cuda_version(cuda_dir: str) -> Version: """Get the CUDA version from nvcc. @@ -137,6 +142,19 @@ def get_torch_arch_list() -> Set[str]: stacklevel=2) return arch_list +if not BUILD_CPU_ONLY: + # First, check the TORCH_CUDA_ARCH_LIST environment variable. + compute_capabilities = get_torch_arch_list() + if not compute_capabilities: + # If TORCH_CUDA_ARCH_LIST is not defined or empty, target all available + # GPUs on the current machine. + device_count = torch.cuda.device_count() + for i in range(device_count): + major, minor = torch.cuda.get_device_capability(i) + if major < 7: + raise RuntimeError( + "GPUs with compute capability below 7.0 are not supported.") + compute_capabilities.add(f"{major}.{minor}") # First, check the TORCH_CUDA_ARCH_LIST environment variable. compute_capabilities = get_torch_arch_list() @@ -210,31 +228,62 @@ def get_torch_arch_list() -> Set[str]: f"Only the following arch is supported: {ROCM_SUPPORTED_ARCHS}" f"amdgpu_arch_found: {amd_arch}") +# Setup CPU Operations +BUILD_CPU_OPS = (os.getenv('VLLM_BUILD_CPU_OPS', "0") == "1" or BUILD_CPU_ONLY) +CPU_OPS_SOURCES = [] +if BUILD_CPU_OPS: + if BUILD_CPU_ONLY: + CXX_FLAGS += ["-DVLLM_BUILD_CPU_ONLY"] + CXX_FLAGS += [ + "-DVLLM_BUILD_CPU_OPS", "-fopenmp", "-mavx512f", "-mavx512bf16", + "-mavx512vl" + ] + CPU_OPS_SOURCES += [ + "csrc/cpu/activation_impl.cpp", + "csrc/cpu/attention_impl.cpp", + "csrc/cpu/cache_impl.cpp", + "csrc/cpu/layernorm_impl.cpp", + "csrc/cpu/pos_encoding_impl.cpp", + ] + ext_modules = [] -vllm_extension_sources = [ - "csrc/cache_kernels.cu", - "csrc/attention/attention_kernels.cu", - "csrc/pos_encoding_kernels.cu", - "csrc/activation_kernels.cu", - "csrc/layernorm_kernels.cu", - "csrc/quantization/squeezellm/quant_cuda_kernel.cu", - "csrc/quantization/gptq/q_gemm.cu", - "csrc/cuda_utils_kernels.cu", - "csrc/pybind.cpp", -] +if not BUILD_CPU_ONLY: + vllm_extension_sources = [ + "csrc/cache_kernels.cu", + "csrc/attention/attention_kernels.cu", + "csrc/pos_encoding_kernels.cu", + "csrc/activation_kernels.cu", + "csrc/layernorm_kernels.cu", + "csrc/quantization/squeezellm/quant_cuda_kernel.cu", + "csrc/quantization/gptq/q_gemm.cu", + "csrc/cuda_utils_kernels.cu", + "csrc/pybind.cpp", + ] + CPU_OPS_SOURCES + + if _is_cuda(): + vllm_extension_sources.append("csrc/quantization/awq/gemm_kernels.cu") + + vllm_extension = CUDAExtension( + name="vllm._C", + sources=vllm_extension_sources, + extra_compile_args={ + "cxx": CXX_FLAGS, + "nvcc": NVCC_FLAGS, + }, + ) +else: + vllm_extension_sources = [ + "csrc/pybind.cpp", + ] + CPU_OPS_SOURCES + vllm_extension = CppExtension( + name="vllm._C", + sources=vllm_extension_sources, + extra_compile_args={ + "cxx": CXX_FLAGS, + }, + ) -if _is_cuda(): - vllm_extension_sources.append("csrc/quantization/awq/gemm_kernels.cu") - -vllm_extension = CUDAExtension( - name="vllm._C", - sources=vllm_extension_sources, - extra_compile_args={ - "cxx": CXX_FLAGS, - "nvcc": NVCC_FLAGS, - }, -) ext_modules.append(vllm_extension) @@ -264,7 +313,7 @@ def get_vllm_version() -> str: if hipcc_version != MAIN_CUDA_VERSION: rocm_version_str = hipcc_version.replace(".", "")[:3] version += f"+rocm{rocm_version_str}" - else: + elif _is_cuda(): cuda_version = str(nvcc_cuda_version) if cuda_version != MAIN_CUDA_VERSION: cuda_version_str = cuda_version.replace(".", "")[:3] @@ -287,9 +336,13 @@ def get_requirements() -> List[str]: if _is_hip(): with open(get_path("requirements-rocm.txt")) as f: requirements = f.read().strip().split("\n") + elif _is_cuda(): + with open(get_path("requirements-gpu.txt")) as f: + requirements = f.read().strip().split("\n") else: - with open(get_path("requirements.txt")) as f: + with open(get_path("requirements-cpu.txt")) as f: requirements = f.read().strip().split("\n") + return requirements diff --git a/tests/kernels/conftest.py b/tests/kernels/conftest.py index fca97ab76bf09..bab5c1e93cd09 100644 --- a/tests/kernels/conftest.py +++ b/tests/kernels/conftest.py @@ -11,11 +11,13 @@ def create_kv_caches( num_heads: int, head_size: int, dtype: torch.dtype, - seed: int, device: str, + seed: int, ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: torch.random.manual_seed(seed) - torch.cuda.manual_seed(seed) + + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) scale = head_size**-0.5 x = 16 // torch.tensor([], dtype=dtype).element_size() diff --git a/tests/kernels/test_activation.py b/tests/kernels/test_activation.py index 826bf8350af17..617c67435c130 100644 --- a/tests/kernels/test_activation.py +++ b/tests/kernels/test_activation.py @@ -7,7 +7,7 @@ NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing D = [512, 4096, 5120, 13824] # Arbitrary values for testing SEEDS = [0] -DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)] +DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)] @pytest.mark.parametrize("num_tokens", NUM_TOKENS) @@ -20,19 +20,34 @@ def test_silu_and_mul( num_tokens: int, d: int, dtype: torch.dtype, + device: str, seed: int, - device: int, ) -> None: torch.random.manual_seed(seed) torch.cuda.manual_seed(seed) - gpu_id = f"cuda:{device}" - x = torch.randn(num_tokens, 2 * d, dtype=dtype, device=gpu_id) + x = torch.randn(num_tokens, 2 * d, dtype=dtype, device=device) layer = SiluAndMul() out = layer(x) ref_out = layer._forward(x) assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5) +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("d", D) +@pytest.mark.parametrize("dtype", [torch.float, torch.bfloat16]) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", ['cpu']) +@torch.inference_mode() +def test_silu_and_mul_cpu( + num_tokens: int, + d: int, + dtype: torch.dtype, + device: str, + seed: int, +) -> None: + test_silu_and_mul(num_tokens, d, dtype, device, seed) + + @pytest.mark.parametrize("num_tokens", NUM_TOKENS) @pytest.mark.parametrize("d", D) @pytest.mark.parametrize("dtype", DTYPES) @@ -44,12 +59,11 @@ def test_gelu_new( d: int, dtype: torch.dtype, seed: int, - device: int, + device: str, ) -> None: torch.random.manual_seed(seed) torch.cuda.manual_seed(seed) - gpu_id = f"cuda:{device}" - x = torch.randn(num_tokens, d, dtype=dtype, device=gpu_id) + x = torch.randn(num_tokens, d, dtype=dtype, device=device) layer = NewGELU() out = layer(x) ref_out = layer._forward(x) @@ -66,12 +80,11 @@ def test_gelu_fast( d: int, dtype: torch.dtype, seed: int, - device: int, + device: str, ) -> None: torch.random.manual_seed(seed) torch.cuda.manual_seed(seed) - gpu_id = f"cuda:{device}" - x = torch.randn(num_tokens, d, dtype=dtype, device=gpu_id) + x = torch.randn(num_tokens, d, dtype=dtype, device=device) layer = FastGELU() out = layer(x) ref_out = layer._forward(x) diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index 3949948e860f7..966caf6948207 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -24,7 +24,7 @@ BLOCK_SIZES = [16, 32] USE_ALIBI = [False, True] SEEDS = [0] -DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)] +DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)] def ref_masked_attention( @@ -59,8 +59,8 @@ def ref_single_query_cached_kv_attention( block_size = value_cache.shape[3] num_seqs = query.shape[0] - block_tables = block_tables.cpu().tolist() - context_lens = context_lens.cpu().tolist() + block_tables = block_tables.cpu() + context_lens = context_lens.cpu() for i in range(num_seqs): q = query[i].unsqueeze(0) block_table = block_tables[i] @@ -107,6 +107,7 @@ def ref_single_query_cached_kv_attention( @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("device", DEVICES) +@torch.inference_mode() def test_paged_attention( kv_cache_factory, version: str, @@ -116,20 +117,22 @@ def test_paged_attention( use_alibi: bool, block_size: int, dtype: torch.dtype, + device: str, seed: int, - device: int, ) -> None: random.seed(seed) torch.random.manual_seed(seed) - torch.cuda.manual_seed(seed) - gpu_id = f"cuda:{device}" + + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + scale = float(1.0 / (head_size**0.5)) num_query_heads, num_kv_heads = num_heads query = torch.empty(num_seqs, num_query_heads, head_size, dtype=dtype, - device=gpu_id) + device=device) query.uniform_(-scale, scale) assert num_query_heads % num_kv_heads == 0 @@ -138,12 +141,12 @@ def test_paged_attention( if use_alibi: alibi_slopes = torch.randn(num_query_heads, dtype=torch.float, - device=gpu_id) + device=device) context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)] context_lens[-1] = MAX_SEQ_LEN max_context_len = max(context_lens) - context_lens = torch.tensor(context_lens, dtype=torch.int, device=gpu_id) + context_lens = torch.tensor(context_lens, dtype=torch.int, device=device) # Create the block tables. max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size @@ -154,12 +157,12 @@ def test_paged_attention( for _ in range(max_num_blocks_per_seq) ] block_tables.append(block_table) - block_tables = torch.tensor(block_tables, dtype=torch.int, device=gpu_id) + block_tables = torch.tensor(block_tables, dtype=torch.int, device=device) # Create the KV caches. key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1, num_kv_heads, head_size, dtype, - seed, gpu_id) + device, seed) key_cache, value_cache = key_caches[0], value_caches[0] # Call the paged attention kernel. @@ -233,6 +236,34 @@ def test_paged_attention( assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5) +@pytest.mark.parametrize("version", ["v1"]) +@pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("use_alibi", [False]) +@pytest.mark.parametrize( + "block_size", [16] +) # FIXME: Currently we only use 16 due to the limitation of the YMM register number. +@pytest.mark.parametrize("dtype", [torch.float, torch.bfloat16]) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", ['cpu']) +@torch.inference_mode() +def test_paged_attention_cpu( + kv_cache_factory, + version: str, + num_seqs: int, + num_heads: Tuple[int, int], + head_size: int, + use_alibi: bool, + block_size: int, + dtype: torch.dtype, + device: str, + seed: int, +) -> None: + test_paged_attention(kv_cache_factory, version, num_seqs, num_heads, + head_size, use_alibi, block_size, dtype, device, seed) + + def ref_multi_query_kv_attention( cu_seq_lens: List[int], query: torch.Tensor, @@ -240,6 +271,7 @@ def ref_multi_query_kv_attention( value: torch.Tensor, scale: float, dtype: torch.dtype, + device: str, ) -> torch.Tensor: num_seqs = len(cu_seq_lens) - 1 ref_outputs = [] @@ -252,7 +284,7 @@ def ref_multi_query_kv_attention( attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype), diagonal=1) attn_mask = attn_mask * torch.finfo(dtype).min - attn_mask = attn_mask.to(dtype=dtype, device=query.device) + attn_mask = attn_mask.to(dtype=dtype, device=device) ref_output = ref_masked_attention( query[start_idx:end_idx], @@ -279,13 +311,15 @@ def test_multi_query_kv_attention( num_heads: Tuple[int, int], head_size: int, dtype: torch.dtype, + device: str, seed: int, - device: int, ) -> None: random.seed(seed) torch.random.manual_seed(seed) - torch.cuda.manual_seed(seed) - gpu_id = f"cuda:{device}" + + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + # MAX_SEQ_LEN sometimes causes OOM in the reference implementation. # As the xformers library is already tested with its own tests, we can use # a smaller MAX_SEQ_LEN here. @@ -299,7 +333,7 @@ def test_multi_query_kv_attention( num_query_heads + 2 * num_kv_heads, head_size, dtype=dtype, - device=gpu_id) + device=device) qkv.uniform_(-scale, scale) query, key, value = qkv.split( [num_query_heads, num_kv_heads, num_kv_heads], dim=1) @@ -309,26 +343,48 @@ def test_multi_query_kv_attention( # Handle MQA and GQA key = torch.repeat_interleave(key, num_queries_per_kv, dim=1) value = torch.repeat_interleave(value, num_queries_per_kv, dim=1) + + output = None attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_lens) - output = xops.memory_efficient_attention_forward( - query.unsqueeze(0), - key.unsqueeze(0), - value.unsqueeze(0), - attn_bias=attn_bias, - p=0.0, - scale=scale, - ) - output = output.squeeze(0) + if device.count('cuda'): + output = xops.memory_efficient_attention_forward( + query.unsqueeze(0), + key.unsqueeze(0), + value.unsqueeze(0), + attn_bias=attn_bias, + p=0.0, + scale=scale, + ) + output = output.squeeze(0) + else: + output = torch.nn.functional.scaled_dot_product_attention( + query.transpose(0, 1), key.transpose(0, 1), value.transpose(0, 1), + attn_bias.materialize((1, num_tokens, num_tokens), dtype=dtype), + 0.0).transpose(0, 1) cu_seq_lens = [0] for seq_len in seq_lens: cu_seq_lens.append(cu_seq_lens[-1] + seq_len) - ref_output = ref_multi_query_kv_attention( - cu_seq_lens, - query, - key, - value, - scale, - dtype, - ) - assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5) + ref_output = ref_multi_query_kv_attention(cu_seq_lens, query, key, value, + scale, dtype, device) + assert torch.allclose(output, ref_output, atol=1e-3, + rtol=1e-5) # type: ignore + + +@pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("dtype", [torch.float, torch.bfloat16]) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", ['cpu']) +@torch.inference_mode() +def test_multi_query_kv_attention_cpu( + num_seqs: int, + num_heads: Tuple[int, int], + head_size: int, + dtype: torch.dtype, + device: str, + seed: int, +) -> None: + test_multi_query_kv_attention(num_seqs, num_heads, head_size, dtype, + device, seed) diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py index 7b1cc058f2cb5..87b7dd5c0ea4a 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/test_cache.py @@ -14,7 +14,7 @@ NUM_BLOCKS = [1024, 3600] # Arbitrary values for testing NUM_MAPPINGS = [256] # Arbitrary values for testing SEEDS = [0] -DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)] +DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)] @pytest.mark.parametrize("num_mappings", NUM_MAPPINGS) @@ -36,13 +36,15 @@ def test_copy_blocks( block_size: int, num_blocks: int, dtype: torch.dtype, + device: str, seed: int, - device: int, ) -> None: random.seed(seed) torch.random.manual_seed(seed) - torch.cuda.manual_seed(seed) - gpu_id = f"cuda:{device}" + + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + # Generate random block mappings where each source block is mapped to two # destination blocks. assert 2 * num_mappings <= num_blocks @@ -59,7 +61,7 @@ def test_copy_blocks( # Create the KV caches. key_caches, value_caches = kv_cache_factory(num_blocks, block_size, num_layers, num_heads, - head_size, dtype, seed, gpu_id) + head_size, dtype, device, seed) # Clone the KV caches. cloned_key_caches = [key_cache.clone() for key_cache in key_caches] @@ -84,6 +86,32 @@ def test_copy_blocks( assert torch.allclose(value_cache, cloned_value_cache) +@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS) +@pytest.mark.parametrize("num_layers", NUM_LAYERS) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("block_size", BLOCK_SIZES) +@pytest.mark.parametrize("num_blocks", NUM_BLOCKS) +@pytest.mark.parametrize("dtype", [torch.float, torch.bfloat16]) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", ['cpu']) +@torch.inference_mode() +def test_copy_blocks_cpu( + kv_cache_factory, + num_mappings: int, + num_layers: int, + num_heads: int, + head_size: int, + block_size: int, + num_blocks: int, + dtype: torch.dtype, + device: str, + seed: int, +) -> None: + test_copy_blocks(kv_cache_factory, num_mappings, num_layers, num_heads, + head_size, block_size, num_blocks, dtype, device, seed) + + @pytest.mark.parametrize("num_tokens", NUM_TOKENS) @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @@ -101,30 +129,32 @@ def test_reshape_and_cache( block_size: int, num_blocks: int, dtype: torch.dtype, + device: str, seed: int, - device: int, ) -> None: random.seed(seed) torch.random.manual_seed(seed) - torch.cuda.manual_seed(seed) - gpu_id = f"cuda:{device}" + + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + # Create a random slot mapping. num_slots = block_size * num_blocks slot_mapping = random.sample(range(num_slots), num_tokens) - slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device=gpu_id) + slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device=device) qkv = torch.randn(num_tokens, 3, num_heads, head_size, dtype=dtype, - device=gpu_id) + device=device) _, key, value = qkv.unbind(dim=1) # Create the KV caches. key_caches, value_caches = kv_cache_factory(num_blocks, block_size, 1, num_heads, head_size, dtype, - seed, gpu_id) + device, seed) key_cache, value_cache = key_caches[0], value_caches[0] # Clone the KV caches. @@ -149,3 +179,27 @@ def test_reshape_and_cache( assert torch.allclose(key_cache, cloned_key_cache) assert torch.allclose(value_cache, cloned_value_cache) + + +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("block_size", BLOCK_SIZES) +@pytest.mark.parametrize("num_blocks", NUM_BLOCKS) +@pytest.mark.parametrize("dtype", [torch.float, torch.bfloat16]) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", ['cpu']) +@torch.inference_mode() +def test_reshape_and_cache_cpu( + kv_cache_factory, + num_tokens: int, + num_heads: int, + head_size: int, + block_size: int, + num_blocks: int, + dtype: torch.dtype, + device: str, + seed: int, +) -> None: + test_reshape_and_cache(kv_cache_factory, num_tokens, num_heads, head_size, + block_size, num_blocks, dtype, device, seed) diff --git a/tests/kernels/test_layernorm.py b/tests/kernels/test_layernorm.py index 8a06b3aa268be..6d32cdeff90ae 100644 --- a/tests/kernels/test_layernorm.py +++ b/tests/kernels/test_layernorm.py @@ -8,7 +8,7 @@ HIDDEN_SIZES = [768, 5120, 8192] # Arbitrary values for testing ADD_RESIDUAL = [False, True] SEEDS = [0] -DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)] +DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)] @pytest.mark.parametrize("num_tokens", NUM_TOKENS) @@ -23,16 +23,18 @@ def test_rms_norm( hidden_size: int, add_residual: bool, dtype: torch.dtype, + device: str, seed: int, - device: int, ) -> None: torch.random.manual_seed(seed) - torch.cuda.manual_seed(seed) - gpu_id = f"cuda:{device}" - layer = RMSNorm(hidden_size).to(dtype=dtype, device=gpu_id) + + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + + layer = RMSNorm(hidden_size).to(dtype=dtype, device=torch.device(device)) layer.weight.data.normal_(mean=1.0, std=0.1) scale = 1 / (2 * hidden_size) - x = torch.randn(num_tokens, hidden_size, dtype=dtype, device=gpu_id) + x = torch.randn(num_tokens, hidden_size, dtype=dtype, device=device) x *= scale residual = torch.randn_like(x) * scale if add_residual else None @@ -48,3 +50,21 @@ def test_rms_norm( assert torch.allclose(out[1], ref_out[1], atol=1e-2, rtol=1e-2) else: assert torch.allclose(out, ref_out, atol=1e-2, rtol=1e-2) + + +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("hidden_size", [64, 768, 2048, 5120, 8192]) +@pytest.mark.parametrize("add_residual", ADD_RESIDUAL) +@pytest.mark.parametrize("dtype", [torch.float, torch.bfloat16]) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", ['cpu']) +@torch.inference_mode() +def test_rms_norm_cpu( + num_tokens: int, + hidden_size: int, + add_residual: bool, + dtype: torch.dtype, + device: str, + seed: int, +) -> None: + test_rms_norm(num_tokens, hidden_size, add_residual, dtype, device, seed) diff --git a/tests/kernels/test_pos_encoding.py b/tests/kernels/test_pos_encoding.py index aad310e2bc6d2..72412d14278e3 100644 --- a/tests/kernels/test_pos_encoding.py +++ b/tests/kernels/test_pos_encoding.py @@ -13,7 +13,7 @@ BATCH_SIZES = [1, 5] # Arbitrary values for testing SEQ_LENS = [11, 8192] # Arbitrary values for testing SEEDS = [0] -DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)] +DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)] @pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE) @@ -34,29 +34,31 @@ def test_rotary_embedding( head_size: int, rotary_dim: Optional[int], dtype: torch.dtype, + device: str, seed: int, - device: int, max_position: int = 8192, base: int = 10000, ) -> None: if rotary_dim is None: rotary_dim = head_size torch.random.manual_seed(seed) - torch.cuda.manual_seed(seed) - gpu_id = f"cuda:{device}" + + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + if rotary_dim is None: rotary_dim = head_size rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style) - rope = rope.to(dtype=dtype, device=gpu_id) + rope = rope.to(dtype=dtype, device=torch.device(device)) positions = torch.randint(0, max_position, (batch_size, seq_len), - device=gpu_id) + device=device) query = torch.randn(batch_size, seq_len, num_heads * head_size, dtype=dtype, - device=gpu_id) + device=device) key = torch.randn_like(query) # NOTE(woosuk): The reference implementation should be executed first @@ -66,3 +68,31 @@ def test_rotary_embedding( # Compare the results. assert torch.allclose(out_query, ref_query, atol=1e-5, rtol=1e-5) assert torch.allclose(out_key, ref_key, atol=1e-5, rtol=1e-5) + + +@pytest.mark.parametrize("is_neox_style", [True]) +@pytest.mark.parametrize("batch_size", BATCH_SIZES) +@pytest.mark.parametrize("seq_len", SEQ_LENS) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS) +@pytest.mark.parametrize("dtype", [torch.float, torch.bfloat16]) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", ['cpu']) +@torch.inference_mode() +def test_rotary_embedding_cpu( + is_neox_style: bool, + batch_size: int, + seq_len: int, + num_heads: int, + head_size: int, + rotary_dim: Optional[int], + dtype: torch.dtype, + device: str, + seed: int, + max_position: int = 8192, + base: int = 10000, +) -> None: + test_rotary_embedding(is_neox_style, batch_size, seq_len, num_heads, + head_size, rotary_dim, dtype, device, seed, + max_position, base) diff --git a/vllm/config.py b/vllm/config.py index f1efcc66e9097..2139280fb081a 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -73,6 +73,7 @@ def __init__( quantization: Optional[str] = None, enforce_eager: bool = False, max_context_len_to_capture: Optional[int] = None, + device: str = "cuda", ) -> None: self.model = model self.tokenizer = tokenizer @@ -86,6 +87,10 @@ def __init__( self.quantization = quantization self.enforce_eager = enforce_eager self.max_context_len_to_capture = max_context_len_to_capture + self.device = torch.device(device) + + if device == "cpu" and not enforce_eager: + self.enforce_eager = True if os.environ.get("VLLM_USE_MODELSCOPE", "False").lower() == "true": # download model from ModelScope hub, @@ -99,7 +104,7 @@ def __init__( self.tokenizer = model_path self.hf_config = get_config(self.model, trust_remote_code, revision) - self.dtype = _get_and_verify_dtype(self.hf_config, dtype) + self.dtype = _get_and_verify_dtype(self.hf_config, dtype, self.device) self.max_model_len = _get_and_verify_max_len(self.hf_config, max_model_len) self._verify_load_format() @@ -280,11 +285,13 @@ def __init__( gpu_memory_utilization: float, swap_space: int, sliding_window: Optional[int] = None, + cpu_only: bool = False, ) -> None: self.block_size = block_size self.gpu_memory_utilization = gpu_memory_utilization self.swap_space_bytes = swap_space * _GB self.sliding_window = sliding_window + self.cpu_only = cpu_only self._verify_args() # Will be set after profiling. @@ -333,6 +340,7 @@ def __init__( tensor_parallel_size: int, worker_use_ray: bool, max_parallel_loading_workers: Optional[int] = None, + device: str = "cuda", ) -> None: self.pipeline_parallel_size = pipeline_parallel_size self.tensor_parallel_size = tensor_parallel_size @@ -340,6 +348,16 @@ def __init__( self.max_parallel_loading_workers = max_parallel_loading_workers self.world_size = pipeline_parallel_size * tensor_parallel_size + + self.device = torch.device(device) + + if self.device == torch.device("cpu"): + logger.info( + "CPU-only mode doesn't support parallel execution currently.") + self.pipeline_parallel_size = 1 + self.tensor_parallel_size = 1 + self.world_size = 1 + if self.world_size > 1: self.worker_use_ray = True self._verify_args() @@ -411,6 +429,7 @@ def _verify_args(self) -> None: def _get_and_verify_dtype( config: PretrainedConfig, dtype: Union[str, torch.dtype], + device: torch.device, ) -> torch.dtype: # NOTE: getattr(config, "torch_dtype", torch.float32) is not correct # because config.torch_dtype can be None. @@ -444,6 +463,10 @@ def _get_and_verify_dtype( raise ValueError(f"dtype \'{dtype}\' is not supported in ROCm. " f"Supported dtypes are {rocm_supported_dtypes}") + if torch_dtype == torch.float16 and device == torch.device("cpu"): + torch_dtype = torch.bfloat16 + logger.warning("float16 is not supported on CPU, casting to bfloat16.") + # Verify the dtype. if torch_dtype != config_dtype: if torch_dtype == torch.float32: diff --git a/vllm/core/block_manager.py b/vllm/core/block_manager.py index 3bde005997bde..0e557f4c9c161 100644 --- a/vllm/core/block_manager.py +++ b/vllm/core/block_manager.py @@ -78,6 +78,7 @@ def __init__( num_cpu_blocks: int, watermark: float = 0.01, sliding_window: Optional[int] = None, + cpu_only: bool = False, ) -> None: self.block_size = block_size self.num_total_gpu_blocks = num_gpu_blocks @@ -92,7 +93,8 @@ def __init__( self.watermark = watermark assert watermark >= 0.0 - self.watermark_blocks = int(watermark * num_gpu_blocks) + self.watermark_blocks = int( + watermark * (num_gpu_blocks if not cpu_only else num_cpu_blocks)) self.gpu_allocator = BlockAllocator(Device.GPU, block_size, num_gpu_blocks) self.cpu_allocator = BlockAllocator(Device.CPU, block_size, @@ -100,6 +102,8 @@ def __init__( # Mapping: seq_id -> BlockTable. self.block_tables: Dict[int, BlockTable] = {} + self.cpu_only = cpu_only + def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: # FIXME(woosuk): Here we assume that all sequences in the group share # the same prompt. This may not be true for preempted sequences. @@ -108,17 +112,31 @@ def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: if self.block_sliding_window is not None: num_required_blocks = min(num_required_blocks, self.block_sliding_window) - num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks() + + num_free_blocks = self.gpu_allocator.get_num_free_blocks( + ) if not self.cpu_only else self.cpu_allocator.get_num_free_blocks() + num_total_blocks = self.num_total_gpu_blocks if not self.cpu_only else self.num_total_cpu_blocks # Use watermark to avoid frequent cache eviction. - if (self.num_total_gpu_blocks - num_required_blocks < - self.watermark_blocks): + if (num_total_blocks - num_required_blocks < self.watermark_blocks): return AllocStatus.NEVER - if num_free_gpu_blocks - num_required_blocks >= self.watermark_blocks: + if num_free_blocks - num_required_blocks >= self.watermark_blocks: return AllocStatus.OK else: return AllocStatus.LATER + def _allocate(self) -> PhysicalTokenBlock: + if not self.cpu_only: + return self.gpu_allocator.allocate() + else: + return self.cpu_allocator.allocate() + + def _free(self, block: PhysicalTokenBlock): + if block.device == Device.CPU: + self.cpu_allocator.free(block) + else: + self.gpu_allocator.free(block) + def allocate(self, seq_group: SequenceGroup) -> None: # NOTE: Here we assume that all sequences in the group have the same # prompt. @@ -131,7 +149,7 @@ def allocate(self, seq_group: SequenceGroup) -> None: and logical_idx >= self.block_sliding_window): block = block_table[logical_idx % self.block_sliding_window] else: - block = self.gpu_allocator.allocate() + block = self._allocate() # Set the reference counts of the token blocks. block.ref_count = seq_group.num_seqs() block_table.append(block) @@ -143,9 +161,10 @@ def allocate(self, seq_group: SequenceGroup) -> None: def can_append_slot(self, seq_group: SequenceGroup) -> bool: # Simple heuristic: If there is at least one free block # for each sequence, we can append. - num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks() + num_free_blocks = self.get_num_free_gpu_blocks( + ) if not self.cpu_only else self.get_num_free_cpu_blocks() num_seqs = seq_group.num_seqs(status=SequenceStatus.RUNNING) - return num_seqs <= num_free_gpu_blocks + return num_seqs <= num_free_blocks def append_slot(self, seq: Sequence) -> Optional[Tuple[int, int]]: """Allocate a physical slot for a new token.""" @@ -161,22 +180,22 @@ def append_slot(self, seq: Sequence) -> Optional[Tuple[int, int]]: else: # The sequence has a new logical block. # Allocate a new physical block. - block = self.gpu_allocator.allocate() + block = self._allocate() block_table.append(block) return None # We want to append the token to the last physical block. last_block = block_table[-1] - assert last_block.device == Device.GPU + assert last_block.device == Device.GPU or self.cpu_only if last_block.ref_count == 1: # Not shared with other sequences. Appendable. return None else: # The last block is shared with other sequences. # Copy on Write: Allocate a new block and copy the tokens. - new_block = self.gpu_allocator.allocate() + new_block = self._allocate() block_table[-1] = new_block - self.gpu_allocator.free(last_block) + self._free(last_block) return last_block.block_number, new_block.block_number def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None: @@ -199,6 +218,9 @@ def _get_physical_blocks( return list(blocks) def can_swap_in(self, seq_group: SequenceGroup) -> bool: + if self.cpu_only: + return True + blocks = self._get_physical_blocks(seq_group) num_swapped_seqs = seq_group.num_seqs(status=SequenceStatus.SWAPPED) num_free_blocks = self.gpu_allocator.get_num_free_blocks() @@ -209,6 +231,9 @@ def can_swap_in(self, seq_group: SequenceGroup) -> bool: return num_free_blocks - num_required_blocks >= self.watermark_blocks def swap_in(self, seq_group: SequenceGroup) -> Dict[int, int]: + if self.cpu_only: + return {} + # CPU block -> GPU block. mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {} for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED): @@ -234,10 +259,16 @@ def swap_in(self, seq_group: SequenceGroup) -> Dict[int, int]: return block_number_mapping def can_swap_out(self, seq_group: SequenceGroup) -> bool: + if self.cpu_only: + return True + blocks = self._get_physical_blocks(seq_group) return len(blocks) <= self.cpu_allocator.get_num_free_blocks() def swap_out(self, seq_group: SequenceGroup) -> Dict[int, int]: + if self.cpu_only: + return {} + # GPU block -> CPU block. mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {} for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): @@ -264,10 +295,7 @@ def swap_out(self, seq_group: SequenceGroup) -> Dict[int, int]: def _free_block_table(self, block_table: BlockTable) -> None: for block in set(block_table): - if block.device == Device.GPU: - self.gpu_allocator.free(block) - else: - self.cpu_allocator.free(block) + self._free(block) def free(self, seq: Sequence) -> None: if seq.seq_id not in self.block_tables: diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 9fe01a14aedcc..9d199af3e38af 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -74,7 +74,9 @@ def __init__( block_size=self.cache_config.block_size, num_gpu_blocks=self.cache_config.num_gpu_blocks, num_cpu_blocks=self.cache_config.num_cpu_blocks, - sliding_window=self.cache_config.sliding_window) + sliding_window=self.cache_config.sliding_window, + cpu_only=cache_config.cpu_only, + ) # Sequence groups in the WAITING state. self.waiting: Deque[SequenceGroup] = deque() diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 7e58069e2c22d..b253f7c170e74 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -35,6 +35,7 @@ class EngineArgs: quantization: Optional[str] = None enforce_eager: bool = False max_context_len_to_capture: int = 8192 + device: str = 'cuda' def __post_init__(self): if self.tokenizer is None: @@ -115,6 +116,12 @@ def add_cli_args( 'The "auto" option will use FP16 precision ' 'for FP32 and FP16 models, and BF16 precision ' 'for BF16 models.') + parser.add_argument( + "--device", + type=str, + default="cuda", + choices=["cuda", "cpu"], + help='device type for vLLM execution, supporting CUDA and CPU.') parser.add_argument('--max-model-len', type=int, default=None, @@ -221,15 +228,17 @@ def create_engine_configs( self.dtype, self.seed, self.revision, self.tokenizer_revision, self.max_model_len, self.quantization, self.enforce_eager, - self.max_context_len_to_capture) + self.max_context_len_to_capture, self.device) cache_config = CacheConfig(self.block_size, self.gpu_memory_utilization, self.swap_space, - model_config.get_sliding_window()) + model_config.get_sliding_window(), + self.device == 'cpu') parallel_config = ParallelConfig(self.pipeline_parallel_size, self.tensor_parallel_size, self.worker_use_ray, - self.max_parallel_loading_workers) + self.max_parallel_loading_workers, + device=self.device) scheduler_config = SchedulerConfig(self.max_num_batched_tokens, self.max_num_seqs, model_config.max_model_len, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index e30bf5db49283..567d512559f5a 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -294,11 +294,11 @@ def _init_cache(self) -> None: logger.info(f"# GPU blocks: {num_gpu_blocks}, " f"# CPU blocks: {num_cpu_blocks}") - if num_gpu_blocks <= 0: + if num_gpu_blocks < 0 or (not self.cache_config.cpu_only and num_gpu_blocks == 0): raise ValueError("No available memory for the cache blocks. " "Try increasing `gpu_memory_utilization` when " "initializing the engine.") - max_seq_len = self.cache_config.block_size * num_gpu_blocks + max_seq_len = self.cache_config.block_size * (num_gpu_blocks if not self.cache_config.cpu_only else num_cpu_blocks) if self.model_config.max_model_len > max_seq_len: raise ValueError( f"The model's max seq len ({self.model_config.max_model_len}) " @@ -783,7 +783,7 @@ def _log_system_stats( num_free_gpu_blocks = ( self.scheduler.block_manager.get_num_free_gpu_blocks()) num_used_gpu_blocks = total_num_gpu_blocks - num_free_gpu_blocks - gpu_cache_usage = num_used_gpu_blocks / total_num_gpu_blocks + gpu_cache_usage = num_used_gpu_blocks / total_num_gpu_blocks if total_num_gpu_blocks > 0 else 0.0 total_num_cpu_blocks = self.cache_config.num_cpu_blocks if total_num_cpu_blocks > 0: diff --git a/vllm/model_executor/input_metadata.py b/vllm/model_executor/input_metadata.py index da615ecccf993..02fd92caa4582 100644 --- a/vllm/model_executor/input_metadata.py +++ b/vllm/model_executor/input_metadata.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, Union import torch @@ -31,8 +31,7 @@ def __init__( self.use_cuda_graph = use_cuda_graph # Set during the execution of the first attention op. - # FIXME(woosuk): This is a hack. - self.attn_bias = None + self.attn_bias: Union[torch.Tensor, None] = None def __repr__(self) -> str: return ("InputMetadata(" diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py index 1af120d13cd4b..3d4ffe1e5c263 100644 --- a/vllm/model_executor/layers/activation.py +++ b/vllm/model_executor/layers/activation.py @@ -27,7 +27,7 @@ class SiluAndMul(nn.Module): def _forward(self, x: torch.Tensor) -> torch.Tensor: """PyTorch-native implementation equivalent to forward().""" d = x.shape[-1] // 2 - return F.silu(x[..., :d]) * x[..., d:] + return (F.silu(x[..., :d].float()) * x[..., d:].float()).to(x) def forward(self, x: torch.Tensor) -> torch.Tensor: d = x.shape[-1] // 2 diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index f1008ec8159f6..54821d0a2a2a7 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -3,9 +3,17 @@ import torch import torch.nn as nn -from xformers import ops as xops -from xformers.ops.fmha.attn_bias import (BlockDiagonalCausalMask, - LowerTriangularMaskWithTensorBias) +try: + from xformers import ops as xops +except: + pass + +try: + from xformers.ops.fmha.attn_bias import (BlockDiagonalCausalMask, + LowerTriangularMaskWithTensorBias) +except: + from vllm.model_executor.layers.xformers_cpu.attn_bias import (BlockDiagonalCausalMask, + LowerTriangularMaskWithTensorBias) from vllm._C import ops from vllm._C import cache_ops @@ -52,6 +60,8 @@ def __init__( assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads + self.cpu_only = torch.zeros((1)).is_cpu + if self.head_size not in _SUPPORTED_HEAD_SIZES: raise ValueError(f"head_size ({self.head_size}) is not supported. " f"Supported head sizes: {_SUPPORTED_HEAD_SIZES}.") @@ -126,8 +136,13 @@ def forward( if self.sliding_window is not None: attn_bias = attn_bias.make_local_attention( self.sliding_window) + if self.cpu_only: + attn_bias = attn_bias.materialize( + (1, seq_len * batch_size, seq_len * batch_size), + dtype=query.dtype) input_metadata.attn_bias = attn_bias else: + assert not self.cpu_only input_metadata.attn_bias = _make_alibi_bias( self.alibi_slopes, self.num_kv_heads, batch_size, seq_len, query.dtype) @@ -152,7 +167,11 @@ def forward( scale=self.scale, op=xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if (is_hip()) else None, - ) + ) if not self.cpu_only else torch.nn.functional.scaled_dot_product_attention( + query.movedim(1, query.dim() -2), key.movedim(1, query.dim() - 2), + value.movedim(1, value.dim() - 2), + input_metadata.attn_bias, + 0.0).movedim(query.dim() - 2, 1).contiguous() output = out.view_as(query) else: # Decoding run. @@ -229,7 +248,7 @@ def _paged_attention( # For context len > 8192, use V2 kernel to avoid shared memory shortage. use_v1 = input_metadata.max_context_len <= 8192 and ( max_num_partitions == 1 or num_seqs * num_heads > 512) - if use_v1: + if use_v1 or query.is_cpu: # Run PagedAttention V1. ops.paged_attention_v1( output, diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 5e1d63a6a62eb..55d38b763b2b5 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -54,7 +54,6 @@ def create_weights(self, input_size_per_partition: int, params_dtype: torch.dtype) -> Dict[str, Any]: weight = Parameter(torch.empty(output_size_per_partition, input_size_per_partition, - device=torch.cuda.current_device(), dtype=params_dtype), requires_grad=False) set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) @@ -113,9 +112,7 @@ def __init__( self.register_parameter(name, weight) if bias: self.bias = Parameter( - torch.empty(self.output_size, - device=torch.cuda.current_device(), - dtype=self.params_dtype)) + torch.empty(self.output_size, dtype=self.params_dtype)) set_weight_attrs(self.bias, {"output_dim": 0}) else: self.register_parameter("bias", None) @@ -183,7 +180,6 @@ def __init__( if bias: self.bias = Parameter( torch.empty(self.output_size_per_partition, - device=torch.cuda.current_device(), dtype=params_dtype)) set_weight_attrs(self.bias, { "output_dim": 0, @@ -509,9 +505,7 @@ def __init__( if bias: self.bias = Parameter( - torch.empty(self.output_size, - device=torch.cuda.current_device(), - dtype=params_dtype)) + torch.empty(self.output_size, dtype=params_dtype)) set_weight_attrs(self.bias, { "output_dim": 0, "weight_loader": self.weight_loader, diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index 91c093e33e3c9..3ee7e054accd3 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -77,16 +77,13 @@ def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: # create the cache on GPU for faster initialization. This may cause # a slight numerical difference between the HF implementation and ours. inv_freq = 1.0 / (base**(torch.arange( - 0, self.rotary_dim, 2, dtype=torch.float, device="cuda") / - self.rotary_dim)) + 0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim)) return inv_freq def _compute_cos_sin_cache(self) -> torch.Tensor: """Compute the cos and sin cache.""" inv_freq = self._compute_inv_freq(self.base) - t = torch.arange(self.max_position_embeddings, - dtype=torch.float, - device="cuda") + t = torch.arange(self.max_position_embeddings, dtype=torch.float) freqs = torch.einsum("i,j -> ij", t, inv_freq) cos = freqs.cos() @@ -101,16 +98,19 @@ def _forward( key: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """PyTorch-native implementation equivalent to forward().""" + device = query.device + dtype = query.dtype + query = query.view(*query.shape[:-1], -1, self.head_size) key = key.view(*key.shape[:-1], -1, self.head_size) - query_rot = query[..., :self.rotary_dim] - key_rot = key[..., :self.rotary_dim] + query_rot = query[..., :self.rotary_dim].float() + key_rot = key[..., :self.rotary_dim].float() if self.rotary_dim < self.head_size: - query_pass = query[..., self.rotary_dim:] - key_pass = key[..., self.rotary_dim:] + query_pass = query[..., self.rotary_dim:].float() + key_pass = key[..., self.rotary_dim:].float() - cos_sin = self.cos_sin_cache[positions] + cos_sin = self.cos_sin_cache[positions].float() cos, sin = cos_sin.chunk(2, dim=-1) if self.is_neox_style: # NOTE(woosuk): Here we assume that the positions tensor has the @@ -131,8 +131,8 @@ def _forward( else: query = query_rot key = key_rot - query = query.flatten(-2) - key = key.flatten(-2) + query = query.flatten(-2).to(dtype=dtype, device=device) + key = key.flatten(-2).to(dtype=dtype, device=device) return query, key def forward( diff --git a/vllm/model_executor/layers/vocab_parallel_embedding.py b/vllm/model_executor/layers/vocab_parallel_embedding.py index b08d5555b0faa..3a888bcb9aa42 100644 --- a/vllm/model_executor/layers/vocab_parallel_embedding.py +++ b/vllm/model_executor/layers/vocab_parallel_embedding.py @@ -68,7 +68,6 @@ def __init__(self, self.weight = Parameter( torch.empty(self.num_embeddings_per_partition, self.embedding_dim, - device=torch.cuda.current_device(), dtype=params_dtype)) set_weight_attrs(self.weight, { "parallel_dim": 0, @@ -125,7 +124,6 @@ def __init__(self, if bias: self.bias = Parameter( torch.empty(self.num_embeddings_per_partition, - device=torch.cuda.current_device(), dtype=params_dtype)) set_weight_attrs(self.bias, { "parallel_dim": 0, diff --git a/vllm/model_executor/layers/xformers_cpu/__init__.py b/vllm/model_executor/layers/xformers_cpu/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/model_executor/layers/xformers_cpu/attn_bias.py b/vllm/model_executor/layers/xformers_cpu/attn_bias.py new file mode 100644 index 0000000000000..89b9d53469c87 --- /dev/null +++ b/vllm/model_executor/layers/xformers_cpu/attn_bias.py @@ -0,0 +1,929 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +import math +from dataclasses import dataclass +from typing import Any, Iterable, List, Optional, Sequence, Tuple, Union + +import torch + + +class AttentionBias: + """Base class for a custom bias that can be applied \ + as the attn_bias argument in + :attr:`xformers.ops.memory_efficient_attention`. + + That function has the ability to add a tensor, the + attention bias, to the QK^T matrix before it is used + in the softmax part of the attention calculation. + The attention bias tensor with shape + (B or 1, n_queries, number of keys) + can be given as the attn_bias input. + The most common use case is for an attention bias is + to contain only zeros and negative infinities, which forms + a mask so that some queries only attend to some keys. + + Children of this class define alternative things which can + be used as the attn_bias input to define an attention bias which + forms such a mask, for some common cases. + + When using an :attr:`xformers.ops.AttentionBias` + instead of a :attr:`torch.Tensor`, the mask matrix does + not need to be materialized, and can be + hardcoded into some kernels for better performance. + + See: + + - :attr:`xformers.ops.fmha.attn_bias.LowerTriangularMask` + - :attr:`xformers.ops.fmha.attn_bias.LowerTriangularFromBottomRightMask` + - :attr:`xformers.ops.fmha.attn_bias.LowerTriangularMaskWithTensorBias` + - :attr:`xformers.ops.fmha.attn_bias.BlockDiagonalMask` + - :attr:`xformers.ops.fmha.attn_bias.BlockDiagonalCausalMask` + + """ + + def materialize( + self, + shape: Tuple[int, ...], + dtype: torch.dtype = torch.float32, + device: Union[str, torch.device] = "cpu", + ) -> torch.Tensor: + """ + Materializes the bias as a `torch.Tensor`. This is very slow + and we don't attempt to make it fast. Only use for debugging/testing. + + Shape should be like `[*, q_seqlen, k_seqlen]` + """ + raise NotImplementedError() + + +def _materialize_causal_mask( + shape: Tuple[int, ...], + dtype: torch.dtype = torch.float32, + device: Union[str, torch.device] = "cpu", + *, + window_size: Optional[int] = None, + from_bottomright: bool = False, +) -> torch.Tensor: + create_as = dtype if dtype is not torch.bfloat16 else torch.float32 + tensor = torch.full( # type: ignore + shape, + dtype=create_as, + fill_value=1, + device=device, + ) + + num_queries, num_keys = shape[-2:] + shift = 0 + if from_bottomright: + shift = num_keys - num_queries + + mask = torch.tril(tensor, diagonal=shift).to(dtype) # type: ignore + if window_size is not None: + mask = torch.triu(mask, diagonal=shift - window_size + 1) + mask = torch.log(mask) + return mask.to(dtype) + + +@dataclass +class LocalAttentionFromBottomRightMask(AttentionBias): + """ + A local attention mask + + The query at position :math:`q` can attend the key at position :math:`k` if + :math:`q - window\\_left <= k + s <= q + window\\_right` + + With :math:`s = num\\_queries - num\\_keys` + + :Example: + + .. code-block:: python + + import torch + from xformers.ops import fmha + + bias = fmha.attn_bias.LocalAttentionFromBottomRightMask(window_left=1, window_right=2) + print(bias.materialize(shape=(4, 4)).exp()) + print(bias.materialize(shape=(4, 5)).exp()) + + .. code-block:: text + + # 4x4 + tensor([[1., 1., 1., 0.], + [1., 1., 1., 1.], + [0., 1., 1., 1.], + [0., 0., 1., 1.]]) + + # 4x5 + tensor([[1., 1., 1., 1., 0.], + [0., 1., 1., 1., 1.], + [0., 0., 1., 1., 1.], + [0., 0., 0., 1., 1.]]) + + :Illustration: + + .. figure:: /_static/local_attn.png + :width: 240px + + The total window size is :math:`window\\_left + 1 + window\\_right` + """ + + window_left: int + window_right: int + + def __post_init__(self) -> None: + if self.window_left < 0: + raise ValueError( + "Invalid window value passed to " + "`LocalAttentionFromBottomRightMask`: expected" + f"`window_left > 0` but got window_left={self.window_left}" + ) + if self.window_right < 0: + raise ValueError( + "Invalid window value passed to " + "`LocalAttentionFromBottomRightMask`: expected" + f"`window_right > 0` but got window_right={self.window_right}" + ) + + def materialize( + self, + shape: Tuple[int, ...], + dtype: torch.dtype = torch.float32, + device: Union[str, torch.device] = "cpu", + ) -> torch.Tensor: + create_as = dtype if dtype is not torch.bfloat16 else torch.float32 + mask = torch.full( # type: ignore + shape, + dtype=create_as, + fill_value=1, + device=device, + ) + + num_queries, num_keys = shape[-2:] + shift = num_keys - num_queries + + mask = torch.triu(mask, diagonal=shift - self.window_left) + mask = torch.tril(mask, diagonal=shift + self.window_right) + mask = torch.log(mask) + return mask.to(dtype) + + +class LowerTriangularMask(AttentionBias): + """ + A lower-triangular (aka causal) mask + + A query Q cannot attend to a key which is farther from the + initial key than Q is from the initial query. + + See also :attr:`LowerTriangularFromBottomRightMask` if the number + of queries is not equal to the number of keys/values. + """ + + def __init__(self, *tensor_args, **tensor_kwargs) -> None: + # NOTE: Unused arguments, we keep them for backward compatibility + super().__init__() + + def materialize( + self, + shape: Tuple[int, ...], + dtype: torch.dtype = torch.float32, + device: Union[str, torch.device] = "cpu", + ) -> torch.Tensor: + return _materialize_causal_mask(shape, dtype=dtype, device=device) + + def add_bias(self, bias: torch.Tensor) -> "LowerTriangularMaskWithTensorBias": + """ + Creates a new causal mask with an arbitrary ``torch.Tensor`` bias + """ + return LowerTriangularMaskWithTensorBias(bias) + + +class LowerTriangularFromBottomRightMask(AttentionBias): + """ + A causal masking. + + This mask is exactly the same as :attr:`LowerTriangularMask` when there is + the same number of queries and keys. + When the number of queries is different from the number of keys, + it is a triangular mask shifted so that the last query can attend to + the last key. + In other words, a query Q cannot attend to a key which is nearer the + final key than Q is to the final query. + + + .. figure:: /_static/causal_bottom_right.png + + The difference between :attr:`LowerTriangularMask` (left) and + :attr:`LowerTriangularFromBottomRightMask` (right). They become + equivalent if the number of queries equals the number of keys. + """ + + def materialize( + self, + shape: Tuple[int, ...], + dtype: torch.dtype = torch.float32, + device: Union[str, torch.device] = "cpu", + ) -> torch.Tensor: + return _materialize_causal_mask( + shape, dtype=dtype, device=device, from_bottomright=True + ) + + def make_local_attention( + self, window_size: int + ) -> "LowerTriangularFromBottomRightLocalAttentionMask": + """ + Create a new bias which combines local + causal attention. + + See :attr:`LowerTriangularFromBottomRightLocalAttentionMask` + """ + return LowerTriangularFromBottomRightLocalAttentionMask(window_size) + + +@dataclass +class LowerTriangularFromBottomRightLocalAttentionMask( + LowerTriangularFromBottomRightMask +): + """ + A mask that combines both :attr:`LowerTriangularFromBottomRightMask` and + local attention. + + A query whose distance from the final query is X cannot attend to a key + whose distance to the final key is either of: + + * less than X (i.e. "causal attention", same as :attr:`LowerTriangularFromBottomRightMask`) + * greater than X + window_size (i.e. "local attention") + + + .. figure:: /_static/causal_bottom_right_local.png + + The mask from :attr:`LowerTriangularFromBottomRightLocalAttentionMask`. + The green area is calculated, and the grey area is masked out. + """ + + _window_size: int + + def __post_init__(self) -> None: + if self._window_size <= 0: + raise ValueError( + f"Expected `window_size > 0`, but window_size={self._window_size}" + ) + + def materialize( + self, + shape: Tuple[int, ...], + dtype: torch.dtype = torch.float32, + device: Union[str, torch.device] = "cpu", + ) -> torch.Tensor: + return _materialize_causal_mask( + shape, + dtype=dtype, + device=device, + window_size=self._window_size, + from_bottomright=True, + ) + + +class LowerTriangularMaskWithTensorBias(LowerTriangularMask): + """A lower-triangular (aka causal) mask with an additive bias""" + + def __init__(self, bias: torch.Tensor) -> None: + self._bias = bias + + def materialize( + self, + shape: Tuple[int, ...], + dtype: torch.dtype = torch.float32, + device: Union[str, torch.device] = "cpu", + ) -> torch.Tensor: + return super().materialize(shape, dtype=dtype, device=device) + self._bias + + +@dataclass +class _SeqLenInfo: + """ + (Internal) Represents the division of a dimension into blocks. + + For example, to represents a dimension of length 7 divided into + three blocks of lengths 2, 3 and 2, use `from_seqlength([2, 3, 2])`. + The members will be: + max_seqlen: 3 + min_seqlen: 2 + seqstart_py: [0, 2, 5, 7] + seqstart: torch.IntTensor([0, 2, 5, 7]) + """ + + seqstart: torch.Tensor + max_seqlen: int + min_seqlen: int + seqstart_py: List[int] + + def to(self, device: torch.device) -> None: + self.seqstart = self.seqstart.to(device, non_blocking=True) + + def intervals(self) -> Iterable[Tuple[int, int]]: + yield from zip(self.seqstart_py, self.seqstart_py[1:]) + + @classmethod + def from_seqlens(cls, seqlens: Iterable[int]) -> "_SeqLenInfo": + """ + Input tensors are assumed to be in shape [B, M, *] + """ + assert not isinstance(seqlens, torch.Tensor) + seqstart_py = [0] + max_seqlen = -1 + min_seqlen = -1 + for seqlen in seqlens: + min_seqlen = min(min_seqlen, seqlen) if min_seqlen != -1 else seqlen + max_seqlen = max(max_seqlen, seqlen) + seqstart_py.append(seqstart_py[len(seqstart_py) - 1] + seqlen) + seqstart = torch.tensor(seqstart_py, dtype=torch.int32) + return cls( + max_seqlen=max_seqlen, + min_seqlen=min_seqlen, + seqstart=seqstart, + seqstart_py=seqstart_py, + ) + + def split( + self, x: torch.Tensor, batch_sizes: Optional[Sequence[int]] = None + ) -> List[torch.Tensor]: + if self.seqstart_py[-1] != x.shape[1] or x.shape[0] != 1: + raise ValueError( + f"Invalid `torch.Tensor` of shape {x.shape}, expected format " + f"(B, M, *) with B=1 and M={self.seqstart_py[-1]}\n" + f" seqstart: {self.seqstart_py}" + ) + if batch_sizes is None: + batch_sizes = [1] * (len(self.seqstart_py) - 1) + split_chunks = [] + it = 0 + for batch_size in batch_sizes: + split_chunks.append( + self.seqstart_py[it + batch_size] - self.seqstart_py[it] + ) + it += batch_size + return [ + tensor.reshape([bs, -1, *tensor.shape[2:]]) + for bs, tensor in zip(batch_sizes, x.split(split_chunks, dim=1)) + ] + + +@dataclass +class _PaddedSeqLenInfo(_SeqLenInfo): + """ + (Internal) Represents the division of a dimension into blocks which are + padded out to the same total length. + + For example, to represent a dimension of length 12 with space for + three blocks of length 4, but where the occupied lengths are + 2, 3 and 2, use `from_seqlens_padded([2, 3, 2], 4)`. + + The layout along the dimension is + + 0 ─► block 0 + block 0 + + + 4 ─► block 1 + block 1 + block 1 + + 8 ─► block 2 + block 2 + + + 12 ─► + + The members will be: + max_seqlen: 3 + min_seqlen: 2 + seqstart_py: [0, 4, 8, 12] + seqstart: torch.IntTensor([0, 4, 8, 12]) + seqlen_py: [2, 3, 2] + seqlen: torch.IntTensor([2, 3, 2]) + padding: 4 + """ + + seqlen: torch.Tensor + seqlen_py: Sequence[int] + padding: int + # From parent: seqstart[i] contains the start position + # of the i-th sequence + # seqstart: torch.Tensor + + def __post_init__(self) -> None: + assert len(self.seqstart_py) == len(self.seqlen_py) + 1 + + def to(self, device: torch.device) -> None: + self.seqlen = self.seqlen.to(device, non_blocking=True) + super().to(device) + + def intervals(self) -> Iterable[Tuple[int, int]]: + for (start, _), length in zip(super().intervals(), self.seqlen_py): + yield start, start + length + + @classmethod + def from_seqlens(cls, seqlens: Iterable[int]) -> "_SeqLenInfo": + raise RuntimeError( + "Use either `_SeqLenInfo.from_seqlens` or `_PaddedSeqLenInfo.from_seqlens_padded`" + ) + + @classmethod + def from_seqlens_padded( + cls, seqlens: Sequence[int], padding: int + ) -> "_PaddedSeqLenInfo": + """ + Input tensors are assumed to be in shape [B, M, *] + seqstart = padding * torch.arange(batch_size) + """ + assert not isinstance(seqlens, torch.Tensor) + assert all(seqlen <= padding for seqlen in seqlens) + seqstart_py = list(range(0, len(seqlens) * padding + 1, padding)) + return cls( + seqlen=torch.tensor(seqlens, dtype=torch.int32), + seqlen_py=seqlens, + max_seqlen=max(seqlens), + min_seqlen=min(seqlens), + seqstart=torch.tensor(seqstart_py, dtype=torch.int32), + seqstart_py=seqstart_py, + padding=padding, + ) + + def split( + self, x: torch.Tensor, batch_sizes: Optional[Sequence[int]] = None + ) -> List[torch.Tensor]: + raise NotImplementedError("_PaddedSeqLenInfo.split") + + +@dataclass +class BlockDiagonalMask(AttentionBias): + """ + A block-diagonal mask that can be passed as ``attn_bias`` + argument to :attr:`xformers.ops.memory_efficient_attention`. + + Queries and Keys are each divided into the same number of blocks. + Queries in block i only attend to keys in block i. + + .. figure:: /_static/block_diag_bias.png + + This bias can be used to handle a batch of sequences of + different lengths, via :attr:`BlockDiagonalMask.from_tensor_list` + + :Example: + + .. code-block:: python + + import torch + from xformers.ops import fmha + + K = 16 + dtype = torch.float16 + device = "cuda" + list_x = [ + torch.randn([1, 3, 1, K], dtype=dtype, device=device), + torch.randn([1, 6, 1, K], dtype=dtype, device=device), + torch.randn([1, 2, 1, K], dtype=dtype, device=device), + ] + attn_bias, x = fmha.BlockDiagonalMask.from_tensor_list(list_x) + linear = torch.nn.Linear(K, K * 3).to(device=device, dtype=dtype) + + q, k, v = linear(x).reshape([1, -1, 1, 3, K]).unbind(-2) + out = fmha.memory_efficient_attention(q, k, v, attn_bias=attn_bias) + list_out = attn_bias.split(out) + print(list_out[0].shape) # [1, 3, 1, K] + assert tuple(list_out[0].shape) == (1, 3, 1, K) + + """ + + q_seqinfo: _SeqLenInfo + k_seqinfo: _SeqLenInfo + _batch_sizes: Optional[Sequence[int]] = None + + def _create_block_mask( + self, + shape: Tuple[int, ...], + dtype: torch.dtype = torch.float32, + device: Union[str, torch.device] = "cpu", + ) -> torch.Tensor: + return torch.zeros( + shape, + dtype=dtype, + device=device, + ) + + def materialize( + self, + shape: Tuple[int, ...], + dtype: torch.dtype = torch.float32, + device: Union[str, torch.device] = "cpu", + ) -> torch.Tensor: + """Materialize the attention bias - for debugging & testing""" + assert shape[-1] == self.k_seqinfo.seqstart_py[-1], ( + shape[-1], + self.k_seqinfo.seqstart_py[-1], + ) + assert shape[-2] == self.q_seqinfo.seqstart_py[-1], ( + shape[-2], + self.q_seqinfo.seqstart_py[-1], + ) + mask = torch.empty(shape[-2:], dtype=dtype, device=device) + mask.fill_(-math.inf) + for i, ((q_start, q_end), (k_start, k_end)) in enumerate( + zip( + self.q_seqinfo.intervals(), + self.k_seqinfo.intervals(), + ) + ): + mask[q_start:q_end, k_start:k_end] = self._create_block_mask( + (q_end - q_start, k_end - k_start), + dtype=dtype, + device=device, + ) + for _ in range(len(shape) - 2): + mask = mask.unsqueeze(0) + return mask.expand(shape) + + @classmethod + def from_seqlens( + cls, + q_seqlen: Sequence[int], + kv_seqlen: Optional[Sequence[int]] = None, + ) -> "BlockDiagonalMask": + """Creates a :attr:`BlockDiagonalMask` from a list of tensors lengths for query and key/value. + + Args: + q_seqlen (Union[Sequence[int], torch.Tensor]): List or tensor of sequence lengths for query tensors + kv_seqlen (Union[Sequence[int], torch.Tensor], optional): List or tensor of sequence lengths for key/value. + (Defaults to ``q_seqlen``.) + Returns: + BlockDiagonalMask + """ + assert kv_seqlen is None or len(q_seqlen) == len(kv_seqlen) + q_seqinfo = _SeqLenInfo.from_seqlens(q_seqlen) + if kv_seqlen is None or q_seqlen == kv_seqlen: + k_seqinfo = q_seqinfo + else: + k_seqinfo = _SeqLenInfo.from_seqlens(kv_seqlen) + return cls(q_seqinfo=q_seqinfo, k_seqinfo=k_seqinfo) + + @classmethod + def from_tensor_list( + cls, + tensors: Sequence[torch.Tensor], + ) -> Tuple["BlockDiagonalMask", torch.Tensor]: + """Creates a :attr:`BlockDiagonalMask` from a list of tensors, and returns the tensors + concatenated on the sequence length dimension + + .. figure:: /_static/block_diag_cat_split.png + + See also :attr:`BlockDiagonalMask.split` to split the returned + :attr:`torch.Tensor` back to a list of tensors of varying sequence length + + Args: + tensors (Sequence[torch.Tensor]): A list of tensors of shape ``[B, M_i, *]``. + All tensors should have the same dimension and the same batch size ``B``, but + they can have different sequence length ``M``. + + Returns: + Tuple[BlockDiagonalMask, torch.Tensor]: The corresponding bias for the attention + along with `tensors` concatenated on the sequence length dimension, with shape ``[1, sum_i{M_i}, *]`` + """ + batch_sizes = [tensor.shape[0] for tensor in tensors] + seqlens = [] + for x in tensors: + for _ in range(x.shape[0]): + seqlens.append(x.shape[1]) + block_diag = cls.from_seqlens(seqlens) + block_diag._batch_sizes = batch_sizes + tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in tensors) + concat_tensors = torch.cat(tensors_bs1, dim=1) + return block_diag, concat_tensors + + @classmethod + def from_tensor_lists_qkv( + cls, + tensors_q: Sequence[torch.Tensor], + tensors_k: Sequence[torch.Tensor], + tensors_v: Optional[Sequence[torch.Tensor]] = None, + ) -> Tuple["BlockDiagonalMask", torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + assert len(tensors_q) == len(tensors_k) + assert tensors_v is None or len(tensors_v) == len(tensors_q) + batch_sizes = [tensor.shape[0] for tensor in tensors_q] + q_seqlens, kv_seqlens = [], [] + for i, (q, k) in enumerate(zip(tensors_q, tensors_k)): + assert q.shape[0] == k.shape[0] + q_seqlens += [q.shape[1]] * q.shape[0] + kv_seqlens += [k.shape[1]] * k.shape[0] + assert tensors_v is None or tensors_v[i].shape[:2] == k.shape[:2] + block_diag = cls.from_seqlens(q_seqlens, kv_seqlens) + block_diag._batch_sizes = batch_sizes + return ( + block_diag, + torch.cat([x.reshape([1, -1, *x.shape[2:]]) for x in tensors_q], dim=1), + torch.cat([x.reshape([1, -1, *x.shape[2:]]) for x in tensors_k], dim=1), + torch.cat([x.reshape([1, -1, *x.shape[2:]]) for x in tensors_v], dim=1) + if tensors_v is not None + else None, + ) + + def split_queries(self, tensor: torch.Tensor) -> Sequence[torch.Tensor]: + return self.q_seqinfo.split(tensor, self._batch_sizes) + + def split_kv(self, tensor: torch.Tensor) -> Sequence[torch.Tensor]: + return self.k_seqinfo.split(tensor, self._batch_sizes) + + def split(self, tensor: torch.Tensor) -> Sequence[torch.Tensor]: + """The inverse operation of :attr:`BlockDiagonalCausalMask.from_tensor_list` + + Args: + tensor (torch.Tensor): Tensor of tokens of shape ``[1, sum_i{M_i}, *]`` + + Returns: + Sequence[torch.Tensor]: A list of tokens with possibly different sequence lengths + """ + assert self.q_seqinfo is self.k_seqinfo + return self.q_seqinfo.split(tensor, self._batch_sizes) + + def make_causal(self) -> "BlockDiagonalCausalMask": + """Makes each block causal""" + return BlockDiagonalCausalMask( + q_seqinfo=self.q_seqinfo, + k_seqinfo=self.k_seqinfo, + _batch_sizes=self._batch_sizes, + ) + + def make_causal_from_bottomright(self) -> "BlockDiagonalCausalFromBottomRightMask": + """Makes each block causal with a possible non-causal prefix""" + return BlockDiagonalCausalFromBottomRightMask( + q_seqinfo=self.q_seqinfo, + k_seqinfo=self.k_seqinfo, + _batch_sizes=self._batch_sizes, + ) + + def make_local_attention( + self, window_size: int + ) -> "BlockDiagonalCausalLocalAttentionMask": + """Experimental: Makes each block causal with local attention""" + return BlockDiagonalCausalLocalAttentionMask( + q_seqinfo=self.q_seqinfo, + k_seqinfo=self.k_seqinfo, + _batch_sizes=self._batch_sizes, + _window_size=window_size, + ) + + def make_local_attention_from_bottomright( + self, window_size: int + ) -> "BlockDiagonalCausalLocalAttentionFromBottomRightMask": + """Experimental: Makes each block causal with local attention, start from bottom right""" + return BlockDiagonalCausalLocalAttentionFromBottomRightMask( + q_seqinfo=self.q_seqinfo, + k_seqinfo=self.k_seqinfo, + _batch_sizes=self._batch_sizes, + _window_size=window_size, + ) + + +@dataclass +class BlockDiagonalCausalMask(BlockDiagonalMask): + """ + Same as :attr:`xformers.ops.fmha.attn_bias.BlockDiagonalMask`, except that each block is causal. + + Queries and Keys are each divided into the same number of blocks. + A query Q in block i cannot attend to a key which is not in block i, + nor one which is farther from the initial key in block i than Q + is from the initial query in block i. + """ + + def _create_block_mask( + self, + shape: Tuple[int, ...], + dtype: torch.dtype = torch.float32, + device: Union[str, torch.device] = "cpu", + ) -> torch.Tensor: + return LowerTriangularMask().materialize( + shape, + dtype=dtype, + device=device, + ) + + +@dataclass +class BlockDiagonalCausalFromBottomRightMask(BlockDiagonalMask): + """ + Same as :attr:`xformers.ops.fmha.attn_bias.BlockDiagonalMask`, except that each block is causal. + This mask allows for a non-causal prefix + NOTE: Each block should have `num_keys >= num_queries` otherwise the forward pass is not + defined (softmax of vector of `-inf` in the attention) + + Queries and keys are each divided into the same number of blocks. + A query Q in block i cannot attend to a key which is not in block i, + nor one which nearer the final key in block i than Q is to the + final query in block i. + """ + + def __post_init__(self) -> None: + for i, ((q_start, q_end), (k_start, k_end)) in enumerate( + zip( + self.q_seqinfo.intervals(), + self.k_seqinfo.intervals(), + ) + ): + num_queries = q_end - q_start + num_keys = k_end - k_start + if num_keys < num_queries: + raise ValueError( + f"Block #{i} has num_keys={num_keys} and num_queries={num_queries}." + " Expected `num_keys >= num_queries`" + ) + + def _create_block_mask( + self, + shape: Tuple[int, ...], + dtype: torch.dtype = torch.float32, + device: Union[str, torch.device] = "cpu", + ) -> torch.Tensor: + return LowerTriangularFromBottomRightMask().materialize( + shape=shape, dtype=dtype, device=device + ) + + +@dataclass +class BlockDiagonalCausalWithOffsetPaddedKeysMask(AttentionBias): + """ + Same as :attr:`xformers.ops.fmha.attn_bias.BlockDiagonalCausalMask`, + except an offset on causality is allowed for each block and we support padding for k/v + + The keys and values are divided into blocks which are padded out to + the same total length. + For example, if there is space for 12 keys, for three blocks of + max length 4, but we only want to use the first 2, 3 and 2 + of each block, use `kv_padding=4` and `kv_seqlens=[2, 3, 2]`. + The queries are divided into blocks, without padding, of lengths given by + q_seqlen. + + A query Q in block i cannot attend to a key which is not in block i, + nor one which is not in use (i.e. in the padded area), + nor one which is nearer to the final key in block i + than Q is to the final query in block i. + """ + + q_seqinfo: _SeqLenInfo + k_seqinfo: _PaddedSeqLenInfo + causal_diagonal: Any = None # unused. Exists for BC only. + + def _create_block_mask( + self, + shape: Tuple[int, ...], + dtype: torch.dtype = torch.float32, + device: Union[str, torch.device] = "cpu", + ) -> torch.Tensor: + return LowerTriangularFromBottomRightMask().materialize( + shape=shape, dtype=dtype, device=device + ) + + def materialize( + self, + shape: Tuple[int, ...], + dtype: torch.dtype = torch.float32, + device: Union[str, torch.device] = "cpu", + ) -> torch.Tensor: + """Materialize the attention bias - for debugging & testing""" + if shape[-1] != self.k_seqinfo.seqstart_py[-1]: + raise ValueError("k shapes wrong") + if shape[-2] != self.q_seqinfo.seqstart_py[-1]: + raise ValueError("q shapes wrong") + mask = torch.empty(shape[-2:], dtype=dtype, device=device) + mask.fill_(-math.inf) + for i, ((q_start, q_end), (k_start, k_end)) in enumerate( + zip( + self.q_seqinfo.intervals(), + self.k_seqinfo.intervals(), + ) + ): + mask[q_start:q_end, k_start:k_end] = self._create_block_mask( + (q_end - q_start, k_end - k_start), + dtype=dtype, + device=device, + ) + for _ in range(len(shape) - 2): + mask = mask.unsqueeze(0) + return mask.expand(shape) + + @classmethod + def from_seqlens( + cls, + q_seqlen: Sequence[int], + kv_padding: int, + kv_seqlen: Sequence[int], + causal_diagonal: Any = None, + ) -> "BlockDiagonalCausalWithOffsetPaddedKeysMask": + """Creates a :attr:`BlockDiagonalCausalWithOffsetPaddedKeysMask` from a list of tensor + lengths for query and key/value. + + Args: + q_seqlen (Sequence[int]): List or tensor of sequence lengths for query tensors + kv_padding (int): Padding for k/v - also an upperbound on each individual key length + kv_seqlen (Sequence[int]): List or tensor of sequence lengths for key/value. + causal_diagonal: unused, for BC only + Returns: + BlockDiagonalCausalWithOffsetPaddedKeysMask + """ + assert kv_seqlen is None or len(q_seqlen) == len(kv_seqlen), ( + q_seqlen, + kv_seqlen, + ) + q_seqinfo = _SeqLenInfo.from_seqlens(q_seqlen) + k_seqinfo = _PaddedSeqLenInfo.from_seqlens_padded(kv_seqlen, kv_padding) + return cls(q_seqinfo=q_seqinfo, k_seqinfo=k_seqinfo) + + +@dataclass +class BlockDiagonalCausalLocalAttentionMask(BlockDiagonalCausalMask): + """ + (Experimental feature) + Same as :attr:`xformers.ops.fmha.attn_bias.BlockDiagonalCausalMask`. + This makes the mask "local" and the attention pattern banded. + + Query i only attends to keys in its block and cannot attend keys further than "window_size" + from it. + """ + + _window_size: int = 0 # forced due to inheritance and default arguments + + def __post_init__(self): + if self._window_size <= 0: + raise ValueError( + f"Expected `window_size > 0`, but window_size={self._window_size}" + ) + q_seqlen = [ + y - x + for x, y in zip( + self.q_seqinfo.seqstart_py[:-1], self.q_seqinfo.seqstart_py[1:] + ) + ] + kv_seqlen = [ + y - x + for x, y in zip( + self.k_seqinfo.seqstart_py[:-1], self.k_seqinfo.seqstart_py[1:] + ) + ] + for q, k in zip(q_seqlen, kv_seqlen): + if q - self._window_size >= k: + # Each query only attends to keys no further than window_size back. + # When q > k + window_size, there will be a query for which the window doesn't reach any key. + raise RuntimeError( + f"No keys are attended in q_seqlen {q} k_seqlen {k} with sliding window {self._window_size}" + ) + + def _create_block_mask( + self, + shape: Tuple[int, ...], + dtype: torch.dtype = torch.float32, + device: Union[str, torch.device] = "cpu", + ) -> torch.Tensor: + return _materialize_causal_mask( + shape, + dtype=dtype, + device=device, + window_size=self._window_size, + ) + + +@dataclass +class BlockDiagonalCausalLocalAttentionFromBottomRightMask( + BlockDiagonalCausalFromBottomRightMask +): + """ + (Experimental feature) + Same as :attr:`xformers.ops.fmha.attn_bias.BlockDiagonalCausalMask`. + This makes the mask "local" and the attention pattern banded. + + Query i only attends to keys in its block and cannot attend keys further than "window_size" + from it. + """ + + _window_size: int = 0 # forced due to inheritance and default arguments + + def __post_init__(self): + super().__post_init__() + if self._window_size <= 0: + raise ValueError( + f"Expected `window_size > 0`, but window_size={self._window_size}" + ) + + def _create_block_mask( + self, + shape: Tuple[int, ...], + dtype: torch.dtype = torch.float32, + device: Union[str, torch.device] = "cpu", + ) -> torch.Tensor: + return _materialize_causal_mask( + shape, + dtype=dtype, + device=device, + window_size=self._window_size, + from_bottomright=True, + ) \ No newline at end of file diff --git a/vllm/model_executor/model_loader.py b/vllm/model_executor/model_loader.py index 37543d8c9838e..f1b2511fe1ffb 100644 --- a/vllm/model_executor/model_loader.py +++ b/vllm/model_executor/model_loader.py @@ -21,6 +21,15 @@ def _set_default_torch_dtype(dtype: torch.dtype): torch.set_default_dtype(old_dtype) +@contextlib.contextmanager +def _set_default_torch_device(device: torch.device): + """Sets the default torch dtype to the given dtype.""" + old_device = torch.zeros((1, 1)).device + torch.set_default_device(device) + yield + torch.set_default_device(old_device) + + def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]: architectures = getattr(config, "architectures", []) for arch in architectures: @@ -58,11 +67,12 @@ def get_model(model_config: ModelConfig) -> nn.Module: f"{supported_dtypes}") linear_method = quant_config.get_linear_method() - with _set_default_torch_dtype(model_config.dtype): + with _set_default_torch_dtype( + model_config.dtype), _set_default_torch_device( + model_config.device): # Create a model instance. # The weights will be initialized as empty tensors. - with torch.device("cuda"): - model = model_class(model_config.hf_config, linear_method) + model = model_class(model_config.hf_config, linear_method) if model_config.load_format == "dummy": # NOTE(woosuk): For accurate performance evaluation, we assign # random values to the weights. diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py index 2d41d40e04678..901d0a035f407 100644 --- a/vllm/model_executor/sampling_metadata.py +++ b/vllm/model_executor/sampling_metadata.py @@ -152,7 +152,7 @@ def from_lists(cls, temperatures: List[float], top_ps: List[float], dtype: torch.dtype) -> "SamplingTensors": # Note that the performance will be very bad without # pinned memory. - pin_memory = not in_wsl() + pin_memory = not (in_wsl() or device.type == "cpu") prompt_max_len = max(len(tokens) for tokens in prompt_tokens) prompt_padded_tokens = [ tokens + [vocab_size] * (prompt_max_len - len(tokens)) diff --git a/vllm/utils.py b/vllm/utils.py index 4d82f92129c95..105783913f513 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -7,10 +7,10 @@ import psutil import torch +import os from vllm._C import cuda_utils - class Device(enum.Enum): GPU = enum.auto() CPU = enum.auto() diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index 1dd0243f8f3a3..95b554ab43df4 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -44,11 +44,15 @@ def __init__( self.gpu_cache = self.allocate_gpu_cache() self.cpu_cache = self.allocate_cpu_cache() - # Initialize the stream for caching operations. - self.cache_stream = torch.cuda.Stream() - assert self.cache_stream != torch.cuda.current_stream() - # Initialize the events for stream synchronization. - self.events = [torch.cuda.Event() for _ in range(self.num_layers)] + if not cache_config.cpu_only: + # Initialize the stream for caching operations. + self.cache_stream = torch.cuda.Stream() + assert self.cache_stream != torch.cuda.current_stream() + # Initialize the events for stream synchronization. + self.events = [torch.cuda.Event() for _ in range(self.num_layers)] + else: + self.cache_stream = None + self.events = None def get_key_block_shape(self) -> Tuple[int, int, int, int]: element_size = torch.tensor([], dtype=self.dtype).element_size() @@ -69,6 +73,9 @@ def get_value_block_shape(self) -> Tuple[int, int, int]: def allocate_gpu_cache(self) -> List[KVCache]: gpu_cache: List[KVCache] = [] + if self.num_gpu_blocks == 0: + return gpu_cache + key_block_shape = self.get_key_block_shape() value_block_shape = self.get_value_block_shape() for _ in range(self.num_layers): @@ -95,6 +102,7 @@ def allocate_cpu_cache(self) -> List[KVCache]: # https://docs.nvidia.com/cuda/wsl-user-guide/index.html#known-limitations-for-linux-cuda-applications logger.warning("Using 'pin_memory=False' as WSL is detected. " "This may slow down the performance.") + pin_memory = not self.cache_config.cpu_only for _ in range(self.num_layers): key_blocks = torch.empty( size=(self.num_cpu_blocks, *key_block_shape), @@ -134,8 +142,12 @@ def swap_out(self, src_to_dst: Dict[int, int]) -> None: self._swap(self.gpu_cache, self.cpu_cache, src_to_dst) def copy(self, src_to_dsts: Dict[int, List[int]]) -> None: - key_caches = [key_cache for key_cache, _ in self.gpu_cache] - value_caches = [value_cache for _, value_cache in self.gpu_cache] + if self.cache_config.cpu_only: + key_caches = [key_cache for key_cache, _ in self.cpu_cache] + value_caches = [value_cache for _, value_cache in self.cpu_cache] + else: + key_caches = [key_cache for key_cache, _ in self.gpu_cache] + value_caches = [value_cache for _, value_cache in self.gpu_cache] # NOTE(woosuk): This operation implicitly synchronizes the CPU and GPU. cache_ops.copy_blocks(key_caches, value_caches, src_to_dsts) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 460d9907e88cd..060484b1f003f 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -43,6 +43,7 @@ def __init__( if model_config is not None else None) self.model = None self.block_size = None # Set after initial profiling. + self.device = self.model_config.device self.graph_runners: Dict[int, CUDAGraphRunner] = {} self.graph_memory_pool = None # Set during graph capture. @@ -128,15 +129,18 @@ def _prepare_prompt( input_tokens = _make_tensor_with_pad(input_tokens, max_prompt_len, pad=0, - dtype=torch.long) + dtype=torch.long, + device=self.device) input_positions = _make_tensor_with_pad(input_positions, max_prompt_len, pad=0, - dtype=torch.long) + dtype=torch.long, + device=self.device) slot_mapping = _make_tensor_with_pad(slot_mapping, max_prompt_len, pad=_PAD_SLOT_ID, - dtype=torch.long) + dtype=torch.long, + device=self.device) input_metadata = InputMetadata( is_prompt=True, @@ -207,24 +211,28 @@ def _prepare_decode( block_tables.append([]) batch_size = graph_batch_size + # When using CUDA graph, we don't need to make the tensors on the GPU + # because they will be eventually copied to the designated GPU buffer. + device = "cpu" if use_captured_graph or self.device.type == "cpu" else "cuda" + pin_memory = use_captured_graph and not self.in_wsl and self.device.type == "cuda" input_tokens = _make_tensor_with_pad(input_tokens, max_len=1, pad=0, dtype=torch.long, - device="cuda") + device=device) input_positions = _make_tensor_with_pad(input_positions, max_len=1, pad=0, dtype=torch.long, - device="cuda") + device=device) slot_mapping = _make_tensor_with_pad(slot_mapping, max_len=1, pad=_PAD_SLOT_ID, dtype=torch.long, - device="cuda") + device=device) context_lens = torch.tensor(context_lens, dtype=torch.int, - device="cuda") + device=device) if use_captured_graph: # The shape of graph_block_tables is @@ -242,7 +250,7 @@ def _prepare_decode( max_len=max_block_table_len, pad=0, dtype=torch.int, - device="cuda", + device=device, ) input_metadata = InputMetadata( @@ -306,9 +314,11 @@ def _prepare_sample( selected_token_indices = _async_h2d(selected_token_indices, dtype=torch.long, - pin_memory=not self.in_wsl) + device=self.device, + pin_memory=not self.in_wsl and not self.device.type == "cpu") categorized_sample_indices = { - t: _async_h2d(seq_ids, dtype=torch.int, pin_memory=not self.in_wsl) + t: _async_h2d(seq_ids, dtype=torch.int, device=self.device, + pin_memory=not self.in_wsl and not self.device.type == "cpu") for t, seq_ids in categorized_sample_indices.items() } @@ -667,6 +677,8 @@ def _get_graph_batch_size(batch_size: int) -> int: return (batch_size + 7) // 8 * 8 -def _async_h2d(data: list, dtype, pin_memory): +def _async_h2d(data: list, dtype: torch.dtype, + device: Union[str, torch.device] = "cuda", + pin_memory: bool = False,): t = torch.tensor(data, dtype=dtype, pin_memory=pin_memory) - return t.to(device="cuda", non_blocking=True) + return t.to(device=device, non_blocking=True) diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index c2a2ac148085b..3794325fc7486 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -53,22 +53,27 @@ def __init__( self.cache_engine = None self.cache_events = None self.gpu_cache = None + self.cpu_cache = None def init_model(self) -> None: - # torch.distributed.all_reduce does not free the input tensor until - # the synchronization point. This causes the memory usage to grow - # as the number of all_reduce calls increases. This env var disables - # this behavior. - # Related issue: - # https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573 - os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" + if self.model_config.device == torch.device('cpu'): + self.rank = 0 + self.device = torch.device("cpu") + else: + # torch.distributed.all_reduce does not free the input tensor until + # the synchronization point. This causes the memory usage to grow + # as the number of all_reduce calls increases. This env var disables + # this behavior. + # Related issue: + # https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573 + os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" - # This env var set by Ray causes exceptions with graph building. - os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None) - self.device = torch.device(f"cuda:{self.local_rank}") - torch.cuda.set_device(self.device) + # This env var set by Ray causes exceptions with graph building. + os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None) + self.device = torch.device(f"cuda:{self.local_rank}") + torch.cuda.set_device(self.device) - _check_if_gpu_supports_dtype(self.model_config.dtype) + _check_if_gpu_supports_dtype(self.model_config.dtype) # Initialize the distributed environment. _init_distributed_environment(self.parallel_config, self.rank, @@ -95,6 +100,15 @@ def profile_num_available_blocks( gpu_memory_utilization: The fraction of the total GPU memory to use. cpu_swap_space: The size of the CPU swap space in bytes. """ + if self.model_config.device == torch.device('cpu'): + cache_block_size = CacheEngine.get_cache_block_size( + block_size, self.model_config, self.parallel_config) + num_gpu_blocks = 0 + num_cpu_blocks = int(cpu_swap_space // cache_block_size) + num_cpu_blocks = max(num_cpu_blocks, 0) + + return num_gpu_blocks, num_cpu_blocks + # Profile the memory usage of the model and get the maximum number of # cache blocks that can be allocated with the remaining free memory. torch.cuda.empty_cache() @@ -126,6 +140,7 @@ def init_cache_engine(self, cache_config: CacheConfig) -> None: self.parallel_config) self.cache_events = self.cache_engine.events self.gpu_cache = self.cache_engine.gpu_cache + self.cpu_cache = self.cache_engine.cpu_cache self.model_runner.set_block_size(self.cache_engine.block_size) def warm_up_model(self) -> None: @@ -194,8 +209,14 @@ def execute_model( if num_seq_groups == 0: return {} + kv_caches = None + if self.model_config.device == torch.device('cpu'): + kv_caches = self.cpu_cache + else: + kv_caches = self.gpu_cache + output = self.model_runner.execute_model(seq_group_metadata_list, - self.gpu_cache) + kv_caches) return output @@ -217,15 +238,19 @@ def _init_distributed_environment( "distributed_init_method must be set if torch.distributed " "is not already initialized") else: + backend = "nccl" + if parallel_config.device == torch.device('cpu'): + backend = "gloo" + torch.distributed.init_process_group( - backend="nccl", + backend=backend, world_size=parallel_config.world_size, rank=rank, init_method=distributed_init_method, ) # A small all_reduce for warmup. - torch.distributed.all_reduce(torch.zeros(1).cuda()) + torch.distributed.all_reduce(torch.zeros(1, device=parallel_config.device)) initialize_model_parallel(parallel_config.tensor_parallel_size, parallel_config.pipeline_parallel_size)