diff --git a/.buildkite/nightly-benchmarks/benchmark-pipeline.yaml b/.buildkite/nightly-benchmarks/benchmark-pipeline.yaml index 3db77d5f16022..64ba1b32fb074 100644 --- a/.buildkite/nightly-benchmarks/benchmark-pipeline.yaml +++ b/.buildkite/nightly-benchmarks/benchmark-pipeline.yaml @@ -21,7 +21,7 @@ steps: podSpec: priorityClassName: perf-benchmark containers: - - image: public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT + - image: public.ecr.aws/q9t5s3a7/vllm-ci-postmerge-repo:$BUILDKITE_COMMIT command: - bash .buildkite/nightly-benchmarks/scripts/run-performance-benchmarks.sh resources: @@ -51,7 +51,7 @@ steps: queue: H200 plugins: - docker#v5.12.0: - image: public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT + image: public.ecr.aws/q9t5s3a7/vllm-ci-postmerge-repo:$BUILDKITE_COMMIT command: - bash - .buildkite/nightly-benchmarks/scripts/run-performance-benchmarks.sh @@ -65,13 +65,18 @@ steps: - VLLM_USAGE_SOURCE - HF_TOKEN + - block: "Run H100 Benchmark" + key: block-h100 + depends_on: ~ + - label: "H100" # skip: "use this flag to conditionally skip the benchmark step, useful for PR testing" agents: queue: H100 + depends_on: block-h100 plugins: - docker#v5.12.0: - image: public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT + image: public.ecr.aws/q9t5s3a7/vllm-ci-postmerge-repo:$BUILDKITE_COMMIT command: - bash - .buildkite/nightly-benchmarks/scripts/run-performance-benchmarks.sh diff --git a/.buildkite/nightly-benchmarks/scripts/wait-for-image.sh b/.buildkite/nightly-benchmarks/scripts/wait-for-image.sh index 19f7160e68a4d..aa0f7ade808e0 100644 --- a/.buildkite/nightly-benchmarks/scripts/wait-for-image.sh +++ b/.buildkite/nightly-benchmarks/scripts/wait-for-image.sh @@ -1,6 +1,6 @@ #!/bin/sh -TOKEN=$(curl -s -L "https://public.ecr.aws/token?service=public.ecr.aws&scope=repository:q9t5s3a7/vllm-ci-test-repo:pull" | jq -r .token) -URL="https://public.ecr.aws/v2/q9t5s3a7/vllm-ci-test-repo/manifests/$BUILDKITE_COMMIT" +TOKEN=$(curl -s -L "https://public.ecr.aws/token?service=public.ecr.aws&scope=repository:q9t5s3a7/vllm-ci-postmerge-repo:pull" | jq -r .token) +URL="https://public.ecr.aws/v2/q9t5s3a7/vllm-ci-postmerge-repo/manifests/$BUILDKITE_COMMIT" TIMEOUT_SECONDS=10 diff --git a/.buildkite/release-pipeline.yaml b/.buildkite/release-pipeline.yaml index f78e360b7afd3..93e118fb3eab8 100644 --- a/.buildkite/release-pipeline.yaml +++ b/.buildkite/release-pipeline.yaml @@ -1,7 +1,7 @@ steps: - label: "Build wheel - CUDA 12.1" agents: - queue: cpu_queue + queue: cpu_queue_postmerge commands: - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.1.0 --tag vllm-ci:build-image --target build --progress plain ." - "mkdir artifacts" @@ -18,7 +18,7 @@ steps: - label: "Build wheel - CUDA 11.8" # depends_on: block-build-cu118-wheel agents: - queue: cpu_queue + queue: cpu_queue_postmerge commands: - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=11.8.0 --tag vllm-ci:build-image --target build --progress plain ." - "mkdir artifacts" @@ -26,3 +26,16 @@ steps: - "bash .buildkite/upload-wheels.sh" env: DOCKER_BUILDKIT: "1" + + - block: "Build release image" + depends_on: ~ + key: block-release-image-build + + - label: "Build release image" + depends_on: block-release-image-build + agents: + queue: cpu_queue_postmerge + commands: + - "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7" + - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.1.0 --tag public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT --target vllm-openai --progress plain ." + - "docker push public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT" diff --git a/.buildkite/run-xpu-test.sh b/.buildkite/run-xpu-test.sh index faeac8e2ded36..e0a12afbe7320 100644 --- a/.buildkite/run-xpu-test.sh +++ b/.buildkite/run-xpu-test.sh @@ -12,5 +12,8 @@ remove_docker_container() { docker rm -f xpu-test || true; } trap remove_docker_container EXIT remove_docker_container -# Run the image and launch offline inference -docker run --network host --name xpu-test --device /dev/dri -v /dev/dri/by-path:/dev/dri/by-path --entrypoint="" xpu-test python3 examples/offline_inference.py +# Run the image and test offline inference/tensor parallel +docker run --name xpu-test --device /dev/dri -v /dev/dri/by-path:/dev/dri/by-path --entrypoint="" xpu-test sh -c ' + python3 examples/offline_inference.py + python3 examples/offline_inference_cli.py -tp 2 +' diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index fc23c9cff0d87..8f57006214c88 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -50,9 +50,9 @@ steps: - tests/multimodal - tests/test_utils - tests/worker - - tests/test_lazy_torch_compile.py + - tests/standalone_tests/lazy_torch_compile.py commands: - - python3 test_lazy_torch_compile.py + - python3 standalone_tests/lazy_torch_compile.py - pytest -v -s mq_llm_engine # MQLLMEngine - pytest -v -s async_engine # AsyncLLMEngine - NUM_SCHEDULER_STEPS=4 pytest -v -s async_engine/test_async_llm_engine.py @@ -61,6 +61,13 @@ steps: - pytest -v -s test_utils.py # Utils - pytest -v -s worker # Worker +- label: Python-only Installation Test + source_file_dependencies: + - tests/standalone_tests/python_only_compile.sh + - setup.py + commands: + - bash standalone_tests/python_only_compile.sh + - label: Basic Correctness Test # 30min #mirror_hardwares: [amd] fast_check: true @@ -230,7 +237,7 @@ steps: source_file_dependencies: - vllm/lora - tests/lora - command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore lora/test_long_context.py lora/test_chatglm3_tp.py lora/test_llama_tp.py + command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_long_context.py --ignore=lora/test_chatglm3_tp.py --ignore=lora/test_llama_tp.py parallelism: 4 - label: "PyTorch Fullgraph Smoke Test" # 9min @@ -334,7 +341,6 @@ steps: commands: - pytest -v -s models/decoder_only/language -m 'core_model or quant_model' - pytest -v -s models/embedding/language -m core_model - - pytest -v -s models/embedding/vision_language -m core_model - label: Language Models Test (Extended) # 50min optional: true @@ -346,7 +352,6 @@ steps: commands: - pytest -v -s models/decoder_only/language -m 'not core_model and not quant_model' - pytest -v -s models/embedding/language -m 'not core_model' - - pytest -v -s models/embedding/vision_language -m 'not core_model' - label: Multi-Modal Models Test (Standard) # 26min #mirror_hardwares: [amd] @@ -357,8 +362,10 @@ steps: - tests/models/embedding/vision_language - tests/models/encoder_decoder/vision_language commands: + - pip install git+https://github.com/TIGER-AI-Lab/Mantis.git - pytest -v -s models/decoder_only/audio_language -m 'core_model or quant_model' - pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m 'core_model or quant_model' + - pytest -v -s models/embedding/vision_language -m core_model - pytest -v -s models/encoder_decoder/language -m core_model - pytest -v -s models/encoder_decoder/vision_language -m core_model @@ -371,11 +378,13 @@ steps: - tests/models/embedding/vision_language - tests/models/encoder_decoder/vision_language commands: + - pip install git+https://github.com/TIGER-AI-Lab/Mantis.git - pytest -v -s models/decoder_only/audio_language -m 'not core_model and not quant_model' # HACK - run phi3v tests separately to sidestep this transformers bug # https://github.com/huggingface/transformers/issues/34307 - pytest -v -s models/decoder_only/vision_language/test_phi3v.py - pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m 'not core_model and not quant_model' + - pytest -v -s models/embedding/vision_language -m 'not core_model' - pytest -v -s models/encoder_decoder/language -m 'not core_model' - pytest -v -s models/encoder_decoder/vision_language -m 'not core_model' @@ -430,6 +439,9 @@ steps: - vllm/model_executor/models/ - tests/distributed/ - vllm/compilation + - vllm/worker/worker_base.py + - vllm/worker/worker.py + - vllm/worker/model_runner.py commands: - pytest -v -s ./compile/test_basic_correctness.py - pytest -v -s ./compile/test_wrapper.py @@ -443,6 +455,7 @@ steps: - pip install -e ./plugins/vllm_add_dummy_model - pytest -v -s distributed/test_distributed_oot.py - CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py + - CUDA_VISIBLE_DEVICES=0,1 pytest -v -s kv_transfer/disagg_test.py - label: Multi-step Tests (4 GPUs) # 36min working_dir: "/vllm-workspace/tests" @@ -477,7 +490,6 @@ steps: - label: LoRA TP Test (Distributed) num_gpus: 4 - soft_fail: true source_file_dependencies: - vllm/lora - tests/lora diff --git a/CMakeLists.txt b/CMakeLists.txt index 882d4412632a5..c78cdc77a7e42 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -522,7 +522,7 @@ else() FetchContent_Declare( vllm-flash-attn GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git - GIT_TAG 5259c586c403a4e4d8bf69973c159b40cc346fb9 + GIT_TAG 04325b6798bcc326c86fb35af62d05a9c8c8eceb GIT_PROGRESS TRUE # Don't share the vllm-flash-attn build between build types BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn diff --git a/benchmarks/backend_request_func.py b/benchmarks/backend_request_func.py index c3fed56e8a956..b67849038cf0d 100644 --- a/benchmarks/backend_request_func.py +++ b/benchmarks/backend_request_func.py @@ -24,6 +24,7 @@ class RequestFuncInput: model: str best_of: int = 1 logprobs: Optional[int] = None + extra_body: Optional[dict] = None multi_modal_content: Optional[dict] = None ignore_eos: bool = False @@ -36,6 +37,7 @@ class RequestFuncOutput: ttft: float = 0.0 # Time to first token itl: List[float] = field( default_factory=list) # List of inter-token latencies + tpot: float = 0.0 # avg next-token latencies prompt_len: int = 0 error: str = "" @@ -242,6 +244,8 @@ async def async_request_openai_completions( "stream": True, "ignore_eos": request_func_input.ignore_eos, } + if request_func_input.extra_body: + payload.update(request_func_input.extra_body) headers = { "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}" } @@ -336,6 +340,8 @@ async def async_request_openai_chat_completions( "stream": True, "ignore_eos": request_func_input.ignore_eos, } + if request_func_input.extra_body: + payload.update(request_func_input.extra_body) headers = { "Content-Type": "application/json", "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", diff --git a/benchmarks/benchmark_guided.py b/benchmarks/benchmark_guided.py new file mode 100644 index 0000000000000..1a0e62598bfcb --- /dev/null +++ b/benchmarks/benchmark_guided.py @@ -0,0 +1,494 @@ +"""Benchmark guided decoding throughput.""" +import argparse +import dataclasses +import json +import os +import random +import time +from typing import List + +import datasets +import pandas as pd +import uvloop +from transformers import AutoTokenizer, PreTrainedTokenizerBase + +from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs +from vllm.entrypoints.openai.api_server import ( + build_async_engine_client_from_engine_args) +from vllm.sampling_params import GuidedDecodingParams +from vllm.utils import FlexibleArgumentParser, merge_async_iterators + + +@dataclasses.dataclass +class SampleRequest: + """A class representing a single inference request for benchmarking. + + Attributes: + prompt: The input text prompt for the model. + multi_modal_data: Optional dictionary containing multi-modal data (e.g. + images). + prompt_len: The length of the prompt in tokens. + expected_output_len: The expected length of the output in tokens. + """ + prompt: str + prompt_len: int + expected_output_len: int + schema: dict + structure_type: str = 'json' + completion: str = None + + +def run_vllm(requests: List[SampleRequest], + engine_args: EngineArgs, + n: int, + guided_decoding_rate: float = 1.0, + warmup: bool = False) -> float: + from vllm import LLM, SamplingParams + llm = LLM(**vars(engine_args)) + + # Add the requests to the engine. + prompts: List[str] = [] + sampling_params: List[SamplingParams] = [] + # create a list containing random selected true or false + guided_decoding_req_idx = random.sample( + range(len(requests)), int(len(requests) * guided_decoding_rate)) + + if warmup: + print(">>>>> Running warmup prompt, for the first 5") + # We setup the first 5 requests to warmup FSM + # if using xgrammar dataset, we will skip warmup + warmup_requests = requests[:5] + for i, request in enumerate(warmup_requests): + prompts.append(request.prompt) + sampling_params.append( + SamplingParams( + n=n, + temperature=1.0, + top_p=1.0, + ignore_eos=True, + max_tokens=request.expected_output_len, + guided_decoding=GuidedDecodingParams(json=request.schema) + if guided_decoding_rate > 0 else None, + )) + llm.generate(prompts, sampling_params, use_tqdm=False) + + print(">>>>> Benchmark started...") + prompts = [] + sampling_params = [] + for i, request in enumerate(requests): + prompts.append(request.prompt) + sampling_params.append( + SamplingParams( + n=n, + temperature=1.0, + top_p=1.0, + ignore_eos=True, + max_tokens=request.expected_output_len, + guided_decoding=GuidedDecodingParams( + **{request.structure_type: request.schema}) + if i in guided_decoding_req_idx else None, + )) + + start = time.perf_counter() + outputs = llm.generate(prompts, sampling_params, use_tqdm=False) + ret = [] + for output, request in zip(outputs, requests): + generated_text = output.outputs[0].text + ret.append({ + "generated": generated_text, + "expected": request.completion + }) + end = time.perf_counter() + return end - start, ret + + +async def run_vllm_async( + requests: List[SampleRequest], + engine_args: AsyncEngineArgs, + n: int, + guided_decoding_rate: float = 1.0, + warmup: bool = False, + disable_frontend_multiprocessing: bool = False) -> float: + from vllm import SamplingParams + + async with build_async_engine_client_from_engine_args( + engine_args, disable_frontend_multiprocessing) as llm: + + # Add the requests to the engine. + prompts: List[str] = [] + sampling_params: List[SamplingParams] = [] + guided_decoding_req_idx = random.sample( + range(len(requests)), int(len(requests) * guided_decoding_rate)) + + if warmup: + print(">>>>>> Running warmup prompt, for the first 5") + # We setup the first 5 requests to warmup FSM + # if using xgrammar dataset, we will skip warmup + warmup_requests = requests[:5] + for i, request in enumerate(warmup_requests): + prompts.append(request.prompt) + sampling_params.append( + SamplingParams( + n=n, + temperature=1.0, + top_p=1.0, + ignore_eos=True, + max_tokens=request.expected_output_len, + guided_decoding=GuidedDecodingParams( + json=request.schema) + if guided_decoding_rate > 0 else None, + )) + generators = [] + for i, (prompt, sp) in enumerate(zip(prompts, sampling_params)): + generator = llm.generate(prompt, sp, request_id=f"test{i}") + generators.append(generator) + all_gens = merge_async_iterators(*generators) + async for i, res in all_gens: + pass + + print(">>>>> Benchmark started...") + prompts = [] + sampling_params = [] + for i, request in enumerate(requests): + prompts.append(request.prompt) + sampling_params.append( + SamplingParams( + n=n, + temperature=1.0, + top_p=1.0, + ignore_eos=True, + max_tokens=request.expected_output_len, + guided_decoding=GuidedDecodingParams(json=request.schema) + if i in guided_decoding_req_idx else None, + )) + + generators = [] + start_time = [] + latencies = [] + start = time.perf_counter() + for i, (prompt, sp) in enumerate(zip(prompts, sampling_params)): + generator = llm.generate(prompt, sp, request_id=f"test{i}") + generators.append(generator) + start_time.append(time.perf_counter()) + latencies.append([]) + all_gens = merge_async_iterators(*generators) + generated_texts = [''] * len(requests) + async for i, res in all_gens: + generated_texts[i] = res.outputs[0].text + lat = time.perf_counter() - start_time[i] + latencies[i].append(lat) + ret = [{ + 'generated': gt, + 'expected': req.completion + } for gt, req in zip(generated_texts, requests)] + end = time.perf_counter() + first_latency = pd.Series([lat[0] * 1000 for lat in latencies]) + next_latency = pd.Series([(lat[-1] - lat[0]) / len(lat[1:]) * 1000 + for lat in latencies]) + return end - start, ret, (first_latency, next_latency) + + +def sample_requests(tokenizer: PreTrainedTokenizerBase, + args: argparse.Namespace) -> List[SampleRequest]: + if args.dataset == 'json': + if args.json_schema_path is None: + dir_path = os.path.dirname(os.path.realpath(__file__)) + args.json_schema_path = os.path.join(dir_path, + "structured_schemas", + "structured_schema_1.json") + with open(args.json_schema_path) as f: + schema = json.load(f) + prompt = f"Generate an example of a user profile given the following schema: {json.dumps(schema)}" # noqa: E501 + input_len = len(tokenizer(prompt).input_ids) + print(f"Input length of the prompt: {input_len} tokens") + requests = [ + SampleRequest(prompt=prompt, + prompt_len=input_len, + expected_output_len=args.output_len, + schema=schema, + structure_type=args.structure_type) + for _ in range(args.num_prompts) + ] + + elif args.dataset == "grammar": + schema = """ + ?start: select_statement + + ?select_statement: "SELECT " column_list " FROM " table_name + + ?column_list: column_name ("," column_name)* + + ?table_name: identifier + + ?column_name: identifier + + ?identifier: /[a-zA-Z_][a-zA-Z0-9_]*/ + """ + prompt = "Generate an SQL query to show the 'username' \ + and 'email' from the 'users' table." + + input_len = len(tokenizer(prompt).input_ids) + print(f"Input length of the prompt: {input_len} tokens") + requests = [ + SampleRequest(prompt=prompt, + prompt_len=input_len, + expected_output_len=args.output_len, + schema=schema, + structure_type=args.structure_type) + for _ in range(args.num_prompts) + ] + + elif args.dataset == "regex": + regex = r"\w+@\w+\.com\n" + args.regex = regex + prompt = "Generate an email address for Alan Turing, \ + who works in Enigma. End in .com and new line. \ + Example result: alan.turing@enigma.com\n" + + input_len = len(tokenizer(prompt).input_ids) + print(f"Input length of the prompt: {input_len} tokens") + requests = [ + SampleRequest(prompt=prompt, + prompt_len=input_len, + expected_output_len=args.output_len, + schema=regex, + structure_type=args.structure_type) + for _ in range(args.num_prompts) + ] + + elif args.dataset == "choice": + choice = ["Positive", "Negative"] + args.choice = choice + prompt = "Classify this sentiment: vLLM is wonderful!" + input_len = len(tokenizer(prompt).input_ids) + print(f"Input length of the prompt: {input_len} tokens") + requests = [ + SampleRequest(prompt=prompt, + prompt_len=input_len, + expected_output_len=args.output_len, + schema=choice, + structure_type=args.structure_type) + for _ in range(args.num_prompts) + ] + + elif args.dataset == "xgrammar_bench": + args.warmup = False + requests: List[SampleRequest] = [] + dataset = datasets.load_dataset("NousResearch/json-mode-eval", + split="train") + print(f"dataset has {len(dataset)} entries") + len_dataset = len(dataset) + for data_point_idx in range(args.num_prompts): + idx = data_point_idx + while idx >= len_dataset: + idx -= len_dataset + schema = dataset["schema"][idx] + prompt = tokenizer.apply_chat_template(dataset["prompt"][idx], + tokenize=False) + input_len = len(tokenizer(prompt).input_ids) + completion = dataset["completion"][idx] + + requests.append( + SampleRequest(prompt=prompt, + prompt_len=input_len, + expected_output_len=args.output_len, + schema=schema, + completion=completion)) + + return requests + + +def evaluate(ret, args): + + def _eval_correctness_json(expected, actual): + # extract json string from string using regex + import re + actual = actual.replace('\n', '').replace(' ', '').strip() + try: + actual = re.search(r'\{.*\}', actual).group() + actual = json.loads(actual) + except Exception: + return False + + return True + + def _eval_correctness_choice(expected, actual): + return actual in args.choice + + def _eval_correctness_regex(expected, actual): + import re + return re.match(args.regex, actual) is not None + + def _eval_correctness(expected, actual): + if args.structure_type == 'json': + return _eval_correctness_json(expected, actual) + elif args.structure_type == 'regex': + return _eval_correctness_regex(expected, actual) + elif args.structure_type == 'choice': + return _eval_correctness_choice(expected, actual) + else: + return None + + scores = [] + for res in ret: + score = _eval_correctness(res['expected'], res['generated']) + res['correctness'] = score + scores.append(score) + + not_none_scores = [score for score in scores if score is not None] + + return (sum(not_none_scores) / len(not_none_scores) * + 100) if len(not_none_scores) > 0 else None + + +def main(args: argparse.Namespace): + print(args) + random.seed(args.seed) + + # async engine is working for 'regex', 'choice' and 'grammar' + if args.dataset == 'grammar': + args.structure_type = 'grammar' + args.async_engine = False + elif args.dataset == 'regex': + args.structure_type = 'regex' + args.async_engine = False + elif args.dataset == 'choice': + args.structure_type = 'choice' + args.async_engine = False + else: + args.structure_type = 'json' + + if args.no_guided_decoding: + args.guided_decoding_ratio = 0 + if args.save_results: + result_file_name = f'{args.guided_decoding_ratio}guided' + result_file_name += f"_{args.model.split('/')[-1]}" + result_file_name += f"_{args.dataset}" + result_file_name += f"_{args.num_prompts}" + result_file_name += f"_out{args.output_len}" + result_file_name += f"_async{args.async_engine}" + result_file_name += f"_warmup{args.warmup}" + result_file_name += f"_chunkedprefill{args.enable_chunked_prefill}" + result_file_name += ".txt" + else: + result_file_name = None + + # Synthesize a prompt with the given input length. + tokenizer = AutoTokenizer.from_pretrained( + args.tokenizer, trust_remote_code=args.trust_remote_code) + requests = sample_requests(tokenizer, args) + + if args.async_engine: + engine_args = AsyncEngineArgs.from_cli_args(args) + elapsed_time, ret, (first_latency, next_latency) = uvloop.run( + run_vllm_async(requests, engine_args, args.n, + args.guided_decoding_ratio, args.warmup, + args.disable_frontend_multiprocessing)) + else: + engine_args = EngineArgs.from_cli_args(args) + elapsed_time, ret = run_vllm(requests, engine_args, args.n, + args.guided_decoding_ratio, args.warmup) + first_latency, next_latency = None, None + + score = evaluate(ret, args) + total_num_tokens = sum(request.prompt_len + request.expected_output_len + for request in requests) + total_output_tokens = sum(request.expected_output_len + for request in requests) + if first_latency is not None: + latency_breakdown = "\nFirst token latency(msecs):\n" + latency_breakdown += f"{first_latency.describe()}" + latency_breakdown += "\nNext token latency(msecs):\n" + latency_breakdown += f"{next_latency.describe()}" + print( + f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, " + f"{total_num_tokens / elapsed_time:.2f} total tokens/s, " + f"{total_output_tokens / elapsed_time:.2f} output tokens/s", + f"Correct rate is {score} %", + f"{latency_breakdown if first_latency is not None else ''}") + + # Output JSON results if specified + if args.output_json or result_file_name: + results = { + "elapsed_time": elapsed_time, + "num_requests": len(requests), + "total_num_tokens": total_num_tokens, + "total_output_tokens": total_output_tokens, + "requests_per_second": len(requests) / elapsed_time, + "tokens_per_second": f"{total_num_tokens / elapsed_time:.2f}", + "output_tokens_per_second": + f"{total_output_tokens / elapsed_time:.2f}", + "correct_rate(%)": score + } + results = {"outputs": ret, **results} + if first_latency is not None: + results["first_token_latency(msecs)"] = first_latency.describe( + ).to_dict() + results["next_token_latency(msecs)"] = next_latency.describe( + ).to_dict() + if args.output_json: + with open(args.output_json, "w") as f: + json.dump(results, f, indent=4) + elif result_file_name: + with open(result_file_name, "w") as f: + json.dump(results, f, indent=4) + + +if __name__ == "__main__": + parser = FlexibleArgumentParser(description="Benchmark guided decoding.") + parser = AsyncEngineArgs.add_cli_args(parser) + + parser.add_argument("--output-len", + type=int, + default=512, + help="Output length for each request. Overrides the " + "output length from the dataset.") + parser.add_argument( + "--dataset", + default='json', + choices=['json', 'grammar', 'regex', 'choice', 'xgrammar_bench']) + parser.add_argument("--json_schema_path", + type=str, + default=None, + help="Path to json schema.") + parser.add_argument("--n", + type=int, + default=1, + help="Number of generated sequences per prompt.") + parser.add_argument("--num-prompts", + type=int, + default=10, + help="Number of prompts to process.") + parser.add_argument( + '--output-json', + type=str, + default=None, + help='Path to save the throughput results in JSON format.') + parser.add_argument("--async-engine", + action='store_true', + default=False, + help="Use vLLM async engine rather than LLM class.") + parser.add_argument("--no-guided-decoding", + action='store_true', + default=False, + help="Whether to disable JSON decoding or not.") + parser.add_argument("--guided-decoding-ratio", + type=float, + default=1.0, + help="Ratio of Guided Decoding requests") + parser.add_argument("--disable-frontend-multiprocessing", + action='store_true', + default=False, + help="Disable decoupled async engine frontend.") + parser.add_argument("--warmup", + action="store_true", + default=False, + help="Run warmup prompts before benchmark.") + parser.add_argument("--save-results", + action="store_true", + default=False, + help="save output results.") + args = parser.parse_args() + if args.tokenizer is None: + args.tokenizer = args.model + main(args) diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index e9fc037a46965..3256692142c5e 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -199,6 +199,56 @@ def sample_sonnet_requests( return sampled_requests +def sample_mmmu_pro_vision_requests( + dataset, + num_requests: int, + tokenizer: PreTrainedTokenizerBase, + fixed_output_len: Optional[int] = None, +) -> List[Tuple[str, str, int, Optional[Dict[str, Collection[str]]]]]: + sampled_requests: List[Tuple[str, int, int, Dict[str, + Collection[str]]]] = [] + for data in dataset: + if len(sampled_requests) == num_requests: + break + + # MMMU-Pro vision direct prompt + # Ref: https://github.com/MMMU-Benchmark/MMMU/blob/6ce42f4d8f70c1841c67867152648974415b5cac/mmmu-pro/prompts.yaml#L5 + prompt = ( + "Answer with the option letter from the given choices directly. " + "The last line of your response should be of the following " + "format: 'Answer: $LETTER' (without quotes) where LETTER is one of " + "options.") + + prompt_token_ids = tokenizer(prompt).input_ids + if fixed_output_len is None: + # Default max output len is set to 128 + print("--hf-output-len is not provided. Using default value 128.") + fixed_output_len = 128 + + prompt_len = len(prompt_token_ids) + output_len = fixed_output_len + + assert isinstance( + data["image"], + Image), ("Input image format must be `PIL.Image.Image`, " + f"given {type(data['image'])}.") + image: Image = data["image"] + image = image.convert("RGB") + image_data = io.BytesIO() + image.save(image_data, format='JPEG') + image_base64 = base64.b64encode(image_data.getvalue()).decode("utf-8") + mm_content = { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{image_base64}" + }, + } + + sampled_requests.append((prompt, prompt_len, output_len, mm_content)) + + return sampled_requests + + def sample_hf_requests( dataset_path: str, dataset_subset: str, @@ -208,6 +258,21 @@ def sample_hf_requests( random_seed: int, fixed_output_len: Optional[int] = None, ) -> List[Tuple[str, str, int, Optional[Dict[str, Collection[str]]]]]: + + # Special case for MMMU-Pro vision dataset + if dataset_path == 'MMMU/MMMU_Pro' and dataset_subset == 'vision': + assert dataset_split == "test" + dataset = load_dataset(dataset_path, + name=dataset_subset, + split=dataset_split, + streaming=True) + assert "image" in dataset.features, ( + "MMMU/MMMU_Pro vision dataset must have 'image' column.") + filter_func = lambda x: isinstance(x["image"], Image) + dataset = dataset.shuffle(seed=random_seed).filter(filter_func) + return sample_mmmu_pro_vision_requests(dataset, num_requests, + tokenizer, fixed_output_len) + dataset = load_dataset(dataset_path, name=dataset_subset, split=dataset_split, diff --git a/benchmarks/benchmark_serving_guided.py b/benchmarks/benchmark_serving_guided.py new file mode 100644 index 0000000000000..4435d87e18a8a --- /dev/null +++ b/benchmarks/benchmark_serving_guided.py @@ -0,0 +1,881 @@ +r"""Benchmark online serving throughput with guided decoding. + +On the server side, run one of the following commands: + (vLLM OpenAI API server) + vllm serve --disable-log-requests + + (TGI backend) + ./launch_tgi_server.sh + +On the client side, run: + python benchmarks/benchmark_serving.py \ + --backend \ + --model \ + --dataset json \ + --guided-decoding-ratio 1.0 \ + --guided-decoding-backend xgrammar \ + --request-rate 10 \ + --num-prompts 1000 + + when using tgi backend, add + --endpoint /generate_stream + to the end of the command above. +""" +import argparse +import asyncio +import dataclasses +import json +import os +import random +import time +import warnings +from dataclasses import dataclass +from typing import AsyncGenerator, List, Optional, Tuple + +import datasets +import numpy as np +import pandas as pd +from backend_request_func import (ASYNC_REQUEST_FUNCS, RequestFuncInput, + RequestFuncOutput) +from tqdm.asyncio import tqdm +from transformers import PreTrainedTokenizerBase + +try: + from vllm.transformers_utils.tokenizer import get_tokenizer +except ImportError: + from backend_request_func import get_tokenizer + +try: + from vllm.utils import FlexibleArgumentParser +except ImportError: + from argparse import ArgumentParser as FlexibleArgumentParser + +MILLISECONDS_TO_SECONDS_CONVERSION = 1000 + + +@dataclass +class BenchmarkMetrics: + completed: int + total_input: int + total_output: int + request_throughput: float + request_goodput: float + output_throughput: float + total_token_throughput: float + mean_ttft_ms: float + median_ttft_ms: float + std_ttft_ms: float + percentiles_ttft_ms: List[Tuple[float, float]] + mean_tpot_ms: float + median_tpot_ms: float + std_tpot_ms: float + percentiles_tpot_ms: List[Tuple[float, float]] + mean_itl_ms: float + median_itl_ms: float + std_itl_ms: float + percentiles_itl_ms: List[Tuple[float, float]] + # E2EL stands for end-to-end latency per request. + # It is the time taken on the client side from sending + # a request to receiving a complete response. + mean_e2el_ms: float + median_e2el_ms: float + std_e2el_ms: float + percentiles_e2el_ms: List[Tuple[float, float]] + + +@dataclasses.dataclass +class SampleRequest: + """A class representing a single inference request for benchmarking. + + Attributes: + prompt: The input text prompt for the model. + multi_modal_data: Optional dictionary containing multi-modal data (e.g. + images). + prompt_len: The length of the prompt in tokens. + expected_output_len: The expected length of the output in tokens. + """ + prompt: str + prompt_len: int + expected_output_len: int + schema: dict + structure_type: str + completion: str = None + + +def sample_requests(tokenizer: PreTrainedTokenizerBase, + args: argparse.Namespace) -> List[SampleRequest]: + if args.dataset == 'json': + if args.json_schema_path is None: + dir_path = os.path.dirname(os.path.realpath(__file__)) + args.json_schema_path = os.path.join(dir_path, + "structured_schemas", + "structured_schema_1.json") + with open(args.json_schema_path) as f: + schema = json.load(f) + prompt = f"Generate an example of a user profile given the following schema: {json.dumps(schema)}" # noqa: E501 + input_len = len(tokenizer(prompt).input_ids) + print(f"Input length of the prompt: {input_len} tokens") + requests = [ + SampleRequest(prompt=prompt, + prompt_len=input_len, + expected_output_len=args.output_len, + schema=schema, + structure_type=args.structure_type) + for _ in range(args.num_prompts) + ] + + elif args.dataset == "grammar": + schema = """ + ?start: select_statement + + ?select_statement: "SELECT " column_list " FROM " table_name + + ?column_list: column_name ("," column_name)* + + ?table_name: identifier + + ?column_name: identifier + + ?identifier: /[a-zA-Z_][a-zA-Z0-9_]*/ + """ + prompt = "Generate an SQL query to show the 'username' \ + and 'email' from the 'users' table." + + input_len = len(tokenizer(prompt).input_ids) + print(f"Input length of the prompt: {input_len} tokens") + requests = [ + SampleRequest(prompt=prompt, + prompt_len=input_len, + expected_output_len=args.output_len, + schema=schema, + structure_type=args.structure_type) + for _ in range(args.num_prompts) + ] + + elif args.dataset == "regex": + regex = r"\w+@\w+\.com\n" + args.regex = regex + prompt = "Generate an email address for Alan Turing, \ + who works in Enigma. End in .com and new line. \ + Example result: alan.turing@enigma.com\n" + + input_len = len(tokenizer(prompt).input_ids) + print(f"Input length of the prompt: {input_len} tokens") + requests = [ + SampleRequest(prompt=prompt, + prompt_len=input_len, + expected_output_len=args.output_len, + schema=regex, + structure_type=args.structure_type) + for _ in range(args.num_prompts) + ] + + elif args.dataset == "choice": + choice = ["Positive", "Negative"] + args.choice = choice + prompt = "Classify this sentiment: vLLM is wonderful!" + input_len = len(tokenizer(prompt).input_ids) + print(f"Input length of the prompt: {input_len} tokens") + requests = [ + SampleRequest(prompt=prompt, + prompt_len=input_len, + expected_output_len=args.output_len, + schema=choice, + structure_type=args.structure_type) + for _ in range(args.num_prompts) + ] + + elif args.dataset == "xgrammar_bench": + requests: List[SampleRequest] = [] + dataset = datasets.load_dataset("NousResearch/json-mode-eval", + split="train") + print(f"dataset has {len(dataset)} entries") + len_dataset = len(dataset) + for data_point_idx in range(args.num_prompts): + idx = data_point_idx + while idx >= len_dataset: + idx -= len_dataset + schema = dataset["schema"][idx] + prompt = tokenizer.apply_chat_template(dataset["prompt"][idx], + tokenize=False) + input_len = len(tokenizer(prompt).input_ids) + completion = dataset["completion"][idx] + + requests.append( + SampleRequest(prompt=prompt, + prompt_len=input_len, + expected_output_len=args.output_len, + schema=schema, + structure_type=args.structure_type, + completion=completion)) + + return requests + + +async def get_request( + input_requests: List[SampleRequest], + request_rate: float, + burstiness: float = 1.0, +) -> AsyncGenerator[Tuple[int, SampleRequest], None]: + """ + Asynchronously generates requests at a specified rate + with OPTIONAL burstiness. + + Args: + input_requests: + A list of input requests, each represented as a tuple. + request_rate: + The rate at which requests are generated (requests/s). + burstiness (optional): + The burstiness factor of the request generation. + Only takes effect when request_rate is not inf. + Default value is 1, which follows a Poisson process. + Otherwise, the request intervals follow a gamma distribution. + A lower burstiness value (0 < burstiness < 1) results + in more bursty requests, while a higher burstiness value + (burstiness > 1) results in a more uniform arrival of requests. + """ + input_requests = iter(input_requests) + + # Calculate scale parameter theta to maintain the desired request_rate. + assert burstiness > 0, ( + f"A positive burstiness factor is expected, but given {burstiness}.") + theta = 1.0 / (request_rate * burstiness) + + for i, request in enumerate(input_requests): + yield i, request + + if request_rate == float("inf"): + # If the request rate is infinity, then we don't need to wait. + continue + + # Sample the request interval from the gamma distribution. + # If burstiness is 1, it follows exponential distribution. + interval = np.random.gamma(shape=burstiness, scale=theta) + # The next request will be sent after the interval. + await asyncio.sleep(interval) + + +def calculate_metrics( + input_requests: List[Tuple[str, int, int]], + outputs: List[RequestFuncOutput], + dur_s: float, + tokenizer: PreTrainedTokenizerBase, + selected_percentile_metrics: List[str], + selected_percentiles: List[float], +) -> Tuple[BenchmarkMetrics, List[int]]: + actual_output_lens: List[int] = [] + total_input = 0 + completed = 0 + good_completed = 0 + itls: List[float] = [] + tpots: List[float] = [] + all_tpots: List[float] = [] + ttfts: List[float] = [] + e2els: List[float] = [] + for i in range(len(outputs)): + if outputs[i].success: + # We use the tokenizer to count the number of output tokens for all + # serving backends instead of looking at len(outputs[i].itl) since + # multiple output tokens may be bundled together + # Note : this may inflate the output token count slightly + output_len = len( + tokenizer(outputs[i].generated_text, + add_special_tokens=False).input_ids) + actual_output_lens.append(output_len) + total_input += input_requests[i].prompt_len + tpot = 0 + if output_len > 1: + tpot = (outputs[i].latency - outputs[i].ttft) / (output_len - + 1) + tpots.append(tpot) + outputs[i].tpot = sum(tpots) / len(tpots) if len(tpots) else 0 + # Note: if output_len <= 1, we regard tpot as 0 for goodput + all_tpots.append(tpot) + itls += outputs[i].itl + ttfts.append(outputs[i].ttft) + e2els.append(outputs[i].latency) + completed += 1 + else: + actual_output_lens.append(0) + + if completed == 0: + warnings.warn( + "All requests failed. This is likely due to a misconfiguration " + "on the benchmark arguments.", + stacklevel=2) + metrics = BenchmarkMetrics( + completed=completed, + total_input=total_input, + total_output=sum(actual_output_lens), + request_throughput=completed / dur_s, + request_goodput=good_completed / dur_s, + output_throughput=sum(actual_output_lens) / dur_s, + total_token_throughput=(total_input + sum(actual_output_lens)) / dur_s, + mean_ttft_ms=np.mean(ttfts or 0) * + 1000, # ttfts is empty if streaming is not supported by backend + std_ttft_ms=np.std(ttfts or 0) * 1000, + median_ttft_ms=np.median(ttfts or 0) * 1000, + percentiles_ttft_ms=[(p, np.percentile(ttfts or 0, p) * 1000) + for p in selected_percentiles], + mean_tpot_ms=np.mean(tpots or 0) * 1000, + std_tpot_ms=np.std(tpots or 0) * 1000, + median_tpot_ms=np.median(tpots or 0) * 1000, + percentiles_tpot_ms=[(p, np.percentile(tpots or 0, p) * 1000) + for p in selected_percentiles], + mean_itl_ms=np.mean(itls or 0) * 1000, + std_itl_ms=np.std(itls or 0) * 1000, + median_itl_ms=np.median(itls or 0) * 1000, + percentiles_itl_ms=[(p, np.percentile(itls or 0, p) * 1000) + for p in selected_percentiles], + mean_e2el_ms=np.mean(e2els or 0) * 1000, + std_e2el_ms=np.std(e2els or 0) * 1000, + median_e2el_ms=np.median(e2els or 0) * 1000, + percentiles_e2el_ms=[(p, np.percentile(e2els or 0, p) * 1000) + for p in selected_percentiles], + ) + + return metrics, actual_output_lens + + +async def benchmark( + backend: str, + api_url: str, + base_url: str, + model_id: str, + tokenizer: PreTrainedTokenizerBase, + input_requests: List[SampleRequest], + request_rate: float, + burstiness: float, + disable_tqdm: bool, + profile: bool, + selected_percentile_metrics: List[str], + selected_percentiles: List[str], + ignore_eos: bool, + max_concurrency: Optional[int], + guided_decoding_ratio: float, + guided_decoding_backend: str, +): + if backend in ASYNC_REQUEST_FUNCS: + request_func = ASYNC_REQUEST_FUNCS[backend] + else: + raise ValueError(f"Unknown backend: {backend}") + + def prepare_extra_body(request) -> dict: + extra_body = {} + # Add the schema to the extra_body + extra_body[request.structure_type] = request.schema + # Add the specific guided_decoding_backend + extra_body["guided_decoding_backend"] = guided_decoding_backend + return extra_body + + print("Starting initial single prompt test run...") + guided_decoding_req_idx = random.sample( + range(len(input_requests)), + int(len(input_requests) * guided_decoding_ratio)) + + test_request = input_requests[0] + test_input = RequestFuncInput( + model=model_id, + prompt=test_request.prompt, + api_url=api_url, + prompt_len=test_request.prompt_len, + output_len=test_request.expected_output_len, + ignore_eos=ignore_eos, + extra_body=prepare_extra_body(test_request), + ) + test_output = await request_func(request_func_input=test_input) + if not test_output.success: + raise ValueError( + "Initial test run failed - Please make sure benchmark arguments " + f"are correctly specified. Error: {test_output.error}") + else: + print("Initial test run completed. Starting main benchmark run...") + + if profile: + print("Starting profiler...") + profile_input = RequestFuncInput( + model=model_id, + prompt=test_request.prompt, + api_url=base_url + "/start_profile", + prompt_len=test_request.prompt_len, + output_len=test_request.expected_output_len, + ignore_eos=ignore_eos, + extra_body=prepare_extra_body(test_request), + ) + profile_output = await request_func(request_func_input=profile_input) + if profile_output.success: + print("Profiler started") + + if burstiness == 1.0: + distribution = "Poisson process" + else: + distribution = "Gamma distribution" + + print(f"Traffic request rate: {request_rate}") + print(f"Burstiness factor: {burstiness} ({distribution})") + print(f"Maximum request concurrency: {max_concurrency}") + + pbar = None if disable_tqdm else tqdm(total=len(input_requests)) + + # This can be used once the minimum Python version is 3.10 or higher, + # and it will simplify the code in limited_request_func. + # semaphore = (asyncio.Semaphore(max_concurrency) + # if max_concurrency else contextlib.nullcontext()) + semaphore = (asyncio.Semaphore(max_concurrency) + if max_concurrency else None) + + async def limited_request_func(request_func_input, pbar): + if semaphore is None: + return await request_func(request_func_input=request_func_input, + pbar=pbar) + async with semaphore: + return await request_func(request_func_input=request_func_input, + pbar=pbar) + + benchmark_start_time = time.perf_counter() + tasks: List[asyncio.Task] = [] + expected: List[str] = [] + async for i, request in get_request(input_requests, request_rate, + burstiness): + extra_body = prepare_extra_body( + request) if i in guided_decoding_req_idx else None + request_func_input = RequestFuncInput( + model=model_id, + prompt=request.prompt, + api_url=api_url, + prompt_len=request.prompt_len, + output_len=request.expected_output_len, + ignore_eos=ignore_eos, + extra_body=extra_body, + ) + expected.append(request.completion) + tasks.append( + asyncio.create_task( + limited_request_func(request_func_input=request_func_input, + pbar=pbar))) + outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks) + + if profile: + print("Stopping profiler...") + profile_input = RequestFuncInput( + model=model_id, + prompt=test_request.prompt, + api_url=base_url + "/stop_profile", + prompt_len=test_request.prompt_len, + output_len=test_request.expected_output_len, + extra_body={test_request.structure_type: test_request.schema}, + ) + profile_output = await request_func(request_func_input=profile_input) + if profile_output.success: + print("Profiler stopped") + + if pbar is not None: + pbar.close() + + benchmark_duration = time.perf_counter() - benchmark_start_time + + metrics, actual_output_lens = calculate_metrics( + input_requests=input_requests, + outputs=outputs, + dur_s=benchmark_duration, + tokenizer=tokenizer, + selected_percentile_metrics=selected_percentile_metrics, + selected_percentiles=selected_percentiles, + ) + + print("{s:{c}^{n}}".format(s=' Serving Benchmark Result ', n=50, c='=')) + print("{:<40} {:<10}".format("Successful requests:", metrics.completed)) + print("{:<40} {:<10.2f}".format("Benchmark duration (s):", + benchmark_duration)) + print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input)) + print("{:<40} {:<10}".format("Total generated tokens:", + metrics.total_output)) + print("{:<40} {:<10.2f}".format("Request throughput (req/s):", + metrics.request_throughput)) + print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):", + metrics.output_throughput)) + print("{:<40} {:<10.2f}".format("Total Token throughput (tok/s):", + metrics.total_token_throughput)) + + result = { + "duration": + benchmark_duration, + "completed": + metrics.completed, + "total_input_tokens": + metrics.total_input, + "total_output_tokens": + metrics.total_output, + "request_throughput": + metrics.request_throughput, + "output_throughput": + metrics.output_throughput, + "total_token_throughput": + metrics.total_token_throughput, + "ttft_description": + pd.Series([output.ttft for output in outputs]).describe().to_dict(), + "tpot_description": + pd.Series([output.tpot for output in outputs]).describe().to_dict(), + "input_lens": [output.prompt_len for output in outputs], + "output_lens": + actual_output_lens, + "ttfts": [output.ttft for output in outputs], + "itls": [output.itl for output in outputs], + "errors": [output.error for output in outputs], + } + + ret = [{ + 'generated': output.generated_text, + 'expected': gt + } for output, gt in zip(outputs, expected)] + + def process_one_metric( + # E.g., "ttft" + metric_attribute_name: str, + # E.g., "TTFT" + metric_name: str, + # E.g., "Time to First Token" + metric_header: str, + ): + # This function prints and adds statistics of the specified + # metric. + if metric_attribute_name not in selected_percentile_metrics: + return + print("{s:{c}^{n}}".format(s=metric_header, n=50, c='-')) + print("{:<40} {:<10.2f}".format( + f"Mean {metric_name} (ms):", + getattr(metrics, f"mean_{metric_attribute_name}_ms"))) + print("{:<40} {:<10.2f}".format( + f"Median {metric_name} (ms):", + getattr(metrics, f"median_{metric_attribute_name}_ms"))) + result[f"mean_{metric_attribute_name}_ms"] = getattr( + metrics, f"mean_{metric_attribute_name}_ms") + result[f"median_{metric_attribute_name}_ms"] = getattr( + metrics, f"median_{metric_attribute_name}_ms") + result[f"std_{metric_attribute_name}_ms"] = getattr( + metrics, f"std_{metric_attribute_name}_ms") + for p, value in getattr(metrics, + f"percentiles_{metric_attribute_name}_ms"): + p_word = str(int(p)) if int(p) == p else str(p) + print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):", + value)) + result[f"p{p_word}_{metric_attribute_name}_ms"] = value + + process_one_metric("ttft", "TTFT", "Time to First Token") + process_one_metric("tpot", "TPOT", + "Time per Output Token (excl. 1st token)") + process_one_metric("itl", "ITL", "Inter-token Latency") + process_one_metric("e2el", "E2EL", "End-to-end Latency") + + print("=" * 50) + + return result, ret + + +def evaluate(ret, args): + + def _eval_correctness_json(expected, actual): + # extract json string from string using regex + import re + actual = actual.replace('\n', '').replace(' ', '').strip() + try: + actual = re.search(r'\{.*\}', actual).group() + actual = json.loads(actual) + except Exception: + return False + + return True + + def _eval_correctness_choice(expected, actual): + return actual in args.choice + + def _eval_correctness_regex(expected, actual): + import re + return re.match(args.regex, actual) is not None + + def _eval_correctness(expected, actual): + if args.structure_type == 'guided_json': + return _eval_correctness_json(expected, actual) + elif args.structure_type == 'guided_regex': + return _eval_correctness_regex(expected, actual) + elif args.structure_type == 'guided_choice': + return _eval_correctness_choice(expected, actual) + else: + return None + + scores = [] + for res in ret: + score = _eval_correctness(res['expected'], res['generated']) + res['correctness'] = score + scores.append(score) + + not_none_scores = [score for score in scores if score is not None] + + return (sum(not_none_scores) / len(not_none_scores) * + 100) if len(not_none_scores) > 0 else None + + +def main(args: argparse.Namespace): + print(args) + random.seed(args.seed) + np.random.seed(args.seed) + + backend = args.backend + model_id = args.model + tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model + + if args.base_url is not None: + api_url = f"{args.base_url}{args.endpoint}" + base_url = f"{args.base_url}" + else: + api_url = f"http://{args.host}:{args.port}{args.endpoint}" + base_url = f"http://{args.host}:{args.port}" + + tokenizer = get_tokenizer(tokenizer_id, + trust_remote_code=args.trust_remote_code) + + if args.dataset == 'grammar': + args.structure_type = 'guided_grammar' + elif args.dataset == 'regex': + args.structure_type = 'guided_regex' + elif args.dataset == 'choice': + args.structure_type = 'guided_choice' + else: + args.structure_type = 'guided_json' + + if args.no_guided_decoding: + args.guided_decoding_ratio = 0 + if args.save_results: + result_file_name = f'{args.guided_decoding_ratio}guided' + result_file_name += f"_{backend}" + result_file_name += f"_{args.request_rate}qps" + result_file_name += f"_{args.model.split('/')[-1]}" + result_file_name += f"_{args.dataset}" + result_file_name += f"_{args.num_prompts}" + result_file_name += f"_out{args.output_len}" + result_file_name += ".txt" + else: + result_file_name = None + + input_requests = sample_requests(tokenizer, args) + + benchmark_result, ret = asyncio.run( + benchmark( + backend=backend, + api_url=api_url, + base_url=base_url, + model_id=model_id, + tokenizer=tokenizer, + input_requests=input_requests, + request_rate=args.request_rate, + burstiness=args.burstiness, + disable_tqdm=args.disable_tqdm, + profile=args.profile, + selected_percentile_metrics=args.percentile_metrics.split(","), + selected_percentiles=[ + float(p) for p in args.metric_percentiles.split(",") + ], + ignore_eos=args.ignore_eos, + max_concurrency=args.max_concurrency, + guided_decoding_ratio=args.guided_decoding_ratio, + guided_decoding_backend=args.guided_decoding_backend, + )) + + # Save config and results to json + score = evaluate(ret, args) + print("correct_rate(%)", score, '\n') + if args.save_results: + results = { + "backend": + backend, + "model_id": + model_id, + "tokenizer_id": + tokenizer_id, + "num_prompts": + args.num_prompts, + "request_rate": + args.request_rate if args.request_rate < float("inf") else "inf", + "burstiness": + args.burstiness, + "max_concurrency": + args.max_concurrency, + "correct_rate(%)": + score + } + results = {"outputs": ret, **results, **benchmark_result} + + # Save to file + if args.result_filename: + result_file_name = args.result_filename + if args.result_dir: + result_file_name = os.path.join(args.result_dir, result_file_name) + with open(result_file_name, "w", encoding='utf-8') as outfile: + json.dump(results, outfile, indent=4) + + +if __name__ == "__main__": + parser = FlexibleArgumentParser( + description="Benchmark the online serving throughput.") + parser.add_argument( + "--backend", + type=str, + default="vllm", + choices=list(ASYNC_REQUEST_FUNCS.keys()), + ) + parser.add_argument( + "--base-url", + type=str, + default=None, + help="Server or API base url if not using http host and port.", + ) + parser.add_argument("--host", type=str, default="localhost") + parser.add_argument("--port", type=int, default=8000) + parser.add_argument( + "--endpoint", + type=str, + default="/v1/completions", + help="API endpoint.", + ) + parser.add_argument( + "--dataset", + default='json', + choices=['json', 'grammar', 'regex', 'choice', 'xgrammar_bench']) + parser.add_argument("--json_schema_path", + type=str, + default=None, + help="Path to json schema.") + parser.add_argument( + "--max-concurrency", + type=int, + default=None, + help="Maximum number of concurrent requests. This can be used " + "to help simulate an environment where a higher level component " + "is enforcing a maximum number of concurrent requests. While the " + "--request-rate argument controls the rate at which requests are " + "initiated, this argument will control how many are actually allowed " + "to execute at a time. This means that when used in combination, the " + "actual request rate may be lower than specified with --request-rate, " + "if the server is not processing requests fast enough to keep up.") + parser.add_argument( + "--model", + type=str, + required=True, + help="Name of the model.", + ) + parser.add_argument( + "--tokenizer", + type=str, + help= + "Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501 + ) + parser.add_argument( + "--num-prompts", + type=int, + default=1000, + help="Number of prompts to process.", + ) + parser.add_argument( + "--output-len", + type=int, + default=128, + help="Number of output tokens.", + ) + parser.add_argument( + "--request-rate", + type=float, + default=float("inf"), + help="Number of requests per second. If this is inf, " + "then all the requests are sent at time 0. " + "Otherwise, we use Poisson process or gamma distribution " + "to synthesize the request arrival times.", + ) + parser.add_argument( + "--burstiness", + type=float, + default=1.0, + help="Burstiness factor of the request generation. " + "Only take effect when request_rate is not inf. " + "Default value is 1, which follows Poisson process. " + "Otherwise, the request intervals follow a gamma distribution. " + "A lower burstiness value (0 < burstiness < 1) results in more " + "bursty requests. A higher burstiness value (burstiness > 1) " + "results in a more uniform arrival of requests.", + ) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument( + "--trust-remote-code", + action="store_true", + help="Trust remote code from huggingface", + ) + parser.add_argument( + "--disable-tqdm", + action="store_true", + help="Specify to disable tqdm progress bar.", + ) + parser.add_argument( + "--save-results", + action="store_true", + help="Specify to save benchmark results to a json file", + ) + parser.add_argument( + "--profile", + action="store_true", + help="Use Torch Profiler. The endpoint must be launched with " + "VLLM_TORCH_PROFILER_DIR to enable profiler.", + ) + parser.add_argument( + "--result-dir", + type=str, + default=None, + help="Specify directory to save benchmark json results." + "If not specified, results are saved in the current directory.", + ) + parser.add_argument( + "--result-filename", + type=str, + default=None, + help="Specify the filename to save benchmark json results." + "If not specified, results will be saved in " + "{backend}-{args.request_rate}qps-{base_model_id}-{current_dt}.json" + " format.", + ) + parser.add_argument( + "--ignore-eos", + action="store_true", + help="Set ignore_eos flag when sending the benchmark request." + "Warning: ignore_eos is not supported in deepspeed_mii and tgi.") + parser.add_argument( + "--percentile-metrics", + type=str, + default="ttft,tpot,itl", + help="Comma-seperated list of selected metrics to report percentils. " + "This argument specifies the metrics to report percentiles. " + "Allowed metric names are \"ttft\", \"tpot\", \"itl\", \"e2el\". " + "Default value is \"ttft,tpot,itl\".") + parser.add_argument( + "--metric-percentiles", + type=str, + default="99", + help="Comma-seperated list of percentiles for selected metrics. " + "To report 25-th, 50-th, and 75-th percentiles, use \"25,50,75\". " + "Default value is \"99\". " + "Use \"--percentile-metrics\" to select metrics.", + ) + parser.add_argument("--no-guided-decoding", + action='store_true', + default=False, + help="Whether to disable JSON decoding or not.") + parser.add_argument("--guided-decoding-ratio", + type=float, + default=1.0, + help="Ratio of Guided Decoding requests") + parser.add_argument("--guided-decoding-backend", + type=str, + choices=["outlines", "lm-format-enforcer", "xgrammar"], + default="xgrammar", + help="Backend to use for guided decoding") + + args = parser.parse_args() + main(args) diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index 159cf055737ce..1e5967bd9bf8b 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -294,23 +294,36 @@ def main(args: argparse.Namespace): tokenizer = AutoTokenizer.from_pretrained( args.tokenizer, trust_remote_code=args.trust_remote_code) if args.dataset is None: - # Synthesize a prompt with the given input length. - # As tokenizer may add additional tokens like BOS, we need to try - # different lengths to get the desired input length. - for i in range(-10, 10): - prompt = "hi " * (args.input_len + i) - tokenized_prompt = tokenizer(prompt).input_ids - if len(tokenized_prompt) == args.input_len: - break - else: - raise ValueError( - f"Failed to synthesize a prompt with {args.input_len} tokens.") - requests = [ - SampleRequest(prompt=prompt, - prompt_len=args.input_len, - expected_output_len=args.output_len) - for _ in range(args.num_prompts) - ] + vocab_size = tokenizer.vocab_size + requests = [] + for _ in range(args.num_prompts): + # Synthesize a prompt with the given input length. + candidate_ids = [ + random.randint(0, vocab_size - 1) + for _ in range(args.input_len) + ] + # As tokenizer may add additional tokens like BOS, we need to try + # different lengths to get the desired input length. + for _ in range(5): # Max attempts to correct + candidate_prompt = tokenizer.decode(candidate_ids) + tokenized_len = len(tokenizer.encode(candidate_prompt)) + + if tokenized_len == args.input_len: + break + + # Adjust length based on difference + diff = args.input_len - tokenized_len + if diff > 0: + candidate_ids.extend([ + random.randint(100, vocab_size - 100) + for _ in range(diff) + ]) + else: + candidate_ids = candidate_ids[:diff] + requests.append( + SampleRequest(prompt=candidate_prompt, + prompt_len=args.input_len, + expected_output_len=args.output_len)) else: requests = sample_requests(tokenizer, args) diff --git a/benchmarks/disagg_benchmarks/disagg_overhead_benchmark.sh b/benchmarks/disagg_benchmarks/disagg_overhead_benchmark.sh new file mode 100644 index 0000000000000..2924ea4a49f54 --- /dev/null +++ b/benchmarks/disagg_benchmarks/disagg_overhead_benchmark.sh @@ -0,0 +1,144 @@ +#!/bin/bash + +# benchmark the overhead of disaggregated prefill. +# methodology: +# - send all request to prefill vLLM instance. It will buffer KV cache. +# - then send all request to decode instance. +# - The TTFT of decode instance is the overhead. + +set -ex + +kill_gpu_processes() { + # kill all processes on GPU. + pkill -f pt_main_thread + sleep 10 + + # remove vllm config file + rm -rf ~/.config/vllm + + # Print the GPU memory usage + # so that we know if all GPU processes are killed. + gpu_memory_usage=$(nvidia-smi --query-gpu=memory.used --format=csv,noheader,nounits -i 0) + # The memory usage should be 0 MB. + echo "GPU 0 Memory Usage: $gpu_memory_usage MB" +} + +wait_for_server() { + # wait for vllm server to start + # return 1 if vllm server crashes + local port=$1 + timeout 1200 bash -c " + until curl -s localhost:${port}/v1/completions > /dev/null; do + sleep 1 + done" && return 0 || return 1 +} + + +benchmark() { + + export VLLM_LOGGING_LEVEL=DEBUG + export VLLM_HOST_IP=$(hostname -I | awk '{print $1}') + + # compare chunked prefill with disaggregated prefill + + results_folder="./results" + model="meta-llama/Meta-Llama-3.1-8B-Instruct" + dataset_name="sonnet" + dataset_path="../sonnet_4x.txt" + num_prompts=10 + qps=$1 + prefix_len=50 + input_len=2048 + output_len=$2 + + + CUDA_VISIBLE_DEVICES=0 python3 \ + -m vllm.entrypoints.openai.api_server \ + --model meta-llama/Meta-Llama-3.1-8B-Instruct \ + --port 8100 \ + --max-model-len 10000 \ + --gpu-memory-utilization 0.6 \ + --kv-transfer-config \ + '{"kv_connector":"PyNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2,"kv_buffer_size":5e9}' & + + + CUDA_VISIBLE_DEVICES=1 python3 \ + -m vllm.entrypoints.openai.api_server \ + --model meta-llama/Meta-Llama-3.1-8B-Instruct \ + --port 8200 \ + --max-model-len 10000 \ + --gpu-memory-utilization 0.6 \ + --kv-transfer-config \ + '{"kv_connector":"PyNcclConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2,"kv_buffer_size":5e9}' & + + wait_for_server 8100 + wait_for_server 8200 + + # let the prefill instance finish prefill + python3 ../benchmark_serving.py \ + --backend vllm \ + --model $model \ + --dataset-name $dataset_name \ + --dataset-path $dataset_path \ + --sonnet-input-len $input_len \ + --sonnet-output-len "$output_len" \ + --sonnet-prefix-len $prefix_len \ + --num-prompts $num_prompts \ + --port 8100 \ + --save-result \ + --result-dir $results_folder \ + --result-filename disagg_prefill_2xtp4.json \ + --request-rate "inf" + + + # send the request to decode. + # The TTFT of this command will be the overhead of disagg prefill impl. + python3 ../benchmark_serving.py \ + --backend vllm \ + --model $model \ + --dataset-name $dataset_name \ + --dataset-path $dataset_path \ + --sonnet-input-len $input_len \ + --sonnet-output-len "$output_len" \ + --sonnet-prefix-len $prefix_len \ + --num-prompts $num_prompts \ + --port 8200 \ + --save-result \ + --result-dir $results_folder \ + --result-filename disagg_prefill_2xtp4.json \ + --request-rate "$qps" + kill_gpu_processes + +} + + +main() { + + (which wget && which curl) || (apt-get update && apt-get install -y wget curl) + (which jq) || (apt-get -y install jq) + (which socat) || (apt-get -y install socat) + + pip install quart httpx + + cd "$(dirname "$0")" + + cd .. + # create sonnet-4x.txt + echo "" > sonnet_4x.txt + for _ in {1..4} + do + cat sonnet.txt >> sonnet_4x.txt + done + cd disagg_benchmarks + + rm -rf results + mkdir results + + default_qps=1 + default_output_len=1 + benchmark $default_qps $default_output_len + +} + + +main "$@" diff --git a/benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh b/benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh new file mode 100644 index 0000000000000..d8d9e976dce76 --- /dev/null +++ b/benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh @@ -0,0 +1,164 @@ +#!/bin/bash + +# Requirement: 8x H100 GPUs. + + +# Model: neuralmagic/Meta-Llama-3-70B-Instruct-FP8-KV +# Query: 2048 input tokens, 11 output tokens, QPS 4, 500 requests +# Resource: 8x H100 +# Approaches: +# 1. Chunked prefill: 1 vllm instance with tp=8 +# 2. Chunked prefill: 2 vllm instance with tp=4, equivalent to 1 tp=4 instance with QPS 4 +# 3. Disaggregated prefill: 1 prefilling instance and 1 decoding instance +# Prefilling instance: max_output_token=1 +# Decoding instance: force the input tokens be the same across requests to bypass prefilling + +set -ex + +kill_gpu_processes() { + # kill all processes on GPU. + pgrep pt_main_thread | xargs -r kill -9 + pgrep python3 | xargs -r kill -9 + for port in 8000 8100 8200; do lsof -t -i:$port | xargs -r kill -9; done + sleep 1 +} + +wait_for_server() { + # wait for vllm server to start + # return 1 if vllm server crashes + local port=$1 + timeout 1200 bash -c " + until curl -s localhost:${port}/v1/completions > /dev/null; do + sleep 1 + done" && return 0 || return 1 +} + + +launch_chunked_prefill() { + model="meta-llama/Meta-Llama-3.1-8B-Instruct" + # disagg prefill + CUDA_VISIBLE_DEVICES=0 python3 \ + -m vllm.entrypoints.openai.api_server \ + --model $model \ + --port 8100 \ + --max-model-len 10000 \ + --enable-chunked-prefill \ + --gpu-memory-utilization 0.6 & + CUDA_VISIBLE_DEVICES=1 python3 \ + -m vllm.entrypoints.openai.api_server \ + --model $model \ + --port 8200 \ + --max-model-len 10000 \ + --enable-chunked-prefill \ + --gpu-memory-utilization 0.6 & + wait_for_server 8100 + wait_for_server 8200 + python3 round_robin_proxy.py & + sleep 1 +} + + +launch_disagg_prefill() { + model="meta-llama/Meta-Llama-3.1-8B-Instruct" + # disagg prefill + CUDA_VISIBLE_DEVICES=0 python3 \ + -m vllm.entrypoints.openai.api_server \ + --model $model \ + --port 8100 \ + --max-model-len 10000 \ + --gpu-memory-utilization 0.6 \ + --kv-transfer-config \ + '{"kv_connector":"PyNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2,"kv_buffer_size":5e9}' & + + CUDA_VISIBLE_DEVICES=1 python3 \ + -m vllm.entrypoints.openai.api_server \ + --model $model \ + --port 8200 \ + --max-model-len 10000 \ + --gpu-memory-utilization 0.6 \ + --kv-transfer-config \ + '{"kv_connector":"PyNcclConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2,"kv_buffer_size":5e9}' & + + wait_for_server 8100 + wait_for_server 8200 + python3 disagg_prefill_proxy_server.py & + sleep 1 +} + + +benchmark() { + results_folder="./results" + model="meta-llama/Meta-Llama-3.1-8B-Instruct" + dataset_name="sonnet" + dataset_path="../sonnet_4x.txt" + num_prompts=100 + qps=$1 + prefix_len=50 + input_len=1024 + output_len=$2 + tag=$3 + + python3 ../benchmark_serving.py \ + --backend vllm \ + --model $model \ + --dataset-name $dataset_name \ + --dataset-path $dataset_path \ + --sonnet-input-len $input_len \ + --sonnet-output-len "$output_len" \ + --sonnet-prefix-len $prefix_len \ + --num-prompts $num_prompts \ + --port 8000 \ + --save-result \ + --result-dir $results_folder \ + --result-filename "$tag"-qps-"$qps".json \ + --request-rate "$qps" + + sleep 2 + +} + + +main() { + + (which wget && which curl) || (apt-get update && apt-get install -y wget curl) + (which jq) || (apt-get -y install jq) + (which socat) || (apt-get -y install socat) + + pip install quart httpx matplotlib aiohttp + + cd "$(dirname "$0")" + + cd .. + # create sonnet-4x.txt so that we can sample 2048 tokens for input + echo "" > sonnet_4x.txt + for _ in {1..4} + do + cat sonnet.txt >> sonnet_4x.txt + done + cd disagg_benchmarks + + rm -rf results + mkdir results + + default_output_len=6 + + export VLLM_HOST_IP=$(hostname -I | awk '{print $1}') + + launch_chunked_prefill + for qps in 2 4 6 8; do + benchmark $qps $default_output_len chunked_prefill + done + kill_gpu_processes + + launch_disagg_prefill + for qps in 2 4 6 8; do + benchmark $qps $default_output_len disagg_prefill + done + kill_gpu_processes + + python3 visualize_benchmark_results.py + +} + + +main "$@" diff --git a/benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py b/benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py new file mode 100644 index 0000000000000..4058b1c0a3b79 --- /dev/null +++ b/benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py @@ -0,0 +1,61 @@ +import os + +import aiohttp +from quart import Quart, make_response, request + +AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60) + +app = Quart(__name__) + + +async def forward_request(url, data): + async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: + headers = { + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}" + } + async with session.post(url=url, json=data, + headers=headers) as response: + if response.status == 200: + # if response.headers.get('Transfer-Encoding') == 'chunked': + if True: + async for chunk_bytes in response.content.iter_chunked( + 1024): + yield chunk_bytes + else: + content = await response.read() + yield content + + +@app.route('/v1/completions', methods=['POST']) +async def handle_request(): + try: + original_request_data = await request.get_json() + + prefill_request = original_request_data.copy() + # change max_tokens = 1 to let it only do prefill + prefill_request['max_tokens'] = 1 + + # finish prefill + async for _ in forward_request('http://localhost:8100/v1/completions', + prefill_request): + continue + + # return decode + generator = forward_request('http://localhost:8200/v1/completions', + original_request_data) + response = await make_response(generator) + response.timeout = None + + return response + + except Exception as e: + import sys + import traceback + exc_info = sys.exc_info() + print("Error occurred in disagg prefill proxy server") + print(e) + print("".join(traceback.format_exception(*exc_info))) + + +if __name__ == '__main__': + app.run(port=8000) diff --git a/benchmarks/disagg_benchmarks/round_robin_proxy.py b/benchmarks/disagg_benchmarks/round_robin_proxy.py new file mode 100644 index 0000000000000..6eb5f63980070 --- /dev/null +++ b/benchmarks/disagg_benchmarks/round_robin_proxy.py @@ -0,0 +1,60 @@ +import asyncio +import itertools + +import aiohttp +from aiohttp import web + + +class RoundRobinProxy: + + def __init__(self, target_ports): + self.target_ports = target_ports + self.port_cycle = itertools.cycle(self.target_ports) + + async def handle_request(self, request): + target_port = next(self.port_cycle) + target_url = f"http://localhost:{target_port}{request.path_qs}" + + async with aiohttp.ClientSession() as session: + try: + # Forward the request + async with session.request( + method=request.method, + url=target_url, + headers=request.headers, + data=request.content, + ) as response: + # Start sending the response + resp = web.StreamResponse(status=response.status, + headers=response.headers) + await resp.prepare(request) + + # Stream the response content + async for chunk in response.content.iter_any(): + await resp.write(chunk) + + await resp.write_eof() + return resp + + except Exception as e: + return web.Response(text=f"Error: {str(e)}", status=500) + + +async def main(): + proxy = RoundRobinProxy([8100, 8200]) + app = web.Application() + app.router.add_route('*', '/{path:.*}', proxy.handle_request) + + runner = web.AppRunner(app) + await runner.setup() + site = web.TCPSite(runner, 'localhost', 8000) + await site.start() + + print("Proxy server started on http://localhost:8000") + + # Keep the server running + await asyncio.Event().wait() + + +if __name__ == '__main__': + asyncio.run(main()) diff --git a/benchmarks/disagg_benchmarks/visualize_benchmark_results.py b/benchmarks/disagg_benchmarks/visualize_benchmark_results.py new file mode 100644 index 0000000000000..e59d8bb0e6c8c --- /dev/null +++ b/benchmarks/disagg_benchmarks/visualize_benchmark_results.py @@ -0,0 +1,46 @@ +import json + +import matplotlib.pyplot as plt +import pandas as pd + +if __name__ == "__main__": + + data = [] + for name in ['disagg_prefill', 'chunked_prefill']: + for qps in [2, 4, 6, 8]: + with open(f"results/{name}-qps-{qps}.json") as f: + x = json.load(f) + x['name'] = name + x['qps'] = qps + data.append(x) + + df = pd.DataFrame.from_dict(data) + dis_df = df[df['name'] == 'disagg_prefill'] + chu_df = df[df['name'] == 'chunked_prefill'] + + plt.style.use('bmh') + plt.rcParams['font.size'] = 20 + + for key in [ + 'mean_ttft_ms', 'median_ttft_ms', 'p99_ttft_ms', 'mean_itl_ms', + 'median_itl_ms', 'p99_itl_ms' + ]: + + fig, ax = plt.subplots(figsize=(11, 7)) + plt.plot(dis_df['qps'], + dis_df[key], + label='disagg_prefill', + marker='o', + linewidth=4) + plt.plot(chu_df['qps'], + chu_df[key], + label='chunked_prefill', + marker='o', + linewidth=4) + ax.legend() + + ax.set_xlabel('QPS') + ax.set_ylabel(key) + ax.set_ylim(bottom=0) + fig.savefig(f'results/{key}.png') + plt.close(fig) diff --git a/benchmarks/structured_schemas/structured_schema_1.json b/benchmarks/structured_schemas/structured_schema_1.json new file mode 100644 index 0000000000000..6003698469e8d --- /dev/null +++ b/benchmarks/structured_schemas/structured_schema_1.json @@ -0,0 +1,113 @@ +{ + "$schema": + "https://json-schema.org/draft/2020-12/schema", + "title": + "User Profile", + "type": + "object", + "properties": { + "userId": { + "type": "string", + "description": "Unique identifier for the user." + }, + "personalInfo": { + "type": "object", + "properties": { + "firstName": { + "type": "string", + "description": "The user's first name." + }, + "lastName": { + "type": "string", + "description": "The user's last name." + }, + "age": { + "type": "integer", + "minimum": 0, + "description": "The user's age." + }, + "phoneNumbers": { + "type": + "array", + "items": { + "type": "object", + "properties": { + "type": { + "type": "string", + "enum": ["home", "work", "mobile"], + "description": "Type of phone number." + }, + "number": { + "type": "string", + "pattern": "^\\+?[1-9]\\d{1,14}$", + "description": "Phone number in E.164 format." + } + }, + "required": ["type", "number"] + }, + "description": + "List of phone numbers associated with the user." + } + }, + "required": ["firstName", "lastName"] + }, + "address": { + "type": "object", + "properties": { + "street": { + "type": "string", + "description": "Street address." + }, + "city": { + "type": "string", + "description": "City name." + }, + "state": { + "type": "string", + "description": "State or province." + }, + "postalCode": { + "type": "string", + "pattern": "^\\d{5}(-\\d{4})?$", + "description": "Postal code." + }, + "country": { + "type": "string", + "description": "Country name." + } + }, + "required": ["street", "city", "state", "postalCode", "country"] + }, + "preferences": { + "type": "object", + "properties": { + "newsletterSubscribed": { + "type": + "boolean", + "description": + "Indicates if the user is subscribed to the newsletter." + }, + "favoriteCategories": { + "type": "array", + "items": { + "type": "string" + }, + "description": "List of user's favorite categories." + } + }, + "required": ["newsletterSubscribed"] + }, + "accountStatus": { + "type": "string", + "enum": ["active", "inactive", "suspended"], + "description": "Current status of the user's account." + }, + "registrationDate": { + "type": "string", + "format": "date-time", + "description": "ISO 8601 formatted date-time of user registration." + } + }, + "required": + ["userId", "personalInfo", "address", "accountStatus", "registrationDate"] +} \ No newline at end of file diff --git a/csrc/attention/paged_attention_v1.cu b/csrc/attention/paged_attention_v1.cu index 741cd0c82dc89..cb1a069942069 100644 --- a/csrc/attention/paged_attention_v1.cu +++ b/csrc/attention/paged_attention_v1.cu @@ -140,13 +140,10 @@ void paged_attention_v1_launcher( blocksparse_block_size, blocksparse_head_sliding_step); #define CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \ - switch (is_block_sparse) { \ - case true: \ - CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, true); \ - break; \ - case false: \ - CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, false); \ - break; \ + if (is_block_sparse) { \ + CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, true); \ + } else { \ + CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, false); \ } // NOTE(woosuk): To reduce the compilation time, we omitted block sizes diff --git a/csrc/attention/paged_attention_v2.cu b/csrc/attention/paged_attention_v2.cu index 6de8d0bdd5b8d..c457bdb89008e 100644 --- a/csrc/attention/paged_attention_v2.cu +++ b/csrc/attention/paged_attention_v2.cu @@ -147,13 +147,10 @@ void paged_attention_v2_launcher( blocksparse_head_sliding_step); #define CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \ - switch (is_block_sparse) { \ - case true: \ - CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, true); \ - break; \ - case false: \ - CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, false); \ - break; \ + if (is_block_sparse) { \ + CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, true); \ + } else { \ + CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, false); \ } // NOTE(woosuk): To reduce the compilation time, we omitted block sizes diff --git a/csrc/mamba/causal_conv1d/causal_conv1d.cu b/csrc/mamba/causal_conv1d/causal_conv1d.cu index 498d069c05f0d..dd1e6de2e0180 100644 --- a/csrc/mamba/causal_conv1d/causal_conv1d.cu +++ b/csrc/mamba/causal_conv1d/causal_conv1d.cu @@ -424,7 +424,7 @@ void causal_conv1d_fwd_kernel(ConvParamsBase params) { // and the one before it (chunk = n_chunks - 1 and chunk = n_chunks - 2), // (which occurs when `final_state_position` is a non-positivie index) // we load the correct data from smem_exchange from both chunks, the last chunk iteration and the one before it - if (final_state_position < 0 && seqlen > kWidth){ + if (conv_states != nullptr && final_state_position < 0 && seqlen > kWidth){ input_t vals_load[kNElts] = {0}; if ((chunk == n_chunks - 2) && (tidx == kNThreads - 1)){ // chunk = n_chunks - 2, a segment of the final state sits in the last index diff --git a/docs/requirements-docs.txt b/docs/requirements-docs.txt index e3e35844405ac..5c80645b405ae 100644 --- a/docs/requirements-docs.txt +++ b/docs/requirements-docs.txt @@ -12,8 +12,9 @@ pydantic >= 2.8 torch py-cpuinfo transformers -mistral_common >= 1.3.4 +mistral_common >= 1.5.0 aiohttp starlette openai # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args -partial-json-parser # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args \ No newline at end of file +partial-json-parser # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args +requests diff --git a/docs/source/automatic_prefix_caching/details.md b/docs/source/automatic_prefix_caching/details.md index 2d3214e28ed93..17f806217aa65 100644 --- a/docs/source/automatic_prefix_caching/details.md +++ b/docs/source/automatic_prefix_caching/details.md @@ -25,7 +25,7 @@ With this mapping, we can add another indirection in vLLM’s KV cache managemen This design achieves automatic prefix caching without the need of maintaining a tree structure among the KV blocks. More specifically, all of the blocks are independent of each other and can be allocated and freed by itself, which enables us to manages the KV cache as ordinary caches in operating system. -# Generalized Caching Policy +## Generalized Caching Policy Keeping all the KV blocks in a hash table enables vLLM to cache KV blocks from earlier requests to save memory and accelerate the computation of future requests. For example, if a new request shares the system prompt with the previous request, the KV cache of the shared prompt can directly be used for the new request without recomputation. However, the total KV cache space is limited and we have to decide which KV blocks to keep or evict when the cache is full. diff --git a/docs/source/conf.py b/docs/source/conf.py index 96ad9a4c26b09..e9d9ac68c9560 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -10,11 +10,13 @@ # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. +import inspect import logging import os import sys from typing import List +import requests from sphinx.ext import autodoc logger = logging.getLogger(__name__) @@ -34,6 +36,7 @@ extensions = [ "sphinx.ext.napoleon", "sphinx.ext.viewcode", + "sphinx.ext.linkcode", "sphinx.ext.intersphinx", "sphinx_copybutton", "sphinx.ext.autodoc", @@ -94,6 +97,69 @@ def setup(app): generate_examples() +_cached_base: str = "" +_cached_branch: str = "" + + +def get_repo_base_and_branch(pr_number): + global _cached_base, _cached_branch + if _cached_base and _cached_branch: + return _cached_base, _cached_branch + + url = f"https://api.github.com/repos/vllm-project/vllm/pulls/{pr_number}" + response = requests.get(url) + if response.status_code == 200: + data = response.json() + _cached_base = data['head']['repo']['full_name'] + _cached_branch = data['head']['ref'] + return _cached_base, _cached_branch + else: + logger.error("Failed to fetch PR details: %s", response) + return None, None + + +def linkcode_resolve(domain, info): + if domain != 'py': + return None + if not info['module']: + return None + filename = info['module'].replace('.', '/') + module = info['module'] + + # try to determine the correct file and line number to link to + obj = sys.modules[module] + + # get as specific as we can + lineno: int = 0 + filename: str = "" + try: + for part in info['fullname'].split('.'): + obj = getattr(obj, part) + + if not (inspect.isclass(obj) or inspect.isfunction(obj) + or inspect.ismethod(obj)): + obj = obj.__class__ # Get the class of the instance + + lineno = inspect.getsourcelines(obj)[1] + filename = (inspect.getsourcefile(obj) + or f"{filename}.py").split("vllm/", 1)[1] + except Exception: + # For some things, like a class member, won't work, so + # we'll use the line number of the parent (the class) + pass + + if filename.startswith("checkouts/"): + # a PR build on readthedocs + pr_number = filename.split("/")[1] + filename = filename.split("/", 2)[2] + base, branch = get_repo_base_and_branch(pr_number) + if base and branch: + return f"https://github.com/{base}/blob/{branch}/{filename}#L{lineno}" + + # Otherwise, link to the source file on the main branch + return f"https://github.com/vllm-project/vllm/blob/main/{filename}#L{lineno}" + + # Mock out external dependencies here, otherwise the autodoc pages may be blank. autodoc_mock_imports = [ "compressed_tensors", @@ -112,6 +178,7 @@ def setup(app): "tensorizer", "pynvml", "outlines", + "xgrammar," "librosa", "soundfile", "gguf", diff --git a/docs/source/design/arch_overview.rst b/docs/source/design/arch_overview.rst index a9e7b4bd69bc7..bc3f509f0a66e 100644 --- a/docs/source/design/arch_overview.rst +++ b/docs/source/design/arch_overview.rst @@ -42,7 +42,7 @@ Here is a sample of `LLM` class usage: sampling_params = SamplingParams(temperature=0.8, top_p=0.95) # Initialize the LLM engine with the OPT-125M model - llm = LLM(model="Qwen/Qwen2.5-1.5B-Instruct") + llm = LLM(model="facebook/opt-125m") # Generate outputs for the input prompts outputs = llm.generate(prompts, sampling_params) diff --git a/docs/source/design/multimodal/multimodal_index.rst b/docs/source/design/multimodal/multimodal_index.rst index 30f543abc20c7..c6d47f90b62d5 100644 --- a/docs/source/design/multimodal/multimodal_index.rst +++ b/docs/source/design/multimodal/multimodal_index.rst @@ -7,7 +7,7 @@ Multi-Modality vLLM provides experimental support for multi-modal models through the :mod:`vllm.multimodal` package. -Multi-modal inputs can be passed alongside text and token prompts to :ref:`supported models ` +Multi-modal inputs can be passed alongside text and token prompts to :ref:`supported models ` via the ``multi_modal_data`` field in :class:`vllm.inputs.PromptType`. Currently, vLLM only has built-in support for image data. You can extend vLLM to process additional modalities @@ -15,9 +15,6 @@ by following :ref:`this guide `. Looking to add your own multi-modal model? Please follow the instructions listed :ref:`here `. -.. - TODO: Add usage of --limit-mm-per-prompt when multi-image input is officially supported - Guides ++++++ diff --git a/docs/source/getting_started/gaudi-installation.rst b/docs/source/getting_started/gaudi-installation.rst index 68c1a56660fa4..249e08278ff8f 100644 --- a/docs/source/getting_started/gaudi-installation.rst +++ b/docs/source/getting_started/gaudi-installation.rst @@ -4,7 +4,7 @@ Installation with Intel® Gaudi® AI Accelerators This README provides instructions on running vLLM with Intel Gaudi devices. Requirements and Installation -============================= +----------------------------- Please follow the instructions provided in the `Gaudi Installation Guide `__ @@ -13,7 +13,7 @@ please follow the methods outlined in the `Optimizing Training Platform Guide `__. Requirements ------------- +~~~~~~~~~~~~ - OS: Ubuntu 22.04 LTS - Python: 3.10 @@ -22,7 +22,7 @@ Requirements Quick start using Dockerfile ----------------------------- +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. code:: console $ docker build -f Dockerfile.hpu -t vllm-hpu-env . @@ -34,10 +34,10 @@ Quick start using Dockerfile Build from source ------------------ +~~~~~~~~~~~~~~~~~ Environment verification -~~~~~~~~~~~~~~~~~~~~~~~~ +^^^^^^^^^^^^^^^^^^^^^^^^ To verify that the Intel Gaudi software was correctly installed, run: @@ -53,7 +53,7 @@ Verification `__ @@ -107,7 +107,7 @@ Supported Features - Attention with Linear Biases (ALiBi) Unsupported Features -==================== +-------------------- - Beam search - LoRA adapters @@ -115,7 +115,7 @@ Unsupported Features - Prefill chunking (mixed-batch inferencing) Supported Configurations -======================== +------------------------ The following configurations have been validated to be function with Gaudi2 devices. Configurations that are not listed may or may not work. @@ -152,10 +152,10 @@ Gaudi2 devices. Configurations that are not listed may or may not work. with tensor parallelism on 8x HPU, BF16 datatype with random or greedy sampling Performance Tuning -================== +------------------ Execution modes ---------------- +~~~~~~~~~~~~~~~ Currently in vLLM for HPU we support four execution modes, depending on selected HPU PyTorch Bridge backend (via ``PT_HPU_LAZY_MODE`` environment variable), and ``--enforce-eager`` flag. @@ -184,7 +184,7 @@ Currently in vLLM for HPU we support four execution modes, depending on selected Bucketing mechanism -------------------- +~~~~~~~~~~~~~~~~~~~ Intel Gaudi accelerators work best when operating on models with fixed tensor shapes. `Intel Gaudi Graph Compiler `__ is responsible for generating optimized binary code that implements the given model topology on Gaudi. In its default configuration, the produced binary code may be heavily dependent on input and output tensor shapes, and can require graph recompilation when encountering differently shaped tensors within the same topology. While the resulting binaries utilize Gaudi efficiently, the compilation itself may introduce a noticeable overhead in end-to-end execution. In a dynamic inference serving scenario, there is a need to minimize the number of graph compilations and reduce the risk of graph compilation occurring during server runtime. Currently it is achieved by "bucketing" model's forward pass across two dimensions - ``batch_size`` and ``sequence_length``. @@ -233,7 +233,7 @@ As an example, if a request of 3 sequences, with max sequence length of 412 come Bucketing is transparent to a client - padding in sequence length dimension is never returned to the client, and padding in batch dimension does not create new requests. Warmup ------- +~~~~~~ Warmup is an optional, but highly recommended step occurring before vLLM server starts listening. It executes a forward pass for each bucket with dummy data. The goal is to pre-compile all graphs and not incur any graph compilation overheads within bucket boundaries during server runtime. Each warmup step is logged during vLLM startup: @@ -257,7 +257,7 @@ This example uses the same buckets as in *Bucketing mechanism* section. Each out Compiling all the buckets might take some time and can be turned off with ``VLLM_SKIP_WARMUP=true`` environment variable. Keep in mind that if you do that, you may face graph compilations once executing a given bucket for the first time. It is fine to disable warmup for development, but it's highly recommended to enable it in deployment. HPU Graph capture ------------------ +~~~~~~~~~~~~~~~~~ `HPU Graphs `__ are currently the most performant execution method of vLLM on Intel Gaudi. When HPU Graphs are enabled, execution graphs will be traced (recorded) ahead of time (after performing warmup), to be later replayed during inference, significantly reducing host overheads. Recording can take large amounts of memory, which needs to be taken into account when allocating KV cache. Enabling HPU Graphs will impact the number of available KV cache blocks, but vLLM provides user-configurable variables to control memory management. @@ -321,7 +321,7 @@ Each described step is logged by vLLM server, as follows (negative values corres Recommended vLLM Parameters ---------------------------- +~~~~~~~~~~~~~~~~~~~~~~~~~~~ - We recommend running inference on Gaudi 2 with ``block_size`` of 128 for BF16 data type. Using default values (16, 32) might lead to @@ -333,7 +333,7 @@ Recommended vLLM Parameters If you encounter out-of-memory issues, see troubleshooting section. Environment variables ---------------------- +~~~~~~~~~~~~~~~~~~~~~ **Diagnostic and profiling knobs:** @@ -380,7 +380,7 @@ Additionally, there are HPU PyTorch Bridge environment variables impacting vLLM - ``PT_HPU_ENABLE_LAZY_COLLECTIVES``: required to be ``true`` for tensor parallel inference with HPU Graphs Troubleshooting: Tweaking HPU Graphs -==================================== +------------------------------------ If you experience device out-of-memory issues or want to attempt inference at higher batch sizes, try tweaking HPU Graphs by following diff --git a/docs/source/getting_started/installation.rst b/docs/source/getting_started/installation.rst index e3dbbc9affe66..9b6cb0e80d60e 100644 --- a/docs/source/getting_started/installation.rst +++ b/docs/source/getting_started/installation.rst @@ -21,7 +21,7 @@ You can install vLLM using pip: .. code-block:: console $ # (Recommended) Create a new conda environment. - $ conda create -n myenv python=3.10 -y + $ conda create -n myenv python=3.12 -y $ conda activate myenv $ # Install vLLM with CUDA 12.1. @@ -73,7 +73,7 @@ Another way to access the latest code is to use the docker images: .. code-block:: console $ export VLLM_COMMIT=33f460b17a54acb3b6cc0b03f4a17876cff5eafd # use full commit hash from the main branch - $ docker pull public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:${VLLM_COMMIT} + $ docker pull public.ecr.aws/q9t5s3a7/vllm-ci-postmerge-repo:${VLLM_COMMIT} These docker images are used for CI and testing only, and they are not intended for production use. They will be expired after several days. @@ -89,45 +89,24 @@ Build from source Python-only build (without compilation) --------------------------------------- -If you only need to change Python code, you can simply build vLLM without compilation. - -The first step is to install the latest vLLM wheel: - -.. code-block:: console - - pip install https://vllm-wheels.s3.us-west-2.amazonaws.com/nightly/vllm-1.0.0.dev-cp38-abi3-manylinux1_x86_64.whl - -You can find more information about vLLM's wheels `above <#install-the-latest-code>`_. - -After verifying that the installation is successful, you can use `the following script `_: +If you only need to change Python code, you can build and install vLLM without compilation. Using `pip's ``--editable`` flag `_, changes you make to the code will be reflected when you run vLLM: .. code-block:: console $ git clone https://github.com/vllm-project/vllm.git $ cd vllm - $ python python_only_dev.py + $ VLLM_USE_PRECOMPILED=1 pip install --editable . -The script will: +This will download the latest nightly wheel and use the compiled libraries from there in the install. -* Find the installed vLLM package in the current environment. -* Copy built files to the current directory. -* Rename the installed vLLM package. -* Symbolically link the current directory to the installed vLLM package. - -Now, you can edit the Python code in the current directory, and the changes will be reflected when you run vLLM. - -Once you have finished editing or want to install another vLLM wheel, you should exit the development environment using `the same script `_ with the ``--quit-dev`` (or ``-q`` for short) flag: +The ``VLLM_PRECOMPILED_WHEEL_LOCATION`` environment variable can be used instead of ``VLLM_USE_PRECOMPILED`` to specify a custom path or URL to the wheel file. For example, to use the `0.6.1.post1 PyPi wheel `_: .. code-block:: console - $ python python_only_dev.py --quit-dev - -The ``--quit-dev`` flag will: - -* Remove the symbolic link from the current directory to the vLLM package. -* Restore the original vLLM package from the backup. + $ export VLLM_PRECOMPILED_WHEEL_LOCATION=https://files.pythonhosted.org/packages/4a/4c/ee65ba33467a4c0de350ce29fbae39b9d0e7fcd887cc756fa993654d1228/vllm-0.6.3.post1-cp38-abi3-manylinux1_x86_64.whl + $ pip install --editable . -If you update the vLLM wheel and rebuild from the source to make further edits, you will need to repeat the `Python-only build <#python-only-build>`_ steps again. +You can find more information about vLLM's wheels `above <#install-the-latest-code>`_. .. note:: @@ -148,9 +127,13 @@ If you want to modify C++ or CUDA code, you'll need to build vLLM from source. T .. tip:: Building from source requires a lot of compilation. If you are building from source repeatedly, it's more efficient to cache the compilation results. + For example, you can install `ccache `_ using ``conda install ccache`` or ``apt install ccache`` . As long as ``which ccache`` command can find the ``ccache`` binary, it will be used automatically by the build system. After the first build, subsequent builds will be much faster. + `sccache `_ works similarly to ``ccache``, but has the capability to utilize caching in remote storage environments. + The following environment variables can be set to configure the vLLM ``sccache`` remote: ``SCCACHE_BUCKET=vllm-build-sccache SCCACHE_REGION=us-west-2 SCCACHE_S3_NO_CREDENTIALS=1``. We also recommend setting ``SCCACHE_IDLE_TIMEOUT=0``. + Use an existing PyTorch installation ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/source/index.rst b/docs/source/index.rst index 0692e949f1c77..86b1eed2d26ba 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -85,12 +85,8 @@ Documentation serving/deploying_with_nginx serving/distributed_serving serving/metrics - serving/env_vars - serving/usage_stats serving/integrations serving/tensorizer - serving/compatibility_matrix - serving/faq .. toctree:: :maxdepth: 1 @@ -99,12 +95,21 @@ Documentation models/supported_models models/adding_model models/enabling_multimodal_inputs - models/engine_args - models/lora - models/vlm - models/structured_outputs - models/spec_decode - models/performance + +.. toctree:: + :maxdepth: 1 + :caption: Usage + + usage/lora + usage/multimodal_inputs + usage/structured_outputs + usage/spec_decode + usage/compatibility_matrix + usage/performance + usage/faq + usage/engine_args + usage/env_vars + usage/usage_stats .. toctree:: :maxdepth: 1 diff --git a/docs/source/models/enabling_multimodal_inputs.rst b/docs/source/models/enabling_multimodal_inputs.rst index 49b5285c45590..5c1236e1a8972 100644 --- a/docs/source/models/enabling_multimodal_inputs.rst +++ b/docs/source/models/enabling_multimodal_inputs.rst @@ -3,7 +3,7 @@ Enabling Multimodal Inputs ========================== -This document walks you through the steps to extend a vLLM model so that it accepts :ref:`multi-modal ` inputs. +This document walks you through the steps to extend a vLLM model so that it accepts :ref:`multi-modal inputs `. .. seealso:: :ref:`adding_a_new_model` diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index c5fbb30b24e28..c9b3fa8485ff1 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -139,6 +139,11 @@ Text Generation - :code:`google/gemma-2-9b`, :code:`google/gemma-2-27b`, etc. - ✅︎ - ✅︎ + * - :code:`GlmForCausalLM` + - GLM-4 + - :code:`THUDM/glm-4-9b-chat-hf`, etc. + - ✅︎ + - ✅︎ * - :code:`GPT2LMHeadModel` - GPT-2 - :code:`gpt2`, :code:`gpt2-xl`, etc. @@ -177,7 +182,7 @@ Text Generation * - :code:`InternLM2ForCausalLM` - InternLM2 - :code:`internlm/internlm2-7b`, :code:`internlm/internlm2-chat-7b`, etc. - - + - ✅︎ - ✅︎ * - :code:`JAISLMHeadModel` - Jais @@ -352,7 +357,7 @@ Text Embedding - ✅︎ * - :code:`Qwen2Model`, :code:`Qwen2ForCausalLM` - Qwen2-based - - :code:`ssmits/Qwen2-7B-Instruct-embed-base`, :code:`Alibaba-NLP/gte-Qwen2-7B-instruct` (see note), etc. + - :code:`ssmits/Qwen2-7B-Instruct-embed-base` (see note), :code:`Alibaba-NLP/gte-Qwen2-7B-instruct` (see note), etc. - ✅︎ - ✅︎ * - :code:`RobertaModel`, :code:`RobertaForMaskedLM` @@ -373,6 +378,10 @@ Text Embedding .. tip:: You can override the model's pooling method by passing :code:`--override-pooler-config`. +.. note:: + :code:`ssmits/Qwen2-7B-Instruct-embed-base` has an improperly defined Sentence Transformers config. + You should manually set mean pooling by passing :code:`--override-pooler-config '{"pooling_type": "MEAN"}'`. + .. note:: Unlike base Qwen2, :code:`Alibaba-NLP/gte-Qwen2-7B-instruct` uses bi-directional attention. You can set :code:`--hf-overrides '{"is_causal": false}'` to change the attention mask accordingly. @@ -392,12 +401,21 @@ Reward Modeling - Example HF Models - :ref:`LoRA ` - :ref:`PP ` + * - :code:`LlamaForCausalLM` + - Llama-based + - :code:`peiyi9979/math-shepherd-mistral-7b-prm`, etc. + - ✅︎ + - ✅︎ * - :code:`Qwen2ForRewardModel` - Qwen2-based - :code:`Qwen/Qwen2.5-Math-RM-72B`, etc. - ✅︎ - ✅︎ +.. important:: + For process-supervised reward models such as :code:`peiyi9979/math-shepherd-mistral-7b-prm`, the pooling config should be set explicitly, + e.g.: :code:`--override-pooler-config '{"pooling_type": "STEP", "step_tag_id": 123, "returned_token_ids": [456, 789]}'`. + .. note:: As an interim measure, these models are supported in both offline and online inference via Embeddings API. @@ -453,6 +471,8 @@ Sentence Pair Scoring .. note:: These models are supported in both offline and online inference via Score API. +.. _supported_mm_models: + Multimodal Language Models ^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -471,8 +491,6 @@ On the other hand, modalities separated by :code:`/` are mutually exclusive. - e.g.: :code:`T / I` means that the model supports text-only and image-only inputs, but not text-with-image inputs. -.. _supported_vlms: - Text Generation --------------- @@ -529,15 +547,15 @@ Text Generation - ✅︎ - * - :code:`InternVLChatModel` - - InternVL2 + - InternVL 2.5, Mono-InternVL, InternVL 2.0 - T + I\ :sup:`E+` - - :code:`OpenGVLab/Mono-InternVL-2B`, :code:`OpenGVLab/InternVL2-4B`, :code:`OpenGVLab/InternVL2-8B`, etc. + - :code:`OpenGVLab/InternVL2_5-4B`, :code:`OpenGVLab/Mono-InternVL-2B`, :code:`OpenGVLab/InternVL2-4B`, etc. - - ✅︎ * - :code:`LlavaForConditionalGeneration` - LLaVA-1.5 - T + I\ :sup:`E+` - - :code:`llava-hf/llava-1.5-7b-hf`, :code:`llava-hf/llava-1.5-13b-hf`, etc. + - :code:`llava-hf/llava-1.5-7b-hf`, :code:`TIGER-Lab/Mantis-8B-siglip-llama3` (see note), etc. - - ✅︎ * - :code:`LlavaNextForConditionalGeneration` @@ -628,9 +646,28 @@ Text Generation | :sup:`E` Pre-computed embeddings can be inputted for this modality. | :sup:`+` Multiple items can be inputted per text prompt for this modality. +.. important:: + To enable multiple multi-modal items per text prompt, you have to set :code:`limit_mm_per_prompt` (offline inference) + or :code:`--limit-mm-per-prompt` (online inference). For example, to enable passing up to 4 images per text prompt: + + .. code-block:: python + + llm = LLM( + model="Qwen/Qwen2-VL-7B-Instruct", + limit_mm_per_prompt={"image": 4}, + ) + + .. code-block:: bash + + vllm serve Qwen/Qwen2-VL-7B-Instruct --limit-mm-per-prompt image=4 + .. note:: vLLM currently only supports adding LoRA to the language backbone of multimodal models. +.. note:: + To use :code:`TIGER-Lab/Mantis-8B-siglip-llama3`, you have to install their GitHub repo (:code:`pip install git+https://github.com/TIGER-AI-Lab/Mantis.git`) + and pass :code:`--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'` when running vLLM. + .. note:: The official :code:`openbmb/MiniCPM-V-2` doesn't work yet, so we need to use a fork (:code:`HwwwH/MiniCPM-V-2`) for now. For more details, please see: https://github.com/vllm-project/vllm/pull/4087#issuecomment-2250397630 @@ -683,6 +720,9 @@ At vLLM, we are committed to facilitating the integration and support of third-p 2. **Best-Effort Consistency**: While we aim to maintain a level of consistency between the models implemented in vLLM and other frameworks like transformers, complete alignment is not always feasible. Factors like acceleration techniques and the use of low-precision computations can introduce discrepancies. Our commitment is to ensure that the implemented models are functional and produce sensible results. +.. tip:: + When comparing the output of :code:`model.generate` from HuggingFace Transformers with the output of :code:`llm.generate` from vLLM, note that the former reads the model's generation config file (i.e., `generation_config.json `__) and applies the default parameters for generation, while the latter only uses the parameters passed to the function. Ensure all sampling parameters are identical when comparing outputs. + 3. **Issue Resolution and Model Updates**: Users are encouraged to report any bugs or issues they encounter with third-party models. Proposed fixes should be submitted via PRs, with a clear explanation of the problem and the rationale behind the proposed solution. If a fix for one model impacts another, we rely on the community to highlight and address these cross-model dependencies. Note: for bugfix PRs, it is good etiquette to inform the original author to seek their feedback. 4. **Monitoring and Updates**: Users interested in specific models should monitor the commit history for those models (e.g., by tracking changes in the main/vllm/model_executor/models directory). This proactive approach helps users stay informed about updates and changes that may affect the models they use. diff --git a/docs/source/serving/deploying_with_kubeai.rst b/docs/source/serving/deploying_with_kubeai.rst new file mode 100644 index 0000000000000..ec3c065320fd9 --- /dev/null +++ b/docs/source/serving/deploying_with_kubeai.rst @@ -0,0 +1,17 @@ +.. _deploying_with_kubeai: + +Deploying with KubeAI +===================== + +`KubeAI `_ is a Kubernetes operator that enables you to deploy and manage AI models on Kubernetes. It provides a simple and scalable way to deploy vLLM in production. Functionality such as scale-from-zero, load based autoscaling, model caching, and much more is provided out of the box with zero external dependencies. + + +Please see the Installation Guides for environment specific instructions: + +* `Any Kubernetes Cluster `_ +* `EKS `_ +* `GKE `_ + +Once you have KubeAI installed, you can +`configure text generation models `_ +using vLLM. \ No newline at end of file diff --git a/docs/source/serving/integrations.rst b/docs/source/serving/integrations.rst index f39997e0e44d9..0dd505a739863 100644 --- a/docs/source/serving/integrations.rst +++ b/docs/source/serving/integrations.rst @@ -6,6 +6,7 @@ Integrations run_on_sky deploying_with_kserve + deploying_with_kubeai deploying_with_triton deploying_with_bentoml deploying_with_cerebrium diff --git a/docs/source/serving/openai_compatible_server.md b/docs/source/serving/openai_compatible_server.md index c39cef85897ed..d75e90807ca1d 100644 --- a/docs/source/serving/openai_compatible_server.md +++ b/docs/source/serving/openai_compatible_server.md @@ -32,7 +32,7 @@ We currently support the following OpenAI APIs: - [Completions API](https://platform.openai.com/docs/api-reference/completions) - *Note: `suffix` parameter is not supported.* - [Chat Completions API](https://platform.openai.com/docs/api-reference/chat) - - [Vision](https://platform.openai.com/docs/guides/vision)-related parameters are supported; see [Using VLMs](../models/vlm.rst). + - [Vision](https://platform.openai.com/docs/guides/vision)-related parameters are supported; see [Multimodal Inputs](../usage/multimodal_inputs.rst). - *Note: `image_url.detail` parameter is not supported.* - We also support `audio_url` content type for audio files. - Refer to [vllm.entrypoints.chat_utils](https://github.com/vllm-project/vllm/tree/main/vllm/entrypoints/chat_utils.py) for the exact schema. @@ -41,7 +41,7 @@ We currently support the following OpenAI APIs: - [Embeddings API](https://platform.openai.com/docs/api-reference/embeddings) - Instead of `inputs`, you can pass in a list of `messages` (same schema as Chat Completions API), which will be treated as a single prompt to the model according to its chat template. - - This enables multi-modal inputs to be passed to embedding models, see [Using VLMs](../models/vlm.rst). + - This enables multi-modal inputs to be passed to embedding models, see [this page](../usage/multimodal_inputs.rst) for details. - *Note: You should run `vllm serve` with `--task embedding` to ensure that the model is being run in embedding mode.* ## Score API for Cross Encoder Models diff --git a/docs/source/serving/compatibility_matrix.rst b/docs/source/usage/compatibility_matrix.rst similarity index 100% rename from docs/source/serving/compatibility_matrix.rst rename to docs/source/usage/compatibility_matrix.rst diff --git a/docs/source/models/engine_args.rst b/docs/source/usage/engine_args.rst similarity index 100% rename from docs/source/models/engine_args.rst rename to docs/source/usage/engine_args.rst diff --git a/docs/source/serving/env_vars.rst b/docs/source/usage/env_vars.rst similarity index 100% rename from docs/source/serving/env_vars.rst rename to docs/source/usage/env_vars.rst diff --git a/docs/source/serving/faq.rst b/docs/source/usage/faq.rst similarity index 99% rename from docs/source/serving/faq.rst rename to docs/source/usage/faq.rst index 9e858e612c8bf..ce327abd5fa20 100644 --- a/docs/source/serving/faq.rst +++ b/docs/source/usage/faq.rst @@ -1,3 +1,5 @@ +.. _faq: + Frequently Asked Questions =========================== diff --git a/docs/source/models/lora.rst b/docs/source/usage/lora.rst similarity index 99% rename from docs/source/models/lora.rst rename to docs/source/usage/lora.rst index ef0177eaf2162..c2c6fa2aebfaf 100644 --- a/docs/source/models/lora.rst +++ b/docs/source/usage/lora.rst @@ -1,7 +1,7 @@ .. _lora: -Using LoRA adapters -=================== +LoRA Adapters +============= This document shows you how to use `LoRA adapters `_ with vLLM on top of a base model. diff --git a/docs/source/models/vlm.rst b/docs/source/usage/multimodal_inputs.rst similarity index 62% rename from docs/source/models/vlm.rst rename to docs/source/usage/multimodal_inputs.rst index bcbe50a25fa09..c93f65327e31b 100644 --- a/docs/source/models/vlm.rst +++ b/docs/source/usage/multimodal_inputs.rst @@ -1,34 +1,31 @@ -.. _vlm: +.. _multimodal_inputs: -Using VLMs -========== +Multimodal Inputs +================= -vLLM provides experimental support for Vision Language Models (VLMs). See the :ref:`list of supported VLMs here `. -This document shows you how to run and serve these models using vLLM. +This page teaches you how to pass multi-modal inputs to :ref:`multi-modal models ` in vLLM. .. note:: - We are actively iterating on VLM support. See `this RFC `_ for upcoming changes, + We are actively iterating on multi-modal support. See `this RFC `_ for upcoming changes, and `open an issue on GitHub `_ if you have any feedback or feature requests. Offline Inference ----------------- -Single-image input -^^^^^^^^^^^^^^^^^^ - -The :class:`~vllm.LLM` class can be instantiated in much the same way as language-only models. - -.. code-block:: python - - llm = LLM(model="llava-hf/llava-1.5-7b-hf") - -To pass an image to the model, note the following in :class:`vllm.inputs.PromptType`: +To input multi-modal data, follow this schema in :class:`vllm.inputs.PromptType`: * ``prompt``: The prompt should follow the format that is documented on HuggingFace. * ``multi_modal_data``: This is a dictionary that follows the schema defined in :class:`vllm.multimodal.MultiModalDataDict`. +Image +^^^^^ + +You can pass a single image to the :code:`'image'` field of the multi-modal dictionary, as shown in the following examples: + .. code-block:: python + llm = LLM(model="llava-hf/llava-1.5-7b-hf") + # Refer to the HuggingFace repo for the correct format to use prompt = "USER: \nWhat is the content of this image?\nASSISTANT:" @@ -41,41 +38,6 @@ To pass an image to the model, note the following in :class:`vllm.inputs.PromptT "multi_modal_data": {"image": image}, }) - for o in outputs: - generated_text = o.outputs[0].text - print(generated_text) - - # Inference with image embeddings as input - image_embeds = torch.load(...) # torch.Tensor of shape (1, image_feature_size, hidden_size of LM) - outputs = llm.generate({ - "prompt": prompt, - "multi_modal_data": {"image": image_embeds}, - }) - - for o in outputs: - generated_text = o.outputs[0].text - print(generated_text) - - # Inference with image embeddings as input with additional parameters - # Specifically, we are conducting a trial run of Qwen2VL and MiniCPM-V with the new input format, which utilizes additional parameters. - mm_data = {} - - image_embeds = torch.load(...) # torch.Tensor of shape (num_images, image_feature_size, hidden_size of LM) - # For Qwen2VL, image_grid_thw is needed to calculate positional encoding. - mm_data['image'] = { - "image_embeds": image_embeds, - "image_grid_thw": torch.load(...) # torch.Tensor of shape (1, 3), - } - # For MiniCPM-V, image_size_list is needed to calculate details of the sliced image. - mm_data['image'] = { - "image_embeds": image_embeds, - "image_size_list": [image.size] # list of image sizes - } - outputs = llm.generate({ - "prompt": prompt, - "multi_modal_data": mm_data, - }) - for o in outputs: generated_text = o.outputs[0].text print(generated_text) @@ -102,12 +64,7 @@ To pass an image to the model, note the following in :class:`vllm.inputs.PromptT A code example can be found in `examples/offline_inference_vision_language.py `_. -Multi-image input -^^^^^^^^^^^^^^^^^ - -Multi-image input is only supported for a subset of VLMs, as shown :ref:`here `. - -To enable multiple multi-modal items per text prompt, you have to set ``limit_mm_per_prompt`` for the :class:`~vllm.LLM` class. +To substitute multiple images inside the same text prompt, you can pass in a list of images instead: .. code-block:: python @@ -118,10 +75,6 @@ To enable multiple multi-modal items per text prompt, you have to set ``limit_mm limit_mm_per_prompt={"image": 2}, # The maximum number to accept ) -Instead of passing in a single image, you can pass in a list of images. - -.. code-block:: python - # Refer to the HuggingFace repo for the correct format to use prompt = "<|user|>\n<|image_1|>\n<|image_2|>\nWhat is the content of each image?<|end|>\n<|assistant|>\n" @@ -169,30 +122,114 @@ Multi-image input can be extended to perform video captioning. We show this with generated_text = o.outputs[0].text print(generated_text) +Video +^^^^^ + +You can pass a list of NumPy arrays directly to the :code:`'video'` field of the multi-modal dictionary +instead of using multi-image input. + +Please refer to `examples/offline_inference_vision_language.py `_ for more details. + +Audio +^^^^^ + +You can pass a tuple :code:`(array, sampling_rate)` to the :code:`'audio'` field of the multi-modal dictionary. + +Please refer to `examples/offline_inference_audio_language.py `_ for more details. + +Embedding +^^^^^^^^^ + +To input pre-computed embeddings belonging to a data type (i.e. image, video, or audio) directly to the language model, +pass a tensor of shape :code:`(num_items, feature_size, hidden_size of LM)` to the corresponding field of the multi-modal dictionary. + +.. code-block:: python + + # Inference with image embeddings as input + llm = LLM(model="llava-hf/llava-1.5-7b-hf") + + # Refer to the HuggingFace repo for the correct format to use + prompt = "USER: \nWhat is the content of this image?\nASSISTANT:" + + # Embeddings for single image + # torch.Tensor of shape (1, image_feature_size, hidden_size of LM) + image_embeds = torch.load(...) + + outputs = llm.generate({ + "prompt": prompt, + "multi_modal_data": {"image": image_embeds}, + }) + + for o in outputs: + generated_text = o.outputs[0].text + print(generated_text) + +For Qwen2-VL and MiniCPM-V, we accept additional parameters alongside the embeddings: + +.. code-block:: python + + # Construct the prompt based on your model + prompt = ... + + # Embeddings for multiple images + # torch.Tensor of shape (num_images, image_feature_size, hidden_size of LM) + image_embeds = torch.load(...) + + # Qwen2-VL + llm = LLM("Qwen/Qwen2-VL-2B-Instruct", limit_mm_per_prompt={"image": 4}) + mm_data = { + "image": { + "image_embeds": image_embeds, + # image_grid_thw is needed to calculate positional encoding. + "image_grid_thw": torch.load(...), # torch.Tensor of shape (1, 3), + } + } + + # MiniCPM-V + llm = LLM("openbmb/MiniCPM-V-2_6", trust_remote_code=True, limit_mm_per_prompt={"image": 4}) + mm_data = { + "image": { + "image_embeds": image_embeds, + # image_size_list is needed to calculate details of the sliced image. + "image_size_list": [image.size for image in images], # list of image sizes + } + } + + outputs = llm.generate({ + "prompt": prompt, + "multi_modal_data": mm_data, + }) + + for o in outputs: + generated_text = o.outputs[0].text + print(generated_text) + Online Inference ---------------- -OpenAI Vision API -^^^^^^^^^^^^^^^^^ +Our OpenAI-compatible server accepts multi-modal data via the `Chat Completions API `_. + +.. important:: + A chat template is **required** to use Chat Completions API. + + Although most models come with a chat template, for others you have to define one yourself. + The chat template can be inferred based on the documentation on the model's HuggingFace repo. + For example, LLaVA-1.5 (``llava-hf/llava-1.5-7b-hf``) requires a chat template that can be found `here `__. + +Image +^^^^^ -You can serve vision language models with vLLM's HTTP server that is compatible with `OpenAI Vision API `_. +Image input is supported according to `OpenAI Vision API `_. +Here is a simple example using Phi-3.5-Vision. -Below is an example on how to launch the same ``microsoft/Phi-3.5-vision-instruct`` with vLLM's OpenAI-compatible API server. +First, launch the OpenAI-compatible server: .. code-block:: bash vllm serve microsoft/Phi-3.5-vision-instruct --task generate \ --trust-remote-code --max-model-len 4096 --limit-mm-per-prompt image=2 -.. important:: - Since OpenAI Vision API is based on `Chat Completions API `_, - a chat template is **required** to launch the API server. - - Although Phi-3.5-Vision comes with a chat template, for other models you may have to provide one if the model's tokenizer does not come with it. - The chat template can be inferred based on the documentation on the model's HuggingFace repo. - For example, LLaVA-1.5 (``llava-hf/llava-1.5-7b-hf``) requires a chat template that can be found `here `_. - -To consume the server, you can use the OpenAI client like in the example below: +Then, you can use the OpenAI client as follows: .. code-block:: python @@ -252,22 +289,59 @@ A full code example can be found in `examples/openai_chat_completion_client_for_ .. note:: - By default, the timeout for fetching images through http url is ``5`` seconds. You can override this by setting the environment variable: + By default, the timeout for fetching images through HTTP URL is ``5`` seconds. + You can override this by setting the environment variable: .. code-block:: console $ export VLLM_IMAGE_FETCH_TIMEOUT= -Chat Embeddings API -^^^^^^^^^^^^^^^^^^^ +Video +^^^^^ + +Instead of :code:`image_url`, you can pass a video file via :code:`video_url`. + +You can use `these tests `_ as reference. + +.. note:: + + By default, the timeout for fetching videos through HTTP URL url is ``30`` seconds. + You can override this by setting the environment variable: + + .. code-block:: console + + $ export VLLM_VIDEO_FETCH_TIMEOUT= -vLLM's Chat Embeddings API is a superset of OpenAI's `Embeddings API `_, -where a list of ``messages`` can be passed instead of batched ``inputs``. This enables multi-modal inputs to be passed to embedding models. +Audio +^^^^^ + +Instead of :code:`image_url`, you can pass an audio file via :code:`audio_url`. + +A full code example can be found in `examples/openai_chat_completion_client_for_multimodal.py `_. + +.. note:: + + By default, the timeout for fetching audios through HTTP URL is ``10`` seconds. + You can override this by setting the environment variable: + + .. code-block:: console + + $ export VLLM_AUDIO_FETCH_TIMEOUT= + +Embedding +^^^^^^^^^ + +vLLM's Embeddings API is a superset of OpenAI's `Embeddings API `_, +where a list of chat ``messages`` can be passed instead of batched ``inputs``. This enables multi-modal inputs to be passed to embedding models. .. tip:: The schema of ``messages`` is exactly the same as in Chat Completions API. + You can refer to the above tutorials for more details on how to pass each type of multi-modal data. -In this example, we will serve the ``TIGER-Lab/VLM2Vec-Full`` model. +Usually, embedding models do not expect chat-based input, so we need to use a custom chat template to format the text and images. +Refer to the examples below for illustration. + +Here is an end-to-end example using VLM2Vec. To serve the model: .. code-block:: bash @@ -279,10 +353,8 @@ In this example, we will serve the ``TIGER-Lab/VLM2Vec-Full`` model. Since VLM2Vec has the same model architecture as Phi-3.5-Vision, we have to explicitly pass ``--task embedding`` to run this model in embedding mode instead of text generation mode. -.. important:: - - VLM2Vec does not expect chat-based input. We use a `custom chat template `_ - to combine the text and images together. + The custom chat template is completely different from the original one for this model, + and can be found `here `__. Since the request schema is not defined by OpenAI client, we post a request to the server using the lower-level ``requests`` library: @@ -310,7 +382,7 @@ Since the request schema is not defined by OpenAI client, we post a request to t response_json = response.json() print("Embedding output:", response_json["data"][0]["embedding"]) -Here is an example for serving the ``MrLight/dse-qwen2-2b-mrl-v1`` model. +Below is another example, this time using the ``MrLight/dse-qwen2-2b-mrl-v1`` model. .. code-block:: bash @@ -319,8 +391,10 @@ Here is an example for serving the ``MrLight/dse-qwen2-2b-mrl-v1`` model. .. important:: - Like with VLM2Vec, we have to explicitly pass ``--task embedding``. Additionally, ``MrLight/dse-qwen2-2b-mrl-v1`` requires an EOS token for embeddings, - which is handled by the jinja template. + Like with VLM2Vec, we have to explicitly pass ``--task embedding``. + + Additionally, ``MrLight/dse-qwen2-2b-mrl-v1`` requires an EOS token for embeddings, which is handled + by `this custom chat template `__. .. important:: diff --git a/docs/source/models/performance.rst b/docs/source/usage/performance.rst similarity index 100% rename from docs/source/models/performance.rst rename to docs/source/usage/performance.rst diff --git a/docs/source/models/spec_decode.rst b/docs/source/usage/spec_decode.rst similarity index 97% rename from docs/source/models/spec_decode.rst rename to docs/source/usage/spec_decode.rst index d57ffec53215d..f1f1917f974bb 100644 --- a/docs/source/models/spec_decode.rst +++ b/docs/source/usage/spec_decode.rst @@ -1,13 +1,16 @@ .. _spec_decode: -Speculative decoding in vLLM -============================ +Speculative decoding +==================== .. warning:: Please note that speculative decoding in vLLM is not yet optimized and does not usually yield inter-token latency reductions for all prompt datasets or sampling parameters. The work to optimize it is ongoing and can be followed in `this issue. `_ +.. warning:: + Currently, speculative decoding in vLLM is not compatible with pipeline parallelism. + This document shows how to use `Speculative Decoding `_ with vLLM. Speculative decoding is a technique which improves inter-token latency in memory-bound LLM inference. @@ -182,7 +185,7 @@ speculative decoding, breaking down the guarantees into three key areas: 3. **vLLM Logprob Stability** - vLLM does not currently guarantee stable token log probabilities (logprobs). This can result in different outputs for the same request across runs. For more details, see the FAQ section - titled *Can the output of a prompt vary across runs in vLLM?* in the `FAQs <../serving/faq>`_. + titled *Can the output of a prompt vary across runs in vLLM?* in the :ref:`FAQs `. **Conclusion** @@ -197,7 +200,7 @@ can occur due to following factors: **Mitigation Strategies** -For mitigation strategies, please refer to the FAQ entry *Can the output of a prompt vary across runs in vLLM?* in the `FAQs <../serving/faq>`_. +For mitigation strategies, please refer to the FAQ entry *Can the output of a prompt vary across runs in vLLM?* in the :ref:`FAQs `. Resources for vLLM contributors ------------------------------- diff --git a/docs/source/models/structured_outputs.rst b/docs/source/usage/structured_outputs.rst similarity index 100% rename from docs/source/models/structured_outputs.rst rename to docs/source/usage/structured_outputs.rst diff --git a/docs/source/serving/usage_stats.md b/docs/source/usage/usage_stats.md similarity index 100% rename from docs/source/serving/usage_stats.md rename to docs/source/usage/usage_stats.md diff --git a/examples/disaggregated_prefill.sh b/examples/disaggregated_prefill.sh new file mode 100644 index 0000000000000..87155273a81d1 --- /dev/null +++ b/examples/disaggregated_prefill.sh @@ -0,0 +1,109 @@ +#!/bin/bash +# This file demonstrates the example usage of disaggregated prefilling +# We will launch 2 vllm instances (1 for prefill and 1 for decode), +# and then transfer the KV cache between them. + +echo "🚧🚧 Warning: The usage of disaggregated prefill is experimental and subject to change 🚧🚧" +sleep 1 + +# Trap the SIGINT signal (triggered by Ctrl+C) +trap 'cleanup' INT + +# Cleanup function +cleanup() { + echo "Caught Ctrl+C, cleaning up..." + # Cleanup commands + pgrep python | xargs kill -9 + pkill -f python + echo "Cleanup complete. Exiting." + exit 0 +} + +export VLLM_HOST_IP=$(hostname -I | awk '{print $1}') + +# install quart first -- required for disagg prefill proxy serve +if python3 -c "import quart" &> /dev/null; then + echo "Quart is already installed." +else + echo "Quart is not installed. Installing..." + python3 -m pip install quart +fi + +# a function that waits vLLM server to start +wait_for_server() { + local port=$1 + timeout 1200 bash -c " + until curl -s localhost:${port}/v1/completions > /dev/null; do + sleep 1 + done" && return 0 || return 1 +} + + +# You can also adjust --kv-ip and --kv-port for distributed inference. + +# prefilling instance, which is the KV producer +CUDA_VISIBLE_DEVICES=0 vllm serve meta-llama/Meta-Llama-3.1-8B-Instruct \ + --port 8100 \ + --max-model-len 100 \ + --gpu-memory-utilization 0.8 \ + --kv-transfer-config \ + '{"kv_connector":"PyNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2}' & + +# decoding instance, which is the KV consumer +CUDA_VISIBLE_DEVICES=1 vllm serve meta-llama/Meta-Llama-3.1-8B-Instruct \ + --port 8200 \ + --max-model-len 100 \ + --gpu-memory-utilization 0.8 \ + --kv-transfer-config \ + '{"kv_connector":"PyNcclConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2}' & + +# wait until prefill and decode instances are ready +wait_for_server 8100 +wait_for_server 8200 + +# launch a proxy server that opens the service at port 8000 +# the workflow of this proxy: +# - send the request to prefill vLLM instance (port 8100), change max_tokens +# to 1 +# - after the prefill vLLM finishes prefill, send the request to decode vLLM +# instance +# NOTE: the usage of this API is subject to change --- in the future we will +# introduce "vllm connect" to connect between prefill and decode instances +python3 ../benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py & +sleep 1 + +# serve two example requests +output1=$(curl -X POST -s http://localhost:8000/v1/completions \ +-H "Content-Type: application/json" \ +-d '{ +"model": "meta-llama/Meta-Llama-3.1-8B-Instruct", +"prompt": "San Francisco is a", +"max_tokens": 10, +"temperature": 0 +}') + +output2=$(curl -X POST -s http://localhost:8000/v1/completions \ +-H "Content-Type: application/json" \ +-d '{ +"model": "meta-llama/Meta-Llama-3.1-8B-Instruct", +"prompt": "Santa Clara is a", +"max_tokens": 10, +"temperature": 0 +}') + + +# Cleanup commands +pgrep python | xargs kill -9 +pkill -f python + +echo "" + +sleep 1 + +# Print the outputs of the curl requests +echo "" +echo "Output of first request: $output1" +echo "Output of second request: $output2" + +echo "🎉🎉 Successfully finished 2 test requests! 🎉🎉" +echo "" diff --git a/examples/offline_inference_embedding.py b/examples/offline_inference_embedding.py index 7d5ef128bc8e0..ae158eef2ca4c 100644 --- a/examples/offline_inference_embedding.py +++ b/examples/offline_inference_embedding.py @@ -10,7 +10,7 @@ # Create an LLM. model = LLM(model="intfloat/e5-mistral-7b-instruct", enforce_eager=True) -# Generate embedding. The output is a list of EmbeddingRequestOutputs. +# Generate embedding. The output is a list of PoolingRequestOutputs. outputs = model.encode(prompts) # Print the outputs. for output in outputs: diff --git a/examples/offline_inference_vision_language.py b/examples/offline_inference_vision_language.py index f08f22eec164a..c6a274ee5894b 100644 --- a/examples/offline_inference_vision_language.py +++ b/examples/offline_inference_vision_language.py @@ -223,7 +223,7 @@ def run_internvl(question: str, modality: str): # Stop tokens for InternVL # models variants may have different stop tokens # please refer to the model card for the correct "stop words": - # https://huggingface.co/OpenGVLab/InternVL2-2B#service + # https://huggingface.co/OpenGVLab/InternVL2-2B/blob/main/conversation.py stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|end|>"] stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens] return llm, prompt, stop_token_ids @@ -419,6 +419,22 @@ def run_aria(question: str, modality: str): return llm, prompt, stop_token_ids +# Mantis +def run_mantis(question: str, modality: str): + assert modality == "image" + + llama3_template = '<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n' # noqa: E501 + prompt = llama3_template.format(f"{question}\n") + + llm = LLM( + model="TIGER-Lab/Mantis-8B-siglip-llama3", + max_model_len=4096, + hf_overrides={"architectures": ["MantisForConditionalGeneration"]}, + ) + stop_token_ids = [128009] + return llm, prompt, stop_token_ids + + model_example_map = { "llava": run_llava, "llava-next": run_llava_next, @@ -441,6 +457,7 @@ def run_aria(question: str, modality: str): "glm4v": run_glm4v, "idefics3": run_idefics3, "aria": run_aria, + "mantis": run_mantis, } diff --git a/examples/offline_inference_vision_language_multi_image.py b/examples/offline_inference_vision_language_multi_image.py index 788b604cfd4a0..928bbef54eab7 100644 --- a/examples/offline_inference_vision_language_multi_image.py +++ b/examples/offline_inference_vision_language_multi_image.py @@ -165,7 +165,7 @@ def load_internvl(question: str, image_urls: List[str]) -> ModelRequestData: # Stop tokens for InternVL # models variants may have different stop tokens # please refer to the model card for the correct "stop words": - # https://huggingface.co/OpenGVLab/InternVL2-2B#service + # https://huggingface.co/OpenGVLab/InternVL2-2B/blob/main/conversation.py stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|end|>"] stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens] diff --git a/examples/tool_chat_template_llama3.2_json.jinja b/examples/tool_chat_template_llama3.2_json.jinja index 39f902c1c3c40..2b290c0eede03 100644 --- a/examples/tool_chat_template_llama3.2_json.jinja +++ b/examples/tool_chat_template_llama3.2_json.jinja @@ -26,13 +26,11 @@ {%- endfor %} {%- endfor %} - {#- This block extracts the system message, so we can slot it into the right place. #} {%- if messages[0]['role'] == 'system' %} {%- if messages[0]['content'] is string %} {%- set system_message = messages[0]['content']|trim %} {%- else %} - {#- Support vLLM's transforming of a content string to JSON. #} {%- set system_message = messages[0]['content'][0]['text']|trim %} {%- endif %} {%- set messages = messages[1:] %} @@ -44,14 +42,8 @@ {%- endif %} {%- endif %} -{#- Including an image is not compatible with a system message #} -{%- if image_ns.has_images and not system_message == "" %} - {{- raise_exception("Prompting with images is incompatible with system messages and tool use.") }} -{%- endif %} - - -{#- System message, if there are no images #} -{%- if not image_ns.has_images %} +{#- System message if there are no images, if the user supplied one, or if tools are used (default tool system message) #} +{%- if system_message or not image_ns.has_images %} {{- "<|start_header_id|>system<|end_header_id|>\n\n" }} {%- if tools is not none %} {{- "Environment: ipython\n" }} diff --git a/python_only_dev.py b/python_only_dev.py index 1ca0f5c30b741..f70b4984025b3 100644 --- a/python_only_dev.py +++ b/python_only_dev.py @@ -1,92 +1,14 @@ -# enable python only development -# copy compiled files to the current directory directly +msg = """Old style python only build (without compilation) is deprecated, please check https://docs.vllm.ai/en/latest/getting_started/installation.html#python-only-build-without-compilation for the new way to do python only build (without compilation). -import argparse -import os -import shutil -import subprocess -import sys -import warnings +TL;DR: -parser = argparse.ArgumentParser( - description="Development mode for python-only code") -parser.add_argument('-q', - '--quit-dev', - action='store_true', - help='Set the flag to quit development mode') -args = parser.parse_args() +VLLM_USE_PRECOMPILED=1 pip install -e . -# cannot directly `import vllm` , because it will try to -# import from the current directory -output = subprocess.run([sys.executable, "-m", "pip", "show", "vllm"], - capture_output=True) +or -assert output.returncode == 0, "vllm is not installed" +export VLLM_COMMIT=33f460b17a54acb3b6cc0b03f4a17876cff5eafd # use full commit hash from the main branch +export VLLM_PRECOMPILED_WHEEL_LOCATION=https://vllm-wheels.s3.us-west-2.amazonaws.com/${VLLM_COMMIT}/vllm-1.0.0.dev-cp38-abi3-manylinux1_x86_64.whl +pip install -e . +""" # noqa -text = output.stdout.decode("utf-8") - -package_path = None -for line in text.split("\n"): - if line.startswith("Location: "): - package_path = line.split(": ")[1] - break - -assert package_path is not None, "could not find package path" - -cwd = os.getcwd() - -assert cwd != package_path, "should not import from the current directory" - -files_to_copy = [ - "vllm/_C.abi3.so", - "vllm/_moe_C.abi3.so", - "vllm/vllm_flash_attn/vllm_flash_attn_c.abi3.so", - "vllm/vllm_flash_attn/flash_attn_interface.py", - "vllm/vllm_flash_attn/__init__.py", - # "vllm/_version.py", # not available in nightly wheels yet -] - -# Try to create _version.py to avoid version related warning -# Refer to https://github.com/vllm-project/vllm/pull/8771 -try: - from setuptools_scm import get_version - get_version(write_to="vllm/_version.py") -except ImportError: - warnings.warn( - "To avoid warnings related to vllm._version, " - "you should install setuptools-scm by `pip install setuptools-scm`", - stacklevel=2) - -if not args.quit_dev: - for file in files_to_copy: - src = os.path.join(package_path, file) - dst = file - print(f"Copying {src} to {dst}") - shutil.copyfile(src, dst) - - pre_built_vllm_path = os.path.join(package_path, "vllm") - tmp_path = os.path.join(package_path, "vllm_pre_built") - current_vllm_path = os.path.join(cwd, "vllm") - - print(f"Renaming {pre_built_vllm_path} to {tmp_path} for backup") - shutil.copytree(pre_built_vllm_path, tmp_path) - shutil.rmtree(pre_built_vllm_path) - - print(f"Linking {current_vllm_path} to {pre_built_vllm_path}") - os.symlink(current_vllm_path, pre_built_vllm_path) -else: - vllm_symlink_path = os.path.join(package_path, "vllm") - vllm_backup_path = os.path.join(package_path, "vllm_pre_built") - current_vllm_path = os.path.join(cwd, "vllm") - - print(f"Unlinking {current_vllm_path} to {vllm_symlink_path}") - assert os.path.islink( - vllm_symlink_path - ), f"not in dev mode: {vllm_symlink_path} is not a symbolic link" - assert current_vllm_path == os.readlink( - vllm_symlink_path - ), "current directory is not the source code of package" - os.unlink(vllm_symlink_path) - - print(f"Recovering backup from {vllm_backup_path} to {vllm_symlink_path}") - os.rename(vllm_backup_path, vllm_symlink_path) +print(msg) diff --git a/requirements-common.txt b/requirements-common.txt index f62ad66a1ecc4..72fb020a82c4e 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -19,8 +19,9 @@ prometheus-fastapi-instrumentator >= 7.0.0 tiktoken >= 0.6.0 # Required for DBRX tokenizer lm-format-enforcer >= 0.10.9, < 0.11 outlines >= 0.0.43, < 0.1 +xgrammar >= 0.1.5; platform_machine == "x86_64" typing_extensions >= 4.10 -filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4 +filelock >= 3.16.1 # need to contain https://github.com/tox-dev/filelock/pull/317 partial-json-parser # used for parsing partial JSON outputs pyzmq msgspec diff --git a/requirements-test.in b/requirements-test.in index 76f6de2f77c34..c0b228148ab31 100644 --- a/requirements-test.in +++ b/requirements-test.in @@ -20,13 +20,10 @@ timm # required for internvl test torch==2.5.1 transformers_stream_generator # required for qwen-vl test matplotlib # required for qwen-vl test -mistral_common[opencv] >= 1.4.4 # required for pixtral test +mistral_common[opencv] >= 1.5.0 # required for pixtral test datamodel_code_generator # required for minicpm3 test lm-eval[api]==0.4.4 # required for model evaluation test -# TODO: Add this after fully implementing llava(mantis) -# git+https://github.com/TIGER-AI-Lab/Mantis.git # required for llava(mantis) test - # quantization bitsandbytes>=0.44.0 buildkite-test-collector==0.1.9 diff --git a/requirements-test.txt b/requirements-test.txt index 65695111e4dc5..19369254dbe26 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -217,7 +217,7 @@ mbstrdecoder==1.1.3 # dataproperty # pytablewriter # typepy -mistral-common[opencv]==1.4.4 +mistral-common[opencv]==1.5.1 # via # -r requirements-test.in # mistral-common @@ -550,7 +550,7 @@ tqdm==4.66.6 # transformers tqdm-multiprocess==0.0.11 # via lm-eval -transformers==4.45.2 +transformers==4.46.3 # via # lm-eval # peft diff --git a/requirements-tpu.txt b/requirements-tpu.txt index 3d1e80f6be620..b8f0b15469e77 100644 --- a/requirements-tpu.txt +++ b/requirements-tpu.txt @@ -16,8 +16,8 @@ ray[default] --find-links https://storage.googleapis.com/libtpu-releases/index.html --find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html --find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html -torch==2.6.0.dev20241114+cpu -torchvision==0.20.0.dev20241114+cpu -torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20241114-cp310-cp310-linux_x86_64.whl -jaxlib==0.4.32.dev20240829 -jax==0.4.32.dev20240829 +torch==2.6.0.dev20241126+cpu +torchvision==0.20.0.dev20241126+cpu +torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20241126-cp310-cp310-linux_x86_64.whl +jaxlib==0.4.36.dev20241122 +jax==0.4.36.dev20241122 diff --git a/setup.py b/setup.py index b936589869e76..fcfaa207c176a 100644 --- a/setup.py +++ b/setup.py @@ -249,6 +249,74 @@ def run(self): self.copy_file(file, dst_file) +class repackage_wheel(build_ext): + """Extracts libraries and other files from an existing wheel.""" + default_wheel = "https://vllm-wheels.s3.us-west-2.amazonaws.com/nightly/vllm-1.0.0.dev-cp38-abi3-manylinux1_x86_64.whl" + + def run(self) -> None: + wheel_location = os.getenv("VLLM_PRECOMPILED_WHEEL_LOCATION", + self.default_wheel) + + assert _is_cuda( + ), "VLLM_USE_PRECOMPILED is only supported for CUDA builds" + + import zipfile + + if os.path.isfile(wheel_location): + wheel_path = wheel_location + print(f"Using existing wheel={wheel_path}") + else: + # Download the wheel from a given URL, assume + # the filename is the last part of the URL + wheel_filename = wheel_location.split("/")[-1] + + import tempfile + + # create a temporary directory to store the wheel + temp_dir = tempfile.mkdtemp(prefix="vllm-wheels") + wheel_path = os.path.join(temp_dir, wheel_filename) + + print(f"Downloading wheel from {wheel_location} to {wheel_path}") + + from urllib.request import urlretrieve + + try: + urlretrieve(wheel_location, filename=wheel_path) + except Exception as e: + from setuptools.errors import SetupError + + raise SetupError( + f"Failed to get vLLM wheel from {wheel_location}") from e + + with zipfile.ZipFile(wheel_path) as wheel: + files_to_copy = [ + "vllm/_C.abi3.so", + "vllm/_moe_C.abi3.so", + "vllm/vllm_flash_attn/vllm_flash_attn_c.abi3.so", + "vllm/vllm_flash_attn/flash_attn_interface.py", + "vllm/vllm_flash_attn/__init__.py", + # "vllm/_version.py", # not available in nightly wheels yet + ] + file_members = filter(lambda x: x.filename in files_to_copy, + wheel.filelist) + + for file in file_members: + print(f"Extracting and including {file.filename} " + "from existing wheel") + package_name = os.path.dirname(file.filename).replace("/", ".") + file_name = os.path.basename(file.filename) + + if package_name not in package_data: + package_data[package_name] = [] + + wheel.extract(file) + if file_name.endswith(".py"): + # python files shouldn't be added to package_data + continue + + package_data[package_name].append(file_name) + + def _is_hpu() -> bool: is_hpu_available = True try: @@ -397,12 +465,15 @@ def get_vllm_version() -> str: if envs.VLLM_TARGET_DEVICE == "empty": version += f"{sep}empty" elif _is_cuda(): - cuda_version = str(get_nvcc_cuda_version()) - if cuda_version != MAIN_CUDA_VERSION: - cuda_version_str = cuda_version.replace(".", "")[:3] - # skip this for source tarball, required for pypi - if "sdist" not in sys.argv: - version += f"{sep}cu{cuda_version_str}" + if envs.VLLM_USE_PRECOMPILED: + version += ".precompiled" + else: + cuda_version = str(get_nvcc_cuda_version()) + if cuda_version != MAIN_CUDA_VERSION: + cuda_version_str = cuda_version.replace(".", "")[:3] + # skip this for source tarball, required for pypi + if "sdist" not in sys.argv: + version += f"{sep}cu{cuda_version_str}" elif _is_hip(): # Get the HIP version hipcc_version = get_hipcc_rocm_version() @@ -514,13 +585,18 @@ def _read_requirements(filename: str) -> List[str]: package_data = { "vllm": ["py.typed", "model_executor/layers/fused_moe/configs/*.json"] } -if envs.VLLM_USE_PRECOMPILED: - ext_modules = [] - package_data["vllm"].append("*.so") if _no_device(): ext_modules = [] +if not ext_modules: + cmdclass = {} +else: + cmdclass = { + "build_ext": + repackage_wheel if envs.VLLM_USE_PRECOMPILED else cmake_build_ext + } + setup( name="vllm", version=get_vllm_version(), @@ -557,7 +633,7 @@ def _read_requirements(filename: str) -> List[str]: "audio": ["librosa", "soundfile"], # Required for audio processing "video": ["decord"] # Required for video processing }, - cmdclass={"build_ext": cmake_build_ext} if len(ext_modules) > 0 else {}, + cmdclass=cmdclass, package_data=package_data, entry_points={ "console_scripts": [ diff --git a/tests/compile/piecewise/test_simple.py b/tests/compile/piecewise/test_simple.py index 7ef502abee345..aa11524812cdd 100644 --- a/tests/compile/piecewise/test_simple.py +++ b/tests/compile/piecewise/test_simple.py @@ -7,7 +7,6 @@ from torch import nn from torch.library import Library -from vllm.compilation.compile_context import set_compile_context from vllm.compilation.counter import compilation_counter from vllm.compilation.decorators import support_torch_compile from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig, @@ -81,6 +80,7 @@ def test_simple_piecewise_compile(): use_cudagraph=True, splitting_ops=["silly.attention"], cudagraph_copy_inputs=True, + cudagraph_capture_sizes=[1, 2], )) with set_current_vllm_config(vllm_config): model = SillyModel(vllm_config=vllm_config, prefix='') @@ -96,11 +96,10 @@ def test_simple_piecewise_compile(): 6, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen ): - with set_compile_context([1, 2]): - model(inputs) + model(inputs) - model(torch.randn(2).cuda()) - model(torch.randn(1).cuda()) + model(torch.randn(2).cuda()) + model(torch.randn(1).cuda()) input = torch.zeros(2).cuda() global global_counter diff --git a/tests/compile/piecewise/test_toy_llama.py b/tests/compile/piecewise/test_toy_llama.py index dbd5a3bbffeab..07c10a3a18c55 100644 --- a/tests/compile/piecewise/test_toy_llama.py +++ b/tests/compile/piecewise/test_toy_llama.py @@ -13,7 +13,6 @@ from torch import nn from torch.library import Library -from vllm.compilation.compile_context import set_compile_context from vllm.compilation.counter import compilation_counter from vllm.compilation.decorators import support_torch_compile from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig, @@ -256,6 +255,7 @@ def run_model(llama_config, compilation_config = CompilationConfig( level=CompilationLevel.PIECEWISE, use_cudagraph=True, + cudagraph_capture_sizes=[1, 2], ) if split_attn: compilation_config.splitting_ops = ["silly.attention"] @@ -273,10 +273,9 @@ def run_model(llama_config, input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda() positions = torch.arange(B).cuda() - with set_compile_context([1, 2]): - model(input_ids, positions) - model(input_ids[:2], positions[:2]) - model(input_ids[:1], positions[:1]) + model(input_ids, positions) + model(input_ids[:2], positions[:2]) + model(input_ids[:1], positions[:1]) input_ids[:2].zero_() output = model(input_ids[:2], positions[:2]) @@ -379,10 +378,13 @@ def benchmark(): level=CompilationLevel.PIECEWISE, use_cudagraph=True, splitting_ops=["silly.attention"], + cudagraph_capture_sizes=cudagraph_sizes, ) else: compilation_config = CompilationConfig( - level=CompilationLevel.PIECEWISE, ) + level=CompilationLevel.PIECEWISE, + cudagraph_capture_sizes=cudagraph_sizes, + ) vllm_config = VllmConfig(compilation_config=compilation_config) with set_current_vllm_config(vllm_config): @@ -396,17 +398,16 @@ def benchmark(): graphs = {} - with set_compile_context(cudagraph_sizes): - model(input_ids, positions) - for b in cudagraph_sizes[::-1]: - if not piecewise: - graph = torch.cuda.CUDAGraph() - with torch.cuda.graph(graph, pool=pool): - output = model(input_ids[:b], positions[:b]) - graphs[b] = (graph, output) - else: + model(input_ids, positions) + for b in cudagraph_sizes[::-1]: + if not piecewise: + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph, pool=pool): output = model(input_ids[:b], positions[:b]) - graphs[b] = (model, output) + graphs[b] = (graph, output) + else: + output = model(input_ids[:b], positions[:b]) + graphs[b] = (model, output) for b in cudagraph_sizes: if piecewise: # noqa is for `Function definition does not bind loop variable` diff --git a/tests/conftest.py b/tests/conftest.py index d56942d8912af..d6be8f5b00af8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -263,7 +263,6 @@ def __init__( dtype: str = "half", *, model_kwargs: Optional[Dict[str, Any]] = None, - is_embedding_model: bool = False, is_sentence_transformer: bool = False, is_cross_encoder: bool = False, skip_tokenizer_init: bool = False, @@ -657,6 +656,7 @@ def __init__( model_name: str, task: TaskOption = "auto", tokenizer_name: Optional[str] = None, + tokenizer_mode: str = "auto", # Use smaller max model length, otherwise bigger model cannot run due # to kv cache size limit. max_model_len: int = 1024, @@ -673,6 +673,7 @@ def __init__( model=model_name, task=task, tokenizer=tokenizer_name, + tokenizer_mode=tokenizer_mode, trust_remote_code=True, dtype=dtype, swap_space=swap_space, @@ -843,6 +844,7 @@ def generate_greedy_logprobs( audios: Optional[PromptAudioInput] = None, videos: Optional[PromptVideoInput] = None, stop_token_ids: Optional[List[int]] = None, + stop: Optional[List[str]] = None, ) -> Union[List[TokensTextLogprobs], List[TokensTextLogprobsPromptLogprobs]]: greedy_logprobs_params = SamplingParams( @@ -850,7 +852,8 @@ def generate_greedy_logprobs( max_tokens=max_tokens, logprobs=num_logprobs, prompt_logprobs=num_prompt_logprobs, - stop_token_ids=stop_token_ids) + stop_token_ids=stop_token_ids, + stop=stop) return self.generate_w_logprobs(prompts, greedy_logprobs_params, diff --git a/tests/distributed/test_pipeline_parallel.py b/tests/distributed/test_pipeline_parallel.py index 0bdb3965eca06..1317a6f6ae41e 100644 --- a/tests/distributed/test_pipeline_parallel.py +++ b/tests/distributed/test_pipeline_parallel.py @@ -249,9 +249,19 @@ def _compare_tp( *, method: Literal["generate", "encode"], ): - tp_size, pp_size, eager_mode, chunked_prefill = parallel_setup - multi_node_only, trust_remote_code, tokenizer_mode, \ - load_format, hf_overrides = test_options + ( + tp_size, + pp_size, + eager_mode, + chunked_prefill, + ) = parallel_setup + ( + multi_node_only, + trust_remote_code, + tokenizer_mode, + load_format, + hf_overrides, + ) = test_options if num_gpus_available < tp_size * pp_size: pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs") diff --git a/tests/distributed/test_pynccl.py b/tests/distributed/test_pynccl.py index fb24d6bc2c100..3e9b0e10a11d8 100644 --- a/tests/distributed/test_pynccl.py +++ b/tests/distributed/test_pynccl.py @@ -61,8 +61,8 @@ def worker_fn(): dtype=torch.float32).cuda(pynccl_comm.rank) with pynccl_comm.change_state(enable=True): tensor = pynccl_comm.all_reduce(tensor) - result = tensor.mean().cpu().item() - assert result == pynccl_comm.world_size + torch.cuda.synchronize() + assert torch.all(tensor == pynccl_comm.world_size).cpu().item() @pytest.mark.skipif(torch.cuda.device_count() < 2, @@ -86,12 +86,12 @@ def multiple_allreduce_worker_fn(): if torch.distributed.get_rank() in [0, 1]: tensor = pynccl_comm.all_reduce(tensor) tensor = pynccl_comm.all_reduce(tensor) - result = tensor.mean().cpu().item() - assert result == 4 + torch.cuda.synchronize() + assert torch.all(tensor == 4).cpu().item() else: tensor = pynccl_comm.all_reduce(tensor) - result = tensor.mean().cpu().item() - assert result == 2 + torch.cuda.synchronize() + assert torch.all(tensor == 2).cpu().item() @pytest.mark.skipif(torch.cuda.device_count() < 4, @@ -112,12 +112,12 @@ def multiple_allreduce_with_vllm_worker_fn(): if torch.distributed.get_rank() in [0, 1]: tensor = tensor_model_parallel_all_reduce(tensor) tensor = tensor_model_parallel_all_reduce(tensor) - result = tensor.mean().cpu().item() - assert result == 4 + torch.cuda.synchronize() + assert torch.all(tensor == 4).cpu().item() else: tensor = tensor_model_parallel_all_reduce(tensor) - result = tensor.mean().cpu().item() - assert result == 2 + torch.cuda.synchronize() + assert torch.all(tensor == 2).cpu().item() @pytest.mark.skipif(torch.cuda.device_count() < 4, @@ -141,10 +141,10 @@ def worker_fn_with_cudagraph(): graph, stream=pynccl_comm.stream), pynccl_comm.change_state( enable=True): a_out = pynccl_comm.all_reduce(a) - pynccl_comm.stream.synchronize() + torch.cuda.synchronize() graph.replay() - pynccl_comm.stream.synchronize() - assert a_out.mean().cpu().item() == pynccl_comm.world_size**1 + torch.cuda.synchronize() + assert torch.all(a_out == pynccl_comm.world_size).cpu().item() @worker_fn_wrapper @@ -170,6 +170,7 @@ def all_gather_worker_fn(): with pynccl_comm.change_state(enable=True): pynccl_comm.all_gather(result, tensor) + torch.cuda.synchronize() torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8) @@ -207,6 +208,7 @@ def reduce_scatter_worker_fn(): with pynccl_comm.change_state(enable=True): pynccl_comm.reduce_scatter(result, tensor) + torch.cuda.synchronize() torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8) @@ -241,8 +243,8 @@ def send_recv_worker_fn(): pynccl_comm.recv(tensor, src=(pynccl_comm.rank - 1) % pynccl_comm.world_size) - result = tensor.mean().cpu().item() - assert result == 1 + torch.cuda.synchronize() + assert torch.all(tensor == 1).cpu().item() @pytest.mark.skipif(torch.cuda.device_count() < 2, @@ -280,11 +282,11 @@ def multiple_send_recv_worker_fn(): pynccl_comm.recv(tensor, src=(pynccl_comm.rank - 1) % pynccl_comm.world_size) - result = tensor.mean().cpu().item() + torch.cuda.synchronize() if torch.distributed.get_rank() in [0, 2]: - assert result == 1 + assert torch.all(tensor == 1).cpu().item() else: - assert result == 2 + assert torch.all(tensor == 2).cpu().item() @pytest.mark.skipif(torch.cuda.device_count() < 4, @@ -293,6 +295,38 @@ def test_pynccl_multiple_send_recv(): distributed_run(multiple_send_recv_worker_fn, 4) +@pytest.mark.skipif(torch.cuda.device_count() < 4, + reason="Need at least 4 GPUs to run the test.") +def test_pynccl_broadcast(): + distributed_run(broadcast_worker_fn, 4) + + +@worker_fn_wrapper +def broadcast_worker_fn(): + # Test broadcast for every root rank. + # Essentially this is an all-gather operation. + pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group, + device=get_world_group().device) + recv_tensors = [ + torch.empty(16, + 1024, + 1024, + dtype=torch.float32, + device=pynccl_comm.device) + for i in range(pynccl_comm.world_size) + ] + recv_tensors[pynccl_comm.rank] = torch.ones( + 16, 1024, 1024, dtype=torch.float32, + device=pynccl_comm.device) * pynccl_comm.rank + + for i in range(pynccl_comm.world_size): + pynccl_comm.broadcast(recv_tensors[i], src=i) + # the broadcast op might be launched in a different stream + # need to synchronize to make sure the tensor is ready + torch.cuda.synchronize() + assert torch.all(recv_tensors[i] == i).cpu().item() + + def test_ncclGetUniqueId(): lib = NCCLLibrary() unique_id = lib.ncclGetUniqueId() diff --git a/tests/engine/test_arg_utils.py b/tests/engine/test_arg_utils.py index 5b0e76fe53685..4e269de9fc40b 100644 --- a/tests/engine/test_arg_utils.py +++ b/tests/engine/test_arg_utils.py @@ -50,15 +50,34 @@ def test_compilation_config(): args = parser.parse_args(["-O=3"]) assert args.compilation_config.level == 3 - # set to json - args = parser.parse_args(["--compilation-config", '{"level": 3}']) + # set to string form of a dict + args = parser.parse_args(["--compilation-config", "{'level': 3}"]) assert args.compilation_config.level == 3 - # set to json - args = parser.parse_args(['--compilation-config={"level": 3}']) + # set to string form of a dict + args = parser.parse_args(["--compilation-config={'level': 3}"]) assert args.compilation_config.level == 3 +def test_prefix_cache_default(): + parser = EngineArgs.add_cli_args(FlexibleArgumentParser()) + args = parser.parse_args([]) + + engine_args = EngineArgs.from_cli_args(args=args) + assert (not engine_args.enable_prefix_caching + ), "prefix caching defaults to off." + + # with flag to turn it on. + args = parser.parse_args(["--enable-prefix-caching"]) + engine_args = EngineArgs.from_cli_args(args=args) + assert engine_args.enable_prefix_caching + + # with disable flag to turn it off. + args = parser.parse_args(["--no-enable-prefix-caching"]) + engine_args = EngineArgs.from_cli_args(args=args) + assert not engine_args.enable_prefix_caching + + def test_valid_pooling_config(): parser = EngineArgs.add_cli_args(FlexibleArgumentParser()) args = parser.parse_args([ diff --git a/tests/entrypoints/conftest.py b/tests/entrypoints/conftest.py index e7ef5637c8ccb..0f7d15e1d85aa 100644 --- a/tests/entrypoints/conftest.py +++ b/tests/entrypoints/conftest.py @@ -69,6 +69,37 @@ def sample_json_schema(): } +@pytest.fixture +def sample_complex_json_schema(): + return { + "type": "object", + "properties": { + "score": { + "type": "integer", + "minimum": 0, + "maximum": 100 # Numeric range + }, + "grade": { + "type": "string", + "pattern": "^[A-D]$" # Regex pattern + }, + "email": { + "type": "string", + "pattern": "^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\.[a-zA-Z]{2,}$" + }, + "tags": { + "type": "array", + "items": { + "type": "string", + "pattern": + "^[a-z]{1,10}$" # Combining length and pattern restrictions + } + } + }, + "required": ["score", "grade", "email", "tags"] + } + + @pytest.fixture def sample_guided_choice(): return [ diff --git a/tests/entrypoints/llm/test_encode.py b/tests/entrypoints/llm/test_encode.py index 4c9f796e5ed71..41163809237e9 100644 --- a/tests/entrypoints/llm/test_encode.py +++ b/tests/entrypoints/llm/test_encode.py @@ -3,7 +3,7 @@ import pytest -from vllm import LLM, EmbeddingRequestOutput, PoolingParams +from vllm import LLM, PoolingParams, PoolingRequestOutput from vllm.distributed import cleanup_dist_env_and_memory MODEL_NAME = "intfloat/e5-mistral-7b-instruct" @@ -43,8 +43,8 @@ def llm(): cleanup_dist_env_and_memory() -def assert_outputs_equal(o1: List[EmbeddingRequestOutput], - o2: List[EmbeddingRequestOutput]): +def assert_outputs_equal(o1: List[PoolingRequestOutput], + o2: List[PoolingRequestOutput]): assert [o.outputs for o in o1] == [o.outputs for o in o2] diff --git a/tests/entrypoints/llm/test_guided_generate.py b/tests/entrypoints/llm/test_guided_generate.py index 67c79415f322a..de6257cfc551c 100644 --- a/tests/entrypoints/llm/test_guided_generate.py +++ b/tests/entrypoints/llm/test_guided_generate.py @@ -76,6 +76,34 @@ def test_guided_json_completion(sample_json_schema, llm): jsonschema.validate(instance=output_json, schema=sample_json_schema) +@pytest.mark.skip_global_cleanup +def test_guided_complex_json_completion(sample_complex_json_schema, llm): + sampling_params = SamplingParams( + temperature=1.0, + max_tokens=1000, + guided_decoding=GuidedDecodingParams(json=sample_complex_json_schema)) + outputs = llm.generate(prompts=[ + f"Give an example JSON for an assignment grade " + f"that fits this schema: {sample_complex_json_schema}" + ] * 2, + sampling_params=sampling_params, + use_tqdm=True) + + assert outputs is not None + + for output in outputs: + assert output is not None + assert isinstance(output, RequestOutput) + prompt = output.prompt + + generated_text = output.outputs[0].text + assert generated_text is not None + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + output_json = json.loads(generated_text) + jsonschema.validate(instance=output_json, + schema=sample_complex_json_schema) + + @pytest.mark.skip_global_cleanup def test_guided_choice_completion(sample_guided_choice, llm): sampling_params = SamplingParams( @@ -159,3 +187,30 @@ def test_validation_against_both_guided_decoding_options(sample_regex, llm): sampling_params=sampling_params, use_tqdm=True, guided_options_request=dict(guided_regex=sample_regex)) + + +@pytest.mark.skip_global_cleanup +def test_guided_json_object(llm): + sampling_params = SamplingParams( + temperature=1.0, + max_tokens=100, + guided_decoding=GuidedDecodingParams(json_object=True)) + + outputs = llm.generate( + prompts=("Generate a JSON object describing a person with name " + "and age for John Smith who is 31 years old."), + sampling_params=sampling_params, + use_tqdm=True) + + assert outputs is not None + for output in outputs: + assert output is not None + assert isinstance(output, RequestOutput) + + generated_text = output.outputs[0].text + print(generated_text) + assert generated_text is not None + + # Parse to verify it is valid JSON + parsed_json = json.loads(generated_text) + assert isinstance(parsed_json, dict) diff --git a/tests/entrypoints/openai/test_async_tokenization.py b/tests/entrypoints/openai/test_async_tokenization.py new file mode 100644 index 0000000000000..fcce8b46c4344 --- /dev/null +++ b/tests/entrypoints/openai/test_async_tokenization.py @@ -0,0 +1,137 @@ +import asyncio +import contextlib +import random +import time +from typing import Callable + +import openai +import pytest +import pytest_asyncio +import requests + +from tests.utils import RemoteOpenAIServer + +MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct" + + +@pytest.fixture(scope="module") +def server(): # noqa: F811 + args = [ + # use half precision for speed and memory savings in CI environment + "--dtype", + "bfloat16", + "--max-model-len", + "8192", + "--enforce-eager", + "--max-num-seqs", + "128", + "--load-format", + "dummy", + ] + + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + yield remote_server + + +@pytest_asyncio.fixture +async def client(server): + async with server.get_async_client() as async_client: + yield async_client + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + ids=["completion", "chat"], + argnames=["create_func_gen", "content_body"], + argvalues=[ + (lambda x: x.completions.create, { + "prompt": " ".join(['A'] * 10_000) + }), + (lambda x: x.chat.completions.create, { + "messages": [{ + "role": "user", + "content": " ".join(['A'] * 10_000) + }] + }), + ], +) +async def test_with_and_without_truncate( + server: RemoteOpenAIServer, + client: openai.AsyncOpenAI, + create_func_gen: Callable, + content_body: dict, +): + create_func = create_func_gen(client) + body = {"model": MODEL_NAME, **content_body, "max_tokens": 10} + + num_requests = 10 + truncate_prompt_tokens = ([1000] * (num_requests // 2) + [None] * + (num_requests - num_requests // 2)) + random.shuffle(truncate_prompt_tokens) + + bodies = [{ + **body, "extra_body": { + 'truncate_prompt_tokens': t + } + } for t in truncate_prompt_tokens] + + async def get_status_code(**kwargs): + try: + await create_func(**kwargs) + return 200 + except openai.APIStatusError as e: + return e.status_code + + responses = await asyncio.gather(*[get_status_code(**b) for b in bodies]) + assert 500 not in responses + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + ids=["single completion", "multiple completions", "chat"], + argnames=["create_func_gen", "content_body"], + argvalues=[ + (lambda x: x.completions.create, { + "prompt": " ".join(['A'] * 300_000) + }), + (lambda x: x.completions.create, { + "prompt": [" ".join(['A'] * 300_000)] * 2 + }), + (lambda x: x.chat.completions.create, { + "messages": [{ + "role": "user", + "content": " ".join(['A'] * 300_000) + }] + }), + ], +) +async def test_healthcheck_response_time( + server: RemoteOpenAIServer, + client: openai.AsyncOpenAI, + create_func_gen: Callable, + content_body: dict, +): + num_requests = 50 + + create_func = create_func_gen(client) + body = {"model": MODEL_NAME, **content_body, "max_tokens": 10} + + def get_response_time(url): + start_time = time.monotonic() + res = requests.get(url) + end_time = time.monotonic() + assert res.status_code == 200 + return end_time - start_time + + no_load_response_time = get_response_time(server.url_for("health")) + tasks = [ + asyncio.create_task(create_func(**body)) for _ in range(num_requests) + ] + await asyncio.sleep(1) # give the tasks a chance to start running + load_response_time = get_response_time(server.url_for("health")) + + with contextlib.suppress(openai.APIStatusError): + await asyncio.gather(*tasks) + + assert load_response_time < 100 * no_load_response_time + assert load_response_time < 0.1 diff --git a/tests/kernels/test_causal_conv1d.py b/tests/kernels/test_causal_conv1d.py index f9b11018288be..51be2425d7dd7 100644 --- a/tests/kernels/test_causal_conv1d.py +++ b/tests/kernels/test_causal_conv1d.py @@ -149,13 +149,14 @@ def causal_conv1d_opcheck_fn(x: torch.Tensor, @pytest.mark.parametrize("itype", [torch.bfloat16, torch.float]) @pytest.mark.parametrize("silu_activation", [True]) @pytest.mark.parametrize("has_bias", [True]) +@pytest.mark.parametrize("has_initial_state", [True, False]) @pytest.mark.parametrize("width", [4]) @pytest.mark.parametrize( 'seqlen', [1, 8, 16, 32, 64, 128, 256, 512, 784, 1024, 1025, 2048, 4096]) @pytest.mark.parametrize('dim', [64]) @pytest.mark.parametrize('batch', [1]) def test_causal_conv1d(batch, dim, seqlen, width, has_bias, silu_activation, - itype): + has_initial_state, itype): device = "cuda" rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3) if itype == torch.bfloat16: @@ -167,11 +168,18 @@ def test_causal_conv1d(batch, dim, seqlen, width, has_bias, silu_activation, weight = torch.randn(dim, width, device=device, dtype=itype) bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None - initial_states = torch.randn(batch, - dim, - width - 1, - device=device, - dtype=itype) + if has_initial_state: + initial_states = torch.randn(batch, + dim, + width - 1, + device=device, + dtype=itype) + has_initial_state_tensor = torch.ones(batch, + dtype=torch.bool, + device=x.device) + else: + initial_states = None + has_initial_state_tensor = None x_ref = x.clone() weight_ref = weight.clone() bias_ref = bias.clone() if bias is not None else None @@ -183,9 +191,7 @@ def test_causal_conv1d(batch, dim, seqlen, width, has_bias, silu_activation, bias, activation=activation, conv_states=initial_states, - has_initial_state=torch.ones(batch, - dtype=torch.bool, - device=x.device)) + has_initial_state=has_initial_state_tensor) out_ref, final_states_ref = causal_conv1d_ref( x_ref, weight_ref, @@ -193,11 +199,12 @@ def test_causal_conv1d(batch, dim, seqlen, width, has_bias, silu_activation, initial_states=initial_states_ref, return_final_states=True, activation=activation) - assert initial_states is not None and final_states_ref is not None - assert torch.allclose(initial_states, - final_states_ref, - rtol=rtol, - atol=atol) + if has_initial_state: + assert initial_states is not None and final_states_ref is not None + assert torch.allclose(initial_states, + final_states_ref, + rtol=rtol, + atol=atol) assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) causal_conv1d_opcheck_fn(x, @@ -205,9 +212,7 @@ def test_causal_conv1d(batch, dim, seqlen, width, has_bias, silu_activation, bias, activation=activation, conv_states=initial_states, - has_initial_state=torch.ones(batch, - dtype=torch.bool, - device=x.device)) + has_initial_state=has_initial_state_tensor) @pytest.mark.parametrize("itype", [torch.bfloat16]) diff --git a/tests/kernels/test_flash_attn.py b/tests/kernels/test_flash_attn.py index a20c73345218f..1ae78d7b46c5b 100644 --- a/tests/kernels/test_flash_attn.py +++ b/tests/kernels/test_flash_attn.py @@ -71,6 +71,7 @@ def ref_paged_attn( return torch.cat(outputs, dim=0) +@pytest.mark.parametrize("use_out", [True, False]) @pytest.mark.parametrize("kv_lens", [[1328, 18, 463], [1, 54, 293, 70]]) @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @@ -81,6 +82,7 @@ def ref_paged_attn( @pytest.mark.parametrize("sliding_window", [None, 256]) @torch.inference_mode() def test_flash_attn_with_paged_kv( + use_out: bool, kv_lens: List[int], num_heads: Tuple[int, int], head_size: int, @@ -116,17 +118,22 @@ def test_flash_attn_with_paged_kv( (num_seqs, max_num_blocks_per_seq), dtype=torch.int32) + q = query.unsqueeze(1) + out = torch.empty_like(q) if use_out else None output = flash_attn_with_kvcache( - q=query.unsqueeze(1), + q=q, k_cache=key_cache, v_cache=value_cache, + out=out, softmax_scale=scale, causal=True, block_table=block_tables, cache_seqlens=kv_lens_tensor, softcap=soft_cap if soft_cap is not None else 0, window_size=window_size, - ).squeeze(1) + ) + output = output if not use_out else out + output = output.squeeze(1) ref_output = ref_paged_attn(query=query, key_cache=key_cache, @@ -141,7 +148,10 @@ def test_flash_attn_with_paged_kv( f"{torch.max(torch.abs(output - ref_output))}" -@pytest.mark.parametrize("seq_lens", [[(1, 1328), (5, 18), (129, 463)]]) +@pytest.mark.parametrize("use_out", [True, False]) +@pytest.mark.parametrize("seq_lens", + [[(1, 1328), (5, 18), + (129, 463)], [(1, 523), (1, 37), (1, 2011)]]) @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("block_size", BLOCK_SIZES) @@ -151,6 +161,7 @@ def test_flash_attn_with_paged_kv( @pytest.mark.parametrize("num_blocks", NUM_BLOCKS) @torch.inference_mode() def test_varlen_with_paged_kv( + use_out: bool, seq_lens: List[Tuple[int, int]], num_heads: Tuple[int, int], head_size: int, @@ -197,10 +208,12 @@ def test_varlen_with_paged_kv( (num_seqs, max_num_blocks_per_seq), dtype=torch.int32) + out = torch.empty_like(query) if use_out else None output = flash_attn_varlen_func( q=query, k=key_cache, v=value_cache, + out=out, cu_seqlens_q=cu_query_lens, cu_seqlens_k=cu_kv_lens, max_seqlen_q=max_query_len, @@ -211,6 +224,7 @@ def test_varlen_with_paged_kv( block_table=block_tables, softcap=soft_cap if soft_cap is not None else 0, ) + output = output if not use_out else out ref_output = ref_paged_attn( query=query, diff --git a/tests/kv_transfer/disagg_test.py b/tests/kv_transfer/disagg_test.py new file mode 100644 index 0000000000000..adc6150edece6 --- /dev/null +++ b/tests/kv_transfer/disagg_test.py @@ -0,0 +1,119 @@ +import os +import subprocess +import sys +import time +from subprocess import Popen + +import pytest +import requests +import torch + + +# Fixture to set up environment variables and teardown servers after tests +@pytest.fixture(scope="module", autouse=True) +def setup_servers(): + if torch.cuda.device_count() < 4: + pytest.skip("Skipping test: fewer than 4 GPUs available") + + # Set up environment variables + VLLM_HOST_IP = subprocess.check_output("hostname -I | awk '{print $1}'", + shell=True).decode().strip() + os.environ["VLLM_HOST_IP"] = VLLM_HOST_IP + + # Start prefill instance + prefill_cmd = [ + sys.executable, + "-m", + "vllm.entrypoints.openai.api_server", + "--model", + "meta-llama/Meta-Llama-3.1-8B-Instruct", + "--port", + "8100", + "--gpu-memory-utilization", + "0.5", + "--max-model-len", + "1000", + "--kv-transfer-config", + '{"kv_connector":"PyNcclConnector","kv_role":"kv_producer",'\ + '"kv_rank":0,"kv_parallel_size":2}', + ] + prefill_env = os.environ.copy() + prefill_env["CUDA_VISIBLE_DEVICES"] = "0" + prefill_proc = Popen(prefill_cmd, env=prefill_env) + + # Start decode instance + decode_cmd = [ + sys.executable, + "-m", + "vllm.entrypoints.openai.api_server", + "--model", + "meta-llama/Meta-Llama-3.1-8B-Instruct", + "--port", + "8200", + "--gpu-memory-utilization", + "0.5", + "--max-model-len", + "1000", + "--kv-transfer-config", + '{"kv_connector":"PyNcclConnector","kv_role":"kv_consumer",'\ + '"kv_rank":1,"kv_parallel_size":2}', + ] + decode_env = os.environ.copy() + decode_env["CUDA_VISIBLE_DEVICES"] = "1" + decode_proc = Popen(decode_cmd, env=decode_env) + + # Wait for servers to be ready + assert wait_for_server(8100), "Prefill server did not start in time" + assert wait_for_server(8200), "Decode server did not start in time" + + # Yield to the test function and handle teardown after tests + yield + + # Cleanup: kill the processes + prefill_proc.terminate() + decode_proc.terminate() + + # Additional cleanup if needed + prefill_proc.wait() + decode_proc.wait() + + +# Helper function to wait for server +def wait_for_server(port, timeout=240): + start_time = time.time() + while time.time() - start_time < timeout: + try: + response = requests.get(f"http://localhost:{port}/v1/completions") + if response.status_code in [200, 405]: + return True + except requests.ConnectionError: + time.sleep(1) + return False + + +# Test function to send curl requests and validate responses +@pytest.mark.parametrize("prompt", ["San Francisco is a", "Santa Clara is a"]) +def test_disaggregated_prefilling(prompt): + # Send to prefill + response = requests.post("http://localhost:8100/v1/completions", + headers={"Content-Type": "application/json"}, + json={ + "model": + "meta-llama/Meta-Llama-3.1-8B-Instruct", + "prompt": prompt, + "max_tokens": 1, + "temperature": 0 + }) + assert response.status_code == 200 + + # Send to decode + response = requests.post("http://localhost:8200/v1/completions", + headers={"Content-Type": "application/json"}, + json={ + "model": + "meta-llama/Meta-Llama-3.1-8B-Instruct", + "prompt": prompt, + "max_tokens": 10, + "temperature": 0 + }) + assert response.status_code == 200 diff --git a/tests/kv_transfer/module_test.py b/tests/kv_transfer/module_test.py new file mode 100644 index 0000000000000..355461919cd7c --- /dev/null +++ b/tests/kv_transfer/module_test.py @@ -0,0 +1,64 @@ +import subprocess +import sys + +import pytest +import torch + + +def run_python_script(script_name, timeout): + script_name = f'kv_transfer/{script_name}' + try: + # Start both processes asynchronously using Popen + process0 = subprocess.Popen( + [sys.executable, script_name], + env={"RANK": + "0"}, # Set the RANK environment variable for process 0 + stdout=sys.stdout, # Pipe stdout to current stdout + stderr=sys.stderr, # Pipe stderr to current stderr + ) + + process1 = subprocess.Popen( + [sys.executable, script_name], + env={"RANK": + "1"}, # Set the RANK environment variable for process 1 + stdout=sys.stdout, # Pipe stdout to current stdout + stderr=sys.stderr, # Pipe stderr to current stderr + ) + + # Wait for both processes to complete, with a timeout + process0.wait(timeout=timeout) + process1.wait(timeout=timeout) + + # Check the return status of both processes + if process0.returncode != 0: + pytest.fail( + f"Test {script_name} failed for RANK=0, {process0.returncode}") + if process1.returncode != 0: + pytest.fail( + f"Test {script_name} failed for RANK=1, {process1.returncode}") + + except subprocess.TimeoutExpired: + # If either process times out, terminate both and fail the test + process0.terminate() + process1.terminate() + pytest.fail(f"Test {script_name} timed out") + except Exception as e: + pytest.fail(f"Test {script_name} failed with error: {str(e)}") + + +# Define the test cases using pytest's parametrize +@pytest.mark.parametrize( + "script_name,timeout", + [ + ("test_lookup_buffer.py", + 60), # Second test case with a 60-second timeout + ("test_send_recv.py", 120) # First test case with a 120-second timeout + ]) +def test_run_python_script(script_name, timeout): + # Check the number of GPUs + if torch.cuda.device_count() < 2: + pytest.skip( + f"Skipping test {script_name} because <2 GPUs are available") + + # Run the test if there are at least 2 GPUs + run_python_script(script_name, timeout) diff --git a/tests/kv_transfer/test_lookup_buffer.py b/tests/kv_transfer/test_lookup_buffer.py new file mode 100644 index 0000000000000..96b0e58713332 --- /dev/null +++ b/tests/kv_transfer/test_lookup_buffer.py @@ -0,0 +1,160 @@ +import os +import random + +import torch +from tqdm import tqdm + +from vllm.config import KVTransferConfig +from vllm.distributed.kv_transfer.kv_lookup_buffer.simple_buffer import ( + SimpleBuffer) +from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import PyNcclPipe + +# TODO: the test depends on a lot of fields in the current implementation. +# We should have standard interface instead direct field access + + +def test_run(my_rank, buffer, device): + + # buffer should be empty in the beginning + if my_rank == 0: + assert buffer.buffer_size == 0 + assert len(buffer.buffer) == 0 + + print("My rank: %d, device: %s" % (my_rank, device)) + + # insert + tokens = torch.tensor([1, 2, 3]).to(device) + roi = (tokens > 0) + if my_rank == 0: + key = 2.0 * torch.ones([5, 6]).to(device) + value = 3.0 * torch.ones([5, 6]).to(device) + + placeholder = torch.tensor([1]).to(device) + + buffer.insert(tokens, roi, key, value, placeholder) + + torch.distributed.barrier() + + # drop_select + if my_rank == 1: + tok, roi_, key, value, hidden = buffer.drop_select(tokens, roi) + assert torch.allclose(tokens, tok) + assert torch.allclose(roi, roi_) + assert torch.allclose(key, 2.0 * torch.ones([5, 6], device=device)) + assert torch.allclose(value, 3.0 * torch.ones([5, 6], device=device)) + torch.distributed.barrier() + + if my_rank == 0: + assert buffer.buffer_size == 0 + assert len(buffer.buffer) == 0 + + print("Test run passed!") + + +def stress_test(my_rank, buf, device): + + torch.distributed.barrier() + torch.manual_seed(100) + + reqs = [ + ( + torch.rand(100).to(device), # tokens + torch.ones(100).bool().to(device), # roi + torch.rand(100).to(device), # key + torch.rand(100).to(device), # value + torch.rand(100).to(device), # hidden + ) for i in tqdm(range(200)) + ] + + random.seed(my_rank) + random.shuffle(reqs) + + torch.distributed.barrier() + + n = 0 + + # the buffer size can only store 100 reqs + # so the sender will occasionally block to wait for the receiver. + for req in tqdm(reqs): + if my_rank == 0: + buf.insert(*req) + else: + tok, roi, k, v, h = req + tok_, roi_, k_, v_, h_ = buf.drop_select(tok, roi) + + if tok_ is None: + assert roi_ is None + assert k_ is None + assert v_ is None + assert h_ is None + n += 1 + else: + assert torch.allclose(tok, tok_) + assert torch.allclose(roi, roi_) + assert torch.allclose(k, k_) + assert torch.allclose(v, v_) + assert torch.allclose(h, h_) + print('Rank %d done' % my_rank) + torch.distributed.barrier() + + if my_rank == 0: + x = torch.tensor([0]) + torch.distributed.recv(x, 1) + # the # of None received is the kv that are not selected + assert x.item() == len(buf.buffer) + # and the size of the buffer should be 2000 * buffer len + print(buf.buffer_size) + assert buf.buffer_size == 1700 * len(buf.buffer) + else: + torch.distributed.send(torch.tensor([n]), 0) + + print("Passed stress test!") + + +if __name__ == "__main__": + + my_rank = int(os.environ['RANK']) + + torch.distributed.init_process_group( + backend='gloo', + init_method='tcp://localhost:12398', + world_size=2, + rank=my_rank, + ) + + print("initialized! My rank is %d" % my_rank) + + config = KVTransferConfig( + kv_connector='PyNcclConnector', + kv_buffer_device='cuda', + kv_buffer_size=1e9, + kv_rank=my_rank, + kv_role="kv_both", # this arg doesn't matter in this test + kv_parallel_size=2, + kv_ip="127.0.0.1", + kv_port=12345, + ) + + data_pipe = PyNcclPipe( + local_rank=my_rank, + config=config, + device="cuda", + port_offset=0, + ) + cpu_pipe = PyNcclPipe( + local_rank=my_rank, + config=config, + device="cpu", + port_offset=1, + ) + + buffer = SimpleBuffer(cpu_pipe, data_pipe, 170000) + + test_run(my_rank, buffer, data_pipe.device) + + stress_test(my_rank, buffer, data_pipe.device) + + buffer.close() + data_pipe.close() + cpu_pipe.close() + print('Done') diff --git a/tests/kv_transfer/test_lookup_buffer.sh b/tests/kv_transfer/test_lookup_buffer.sh new file mode 100644 index 0000000000000..09d7ee018c3f4 --- /dev/null +++ b/tests/kv_transfer/test_lookup_buffer.sh @@ -0,0 +1,3 @@ +#!/bin/bash +RANK=0 python test_lookup_buffer.py & +RANK=1 python test_lookup_buffer.py & \ No newline at end of file diff --git a/tests/kv_transfer/test_send_recv.py b/tests/kv_transfer/test_send_recv.py new file mode 100644 index 0000000000000..65973bf10a4d7 --- /dev/null +++ b/tests/kv_transfer/test_send_recv.py @@ -0,0 +1,155 @@ +import os +import time +from typing import List + +import torch +from tqdm import tqdm + +from vllm.config import KVTransferConfig +from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import PyNcclPipe + + +def test_run(my_rank, pipe): + # test run + x = torch.tensor([1]).to(pipe.device) + y = torch.tensor([[2., 3., 4., 8.]]).to(pipe.device) + if my_rank == 0: + pipe.send_tensor(x) + print("sent tensor x") + pipe.send_tensor(y) + print("sent tensor y") + x2 = pipe.recv_tensor() + print("received x2 = ", x2) + y2 = pipe.recv_tensor() + print("received y2 = ", x2) + + else: + x2 = pipe.recv_tensor() + print("received x2 = ", x2) + y2 = pipe.recv_tensor() + print("received y2 = ", x2) + pipe.send_tensor(x) + print("sent tensor x") + pipe.send_tensor(y) + print("sent tensor y") + + assert torch.allclose(x, x2) + assert torch.allclose(y, y2) + + +def stress_test(my_rank, pipe): + + torch.distributed.barrier() + + tensors: List[torch.Tensor] = [] + + torch.manual_seed(0) + + for i in tqdm(range(500)): + mean = torch.rand(1).item() * 100 + std = torch.rand(1).item() * 100 + size = torch.randint(900, 1000, (2, )) + x = torch.normal(mean * 1.0, std * 1.0, + size=size.tolist()).to(pipe.device) + + # 5% probability of sending a None + if torch.rand(1).item() < 0.05: + tensors.append(None) + tensors.append(None) + tensors.append(None) + else: + tensors.append(x) + tensors.append(x.mean().unsqueeze(0)) + tensors.append(x.std().unsqueeze(0)) + + torch.distributed.barrier() + + for i in tqdm(range(500)): + if my_rank == int((i % 10) > 3): + pipe.send_tensor(tensors[3 * i]) + pipe.send_tensor(tensors[3 * i + 1]) + pipe.send_tensor(tensors[3 * i + 2]) + else: + x = pipe.recv_tensor() + mean = pipe.recv_tensor() + std = pipe.recv_tensor() + + if x is None: + assert mean is None + assert std is None + else: + assert torch.allclose(x, tensors[3 * i]) + assert x.mean() == mean[0] + assert x.std() == std[0] + + torch.distributed.barrier() + + +def latency_test(my_rank, pipe, nelement, ntensor): + + latencies = [] + + torch.distributed.barrier() + + for i in tqdm(range(500)): + + tensors = [] + + if my_rank == 0: + # create tensor + tensors = [ + torch.rand(nelement).to(pipe.device) for _ in range(ntensor) + ] + + torch.distributed.barrier() + + if my_rank == 0: + t = torch.tensor([time.time()], + dtype=torch.float64).to(pipe.device) + for tensor in tensors: + pipe.send_tensor(tensor) + pipe.send_tensor(t) + else: + for _ in range(ntensor): + pipe.recv_tensor() + t = pipe.recv_tensor() + latencies.append(time.time() - t.item()) + + torch.distributed.barrier() + + print('Latency test passed.') + print('Latency:', torch.tensor(latencies).mean().item() * 1000, 'ms') + + +if __name__ == "__main__": + + my_rank = int(os.environ['RANK']) + + torch.distributed.init_process_group( + backend='gloo', + init_method='tcp://localhost:12398', + world_size=2, + rank=my_rank, + ) + + config = KVTransferConfig( + kv_connector='PyNcclConnector', + kv_buffer_device='cuda', + kv_buffer_size=1e9, + kv_rank=my_rank, + kv_role="kv_both", # this arg doesn't matter in this test + kv_parallel_size=2, + kv_ip="127.0.0.1", + kv_port=12345, + ) + + pipe = PyNcclPipe( + local_rank=my_rank, + config=config, + ) + + test_run(my_rank, pipe) + stress_test(my_rank, pipe) + + # Use this function if you want to test the latency of pipe impl. + # latency_test(my_rank, pipe, 1024 * 8 * 128, 80) diff --git a/tests/kv_transfer/test_send_recv.sh b/tests/kv_transfer/test_send_recv.sh new file mode 100644 index 0000000000000..1e89e246b4992 --- /dev/null +++ b/tests/kv_transfer/test_send_recv.sh @@ -0,0 +1,3 @@ +#!/bin/bash +RANK=0 python3 test_send_recv.py & +RANK=1 python3 test_send_recv.py & \ No newline at end of file diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index 15e576cb065c7..a113e3f7abc1e 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -565,7 +565,9 @@ def _pretest(): @pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) @pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("stage", STAGES) -def test_linear_replicated(dist_init, num_loras, device, stage) -> None: +@pytest.mark.parametrize("bias_enabled", [True, False]) +def test_linear_replicated(dist_init, num_loras, device, stage, + bias_enabled) -> None: torch.cuda.set_device(device) torch.set_default_device(device) @@ -573,7 +575,8 @@ def test_linear_replicated(dist_init, num_loras, device, stage) -> None: max_loras = 8 lora_config = LoRAConfig(max_loras=max_loras, max_lora_rank=8, - lora_dtype=torch.float16) + lora_dtype=torch.float16, + bias_enabled=bias_enabled) def create_random_linear_replicated_layer(): @@ -585,7 +588,12 @@ def create_random_linear_replicated_layer(): lora_linear = ReplicatedLinearWithLoRA(linear) lora_linear.create_lora_weights(max_loras, lora_config) - + assert (lora_linear.n_slices == len(lora_linear.lora_a_stacked) == len( + lora_linear.lora_b_stacked) == 1) + if bias_enabled: + assert len(lora_linear.lora_bias_stacked) == lora_linear.n_slices + else: + assert lora_linear.lora_bias_stacked is None return linear, lora_linear for i in range(10): @@ -669,8 +677,9 @@ def create_random_linear_replicated_layer(): @pytest.mark.parametrize("fully_shard", [True, False]) @pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("stage", STAGES) +@pytest.mark.parametrize("bias_enabled", [True, False]) def test_linear_parallel(dist_init, num_loras, orientation, fully_shard, - device, stage) -> None: + device, stage, bias_enabled) -> None: torch.cuda.set_device(device) torch.set_default_device(device) @@ -679,7 +688,8 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard, lora_config = LoRAConfig(max_loras=max_loras, max_lora_rank=8, fully_sharded_loras=fully_shard, - lora_dtype=torch.float16) + lora_dtype=torch.float16, + bias_enabled=bias_enabled) def create_random_linear_parallel_layer(): if orientation == "row": @@ -700,7 +710,12 @@ def create_random_linear_parallel_layer(): if not fully_shard else ColumnParallelLinearWithShardedLoRA(linear)) lora_linear.create_lora_weights(max_loras, lora_config) - + assert (lora_linear.n_slices == len(lora_linear.lora_a_stacked) == len( + lora_linear.lora_b_stacked) == 1) + if bias_enabled: + assert len(lora_linear.lora_bias_stacked) == lora_linear.n_slices + else: + assert lora_linear.lora_bias_stacked is None return linear, lora_linear for i in range(10): @@ -784,8 +799,9 @@ def create_random_linear_parallel_layer(): @pytest.mark.parametrize("fully_shard", [True, False]) @pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("stage", STAGES) +@pytest.mark.parametrize("bias_enabled", [True, False]) def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard, - device, stage) -> None: + device, stage, bias_enabled) -> None: torch.cuda.set_device(device) torch.set_default_device(device) @@ -794,7 +810,8 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard, lora_config = LoRAConfig(max_loras=max_loras, max_lora_rank=8, fully_sharded_loras=fully_shard, - lora_dtype=torch.float16) + lora_dtype=torch.float16, + bias_enabled=bias_enabled) def create_column_parallel_packed_layer(): if repeats == 2: @@ -832,10 +849,16 @@ class FakeConfig: num_key_value_heads = 32 num_attention_heads = 32 + n_slices = repeats lora_linear.create_lora_weights(max_loras, lora_config, model_config=FakeConfig()) - + assert (lora_linear.n_slices == len(lora_linear.lora_a_stacked) == len( + lora_linear.lora_b_stacked) == n_slices) + if bias_enabled: + assert len(lora_linear.lora_bias_stacked) == lora_linear.n_slices + else: + assert lora_linear.lora_bias_stacked is None return linear, lora_linear for i in range(10): @@ -911,7 +934,6 @@ class FakeConfig: 512, lora_config.lora_extra_vocab_size, ) - # lora_linear.set_mapping(*mapping_info) lora_result = lora_linear(torch.cat(inputs))[0] expected_result = linear(torch.cat(inputs))[0] diff --git a/tests/lora/test_llama_tp.py b/tests/lora/test_llama_tp.py index aae6310a2a213..d3ca7f878191a 100644 --- a/tests/lora/test_llama_tp.py +++ b/tests/lora/test_llama_tp.py @@ -55,15 +55,7 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]: return generated_texts -@fork_new_process_for_each_test -def test_llama_lora(sql_lora_files): - - llm = vllm.LLM(MODEL_PATH, - enable_lora=True, - max_num_seqs=16, - max_loras=4, - tensor_parallel_size=1) - +def generate_and_test(llm, sql_lora_files): print("lora adapter created") assert do_sample(llm, sql_lora_files, lora_id=0) == EXPECTED_NO_LORA_OUTPUT @@ -79,6 +71,17 @@ def test_llama_lora(sql_lora_files): print("removing lora") +@fork_new_process_for_each_test +def test_llama_lora(sql_lora_files): + + llm = vllm.LLM(MODEL_PATH, + enable_lora=True, + max_num_seqs=16, + max_loras=4, + tensor_parallel_size=1) + generate_and_test(llm, sql_lora_files) + + @fork_new_process_for_each_test def test_llama_lora_warmup(sql_lora_files): """Test that the LLM initialization works with a warmup LORA path and @@ -118,20 +121,7 @@ def test_llama_lora_tp4(sql_lora_files): max_loras=4, tensor_parallel_size=4, ) - - print("lora adapter created") - assert do_sample(llm, sql_lora_files, lora_id=0) == EXPECTED_NO_LORA_OUTPUT - - print("lora 1") - assert do_sample(llm, sql_lora_files, lora_id=1) == EXPECTED_LORA_OUTPUT - - print("no lora") - assert do_sample(llm, sql_lora_files, lora_id=0) == EXPECTED_NO_LORA_OUTPUT - - print("lora 2") - assert do_sample(llm, sql_lora_files, lora_id=2) == EXPECTED_LORA_OUTPUT - - print("removing lora") + generate_and_test(llm, sql_lora_files) @multi_gpu_test(num_gpus=4) @@ -146,16 +136,20 @@ def test_llama_lora_tp4_fully_sharded_loras(sql_lora_files): tensor_parallel_size=4, fully_sharded_loras=True, ) - print("lora adapter created") - assert do_sample(llm, sql_lora_files, lora_id=0) == EXPECTED_NO_LORA_OUTPUT - - print("lora 1") - assert do_sample(llm, sql_lora_files, lora_id=1) == EXPECTED_LORA_OUTPUT + generate_and_test(llm, sql_lora_files) - print("no lora") - assert do_sample(llm, sql_lora_files, lora_id=0) == EXPECTED_NO_LORA_OUTPUT - print("lora 2") - assert do_sample(llm, sql_lora_files, lora_id=2) == EXPECTED_LORA_OUTPUT +@multi_gpu_test(num_gpus=4) +@fork_new_process_for_each_test +def test_llama_lora_tp4_fully_sharded_enable_bias(sql_lora_files): - print("removing lora") + llm = vllm.LLM( + MODEL_PATH, + enable_lora=True, + max_num_seqs=16, + max_loras=4, + tensor_parallel_size=4, + fully_sharded_loras=True, + enable_lora_bias=True, + ) + generate_and_test(llm, sql_lora_files) diff --git a/tests/lora/test_tokenizer_group.py b/tests/lora/test_tokenizer_group.py index daa39b2a3dba1..d225a3f7d6c06 100644 --- a/tests/lora/test_tokenizer_group.py +++ b/tests/lora/test_tokenizer_group.py @@ -17,6 +17,7 @@ async def test_tokenizer_group_lora(sql_lora_files, tokenizer_group_type): tokenizer_id="gpt2", enable_lora=True, max_num_seqs=1, + max_loras=1, max_input_length=None, ) lora_request = LoRARequest("1", 1, sql_lora_files) @@ -53,3 +54,22 @@ def test_get_lora_tokenizer(sql_lora_files, tmp_path): lora_request = LoRARequest("1", 1, str(tmp_path)) tokenizer = get_lora_tokenizer(lora_request) assert not tokenizer + + +@pytest.mark.parametrize("enable_lora", [True, False]) +@pytest.mark.parametrize("max_num_seqs", [1, 2]) +@pytest.mark.parametrize("max_loras", [1, 2]) +def test_lora_tokenizers(enable_lora, max_num_seqs, max_loras): + tokenizer_group = get_tokenizer_group( + get_tokenizer_pool_config(None), + tokenizer_id="gpt2", + enable_lora=enable_lora, + max_num_seqs=max_num_seqs, + max_loras=max_loras, + max_input_length=None, + ) + if enable_lora: + assert tokenizer_group.lora_tokenizers.capacity == max( + max_num_seqs, max_loras) + else: + assert tokenizer_group.lora_tokenizers.capacity == 0 diff --git a/tests/model_executor/test_guided_processors.py b/tests/model_executor/test_guided_processors.py index 45fab8e96b968..9f4d81b583141 100644 --- a/tests/model_executor/test_guided_processors.py +++ b/tests/model_executor/test_guided_processors.py @@ -36,7 +36,8 @@ def test_guided_logits_processors(sample_regex, sample_json_schema): @pytest.mark.asyncio -@pytest.mark.parametrize("backend", ["outlines", "lm-format-enforcer"]) +@pytest.mark.parametrize("backend", + ["outlines", "lm-format-enforcer", "xgrammar"]) async def test_guided_logits_processor_black_box(backend: str, sample_regex, sample_json_schema): tokenizer = AutoTokenizer.from_pretrained('HuggingFaceH4/zephyr-7b-beta') diff --git a/tests/models/decoder_only/language/test_jamba.py b/tests/models/decoder_only/language/test_jamba.py index 87a05b3011393..cae25ae9fa2c8 100644 --- a/tests/models/decoder_only/language/test_jamba.py +++ b/tests/models/decoder_only/language/test_jamba.py @@ -1,8 +1,8 @@ import pytest from tests.utils import multi_gpu_test +from vllm.config import VllmConfig from vllm.sampling_params import SamplingParams -from vllm.worker.model_runner import _get_graph_batch_size from ...utils import check_outputs_equal @@ -189,7 +189,8 @@ def test_mamba_cache_cg_padding( # This test is for verifying that mamba cache is padded to CG captured # batch size. If it's not, a torch RuntimeError will be raised because # tensor dimensions aren't compatible - while len(example_prompts) == _get_graph_batch_size(len(example_prompts)): + while len(example_prompts) == VllmConfig.get_graph_batch_size( + len(example_prompts)): example_prompts.append(example_prompts[0]) try: diff --git a/tests/models/decoder_only/language/test_mamba.py b/tests/models/decoder_only/language/test_mamba.py index 01e208347bff4..35018c3c14dee 100644 --- a/tests/models/decoder_only/language/test_mamba.py +++ b/tests/models/decoder_only/language/test_mamba.py @@ -5,8 +5,8 @@ import pytest from transformers import AutoModelForCausalLM, AutoTokenizer +from vllm.config import VllmConfig from vllm.sampling_params import SamplingParams -from vllm.worker.model_runner import _get_graph_batch_size from ...utils import check_outputs_equal @@ -200,7 +200,8 @@ def test_mamba_cache_cg_padding( # This test is for verifying that mamba cache is padded to CG captured # batch size. If it's not, a torch RuntimeError will be raised because # tensor dimensions aren't compatible - while len(example_prompts) == _get_graph_batch_size(len(example_prompts)): + while len(example_prompts) == VllmConfig.get_graph_batch_size( + len(example_prompts)): example_prompts.append(example_prompts[0]) try: diff --git a/tests/models/decoder_only/vision_language/test_models.py b/tests/models/decoder_only/vision_language/test_models.py index 3f6d8ef42cd5f..ed8f34a677f84 100644 --- a/tests/models/decoder_only/vision_language/test_models.py +++ b/tests/models/decoder_only/vision_language/test_models.py @@ -6,8 +6,8 @@ from typing import Type import pytest -import transformers from transformers import AutoModelForVision2Seq +from transformers.utils import is_flash_attn_2_available from vllm.platforms import current_platform from vllm.utils import cuda_device_count_stateless, identity @@ -34,7 +34,7 @@ "dtype": "half", "max_tokens": 5, "tensor_parallel_size": 2, - "model_kwargs": {"device_map": "auto"}, + "hf_model_kwargs": {"device_map": "auto"}, "image_size_factors": [(.25, 0.5, 1.0)], "distributed_executor_backend": ( "ray", @@ -108,7 +108,7 @@ "cherry_blossom": "What is in the picture?", }), auto_cls=AutoModelForVision2Seq, - postprocess_inputs=model_utils.get_key_type_post_processor( + postprocess_inputs=model_utils.cast_dtype_post_processor( "pixel_values" ), vllm_output_post_proc=model_utils.paligemma_vllm_to_hf_output, @@ -134,6 +134,35 @@ marks=[pytest.mark.core_model, pytest.mark.cpu_model], ), #### Extended model tests + "aria": VLMTestInfo( + models=["rhymes-ai/Aria"], + tokenizer_mode="slow", + test_type=( + VLMTestType.IMAGE, + VLMTestType.MULTI_IMAGE, + ), + dtype="bfloat16", + prompt_formatter=lambda img_prompt: f"<|im_start|>user\n{img_prompt}<|im_end|>\n<|im_start|>assistant\n ", # noqa: E501 + img_idx_to_prompt=lambda idx: "<|img|>\n", + max_model_len=4096, + max_num_seqs=2, + single_image_prompts=IMAGE_ASSETS.prompts({ + "stop_sign": "Please describe the image shortly.", + "cherry_blossom": "Please infer the season with reason.", + }), + multi_image_prompt="Describe the two images shortly.", # noqa: E501 + postprocess_inputs=model_utils.cast_dtype_post_processor("pixel_values"), + stop_str=["<|im_end|>"], + image_size_factors=[(0.10, 0.15)], + max_tokens=64, + marks=[ + pytest.mark.skipif( + not is_flash_attn_2_available(), + reason="Model needs flash-attn for numeric convergence.", + ), + large_gpu_mark(min_gb=64), + ], + ), "blip2": VLMTestInfo( models=["Salesforce/blip2-opt-2.7b"], test_type=VLMTestType.IMAGE, @@ -148,7 +177,7 @@ prompt_formatter=lambda img_prompt: f"USER: {img_prompt}\nASSISTANT:", max_model_len=4096, auto_cls=AutoModelForVision2Seq, - postprocess_inputs=model_utils.get_key_type_post_processor( + postprocess_inputs=model_utils.cast_dtype_post_processor( "pixel_values" ), # For chameleon, we only compare the sequences @@ -157,12 +186,6 @@ comparator=check_outputs_equal, max_tokens=8, dtype="bfloat16", - marks=[ - pytest.mark.skipif( - transformers.__version__ < "4.46.2", - reason="Model broken in HF, see huggingface/transformers#34379" - ), - ] ), "fuyu": VLMTestInfo( models=["adept/fuyu-8b"], @@ -213,13 +236,7 @@ max_model_len=8192, max_num_seqs=2, auto_cls=AutoModelForVision2Seq, - marks=[ - pytest.mark.skipif( - transformers.__version__ < "4.46.0", - reason="Model introduced in HF >= 4.46.0" - ), - large_gpu_mark(min_gb=48), - ], + marks=[large_gpu_mark(min_gb=48)], ), "intern_vl": VLMTestInfo( models=[ @@ -264,7 +281,7 @@ prompt_formatter=lambda vid_prompt: f"<|im_start|>user\n{vid_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501 num_video_frames=16, max_model_len=16384, - postprocess_inputs=model_utils.get_key_type_post_processor( + postprocess_inputs=model_utils.cast_dtype_post_processor( "pixel_values_videos" ), auto_cls=AutoModelForVision2Seq, @@ -288,23 +305,44 @@ auto_cls=AutoModelForVision2Seq, vllm_output_post_proc=model_utils.llava_video_vllm_to_hf_output, image_sizes=[((1669, 2560), (2560, 1669), (183, 488), (488, 183))], - marks=[ - pytest.mark.skipif( - transformers.__version__ < "4.46.2", - reason="Model broken with changes in transformers 4.46" - ) - ], ), - "minicpmv": VLMTestInfo( - models=["openbmb/MiniCPM-Llama3-V-2_5"], + "mantis": VLMTestInfo( + models=["TIGER-Lab/Mantis-8B-siglip-llama3"], test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), + prompt_formatter=lambda img_prompt: f"<|start_header_id|>user<|end_header_id|>\n\n{img_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", # noqa: E501 + max_model_len=4096, + postprocess_inputs=model_utils.cast_dtype_post_processor( + "pixel_values" + ), + vllm_runner_kwargs={"hf_overrides": {"architectures": ["MantisForConditionalGeneration"]}}, # noqa: E501 + get_stop_token_ids=lambda tok: [128009], + auto_cls=AutoModelForVision2Seq, + vllm_output_post_proc=model_utils.mantis_vllm_to_hf_output, + patch_hf_runner=model_utils.mantis_patch_hf_runner, + ), + "minicpmv_25": VLMTestInfo( + models=["openbmb/MiniCPM-Llama3-V-2_5"], + test_type=VLMTestType.IMAGE, prompt_formatter=lambda img_prompt: f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{img_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", # noqa: E501 img_idx_to_prompt=lambda idx: "(./)\n", max_model_len=4096, max_num_seqs=2, get_stop_token_ids=lambda tok: [tok.eos_id, tok.eot_id], postprocess_inputs=model_utils.wrap_inputs_post_processor, - hf_output_post_proc=model_utils.minicmpv_trunc_hf_output, + hf_output_post_proc=model_utils.minicpmv_trunc_hf_output, + ), + "minicpmv_26": VLMTestInfo( + models=["openbmb/MiniCPM-V-2_6"], + test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), + prompt_formatter=lambda img_prompt: f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{img_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", # noqa: E501 + img_idx_to_prompt=lambda idx: "(./)\n", + max_model_len=4096, + max_num_seqs=2, + get_stop_token_ids=lambda tok: tok.convert_tokens_to_ids(['<|im_end|>', '<|endoftext|>']), # noqa: E501 + postprocess_inputs=model_utils.ignore_inputs_post_processor( + "image_sizes" + ), + hf_output_post_proc=model_utils.minicpmv_trunc_hf_output, ), # Tests for phi3v currently live in another file because of a bug in # transformers. Once this issue is fixed, we can enable them here instead. @@ -318,7 +356,7 @@ # max_num_seqs=2, # task="generate", # # use eager mode for hf runner since phi3v didn't work with flash_attn - # model_kwargs={"_attn_implementation": "eager"}, + # hf_model_kwargs={"_attn_implementation": "eager"}, # use_tokenizer_eos=True, # vllm_output_post_proc=model_utils.phi3v_vllm_to_hf_output, # num_logprobs=10, @@ -349,7 +387,7 @@ prompt_formatter=lambda img_prompt: f"USER: {img_prompt}\nASSISTANT:", max_model_len=4096, auto_cls=AutoModelForVision2Seq, - postprocess_inputs=model_utils.get_key_type_post_processor( + postprocess_inputs=model_utils.cast_dtype_post_processor( "pixel_values" ), vllm_output_post_proc = lambda vllm_output, model: vllm_output[:2], @@ -361,10 +399,6 @@ cuda_device_count_stateless() < 2, reason="Need at least 2 GPUs to run the test.", ), - pytest.mark.skipif( - transformers.__version__ < "4.46.2", - reason="Model broken in HF, see huggingface/transformers#34379" - ) ], **COMMON_BROADCAST_SETTINGS # type: ignore ), @@ -418,7 +452,7 @@ test_type=VLMTestType.CUSTOM_INPUTS, max_model_len=16384, max_num_seqs=2, - postprocess_inputs=model_utils.get_key_type_post_processor( + postprocess_inputs=model_utils.cast_dtype_post_processor( "pixel_values" ), auto_cls=AutoModelForVision2Seq, diff --git a/tests/models/decoder_only/vision_language/test_pixtral.py b/tests/models/decoder_only/vision_language/test_pixtral.py index 6233860747b9c..90c0fab99054c 100644 --- a/tests/models/decoder_only/vision_language/test_pixtral.py +++ b/tests/models/decoder_only/vision_language/test_pixtral.py @@ -228,7 +228,7 @@ def test_model_engine(vllm_runner, model: str, dtype: str) -> None: name_1="output") -@large_gpu_test(min_gb=24) +@large_gpu_test(min_gb=48) @pytest.mark.parametrize( "prompt,expected_ranges", [(_create_engine_inputs_hf(IMG_URLS[:1]), [{ diff --git a/tests/models/decoder_only/vision_language/vlm_utils/core.py b/tests/models/decoder_only/vision_language/vlm_utils/core.py index 7e8c6dabb15af..54b7b0733210f 100644 --- a/tests/models/decoder_only/vision_language/vlm_utils/core.py +++ b/tests/models/decoder_only/vision_language/vlm_utils/core.py @@ -3,9 +3,11 @@ import torch from PIL.Image import Image -from transformers import AutoTokenizer, BatchEncoding +from transformers import AutoTokenizer, BatchEncoding, PreTrainedTokenizerBase from transformers.models.auto.auto_factory import _BaseAutoModelClass +from vllm.config import TaskOption + from .....conftest import HfRunner, VllmRunner from .types import RunnerOutput @@ -28,11 +30,15 @@ def run_test( use_tokenizer_eos: bool, postprocess_inputs: Callable[[BatchEncoding], BatchEncoding], comparator: Callable[..., None], - get_stop_token_ids: Optional[Callable[[AutoTokenizer], List[int]]], + get_stop_token_ids: Optional[Callable[[PreTrainedTokenizerBase], + List[int]]], + stop_str: Optional[List[str]], + tokenizer_mode: str, limit_mm_per_prompt: Dict[str, int], - model_kwargs: Optional[Dict[str, Any]], + vllm_runner_kwargs: Optional[Dict[str, Any]], + hf_model_kwargs: Optional[Dict[str, Any]], patch_hf_runner: Optional[Callable[[HfRunner], HfRunner]], - task: str = "auto", + task: TaskOption = "auto", runner_mm_key: str = "images", distributed_executor_backend: Optional[str] = None, tensor_parallel_size: int = 1, @@ -50,11 +56,17 @@ def run_test( # vLLM needs a fresh new process without cuda initialization. # if we run HF first, the cuda initialization will be done and it # will hurt multiprocessing backend with fork method (the default method). - vllm_kwargs = {} + vllm_kwargs: Dict[str, Any] = {} if get_stop_token_ids is not None: vllm_kwargs["stop_token_ids"] = get_stop_token_ids(tokenizer) + if stop_str: + vllm_kwargs["stop"] = stop_str + + if vllm_runner_kwargs is None: + vllm_runner_kwargs = {} with vllm_runner(model, + tokenizer_mode=tokenizer_mode, max_model_len=max_model_len, max_num_seqs=max_num_seqs, dtype=dtype, @@ -62,7 +74,8 @@ def run_test( tensor_parallel_size=tensor_parallel_size, distributed_executor_backend=distributed_executor_backend, enforce_eager=enforce_eager, - task=task) as vllm_model: + task=task, + **vllm_runner_kwargs) as vllm_model: for prompts, media in vllm_inputs: vllm_kwargs[runner_mm_key] = media vllm_output = vllm_model.generate_greedy_logprobs( @@ -73,7 +86,7 @@ def run_test( dtype=dtype, auto_cls=auto_cls, postprocess_inputs=postprocess_inputs, - model_kwargs=model_kwargs) + model_kwargs=hf_model_kwargs) # Some models need to patch things like the model processor, e.g., internvl if patch_hf_runner is not None: @@ -85,6 +98,8 @@ def run_test( hf_kwargs = {} if use_tokenizer_eos: hf_kwargs["eos_token_id"] = tokenizer.eos_token_id + if stop_str: + hf_kwargs["stop_strings"] = stop_str with hf_model, torch.no_grad(): for prompts, media in inputs: @@ -138,4 +153,4 @@ def process_runner_outputs( def process_outputs(output_processor, model, outputs_per_image): """Applies a model specific post-processor function to a runner's output""" return [[output_processor(res, model) for res in outputs] - for outputs in outputs_per_image] + for outputs in outputs_per_image] \ No newline at end of file diff --git a/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py b/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py index 849857b4232e7..3eca8fb9dcb1a 100644 --- a/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py +++ b/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py @@ -126,6 +126,16 @@ def llava_onevision_vllm_to_hf_output(vllm_output: RunnerOutput, return hf_output_ids, hf_output_str, out_logprobs +def mantis_vllm_to_hf_output(vllm_output: RunnerOutput, + model: str) -> RunnerOutput: + """Sanitize vllm output [mantis] to compare with hf output.""" + output_ids, output_str, out_logprobs = vllm_output + + hf_output_str = output_str + "<|eot_id|>" + + return output_ids, hf_output_str, out_logprobs + + def phi3v_vllm_to_hf_output(vllm_output: RunnerOutput, model: str) -> RunnerOutput: """Sanitize vllm output [phi3v] to be comparable with hf output.""" @@ -170,7 +180,7 @@ def paligemma_vllm_to_hf_output(vllm_output: RunnerOutput, ####### Post-processors for HF outputs -def minicmpv_trunc_hf_output(hf_output: RunnerOutput, +def minicpmv_trunc_hf_output(hf_output: RunnerOutput, model: str) -> RunnerOutput: output_ids, output_str, out_logprobs = hf_output if output_str.endswith("<|eot_id|>"): @@ -184,7 +194,7 @@ def get_llava_embeddings(image_assets: _ImageAssets): ####### postprocessors to run on HF BatchEncoding -def get_key_type_post_processor( +def cast_dtype_post_processor( hf_inp_key: str) -> Callable[[BatchEncoding, str], BatchEncoding]: """Gets a handle to a post processor which converts a given key into a target data type.""" @@ -197,6 +207,17 @@ def process(hf_inputs: BatchEncoding, dtype: str): return process +def ignore_inputs_post_processor( + hf_inp_key: str) -> Callable[[BatchEncoding, str], BatchEncoding]: + """Gets a handle to a post processor which ignores a given key.""" + + def process(hf_inputs: BatchEncoding, dtype: str): + del hf_inputs[hf_inp_key] + return hf_inputs + + return process + + def wrap_inputs_post_processor(hf_inputs: BatchEncoding, dtype: str): return {"model_inputs": hf_inputs} @@ -407,3 +428,26 @@ def _internvl_generate( ) return outputs + + +def mantis_patch_hf_runner(hf_model: HfRunner) -> HfRunner: + from mantis.models.mllava import MLlavaProcessor + + hf_model.processor = MLlavaProcessor.from_pretrained(hf_model.model_name) + + orig_generate = hf_model.model.generate + tokenizer = hf_model.processor.tokenizer + + def _generate(self, *args, **kwargs): + return orig_generate( + *args, + **kwargs, + eos_token_id=[ + tokenizer.eos_token_id, + tokenizer.convert_tokens_to_ids("<|eot_id|>"), + ], + ) + + hf_model.model.generate = types.MethodType(_generate, hf_model.model) + + return hf_model diff --git a/tests/models/decoder_only/vision_language/vlm_utils/types.py b/tests/models/decoder_only/vision_language/vlm_utils/types.py index 8459476dc2d07..e2e0c6390fcb9 100644 --- a/tests/models/decoder_only/vision_language/vlm_utils/types.py +++ b/tests/models/decoder_only/vision_language/vlm_utils/types.py @@ -7,9 +7,11 @@ import torch from PIL.Image import Image from pytest import MarkDecorator -from transformers import AutoModelForCausalLM, AutoTokenizer, BatchEncoding +from transformers import (AutoModelForCausalLM, BatchEncoding, + PreTrainedTokenizerBase) from transformers.models.auto.auto_factory import _BaseAutoModelClass +from vllm.config import TaskOption from vllm.sequence import SampleLogprobs from vllm.utils import identity @@ -66,7 +68,7 @@ class ImageSizeWrapper(NamedTuple): class VLMTestInfo(NamedTuple): """Holds the configuration for 1+ tests for one model architecture.""" - models: Union[List[str]] + models: List[str] test_type: Union[VLMTestType, Iterable[VLMTestType]] # Should be None only if this is a CUSTOM_INPUTS test @@ -92,15 +94,20 @@ class VLMTestInfo(NamedTuple): enforce_eager: bool = True max_model_len: int = 1024 max_num_seqs: int = 256 - task: str = "auto" + task: TaskOption = "auto" tensor_parallel_size: int = 1 + vllm_runner_kwargs: Optional[Dict[str, Any]] = None # Optional callable which gets a list of token IDs from the model tokenizer - get_stop_token_ids: Optional[Callable[[AutoTokenizer], List[int]]] = None + get_stop_token_ids: Optional[Callable[[PreTrainedTokenizerBase], + List[int]]] = None + # Optional list of strings to stop generation, useful when stop tokens are + # not special tokens in the tokenizer + stop_str: Optional[List[str]] = None # Exposed options for HF runner - model_kwargs: Optional[Dict[str, Any]] = None - # Indicates we should explicitly pass the EOS from the tokeniezr + hf_model_kwargs: Optional[Dict[str, Any]] = None + # Indicates we should explicitly pass the EOS from the tokenizer use_tokenizer_eos: bool = False auto_cls: Type[_BaseAutoModelClass] = AutoModelForCausalLM # Callable to pass to the HF runner to run on inputs; for now, we also pass @@ -148,6 +155,8 @@ class VLMTestInfo(NamedTuple): marks: Optional[List[MarkDecorator]] = None + tokenizer_mode: str = "auto" + def get_non_parametrized_runner_kwargs(self): """Returns a dictionary of expandable kwargs for items that are used in all test types, which are NOT used when creating the parametrized @@ -159,6 +168,7 @@ def get_non_parametrized_runner_kwargs(self): "max_num_seqs": self.max_num_seqs, "task": self.task, "tensor_parallel_size": self.tensor_parallel_size, + "vllm_runner_kwargs": self.vllm_runner_kwargs, "hf_output_post_proc": self.hf_output_post_proc, "vllm_output_post_proc": self.vllm_output_post_proc, "auto_cls": self.auto_cls, @@ -166,8 +176,10 @@ def get_non_parametrized_runner_kwargs(self): "postprocess_inputs": self.postprocess_inputs, "comparator": self.comparator, "get_stop_token_ids": self.get_stop_token_ids, - "model_kwargs": self.model_kwargs, + "hf_model_kwargs": self.hf_model_kwargs, + "stop_str": self.stop_str, "patch_hf_runner": self.patch_hf_runner, + "tokenizer_mode": self.tokenizer_mode } diff --git a/tests/models/embedding/language/test_embedding.py b/tests/models/embedding/language/test_embedding.py index 36b1e5887981c..5ef8540265d14 100644 --- a/tests/models/embedding/language/test_embedding.py +++ b/tests/models/embedding/language/test_embedding.py @@ -4,6 +4,8 @@ """ import pytest +from vllm.config import PoolerConfig + from ..utils import check_embeddings_close @@ -33,6 +35,9 @@ def test_models( dtype: str, ) -> None: vllm_extra_kwargs = {} + if model == "ssmits/Qwen2-7B-Instruct-embed-base": + vllm_extra_kwargs["override_pooler_config"] = \ + PoolerConfig(pooling_type="MEAN") if model == "Alibaba-NLP/gte-Qwen2-7B-instruct": vllm_extra_kwargs["hf_overrides"] = {"is_causal": False} diff --git a/tests/models/registry.py b/tests/models/registry.py index a93bfe907e0d7..a89518820045f 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -63,6 +63,7 @@ class _HfExamplesInfo: "FalconForCausalLM": _HfExamplesInfo("tiiuae/falcon-7b"), "GemmaForCausalLM": _HfExamplesInfo("google/gemma-2b"), "Gemma2ForCausalLM": _HfExamplesInfo("google/gemma-2-9b"), + "GlmForCausalLM": _HfExamplesInfo("THUDM/glm-4-9b-chat-hf"), "GPT2LMHeadModel": _HfExamplesInfo("gpt2"), "GPTBigCodeForCausalLM": _HfExamplesInfo("bigcode/starcoder"), "GPTJForCausalLM": _HfExamplesInfo("EleutherAI/gpt-j-6b"), @@ -175,6 +176,7 @@ class _HfExamplesInfo: "LlavaNextForConditionalGeneration": _HfExamplesInfo("llava-hf/llava-v1.6-mistral-7b-hf"), # noqa: E501 "LlavaNextVideoForConditionalGeneration": _HfExamplesInfo("llava-hf/LLaVA-NeXT-Video-7B-hf"), # noqa: E501 "LlavaOnevisionForConditionalGeneration": _HfExamplesInfo("llava-hf/llava-onevision-qwen2-0.5b-ov-hf"), # noqa: E501 + "MantisForConditionalGeneration": _HfExamplesInfo("TIGER-Lab/Mantis-8B-siglip-llama3"), # noqa: E501 "MiniCPMV": _HfExamplesInfo("openbmb/MiniCPM-Llama3-V-2_5", trust_remote_code=True), "MolmoForCausalLM": _HfExamplesInfo("allenai/Molmo-7B-D-0924", diff --git a/tests/models/test_initialization.py b/tests/models/test_initialization.py index b8312c2d9b7cc..3b728f2744fca 100644 --- a/tests/models/test_initialization.py +++ b/tests/models/test_initialization.py @@ -1,7 +1,6 @@ from unittest.mock import patch import pytest -import transformers from transformers import PretrainedConfig from vllm import LLM @@ -11,10 +10,6 @@ @pytest.mark.parametrize("model_arch", HF_EXAMPLE_MODELS.get_supported_archs()) def test_can_initialize(model_arch): - if (model_arch == "Idefics3ForConditionalGeneration" - and transformers.__version__ < "4.46.0"): - pytest.skip(reason="Model introduced in HF >= 4.46.0") - model_info = HF_EXAMPLE_MODELS.get_hf_info(model_arch) if not model_info.is_available_online: pytest.skip("Model is not available online") diff --git a/tests/models/test_registry.py b/tests/models/test_registry.py index 289ea66b5ebc5..b5368aab3ecf1 100644 --- a/tests/models/test_registry.py +++ b/tests/models/test_registry.py @@ -3,14 +3,11 @@ import pytest import torch.cuda -from vllm.model_executor.models import (is_embedding_model, +from vllm.model_executor.models import (is_pooling_model, is_text_generation_model, supports_multimodal) -# yapf conflicts with isort for this block -# yapf: disable -from vllm.model_executor.models.registry import (_CROSS_ENCODER_MODELS, - _EMBEDDING_MODELS, - _MULTIMODAL_MODELS, +from vllm.model_executor.models.adapters import as_embedding_model +from vllm.model_executor.models.registry import (_MULTIMODAL_MODELS, _SPECULATIVE_DECODING_MODELS, _TEXT_GENERATION_MODELS, ModelRegistry) @@ -26,18 +23,18 @@ def test_registry_imports(model_arch): model_cls, _ = ModelRegistry.resolve_model_cls(model_arch) if model_arch in _SPECULATIVE_DECODING_MODELS: - pass # Ignore these models which do not have a unified format - else: - assert is_text_generation_model(model_cls) is ( - model_arch in _TEXT_GENERATION_MODELS - or model_arch in _MULTIMODAL_MODELS) - - embedding_models = {**_EMBEDDING_MODELS, **_CROSS_ENCODER_MODELS} - assert is_embedding_model(model_cls) is (model_arch - in embedding_models) - - assert supports_multimodal(model_cls) is (model_arch - in _MULTIMODAL_MODELS) + return # Ignore these models which do not have a unified format + + if (model_arch in _TEXT_GENERATION_MODELS + or model_arch in _MULTIMODAL_MODELS): + assert is_text_generation_model(model_cls) + + # All vLLM models should be convertible to an embedding model + embed_model = as_embedding_model(model_cls) + assert is_pooling_model(embed_model) + + if model_arch in _MULTIMODAL_MODELS: + assert supports_multimodal(model_cls) @fork_new_process_for_each_test diff --git a/tests/multimodal/test_mapper.py b/tests/multimodal/test_mapper.py index 13ad4a7966b9d..71832acbd17b8 100644 --- a/tests/multimodal/test_mapper.py +++ b/tests/multimodal/test_mapper.py @@ -2,7 +2,7 @@ import numpy as np import pytest -from transformers import CLIPImageProcessor, LlavaNextImageProcessor +from transformers import LlavaNextImageProcessor from vllm.config import ModelConfig from vllm.multimodal import MultiModalRegistry @@ -14,49 +14,6 @@ def mm_registry(): return MultiModalRegistry() -@pytest.mark.parametrize("dtype", ["half", "float"]) -@pytest.mark.parametrize("size_factor", [0.25, 0.5, 1.0]) -def test_clip_image_processor(image_assets, mm_registry, dtype, size_factor): - MODEL_NAME = "llava-hf/llava-1.5-7b-hf" - - hf_processor = CLIPImageProcessor.from_pretrained(MODEL_NAME) - assert isinstance(hf_processor, CLIPImageProcessor) - - model_config = ModelConfig( - model=MODEL_NAME, - task="auto", - tokenizer=MODEL_NAME, - tokenizer_mode="auto", - trust_remote_code=False, - seed=0, - dtype=dtype, - revision=None, - limit_mm_per_prompt={"image": 1}, - ) - - mm_registry.init_mm_limits_per_prompt(model_config) - - for asset in image_assets: - image = rescale_image_size(asset.pil_image, size_factor) - - hf_result = hf_processor.preprocess( - image, - return_tensors="pt", - ) - vllm_result = mm_registry.map_input( - model_config, - {"image": image}, - ) - - assert hf_result.keys() == vllm_result.keys() - for key, hf_tensor in hf_result.items(): - hf_arr: np.ndarray = hf_tensor.numpy() - vllm_arr: np.ndarray = vllm_result[key].numpy() - - assert hf_arr.shape == vllm_arr.shape, f"Failed for key={key}" - assert np.allclose(hf_arr, vllm_arr), f"Failed for key={key}" - - @pytest.mark.parametrize("dtype", ["half", "float"]) @pytest.mark.parametrize("size_factor", [0.25, 0.5, 1.0]) def test_llava_next_image_processor(image_assets, mm_registry, dtype, @@ -107,7 +64,7 @@ def test_llava_next_image_processor(image_assets, mm_registry, dtype, (2, 1, False), (2, 2, True)], ) def test_mm_limits(image_assets, mm_registry, num_images, limit, is_valid): - MODEL_NAME = "llava-hf/llava-1.5-7b-hf" + MODEL_NAME = "llava-hf/llava-v1.6-mistral-7b-hf" model_config = ModelConfig( model=MODEL_NAME, @@ -138,7 +95,7 @@ def test_mm_limits(image_assets, mm_registry, num_images, limit, is_valid): # NOTE: We don't test zero images since the HF processor doesn't support it @pytest.mark.parametrize("num_images", [1, 2]) def test_image_mapper_multi(image_assets, mm_registry, num_images): - MODEL_NAME = "llava-hf/llava-1.5-7b-hf" + MODEL_NAME = "llava-hf/llava-v1.6-mistral-7b-hf" model_config = ModelConfig( model=MODEL_NAME, diff --git a/tests/multimodal/test_processing.py b/tests/multimodal/test_processing.py index b2367060c6c1b..ae668d1dd56c8 100644 --- a/tests/multimodal/test_processing.py +++ b/tests/multimodal/test_processing.py @@ -3,50 +3,15 @@ import pytest from transformers import BatchFeature -from vllm.multimodal.processing import (PromptReplacement, find_text_matches, - find_token_matches, iter_token_matches, - iter_token_runs, replace_text_matches) +from vllm.multimodal.processing import (PromptReplacement, _PlaceholderInfo, + find_text_matches, find_token_matches, + iter_placeholders, iter_token_matches, + replace_text_matches, + replace_token_matches) from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.utils import full_groupby -# yapf: disable -@pytest.mark.parametrize( - ("token_ids", "expected"), - [ - ([], []), - ( - [32000, 32000, 32000], - [{ "token_id": 32000, "start_idx": 0, "length": 3 }], - ), - ( - [9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918], - [ - { "token_id": 9833, "start_idx": 0, "length": 1 }, - { "token_id": 28747, "start_idx": 1, "length": 1 }, - { "token_id": 32000, "start_idx": 2, "length": 3 }, - { "token_id": 9833, "start_idx": 5, "length": 1 }, - { "token_id": 28747, "start_idx": 6, "length": 1 }, - { "token_id": 32000, "start_idx": 7, "length": 2 }, - { "token_id": 918, "start_idx": 9, "length": 1 }, - ], - ), - ], -) -# yapf: enable -def test_iter_token_runs(token_ids, expected): - result = list(iter_token_runs(token_ids)) - - # Only displayed on error - print("result:", result) - - # Manually constructed results - assert [item._asdict() for item in result] == expected - - # Invariants - assert sum(run_info.length for run_info in result) == len(token_ids) - - # yapf: disable @pytest.mark.parametrize( ("token_ids", "match_ids", "expected"), @@ -170,13 +135,11 @@ def test_find_token_matches(prompt, target_by_key, expected_by_key): # Should not be used since there is nothing to convert to token IDs mock_tokenizer = cast(AnyTokenizer, object()) - result = find_token_matches( - prompt, - [ - PromptReplacement(target, [], 0).bind(key, mock_tokenizer) - for key, target in target_by_key.items() - ], - ) + prompt_repls = [ + PromptReplacement(target, [], 0).bind(key, mock_tokenizer) + for key, target in target_by_key.items() + ] + result = find_token_matches(prompt, prompt_repls) # Only displayed on error print("result:", result) @@ -279,13 +242,11 @@ def test_find_text_matches(prompt, target_by_key, expected_by_key): # Should not be used since there is nothing to convert to text mock_tokenizer = cast(AnyTokenizer, object()) - result = find_text_matches( - prompt, - [ - PromptReplacement(target, [], 0).bind(key, mock_tokenizer) - for key, target in target_by_key.items() - ], - ) + prompt_repls = [ + PromptReplacement(target, [], 0).bind(key, mock_tokenizer) + for key, target in target_by_key.items() + ] + result = find_text_matches(prompt, prompt_repls) # Only displayed on error print("result:", result) @@ -303,7 +264,7 @@ def test_find_text_matches(prompt, target_by_key, expected_by_key): # yapf: disable @pytest.mark.parametrize( - ("prompt", "target_by_key", "repl_by_key", "expected_by_mm_count"), + ("prompt", "target_by_key", "repl_by_key"), [ ( "Image:Image:!", @@ -322,49 +283,201 @@ def test_find_text_matches(prompt, target_by_key, expected_by_key): # Test multiple repl_count "pattern_3": ("?", 2), }, - { - # Test no replacement - 0: "Image:Image:!", - # Test single replacement - 1: "Image:??", - # Test repeated replacement - 2: "??", - }, ), ] ) +@pytest.mark.parametrize( + ("mm_count", "expected"), + [ + (0, "Image:Image:!"), + (1, "Image:??"), + (2, "??"), + ] +) # yapf: enable def test_find_replace_text( prompt, target_by_key, repl_by_key, - expected_by_mm_count, + mm_count, + expected, ): # Should not be used since there is nothing to convert to text mock_tokenizer = cast(AnyTokenizer, object()) - matches = find_text_matches( + prompt_repls = [ + PromptReplacement(target, *repl_by_key[key]).bind(key, mock_tokenizer) + for key, target in target_by_key.items() + ] + matches = find_text_matches(prompt, prompt_repls) + + result = replace_text_matches( prompt, - [ - PromptReplacement(target, *repl_by_key[key]) \ - .bind(key, mock_tokenizer) - for key, target in target_by_key.items() - ], + matches, + {key: list(range(mm_count)) + for key in repl_by_key}, + BatchFeature(), ) - result_by_mm_count = { - mm_count: replace_text_matches( - prompt, - matches, - {key: list(range(mm_count)) - for key in repl_by_key}, - BatchFeature(), - ) - for mm_count in expected_by_mm_count - } # Only displayed on error print("matches:", matches) - print("result_by_mm_count:", result_by_mm_count) + print("result:", result) + + # Manually constructed results + assert result == expected + + +# yapf: disable +@pytest.mark.parametrize( + ("prompt", "target_by_key", "repl_by_key"), + [ + # Tokenized test cases of `test_find_replace_text` + # using the vocab of llava-hf/llava-v1.6-mistral-7b-hf + ( + [1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918], + { + # We use `` before `Image:` to test matches that + # occur out of order + "pattern_1": [32000], + "pattern_2": [9833, 28747], + "pattern_3": [918], + }, + { + # Test whether target is confused with repl_unit + "pattern_1": ([32000, 32000], 1), + # Test empty repl_unit + "pattern_2": ([], 1), + # Test multiple repl_count + "pattern_3": ([1550], 2), + }, + ), + ] +) +@pytest.mark.parametrize( + ("mm_count", "expected"), + [ + (0, [1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918]), + (1, [1, 32000, 32000, 9833, 28747, 32000, 32000, 1550, 1550]), + (2, [1, 32000, 32000, 32000, 32000, 32000, 1550, 1550]), + ] +) +# yapf: enable +def test_find_replace_tokens( + prompt, + target_by_key, + repl_by_key, + mm_count, + expected, +): + # Should not be used since there is nothing to convert to tokens + mock_tokenizer = cast(AnyTokenizer, object()) + + prompt_repls = [ + PromptReplacement(target, *repl_by_key[key]).bind(key, mock_tokenizer) + for key, target in target_by_key.items() + ] + matches = find_token_matches(prompt, prompt_repls) + + result = replace_token_matches( + prompt, + matches, + {key: list(range(mm_count)) + for key in repl_by_key}, + BatchFeature(), + ) + + # Only displayed on error + print("matches:", matches) + print("result:", result) + + # Manually constructed results + assert result == expected + + +# yapf: disable +@pytest.mark.parametrize( + "repl_by_key", + [ + { + "pattern_1": ([32000, 32000], 1), + "pattern_2": ([], 1), + "pattern_3": ([1550], 2), + }, + ], +) +@pytest.mark.parametrize( + ("prompt", "expected"), + [ + ( + [1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918], + [ + _PlaceholderInfo( + modality="pattern_1", + start_idx=6, + unit=[32000, 32000], + unit_count=1, + ), + ], + ), + ( + [1, 32000, 32000, 9833, 28747, 32000, 32000, 1550, 1550], + [ + _PlaceholderInfo( + modality="pattern_1", + start_idx=1, + unit=[32000, 32000], + unit_count=1, + ), + _PlaceholderInfo( + modality="pattern_1", + start_idx=5, + unit=[32000, 32000], + unit_count=1, + ), + _PlaceholderInfo( + modality="pattern_3", + start_idx=7, + unit=[1550], + unit_count=2, + ), + ], + ), + ( + [1, 32000, 32000, 32000, 32000, 32000, 1550, 1550], + [ + _PlaceholderInfo( + modality="pattern_1", + start_idx=1, + unit=[32000, 32000], + unit_count=2, + ), + _PlaceholderInfo( + modality="pattern_3", + start_idx=6, + unit=[1550], + unit_count=2, + ), + ], + ), + ] +) +def test_iter_placeholders( + repl_by_key, + prompt, + expected, +): + # Should not be used since there is nothing to convert to tokens + mock_tokenizer = cast(AnyTokenizer, object()) + + prompt_repls = [ + PromptReplacement([], *repl).bind(key, mock_tokenizer) + for key, repl in repl_by_key.items() + ] + + result = list(iter_placeholders(prompt_repls, prompt)) + + # Only displayed on error + print("result:", result) # Manually constructed results - assert result_by_mm_count == expected_by_mm_count + assert result == expected diff --git a/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py b/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py index 21958b1640204..d676eacffb056 100644 --- a/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py +++ b/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py @@ -1,13 +1,34 @@ -from typing import List, Optional, Union +from typing import Iterable, List, Optional, Tuple, Union import torch +import torch.nn as nn from vllm.attention import AttentionMetadata -from vllm.model_executor.models.gemma2 import Gemma2EmbeddingModel -from vllm.sequence import IntermediateTensors +from vllm.config import VllmConfig +from vllm.model_executor.layers.pooler import Pooler, PoolingType +from vllm.model_executor.models.gemma2 import Gemma2Model +from vllm.model_executor.models.utils import WeightsMapper, maybe_prefix +from vllm.model_executor.pooling_metadata import PoolingMetadata +from vllm.sequence import IntermediateTensors, PoolerOutput -class MyGemma2Embedding(Gemma2EmbeddingModel): +class MyGemma2Embedding(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + self.model = Gemma2Model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + + self._pooler = Pooler.from_config_with_defaults( + vllm_config.model_config.pooler_config, + pooling_type=PoolingType.LAST, + normalize=True, + softmax=False, + ) + + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) def forward( self, @@ -18,7 +39,7 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = super().forward( + hidden_states = self.model( input_ids, positions, kv_caches, @@ -32,3 +53,17 @@ def forward( # Return all-zero embeddings return torch.zeros_like(hidden_states) + + def pooler( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> Optional[PoolerOutput]: + return self._pooler(hidden_states, pooling_metadata) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""}) + weights = hf_to_vllm_mapper.apply(weights) + weights = ((name, data) for name, data in weights + if not name.startswith("lm_head.")) + return self.model.load_weights(weights) diff --git a/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_llava.py b/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_llava.py index 3ebd7864b8fc8..2f4194a63fc25 100644 --- a/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_llava.py +++ b/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_llava.py @@ -2,19 +2,15 @@ import torch -from vllm.inputs import INPUT_REGISTRY from vllm.model_executor.models.llava import (LlavaForConditionalGeneration, - dummy_data_for_llava, - get_max_llava_image_tokens, - input_processor_for_llava) + LlavaProcessor, + get_max_llava_image_tokens) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -@MULTIMODAL_REGISTRY.register_image_input_mapper() @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens) -@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava) -@INPUT_REGISTRY.register_input_processor(input_processor_for_llava) +@MULTIMODAL_REGISTRY.register_processor(LlavaProcessor) class MyLlava(LlavaForConditionalGeneration): def compute_logits( diff --git a/tests/test_lazy_torch_compile.py b/tests/standalone_tests/lazy_torch_compile.py similarity index 100% rename from tests/test_lazy_torch_compile.py rename to tests/standalone_tests/lazy_torch_compile.py diff --git a/tests/standalone_tests/python_only_compile.sh b/tests/standalone_tests/python_only_compile.sh new file mode 100644 index 0000000000000..f00895c0997f1 --- /dev/null +++ b/tests/standalone_tests/python_only_compile.sh @@ -0,0 +1,30 @@ +#!/bin/bash +# This script tests if the python only compilation works correctly +# for users who do not have any compilers installed on their system + +set -e +set -x + +cd /vllm-workspace/ + +# uninstall vllm +pip3 uninstall -y vllm +# restore the original files +mv test_docs/vllm ./vllm + +# remove all compilers +apt remove --purge build-essential -y +apt autoremove -y + +echo 'import os; os.system("touch /tmp/changed.file")' >> vllm/__init__.py + +VLLM_USE_PRECOMPILED=1 pip3 install -vvv -e . + +# Run the script +python3 -c 'import vllm' + +# Check if the clangd log file was created +if [ ! -f /tmp/changed.file ]; then + echo "changed.file was not created, python only compilation failed" + exit 1 +fi diff --git a/tests/test_config.py b/tests/test_config.py index 3cf90297ce177..45b0b938af215 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -26,8 +26,7 @@ def test_auto_task(model_id, expected_task): @pytest.mark.parametrize(("model_id", "bad_task"), [ - ("facebook/opt-125m", "embedding"), - ("intfloat/e5-mistral-7b-instruct", "generate"), + ("Qwen/Qwen2.5-Math-RM-72B", "generate"), ]) def test_incorrect_task(model_id, bad_task): with pytest.raises(ValueError, match=r"does not support the .* task"): diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index 83bfbb6ade8d7..b44d3e5cb0678 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -23,7 +23,8 @@ def test_prefill(): manager = KVCacheManager( block_size=16, num_gpu_blocks=10, - sliding_window=False, + max_model_len=8192, + sliding_window=None, enable_caching=True, num_preallocate_tokens=16, ) @@ -121,7 +122,8 @@ def test_decode(): manager = KVCacheManager( block_size=16, num_gpu_blocks=10, - sliding_window=False, + max_model_len=8192, + sliding_window=None, enable_caching=True, num_preallocate_tokens=16, ) @@ -172,7 +174,8 @@ def test_evict(): manager = KVCacheManager( block_size=16, num_gpu_blocks=10, - sliding_window=False, + max_model_len=8192, + sliding_window=None, enable_caching=True, num_preallocate_tokens=16, ) @@ -220,7 +223,8 @@ def test_hash_block_correct_reuse(): manager = KVCacheManager( block_size=block_size, num_gpu_blocks=1, - sliding_window=False, + max_model_len=8192, + sliding_window=None, enable_caching=True, num_preallocate_tokens=0, ) @@ -256,7 +260,8 @@ def test_computed_blocks_not_evicted(): manager = KVCacheManager( block_size=block_size, num_gpu_blocks=2, - sliding_window=False, + max_model_len=8192, + sliding_window=None, enable_caching=True, num_preallocate_tokens=0, ) @@ -303,7 +308,8 @@ def test_basic_prefix_caching_disabled(): manager = KVCacheManager( block_size=block_size, num_gpu_blocks=4, - sliding_window=False, + max_model_len=8192, + sliding_window=None, enable_caching=False, num_preallocate_tokens=0, ) @@ -342,7 +348,8 @@ def test_preallocate_blocks(num_preallocate_tokens: int, block_size: int): manager = KVCacheManager( block_size=block_size, num_gpu_blocks=10, - sliding_window=False, + max_model_len=8192, + sliding_window=None, enable_caching=True, num_preallocate_tokens=num_preallocate_tokens, ) @@ -370,7 +377,8 @@ def test_cache_blocks(): manager = KVCacheManager( block_size=block_size, num_gpu_blocks=5, - sliding_window=False, + max_model_len=8192, + sliding_window=None, enable_caching=True, num_preallocate_tokens=0, ) diff --git a/tests/v1/engine/test_engine_args.py b/tests/v1/engine/test_engine_args.py index 69cfdf5a395c1..ac5e7dde525a7 100644 --- a/tests/v1/engine/test_engine_args.py +++ b/tests/v1/engine/test_engine_args.py @@ -4,6 +4,7 @@ from vllm.config import VllmConfig from vllm.engine.arg_utils import EngineArgs from vllm.usage.usage_lib import UsageContext +from vllm.utils import FlexibleArgumentParser if not envs.VLLM_USE_V1: pytest.skip( @@ -12,6 +13,24 @@ ) +def test_prefix_caching_from_cli(): + parser = EngineArgs.add_cli_args(FlexibleArgumentParser()) + args = parser.parse_args([]) + engine_args = EngineArgs.from_cli_args(args=args) + assert (engine_args.enable_prefix_caching + ), "V1 turns on prefix caching by default." + + # Turn it off possible with flag. + args = parser.parse_args(["--no-enable-prefix-caching"]) + engine_args = EngineArgs.from_cli_args(args=args) + assert not engine_args.enable_prefix_caching + + # Turn it on with flag. + args = parser.parse_args(["--enable-prefix-caching"]) + engine_args = EngineArgs.from_cli_args(args=args) + assert engine_args.enable_prefix_caching + + def test_defaults(): engine_args = EngineArgs(model="facebook/opt-125m") diff --git a/tests/v1/engine/test_engine_core.py b/tests/v1/engine/test_engine_core.py index bd11ff1877064..fef44ac29c41f 100644 --- a/tests/v1/engine/test_engine_core.py +++ b/tests/v1/engine/test_engine_core.py @@ -27,9 +27,8 @@ def make_request() -> EngineCoreRequest: request_id=uuid.uuid4(), prompt=PROMPT, prompt_token_ids=PROMPT_TOKENS, - mm_data=None, + mm_inputs=None, mm_placeholders=None, - mm_processor_kwargs=None, sampling_params=SamplingParams(), eos_token_id=None, arrival_time=time.time(), diff --git a/tests/v1/engine/test_engine_core_client.py b/tests/v1/engine/test_engine_core_client.py index 582192196aaf9..4e003a25e91d2 100644 --- a/tests/v1/engine/test_engine_core_client.py +++ b/tests/v1/engine/test_engine_core_client.py @@ -29,9 +29,8 @@ def make_request(params: SamplingParams) -> EngineCoreRequest: request_id=str(uuid.uuid4()), prompt=PROMPT, prompt_token_ids=PROMPT_TOKENS, - mm_data=None, + mm_inputs=None, mm_placeholders=None, - mm_processor_kwargs=None, sampling_params=params, eos_token_id=None, arrival_time=time.time(), diff --git a/tests/worker/test_encoder_decoder_model_runner.py b/tests/worker/test_encoder_decoder_model_runner.py index 9e166ae64dbfb..5289c91f201cd 100644 --- a/tests/worker/test_encoder_decoder_model_runner.py +++ b/tests/worker/test_encoder_decoder_model_runner.py @@ -4,12 +4,12 @@ import pytest import torch +from vllm.config import VllmConfig from vllm.engine.arg_utils import EngineArgs from vllm.platforms import current_platform from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata from vllm.utils import make_tensor_with_pad from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner -from vllm.worker.model_runner import _get_graph_batch_size BATCH_SIZES = [1, 4, 16, 64, 256] @@ -548,7 +548,7 @@ def test_prepare_decode_cuda_graph(batch_size, multiple_seqs_per_seq_group): # With CUDA Graph capture and replay enabled, the decoder and encoder # input sequences will be padded. Create the expected padded tensors # accordingly. - graph_batch_size = _get_graph_batch_size(expanded_batch_size) + graph_batch_size = VllmConfig.get_graph_batch_size(expanded_batch_size) cuda_graph_pad_size = graph_batch_size - expanded_batch_size padded_seq_lens = seq_lens + list(itertools.repeat(1, cuda_graph_pad_size)) padded_encoder_seq_lens = encoder_seq_lens + list( diff --git a/tests/worker/test_model_input.py b/tests/worker/test_model_input.py index b36e8bfe73ff3..309854e6babf3 100644 --- a/tests/worker/test_model_input.py +++ b/tests/worker/test_model_input.py @@ -8,10 +8,10 @@ from vllm.attention.backends.utils import CommonAttentionState from vllm.model_executor import SamplingMetadata from vllm.model_executor.pooling_metadata import PoolingMetadata -from vllm.worker.embedding_model_runner import ( - ModelInputForGPUWithPoolingMetadata) from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata from vllm.worker.multi_step_model_runner import StatefulModelInput +from vllm.worker.pooling_model_runner import ( + ModelInputForGPUWithPoolingMetadata) class MockAttentionBackend(AttentionBackend): diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index 433a9b30ba57a..4055524f3e0c7 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -3,13 +3,14 @@ import pytest import torch +from vllm.config import VllmConfig from vllm.distributed.parallel_state import (ensure_model_parallel_initialized, init_distributed_environment) from vllm.engine.arg_utils import EngineArgs from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata from vllm.utils import get_open_port -from vllm.worker.model_runner import ModelRunner, _get_graph_batch_size +from vllm.worker.model_runner import ModelRunner def _create_model_runner(model: str, *args, **kwargs) -> ModelRunner: @@ -176,7 +177,7 @@ def test_prepare_decode_cuda_graph(batch_size): model_input.attn_metadata, model_input.attn_metadata.slot_mapping) assert len(slot_mapping) == len(input_tokens) - expected_bs = _get_graph_batch_size(len(seq_group_metadata_list)) + expected_bs = VllmConfig.get_graph_batch_size(len(seq_group_metadata_list)) # Verify input metadata is correct for prompts. device = model_runner.device assert attn_metadata.num_prefills == 0 diff --git a/vllm/__init__.py b/vllm/__init__.py index 8f477ea84756d..a10f6d3128cb6 100644 --- a/vllm/__init__.py +++ b/vllm/__init__.py @@ -7,8 +7,8 @@ from vllm.executor.ray_utils import initialize_ray_cluster from vllm.inputs import PromptType, TextPrompt, TokensPrompt from vllm.model_executor.models import ModelRegistry -from vllm.outputs import (CompletionOutput, EmbeddingOutput, - EmbeddingRequestOutput, RequestOutput) +from vllm.outputs import (CompletionOutput, PoolingOutput, + PoolingRequestOutput, RequestOutput) from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams @@ -25,8 +25,8 @@ "SamplingParams", "RequestOutput", "CompletionOutput", - "EmbeddingOutput", - "EmbeddingRequestOutput", + "PoolingOutput", + "PoolingRequestOutput", "LLMEngine", "EngineArgs", "AsyncLLMEngine", @@ -34,3 +34,26 @@ "initialize_ray_cluster", "PoolingParams", ] + + +def __getattr__(name: str): + import warnings + + if name == "EmbeddingOutput": + msg = ("EmbeddingOutput has been renamed to PoolingOutput. " + "The original name will be removed in an upcoming version.") + + warnings.warn(DeprecationWarning(msg), stacklevel=2) + + return PoolingOutput + + if name == "EmbeddingRequestOutput": + msg = ("EmbeddingRequestOutput has been renamed to " + "PoolingRequestOutput. " + "The original name will be removed in an upcoming version.") + + warnings.warn(DeprecationWarning(msg), stacklevel=2) + + return PoolingRequestOutput + + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 5be2d83346d00..aed04361e5fb4 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -247,5 +247,6 @@ def forward( k_scale: float = 1.0, v_scale: float = 1.0, attn_type: str = AttentionType.DECODER, + output: Optional[torch.Tensor] = None, ) -> torch.Tensor: raise NotImplementedError diff --git a/vllm/attention/backends/blocksparse_attn.py b/vllm/attention/backends/blocksparse_attn.py index 9e54c3b40c54e..99cb84346d84e 100644 --- a/vllm/attention/backends/blocksparse_attn.py +++ b/vllm/attention/backends/blocksparse_attn.py @@ -360,6 +360,7 @@ def forward( k_scale: float = 1.0, v_scale: float = 1.0, attn_type: str = AttentionType.DECODER, + output: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with FlashAttention and PagedAttention. @@ -448,5 +449,6 @@ def forward( blocksparse_head_sliding_step=self.head_sliding_step, ) + assert output is not None # Reshape the output tensor. return output.view(num_tokens, hidden_size) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 32738d1043b1d..c69e12ad78c44 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -638,24 +638,27 @@ def forward( k_scale: float = 1.0, v_scale: float = 1.0, attn_type: str = AttentionType.DECODER, + output: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with FlashAttention. Args: - query: shape = [num_tokens, num_heads * head_size] - key: shape = [num_tokens, num_kv_heads * head_size] - value: shape = [num_tokens, num_kv_heads * head_size] + query: shape = [num_tokens, num_heads, head_size] + key: shape = [num_tokens, num_kv_heads, head_size] + value: shape = [num_tokens, num_kv_heads, head_size] + output: shape = [num_tokens, num_heads, head_size] kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size] NOTE: kv_cache will be an empty tensor with shape [0] for profiling run. attn_metadata: Metadata for attention. - Returns: - shape = [num_tokens, num_heads * head_size] + NOTE: It in-place updates the output tensor. """ # NOTE(woosuk): FlashAttention does not support FP8 KV cache. assert k_scale == 1.0 and v_scale == 1.0, ( "key/v_scale is not supported in FlashAttention.") + assert output is not None, "Output tensor must be provided." + if (attn_type == AttentionType.ENCODER and (not attn_metadata.is_all_encoder_attn_metadata_set)): raise AttributeError("Encoder attention requires setting " @@ -666,23 +669,12 @@ def forward( "requires setting cross-attention " "metadata attributes.") - num_heads: int = self.num_heads - head_size: int = self.head_size - num_kv_heads: int = self.num_kv_heads kv_cache_dtype: str = self.kv_cache_dtype softmax_scale: float = self.scale window_size = self.sliding_window alibi_slopes: Optional[torch.Tensor] = self.alibi_slopes logits_soft_cap: Optional[float] = self.logits_soft_cap - num_tokens, hidden_size = query.shape - - # Reshape the query, key, and value tensors. - query = query.view(-1, num_heads, head_size) - if (key is not None) and (value is not None): - key = key.view(-1, num_kv_heads, head_size) - value = value.view(-1, num_kv_heads, head_size) - if kv_cache.numel() > 0: key_cache = kv_cache[0] value_cache = kv_cache[1] @@ -721,13 +713,13 @@ def forward( num_decode_query_tokens) = \ get_num_prefill_decode_query_kv_tokens(attn_metadata, attn_type) decode_query = query[num_prefill_query_tokens:] + decode_output = output[num_prefill_query_tokens:] # QKV for prefill. query = query[:num_prefill_query_tokens] + prefill_output = output[:num_prefill_query_tokens] assert query.shape[0] == num_prefill_query_tokens assert decode_query.shape[0] == num_decode_query_tokens - prefill_output: Optional[torch.Tensor] = None - decode_output: Optional[torch.Tensor] = None if prefill_meta := attn_metadata.prefill_metadata: # Prompt run. if (kv_cache.numel() == 0 or prefill_meta.block_tables is None @@ -741,7 +733,7 @@ def forward( key = key[:num_prefill_kv_tokens] value = value[:num_prefill_kv_tokens] - prefill_output = flash_attn_varlen_func( + flash_attn_varlen_func( q=query, k=key, v=value, @@ -754,6 +746,7 @@ def forward( window_size=window_size, alibi_slopes=alibi_slopes, softcap=logits_soft_cap, + out=prefill_output, ) else: # prefix-enabled attention @@ -761,7 +754,7 @@ def forward( "Only decoder-only models support prefix caching") assert prefill_meta.seq_lens is not None max_seq_len = max(prefill_meta.seq_lens) - prefill_output = flash_attn_varlen_func( # noqa + flash_attn_varlen_func( # noqa q=query, k=key_cache, v=value_cache, @@ -775,6 +768,7 @@ def forward( alibi_slopes=alibi_slopes, block_table=prefill_meta.block_tables, softcap=logits_soft_cap, + out=prefill_output, ) if decode_meta := attn_metadata.decode_metadata: @@ -788,7 +782,7 @@ def forward( assert attn_type == AttentionType.DECODER, ( "Only decoder-only models support max_decode_query_len > 1" ) - decode_output = flash_attn_varlen_func( + flash_attn_varlen_func( q=decode_query, k=key_cache, v=value_cache, @@ -802,6 +796,7 @@ def forward( alibi_slopes=alibi_slopes, softcap=logits_soft_cap, block_table=decode_meta.block_tables, + out=decode_output, ) else: # Use flash_attn_with_kvcache for normal decoding. @@ -810,7 +805,7 @@ def forward( _, block_tables_arg, ) = get_seq_len_block_table_args(decode_meta, False, attn_type) - decode_output = flash_attn_with_kvcache( + flash_attn_with_kvcache( q=decode_query.unsqueeze(1), k_cache=key_cache, v_cache=value_cache, @@ -821,20 +816,8 @@ def forward( window_size=window_size, alibi_slopes=alibi_slopes, softcap=logits_soft_cap, - ).squeeze(1) - - if prefill_output is None: - assert decode_output is not None - return decode_output.view(num_decode_query_tokens, hidden_size) - if decode_output is None: - assert prefill_output is not None - return prefill_output.view(num_prefill_query_tokens, hidden_size) - - assert decode_meta is not None - decode_output = decode_output.squeeze(1) - output = torch.cat([prefill_output, decode_output], dim=0) - return output.view(num_tokens, hidden_size) - + out=decode_output.unsqueeze(1), + ) return output diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 1a2024705eb04..e367468d05d26 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -774,7 +774,11 @@ def forward( k_scale: float = 1.0, v_scale: float = 1.0, attn_type: str = AttentionType.DECODER, + output: Optional[torch.Tensor] = None, ) -> torch.Tensor: + + # TODO: directly write to output tensor + if attn_type != AttentionType.DECODER: raise NotImplementedError("Encoder self-attention and " "encoder/decoder cross-attention " diff --git a/vllm/attention/backends/hpu_attn.py b/vllm/attention/backends/hpu_attn.py index 5359941d41fde..2c62e565c04c7 100644 --- a/vllm/attention/backends/hpu_attn.py +++ b/vllm/attention/backends/hpu_attn.py @@ -145,6 +145,7 @@ def forward( k_scale: float = 1.0, v_scale: float = 1.0, attn_type: str = AttentionType.DECODER, + output: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with xFormers and PagedAttention. diff --git a/vllm/attention/backends/ipex_attn.py b/vllm/attention/backends/ipex_attn.py index 3b0d51ea4a3d8..21949874bea47 100644 --- a/vllm/attention/backends/ipex_attn.py +++ b/vllm/attention/backends/ipex_attn.py @@ -173,6 +173,7 @@ def forward( k_scale: float = 1.0, v_scale: float = 1.0, attn_type: str = AttentionType.DECODER, + output: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with IPEX varlen_attention and PagedAttention. diff --git a/vllm/attention/backends/pallas.py b/vllm/attention/backends/pallas.py index 5988be0e6b687..9809aed0e66f9 100644 --- a/vllm/attention/backends/pallas.py +++ b/vllm/attention/backends/pallas.py @@ -151,6 +151,7 @@ def forward( k_scale: float = 1.0, v_scale: float = 1.0, attn_type: str = AttentionType.DECODER, + output: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with Pallas attention. diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 6a494f4e73cb4..19daeb729ee61 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -415,6 +415,7 @@ def forward( k_scale: float = 1.0, v_scale: float = 1.0, attn_type: str = AttentionType.DECODER, + output: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with FlashAttention and PagedAttention. @@ -429,7 +430,7 @@ def forward( Returns: shape = [num_tokens, num_heads * head_size] """ - # Reminder: Please update docs/source/serving/compatibility_matrix.rst + # Reminder: Please update docs/source/usage/compatibility_matrix.rst # If the feature combo become valid if attn_type != AttentionType.DECODER: raise NotImplementedError("Encoder self-attention and " diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index 16e044b618c40..86e952a903f36 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -341,7 +341,11 @@ def build(self, seq_lens: List[int], query_lens: List[int], ) else: block_tables = torch.tensor([]) - seq_lens_tensor = torch.tensor([]) + seq_lens_tensor = torch.tensor( + input_data.seq_lens[:input_data.num_prefills], + dtype=torch.int32, + device="cpu", + ) # For multi-modal models placeholder_index_maps = None @@ -427,6 +431,7 @@ def forward( k_scale: float = 1.0, v_scale: float = 1.0, attn_type: str = AttentionType.DECODER, + output: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with torch SDPA and PagedAttention. diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 292575a8736bc..e2e989efb020c 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -417,6 +417,7 @@ def forward( k_scale: float = 1.0, v_scale: float = 1.0, attn_type: str = AttentionType.DECODER, + output: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with xFormers and PagedAttention. diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 17157617248f7..05d997279893b 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -3,8 +3,8 @@ import torch import torch.nn as nn +import torch.nn.functional as F -import vllm.envs as envs from vllm.attention import AttentionMetadata, AttentionType from vllm.attention.selector import backend_name_to_enum, get_attn_backend from vllm.config import CacheConfig, get_current_vllm_config @@ -12,7 +12,7 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod -from vllm.platforms import current_platform +from vllm.platforms import _Backend, current_platform from vllm.utils import direct_register_custom_op @@ -97,14 +97,23 @@ def __init__( self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads, alibi_slopes, sliding_window, kv_cache_dtype, blocksparse_params, logits_soft_cap) + self.num_heads = num_heads + self.head_size = head_size + self.num_kv_heads = num_kv_heads self.backend = backend_name_to_enum(attn_backend.get_name()) # For cuda-alike (CUDA and ROCM) and cpu platforms, we control how # torch.compile works by registering the attention as one giant # opaque custom op. For other platforms, we directly call them # and let torch.compile handle them. - self.use_direct_call = envs.VLLM_USE_V1 or not ( - current_platform.is_cuda_alike() or current_platform.is_cpu()) + self.use_direct_call = not current_platform.is_cuda_alike( + ) and not current_platform.is_cpu() + + # For some attention backends, we allocate an output tensor before + # calling the custom op. When piecewise cudagraph is enabled, this + # makes sure the output tensor is allocated inside the cudagraph. + self.use_output = self.backend == _Backend.FLASH_ATTN or \ + self.backend == _Backend.FLASH_ATTN_VLLM_V1 compilation_config = get_current_vllm_config().compilation_config if prefix in compilation_config.static_forward_context: raise ValueError(f"Duplicate layer name: {prefix}") @@ -130,6 +139,22 @@ def forward( self._k_scale, self._v_scale, attn_type=attn_type) + elif self.use_output: + output = torch.empty_like(query) + hidden_size = query.size(-1) + # Reshape the query, key, and value tensors. + # NOTE(woosuk): We do this outside the custom op to minimize the + # CPU overheads from the non-CUDA-graph regions. + query = query.view(-1, self.num_heads, self.head_size) + output = output.view(-1, self.num_heads, self.head_size) + if key is not None: + key = key.view(-1, self.num_kv_heads, self.head_size) + if value is not None: + value = value.view(-1, self.num_kv_heads, self.head_size) + torch.ops.vllm.unified_attention_with_output( + query, key, value, output, kv_cache, attn_type, + self.layer_name) + return output.view(-1, hidden_size) else: return torch.ops.vllm.unified_attention(query, key, value, kv_cache, attn_type, @@ -144,6 +169,68 @@ def extra_repr(self) -> str: return s +class MultiHeadAttention(nn.Module): + """Multi-headed attention without any cache, used for ViT.""" + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: Optional[int] = None, + ): + super().__init__() + self.num_heads = num_heads + self.head_size = head_size + self.scale = scale + self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads + + dtype = torch.get_default_dtype() + attn_backend = get_attn_backend(head_size, + dtype, + kv_cache_dtype=None, + block_size=16, + is_attention_free=False) + if attn_backend in {_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1}: + attn_backend = _Backend.XFORMERS + + self.attn_backend = attn_backend if attn_backend in { + _Backend.TORCH_SDPA, _Backend.XFORMERS + } else _Backend.TORCH_SDPA + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + ) -> torch.Tensor: + """Input shape: batch_size x seq_len x hidden_size""" + # TODO(Isotr0py): Use existing backend implementations and support FA2 + bsz, q_len, _ = query.size() + kv_len = key.size(1) + + query = query.view(bsz, q_len, self.num_heads, self.head_size) + key = key.view(bsz, kv_len, self.num_kv_heads, self.head_size) + value = value.view(bsz, kv_len, self.num_kv_heads, self.head_size) + + if self.attn_backend == _Backend.XFORMERS: + from xformers import ops as xops + + out = xops.memory_efficient_attention_forward(query, + key, + value, + scale=self.scale) + elif self.attn_backend == _Backend.TORCH_SDPA: + query, key, value = (x.transpose(1, 2) + for x in (query, key, value)) + out = F.scaled_dot_product_attention(query, + key, + value, + scale=self.scale) + out = out.transpose(1, 2) + return out.view(bsz, q_len, -1) + + def unified_attention( query: torch.Tensor, key: torch.Tensor, @@ -183,3 +270,47 @@ def unified_attention_fake( fake_impl=unified_attention_fake, dispatch_key=current_platform.dispatch_key, ) + + +def unified_attention_with_output( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + output: torch.Tensor, + kv_cache: torch.Tensor, + attn_type: str, + layer_name: str, +) -> None: + forward_context: ForwardContext = get_forward_context() + attn_metadata = forward_context.dynamic_forward_context + self = forward_context.static_forward_context[layer_name] + self.impl.forward(query, + key, + value, + kv_cache, + attn_metadata, + self._k_scale, + self._v_scale, + attn_type=attn_type, + output=output) + + +def unified_attention_with_output_fake( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + output: torch.Tensor, + kv_cache: torch.Tensor, + attn_type: str, + layer_name: str, +) -> None: + return + + +direct_register_custom_op( + op_name="unified_attention_with_output", + op_func=unified_attention_with_output, + mutates_args=["kv_cache", "output"], + fake_impl=unified_attention_with_output_fake, + dispatch_key=current_platform.dispatch_key, +) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 464bc2af8fd6d..1206424ae1e3f 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -1,5 +1,6 @@ import copy import dataclasses +import time from contextlib import ExitStack from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple from unittest.mock import patch @@ -14,6 +15,7 @@ from .counter import compilation_counter from .inductor_pass import InductorPass +from .monitor import end_monitoring_torch_compile from .pass_manager import PostGradPassManager logger = init_logger(__name__) @@ -22,31 +24,54 @@ def wrap_inductor(graph, example_inputs, additional_inductor_config, - do_logging=False, + compilation_config: CompilationConfig, + graph_index: int = 0, + num_graphs: int = 1, runtime_shape: Optional[int] = None, use_inductor: bool = True): + if graph_index == 0: + # before compiling the first graph, record the start time + global compilation_start_time + compilation_start_time = time.time() + if not use_inductor: return graph compilation_counter.num_inductor_compilations += 1 - if do_logging: - if runtime_shape is None: - logger.info("Compiling a graph for general shape") - else: - logger.info("Compiling a graph for shape %s", runtime_shape) - from torch._inductor import config - current_config = config.shallow_copy_dict() + current_config = config.get_config_copy() from torch._inductor.compile_fx import compile_fx if additional_inductor_config is not None: current_config.update(additional_inductor_config) + if isinstance(runtime_shape, int): + # for a specific batchsize, tuning triton kernel parameters + # can be beneficial + current_config["max_autotune"] = True + current_config["coordinate_descent_tuning"] = True + # inductor can inplace modify the graph, so we need to copy it # see https://github.com/pytorch/pytorch/issues/138980 graph = copy.deepcopy(graph) - return compile_fx(graph, example_inputs, config_patches=current_config) + compiled_graph = compile_fx(graph, + example_inputs, + config_patches=current_config) + + # after compiling the last graph, record the end time + if graph_index == num_graphs - 1: + now = time.time() + elapsed = now - compilation_start_time + compilation_config.compilation_time += elapsed + if runtime_shape is None: + logger.info("Compiling a graph for general shape takes %.2f s", + elapsed) + else: + logger.info("Compiling a graph for shape %s takes %.2f s", + runtime_shape, elapsed) + + return compiled_graph @dataclasses.dataclass @@ -108,6 +133,8 @@ def split_graph(graph: fx.GraphModule, # we share the global graph pool among all the backends global_graph_pool = None +compilation_start_time = 0.0 + class PiecewiseCompileInterpreter(torch.fx.Interpreter): """Code adapted from `torch.fx.passes.shape_prop.ShapeProp`. @@ -151,12 +178,15 @@ def call_module(self, target: torch.fx.node.Target, sym_shape_indices = [ i for i, x in enumerate(args) if isinstance(x, torch.SymInt) ] + global compilation_start_time compiled_graph_for_general_shape = wrap_inductor( submod, args, self.compilation_configs.inductor_compile_config, + self.compilation_configs, + graph_index=index, + num_graphs=len(self.compile_submod_names), runtime_shape=None, - do_logging=index == 0, use_inductor=self.compilation_configs.use_inductor) self.module.__dict__[target] = PiecewiseBackend( @@ -242,10 +272,6 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: assert not self._called, "VllmBackend can only be called once" self.graph = graph - # config is updated now, because only here can - # we get the sizes to capture for cudagraph - # from compilation context - self.compilation_configs.init_during_runtime() self.configure_post_pass() self.split_gm, self.piecewise_graphs = split_graph( @@ -377,6 +403,8 @@ def __init__(self, graph: fx.GraphModule, # the entries for different shapes that we need to either # compile or capture cudagraph self.concrete_size_entries: Dict[int, ConcreteSizeEntry] = {} + self.to_be_compiled_sizes: Set[int] = self.compile_sizes.union( + self.capture_sizes) for shape in self.compile_sizes.union(self.capture_sizes): self.concrete_size_entries[shape] = ConcreteSizeEntry( runtime_shape=shape, @@ -387,6 +415,9 @@ def __init__(self, graph: fx.GraphModule, def __call__(self, *args) -> Any: if not self.first_run_finished: self.first_run_finished = True + # no specific sizes to compile + if self.is_last_graph and not self.to_be_compiled_sizes: + end_monitoring_torch_compile(self.compilation_configs) return self.compiled_graph_for_general_shape(*args) runtime_shape = args[self.sym_shape_indices[0]] @@ -401,15 +432,22 @@ def __call__(self, *args) -> Any: if entry.need_to_compile and not entry.compiled: entry.compiled = True + self.to_be_compiled_sizes.remove(runtime_shape) # args are real arguments entry.runnable = wrap_inductor( self.graph, args, self.compilation_configs.inductor_compile_config, + self.compilation_configs, + graph_index=self.piecewise_compile_index, + num_graphs=self.total_piecewise_compiles, runtime_shape=runtime_shape, - do_logging=self.is_first_graph, use_inductor=self.compilation_configs.use_inductor) + # finished compilations for all required shapes + if self.is_last_graph and not self.to_be_compiled_sizes: + end_monitoring_torch_compile(self.compilation_configs) + if not entry.use_cudagraph: return entry.runnable(*args) diff --git a/vllm/compilation/compile_context.py b/vllm/compilation/compile_context.py deleted file mode 100644 index 29db3d4c637b9..0000000000000 --- a/vllm/compilation/compile_context.py +++ /dev/null @@ -1,23 +0,0 @@ -from contextlib import contextmanager -from typing import Any - -_compile_context: Any = None - - -def get_compile_context() -> Any: - """Get the current compile context.""" - return _compile_context - - -@contextmanager -def set_compile_context(context: Any): - """A context manager that stores the current compile context, - usually it is a list of sizes to specialize. - """ - global _compile_context - prev_context = _compile_context - _compile_context = context - try: - yield - finally: - _compile_context = prev_context diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py index 8b81a29936989..a32dced57e5b3 100644 --- a/vllm/compilation/decorators.py +++ b/vllm/compilation/decorators.py @@ -1,7 +1,8 @@ import inspect -from typing import Dict, List, Optional, Union +from typing import Callable, Dict, List, Optional, TypeVar, Union, overload import torch +import torch.nn as nn from vllm.compilation.counter import compilation_counter from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher @@ -10,12 +11,31 @@ from vllm.sequence import IntermediateTensors from vllm.utils import supports_dynamo +from .monitor import start_monitoring_torch_compile + logger = init_logger(__name__) +_T = TypeVar("_T", bound=type[nn.Module]) + + +@overload +def support_torch_compile( + *, + dynamic_arg_dims: Optional[Dict[str, Union[int, List[int]]]], +) -> Callable[[_T], _T]: + ... + + +@overload +def support_torch_compile(cls: _T) -> _T: + ... + def support_torch_compile( - cls: Optional[type] = None, - dynamic_arg_dims: Optional[Dict[str, Union[int, List[int]]]] = None): + cls: Optional[_T] = None, + *, + dynamic_arg_dims: Optional[Dict[str, Union[int, List[int]]]] = None, +) -> Union[Callable[[_T], _T], _T]: """ A decorator to add support for compiling the forward method of a class. @@ -66,7 +86,7 @@ def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]): computation graph. """ - def cls_decorator_helper(cls: type): + def cls_decorator_helper(cls: _T) -> _T: # helper to pass `dynamic_arg_dims`` to `_support_torch_compile`` # to avoid too much indentation for `_support_torch_compile`` if not hasattr(cls, 'forward'): @@ -105,8 +125,10 @@ def cls_decorator_helper(cls: type): return cls_decorator_helper -def _support_torch_compile(cls: type, - dynamic_arg_dims: Dict[str, Union[int, List[int]]]): +def _support_torch_compile( + cls: _T, + dynamic_arg_dims: Dict[str, Union[int, List[int]]], +) -> _T: """ A decorator to add support for compiling the forward method of a class. """ @@ -119,7 +141,7 @@ def _support_torch_compile(cls: type, # other than TorchCompileWrapperWithCustomDispatcher cls.__bases__ = cls.__bases__ + (TorchCompileWrapperWithCustomDispatcher, ) - old_init = cls.__init__ # type: ignore + old_init = cls.__init__ def __init__(self, *, vllm_config: VllmConfig, prefix: str = '', **kwargs): old_init(self, vllm_config=vllm_config, prefix=prefix, **kwargs) @@ -135,7 +157,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = '', **kwargs): TorchCompileWrapperWithCustomDispatcher.__init__( self, compilation_level=vllm_config.compilation_config.level) - cls.__init__ = __init__ # type: ignore + if vllm_config.compilation_config.level == CompilationLevel.PIECEWISE: + start_monitoring_torch_compile(vllm_config.compilation_config) + + cls.__init__ = __init__ def __call__(self, *args, **kwargs): # torch.compiler.is_compiling() means we are inside the compilation @@ -180,5 +205,5 @@ def __call__(self, *args, **kwargs): model_output = self.forward(*args, **kwargs) return model_output - cls.__call__ = __call__ # type: ignore + cls.__call__ = __call__ return cls diff --git a/vllm/compilation/monitor.py b/vllm/compilation/monitor.py new file mode 100644 index 0000000000000..f718e46423212 --- /dev/null +++ b/vllm/compilation/monitor.py @@ -0,0 +1,14 @@ +from vllm.config import CompilationConfig, CompilationLevel +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +def start_monitoring_torch_compile(compilation_config: CompilationConfig): + pass + + +def end_monitoring_torch_compile(compilation_config: CompilationConfig): + if compilation_config.level == CompilationLevel.PIECEWISE: + logger.info("graph compilation takes %.2f s in total", + compilation_config.compilation_time) diff --git a/vllm/config.py b/vllm/config.py index f7448f31f3649..74e39c5de6a00 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1,3 +1,4 @@ +import ast import copy import enum import hashlib @@ -92,6 +93,8 @@ class ModelConfig: the default version. max_model_len: Maximum length of a sequence (including prompt and output). If None, will be derived from the model. + spec_target_max_model_len: Specify the the maximum length for spec + decoding draft models. quantization: Quantization method that was used to quantize the model weights. If None, we assume the model weights are not quantized. quantization_param_path: Path to JSON file containing scaling factors. @@ -108,6 +111,7 @@ class ModelConfig: to eager mode. Additionally for encoder-decoder models, if the sequence length of the encoder input is larger than this, we fall back to the eager mode. + max_logprobs: Maximum number of log probabilities. Defaults to 20. disable_sliding_window: Whether to disable sliding window. If True, we will disable the sliding window functionality of the model. If the model does not support sliding window, this argument is @@ -120,6 +124,8 @@ class ModelConfig: the model name will be the same as `model`. limit_mm_per_prompt: Maximum number of data items per modality per prompt. Only applicable for multimodal models. + use_async_output_proc: Whether to use async output processor. + Defaults to True. config_format: The config format which shall be loaded. Defaults to 'auto' which defaults to 'hf'. hf_overrides: If a dictionary, contains arguments to be forwarded to the @@ -131,7 +137,7 @@ class ModelConfig: override default neuron config that are specific to Neuron devices, this argument will be used to configure the neuron config that can not be gathered from the vllm arguments. - override_pooling_config: Initialize non default pooling config or + override_pooler_config: Initialize non default pooling config or override default pooling config for the embedding model. """ @@ -360,7 +366,7 @@ def _resolve_task( # NOTE: Listed from highest to lowest priority, # in case the model supports multiple of them "generate": ModelRegistry.is_text_generation_model(architectures), - "embedding": ModelRegistry.is_embedding_model(architectures), + "embedding": ModelRegistry.is_pooling_model(architectures), } supported_tasks_lst: List[_Task] = [ task for task, is_supported in task_support.items() if is_supported @@ -371,6 +377,31 @@ def _resolve_task( selected_task = next(iter(supported_tasks_lst)) if len(supported_tasks) > 1: + suffix_to_preferred_task: List[Tuple[str, _Task]] = [ + # Hardcode the models that are exceptions + ("AquilaModel", "generate"), + ("ChatGLMModel", "generate"), + # Other models follow this pattern + ("ForCausalLM", "generate"), + ("ForConditionalGeneration", "generate"), + ("ChatModel", "generate"), + ("LMHeadModel", "generate"), + ("EmbeddingModel", "embedding"), + ("RewardModel", "embedding"), + ("ForSequenceClassification", "embedding"), + ] + info, arch = ModelRegistry.inspect_model_cls(architectures) + + for suffix, pref_task in suffix_to_preferred_task: + if arch.endswith(suffix) and pref_task in supported_tasks: + selected_task = pref_task + break + else: + if (arch.endswith("Model") + and info.architecture.endswith("ForCausalLM") + and "embedding" in supported_tasks): + selected_task = "embedding" + logger.info( "This model supports multiple tasks: %s. " "Defaulting to '%s'.", supported_tasks, selected_task) @@ -394,17 +425,11 @@ def _parse_quant_hf_config(self): def _verify_quantization(self) -> None: supported_quantization = QUANTIZATION_METHODS - rocm_supported_quantization = [ - "awq", "gptq", "fp8", "compressed_tensors", "compressed-tensors", - "fbgemm_fp8", "gguf" - ] optimized_quantization_methods = [ "fp8", "marlin", "modelopt", "gptq_marlin_24", "gptq_marlin", "awq_marlin", "fbgemm_fp8", "compressed_tensors", "compressed-tensors", "experts_int8" ] - tpu_supported_quantization = ["tpu_int8"] - neuron_supported_quantization = ["neuron_quant"] if self.quantization is not None: self.quantization = self.quantization.lower() @@ -439,32 +464,12 @@ def _verify_quantization(self) -> None: raise ValueError( f"Unknown quantization method: {self.quantization}. Must " f"be one of {supported_quantization}.") - if current_platform.is_rocm( - ) and self.quantization not in rocm_supported_quantization: - raise ValueError( - f"{self.quantization} quantization is currently not " - f"supported in ROCm.") - if current_platform.is_tpu( - ) and self.quantization not in tpu_supported_quantization: - raise ValueError( - f"{self.quantization} quantization is currently not " - f"supported in TPU Backend.") + current_platform.verify_quantization(self.quantization) if self.quantization not in optimized_quantization_methods: logger.warning( "%s quantization is not fully " "optimized yet. The speed can be slower than " "non-quantized models.", self.quantization) - if (self.quantization == "awq" and current_platform.is_rocm() - and not envs.VLLM_USE_TRITON_AWQ): - logger.warning( - "Using AWQ quantization with ROCm, but VLLM_USE_TRITON_AWQ" - " is not set, enabling VLLM_USE_TRITON_AWQ.") - envs.VLLM_USE_TRITON_AWQ = True - if current_platform.is_neuron( - ) and self.quantization not in neuron_supported_quantization: - raise ValueError( - f"{self.quantization} quantization is currently not " - f"supported in Neuron Backend.") def _verify_cuda_graph(self) -> None: if self.max_seq_len_to_capture is None: @@ -506,7 +511,7 @@ def verify_async_output_proc(self, parallel_config, speculative_config, self.use_async_output_proc = False return - # Reminder: Please update docs/source/serving/compatibility_matrix.rst + # Reminder: Please update docs/source/usage/compatibility_matrix.rst # If the feature combo become valid if device_config.device_type not in ("cuda", "tpu", "xpu", "hpu"): logger.warning( @@ -522,7 +527,7 @@ def verify_async_output_proc(self, parallel_config, speculative_config, self.use_async_output_proc = False return - # Reminder: Please update docs/source/serving/compatibility_matrix.rst + # Reminder: Please update docs/source/usage/compatibility_matrix.rst # If the feature combo become valid if device_config.device_type == "cuda" and self.enforce_eager: logger.warning( @@ -537,7 +542,7 @@ def verify_async_output_proc(self, parallel_config, speculative_config, if self.task == "embedding": self.use_async_output_proc = False - # Reminder: Please update docs/source/serving/compatibility_matrix.rst + # Reminder: Please update docs/source/usage/compatibility_matrix.rst # If the feature combo become valid if speculative_config: logger.warning("Async output processing is not supported with" @@ -744,8 +749,13 @@ class CacheConfig: vLLM execution. swap_space: Size of the CPU swap space per GPU (in GiB). cache_dtype: Data type for kv cache storage. + is_attention_free: Whether the model is attention-free. num_gpu_blocks_override: Number of GPU blocks to use. This overrides the profiled num_gpu_blocks if specified. Does nothing if None. + sliding_window: Sliding window size for the KV cache. Can not work with + prefix caching enabled. + enable_prefix_caching: Whether to enable prefix caching. + cpu_offload_gb: Size of the CPU offload buffer in GiB. """ def __init__( @@ -914,6 +924,7 @@ class LoadConfig: "tensorizer" will use CoreWeave's tensorizer library for fast weight loading. "bitsandbytes" will load nf4 type weights. + model_loader_extra_config: The extra config for the model loader. ignore_patterns: The list of patterns to ignore when loading the model. Default to "original/**/*" to avoid repeated loading of llama's checkpoints. @@ -930,7 +941,9 @@ def __post_init__(self): if isinstance(model_loader_extra_config, str): self.model_loader_extra_config = json.loads( model_loader_extra_config) - self._verify_load_format() + if isinstance(self.load_format, str): + load_format = self.load_format.lower() + self.load_format = LoadFormat(load_format) if self.ignore_patterns is not None and len(self.ignore_patterns) > 0: logger.info( @@ -939,25 +952,6 @@ def __post_init__(self): else: self.ignore_patterns = ["original/**/*"] - def _verify_load_format(self) -> None: - if not isinstance(self.load_format, str): - return - - load_format = self.load_format.lower() - self.load_format = LoadFormat(load_format) - - rocm_not_supported_load_format: List[str] = [] - if current_platform.is_rocm( - ) and load_format in rocm_not_supported_load_format: - rocm_supported_load_format = [ - f for f in LoadFormat.__members__ - if (f not in rocm_not_supported_load_format) - ] - raise ValueError( - f"load format '{load_format}' is not supported in ROCm. " - f"Supported load formats are " - f"{rocm_supported_load_format}") - @dataclass class ParallelConfig: @@ -1720,7 +1714,7 @@ def verify_with_model_config(self, model_config: ModelConfig): model_config.quantization) def verify_with_scheduler_config(self, scheduler_config: SchedulerConfig): - # Reminder: Please update docs/source/serving/compatibility_matrix.rst + # Reminder: Please update docs/source/usage/compatibility_matrix.rst # If the feature combo become valid if scheduler_config.chunked_prefill_enabled: raise ValueError("LoRA is not supported with chunked prefill yet.") @@ -1788,15 +1782,15 @@ class PoolerConfig: step_tag_id: Optional[int] = None """ - If set, only the score corresponding to the ``step_tag_id`` in the + If set, only the score corresponding to the ``step_tag_id`` in the generated sentence should be returned. Otherwise, the scores for all tokens are returned. """ returned_token_ids: Optional[List[int]] = None """ - A list of indices for the vocabulary dimensions to be extracted, - such as the token IDs of ``good_token`` and ``bad_token`` in the + A list of indices for the vocabulary dimensions to be extracted, + such as the token IDs of ``good_token`` and ``bad_token`` in the ``math-shepherd-mistral-7b-prm`` model. """ @@ -2030,11 +2024,12 @@ def get_served_model_name(model: str, class DecodingConfig: """Dataclass which contains the decoding strategy of the engine""" - # Which guided decoding algo to use. 'outlines' / 'lm-format-enforcer' - guided_decoding_backend: str = 'outlines' + # Which guided decoding algo to use. + # 'outlines' / 'lm-format-enforcer' / 'xgrammar' + guided_decoding_backend: str = 'xgrammar' def __post_init__(self): - valid_guided_backends = ['outlines', 'lm-format-enforcer'] + valid_guided_backends = ['outlines', 'lm-format-enforcer', 'xgrammar'] backend = self.guided_decoding_backend if backend not in valid_guided_backends: raise ValueError(f"Invalid guided_decoding_backend '{backend}," @@ -2062,6 +2057,88 @@ def __post_init__(self): f"installed. Original error:\n{otel_import_error_traceback}") +class KVTransferConfig(BaseModel): + """Configuration for distributed KV cache transfer.""" + + # The KV connector for vLLM to transmit KV caches between vLLM instances. + kv_connector: Optional[str] = None + + # The device used by kv connector to buffer the KV cache. + # Currently only support 'cuda'. + kv_buffer_device: Optional[str] = "cuda" + + # The buffer size for TorchDistributedConnector. Measured in number of + # bytes. Recommended value: 1e9 (about 1GB). + kv_buffer_size: float = 1e9 + + # Whether this vLLM instance produces, consumes KV cache, or both. Choices + # are 'kv_producer', 'kv_consumer', and 'both'. + kv_role: Optional[str] = None + + # The rank of this vLLM instance in the KV cache transfer. Typical value: + # 0 for prefill instance, 1 for decode instance. + # Currently only 1P1D is supported. + kv_rank: Optional[int] = None + + # The number of parallel instances for KV cache transfer. For + # PyNcclConnector, this should be 2. + kv_parallel_size: int = 1 + + # The KV connector ip, used to build distributed connection + kv_ip: str = "127.0.0.1" + + # The KV connector port, used to build distributed connection + kv_port: int = 14579 + + @classmethod + def from_cli(cls, cli_value: str) -> "KVTransferConfig": + """Parse the CLI value for the kv cache transfer config.""" + return KVTransferConfig.model_validate_json(cli_value) + + def model_post_init(self, __context: Any) -> None: + if all([ + self.kv_connector is not None, + self.kv_connector != "PyNcclConnector" + ]): + raise ValueError(f"Unsupported kv_connector: {self.kv_connector}. " + f"Supported connectors are " + f"`PyNcclConnector`.") + + if self.kv_role is not None and self.kv_role not in [ + "kv_producer", "kv_consumer", "kv_both" + ]: + raise ValueError( + f"Unsupported kv_role: {self.kv_role}. " + f"Supported roles are `kv_producer`, `kv_consumer`, " + f"and `kv_both`") + + if self.kv_connector is not None and self.kv_role is None: + raise ValueError("Please specify kv_disagg_role when kv_connector " + "is set, supported roles are `kv_producer`, " + "`kv_consumer`, and `kv_both`") + + @property + def is_kv_transfer_instance(self) -> bool: + return self.kv_connector is not None and \ + self.kv_role in ["kv_producer", "kv_consumer", "kv_both"] + + @property + def need_kv_parallel_group(self) -> bool: + # for those database-based connector, vLLM does not need to create + # parallel group, and in that case the kv parallel size will be 1. + return self.kv_connector is not None and self.kv_parallel_size > 1 + + @property + def is_kv_producer(self) -> bool: + return self.kv_connector is not None and \ + self.kv_role in ["kv_producer", "kv_both"] + + @property + def is_kv_consumer(self) -> bool: + return self.kv_connector is not None and \ + self.kv_role in ["kv_consumer", "kv_both"] + + class CompilationLevel: # constants for the levels of the compilation process NO_COMPILATION = 0 @@ -2123,14 +2200,10 @@ class CompilationConfig(BaseModel): - use_inductor: whether to use inductor compilation. - False: inductor compilation is not used. graph runs in eager. - True: inductor compilation is used. one graph for symbolic shape - is compiled. In addition, compile for different sizes specified - in inductor_compile_sizes, using configurations + is compiled. In addition, compile for cudagraph sizes that are + in candidate_compile_sizes, using configurations in inductor_compile_config. - - inductor_compile_sizes: sizes to compile for inductor. - - inductor_specialize_for_cudagraph_no_more_than: an optional integer - to specialize inductor for cudagraph sizes no more than the - specified size. It is useful when we want to specialize inductor - with a subset of cudagraph sizes. + - candidate_compile_sizes: sizes to compile for inductor. - inductor_compile_config: additional configurations for inductor. - None: use default configurations. - inductor_passes: additional passes for inductor. It is a dictionary @@ -2139,7 +2212,7 @@ class CompilationConfig(BaseModel): from Python, functions can also be passed directly via Python object constructor, e.g. `CompilationConfig(inductor_passes={"a": func})` - custom inductor passes: see PassConfig for more details - + Why we have different sizes for cudagraph and inductor: - cudagraph: a cudagraph captured for a specific size can only be used for the same size. We need to capture all the sizes we want to use. @@ -2155,12 +2228,11 @@ class CompilationConfig(BaseModel): custom_ops: List[str] = Field(default_factory=list) splitting_ops: List[str] = Field(default_factory=lambda: [ "vllm.unified_attention", - "vllm.unified_v1_flash_attention", + "vllm.unified_attention_with_output", ]) use_inductor: bool = True - inductor_specialize_for_cudagraph_no_more_than: Optional[int] = None - inductor_compile_sizes: Optional[List[int]] = Field(default=None) + candidate_compile_sizes: Optional[List[int]] = Field(default=None) inductor_compile_config: Dict = Field(default_factory=dict) inductor_passes: Dict[str, str] = Field(default_factory=dict) @@ -2214,6 +2286,7 @@ def model_post_init(self, __context: Any) -> None: # keep track of enabled and disabled custom ops enabled_custom_ops: Counter[str] = PrivateAttr disabled_custom_ops: Counter[str] = PrivateAttr + compilation_time: float = PrivateAttr # Per-model forward context # Mainly used to store attention cls @@ -2225,7 +2298,9 @@ def from_cli(cls, cli_value: str) -> "CompilationConfig": """Parse the CLI value for the compilation config.""" if cli_value in ["0", "1", "2", "3"]: return cls(level=int(cli_value)) - return CompilationConfig.model_validate_json(cli_value) + # do not use `eval`, it is dangerous and can execute arbitrary code + dict_value = ast.literal_eval(cli_value) + return CompilationConfig.model_validate(dict_value) def model_post_init(self, __context: Any) -> None: @@ -2252,6 +2327,7 @@ def model_post_init(self, __context: Any) -> None: self.enabled_custom_ops = Counter() self.disabled_custom_ops = Counter() self.static_forward_context = {} + self.compilation_time = 0.0 def init_backend(self) -> Union[str, Callable]: if self.level == CompilationLevel.NO_COMPILATION: @@ -2274,15 +2350,10 @@ def init_backend(self) -> Union[str, Callable]: from vllm.compilation.backends import VllmBackend return VllmBackend(self) - def init_during_runtime(self): + def init_with_cudagraph_sizes(self, sizes_to_specialize: List[int]): """To complete the initialization of config, - we need to know the compile context, which is only available - during the first run of the model. - """ - from vllm.compilation.compile_context import get_compile_context - context = get_compile_context() - context = copy.deepcopy(context) if context is not None else [] - sizes_to_specialize: List[int] = context + we need to know the cudagraph sizes.""" + if self.cudagraph_capture_sizes is None: self.capture_sizes = sizes_to_specialize else: @@ -2290,18 +2361,35 @@ def init_during_runtime(self): logger.info(("cudagraph sizes specified by model runner" " %s is overridden by config %s"), sizes_to_specialize, self.cudagraph_capture_sizes) - if self.inductor_specialize_for_cudagraph_no_more_than is not None: - assert self.inductor_compile_sizes is None, ( - "inductor_compile_sizes should be None when " - "inductor_specialize_for_cudagraph_no_more_than is not None") - self.compile_sizes = [ - x for x in self.capture_sizes - if x <= self.inductor_specialize_for_cudagraph_no_more_than - ] - else: - if self.inductor_compile_sizes is None: - self.inductor_compile_sizes = [] - self.compile_sizes = self.inductor_compile_sizes + + if self.candidate_compile_sizes is None: + self.candidate_compile_sizes = [] + self.compile_sizes = [ + x for x in self.candidate_compile_sizes if x in self.capture_sizes + ] + ignored_sizes = [ + x for x in self.candidate_compile_sizes + if x not in self.capture_sizes + ] + if ignored_sizes: + logger.warning(("candidate_compile_sizes %s are ignored " + "because they are not cudagraph capture sizes."), + ignored_sizes) + + # sort to make sure cudagraph capture sizes are in descending order + self.capture_sizes.sort(reverse=True) + + +_BATCH_SIZE_ALIGNMENT = 8 +# all the token sizes that **can** be captured by cudagraph. +# they can be arbitrarily large. +# currently it includes: 1, 2, 4, 8, 16, 24, 32, 40, ..., 8192. +# the actual sizes to capture will be determined by the model, +# depending on the model's max_num_seqs. +# NOTE: get_graph_batch_size needs to be updated if this list is changed. +_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [ + _BATCH_SIZE_ALIGNMENT * i for i in range(1, 1025) +] @dataclass @@ -2327,6 +2415,44 @@ class VllmConfig: quant_config: Optional[QuantizationConfig] = None compilation_config: CompilationConfig = field(default=None, init=True) # type: ignore + kv_transfer_config: KVTransferConfig = field(default=None, + init=True) # type: ignore + instance_id: str = "" + + @staticmethod + def get_graph_batch_size(batch_size: int) -> int: + """Returns the padded batch size given actual batch size. + + Batch sizes are 1, 2, 4, _BATCH_SIZE_ALIGNMENT, + 2*_BATCH_SIZE_ALIGNMENT, 3*_BATCH_SIZE_ALIGNMENT... + """ + if batch_size <= 2: + return batch_size + elif batch_size <= 4: + return 4 + else: + return ((batch_size + _BATCH_SIZE_ALIGNMENT - 1) // + _BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT) + + @staticmethod + def get_max_graph_batch_size(max_num_seqs: int) -> int: + """ + max_num_seqs: Maximum number of sequences in a batch. + _BATCH_SIZES_TO_CAPTURE: all the sizes that we want to capture. + + pad the max_num_seqs if necessary by calling get_graph_batch_size, + which will deal with some edge cases like 1, 2, 4. + + if the padded size is in _BATCH_SIZES_TO_CAPTURE, return the padded + size. if not, it means the padded size is larger than the largest size + in _BATCH_SIZES_TO_CAPTURE, return the largest size in + _BATCH_SIZES_TO_CAPTURE. + """ + padded_size = VllmConfig.get_graph_batch_size(max_num_seqs) + if padded_size in _BATCH_SIZES_TO_CAPTURE: + return padded_size + assert padded_size > _BATCH_SIZES_TO_CAPTURE[-1] + return _BATCH_SIZES_TO_CAPTURE[-1] @staticmethod def _get_quantization_config( @@ -2356,7 +2482,15 @@ def _get_quantization_config( return quant_config return None - def with_hf_config(self, hf_config: PretrainedConfig) -> "VllmConfig": + def with_hf_config( + self, + hf_config: PretrainedConfig, + architectures: Optional[list[str]] = None, + ) -> "VllmConfig": + if architectures is not None: + hf_config = copy.deepcopy(hf_config) + hf_config.architectures = architectures + model_config = copy.deepcopy(self.model_config) model_config.hf_config = hf_config @@ -2411,6 +2545,28 @@ def __post_init__(self): self.compilation_config.pass_config.enable_reshape = False self.compilation_config.level = CompilationLevel.PIECEWISE + if not envs.VLLM_USE_V1: + max_batchsize_to_capture = 0 + if self.scheduler_config is not None and \ + self.model_config is not None and \ + not self.model_config.enforce_eager: + max_batchsize_to_capture = \ + self.get_max_graph_batch_size( + self.scheduler_config.max_num_seqs) + batch_size_capture_list = [ + size for size in _BATCH_SIZES_TO_CAPTURE + if size <= max_batchsize_to_capture + ] + else: + batch_size_capture_list = [] + if self.model_config is not None and \ + not self.model_config.enforce_eager: + batch_size_capture_list = [1, 2, 4 + ] + [i for i in range(8, 513, 8)] + + self.compilation_config.init_with_cudagraph_sizes( + batch_size_capture_list) + if self.cache_config is not None and \ self.cache_config.cpu_offload_gb > 0 and \ self.compilation_config.level != CompilationLevel.NO_COMPILATION: @@ -2427,6 +2583,9 @@ def __post_init__(self): current_platform.check_and_update_config(self) + if not self.instance_id: + self.instance_id = random_uuid()[:5] + def __str__(self): return ("model=%r, speculative_config=%r, tokenizer=%r, " "skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, " diff --git a/vllm/distributed/device_communicators/pynccl.py b/vllm/distributed/device_communicators/pynccl.py index d4e3f81747038..a6800f93f167b 100644 --- a/vllm/distributed/device_communicators/pynccl.py +++ b/vllm/distributed/device_communicators/pynccl.py @@ -197,6 +197,25 @@ def recv(self, tensor: torch.Tensor, src: int, stream=None): ncclDataTypeEnum.from_torch(tensor.dtype), src, self.comm, cudaStream_t(stream.cuda_stream)) + def broadcast(self, tensor: torch.Tensor, src: int, stream=None): + if self.disabled: + return + assert tensor.device == self.device, ( + f"this nccl communicator is created to work on {self.device}, " + f"but the input tensor is on {tensor.device}") + if stream is None: + stream = self.stream + if src == self.rank: + sendbuff = buffer_type(tensor.data_ptr()) + # NCCL requires the sender also to have a receive buffer + recvbuff = buffer_type(tensor.data_ptr()) + else: + sendbuff = buffer_type() + recvbuff = buffer_type(tensor.data_ptr()) + self.nccl.ncclBroadcast(sendbuff, recvbuff, tensor.numel(), + ncclDataTypeEnum.from_torch(tensor.dtype), src, + self.comm, cudaStream_t(stream.cuda_stream)) + @contextmanager def change_state(self, enable: Optional[bool] = None, diff --git a/vllm/distributed/device_communicators/pynccl_wrapper.py b/vllm/distributed/device_communicators/pynccl_wrapper.py index ff88f72470b27..7dea61b6a09f1 100644 --- a/vllm/distributed/device_communicators/pynccl_wrapper.py +++ b/vllm/distributed/device_communicators/pynccl_wrapper.py @@ -189,6 +189,15 @@ class NCCLLibrary: ncclComm_t, cudaStream_t ]), + # ncclResult_t ncclBroadcast( + # const void* sendbuff, void* recvbuff, size_t count, + # ncclDataType_t datatype, int root, ncclComm_t comm, + # cudaStream_t stream); + Function("ncclBroadcast", ncclResult_t, [ + buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t, + ctypes.c_int, ncclComm_t, cudaStream_t + ]), + # be cautious! this is a collective call, it will block until all # processes in the communicator have called this function. # because Python object destruction can happen in random order, @@ -312,6 +321,13 @@ def ncclRecv(self, recvbuff: buffer_type, count: int, datatype: int, self.NCCL_CHECK(self._funcs["ncclRecv"](recvbuff, count, datatype, src, comm, stream)) + def ncclBroadcast(self, sendbuff: buffer_type, recvbuff: buffer_type, + count: int, datatype: int, root: int, comm: ncclComm_t, + stream: cudaStream_t) -> None: + self.NCCL_CHECK(self._funcs["ncclBroadcast"](sendbuff, recvbuff, count, + datatype, root, comm, + stream)) + def ncclCommDestroy(self, comm: ncclComm_t) -> None: self.NCCL_CHECK(self._funcs["ncclCommDestroy"](comm)) diff --git a/vllm/distributed/kv_transfer/README.md b/vllm/distributed/kv_transfer/README.md new file mode 100644 index 0000000000000..dab2d10c4c9d0 --- /dev/null +++ b/vllm/distributed/kv_transfer/README.md @@ -0,0 +1,30 @@ + +# Distributed KV cache transfer + +This folder implements distributed KV cache transfer across vLLM instances. +Currently the main usecase is for disaggregated prefilling. + +## Abstractions + +The KV cache transfer contains three layer of abstractions: + +- KV pipe: a FIFO pipe for torch.tensor transmission. Key APIs: `send_tensor` and `recv_tensor`. +- KV lookup buffer: a lookup buffer for KV caches. Key: the tokens, value: the KV caches (and/or hidden states). Key APIs: `insert` and `drop_select` (similar to SQL semantics). +- KV connector: a connector that connects the KV pipe and KV lookup buffer to vLLM. Key APIs: `send_kv_caches_and_hidden_states` and `recv_kv_caches_and_hidden_states`. + +Why we need KV lookup buffer: FIFO pipe itself is not enough as prefill vLLM worker may process requests in a different order compared to decode vLLM worker. Say the QPS is really high, prefill worker may handle requests in order A -> B -> C, but the decode worker may process request C first. This is not the case that can be naturally handled by FIFO pipe, so we provide KV lookup buffer to help translate a FIFO pipe to a lookup buffer. + +NOTE: KV pipe layer is bypassible: you can skip this layer if your distributed +communication service already supports key-value-based lookup (like redis or +RDMA database). + +NOTE: If you want to not only transfer KV caches, but adjust the model execution flow of vLLM as well (for example, allow vLLM to receive KV caches on some tokens and do prefill on the remaining tokens), you can bypass both KV pipe layer and KV lookup buffer layer, and directly implement on KV connector layer. Bear in mind that as vLLM's model input is constantly changing, this implementation will likely be broken when vLLM has new updates. + +## Disaggregated prefilling + +The example usage is in [this file](../../../examples/disaggregated_prefill.sh). + +Here is the diagram of how we run disaggretgated prefilling. + +![Disaggregated prefill workflow](./disagg_prefill_workflow.jpg) + diff --git a/vllm/distributed/kv_transfer/__init__.py b/vllm/distributed/kv_transfer/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/distributed/kv_transfer/disagg_prefill_workflow.jpg b/vllm/distributed/kv_transfer/disagg_prefill_workflow.jpg new file mode 100644 index 0000000000000..a25ec5ef52491 Binary files /dev/null and b/vllm/distributed/kv_transfer/disagg_prefill_workflow.jpg differ diff --git a/vllm/distributed/kv_transfer/kv_connector/__init__.py b/vllm/distributed/kv_transfer/kv_connector/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/distributed/kv_transfer/kv_connector/base.py b/vllm/distributed/kv_transfer/kv_connector/base.py new file mode 100644 index 0000000000000..6089e3babac3e --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/base.py @@ -0,0 +1,122 @@ +""" +KVConnectorBase Class for Distributed KV Cache & Hidden State communication + +The class provides two primary abstract methods: +1. send_kv_caches_and_hidden_states(): Send KV caches and hidden states +2. recv_kv_caches_and_hidden_states(): Recv KV caches and hidden states +""" + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, List, Tuple, Union + +import torch + +from vllm.sequence import IntermediateTensors + +if TYPE_CHECKING: + from vllm.config import VllmConfig + from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata + + +class KVConnectorBase(ABC): + """ + Abstract base class for a KV connector. + + The class provides two primary abstract methods: + 1. send_kv_caches_and_hidden_states(): Send KV caches and hidden states + 2. recv_kv_caches_and_hidden_states(): Recv KV caches and hidden states + """ + + @abstractmethod + def __init__( + self, + rank: int, + local_rank: int, + config: "VllmConfig", + ): + raise NotImplementedError + + @abstractmethod + def close(self) -> None: + """Close the buffer and release resources. + + This method is responsible for cleaning up resources related to the + connector when it is no longer needed. + + Raises: + NotImplementedError: This method must be implemented in subclasses. + """ + raise NotImplementedError + + @abstractmethod + def send_kv_caches_and_hidden_states( + self, + model_executable: torch.nn.Module, + model_input: "ModelInputForGPUWithSamplingMetadata", + kv_caches: List[torch.Tensor], + hidden_or_intermediate_states: Union[torch.Tensor, + IntermediateTensors], + ) -> None: + """ + Send KV caches and hidden states to the connector. + + This method processes the input tokens, KV caches, and + hidden/intermediate states for a given model and sends the data to the + decode instance. + + Args: + model_executable (torch.nn.Module): The model executable containing + start and end layer information. + model_input (ModelInputForGPUWithSamplingMetadata): The input + metadata from vLLM. + kv_caches (List[torch.Tensor]): List of KV caches (keys and values) + for each layer. + hidden_or_intermediate_states (Union[torch.Tensor, + IntermediateTensors]): + The hidden or intermediate states associated with the tokens. + + Returns: + None + + """ + + raise NotImplementedError + + @abstractmethod + def recv_kv_caches_and_hidden_states( + self, model_executable: torch.nn.Module, + model_input: "ModelInputForGPUWithSamplingMetadata", + kv_caches: List[torch.Tensor] + ) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool, + "ModelInputForGPUWithSamplingMetadata"]: + """ + Receive KV caches and hidden states from the connector. + + This method attempts to retrieve KV caches and hidden states for input + tokens. If all required KV caches and hidden states are received, it + will bypass model input, else it will fall back to normal vLLM model + forwarding. + + Args: + model_executable (torch.nn.Module): + The model executable from vLLM modelrunner. + model_input (ModelInputForGPUWithSamplingMetadata): + The model input from vLLM modelrunner. + kv_caches (List[torch.Tensor]): + List of KV caches for each layer. + + Returns: + - hidden_or_intermediate_states (torch.Tensor or + IntermediateTensors): + Concatenated hidden states if all required data is retrieved, + otherwise `None`. + - bypass_model_exec (bool): + Indicates whether the model execution can be skipped (True) or + needs to be redone (False). + - model_input (ModelInputForGPUWithSamplingMetadata): + Optionally adjusted input metadata for re-execution when + `bypass_model_exec=False`. + + """ + + raise NotImplementedError diff --git a/vllm/distributed/kv_transfer/kv_connector/factory.py b/vllm/distributed/kv_transfer/kv_connector/factory.py new file mode 100644 index 0000000000000..015f892cec933 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/factory.py @@ -0,0 +1,19 @@ +from typing import TYPE_CHECKING + +from .base import KVConnectorBase + +if TYPE_CHECKING: + from vllm.config import VllmConfig + + +class KVConnectorFactory: + + @staticmethod + def create_connector(rank: int, local_rank: int, + config: "VllmConfig") -> KVConnectorBase: + if config.kv_transfer_config.kv_connector == 'PyNcclConnector': + from .simple_connector import SimpleConnector + return SimpleConnector(rank, local_rank, config) + else: + raise ValueError(f"Unsupported connector type: " + f"{config.kv_connector}") diff --git a/vllm/distributed/kv_transfer/kv_connector/simple_connector.py b/vllm/distributed/kv_transfer/kv_connector/simple_connector.py new file mode 100644 index 0000000000000..5870070a54c75 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/simple_connector.py @@ -0,0 +1,261 @@ +""" +Simple KV Cache Connector for Distributed Machine Learning Inference + +The SimpleConnector transfers KV caches between prefill vLLM worker (KV cache +producer) and decode vLLM worker (KV cache consumer) using PyNcclPipe. + +But the logic can be extended to support other pipe and lookup buffer. +""" +from typing import TYPE_CHECKING, List, Optional, Tuple, Union + +import torch + +from vllm import _custom_ops as ops +from vllm.config import VllmConfig +from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase +from vllm.distributed.kv_transfer.kv_lookup_buffer.simple_buffer import ( + SimpleBuffer) +from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import PyNcclPipe +from vllm.logger import init_logger +from vllm.sequence import IntermediateTensors + +if TYPE_CHECKING: + from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata + +logger = init_logger(__name__) + + +class SimpleConnector(KVConnectorBase): + + def __init__( + self, + rank: int, + local_rank: int, + config: VllmConfig, + ): + + self.config = config.kv_transfer_config + + logger.info("Initializing PyNcclConfig under kv_transfer_config %s", + self.config) + + self.lookup_buffer_size = self.config.kv_buffer_size + + self.producer_buffer: Optional[SimpleBuffer] = None + self.consumer_buffer: Optional[SimpleBuffer] = None + + # 2 pipes for every rank in the world + port_offset_base = 2 * rank + + # In disaggregated prefill, the prefill vLLM only uses send pipe + # and the decode vLLM only uses recv pipe + if self.config.is_kv_producer: + + self.producer_data_pipe = PyNcclPipe( + local_rank=local_rank, + config=self.config, + port_offset=port_offset_base, + ) + self.producer_signal_pipe = PyNcclPipe( + local_rank=local_rank, + config=self.config, + port_offset=port_offset_base + 1, + device="cpu", + ) + self.producer_buffer = SimpleBuffer(self.producer_signal_pipe, + self.producer_data_pipe, + self.config.kv_buffer_size) + + else: + + # the current vLLM instance is KV consumer, so it needs to connect + # its recv pipe to the send pipe of KV producder + self.consumer_data_pipe = PyNcclPipe( + local_rank=local_rank, + config=self.config, + port_offset=port_offset_base, + ) + self.consumer_signal_pipe = PyNcclPipe( + local_rank=local_rank, + config=self.config, + port_offset=port_offset_base + 1, + device="cpu", + ) + self.consumer_buffer = SimpleBuffer( + self.consumer_signal_pipe, + self.consumer_data_pipe, + self.config.kv_buffer_size, + ) + + def select(self, input_tokens: Optional[torch.Tensor], + roi: Optional[torch.Tensor]) -> List[Optional[torch.Tensor]]: + + assert self.consumer_buffer is not None, "Please initialize the "\ + "consumer buffer before calling select." + return self.consumer_buffer.drop_select(input_tokens, roi) + + def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor, + key: torch.Tensor, value: torch.Tensor, + hidden: torch.Tensor) -> None: + + assert self.producer_buffer is not None, "Please initialize the "\ + "producer buffer before calling insert." + + self.producer_buffer.insert(input_tokens, roi, key, value, hidden) + + def send_kv_caches_and_hidden_states( + self, + model_executable: torch.nn.Module, + model_input: "ModelInputForGPUWithSamplingMetadata", + kv_caches: List[torch.Tensor], + hidden_or_intermediate_states: Union[torch.Tensor, + IntermediateTensors], + ) -> None: + + input_tokens_tensor = model_input.input_tokens + seq_lens = model_input.attn_metadata.seq_lens + slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten() + start_layer = model_executable.model.start_layer + end_layer = model_executable.model.end_layer + + # query_lens contains new KV caches that are added to vLLM. + # so we will send them to decode instance + # FIXME(Kuntai): This assume that all requests are prefill. + for idx, slen in enumerate(seq_lens): + start_pos = sum(seq_lens[:idx]) + end_pos = start_pos + slen + current_tokens = input_tokens_tensor[start_pos:end_pos] + + keys, values = [], [] + + for layer_id in range(start_layer, end_layer): + kv_cache = kv_caches[layer_id - start_layer] + + _, _, num_heads, head_size = kv_cache[0].shape + + key_cache = kv_cache[0].reshape(-1, num_heads, head_size) + value_cache = kv_cache[1].reshape(-1, num_heads, head_size) + + current_slot_mapping = slot_mapping_flat[start_pos:end_pos] + + keys.append(key_cache[current_slot_mapping].unsqueeze(0)) + values.append(value_cache[current_slot_mapping].unsqueeze(0)) + + keys = torch.cat(keys, dim=0) + values = torch.cat(values, dim=0) + + self.insert(current_tokens, + torch.ones_like(current_tokens, + dtype=bool), keys, values, + hidden_or_intermediate_states[start_pos:end_pos]) + + logger.debug("[rank%d]: KV send DONE.", torch.distributed.get_rank()) + + def recv_kv_caches_and_hidden_states( + self, model_executable: torch.nn.Module, + model_input: "ModelInputForGPUWithSamplingMetadata", + kv_caches: List[torch.Tensor] + ) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool, + "ModelInputForGPUWithSamplingMetadata"]: + + # When bypass_model_exec is set to False, it means that at least for one + # request its corresponding KV cache or hidden state is missing. + # In this case we need to do prefilling to recompute missing KV cache + # and hidden states. + bypass_model_exec = True + + input_tokens_tensor = model_input.input_tokens + seq_lens = model_input.attn_metadata.seq_lens + slot_mapping = model_input.attn_metadata.slot_mapping.flatten() + + hidden_or_intermediate_states_for_one_req = [] + + input_tokens_list = [] + num_computed_tokens_list = [] + start_pos_list = [] + + # enumerate different requests + # FIXME(Kuntai): This impl assumes that all requests are prefill. + for idx, slen in enumerate(seq_lens): + + start_pos = sum(seq_lens[:idx]) + end_pos = start_pos + slen + current_tokens = input_tokens_tensor[start_pos:end_pos] + num_tokens = slen + + # collecting data for rebuilding the input + input_tokens_list.append(current_tokens) + start_pos_list.append(start_pos) + + ret = self.select(current_tokens, + torch.ones_like(current_tokens, dtype=bool)) + if ret[0] is None: + # didn't find any match. + bypass_model_exec = False + num_computed_tokens_list.append(0) + continue + + roi: torch.Tensor = ret[1] + keys: torch.Tensor = ret[2] + values: torch.Tensor = ret[3] + hidden: torch.Tensor = ret[4] + + num_computed_tokens = roi.shape[0] + num_computed_tokens_list.append(num_computed_tokens) + + # check if both KV cache and the hidden states are received + # If not, need to redo the forwarding to compute missing states + if not all([(num_computed_tokens == num_tokens), hidden is not None + ]): + bypass_model_exec = False + + # update the end position based on how many tokens are cached. + end_pos = start_pos + num_computed_tokens + + # put received KV caches into paged memory + for i in range(model_executable.model.start_layer, + model_executable.model.end_layer): + + kv_cache = kv_caches[i - model_executable.model.start_layer] + layer = model_executable.model.layers[i] + + key_cache, value_cache = kv_cache[0], kv_cache[1] + ops.reshape_and_cache_flash( + keys[i - model_executable.model.start_layer].to( + key_cache.device), + values[i - model_executable.model.start_layer].to( + value_cache.device), + key_cache, + value_cache, + slot_mapping[start_pos:end_pos], + layer.self_attn.attn.kv_cache_dtype, + layer.self_attn.attn._k_scale, + layer.self_attn.attn._v_scale, + ) + + hidden_or_intermediate_states_for_one_req.append(hidden) + + if not bypass_model_exec: + # Some of the KV cache is not retrieved + # Here we will fall back to normal model forwarding + # But optionally you can adjust model_input so that you only do + # prefilling on those tokens that are missing KV caches. + logger.debug( + "[rank%d]: Failed to receive all KVs and hidden " + "states, redo model forwarding.", torch.distributed.get_rank()) + hidden_or_intermediate_states = None + + else: + logger.debug( + "[rank%d]: Successfully received all KVs and hidden " + "states, skip model forwarding.", torch.distributed.get_rank()) + hidden_or_intermediate_states = torch.cat( + hidden_or_intermediate_states_for_one_req, dim=0) + + return hidden_or_intermediate_states, bypass_model_exec, model_input + + def close(self): + self.producer_data_pipe.close() + self.producer_signal_pipe.close() + self.consumer_data_pipe.close() + self.consumer_signal_pipe.close() diff --git a/vllm/distributed/kv_transfer/kv_lookup_buffer/__init__.py b/vllm/distributed/kv_transfer/kv_lookup_buffer/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/distributed/kv_transfer/kv_lookup_buffer/base.py b/vllm/distributed/kv_transfer/kv_lookup_buffer/base.py new file mode 100644 index 0000000000000..bad119a1aa929 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_lookup_buffer/base.py @@ -0,0 +1,108 @@ +""" +This file contains a new class `KVLookupBufferBase` that allows developers to +think of KV cache operations as inserting new KV cache entries (`insert`) +into the lookup buffer and querying existing KV caches (`drop_select`) +from the lookup buffer. + +All distributed communications are abstracted behind this class. +""" + +from abc import ABC, abstractmethod +from typing import List, Optional + +import torch + + +class KVLookupBufferBase(ABC): + """ + Abstract base class for a lookup buffer. + + This class provides an abstraction for a key-value (KV) cache lookup buffer. + + The key of the lookup buffer: + - input_tokens: token IDs of the request + - roi: a binary mask on top of input_tokens. + - Purpose of roi: Since KV cache may only be available for a subset of + tokens in the input (for example, when vLLM is connected to an external + KV cache service), roi specifies the subset of tokens that the KV cache + is associated with. + - NOTE: roi can be further extended to describe which part of KV the + current process is holding (each process may only hold a part of KV + due to TP and PP). This is not implemented for now. + + The value of the lookup buffer: + - key: the key tensor in the KV cache + - value: the value tensor in the KV cache + - hidden: the final hidden state generated by model forwarding. This allows + vLLM to bypass further model forwarding by transmitting the hidden state. + """ + + @abstractmethod + def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor, + key: torch.Tensor, value: torch.Tensor, + hidden: torch.Tensor) -> None: + """Insert into the lookup buffer. + + The functionality is similar to the following python statement + ``` + buffer[input_tokens, roi] = [key, value, hidden] + ``` + + FIXME: in the future, we should only have two arguments, key and value, + where key is a tensor dict and value is a tensor dict. + + FIXME: we should transmit both sampler outputs and the hidden states. + + Args: + input_tokens (torch.Tensor): token IDs. + roi (torch.Tensor): A binary mask on top of the input tokens + key (torch.Tensor): The key tensor in the KV cache. + value (torch.Tensor): The value tensor in the KV cache. + hidden (torch.Tensor): The final hidden state tensor generated + during model forwarding to bypass model + forwarding. + + Raises: + NotImplementedError: This method must be implemented in subclasses. + """ + raise NotImplementedError + + @abstractmethod + def drop_select( + self, input_tokens: Optional[torch.Tensor], + roi: Optional[torch.Tensor]) -> List[Optional[torch.Tensor]]: + """Select and *drop* KV cache entries from the lookup buffer. + + The functionality is similar to the following python statements + ``` + ret = buffer.pop(input_tokens, roi) + return ret + ``` + + If `input_tokens` and `roi` is `None`, it means selecting any of the + KV caches in the buffer, return, and remove it from the buffer, useful + when offloading KV cache to KV cache storage service. + + Args: + input_tokens (torch.Tensor): token IDs. + roi (torch.Tensor): A binary mask on top of the input tokens + + Returns: + List[Optional[torch.Tensor]]: A list of tensors. Can be None. + + Raises: + NotImplementedError: This method must be implemented in subclasses. + """ + raise NotImplementedError + + @abstractmethod + def close(self) -> None: + """Close the buffer and release resources. + + This method is responsible for cleaning up resources related to the + lookup buffer when it is no longer needed. + + Raises: + NotImplementedError: This method must be implemented in subclasses. + """ + raise NotImplementedError diff --git a/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py b/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py new file mode 100644 index 0000000000000..fe8d8d7375f36 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py @@ -0,0 +1,242 @@ +""" + Implements a distributed key-value (KV) cache transfer mechanism. + + Key Features: + - Distributed KV cache transmission using PyNccl pipes. + - Non-blocking `insert`, blocking `drop_select`. + - Use CPU signal pipe to avoid racing condition + - Handles buffer size constraints and provide backpressure mechanism to + stop the prefill instance when the decode instance is slow. +""" +import threading +import time +from collections import deque +from typing import Deque, List, Optional, Union + +import torch + +from vllm.distributed.kv_transfer.kv_lookup_buffer.base import ( + KVLookupBufferBase) +from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +class SimpleBuffer(KVLookupBufferBase): + + def __init__(self, signal_pipe: KVPipeBase, data_pipe: KVPipeBase, + buffer_size_thresh: float): + """ + signal_pipe: on CPU + + NOTE: on-device recv will block all threads in the process, making the + KV cache producer unable to listen to new request while transmitting + KV cache. Luckily CPU recv only blocks the current thread so we use + CPU recv to listen to new request. + + data_pipe: on device (e.g. GPU) + """ + + self.buffer: Deque[List[torch.Tensor]] = deque() + + self.buffer_size = 0 + self.buffer_size_threshold = buffer_size_thresh + self.buffer_lock = threading.Lock() + self.signal_pipe = signal_pipe + self.data_pipe = data_pipe + self.request_handling_thread: Optional[threading.Thread] = None + + self.normal_signal = torch.tensor([0], device="cpu") + self.end_signal = None + + def _matches(self, tokens_roi_sender: List[torch.Tensor], + tokens_roi_recver: List[torch.Tensor]): + + # tokens_roi_sender: tokens and roi of the producer (in the buffer) + # tokens_roi_recver: tokens and roi of the consumer (query) + + tokens_sender = tokens_roi_sender[0] + tokens_recver = tokens_roi_recver[0] + roi_sender = tokens_roi_sender[1] + roi_recver = tokens_roi_recver[1] + + if tokens_recver is None: + # consumer sends an empty request + # semantics: DROP SELECT * LIMIT 1 + # so any of the data in the buffer can be drop-selected + return True + + # Assuming that roi is a binary mask on tokens + tokens_sender = tokens_sender[roi_sender] + tokens_recver = tokens_recver[roi_recver] + + # simple common prefix matching + min_length = min(len(tokens_sender), len(tokens_recver)) + if torch.allclose(tokens_sender[:min_length], + tokens_recver[:min_length]): + return min_length + + return 0 + + def _send_tensor_and_dec_size(self, + tensor: Optional[torch.Tensor]) -> None: + + assert tensor is not None, "Use self.data_pipe.send(None) instead" + self.buffer_size -= tensor.element_size() * tensor.numel() + if tensor.dtype == torch.bool: + tensor = tensor.float() + self.data_pipe.send_tensor(tensor) + + def _get_element_size(self, data: Optional[Union[List, torch.Tensor]]): + + if isinstance(data, torch.Tensor): + return data.element_size() * data.numel() + if not data: + # cannot perform `not data` on a tensor + # so this check needs to go after the check above + return 0 + + raise AssertionError(f"Unknown data type {type(data)}") + + def _add_to_buffer(self, input_tokens: torch.Tensor, roi: torch.Tensor, + key: torch.Tensor, value: torch.Tensor, + hidden: torch.Tensor): + + if isinstance(input_tokens, torch.Tensor): + input_tokens = input_tokens.clone() + if isinstance(roi, torch.Tensor): + roi = roi.clone() + if isinstance(key, torch.Tensor): + key = key.clone() + if isinstance(value, torch.Tensor): + value = value.clone() + if isinstance(hidden, torch.Tensor): + hidden = hidden.clone() + + buffer_item = [input_tokens, roi, key, value, hidden] + + with self.buffer_lock: + for data in buffer_item: + self.buffer_size += self._get_element_size(data) + self.buffer.append(buffer_item) + + def _is_end_signal(self, signal): + return signal is None + + def drop_select_handler(self): + + try: + + while True: + signal = self.signal_pipe.recv_tensor() + if self._is_end_signal(signal): + logger.info("Received end signal!") + break + + input_tokens = self.data_pipe.recv_tensor() + + roi = self.data_pipe.recv_tensor() + assert roi is not None, "Please provide the roi when sending "\ + "drop-select request" + roi = (roi > 0.5) + tokens_roi_recver = [input_tokens, roi] + + matched_length = 0 + + # perform input tokens and roi matching + # FIXME: this matching is O(n), ideally it should be O(1) + # but this buffer size won't (and shouldn't) be too large so + # the fix is not urgent. + with self.buffer_lock: + + for _ in range(len(self.buffer)): + + temp_length = self._matches(self.buffer[0], + tokens_roi_recver) + if temp_length > 0: + matched_length = temp_length + break + # rotate the element we just accessed to the end + self.buffer.rotate(-1) + + if matched_length > 0: + # need to clone the tensor + # in case the tensor is freed before sending finishes + matched_item = self.buffer.popleft() + for tensor in matched_item: + self._send_tensor_and_dec_size(tensor) + + else: + # no match, just send None + for _ in range(5): + self.data_pipe.send_tensor(None) + + except RuntimeError as e: + if 'Connection closed by peer' not in str(e): + raise e + + logger.debug("Closing drop_select_handler") + + def drop_select( + self, input_tokens: Optional[torch.Tensor], + roi: Optional[torch.Tensor]) -> List[Optional[torch.Tensor]]: + + assert self.request_handling_thread is None, \ + "drop_select should be called by the KV cache consumer "\ + "(e.g. the decode vLLM instance)" + + if isinstance(input_tokens, torch.Tensor): + input_tokens = input_tokens.clone() + if isinstance(roi, torch.Tensor): + roi = roi.clone().float() + + self.signal_pipe.send_tensor(self.normal_signal) + self.data_pipe.send_tensor(input_tokens) + self.data_pipe.send_tensor(roi) + + input_tokens = self.data_pipe.recv_tensor() + roi = self.data_pipe.recv_tensor() + if roi is not None: + # convert from float tensor to bool tensor + # as PyNccl does not support sending bool tensor + roi = (roi > 0.5) + key = self.data_pipe.recv_tensor() + value = self.data_pipe.recv_tensor() + hidden = self.data_pipe.recv_tensor() + + return [input_tokens, roi, key, value, hidden] + + def full_handler(self): + time.sleep(0.001) + + def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor, + key: torch.Tensor, value: torch.Tensor, + hidden: torch.Tensor) -> None: + + if self.buffer_size > self.buffer_size_threshold: + # log outside the while loop to avoid this message being logged + # repeatedly. + logger.debug("KV transfer buffer is full. Handling...") + while self.buffer_size > self.buffer_size_threshold: + self.full_handler() + + self._add_to_buffer(input_tokens, roi, key, value, hidden) + + # when calling the insert, the current process is a sender + # need to launch the request handler and start listening to request. + if self.request_handling_thread is None: + self.request_handling_thread = threading.Thread( + target=self.drop_select_handler) + self.request_handling_thread.start() + + def close(self): + + if hasattr(self, "request_handling_thread" + ) and self.request_handling_thread is not None: + self.request_handling_thread.join() + + else: + # TODO: have a explicit close signal and have a explicit way to + # check if it's requester + self.signal_pipe.send_tensor(self.end_signal) diff --git a/vllm/distributed/kv_transfer/kv_pipe/__init__.py b/vllm/distributed/kv_transfer/kv_pipe/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/distributed/kv_transfer/kv_pipe/base.py b/vllm/distributed/kv_transfer/kv_pipe/base.py new file mode 100644 index 0000000000000..4b0cb44cc5b81 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_pipe/base.py @@ -0,0 +1,65 @@ +""" +This file defines an interface `KVPipeBase` +that provides an abstraction for sending and receiving tensors, or None, via +distributed communications. + +All classes instantiated from this interface are assumed to be a FIFO pipe. + +If your distributed communication platform already supports key-value lookup, +you can bypass this interface and directly start from `kv_lookup_buffer`. +""" + +from abc import ABC, abstractmethod +from typing import Optional + +import torch + + +class KVPipeBase(ABC): + """ + This class provides an interface for sending and receiving tensors, or + None, by distributed communications. + """ + + @abstractmethod + def send_tensor(self, tensor: Optional[torch.Tensor]) -> None: + """Send a tensor, or None, via the pipe. + + Need to support sending None -- important for error handling. + + TODO: add a `key` argument so that we can use traditional + key-value database as the distributed communication mechanism behind + the pipe. + + Args: + tensor (Optional[torch.Tensor]): The tensor to be sent. Can be None. + + Raises: + NotImplementedError: This method must be implemented in subclasses. + """ + raise NotImplementedError + + @abstractmethod + def recv_tensor(self) -> Optional[torch.Tensor]: + """Receive a tensor (can be None) from the pipeline. + + Returns: + Optional[torch.Tensor]: The tensor received from the pipeline. Can + be None. + + Raises: + NotImplementedError: This method must be implemented in subclasses. + """ + raise NotImplementedError + + @abstractmethod + def close(self) -> None: + """Close the pipeline and release resources. + + This method is responsible for closing the communication pipeline + and releasing any resources associated with it. + + Raises: + NotImplementedError: This method must be implemented in subclasses. + """ + raise NotImplementedError diff --git a/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py b/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py new file mode 100644 index 0000000000000..98222fa67e492 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py @@ -0,0 +1,276 @@ +""" + This module implements a PyNccl pipe for sending and receiving + Optional[torch.Tensor] between distributed ranks with advanced + communication features. + + Key Features: + - Supports sending and receiving tensors with metadata + - Handles both CUDA and CPU device communications + - Implements a non-blocking tensor transfer mechanism + - Manages buffer size and provides backpressure control + - Supports distributed process groups with configurable parameters +""" + +import threading +import time +from concurrent.futures import ThreadPoolExecutor +from typing import Callable, Dict, Optional, Tuple + +import torch + +from vllm.config import KVTransferConfig +from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator +from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase +from vllm.distributed.utils import StatelessProcessGroup +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +class BrokenPipeException(Exception): + + def __init__(self, message): + self.message = message + super().__init__(self.message) + + +Metadata = Dict[str, Optional[torch.Tensor]] + + +class PyNcclPipe(KVPipeBase): + + METADATA_LENGTH = 16 + MAX_TENSOR_DIMENSIONS = 14 + METADATA_DTYPE = torch.int64 + + def __init__(self, + local_rank: int, + config: KVTransferConfig, + device: Optional[str] = None, + port_offset: int = 0): + self.config = config + self.local_rank = local_rank + self.kv_rank = self.config.kv_rank + self.kv_parallel_size = self.config.kv_parallel_size + if device is None: + self.device = self._select_device(self.config.kv_buffer_device) + else: + self.device = self._select_device(device) + + # build distributed connection and send/recv implementation + self.group = StatelessProcessGroup.create( + host=self.config.kv_ip, + port=self.config.kv_port + port_offset, + rank=self.kv_rank, + world_size=self.kv_parallel_size, + ) + # add a barrier to make sure the connection is initiated properly + self.group.barrier() + impl = self._get_device_send_recv_impl(self.group) + self.device_send_func, self.device_recv_func = impl + # set target rank + self.target_rank_for_send = (self.kv_rank + 1) % self.kv_parallel_size + self.target_rank_for_recv = (self.kv_rank - 1) % self.kv_parallel_size + + # transportation-related variables + self.transport_thread: Optional[ThreadPoolExecutor] = None + self.buffer_size = 0 + self.buffer_size_lock = threading.Lock() + self.buffer_size_thresh = self.config.kv_buffer_size + + def _get_device_send_recv_impl( + self, group: StatelessProcessGroup + ) -> Tuple[Callable[[torch.Tensor, int], None], Callable[ + [torch.Tensor, int], None]]: + + send: Callable[[torch.Tensor, int], None] + recv: Callable[[torch.Tensor, int], None] + if self.device.type == "cuda": + # use PyNCCL for send / recv + comm = PyNcclCommunicator(group, device=self.local_rank) + comm.disabled = False + send, recv = comm.send, comm.recv # type: ignore + else: + # This send / recv implementation here is NOT intended to transfer + # KV caches (and should NOT be repurposed to transfer KV caches). + # Currently it is only used to transmit control-plane messages + # for PyNcclBuffer. + send = group.send_obj + + def my_recv(x, src): + x[...] = group.recv_obj(src) + + recv = my_recv + + return send, recv + + def _select_device(self, device: str): + logger.info("Selecting device: %s", device) + if device == "cuda": + return torch.device(f"cuda:{self.local_rank}") + else: + return torch.device("cpu") + + def _make_metadata(self, tensor: Optional[torch.Tensor]) -> Metadata: + """ + Create the metadata as a dictionary based on the input tensor. + + Parameters: + - tensor: The input tensor or None if no tensor is provided. + + Returns: + - metadata: A dictionary with the following keys: + - "dtype": The data type of the tensor or None. + - "shape": The shape of the tensor or None. + """ + if tensor is None: + return {"dtype": None, "shape": None} + else: + return {"dtype": tensor.dtype, "shape": tensor.shape} + + def _prepare_recv_buffer(self, metadata: Metadata) -> torch.Tensor: + """ + Create a buffer to receive the tensor based on the provided metadata. + + Parameters: + - metadata: A dictionary with keys "dtype" and "shape", describing + the tensor's data type and shape. + + Returns: + - buffer: A tensor of the specified type and shape, allocated on + self.device. + """ + return torch.empty(metadata["shape"], + dtype=metadata["dtype"], + device=self.device) + + def _send_metadata(self, metadata: Metadata): + """ + Send the metadata dictionary to the target rank. + + Parameters: + - metadata: A dictionary with keys "dtype" and "shape". + """ + self.group.send_obj(metadata, self.target_rank_for_send) + + def _recv_metadata(self) -> Metadata: + """ + Receive the metadata dictionary from the target rank. + + Returns: + - metadata: A dictionary with keys "dtype" and "shape" describing + the tensor. + """ + return self.group.recv_obj(self.target_rank_for_recv) + + def _send_impl(self, tensor: Optional[torch.Tensor]) -> None: + """ + The actual implementation of sending the tensor and its metadata to the + target rank. + + Parameters: + - tensor: The input tensor to be sent, or None if no tensor is + being sent. + """ + metadata = self._make_metadata(tensor) + self._send_metadata(metadata) + if tensor is not None: + self.device_send_func(tensor.to(self.device), + self.target_rank_for_send) + + def _recv_impl(self) -> Optional[torch.Tensor]: + """ + The actual implementation of receiving a tensor and its metadata from + the target rank. + + Returns: + - buffer: The received tensor, or None if no tensor is received. + """ + metadata = self._recv_metadata() + if metadata["dtype"] is None: + return None + buffer = self._prepare_recv_buffer(metadata) + self.device_recv_func(buffer, self.target_rank_for_recv) + + return buffer + + def send_tensor_wrapper(self, tensor: Optional[torch.Tensor], + tensor_size: int) -> None: + """ + Wrapper for _send_impl to handle exceptions and update buffer size. + """ + try: + self._send_impl(tensor) + + with self.buffer_size_lock: + self.buffer_size -= tensor_size + except Exception as e: + logger.error("[rank%d]: Exception when trying to send %s, msg: %s", + torch.distributed.get_rank(), str(tensor), str(e)) + import traceback + traceback.print_exc() + + def block_if_full(self): + """ + Block the current thread if the buffer size is larger than the + threshold. + """ + while self.buffer_size > self.buffer_size_thresh: + logger.debug("KV cache transfer pipe is full. Waiting...") + time.sleep(0.05) + + def send_tensor(self, tensor: Optional[torch.Tensor]) -> None: + """ + Sends a tensor and its metadata to the destination rank in a + non-blocking way. + + Parameters: + - tensor: The tensor to send, or None if no tensor is being sent. + """ + if self.transport_thread is None: + self.transport_thread = ThreadPoolExecutor(max_workers=1) + + if tensor is not None: + tensor_size = tensor.element_size() * tensor.numel() + else: + tensor_size = 0 + + self.block_if_full() + + with self.buffer_size_lock: + self.buffer_size += tensor_size + + self.transport_thread.submit(self.send_tensor_wrapper, tensor, + tensor_size) + + def recv_tensor(self) -> Optional[torch.Tensor]: + """ + Receives a tensor and its metadata from the source rank. Blocking call. + + Returns: + - tensor: The received tensor, or None if no tensor is received. + """ + if self.transport_thread is None: + self.transport_thread = ThreadPoolExecutor(max_workers=1) + + future = self.transport_thread.submit(self._recv_impl) + + try: + tensor = future.result() + except Exception as e: + logger.error("Encountering exception in KV receiving thread") + logger.error("%s", e) + logger.error("My device: %s", self.device) + import traceback + traceback.print_exc() + raise e + + return tensor + + def close(self): + """ + Close the pipe and release associated resources. + """ + if hasattr(self, + "transport_thread") and self.transport_thread is not None: + self.transport_thread.shutdown() diff --git a/vllm/distributed/kv_transfer/kv_transfer_agent.py b/vllm/distributed/kv_transfer/kv_transfer_agent.py new file mode 100644 index 0000000000000..9ce97851dc849 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_transfer_agent.py @@ -0,0 +1,75 @@ +"""A centralized entrypoint to perform distributed KV cache transfer. + +This implementation is a shim wrapper on two APIs exposed by `kv_connector`: +1. `send_kv_caches_and_hidden_states` +2. `recv_kv_caches_and_hidden_states +""" +from typing import TYPE_CHECKING, List, Tuple, Union + +if TYPE_CHECKING: + from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata + from vllm.config import VllmConfig + +import torch + +from vllm.distributed.kv_transfer.kv_connector.factory import ( + KVConnectorFactory) +from vllm.logger import init_logger +from vllm.sequence import IntermediateTensors + +logger = init_logger(__name__) + + +class KVTransferAgent: + """ + A class designated for distributed KV transfer + + Target use cases: + 1. Disaggregated prefill + 2. Remote KV cache storage + """ + + def __init__( + self, + rank: int, + local_rank: int, + config: "VllmConfig", + ): + + self.config = config + + if config.kv_transfer_config is None: + raise ValueError("KVTransferConfig is not set in the VllmConfig," + " cannot initialize KVConnector.") + + assert self.config.kv_transfer_config.is_kv_transfer_instance, "KV"\ + "TransferAgent should only be used when kv_connector is set." + + self.connector = KVConnectorFactory.create_connector( + rank, local_rank, config) + + def send_kv_caches_and_hidden_states( + self, + model_executable: torch.nn.Module, + model_input: "ModelInputForGPUWithSamplingMetadata", + kv_caches: List[torch.Tensor], + hidden_or_intermediate_states: Union[torch.Tensor, + IntermediateTensors], + ) -> None: + + self.connector.send_kv_caches_and_hidden_states( + model_executable, model_input, kv_caches, + hidden_or_intermediate_states) + + def close(self) -> None: + self.connector.close() + + def recv_kv_caches_and_hidden_states( + self, model_executable: torch.nn.Module, + model_input: "ModelInputForGPUWithSamplingMetadata", + kv_caches: List[torch.Tensor] + ) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool, + "ModelInputForGPUWithSamplingMetadata"]: + + return self.connector.recv_kv_caches_and_hidden_states( + model_executable, model_input, kv_caches) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index ccbe00386c5da..34815d7f0aa78 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -27,18 +27,23 @@ from contextlib import contextmanager, nullcontext from dataclasses import dataclass from multiprocessing import shared_memory -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, + Union) from unittest.mock import patch import torch import torch.distributed from torch.distributed import Backend, ProcessGroup +import vllm.distributed.kv_transfer.kv_transfer_agent as kv_transfer import vllm.envs as envs from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.utils import direct_register_custom_op, supports_custom_op +if TYPE_CHECKING: + from vllm.config import VllmConfig + @dataclass class GraphCaptureContext: @@ -904,6 +909,14 @@ def get_pp_group() -> GroupCoordinator: # kept for backward compatibility get_pipeline_model_parallel_group = get_pp_group +_KV_TRANSFER: Optional[kv_transfer.KVTransferAgent] = None + + +def get_kv_transfer_group() -> kv_transfer.KVTransferAgent: + assert _KV_TRANSFER is not None, ( + "disaggregated KV cache transfer parallel group is not initialized") + return _KV_TRANSFER + @contextmanager def graph_capture(): @@ -1052,6 +1065,26 @@ def initialize_model_parallel( group_name="pp") +def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None: + """ + Initialize KV cache transfer parallel group. + """ + + global _KV_TRANSFER + + if vllm_config.kv_transfer_config is None: + return + + if all([ + vllm_config.kv_transfer_config.need_kv_parallel_group, + _KV_TRANSFER is None + ]): + _KV_TRANSFER = kv_transfer.KVTransferAgent( + rank=get_world_group().rank, + local_rank=get_world_group().local_rank, + config=vllm_config) + + def ensure_model_parallel_initialized( tensor_model_parallel_size: int, pipeline_model_parallel_size: int, diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 90b4798f17a13..3db069ec64ee4 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -9,10 +9,10 @@ import vllm.envs as envs from vllm.config import (CacheConfig, CompilationConfig, ConfigFormat, - DecodingConfig, DeviceConfig, HfOverrides, LoadConfig, - LoadFormat, LoRAConfig, ModelConfig, - ObservabilityConfig, ParallelConfig, PoolerConfig, - PromptAdapterConfig, SchedulerConfig, + DecodingConfig, DeviceConfig, HfOverrides, + KVTransferConfig, LoadConfig, LoadFormat, LoRAConfig, + ModelConfig, ObservabilityConfig, ParallelConfig, + PoolerConfig, PromptAdapterConfig, SchedulerConfig, SpeculativeConfig, TaskOption, TokenizerPoolConfig, VllmConfig) from vllm.executor.executor_base import ExecutorBase @@ -108,6 +108,7 @@ class EngineArgs: # notice. distributed_executor_backend: Optional[Union[str, Type[ExecutorBase]]] = None + # number of P/D disaggregation (or other disaggregation) workers pipeline_parallel_size: int = 1 tensor_parallel_size: int = 1 max_parallel_loading_workers: Optional[int] = None @@ -167,7 +168,7 @@ class EngineArgs: scheduler_delay_factor: float = 0.0 enable_chunked_prefill: Optional[bool] = None - guided_decoding_backend: str = 'outlines' + guided_decoding_backend: str = 'xgrammar' # Speculative decoding configuration. speculative_model: Optional[str] = None speculative_model_quantization: Optional[str] = None @@ -194,6 +195,8 @@ class EngineArgs: compilation_config: Optional[CompilationConfig] = None worker_cls: str = "auto" + kv_transfer_config: Optional[KVTransferConfig] = None + def __post_init__(self): if not self.tokenizer: self.tokenizer = self.model @@ -206,12 +209,9 @@ def __post_init__(self): # support `EngineArgs(compilation_config={...})` # without having to manually construct a # CompilationConfig object - if isinstance(self.compilation_config, (int)): + if isinstance(self.compilation_config, (int, dict)): self.compilation_config = CompilationConfig.from_cli( str(self.compilation_config)) - elif isinstance(self.compilation_config, (dict)): - self.compilation_config = CompilationConfig.from_cli( - json.dumps(self.compilation_config)) # Setup plugins from vllm.plugins import load_general_plugins @@ -361,11 +361,12 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: parser.add_argument( '--guided-decoding-backend', type=str, - default='outlines', - choices=['outlines', 'lm-format-enforcer'], + default='xgrammar', + choices=['outlines', 'lm-format-enforcer', 'xgrammar'], help='Which engine will be used for guided decoding' ' (JSON schema / regex etc) by default. Currently support ' - 'https://github.com/outlines-dev/outlines and ' + 'https://github.com/outlines-dev/outlines,' + 'https://github.com/mlc-ai/xgrammar, and ' 'https://github.com/noamgat/lm-format-enforcer.' ' Can be overridden per request via guided_decoding_backend' ' parameter.') @@ -416,15 +417,20 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: 'tokens. This is ignored on neuron devices and ' 'set to max-model-len') - parser.add_argument('--enable-prefix-caching', - action='store_true', - help='Enables automatic prefix caching.') + parser.add_argument( + "--enable-prefix-caching", + action=argparse.BooleanOptionalAction, + default=EngineArgs.enable_prefix_caching, + help="Enables automatic prefix caching. " + "Use --no-enable-prefix-caching to disable explicitly.", + ) parser.add_argument('--disable-sliding-window', action='store_true', help='Disables sliding window, ' 'capping to sliding window size') parser.add_argument('--use-v2-block-manager', action='store_true', + default=True, help='[DEPRECATED] block manager v1 has been ' 'removed and SelfAttnBlockSpaceManager (i.e. ' 'block manager v2) is now the default. ' @@ -904,6 +910,12 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: 'compilers, using -O without space is also ' 'supported. -O3 is equivalent to -O 3.') + parser.add_argument('--kv-transfer-config', + type=KVTransferConfig.from_cli, + default=None, + help='The configurations for distributed KV cache ' + 'transfer. Should be a JSON string.') + parser.add_argument( '--worker-cls', type=str, @@ -1038,9 +1050,12 @@ def create_engine_config(self, # long context (> 32K) models. This is to avoid OOM errors in the # initial memory profiling phase. - # Chunked prefill is currently disabled for multimodal models by - # default. - if use_long_context and not model_config.is_multimodal_model: + # For multimodal models, chunked prefill is disabled by default in + # V0, but enabled by design in V1 + if model_config.is_multimodal_model: + self.enable_chunked_prefill = bool(envs.VLLM_USE_V1) + + elif use_long_context: is_gpu = device_config.device_type == "cuda" use_sliding_window = (model_config.get_sliding_window() is not None) @@ -1097,7 +1112,7 @@ def create_engine_config(self, disable_logprobs=self.disable_logprobs_during_spec_decoding, ) - # Reminder: Please update docs/source/serving/compatibility_matrix.rst + # Reminder: Please update docs/source/usage/compatibility_matrix.rst # If the feature combo become valid if self.num_scheduler_steps > 1: if speculative_config is not None: @@ -1197,6 +1212,7 @@ def create_engine_config(self, observability_config=observability_config, prompt_adapter_config=prompt_adapter_config, compilation_config=self.compilation_config, + kv_transfer_config=self.kv_transfer_config, ) if envs.VLLM_USE_V1: @@ -1228,12 +1244,9 @@ def _override_v1_engine_config(self, engine_config: VllmConfig) -> None: Override the EngineConfig's configs based on the usage context for V1. """ assert envs.VLLM_USE_V1, "V1 is not enabled" - # TODO (ywang96): Enable APC by default when VLM supports it. if engine_config.model_config.is_multimodal_model: - logger.warning( - "Prefix caching is currently not supported for multimodal " - "models and has been disabled.") - engine_config.cache_config.enable_prefix_caching = False + # TODO (ywang96): Enable APC by default when VLM supports it. + assert not engine_config.cache_config.enable_prefix_caching @dataclass diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 31a15b04314d5..60dccd7a0812c 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -1,4 +1,5 @@ import asyncio +import copy import time import weakref from functools import partial @@ -6,6 +7,8 @@ List, Mapping, Optional, Set, Tuple, Type, Union, overload) from weakref import ReferenceType +from typing_extensions import deprecated + import vllm.envs as envs from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, VllmConfig) @@ -25,7 +28,7 @@ from vllm.model_executor.guided_decoding import ( get_guided_decoding_logits_processor) from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.outputs import EmbeddingRequestOutput, RequestOutput +from vllm.outputs import PoolingRequestOutput, RequestOutput from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams @@ -74,7 +77,7 @@ def _log_task_completion(task: asyncio.Task, class AsyncStream: - """A stream of RequestOutputs or EmbeddingRequestOutputs for a request + """A stream of RequestOutputs or PoolingRequestOutputs for a request that can be iterated over asynchronously via an async generator.""" def __init__(self, request_id: str, cancel: Callable[[str], None]) -> None: @@ -83,7 +86,7 @@ def __init__(self, request_id: str, cancel: Callable[[str], None]) -> None: self._queue: asyncio.Queue = asyncio.Queue() self._finished = False - def put(self, item: Union[RequestOutput, EmbeddingRequestOutput, + def put(self, item: Union[RequestOutput, PoolingRequestOutput, Exception]) -> None: if not self._finished: self._queue.put_nowait(item) @@ -103,7 +106,7 @@ def finished(self) -> bool: async def generator( self - ) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]: + ) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]: try: while True: result = await self._queue.get() @@ -154,7 +157,7 @@ def propagate_exception(self, def process_request_output(self, request_output: Union[RequestOutput, - EmbeddingRequestOutput], + PoolingRequestOutput], *, verbose: bool = False) -> None: """Process a request output from the engine.""" @@ -265,7 +268,7 @@ def __init__(self, *args, **kwargs): async def step_async( self, virtual_engine: int - ) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: + ) -> List[Union[RequestOutput, PoolingRequestOutput]]: """Performs one decoding iteration and returns newly generated results. The workers are ran asynchronously if possible. @@ -422,7 +425,8 @@ async def get_tokenizer_async(self, return await ( self.get_tokenizer_group().get_lora_tokenizer_async(lora_request)) - @overload # DEPRECATED + @overload + @deprecated("'inputs' will be renamed to 'prompt") async def add_request_async( self, request_id: str, @@ -504,7 +508,8 @@ async def add_request_async( sampling_params=params, tokenizer=await self.get_tokenizer_async(lora_request), default_guided_backend=self.decoding_config. - guided_decoding_backend) + guided_decoding_backend, + model_config=self.model_config) self._add_processed_request( request_id=request_id, @@ -525,22 +530,30 @@ async def check_health_async(self) -> None: async def build_guided_decoding_logits_processor_async( sampling_params: SamplingParams, tokenizer: AnyTokenizer, - default_guided_backend: str) -> SamplingParams: + default_guided_backend: str, + model_config: ModelConfig) -> SamplingParams: """Constructs logits processors based on the guided_decoding, logits_bias, and allowed_token_ids fields in sampling_params. Deletes those fields and adds the constructed logits processors to the logits_processors field. Modifies sampling params in-place and returns the modified sampling params.""" - if (guided_decoding := sampling_params.guided_decoding) is None: + if sampling_params.guided_decoding is None: return sampling_params + # Defensively copy sampling params since guided decoding logits + # processors can have different state for each request + sampling_params = copy.copy(sampling_params) + guided_decoding = sampling_params.guided_decoding + logger.debug("Building guided decoding logits processor. " "Params: %s", guided_decoding) guided_decoding.backend = guided_decoding.backend or default_guided_backend processor = await get_guided_decoding_logits_processor( - guided_params=guided_decoding, tokenizer=tokenizer) + guided_params=guided_decoding, + tokenizer=tokenizer, + model_config=model_config) if processor: if sampling_params.logits_processors is None: @@ -894,7 +907,8 @@ async def run_engine_loop(engine_ref: ReferenceType): # This method does not need to be async, but kept that way # for backwards compatibility. - @overload # DEPRECATED + @overload + @deprecated("'inputs' will be renamed to 'prompt") def add_request( self, request_id: str, @@ -907,7 +921,7 @@ def add_request( prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, ) -> Coroutine[None, None, AsyncGenerator[Union[ - RequestOutput, EmbeddingRequestOutput], None]]: + RequestOutput, PoolingRequestOutput], None]]: ... @overload @@ -922,7 +936,7 @@ def add_request( prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, ) -> Coroutine[None, None, AsyncGenerator[Union[ - RequestOutput, EmbeddingRequestOutput], None]]: + RequestOutput, PoolingRequestOutput], None]]: ... @deprecate_kwargs( @@ -941,7 +955,7 @@ async def add_request( priority: int = 0, *, inputs: Optional[PromptType] = None, # DEPRECATED - ) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]: + ) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]: if inputs is not None: prompt = inputs assert prompt is not None and params is not None @@ -1070,7 +1084,7 @@ async def encode( lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, priority: int = 0, - ) -> AsyncGenerator[EmbeddingRequestOutput, None]: + ) -> AsyncGenerator[PoolingRequestOutput, None]: """Generate outputs for a request from an embedding model. Generate outputs for a request. This method is a coroutine. It adds the @@ -1088,7 +1102,7 @@ async def encode( Only applicable with priority scheduling. Yields: - The output `EmbeddingRequestOutput` objects from the LLMEngine + The output `PoolingRequestOutput` objects from the LLMEngine for the request. Details: @@ -1141,7 +1155,7 @@ async def encode( trace_headers=trace_headers, priority=priority, ): - yield LLMEngine.validate_output(output, EmbeddingRequestOutput) + yield LLMEngine.validate_output(output, PoolingRequestOutput) async def abort(self, request_id: str) -> None: """Abort a request. diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index ecc222f692c41..26a8c94099a11 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1,3 +1,4 @@ +import copy import time from collections import Counter as collectionsCounter from collections import deque @@ -10,7 +11,7 @@ from typing import Set, Type, Union, cast, overload import torch -from typing_extensions import TypeVar +from typing_extensions import TypeVar, deprecated import vllm.envs as envs from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, @@ -40,7 +41,7 @@ get_local_guided_decoding_logits_processor) from vllm.model_executor.layers.sampler import SamplerOutput from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry -from vllm.outputs import (EmbeddingRequestOutput, RequestOutput, +from vllm.outputs import (PoolingRequestOutput, RequestOutput, RequestOutputFactory) from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest @@ -80,7 +81,7 @@ def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]: _G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup) -_O = TypeVar("_O", RequestOutput, EmbeddingRequestOutput) +_O = TypeVar("_O", RequestOutput, PoolingRequestOutput) @dataclass @@ -112,7 +113,7 @@ class SchedulerContext: def __init__(self, multi_step_stream_outputs: bool = False): self.output_queue: Deque[OutputData] = deque() self.request_outputs: List[Union[RequestOutput, - EmbeddingRequestOutput]] = [] + PoolingRequestOutput]] = [] self.seq_group_metadata_list: Optional[ List[SequenceGroupMetadata]] = None self.scheduler_outputs: Optional[SchedulerOutputs] = None @@ -472,6 +473,7 @@ def _initialize_kv_caches(self) -> None: The workers will determine the number of blocks in both the GPU cache and the swap CPU cache. """ + start = time.time() num_gpu_blocks, num_cpu_blocks = ( self.model_executor.determine_num_available_blocks()) @@ -487,6 +489,9 @@ def _initialize_kv_caches(self) -> None: self.cache_config.num_cpu_blocks = num_cpu_blocks self.model_executor.initialize_cache(num_gpu_blocks, num_cpu_blocks) + elapsed = time.time() - start + logger.info(("init engine (profile, create kv cache, " + "warmup model) took %.2f seconds"), elapsed) @classmethod def _get_executor_cls(cls, @@ -619,7 +624,7 @@ def _init_tokenizer(self) -> BaseTokenizerGroup: model_config=self.model_config, scheduler_config=self.scheduler_config, parallel_config=self.parallel_config, - enable_lora=bool(self.lora_config)) + lora_config=self.lora_config) def _verify_args(self) -> None: self.model_config.verify_with_parallel_config(self.parallel_config) @@ -719,7 +724,8 @@ def _add_processed_request( def stop_remote_worker_execution_loop(self) -> None: self.model_executor.stop_remote_worker_execution_loop() - @overload # DEPRECATED + @overload + @deprecated("'inputs' will be renamed to 'prompt") def add_request( self, request_id: str, @@ -1023,9 +1029,9 @@ def _update_num_computed_tokens_for_multi_step_prefill( This function updates num_computed_tokens for prompt sequences when Multi-Step is enabled. - seq_group: SequenceGroup to update the num_computed_tokens for. + seq_group: SequenceGroup to update the num_computed_tokens for. seq_group_meta: Metadata of the given SequenceGroup. - is_first_step_output: Optional[bool] - + is_first_step_output: Optional[bool] - When available, is_first_step_output indicates if the appended output token is the output of the first-step in multi-step. A value of None indicates that outputs from all steps in @@ -1314,7 +1320,7 @@ def _advance_to_next_step( else: seq.append_token_id(sample.output_token, sample.logprobs) - def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: + def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]: """Performs one decoding iteration and returns newly generated results. .. figure:: https://i.imgur.com/sv2HssD.png @@ -2035,7 +2041,11 @@ def _build_logits_processors( logits_processors = [] - if (guided_decoding := sampling_params.guided_decoding) is not None: + if sampling_params.guided_decoding is not None: + # Defensively copy sampling params since guided decoding logits + # processors can have different state for each request + sampling_params = copy.copy(sampling_params) + guided_decoding = sampling_params.guided_decoding logger.debug( "Building guided decoding logits processor in " @@ -2046,7 +2056,9 @@ def _build_logits_processors( self.decoding_config.guided_decoding_backend processor = get_local_guided_decoding_logits_processor( - guided_params=guided_decoding, tokenizer=tokenizer) + guided_params=guided_decoding, + tokenizer=tokenizer, + model_config=self.model_config) if processor: logits_processors.append(processor) diff --git a/vllm/engine/metrics.py b/vllm/engine/metrics.py index 5bfd6a9f4b386..a5ae21c3966a7 100644 --- a/vllm/engine/metrics.py +++ b/vllm/engine/metrics.py @@ -473,13 +473,13 @@ def log(self, stats: Stats) -> None: ) if (stats.cpu_prefix_cache_hit_rate >= 0 or stats.gpu_prefix_cache_hit_rate >= 0): - logger.info( + log_fn( "Prefix cache hit rate: GPU: %.2f%%, CPU: %.2f%%", stats.gpu_prefix_cache_hit_rate * 100, stats.cpu_prefix_cache_hit_rate * 100, ) if self.spec_decode_metrics is not None: - logger.info( + log_fn( self._format_spec_decode_metrics_str( self.spec_decode_metrics)) @@ -599,9 +599,9 @@ def _log_prometheus(self, stats: Stats) -> None: stats.time_queue_requests) self._log_histogram(self.metrics.histogram_inference_time_request, stats.time_inference_requests) - self._log_histogram(self.metrics.histogram_decode_time_request, - stats.time_prefill_requests) self._log_histogram(self.metrics.histogram_prefill_time_request, + stats.time_prefill_requests) + self._log_histogram(self.metrics.histogram_decode_time_request, stats.time_decode_requests) self._log_histogram(self.metrics.histogram_time_in_queue_request, stats.time_in_queue_requests) diff --git a/vllm/engine/multiprocessing/__init__.py b/vllm/engine/multiprocessing/__init__.py index 34c161e9395ae..7020012e8bb86 100644 --- a/vllm/engine/multiprocessing/__init__.py +++ b/vllm/engine/multiprocessing/__init__.py @@ -2,6 +2,8 @@ from enum import Enum from typing import List, Mapping, Optional, Union, overload +from typing_extensions import deprecated + from vllm import PoolingParams from vllm.inputs import PromptType from vllm.lora.request import LoRARequest @@ -32,7 +34,8 @@ class RPCProcessRequest: prompt_adapter_request: Optional[PromptAdapterRequest] = None priority: int = 0 - @overload # DEPRECATED + @overload + @deprecated("'inputs' will be renamed to 'prompt") def __init__( self, *, diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index fe21c58c775fe..7e4f81b2cf8e2 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -9,6 +9,7 @@ import psutil import zmq import zmq.asyncio +from typing_extensions import deprecated from zmq import Frame # type: ignore[attr-defined] from zmq.asyncio import Socket @@ -35,7 +36,7 @@ from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.outputs import EmbeddingRequestOutput, RequestOutput +from vllm.outputs import PoolingRequestOutput, RequestOutput from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs @@ -93,8 +94,7 @@ def __init__(self, ipc_path: str, engine_config: VllmConfig, model_config=self.model_config, scheduler_config=engine_config.scheduler_config, parallel_config=engine_config.parallel_config, - enable_lora=bool(engine_config.lora_config), - ) + lora_config=engine_config.lora_config) self.input_preprocessor = InputPreprocessor(self.model_config, self.tokenizer) @@ -414,7 +414,8 @@ def errored(self) -> bool: def dead_error(self) -> BaseException: return ENGINE_DEAD_ERROR(self._errored_with) - @overload # DEPRECATED + @overload + @deprecated("'inputs' will be renamed to 'prompt") def generate( self, *, @@ -472,8 +473,8 @@ def generate( trace_headers: OpenTelemetry trace headers. prompt_adapter_request: Prompt Adapter request to use for generation, if any. - priority: Priority of the request (lower means earlier handling). - Any priority other than 0 will lead to an error if the + priority: Priority of the request (lower means earlier handling). + Any priority other than 0 will lead to an error if the scheduling policy is not "priority". """ if inputs is not None: @@ -485,7 +486,8 @@ def generate( lora_request, trace_headers, prompt_adapter_request, priority) - @overload # DEPRECATED + @overload + @deprecated("'inputs' will be renamed to 'prompt") def encode( self, *, @@ -495,7 +497,7 @@ def encode( lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, priority: int = 0, - ) -> AsyncGenerator[EmbeddingRequestOutput, None]: + ) -> AsyncGenerator[PoolingRequestOutput, None]: ... @overload @@ -507,7 +509,7 @@ def encode( lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, priority: int = 0, - ) -> AsyncGenerator[EmbeddingRequestOutput, None]: + ) -> AsyncGenerator[PoolingRequestOutput, None]: ... @deprecate_kwargs( @@ -524,7 +526,7 @@ def encode( priority: int = 0, *, inputs: Optional[PromptType] = None # DEPRECATED - ) -> AsyncGenerator[EmbeddingRequestOutput, None]: + ) -> AsyncGenerator[PoolingRequestOutput, None]: """Generate outputs for a request from an embedding model. Generate outputs for a request. This method is a coroutine. It adds the @@ -540,7 +542,7 @@ def encode( trace_headers: OpenTelemetry trace headers. Yields: - The output `EmbeddingRequestOutput` objects from the LLMEngine + The output `PoolingRequestOutput` objects from the LLMEngine for the request. """ if inputs is not None: @@ -549,7 +551,7 @@ def encode( and request_id is not None) return cast( - AsyncGenerator[EmbeddingRequestOutput, None], + AsyncGenerator[PoolingRequestOutput, None], self._process_request(prompt, pooling_params, request_id, @@ -567,7 +569,7 @@ async def _process_request( prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, ) -> Union[AsyncGenerator[RequestOutput, None], AsyncGenerator[ - EmbeddingRequestOutput, None]]: + PoolingRequestOutput, None]]: """Send an RPCGenerateRequest to the RPCServer and stream responses.""" # If already dead, error out. @@ -586,6 +588,7 @@ async def _process_request( default_guided_backend=(self.decoding_config.guided_decoding_backend if self.decoding_config else DecodingConfig.guided_decoding_backend), + model_config=self.model_config ) # 1) Create output queue for this requests. diff --git a/vllm/engine/output_processor/multi_step.py b/vllm/engine/output_processor/multi_step.py index 7a6ebb430541f..a9b638ed02a1e 100644 --- a/vllm/engine/output_processor/multi_step.py +++ b/vllm/engine/output_processor/multi_step.py @@ -65,7 +65,7 @@ def process_prompt_logprob(self, seq_group: SequenceGroup, @staticmethod @functools.lru_cache def _log_prompt_logprob_unsupported_warning_once(): - # Reminder: Please update docs/source/serving/compatibility_matrix.rst + # Reminder: Please update docs/source/usage/compatibility_matrix.rst # If the feature combo become valid logger.warning( "Prompt logprob is not supported by multi step workers. " diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index e15395d75c91f..4079de7d36793 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -11,8 +11,7 @@ from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.outputs import (CompletionOutput, EmbeddingRequestOutput, - RequestOutput) +from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import BeamSearchParams, SamplingParams @@ -209,7 +208,7 @@ def encode( lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, priority: int = 0, - ) -> AsyncGenerator[EmbeddingRequestOutput, None]: + ) -> AsyncGenerator[PoolingRequestOutput, None]: """Generate outputs for a request from an embedding model.""" ... diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 1551a9a998160..8de30ccd18a11 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1,11 +1,11 @@ import itertools -import json import warnings from contextlib import contextmanager from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Tuple, Type, Union, cast, overload) from tqdm import tqdm +from typing_extensions import deprecated from vllm import envs from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput, @@ -26,7 +26,7 @@ from vllm.lora.request import LoRARequest from vllm.model_executor.guided_decoding.guided_fields import ( GuidedDecodingRequest, LLMGuidedOptions) -from vllm.outputs import EmbeddingRequestOutput, RequestOutput +from vllm.outputs import PoolingRequestOutput, RequestOutput from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams, @@ -185,12 +185,9 @@ def __init__( kwargs["disable_log_stats"] = True if compilation_config is not None: - if isinstance(compilation_config, (int)): + if isinstance(compilation_config, (int, dict)): compilation_config_instance = CompilationConfig.from_cli( str(compilation_config)) - elif isinstance(compilation_config, (dict)): - compilation_config_instance = CompilationConfig.from_cli( - json.dumps(compilation_config)) else: compilation_config_instance = compilation_config else: @@ -256,6 +253,7 @@ def set_tokenizer(self, tokenizer: AnyTokenizer) -> None: tokenizer_group.tokenizer = get_cached_tokenizer(tokenizer) @overload # LEGACY: single (prompt + optional token ids) + @deprecated("'prompt_token_ids' will become part of 'prompts") def generate( self, prompts: str, @@ -268,6 +266,7 @@ def generate( ... @overload # LEGACY: multi (prompt + optional token ids) + @deprecated("'prompt_token_ids' will become part of 'prompts") def generate( self, prompts: List[str], @@ -280,6 +279,7 @@ def generate( ... @overload # LEGACY: single (token ids + optional prompt) + @deprecated("'prompt_token_ids' will become part of 'prompts") def generate( self, prompts: Optional[str] = None, @@ -293,6 +293,7 @@ def generate( ... @overload # LEGACY: multi (token ids + optional prompt) + @deprecated("'prompt_token_ids' will become part of 'prompts") def generate( self, prompts: Optional[List[str]] = None, @@ -306,6 +307,7 @@ def generate( ... @overload # LEGACY: single or multi token ids [pos-only] + @deprecated("'prompt_token_ids' will become part of 'prompts") def generate( self, prompts: None, @@ -671,6 +673,7 @@ def chat( ) @overload # LEGACY: single (prompt + optional token ids) + @deprecated("'prompt_token_ids' will become part of 'prompts") def encode( self, prompts: str, @@ -679,10 +682,11 @@ def encode( prompt_token_ids: Optional[List[int]] = None, use_tqdm: bool = True, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, - ) -> List[EmbeddingRequestOutput]: + ) -> List[PoolingRequestOutput]: ... @overload # LEGACY: multi (prompt + optional token ids) + @deprecated("'prompt_token_ids' will become part of 'prompts") def encode( self, prompts: List[str], @@ -691,10 +695,11 @@ def encode( prompt_token_ids: Optional[List[List[int]]] = None, use_tqdm: bool = True, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, - ) -> List[EmbeddingRequestOutput]: + ) -> List[PoolingRequestOutput]: ... @overload # LEGACY: single (token ids + optional prompt) + @deprecated("'prompt_token_ids' will become part of 'prompts") def encode( self, prompts: Optional[str] = None, @@ -704,10 +709,11 @@ def encode( prompt_token_ids: List[int], use_tqdm: bool = True, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, - ) -> List[EmbeddingRequestOutput]: + ) -> List[PoolingRequestOutput]: ... @overload # LEGACY: multi (token ids + optional prompt) + @deprecated("'prompt_token_ids' will become part of 'prompts") def encode( self, prompts: Optional[List[str]] = None, @@ -717,10 +723,11 @@ def encode( prompt_token_ids: List[List[int]], use_tqdm: bool = True, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, - ) -> List[EmbeddingRequestOutput]: + ) -> List[PoolingRequestOutput]: ... @overload # LEGACY: single or multi token ids [pos-only] + @deprecated("'prompt_token_ids' will become part of 'prompts") def encode( self, prompts: None, @@ -728,7 +735,7 @@ def encode( prompt_token_ids: Union[List[int], List[List[int]]], use_tqdm: bool = True, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, - ) -> List[EmbeddingRequestOutput]: + ) -> List[PoolingRequestOutput]: ... @overload @@ -741,7 +748,7 @@ def encode( Sequence[PoolingParams]]] = None, use_tqdm: bool = True, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, - ) -> List[EmbeddingRequestOutput]: + ) -> List[PoolingRequestOutput]: ... @deprecate_kwargs( @@ -759,7 +766,7 @@ def encode( use_tqdm: bool = True, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, - ) -> List[EmbeddingRequestOutput]: + ) -> List[PoolingRequestOutput]: """Generates the completions for the input prompts. This class automatically batches the given prompts, considering @@ -778,7 +785,7 @@ def encode( generation, if any. Returns: - A list of ``EmbeddingRequestOutput`` objects containing the + A list of ``PoolingRequestOutput`` objects containing the generated embeddings in the same order as the input prompts. Note: @@ -821,7 +828,7 @@ def encode( outputs = self._run_engine(use_tqdm=use_tqdm) return self.engine_class.validate_outputs(outputs, - EmbeddingRequestOutput) + PoolingRequestOutput) def score( self, @@ -832,7 +839,7 @@ def score( use_tqdm: bool = True, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, - ) -> List[EmbeddingRequestOutput]: + ) -> List[PoolingRequestOutput]: """Generates similarity scores for all pairs . The inputs can be 1 -> 1, 1 -> N or N -> N. In the 1 - N case @@ -854,7 +861,7 @@ def score( generation, if any. Returns: - A list of ``EmbeddingRequestOutput`` objects containing the + A list of ``PoolingRequestOutput`` objects containing the generated scores in the same order as the input prompts. """ task = self.llm_engine.model_config.task @@ -943,7 +950,7 @@ def ensure_str(prompt: SingletonPrompt): outputs = self._run_engine(use_tqdm=use_tqdm) return self.engine_class.validate_outputs(outputs, - EmbeddingRequestOutput) + PoolingRequestOutput) def start_profile(self) -> None: self.llm_engine.start_profile() @@ -1085,7 +1092,7 @@ def _add_guided_params( def _run_engine( self, *, use_tqdm: bool - ) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: + ) -> List[Union[RequestOutput, PoolingRequestOutput]]: # Initialize tqdm. if use_tqdm: num_requests = self.llm_engine.get_num_unfinished_requests() @@ -1098,7 +1105,7 @@ def _run_engine( ) # Run the engine. - outputs: List[Union[RequestOutput, EmbeddingRequestOutput]] = [] + outputs: List[Union[RequestOutput, PoolingRequestOutput]] = [] total_in_toks = 0 total_out_toks = 0 while self.llm_engine.has_unfinished_requests(): diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 6bc31ef83ded4..c7bc30040279c 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -175,8 +175,8 @@ async def build_async_engine_client_from_engine_args( # Select random path for IPC. ipc_path = get_open_zmq_ipc_path() - logger.info("Multiprocessing frontend to use %s for IPC Path.", - ipc_path) + logger.debug("Multiprocessing frontend to use %s for IPC Path.", + ipc_path) # Start RPCServer in separate process (holds the LLMEngine). # the current process might have CUDA context, @@ -249,8 +249,8 @@ def mount_metrics(app: FastAPI): prometheus_multiproc_dir_path = os.getenv("PROMETHEUS_MULTIPROC_DIR", None) if prometheus_multiproc_dir_path is not None: - logger.info("vLLM to use %s as PROMETHEUS_MULTIPROC_DIR", - prometheus_multiproc_dir_path) + logger.debug("vLLM to use %s as PROMETHEUS_MULTIPROC_DIR", + prometheus_multiproc_dir_path) registry = CollectorRegistry() multiprocess.MultiProcessCollector(registry) diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 936aae8f1c267..fc1c4908d6650 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -101,7 +101,7 @@ async def create_completion( tokenizer = await self.engine_client.get_tokenizer(lora_request) - request_prompts, engine_prompts = self._preprocess_completion( + request_prompts, engine_prompts = await self._preprocess_completion( request, tokenizer, request.prompt, diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index c84a7d2d8e13e..2cbb252610e39 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -18,14 +18,14 @@ ErrorResponse, UsageInfo) from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing from vllm.logger import init_logger -from vllm.outputs import EmbeddingOutput, EmbeddingRequestOutput +from vllm.outputs import PoolingOutput, PoolingRequestOutput from vllm.utils import merge_async_iterators, random_uuid logger = init_logger(__name__) def _get_embedding( - output: EmbeddingOutput, + output: PoolingOutput, encoding_format: Literal["float", "base64"], ) -> Union[List[float], str]: if encoding_format == "float": @@ -40,7 +40,7 @@ def _get_embedding( def request_output_to_embedding_response( - final_res_batch: List[EmbeddingRequestOutput], request_id: str, + final_res_batch: List[PoolingRequestOutput], request_id: str, created_time: int, model_name: str, encoding_format: Literal["float", "base64"]) -> EmbeddingResponse: data: List[EmbeddingResponseData] = [] @@ -156,19 +156,20 @@ async def create_embedding( add_special_tokens=request.add_special_tokens, ) else: - request_prompts, engine_prompts = self._preprocess_completion( - request, - tokenizer, - request.input, - truncate_prompt_tokens=truncate_prompt_tokens, - add_special_tokens=request.add_special_tokens, - ) + (request_prompts, + engine_prompts) = await self._preprocess_completion( + request, + tokenizer, + request.input, + truncate_prompt_tokens=truncate_prompt_tokens, + add_special_tokens=request.add_special_tokens, + ) except ValueError as e: logger.exception("Error in preprocessing prompt inputs") return self.create_error_response(str(e)) # Schedule the request and get the result generator. - generators: List[AsyncGenerator[EmbeddingRequestOutput, None]] = [] + generators: List[AsyncGenerator[PoolingRequestOutput, None]] = [] try: pooling_params = request.to_pooling_params() @@ -206,7 +207,7 @@ async def create_embedding( num_prompts = len(engine_prompts) # Non-streaming response - final_res_batch: List[Optional[EmbeddingRequestOutput]] + final_res_batch: List[Optional[PoolingRequestOutput]] final_res_batch = [None] * num_prompts try: async for i, res in result_generator: @@ -214,7 +215,7 @@ async def create_embedding( assert all(final_res is not None for final_res in final_res_batch) - final_res_batch_checked = cast(List[EmbeddingRequestOutput], + final_res_batch_checked = cast(List[PoolingRequestOutput], final_res_batch) response = request_output_to_embedding_response( diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index cae2877ea7e99..8232c6116c1bd 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -1,5 +1,6 @@ import json import pathlib +from concurrent.futures.thread import ThreadPoolExecutor from dataclasses import dataclass from http import HTTPStatus from typing import (Any, Callable, Dict, Iterable, Iterator, List, Mapping, @@ -46,7 +47,7 @@ from vllm.tracing import (contains_trace_headers, extract_trace_headers, log_tracing_disabled_warning) from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer -from vllm.utils import AtomicCounter, is_list_of +from vllm.utils import AtomicCounter, is_list_of, make_async logger = init_logger(__name__) @@ -140,6 +141,14 @@ def __init__( self.request_logger = request_logger self.return_tokens_as_token_ids = return_tokens_as_token_ids + self._tokenizer_executor = ThreadPoolExecutor(max_workers=1) + + self._tokenize_prompt_input_async = make_async( + self._tokenize_prompt_input, executor=self._tokenizer_executor) + self._tokenize_prompt_input_or_inputs_async = make_async( + self._tokenize_prompt_input_or_inputs, + executor=self._tokenizer_executor) + async def show_available_models(self) -> ModelList: """Show available models. Right now we only have one model.""" model_cards = [ @@ -368,7 +377,7 @@ def _tokenize_prompt_input_or_inputs( input_or_inputs: Union[str, List[str], List[int], List[List[int]]], truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None, add_special_tokens: bool = True, - ) -> Iterator[TextTokensPrompt]: + ) -> List[TextTokensPrompt]: """ Tokenize/detokenize depending on the input format. @@ -376,45 +385,41 @@ def _tokenize_prompt_input_or_inputs( , each input can be a string or array of tokens. Note that each request can pass one or more inputs. """ - for prompt_input in parse_and_batch_prompt(input_or_inputs): - # Although our type checking is based on mypy, - # VSCode Pyright extension should still work properly - # "is True" is required for Pyright to perform type narrowing - # See: https://github.com/microsoft/pyright/issues/7672 - if prompt_input["is_tokens"] is False: - yield self._normalize_prompt_text_to_input( - request, - tokenizer, - prompt=prompt_input["content"], - truncate_prompt_tokens=truncate_prompt_tokens, - add_special_tokens=add_special_tokens, - ) - else: - yield self._normalize_prompt_tokens_to_input( - request, - tokenizer, - prompt_ids=prompt_input["content"], - truncate_prompt_tokens=truncate_prompt_tokens, - ) + # Although our type checking is based on mypy, + # VSCode Pyright extension should still work properly + # "is True" is required for Pyright to perform type narrowing + # See: https://github.com/microsoft/pyright/issues/7672 + return [ + self._normalize_prompt_text_to_input( + request, + tokenizer, + prompt=prompt_input["content"], + truncate_prompt_tokens=truncate_prompt_tokens, + add_special_tokens=add_special_tokens) + if prompt_input["is_tokens"] is False else + self._normalize_prompt_tokens_to_input( + request, + tokenizer, + prompt_ids=prompt_input["content"], + truncate_prompt_tokens=truncate_prompt_tokens) + for prompt_input in parse_and_batch_prompt(input_or_inputs) + ] - def _preprocess_completion( + async def _preprocess_completion( self, request: CompletionLikeRequest, tokenizer: AnyTokenizer, input_or_inputs: Union[str, List[str], List[int], List[List[int]]], truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None, add_special_tokens: bool = True, - ) -> Tuple[Sequence[TextTokensPrompt], List[TokensPrompt]]: - request_prompts = [ - request_prompt - for request_prompt in self._tokenize_prompt_input_or_inputs( - request, - tokenizer, - input_or_inputs, - truncate_prompt_tokens=truncate_prompt_tokens, - add_special_tokens=add_special_tokens, - ) - ] + ) -> Tuple[List[TextTokensPrompt], List[TokensPrompt]]: + request_prompts = await self._tokenize_prompt_input_or_inputs_async( + request, + tokenizer, + input_or_inputs, + truncate_prompt_tokens=truncate_prompt_tokens, + add_special_tokens=add_special_tokens, + ) engine_prompts = [ TokensPrompt(prompt_token_ids=request_prompt["prompt_token_ids"]) @@ -493,7 +498,7 @@ async def _preprocess_chat( request=request) if isinstance(request_prompt, str): - prompt_inputs = self._tokenize_prompt_input( + prompt_inputs = await self._tokenize_prompt_input_async( request, tokenizer, request_prompt, diff --git a/vllm/entrypoints/openai/serving_score.py b/vllm/entrypoints/openai/serving_score.py index 156fea6f47982..a1f14449ba9c3 100644 --- a/vllm/entrypoints/openai/serving_score.py +++ b/vllm/entrypoints/openai/serving_score.py @@ -13,15 +13,15 @@ from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing from vllm.inputs.data import TokensPrompt from vllm.logger import init_logger -from vllm.outputs import EmbeddingRequestOutput +from vllm.outputs import PoolingRequestOutput from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer -from vllm.utils import merge_async_iterators, random_uuid +from vllm.utils import make_async, merge_async_iterators, random_uuid logger = init_logger(__name__) def request_output_to_score_response( - final_res_batch: List[EmbeddingRequestOutput], request_id: str, + final_res_batch: List[PoolingRequestOutput], request_id: str, created_time: int, model_name: str) -> ScoreResponse: data: List[ScoreResponseData] = [] score = None @@ -133,7 +133,7 @@ async def create_score( return self.create_error_response(str(e)) # Schedule the request and get the result generator. - generators: List[AsyncGenerator[EmbeddingRequestOutput, None]] = [] + generators: List[AsyncGenerator[PoolingRequestOutput, None]] = [] input_pairs = make_pairs(request.text_1, request.text_2) @@ -145,9 +145,11 @@ async def create_score( tokenization_kwargs["truncation"] = True tokenization_kwargs["max_length"] = truncate_prompt_tokens - prompt_inputs = tokenizer(text=q, - text_pair=t, - **tokenization_kwargs) + tokenize_async = make_async(tokenizer.__call__, + executor=self._tokenizer_executor) + prompt_inputs = await tokenize_async(text=q, + text_pair=t, + **tokenization_kwargs) engine_prompt = TokensPrompt( prompt_token_ids=prompt_inputs["input_ids"], token_type_ids=prompt_inputs.get("token_type_ids")) @@ -192,7 +194,7 @@ async def create_score( num_prompts = len(engine_prompts) # Non-streaming response - final_res_batch: List[Optional[EmbeddingRequestOutput]] + final_res_batch: List[Optional[PoolingRequestOutput]] final_res_batch = [None] * num_prompts try: @@ -201,7 +203,7 @@ async def create_score( assert all(final_res is not None for final_res in final_res_batch) - final_res_batch_checked = cast(List[EmbeddingRequestOutput], + final_res_batch_checked = cast(List[PoolingRequestOutput], final_res_batch) response = request_output_to_score_response( diff --git a/vllm/entrypoints/openai/serving_tokenization.py b/vllm/entrypoints/openai/serving_tokenization.py index 59b3b1311f881..9c3dc2c98b2dd 100644 --- a/vllm/entrypoints/openai/serving_tokenization.py +++ b/vllm/entrypoints/openai/serving_tokenization.py @@ -81,12 +81,13 @@ async def create_tokenize( add_special_tokens=request.add_special_tokens, ) else: - request_prompts, engine_prompts = self._preprocess_completion( - request, - tokenizer, - request.prompt, - add_special_tokens=request.add_special_tokens, - ) + (request_prompts, + engine_prompts) = await self._preprocess_completion( + request, + tokenizer, + request.prompt, + add_special_tokens=request.add_special_tokens, + ) except ValueError as e: logger.exception("Error in preprocessing prompt inputs") return self.create_error_response(str(e)) @@ -134,7 +135,7 @@ async def create_detokenize( # Silently ignore prompt adapter since it does not affect tokenization # (Unlike in Embeddings API where an error is raised) - prompt_input = self._tokenize_prompt_input( + prompt_input = await self._tokenize_prompt_input_async( request, tokenizer, request.tokens, diff --git a/vllm/envs.py b/vllm/envs.py index c896770e5f6bc..ab12a7b48dc53 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -8,7 +8,6 @@ VLLM_RPC_BASE_PATH: str = tempfile.gettempdir() VLLM_USE_MODELSCOPE: bool = False VLLM_RINGBUFFER_WARNING_INTERVAL: int = 60 - VLLM_INSTANCE_ID: Optional[str] = None VLLM_NCCL_SO_PATH: Optional[str] = None LD_LIBRARY_PATH: Optional[str] = None VLLM_USE_TRITON_FLASH_ATTN: bool = False @@ -113,7 +112,8 @@ def get_default_config_root(): # If set, vllm will use precompiled binaries (*.so) "VLLM_USE_PRECOMPILED": - lambda: bool(os.environ.get("VLLM_USE_PRECOMPILED")), + lambda: bool(os.environ.get("VLLM_USE_PRECOMPILED")) or bool( + os.environ.get("VLLM_PRECOMPILED_WHEEL_LOCATION")), # CMake build type # If not set, defaults to "Debug" or "RelWithDebInfo" @@ -174,11 +174,6 @@ def get_default_config_root(): "VLLM_USE_MODELSCOPE": lambda: os.environ.get("VLLM_USE_MODELSCOPE", "False").lower() == "true", - # Instance id represents an instance of the VLLM. All processes in the same - # instance should have the same instance id. - "VLLM_INSTANCE_ID": - lambda: os.environ.get("VLLM_INSTANCE_ID", None), - # Interval in seconds to log a warning message when the ring buffer is full "VLLM_RINGBUFFER_WARNING_INTERVAL": lambda: int(os.environ.get("VLLM_RINGBUFFER_WARNING_INTERVAL", "60")), diff --git a/vllm/executor/cpu_executor.py b/vllm/executor/cpu_executor.py index 336f9bc8efb20..2816b5c5c1f88 100644 --- a/vllm/executor/cpu_executor.py +++ b/vllm/executor/cpu_executor.py @@ -10,8 +10,7 @@ from vllm.model_executor.layers.sampler import SamplerOutput from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sequence import ExecuteModelRequest -from vllm.utils import (get_distributed_init_method, get_open_port, - get_vllm_instance_id, make_async) +from vllm.utils import get_distributed_init_method, get_open_port, make_async from vllm.worker.worker_base import WorkerWrapperBase logger = init_logger(__name__) @@ -23,7 +22,7 @@ class CPUExecutor(ExecutorBase): def _init_executor(self) -> None: assert self.device_config.device_type == "cpu" - # Reminder: Please update docs/source/serving/compatibility_matrix.rst + # Reminder: Please update docs/source/usage/compatibility_matrix.rst # If the feature combo become valid assert self.lora_config is None, "cpu backend doesn't support LoRA" @@ -31,9 +30,6 @@ def _init_executor(self) -> None: # Environment variables for CPU executor # - # Ensure that VLLM_INSTANCE_ID is set, to be inherited by workers - os.environ["VLLM_INSTANCE_ID"] = get_vllm_instance_id() - # Disable torch async compiling which won't work with daemonic processes os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1" diff --git a/vllm/executor/multiproc_gpu_executor.py b/vllm/executor/multiproc_gpu_executor.py index a6c05a71d2b6f..c450209f0eb91 100644 --- a/vllm/executor/multiproc_gpu_executor.py +++ b/vllm/executor/multiproc_gpu_executor.py @@ -16,7 +16,7 @@ from vllm.triton_utils.importing import HAS_TRITON from vllm.utils import (_run_task_with_lock, cuda_device_count_stateless, cuda_is_initialized, get_distributed_init_method, - get_open_port, get_vllm_instance_id, make_async, + get_open_port, make_async, update_environment_variables) if HAS_TRITON: @@ -37,9 +37,6 @@ def _init_executor(self) -> None: world_size = self.parallel_config.world_size tensor_parallel_size = self.parallel_config.tensor_parallel_size - # Ensure that VLLM_INSTANCE_ID is set, to be inherited by workers - os.environ["VLLM_INSTANCE_ID"] = get_vllm_instance_id() - # Disable torch async compiling which won't work with daemonic processes os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1" diff --git a/vllm/executor/neuron_executor.py b/vllm/executor/neuron_executor.py index 31e6fdc3ab1bb..a9efc4f9a801c 100644 --- a/vllm/executor/neuron_executor.py +++ b/vllm/executor/neuron_executor.py @@ -29,11 +29,13 @@ def _init_worker(self): wrapper = WorkerWrapperBase(vllm_config=self.vllm_config) distributed_init_method = get_distributed_init_method( get_ip(), get_open_port()) - self.driver_worker = wrapper.init_worker( + wrapper.init_worker( vllm_config=self.vllm_config, local_rank=0, rank=0, - distributed_init_method=distributed_init_method) + distributed_init_method=distributed_init_method, + ) + self.driver_worker = wrapper.worker self.driver_worker.init_device() self.driver_worker.load_model() diff --git a/vllm/executor/openvino_executor.py b/vllm/executor/openvino_executor.py index db0070ce510ee..057a32364e512 100644 --- a/vllm/executor/openvino_executor.py +++ b/vllm/executor/openvino_executor.py @@ -36,7 +36,7 @@ def _init_worker(self): distributed_init_method = get_distributed_init_method( get_ip(), get_open_port()) - self.driver_worker = wrapper.init_worker( + wrapper.init_worker( ov_core=ov.Core(), vllm_config=self.vllm_config, local_rank=0, @@ -45,6 +45,7 @@ def _init_worker(self): kv_cache_dtype=self.cache_config.cache_dtype, is_driver_worker=True, ) + self.driver_worker = wrapper.worker self.driver_worker.init_device() self.driver_worker.load_model() diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index 6542b18ae70b1..4263fb27265f6 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -15,8 +15,7 @@ from vllm.model_executor.layers.sampler import SamplerOutput from vllm.sequence import ExecuteModelRequest from vllm.utils import (_run_task_with_lock, get_distributed_init_method, - get_ip, get_open_port, get_vllm_instance_id, - make_async) + get_ip, get_open_port, make_async) if ray is not None: from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy @@ -189,8 +188,14 @@ def sort_by_driver_then_worker_ip(worker): self.workers = sorted(self.workers, key=sort_by_driver_then_worker_ip) # Get the set of GPU IDs used on each node. - worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids", - use_dummy_driver=True) + worker_node_and_gpu_ids = [] + for worker in [self.driver_dummy_worker] + self.workers: + if worker is None: + # driver_dummy_worker can be None when using ray spmd worker. + continue + worker_node_and_gpu_ids.append( + ray.get(worker.get_node_and_gpu_ids.remote()) \ + ) # type: ignore node_workers = defaultdict(list) # node id -> list of worker ranks node_gpus = defaultdict(list) # node id -> list of gpu ids @@ -220,14 +225,10 @@ def sort_by_driver_then_worker_ip(worker): " environment variable, make sure it is unique for" " each node.") - VLLM_INSTANCE_ID = get_vllm_instance_id() - # Set environment variables for the driver and workers. all_args_to_update_environment_variables = [({ "CUDA_VISIBLE_DEVICES": ",".join(map(str, node_gpus[node_id])), - "VLLM_INSTANCE_ID": - VLLM_INSTANCE_ID, "VLLM_TRACE_FUNCTION": str(envs.VLLM_TRACE_FUNCTION), **({ @@ -334,7 +335,6 @@ def _run_workers( async_run_tensor_parallel_workers_only: bool = False, all_args: Optional[List[Tuple[Any, ...]]] = None, all_kwargs: Optional[List[Dict[str, Any]]] = None, - use_dummy_driver: bool = False, max_concurrent_workers: Optional[int] = None, **kwargs, ) -> Any: @@ -394,18 +394,10 @@ def _run_workers( driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0] # Start the driver worker after all the ray workers. - if not use_dummy_driver: - driver_worker_output = [ - self.driver_worker.execute_method(method, *driver_args, - **driver_kwargs) - ] - else: - assert self.driver_dummy_worker is not None - driver_worker_output = [ - ray.get( - self.driver_dummy_worker.execute_method.remote( - method, *driver_args, **driver_kwargs)) - ] + driver_worker_output = [ + self.driver_worker.execute_method(method, *driver_args, + **driver_kwargs) + ] # Get the results of the ray workers. if self.workers: diff --git a/vllm/executor/ray_hpu_executor.py b/vllm/executor/ray_hpu_executor.py index a74328e5aa272..f3025cb537ab8 100644 --- a/vllm/executor/ray_hpu_executor.py +++ b/vllm/executor/ray_hpu_executor.py @@ -15,8 +15,7 @@ from vllm.model_executor.layers.sampler import SamplerOutput from vllm.sequence import ExecuteModelRequest from vllm.utils import (_run_task_with_lock, get_distributed_init_method, - get_ip, get_open_port, get_vllm_instance_id, - make_async) + get_ip, get_open_port, make_async) if ray is not None: from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy @@ -164,9 +163,14 @@ def sort_by_driver_then_worker_ip(worker): # node will be placed first. self.workers = sorted(self.workers, key=sort_by_driver_then_worker_ip) - # Get the set of GPU IDs used on each node. - worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids", - use_dummy_driver=True) + worker_node_and_gpu_ids = [] + for worker in [self.driver_dummy_worker] + self.workers: + if worker is None: + # driver_dummy_worker can be None when using ray spmd worker. + continue + worker_node_and_gpu_ids.append( + ray.get(worker.get_node_and_gpu_ids.remote()) \ + ) # type: ignore node_workers = defaultdict(list) # node id -> list of worker ranks node_gpus = defaultdict(list) # node id -> list of gpu ids @@ -196,12 +200,8 @@ def sort_by_driver_then_worker_ip(worker): "environment variable, make sure it is unique for" " each node.") - VLLM_INSTANCE_ID = get_vllm_instance_id() - # Set environment variables for the driver and workers. all_args_to_update_environment_variables = [({ - "VLLM_INSTANCE_ID": - VLLM_INSTANCE_ID, "VLLM_TRACE_FUNCTION": str(envs.VLLM_TRACE_FUNCTION), }, ) for (node_id, _) in worker_node_and_gpu_ids] @@ -301,7 +301,6 @@ def _run_workers( async_run_tensor_parallel_workers_only: bool = False, all_args: Optional[List[Tuple[Any, ...]]] = None, all_kwargs: Optional[List[Dict[str, Any]]] = None, - use_dummy_driver: bool = False, max_concurrent_workers: Optional[int] = None, **kwargs, ) -> Any: @@ -361,18 +360,10 @@ def _run_workers( driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0] # Start the driver worker after all the ray workers. - if not use_dummy_driver: - driver_worker_output = [ - self.driver_worker.execute_method(method, *driver_args, - **driver_kwargs) - ] - else: - assert self.driver_dummy_worker is not None - driver_worker_output = [ - ray.get( - self.driver_dummy_worker.execute_method.remote( - method, *driver_args, **driver_kwargs)) - ] + driver_worker_output = [ + self.driver_worker.execute_method(method, *driver_args, + **driver_kwargs) + ] # Get the results of the ray workers. if self.workers: diff --git a/vllm/executor/ray_tpu_executor.py b/vllm/executor/ray_tpu_executor.py index c227b5e283c68..5118c13934f0d 100644 --- a/vllm/executor/ray_tpu_executor.py +++ b/vllm/executor/ray_tpu_executor.py @@ -13,7 +13,7 @@ from vllm.model_executor.layers.sampler import SamplerOutput from vllm.sequence import ExecuteModelRequest from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, - get_vllm_instance_id, make_async) + make_async) if ray is not None: from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy @@ -137,19 +137,21 @@ def sort_by_driver_then_worker_ip(worker): self.workers = sorted(self.workers, key=sort_by_driver_then_worker_ip) # Get the set of TPU IDs used on each node. - worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids", - use_dummy_driver=True) + worker_node_and_gpu_ids = [] + for worker in [self.driver_dummy_worker] + self.workers: + if worker is None: + # driver_dummy_worker can be None when using ray spmd worker. + continue + worker_node_and_gpu_ids.append( + ray.get(worker.get_node_and_gpu_ids.remote()) \ + ) # type: ignore node_workers = defaultdict(list) for i, (node_id, _) in enumerate(worker_node_and_gpu_ids): node_workers[node_id].append(i) - VLLM_INSTANCE_ID = get_vllm_instance_id() - # Set environment variables for the driver and workers. all_args_to_update_environment_variables = [({ - "VLLM_INSTANCE_ID": - VLLM_INSTANCE_ID, "VLLM_TRACE_FUNCTION": str(envs.VLLM_TRACE_FUNCTION), }, ) for _ in worker_node_and_gpu_ids] @@ -203,7 +205,6 @@ def _run_workers( async_run_remote_workers_only: bool = False, all_args: Optional[List[Tuple[Any, ...]]] = None, all_kwargs: Optional[List[Dict[str, Any]]] = None, - use_dummy_driver: bool = False, max_concurrent_workers: Optional[int] = None, use_ray_compiled_dag: bool = False, **kwargs, @@ -245,14 +246,8 @@ def _run_workers( driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0] # Start the driver worker after all the ray workers. - if not use_dummy_driver: - driver_worker_output = self.driver_worker.execute_method( - method, *driver_args, **driver_kwargs) - else: - assert self.driver_dummy_worker is not None - driver_worker_output = ray.get( - self.driver_dummy_worker.execute_method.remote( - method, *driver_args, **driver_kwargs)) + driver_worker_output = self.driver_worker.execute_method( + method, *driver_args, **driver_kwargs) # Get the results of the ray workers. if self.workers: ray_worker_outputs = ray.get(ray_worker_outputs) diff --git a/vllm/executor/ray_xpu_executor.py b/vllm/executor/ray_xpu_executor.py index 2b1cdc09b0a9f..d2086f5fef26c 100644 --- a/vllm/executor/ray_xpu_executor.py +++ b/vllm/executor/ray_xpu_executor.py @@ -1,11 +1,13 @@ import asyncio from typing import List, Optional +import ray + import vllm.envs as envs from vllm.executor.ray_gpu_executor import RayGPUExecutor, RayGPUExecutorAsync from vllm.executor.xpu_executor import XPUExecutor from vllm.logger import init_logger -from vllm.utils import get_vllm_instance_id, make_async +from vllm.utils import make_async logger = init_logger(__name__) @@ -14,15 +16,16 @@ class RayXPUExecutor(RayGPUExecutor, XPUExecutor): def _get_env_vars_to_be_updated(self): # Get the set of GPU IDs used on each node. - worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids", - use_dummy_driver=True) - - VLLM_INSTANCE_ID = get_vllm_instance_id() + worker_node_and_gpu_ids = [] + for worker in [self.driver_dummy_worker] + self.workers: + if worker is None: + # driver_dummy_worker can be None when using ray spmd worker. + continue + worker_node_and_gpu_ids.append( + ray.get(worker.get_node_and_gpu_ids.remote())) # type: ignore # Set environment variables for the driver and workers. all_args_to_update_environment_variables = [({ - "VLLM_INSTANCE_ID": - VLLM_INSTANCE_ID, "VLLM_TRACE_FUNCTION": str(envs.VLLM_TRACE_FUNCTION), }, ) for (_, _) in worker_node_and_gpu_ids] diff --git a/vllm/inputs/__init__.py b/vllm/inputs/__init__.py index 54fbd7a321a6f..d4402e77a3886 100644 --- a/vllm/inputs/__init__.py +++ b/vllm/inputs/__init__.py @@ -38,34 +38,3 @@ "InputProcessingContext", "InputRegistry", ] - - -def __getattr__(name: str): - import warnings - - if name == "PromptInput": - msg = ("PromptInput has been renamed to PromptType. " - "The original name will be removed in an upcoming version.") - - warnings.warn(DeprecationWarning(msg), stacklevel=2) - - return PromptType - - if name == "LLMInputs": - msg = ("LLMInputs has been renamed to DecoderOnlyInputs. " - "The original name will be removed in an upcoming version.") - - warnings.warn(DeprecationWarning(msg), stacklevel=2) - - return DecoderOnlyInputs - - if name == "EncoderDecoderLLMInputs": - msg = ( - "EncoderDecoderLLMInputs has been renamed to EncoderDecoderInputs. " - "The original name will be removed in an upcoming version.") - - warnings.warn(DeprecationWarning(msg), stacklevel=2) - - return EncoderDecoderInputs - - raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index fb7dbbebd7b90..85aaaa776907f 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -7,7 +7,8 @@ from typing_extensions import NotRequired, TypedDict, TypeVar, assert_never if TYPE_CHECKING: - from vllm.multimodal import MultiModalDataDict, MultiModalPlaceholderDict + from vllm.multimodal import (MultiModalDataDict, MultiModalKwargs, + MultiModalPlaceholderDict) from vllm.multimodal.inputs import MultiModalInputsV2 @@ -150,6 +151,12 @@ class TokenInputs(TypedDict): if the model supports it. """ + multi_modal_inputs: NotRequired["MultiModalKwargs"] + """ + Optional multi-modal inputs to pass to the model, + if the model supports it. + """ + multi_modal_placeholders: NotRequired["MultiModalPlaceholderDict"] """ Placeholder ranges for the multi-modal data. @@ -169,6 +176,7 @@ def token_inputs( token_type_ids: Optional[List[int]] = None, prompt: Optional[str] = None, multi_modal_data: Optional["MultiModalDataDict"] = None, + multi_modal_inputs: Optional["MultiModalKwargs"] = None, multi_modal_placeholders: Optional["MultiModalPlaceholderDict"] = None, mm_processor_kwargs: Optional[Dict[str, Any]] = None, ) -> TokenInputs: @@ -181,6 +189,8 @@ def token_inputs( inputs["token_type_ids"] = token_type_ids if multi_modal_data is not None: inputs["multi_modal_data"] = multi_modal_data + if multi_modal_inputs is not None: + inputs["multi_modal_inputs"] = multi_modal_inputs if multi_modal_placeholders is not None: inputs["multi_modal_placeholders"] = multi_modal_placeholders if mm_processor_kwargs is not None: @@ -273,6 +283,18 @@ def multi_modal_data(self) -> "MultiModalDataDict": assert_never(inputs) + @cached_property + def multi_modal_inputs(self) -> Union[Dict, "MultiModalKwargs"]: + inputs = self.inputs + + if inputs["type"] == "token": + return inputs.get("multi_modal_inputs", {}) + + if inputs["type"] == "multimodal": + return inputs.get("mm_kwargs", {}) + + assert_never(inputs) + @cached_property def multi_modal_placeholders(self) -> "MultiModalPlaceholderDict": inputs = self.inputs @@ -358,34 +380,3 @@ def to_enc_dec_tuple_list( return [(enc_dec_prompt["encoder_prompt"], enc_dec_prompt["decoder_prompt"]) for enc_dec_prompt in enc_dec_prompts] - - -def __getattr__(name: str): - import warnings - - if name == "PromptInput": - msg = ("PromptInput has been renamed to PromptType. " - "The original name will be removed in an upcoming version.") - - warnings.warn(DeprecationWarning(msg), stacklevel=2) - - return PromptType - - if name == "LLMInputs": - msg = ("LLMInputs has been renamed to DecoderOnlyInputs. " - "The original name will be removed in an upcoming version.") - - warnings.warn(DeprecationWarning(msg), stacklevel=2) - - return DecoderOnlyInputs - - if name == "EncoderDecoderLLMInputs": - msg = ( - "EncoderDecoderLLMInputs has been renamed to EncoderDecoderInputs. " - "The original name will be removed in an upcoming version.") - - warnings.warn(DeprecationWarning(msg), stacklevel=2) - - return EncoderDecoderInputs - - raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index 68b4756331e6d..646554c72481a 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -11,8 +11,8 @@ from vllm.logger import init_logger from vllm.transformers_utils.processor import cached_get_processor from vllm.transformers_utils.tokenizer import AnyTokenizer -from vllm.utils import (get_allowed_kwarg_only_overrides, print_warning_once, - resolve_mm_processor_kwargs) +from vllm.utils import (ClassRegistry, get_allowed_kwarg_only_overrides, + print_warning_once, resolve_mm_processor_kwargs) from .data import ProcessorInputs, SingletonInputs from .parse import is_encoder_decoder_inputs @@ -136,12 +136,12 @@ class InputRegistry: """ def __init__(self) -> None: - self._dummy_factories_by_model_type: Dict[Type[nn.Module], - DummyDataFactory] = {} - self._dummy_encoder_factories_by_model_type: Dict[ - Type[nn.Module], DummyDataFactory] = {} - self._input_processors_by_model_type: Dict[Type[nn.Module], - InputProcessor] = {} + self._dummy_factories_by_model_type = \ + ClassRegistry[nn.Module, DummyDataFactory]() + self._dummy_encoder_factories_by_model_type = \ + ClassRegistry[nn.Module, DummyDataFactory]() + self._input_processors_by_model_type = \ + ClassRegistry[nn.Module, InputProcessor]() def _default_dummy_data_factory( self, @@ -232,19 +232,35 @@ def dummy_data_for_profiling( """ # Avoid circular import from vllm.model_executor.model_loader import get_model_architecture - - model_cls, _ = get_model_architecture(model_config) - if is_encoder_data: - dummy_factory = self._get_dummy_encoder_data_factory(model_cls) + from vllm.multimodal import MultiModalKwargs + from vllm.multimodal.utils import cached_get_tokenizer + + if mm_registry.has_processor(model_config): + tokenizer = cached_get_tokenizer( + model_config.tokenizer, + trust_remote_code=model_config.trust_remote_code, + ) + processor = mm_registry.create_processor(model_config, tokenizer) + + mm_counts = mm_registry.get_mm_limits_per_prompt(model_config) + mm_max_tokens = mm_registry.get_max_tokens_by_modality( + model_config) + + dummy_data = processor.get_dummy_data(seq_len, mm_counts, + mm_max_tokens) else: - dummy_factory = self._get_dummy_data_factory(model_cls) - mm_counts = mm_registry.get_mm_limits_per_prompt(model_config) - mm_processor_kwargs = get_allowed_kwarg_only_overrides( - dummy_factory, overrides=model_config.mm_processor_kwargs) + model_cls, _ = get_model_architecture(model_config) + if is_encoder_data: + dummy_factory = self._get_dummy_encoder_data_factory(model_cls) + else: + dummy_factory = self._get_dummy_data_factory(model_cls) + mm_counts = mm_registry.get_mm_limits_per_prompt(model_config) + mm_processor_kwargs = get_allowed_kwarg_only_overrides( + dummy_factory, overrides=model_config.mm_processor_kwargs) - dummy_data = dummy_factory(InputContext(model_config), seq_len, - _MultiModalCounts(mm_counts), - **mm_processor_kwargs) + dummy_data = dummy_factory(InputContext(model_config), seq_len, + _MultiModalCounts(mm_counts), + **mm_processor_kwargs) # Having more tokens is over-conservative but otherwise fine num_tokens = dummy_data.seq_data.prompt_token_ids @@ -257,7 +273,9 @@ def dummy_data_for_profiling( raise AssertionError( f"Expected at least {seq_len} dummy tokens for profiling, " f"but found {len(num_tokens)} tokens instead.") - if dummy_data.multi_modal_data is not None: + + if (dummy_data.multi_modal_data is not None and + not isinstance(dummy_data.multi_modal_data, MultiModalKwargs)): for k, v in dummy_data.multi_modal_data.items(): num_items = len(v) if isinstance(v, list) else 1 num_expected = mm_counts[k] diff --git a/vllm/lora/fully_sharded_layers.py b/vllm/lora/fully_sharded_layers.py index f5c2eced9d2bb..545ec21ca74c1 100644 --- a/vllm/lora/fully_sharded_layers.py +++ b/vllm/lora/fully_sharded_layers.py @@ -1,5 +1,5 @@ # pylint: disable=unused-argument -from typing import TYPE_CHECKING, List, Optional, Union +from typing import TYPE_CHECKING, List, Optional, Tuple, Union, cast import torch import torch.nn as nn @@ -32,6 +32,44 @@ def dec(*args, **kwargs): return dec +def _mcp_apply(x, bias, layer: ColumnParallelLinearWithLoRA): + """ + For `ColumnParallelLinearWithLoRA` or classes that inherit from + `ColumnParallelLinearWithLoRA`, they share the same `apply` logic. + """ + assert (layer.n_slices == len(layer.lora_a_stacked) == len( + layer.lora_b_stacked) == len(layer.output_slices)) + if layer.lora_bias_stacked is not None: + assert layer.n_slices == len(layer.lora_bias_stacked) + + output = layer.base_layer.quant_method.apply(layer.base_layer, x, bias) + + x = x.view(-1, x.shape[-1]) + output, out_orig_shape = output.view(-1, output.shape[-1]), output.shape + + # Since communication is needed, the buffer is directly initialized as a + # tensor rather than a tuple of tensor. + buffers = torch.zeros( + (layer.n_slices, x.shape[0], layer.lora_a_stacked[0].shape[2]), + dtype=torch.float32, + device=x.device, + ) + + layer.punica_wrapper.add_shrink(buffers, x, layer.lora_a_stacked, 1.0) + buffers = tensor_model_parallel_all_gather(buffers) + layer.punica_wrapper.add_expand(output, + buffers, + layer.lora_b_stacked, + layer.lora_bias_stacked, + layer.output_slices, + offset_start=0, + add_input=True) + + output = output.view(*out_orig_shape) + # now have column partitioned and packed output + return output + + # these layers are based on the tensor parallelism strategy given in # Y. Sheng et al., S-LoRA: Serving Thousands of Concurrent LoRA Adapters. 2023, # https://arxiv.org/abs/2311.03285. @@ -51,40 +89,15 @@ class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA): # gather operation. def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: tp_rank = get_tensor_model_parallel_rank() - shard_size = self.lora_a_stacked.shape[2] + shard_size = self.lora_a_stacked[0].shape[2] start_idx = tp_rank * shard_size lora_a = lora_a[:, start_idx:start_idx + shard_size] return lora_a - def apply(self, x: torch.Tensor, - bias: Optional[torch.Tensor]) -> torch.Tensor: - output = self.base_layer.quant_method.apply(self.base_layer, x, bias) - - x = x.view(-1, x.shape[-1]) - output, out_orig_shape = output.view(-1, - output.shape[-1]), output.shape - buffer = torch.zeros( - (x.shape[0], self.lora_a_stacked.shape[2]), - dtype=torch.float32, - device=x.device, - ) - self.punica_wrapper.add_shrink(buffer, x, self.lora_a_stacked, 1.0) - buffer = tensor_model_parallel_all_gather(buffer) - self.punica_wrapper.add_expand(output, - buffer, - self.lora_b_stacked, - add_input=True) - # now have column partitioned output - - if self.bias_stacked is not None: - self.bias_stacked = self.bias_stacked.view( - -1, self.bias_stacked.shape[-1]) - self.bias_stacked = self.bias_stacked[ - self.punica_wrapper.token_lora_indices] - output += self.bias_stacked - - output = output.view(*out_orig_shape) - return output + def apply(self, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + return _mcp_apply(x, bias, self) @classmethod @_fully_sharded_can_replace @@ -105,59 +118,6 @@ def can_replace_layer( ) -def _mcp_apply(x, bias, layer: QKVParallelLinearWithLora): - """ - MergedColumnParallelLinearWithShardedLoRA and - MergedQKVParallelLinearWithShardedLora share the same - LoRa weight application method. - - The main difference is the step by shard_size for lora_b which can - vary for MergedQKVParallelLinearWithShardedLora but is constant for - MergedColumnParallelLinearWithShardedLoRA. - """ - # expecting 2 for column parallel and 3 for qkv - n = len(layer.lora_a_stacked) - output = layer.base_layer.quant_method.apply(layer.base_layer, x, bias) - - x = x.view(-1, x.shape[-1]) - output, out_orig_shape = output.view(-1, output.shape[-1]), output.shape - buffers = torch.zeros( - (n, x.shape[0], layer.lora_a_stacked[0].shape[2]), - dtype=torch.float32, - device=x.device, - ) - for idx in range(n): - layer.punica_wrapper.add_shrink(buffers[idx], x, - layer.lora_a_stacked[idx], 1.0) - - buffers = tensor_model_parallel_all_gather(buffers) - left_offset = 0 - for idx in range(n): - shard_size = layer.lora_b_stacked[idx].shape[2] - - if layer.bias_stacked is not None: - bias = layer.bias_stacked[idx] - if bias is not None: - bias = bias.view(-1, bias.shape[-1]) - bias = bias[layer.punica_wrapper.token_lora_indices] - bias[layer.punica_wrapper.token_lora_indices == -1] = 0 - output[:, left_offset:left_offset + shard_size] += bias - - layer.punica_wrapper.add_expand_slice( - output, - buffers[idx], - layer.lora_b_stacked[idx], - left_offset, - shard_size, - add_input=True, - ) - left_offset += shard_size - - output = output.view(*out_orig_shape) - # now have column partitioned and packed output - return output - - class MergedColumnParallelLinearWithShardedLoRA( MergedColumnParallelLinearWithLoRA): """ @@ -181,8 +141,9 @@ def slice_lora_a( ] return lora_a - def apply(self, x: torch.Tensor, - bias: Optional[torch.Tensor]) -> torch.Tensor: + def apply(self, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: return _mcp_apply(x, bias, self) @classmethod @@ -214,30 +175,15 @@ class QKVParallelLinearWithShardedLora(QKVParallelLinearWithLora): def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: tp_rank = get_tensor_model_parallel_rank() - shard_size = self.lora_a_stacked.shape[2] + shard_size = self.lora_a_stacked[0].shape[2] start_idx = tp_rank * shard_size lora_a = lora_a[:, start_idx:start_idx + shard_size] return lora_a - def apply(self, x: torch.Tensor, - bias: Optional[torch.Tensor]) -> torch.Tensor: - output = self.base_layer.quant_method.apply(self.base_layer, x, bias) - - x = x.view(-1, x.shape[-1]) - output, out_orig_shape = output.view(-1, - output.shape[-1]), output.shape - buffer = torch.zeros((x.shape[0], self.lora_a_stacked.shape[2]), - dtype=torch.float32, - device=x.device) - self.punica_wrapper.add_shrink(buffer, x, self.lora_a_stacked, 1.0) - buffer = tensor_model_parallel_all_gather(buffer) - self.punica_wrapper.add_expand(output, - buffer, - self.lora_b_stacked, - add_input=True) - # now have column partitioned output - output = output.view(*out_orig_shape) - return output + def apply(self, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + return _mcp_apply(x, bias, self) @classmethod @_fully_sharded_can_replace @@ -278,8 +224,9 @@ def slice_lora_a( ] return lora_a - def apply(self, x: torch.Tensor, - bias: Optional[torch.Tensor]) -> torch.Tensor: + def apply(self, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: return _mcp_apply(x, bias, self) @classmethod @@ -312,7 +259,7 @@ class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA): """ def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: - shard_size = self.lora_b_stacked.shape[2] + shard_size = self.lora_b_stacked[0].shape[2] start_idx = self.tp_rank * shard_size end_idx = (self.tp_rank + 1) * shard_size lora_b = lora_b[:, start_idx:end_idx] @@ -321,20 +268,24 @@ def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: def slice_bias(self, bias: torch.Tensor) -> torch.Tensor: if bias is None: return bias - shard_size = self.bias_stacked.shape[2] + self.lora_bias_stacked = cast(Tuple[torch.Tensor, ...], + self.lora_bias_stacked) + shard_size = self.lora_bias_stacked[0].shape[2] start_idx = self.tp_rank * shard_size end_idx = (self.tp_rank + 1) * shard_size bias = bias[start_idx:end_idx] return bias - def apply(self, x: torch.Tensor) -> torch.Tensor: + def apply(self, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: output = self.base_layer.quant_method.apply(self.base_layer, x) x = x.view(-1, x.shape[-1]) output, out_orig_shape = output.view(-1, output.shape[-1]), output.shape buffer = torch.zeros( - (x.shape[0], self.lora_a_stacked.shape[2]), + (self.n_slices, x.shape[0], self.lora_a_stacked[0].shape[2]), dtype=torch.float32, device=x.device, ) @@ -348,18 +299,18 @@ def apply(self, x: torch.Tensor) -> torch.Tensor: # remains is a standard all_reduce. User should be aware though that # the output is not the same as a normal row_parallel, it should be # reduced before being used - shard_size = self.lora_b_stacked.shape[2] - start_idx = self.tp_rank * shard_size - - if self.bias_stacked is not None: - bias = self.bias_stacked.view(-1, self.bias_stacked.shape[-1]) - bias = bias[self.punica_wrapper.token_lora_indices] - bias[self.punica_wrapper.token_lora_indices == -1] = 0 - output += bias - - self.punica_wrapper.add_expand_slice(output, buffer, - self.lora_b_stacked, start_idx, - shard_size) + # NOTE offset are based on the rank. + shard_size = self.lora_b_stacked[0].shape[2] + offset_start = self.tp_rank * shard_size + self.punica_wrapper.add_expand( + output, + buffer, + self.lora_b_stacked, + self.lora_bias_stacked, + self.output_slices, + offset_start=offset_start, + add_input=True, + ) output = output.view(*out_orig_shape) return output diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 3701988ff692f..3e9c2ceb83eac 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -1,7 +1,7 @@ # pylint: disable=unused-argument import math from dataclasses import dataclass -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union, cast import torch import torch.nn as nn @@ -18,11 +18,14 @@ tensor_model_parallel_gather) from vllm.distributed.utils import divide from vllm.lora.punica import PunicaWrapper +# yapf: disable from vllm.model_executor.layers.linear import (ColumnParallelLinear, + LinearBase, MergedColumnParallelLinear, QKVParallelLinear, ReplicatedLinear, RowParallelLinear) +# yapf: enable from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.rotary_embedding import ( LinearScalingRotaryEmbedding, RotaryEmbedding) @@ -67,63 +70,6 @@ def dec(*args, **kwargs): return dec -def apply_bias( - indices: torch.Tensor, - output: torch.Tensor, - bias_stacked: torch.Tensor, -): - """Applies bias to output - - Input shapes: - bias_stacked: (num_loras, output_dim) - indices: (batch_size) - output: (batch_size, output_dim) - """ - org_output = output - output = output.view(-1, output.shape[-1]) - indices = indices.view(-1) - - bias_stacked = bias_stacked.view(-1, bias_stacked.shape[-1]) - bias_stacked = bias_stacked[indices] - bias_stacked[indices == -1] = 0 - output += bias_stacked - - return output.view_as(org_output) - - -def apply_bias_packed_nslice( - indices: torch.Tensor, - output: torch.Tensor, - output_slices: Tuple[int, ...], - bias_stacked: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], -): - """Applies bias to output - - Input shapes: - bias_stacked: 3 element tuple of (num_loras, output_dim) - indices: (batch_size) - output: (batch_size, q_slice_size + 2*kv_slice_size) - output_slices: n-1 element tuple of (slice_size...), - where n is number of slices - """ - org_output = output - output = output.view(-1, output.shape[-1]) - indices = indices.view(-1) - - offset_left = 0 - for slice_idx, slice in enumerate(output_slices): - bias = bias_stacked[slice_idx] - if bias is not None: - bias = bias.view(-1, bias.shape[-1]) - bias = bias[indices] - bias[indices == -1] = 0 - output[:, offset_left:offset_left + slice] += bias - - offset_left += slice - - return output.view_as(org_output) - - @dataclass class LoRAMapping(AdapterMapping): is_prefill: bool = False @@ -306,12 +252,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: full_lora_a_embeddings.shape[1], -1, ) - - # Embedding layer only need expand op - self.punica_wrapper.add_expand(full_output, - full_lora_a_embeddings, - self.lora_b_stacked, - add_input=True) + self.punica_wrapper.add_lora_embedding(full_output, + full_lora_a_embeddings, + self.lora_b_stacked, + add_input=True) return full_output.view_as(full_output_org) @classmethod @@ -325,14 +269,19 @@ def can_replace_layer( return type(source_layer) is VocabParallelEmbedding -class ReplicatedLinearWithLoRA(BaseLayerWithLoRA): +class BaseLinearLayerWithLoRA(BaseLayerWithLoRA): - def __init__(self, base_layer: ReplicatedLinear) -> None: + def __init__(self, base_layer: LinearBase): super().__init__() self.base_layer = base_layer self.input_size = self.base_layer.input_size - self.output_size = self.base_layer.output_size self.device = _get_lora_device(self.base_layer) + self.lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]] = None + + self.output_slices: Tuple[int, ...] + self.tp_size: int + self.output_size: int + self.n_slices: int def create_lora_weights( self, @@ -341,39 +290,64 @@ def create_lora_weights( model_config: Optional[PretrainedConfig] = None, ) -> None: self.lora_config = lora_config - lora_a_output_size = lora_config.max_lora_rank - self.lora_a_stacked = torch.zeros( - max_loras, - 1, - lora_a_output_size, - self.input_size, - dtype=lora_config.lora_dtype, - device=self.device, - ) - self.lora_b_stacked = torch.zeros( - max_loras, - 1, - self.output_size, - lora_config.max_lora_rank, - dtype=lora_config.lora_dtype, - device=self.device, - ) - if lora_config.bias_enabled: - self.bias_stacked = torch.zeros( + # + if isinstance(self.base_layer, ReplicatedLinear): + lora_a_out_size = lora_config.max_lora_rank + lora_b_out_size = self.output_size + + elif isinstance(self.base_layer, ColumnParallelLinear): + lora_a_out_size = (lora_config.max_lora_rank if + not lora_config.fully_sharded_loras else divide( + lora_config.max_lora_rank, self.tp_size)) + lora_b_out_size = self.output_size + + elif isinstance(self.base_layer, RowParallelLinear): + lora_a_out_size = lora_config.max_lora_rank + lora_b_out_size = (self.output_size if + not lora_config.fully_sharded_loras else divide( + self.output_size, self.tp_size)) + else: + raise NotImplementedError + + self.lora_a_stacked = tuple( + torch.zeros( max_loras, 1, - self.output_size, + lora_a_out_size, + self.input_size, dtype=lora_config.lora_dtype, device=self.device, - ) - else: - self.bias_stacked = None + ) for _ in range(self.n_slices)) + self.lora_b_stacked = tuple( + torch.zeros( + max_loras, + 1, + lora_b_out_size, + lora_config.max_lora_rank, + dtype=lora_config.lora_dtype, + device=self.device, + ) for _ in range(self.n_slices)) + if lora_config.bias_enabled: + lora_bias_out_size = lora_b_out_size + self.lora_bias_stacked = tuple( + torch.zeros( + max_loras, + 1, + lora_bias_out_size, + dtype=lora_config.lora_dtype, + device=self.device, + ) for _ in range(self.n_slices)) + self.output_slices = (self.lora_b_stacked[0].shape[2], ) def reset_lora(self, index: int): - self.lora_a_stacked[index] = 0 - self.lora_b_stacked[index] = 0 - if self.lora_config.bias_enabled: - self.bias_stacked[index] = 0 + for s_index in range(self.n_slices): + self.lora_a_stacked[s_index][index] = 0 + self.lora_b_stacked[s_index][index] = 0 + if self.lora_config.bias_enabled: + # Make mypy happy + self.lora_bias_stacked = cast(Tuple[torch.Tensor, ...], + self.lora_bias_stacked) + self.lora_bias_stacked[s_index][index] = 0 def set_lora( self, @@ -381,35 +355,56 @@ def set_lora( lora_a: torch.Tensor, lora_b: torch.Tensor, embeddings_tensor: Optional[torch.Tensor], - bias: Optional[torch.Tensor] = None, + lora_bias: Optional[torch.Tensor] = None, ): - self.reset_lora(index) + # Except for QKVParallelLinearWithLora and + # MergedColumnParallelLinearWithLoRA, all other linear LoRA layers + # store weights in a tuple of size 1. These two layers will + # override this function. + assert (len(self.lora_a_stacked) == len(self.lora_b_stacked) == + self.n_slices == 1) - self.lora_a_stacked[index, - 0, :lora_a.shape[1], :lora_a.shape[0]].copy_( - lora_a.T, non_blocking=True) - self.lora_b_stacked[index, - 0, :lora_b.shape[1], :lora_b.shape[0]].copy_( - lora_b.T, non_blocking=True) - if bias is not None: - self.bias_stacked[index, - 0, :bias.shape[0]].copy_(bias.T, - non_blocking=True) - - def apply(self, x: torch.Tensor, - bias: Optional[torch.Tensor]) -> torch.Tensor: + self.reset_lora(index) + if self.tp_size > 1: + lora_a = self.slice_lora_a(lora_a) + lora_b = self.slice_lora_b(lora_b) + if lora_bias is not None: + lora_bias = self.slice_bias(lora_bias) + + self.lora_a_stacked[0][index, + 0, :lora_a.shape[1], :lora_a.shape[0]].copy_( + lora_a.T, non_blocking=True) + self.lora_b_stacked[0][index, + 0, :lora_b.shape[1], :lora_b.shape[0]].copy_( + lora_b.T, non_blocking=True) + if lora_bias is not None: + + self.lora_bias_stacked = cast(Tuple[torch.Tensor, ...], + self.lora_bias_stacked) + assert len(self.lora_bias_stacked) + self.lora_bias_stacked[0][index, 0, :lora_bias.shape[0]].copy_( + lora_bias.T, non_blocking=True) + + def apply(self, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: output = self.base_layer.quant_method.apply(self.base_layer, x, bias) - if self.bias_stacked is not None: - self.indices = self.punica_wrapper.token_lora_indices - output = apply_bias( - self.indices, - output, - self.bias_stacked, - ) - self.punica_wrapper.add_lora(output, x, self.lora_a_stacked, - self.lora_b_stacked, 1.0) + self.punica_wrapper.add_lora_linear(output, x, self.lora_a_stacked, + self.lora_b_stacked, + self.lora_bias_stacked, 1.0, + self.output_slices) return output + +class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA): + + def __init__(self, base_layer: ReplicatedLinear) -> None: + super().__init__(base_layer, ) + # To ensure interface compatibility, set to 1 always. + self.tp_size = 1 + self.output_size = self.base_layer.output_size + self.n_slices = 1 + def forward(self, input_): """Forward of ReplicatedLinearWithLoRA @@ -442,73 +437,26 @@ def can_replace_layer( return type(source_layer) is ReplicatedLinear -class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA): +class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA): """ LoRA on top of ColumnParallelLinear layer. - LoRA B is sliced for tensor parallelism. + There are two types for the `base_layer`: + 1. ColumnParallelLinear, e.g.`dense_h_to_4h` in `FalconForCausalLM`. + 2. MergedColumnParallelLinear, e.g.`gate_up_proj` in `Phi3ForCausalLM`. """ def __init__(self, base_layer: ColumnParallelLinear) -> None: - super().__init__() + super().__init__(base_layer) # The base_layer type is ColumnParallelLinear or # MergedColumnParallelLinear, their weight sharding logic is # inconsistent when TP is greater than 1. self.is_merged_col_linear = type( base_layer) is MergedColumnParallelLinear - - self.base_layer = base_layer self.tp_size = get_tensor_model_parallel_world_size() - self.input_size = self.base_layer.input_size self.output_size = self.base_layer.output_size_per_partition - self.device = _get_lora_device(self.base_layer) - - def create_lora_weights( - self, - max_loras: int, - lora_config: LoRAConfig, - model_config: Optional[PretrainedConfig] = None, - ) -> None: - self.lora_config = lora_config - self.tp_size = get_tensor_model_parallel_world_size() - lora_a_output_size_per_partition = ( - lora_config.max_lora_rank if not lora_config.fully_sharded_loras - else divide(lora_config.max_lora_rank, self.tp_size)) - self.lora_a_stacked = torch.zeros( - max_loras, - 1, - lora_a_output_size_per_partition, - self.input_size, - dtype=lora_config.lora_dtype, - device=self.device, - ) - self.lora_b_stacked = torch.zeros( - max_loras, - 1, - self.output_size, - lora_config.max_lora_rank, - dtype=lora_config.lora_dtype, - device=self.device, - ) - - if lora_config.bias_enabled: - self.bias_stacked = torch.zeros( - max_loras, - 1, - self.output_size, - dtype=lora_config.lora_dtype, - device=self.device, - ) - else: - self.bias_stacked = None - - self.output_dim = self.lora_b_stacked.shape[2] - - def reset_lora(self, index: int): - self.lora_a_stacked[index] = 0 - self.lora_b_stacked[index] = 0 - if self.lora_config.bias_enabled: - self.bias_stacked[index] = 0 + # There is only one LoRA layer + self.n_slices = 1 def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: return lora_a @@ -547,46 +495,6 @@ def slice_bias(self, bias: torch.Tensor) -> torch.Tensor: bias = bias[start_idx:end_idx] return bias - def set_lora( - self, - index: int, - lora_a: torch.Tensor, - lora_b: torch.Tensor, - embeddings_tensor: Optional[torch.Tensor], - bias: Optional[torch.Tensor] = None, - ): - self.reset_lora(index) - - if self.tp_size > 1: - lora_a = self.slice_lora_a(lora_a) - lora_b = self.slice_lora_b(lora_b) - bias = self.slice_bias(bias) - - self.lora_a_stacked[index, - 0, :lora_a.shape[1], :lora_a.shape[0]].copy_( - lora_a.T, non_blocking=True) - self.lora_b_stacked[index, - 0, :lora_b.shape[1], :lora_b.shape[0]].copy_( - lora_b.T, non_blocking=True) - if bias is not None: - self.bias_stacked[index, - 0, :bias.shape[0]].copy_(bias.T, - non_blocking=True) - - def apply(self, x: torch.Tensor, - bias: Optional[torch.Tensor]) -> torch.Tensor: - output = self.base_layer.quant_method.apply(self.base_layer, x, bias) - if self.bias_stacked is not None: - self.indices = self.punica_wrapper.token_lora_indices - output = apply_bias( - self.indices, - output, - self.bias_stacked, - ) - self.punica_wrapper.add_lora(output, x, self.lora_a_stacked, - self.lora_b_stacked, 1.0) - return output - def forward(self, input_): """Forward of ColumnParallelLinear @@ -634,8 +542,20 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): Both slices must have the same size. """ - def __init__(self, base_layer: MergedColumnParallelLinear) -> None: + def __init__( + self, base_layer: Union[MergedColumnParallelLinear, + QKVParallelLinear]) -> None: super().__init__(base_layer) + # There are two LoRA layers + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + # the output_sizes in MergedColumnParallelLinear is not sharded by tp + # we need to divide it by the tp_size to get correct slices size + output_sizes = self.base_layer.output_sizes + self.output_slices = tuple( + divide(output_size, self.tp_size) for output_size in output_sizes) + self.n_slices = len(self.output_slices) + self.output_ids = (self.tp_rank, ) * self.n_slices def create_lora_weights( self, @@ -643,16 +563,11 @@ def create_lora_weights( lora_config: LoRAConfig, model_config: Optional[PretrainedConfig] = None, ) -> None: + """ + The main reason for overriding this function is to enhance code + maintainability. + """ self.lora_config = lora_config - n_slices = 2 - if not (len(self.base_layer.output_sizes) == n_slices - and self.base_layer.output_sizes[0] - == self.base_layer.output_sizes[1]): - raise ValueError( - "LoRAColumnParallelLinear2Slice requires 2 slices with " - "the same size.") - self.tp_size = get_tensor_model_parallel_world_size() - self.tp_rank = get_tensor_model_parallel_rank() lora_a_output_size_per_partition = ( lora_config.max_lora_rank if not lora_config.fully_sharded_loras @@ -666,38 +581,25 @@ def create_lora_weights( self.input_size, dtype=lora_config.lora_dtype, device=self.device, - ) for _ in range(n_slices)) + ) for _ in range(self.n_slices)) self.lora_b_stacked = tuple( torch.zeros( max_loras, 1, - self.output_size // 2, + output_size, lora_config.max_lora_rank, dtype=lora_config.lora_dtype, device=self.device, - ) for _ in range(n_slices)) + ) for output_size in self.output_slices) if lora_config.bias_enabled: - self.bias_stacked = tuple( + self.lora_bias_stacked = tuple( torch.zeros( max_loras, 1, - self.output_size // 2, + output_size, dtype=lora_config.lora_dtype, device=self.device, - ) for _ in range(n_slices)) - else: - self.bias_stacked = None - - self.output_dim = self.lora_b_stacked[0].shape[2] - - def reset_lora(self, index: int): - self.lora_a_stacked[0][index] = 0 - self.lora_a_stacked[1][index] = 0 - self.lora_b_stacked[0][index] = 0 - self.lora_b_stacked[1][index] = 0 - if self.lora_config.bias_enabled: - self.bias_stacked[0][index] = 0 - self.bias_stacked[1][index] = 0 + ) for output_size in self.output_slices) def slice_lora_a( self, lora_a: List[Union[torch.Tensor, None]] @@ -707,27 +609,21 @@ def slice_lora_a( def slice_lora_b( self, lora_b: List[Union[torch.Tensor, None]] ) -> List[Union[torch.Tensor, None]]: - #NOTE: lora_b contains 2 subloras, and each sublora could be None. - shard_size = self.output_dim - start_idx = self.tp_rank * shard_size - end_idx = (self.tp_rank + 1) * shard_size - lora_b = [ - lora_b[0][:, start_idx:end_idx] if lora_b[0] is not None else None, - lora_b[1][:, start_idx:end_idx] if lora_b[1] is not None else None, - ] + for i, (shard_id, shard_size) in enumerate( + zip(self.output_ids, self.output_slices)): + if (lora_b_i := lora_b[i]) is not None: + lora_b[i] = lora_b_i[:, shard_size * shard_id:shard_size * + (shard_id + 1)] return lora_b def slice_bias( self, bias: List[Union[torch.Tensor, None]]) -> List[Union[torch.Tensor, None]]: - # NOTE : each bias could be None. - shard_size = self.output_dim - start_idx = self.tp_rank * shard_size - end_idx = (self.tp_rank + 1) * shard_size - bias = [ - bias[0][start_idx:end_idx] if bias[0] is not None else None, - bias[1][start_idx:end_idx] if bias[1] is not None else None - ] + for i, (shard_id, shard_size) in enumerate( + zip(self.output_ids, self.output_slices)): + if (bias_i := bias[i]) is not None: + bias[i] = bias_i[shard_size * shard_id:shard_size * + (shard_id + 1)] return bias def set_lora( @@ -736,54 +632,35 @@ def set_lora( lora_a: torch.Tensor, lora_b: torch.Tensor, embeddings_tensor: Optional[torch.Tensor], - bias: Optional[torch.Tensor] = None, + lora_bias: Optional[torch.Tensor] = None, ): self.reset_lora(index) if self.tp_size > 1: lora_a = self.slice_lora_a(lora_a) lora_b = self.slice_lora_b(lora_b) - if bias is not None: - bias = self.slice_bias(bias) - - if lora_a[0] is not None: - self.lora_a_stacked[0][ - index, 0, :lora_a[0].shape[1], :lora_a[0].shape[0]].copy_( - lora_a[0].T, non_blocking=True) - self.lora_b_stacked[0][ - index, 0, :lora_b[0].shape[1], :lora_b[0].shape[0]].copy_( - lora_b[0].T, non_blocking=True) - if bias is not None and bias[0] is not None: - self.bias_stacked[0][index, - 0, :bias[0].shape[0]].copy_(bias[0].T, - non_blocking=True) - if lora_a[1] is not None: - self.lora_a_stacked[1][ - index, 0, :lora_a[1].shape[1], :lora_a[1].shape[0]].copy_( - lora_a[1].T, non_blocking=True) - self.lora_b_stacked[1][ - index, 0, :lora_b[1].shape[1], :lora_b[1].shape[0]].copy_( - lora_b[1].T, non_blocking=True) - if bias is not None and bias[1] is not None: - self.bias_stacked[1][index, - 0, :bias[1].shape[0]].copy_(bias[1].T, - non_blocking=True) - - def apply(self, x: torch.Tensor, - bias: Optional[torch.Tensor]) -> torch.Tensor: - output = self.base_layer.quant_method.apply(self.base_layer, x, bias) - if self.bias_stacked is not None: - self.indices = self.punica_wrapper.token_lora_indices - output = apply_bias_packed_nslice( - self.indices, - output, - (self.output_dim, self.output_dim), - self.bias_stacked, - ) - self.punica_wrapper.add_lora_packed_nslice( - output, x, self.lora_a_stacked, self.lora_b_stacked, 1.0, - (self.output_dim, self.output_dim)) - return output + if lora_bias is not None: + lora_bias = self.slice_bias(lora_bias) + + for i in range(self.n_slices): + if (lora_a_i := lora_a[i]) is not None: + self.lora_a_stacked[i][ + index, 0, :lora_a_i.shape[1], :lora_a_i.shape[0]].copy_( + lora_a_i.T, non_blocking=True) + if (lora_b_i := lora_b[i]) is not None: + self.lora_b_stacked[i][ + index, 0, :lora_b_i.shape[1], :lora_b_i.shape[0]].copy_( + lora_b_i.T, non_blocking=True) + + if lora_bias is not None: + self.lora_bias_stacked = cast(Tuple[torch.Tensor, ...], + self.lora_bias_stacked) + for i in range(self.n_slices): + if (lora_bias_i := lora_bias[i]) is not None: + self.lora_bias_stacked[i][index, + 0, :lora_bias_i.shape[0]].copy_( + lora_bias_i.T, + non_blocking=True) @classmethod @_not_fully_sharded_can_replace @@ -813,7 +690,6 @@ class QKVParallelLinearWithLora(ColumnParallelLinearWithLoRA): def __init__(self, base_layer: QKVParallelLinear) -> None: super().__init__(base_layer) - self.tp_size = get_tensor_model_parallel_world_size() self.q_proj_total_size = (self.base_layer.total_num_heads * self.base_layer.head_size) self.q_proj_shard_size = (self.base_layer.num_heads * @@ -822,6 +698,8 @@ def __init__(self, base_layer: QKVParallelLinear) -> None: self.base_layer.head_size) self.kv_proj_total_size = (self.base_layer.total_num_kv_heads * self.base_layer.head_size) + # There is only one LoRA layer + self.n_slices = 1 def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: tp_rank = get_tensor_model_parallel_rank() @@ -856,32 +734,6 @@ def slice_bias(self, bias: torch.Tensor) -> torch.Tensor: bias = torch.cat([bias_q, bias_k, bias_v], dim=1) return bias - def set_lora( - self, - index: int, - lora_a: torch.Tensor, - lora_b: torch.Tensor, - embeddings_tensor: Optional[torch.Tensor], - bias: Optional[torch.Tensor] = None, - ): - self.reset_lora(index) - if self.tp_size > 1: - lora_a = self.slice_lora_a(lora_a) - lora_b = self.slice_lora_b(lora_b) - if bias is not None: - bias = self.slice_bias(bias) - - self.lora_a_stacked[index, - 0, :lora_a.shape[1], :lora_a.shape[0]].copy_( - lora_a.T, non_blocking=True) - self.lora_b_stacked[index, - 0, :lora_b.shape[1], :lora_b.shape[0]].copy_( - lora_b.T, non_blocking=True) - if bias is not None: - self.bias_stacked[index, - 0, :bias.shape[0]].copy_(bias.T, - non_blocking=True) - @classmethod @_not_fully_sharded_can_replace def can_replace_layer(cls, source_layer: nn.Module, @@ -891,8 +743,8 @@ def can_replace_layer(cls, source_layer: nn.Module, packed_modules_list) == 1 -class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA): - """ColumnParallelLinear layer that is composed of 3 sublayers (slices) +class MergedQKVParallelLinearWithLora(MergedColumnParallelLinearWithLoRA): + """MergedColumnParallelLinear layer that is composed of 3 sublayers (slices) packed together in qkv proj fashion (q_proj + k_proj + v_proj -> qkv_proj). @@ -904,16 +756,11 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA): def __init__(self, base_layer: QKVParallelLinear) -> None: super().__init__(base_layer) - - def create_lora_weights( - self, - max_loras: int, - lora_config: LoRAConfig, - model_config: Optional[PretrainedConfig] = None, - ) -> None: - self.lora_config = lora_config + # There are three LoRA layer. + self.n_slices = len(self.base_layer.output_sizes) self.tp_size = get_tensor_model_parallel_world_size() self.tp_rank = get_tensor_model_parallel_rank() + self.q_proj_shard_size = (self.base_layer.num_heads * self.base_layer.head_size) self.kv_proj_shard_size = (self.base_layer.num_kv_heads * @@ -921,227 +768,28 @@ def create_lora_weights( self.q_shard_id = self.tp_rank self.kv_shard_id = self.tp_rank // self.base_layer.num_kv_head_replicas - lora_a_output_size_per_partition = ( - lora_config.max_lora_rank if not lora_config.fully_sharded_loras - else divide(lora_config.max_lora_rank, self.tp_size)) - # q, k, v - self.lora_a_stacked = ( - torch.zeros( - max_loras, - 1, - lora_a_output_size_per_partition, - self.input_size, - dtype=lora_config.lora_dtype, - device=self.device, - ), - torch.zeros( - max_loras, - 1, - lora_a_output_size_per_partition, - self.input_size, - dtype=lora_config.lora_dtype, - device=self.device, - ), - torch.zeros( - max_loras, - 1, - lora_a_output_size_per_partition, - self.input_size, - dtype=lora_config.lora_dtype, - device=self.device, - ), - ) - self.lora_b_stacked = ( - torch.zeros( - max_loras, - 1, - self.q_proj_shard_size, - lora_config.max_lora_rank, - dtype=lora_config.lora_dtype, - device=self.device, - ), - torch.zeros( - max_loras, - 1, - self.kv_proj_shard_size, - lora_config.max_lora_rank, - dtype=lora_config.lora_dtype, - device=self.device, - ), - torch.zeros( - max_loras, - 1, - self.kv_proj_shard_size, - lora_config.max_lora_rank, - dtype=lora_config.lora_dtype, - device=self.device, - ), - ) - if lora_config.bias_enabled: - self.bias_stacked = ( - torch.zeros( - max_loras, - 1, - self.q_proj_shard_size, - dtype=lora_config.lora_dtype, - device=self.device, - ), - torch.zeros( - max_loras, - 1, - self.kv_proj_shard_size, - dtype=lora_config.lora_dtype, - device=self.device, - ), - torch.zeros( - max_loras, - 1, - self.kv_proj_shard_size, - dtype=lora_config.lora_dtype, - device=self.device, - ), - ) - else: - self.bias_stacked = None - self.output_slices = ( self.q_proj_shard_size, self.kv_proj_shard_size, self.kv_proj_shard_size, ) - self.packed_indices: Optional[torch.Tensor] = None - self.standard_indices: Optional[torch.Tensor] = None - # lazily initialized. - self.indices: torch.Tensor - self.indices_len: List[int] - - def reset_lora(self, index: int): - self.lora_a_stacked[0][index] = 0 - self.lora_b_stacked[0][index] = 0 - self.lora_a_stacked[1][index] = 0 - self.lora_b_stacked[1][index] = 0 - self.lora_a_stacked[2][index] = 0 - self.lora_b_stacked[2][index] = 0 - if self.lora_config.bias_enabled: - self.bias_stacked[0][index] = 0 - self.bias_stacked[1][index] = 0 - self.bias_stacked[2][index] = 0 - - def slice_lora_a( - self, lora_a: List[Union[torch.Tensor, None]] - ) -> List[Union[torch.Tensor, None]]: - return lora_a - - def slice_lora_b( - self, lora_b: List[Union[torch.Tensor, None]] - ) -> List[Union[torch.Tensor, None]]: - lora_b_q, lora_b_k, lora_b_v = None, None, None - if lora_b[0] is not None: - lora_b_q = lora_b[0][:, self.q_proj_shard_size * - self.q_shard_id:self.q_proj_shard_size * - (self.q_shard_id + 1), ] - if lora_b[1] is not None: - lora_b_k = lora_b[1][:, self.kv_proj_shard_size * - self.kv_shard_id:self.kv_proj_shard_size * - (self.kv_shard_id + 1), ] - if lora_b[2] is not None: - lora_b_v = lora_b[2][:, self.kv_proj_shard_size * - self.kv_shard_id:self.kv_proj_shard_size * - (self.kv_shard_id + 1), ] - lora_b = [lora_b_q, lora_b_k, lora_b_v] - return lora_b - - def slice_bias( - self, bias: List[Union[torch.Tensor, - None]]) -> List[Union[torch.Tensor, None]]: - bias_q, bias_k, bias_v = bias - if bias_q is not None: - bias_q = bias_q[self.q_proj_shard_size * - self.q_shard_id:self.q_proj_shard_size * - (self.q_shard_id + 1)] - if bias_k is not None: - bias_k = bias_k[self.kv_proj_shard_size * - self.kv_shard_id:self.kv_proj_shard_size * - (self.kv_shard_id + 1)] - if bias_v is not None: - bias_v = bias_v[self.kv_proj_shard_size * - self.kv_shard_id:self.kv_proj_shard_size * - (self.kv_shard_id + 1)] - bias = [bias_q, bias_k, bias_v] - return bias + self.output_ids = ( + self.q_shard_id, + self.kv_shard_id, + self.kv_shard_id, + ) - def set_lora( + def create_lora_weights( self, - index: int, - lora_a: torch.Tensor, - lora_b: torch.Tensor, - embeddings_tensor: Optional[torch.Tensor], - bias: Optional[torch.Tensor] = None, - ): - self.reset_lora(index) - - if self.tp_size > 1: - lora_a = self.slice_lora_a(lora_a) - lora_b = self.slice_lora_b(lora_b) - if bias is not None: - bias = self.slice_bias(bias) - - if lora_b[0] is not None: - lora_b_q = lora_b[0] - self.lora_b_stacked[0][ - index, 0, :lora_b_q.shape[1], :lora_b_q.shape[0]].copy_( - lora_b_q.T, non_blocking=True) - if lora_b[1] is not None: - lora_b_k = lora_b[1] - self.lora_b_stacked[1][ - index, 0, :lora_b_k.shape[1], :lora_b_k.shape[0]].copy_( - lora_b_k.T, non_blocking=True) - if lora_b[2] is not None: - lora_b_v = lora_b[2] - self.lora_b_stacked[2][ - index, 0, :lora_b_v.shape[1], :lora_b_v.shape[0]].copy_( - lora_b_v.T, non_blocking=True) - - if lora_a[0] is not None: - self.lora_a_stacked[0][ - index, 0, :lora_a[0].shape[1], :lora_a[0].shape[0]].copy_( - lora_a[0].T, non_blocking=True) - if lora_a[1] is not None: - self.lora_a_stacked[1][ - index, 0, :lora_a[1].shape[1], :lora_a[1].shape[0]].copy_( - lora_a[1].T, non_blocking=True) - if lora_a[2] is not None: - self.lora_a_stacked[2][ - index, 0, :lora_a[2].shape[1], :lora_a[2].shape[0]].copy_( - lora_a[2].T, non_blocking=True) - - if bias is not None: - if bias[0] is not None: - self.bias_stacked[0][index, 0, :bias[0].shape[0]].copy_( - bias[0].T, non_blocking=True) - if bias[1] is not None: - self.bias_stacked[1][index, 0, :bias[1].shape[0]].copy_( - bias[1].T, non_blocking=True) - if bias[2] is not None: - self.bias_stacked[2][index, 0, :bias[2].shape[0]].copy_( - bias[2].T, non_blocking=True) - - def apply(self, x: torch.Tensor, - bias: Optional[torch.Tensor]) -> torch.Tensor: - output = self.base_layer.quant_method.apply(self.base_layer, x, bias) - if self.bias_stacked is not None: - self.indices = self.punica_wrapper.token_lora_indices - output = apply_bias_packed_nslice( - self.indices, - output, - self.output_slices, - self.bias_stacked, - ) - self.punica_wrapper.add_lora_packed_nslice(output, x, - self.lora_a_stacked, - self.lora_b_stacked, 1.0, - self.output_slices) - return output + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None, + ) -> None: + """ + The main reason for overloading this function is to handle inconsistent + weight dimensions in qkv lora. + """ + super().create_lora_weights(max_loras, lora_config, model_config) @classmethod @_not_fully_sharded_can_replace @@ -1156,76 +804,25 @@ def can_replace_layer( and len(packed_modules_list) == 3) -class RowParallelLinearWithLoRA(BaseLayerWithLoRA): +class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA): def __init__(self, base_layer: RowParallelLinear) -> None: - super().__init__() - self.base_layer = base_layer + super().__init__(base_layer) + + self.tp_size = get_tensor_model_parallel_world_size() + # reset input_size self.input_size = self.base_layer.input_size_per_partition self.output_size = self.base_layer.output_size - self.device = _get_lora_device(self.base_layer) - def create_lora_weights( - self, - max_loras: int, - lora_config: LoRAConfig, - model_config: Optional[PretrainedConfig] = None, - ) -> None: - self.lora_config = lora_config self.tp_rank = get_tensor_model_parallel_rank() - self.lora_a_stacked = torch.zeros( - ( - max_loras, - 1, - lora_config.max_lora_rank, - self.input_size, - ), - dtype=lora_config.lora_dtype, - device=self.device, - ) - tp_size = get_tensor_model_parallel_world_size() - lora_b_output_size_per_partition = ( - self.output_size if not lora_config.fully_sharded_loras else - divide(self.output_size, tp_size)) - - self.lora_b_stacked = torch.zeros( - ( - max_loras, - 1, - lora_b_output_size_per_partition, - lora_config.max_lora_rank, - ), - dtype=lora_config.lora_dtype, - device=self.device, - ) - - if lora_config.bias_enabled: - self.bias_stacked = torch.zeros( - ( - max_loras, - 1, - self.output_size, - ), - dtype=lora_config.lora_dtype, - device=self.device, - ) - else: - self.bias_stacked = None - # Lazily initialized - self.indices: torch.Tensor - self.indices_len: List[int] - - def reset_lora(self, index: int): - self.lora_a_stacked[index] = 0 - self.lora_b_stacked[index] = 0 - if self.lora_config.bias_enabled: - self.bias_stacked[index] = 0 + # There is only one LoRA layer. + self.n_slices = 1 def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: - tensor_model_parallel_rank = get_tensor_model_parallel_rank() + shard_size = self.input_size - start_idx = tensor_model_parallel_rank * shard_size - end_idx = (tensor_model_parallel_rank + 1) * shard_size + start_idx = self.tp_rank * shard_size + end_idx = (self.tp_rank + 1) * shard_size lora_a = lora_a[start_idx:end_idx, :] return lora_a @@ -1235,46 +832,6 @@ def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: def slice_bias(self, bias: torch.Tensor) -> torch.Tensor: return bias - def set_lora( - self, - index: int, - lora_a: torch.Tensor, - lora_b: torch.Tensor, - embeddings_tensor: Optional[torch.Tensor], - bias: Optional[torch.Tensor] = None, - ): - self.reset_lora(index) - - if self.base_layer.tp_size > 1: - lora_a = self.slice_lora_a(lora_a) - lora_b = self.slice_lora_b(lora_b) - if bias is not None: - bias = self.slice_bias(bias) - - self.lora_a_stacked[index, - 0, :lora_a.shape[1], :lora_a.shape[0]].copy_( - lora_a.T, non_blocking=True) - self.lora_b_stacked[index, - 0, :lora_b.shape[1], :lora_b.shape[0]].copy_( - lora_b.T, non_blocking=True) - if bias is not None: - self.bias_stacked[index, - 0, :bias.shape[0]].copy_(bias.T, - non_blocking=True) - - def apply(self, x: torch.Tensor) -> torch.Tensor: - output = self.base_layer.quant_method.apply(self.base_layer, x) - if self.bias_stacked is not None: - self.indices = self.punica_wrapper.token_lora_indices - output = apply_bias( - self.indices, - output, - self.bias_stacked, - ) - self.punica_wrapper.add_lora(output, x, self.lora_a_stacked, - self.lora_b_stacked, 1.0) - return output - def forward(self, input_): """Forward of RowParallelLinear @@ -1292,10 +849,9 @@ def forward(self, input_): input_parallel = input_ else: # TODO: simplify code below - tp_rank = get_tensor_model_parallel_rank() splitted_input = split_tensor_along_last_dim( input_, num_partitions=self.base_layer.tp_size) - input_parallel = splitted_input[tp_rank].contiguous() + input_parallel = splitted_input[self.tp_rank].contiguous() # Matrix multiply. output_parallel = self.apply(input_parallel) diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 2ffefe61427e3..9855b57d0c9c9 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -555,17 +555,17 @@ def create_dummy_lora( input_dim, output_dim, rank, - module.lora_a_stacked.dtype, + module.lora_a_stacked[0].dtype, "cpu", embeddings_tensor_dim=embeddings_tensor_dim, bias_enabled=bias_enabled) else: lora = LoRALayerWeights.create_dummy_lora_weights( module_name, - module.lora_a_stacked.shape[-1], - module.lora_b_stacked.shape[-2], + module.lora_a_stacked[0].shape[-1], + module.lora_b_stacked[0].shape[-2], rank, - module.lora_a_stacked.dtype, + module.lora_a_stacked[0].dtype, "cpu", bias_enabled=bias_enabled, ) diff --git a/vllm/lora/punica.py b/vllm/lora/punica.py index 082041f390750..563d1181d6fcb 100644 --- a/vllm/lora/punica.py +++ b/vllm/lora/punica.py @@ -362,7 +362,7 @@ def long_lora_indices(self) -> torch.Tensor: long_lora_len = self.indices_len[4] return self._long_lora_indices[:long_lora_len] - def shrink_prefill( + def _shrink_prefill( self, y: torch.Tensor, x: torch.Tensor, @@ -380,7 +380,7 @@ def shrink_prefill( scale, ) - def shrink_decode( + def _shrink_decode( self, y: torch.Tensor, x: torch.Tensor, @@ -389,7 +389,7 @@ def shrink_decode( ): bgmv_shrink(x, w_t_all, y, self.token_lora_indices, scale) - def expand_prefill( + def _expand_prefill( self, y: torch.Tensor, x: torch.Tensor, @@ -407,7 +407,7 @@ def expand_prefill( add_input, ) - def expand_decode( + def _expand_decode( self, y: torch.Tensor, x: torch.Tensor, @@ -416,7 +416,7 @@ def expand_decode( ): bgmv_expand(x, w_t_all, y, self.token_lora_indices, add_input) - def expand_slice_prefill( + def _expand_slice_prefill( self, y: torch.Tensor, x: torch.Tensor, @@ -438,7 +438,7 @@ def expand_slice_prefill( add_input, ) - def expand_slice_decode( + def _expand_slice_decode( self, y: torch.Tensor, x: torch.Tensor, @@ -450,7 +450,57 @@ def expand_slice_decode( bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset, y_slice_size, add_input) - def add_shrink( + def _apply_expand(self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + y_offset: Optional[int], + y_slice_size: Optional[int], + add_input: bool = True): + """ + Perform the ` y[:,y_offset:y_offset+y_slice_size]+=x@w_t_all` + computation, which is suitable for the + GEMM of lora'b. + """ + + expand_slice_fun: Callable = (self._expand_slice_prefill + if self.is_prefill else + self._expand_slice_decode) + expand_slice_fun(y, x, w_t_all, y_offset, y_slice_size, add_input) + + def _apply_bias( + self, + indices: torch.Tensor, + output: torch.Tensor, + output_slices: Tuple[int, ...], + lora_bias_stacked: Tuple[Optional[torch.Tensor], ...], + ): + """Applies bias to output + + Input shapes: + lora_bias_stacked: 3 element tuple of (num_loras, output_dim) + indices: (batch_size) + output: (batch_size, q_slice_size + 2*kv_slice_size) + output_slices: n-1 element tuple of (slice_size...), + where n is number of slices + """ + org_output = output + output = output.view(-1, output.shape[-1]) + indices = indices.view(-1) + + offset_left = 0 + for slice_idx, slice in enumerate(output_slices): + bias = lora_bias_stacked[slice_idx] + if bias is not None: + bias = bias.view(-1, bias.shape[-1]) + bias = bias[indices] + bias[indices == -1] = 0 + output[:, offset_left:offset_left + slice] += bias + offset_left += slice + + return output.view_as(org_output) + + def _apply_shrink( self, y: torch.Tensor, x: torch.Tensor, @@ -461,151 +511,215 @@ def add_shrink( Perform the ` y+=x@w_t_all` computation, which is suitable for the GEMM of lora'a. When `is_prefill is` true, it indicates that it is currently the - prefill stage, and the `shrink_prefill` function should be called. - Otherwise, it is the decode stage, and the shrink_decode function + prefill stage, and the `_shrink_prefill` function should be called. + Otherwise, it is the decode stage, and the _shrink_decode function should be called. """ - shrink_fun: Callable = (self.shrink_prefill - if self.is_prefill else self.shrink_decode) + y_org = y + y = y.view(-1, y.shape[-1]) + shrink_fun: Callable = (self._shrink_prefill + if self.is_prefill else self._shrink_decode) shrink_fun(y, x, w_t_all, scale) + y = y.view_as(y_org) - def add_expand( + def add_shrink( self, - y: torch.Tensor, + y: Union[Tuple[torch.Tensor, ...], torch.Tensor], x: torch.Tensor, - w_t_all: torch.Tensor, - add_input: bool = True, + lora_a_stacked: Tuple[torch.Tensor, ...], + scale: float, ): """ - Perform the ` y+=x@w_t_all` computation, which is suitable for the - GEMM of lora'b. - When `is_prefill` is true, it indicates that it is currently the - prefill stage, and the `expand_prefill` function should be called. - Otherwise, it is the decode stage, and the expand_decode function + Performs GEMM for multiple slices of lora_a. + When `is_prefill is` true, it indicates that it is currently the + prefill stage, and the `_shrink_prefill` function should be called. + Otherwise, it is the decode stage, and the _shrink_decode function should be called. - """ + + Semantics: + for i in range(len(lora_a_stacked)): + y[i] += (x @ lora_a_stacked[i]) * scale + + Args: + y (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Output tensors + x (torch.Tensor): Input tensor + lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weights + scale (float): Scaling factor for the operation + """ - expand_fun: Callable = (self.expand_prefill - if self.is_prefill else self.expand_decode) - expand_fun(y, x, w_t_all, add_input) - - def add_expand_slice(self, - y: torch.Tensor, - x: torch.Tensor, - w_t_all: torch.Tensor, - y_offset: Optional[int], - y_slice_size: Optional[int], - add_input: bool = True): + x = x.view(-1, x.shape[-1]) + # TODO fuse these kernels + for slice_idx in range(len(lora_a_stacked)): + self._apply_shrink(y[slice_idx], x, lora_a_stacked[slice_idx], + scale) + + def add_expand( + self, + y: torch.Tensor, + x: Union[Tuple[torch.Tensor, ...], torch.Tensor], + lora_b_stacked: Tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], + output_slices: Tuple[int, ...], + offset_start: int = 0, + add_input=True, + ) -> None: """ - Similar to `add_expand` + Performs GEMM and bias addition for multiple slices of lora_b. + + Semantics: + for i in range(len(lora_b_stacked)): + slice = output_slices[i] + y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] + + lora_bias_stacked[i] + offset += slice + + Args: + y (torch.Tensor): Output tensor. + x (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Input tensors + lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight + lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): + bias's weight + output_slices (Tuple[int, ...]): Every slice's size + add_input (bool): Defaults to True. + """ + y_org = y + y = y.view(-1, y.shape[-1]) + offset_left = offset_start + if lora_bias_stacked is not None: + self._apply_bias(self.token_lora_indices, y, output_slices, + lora_bias_stacked) + for slice_idx in range(len(lora_b_stacked)): + self._apply_expand( + y, + x[slice_idx], + lora_b_stacked[slice_idx], + offset_left, + output_slices[slice_idx], + add_input=add_input, + ) + offset_left += output_slices[slice_idx] + y = y.view_as(y_org) + + def add_lora_embedding( + self, + y: torch.Tensor, + x: torch.Tensor, + lora_b_stacked: torch.Tensor, + add_input: bool = True, + ): """ + Applies lora specifically for VocabParallelEmbeddingWithLoRA. - expand_slice_fun: Callable = (self.expand_slice_prefill - if self.is_prefill else - self.expand_slice_decode) - expand_slice_fun(y, x, w_t_all, y_offset, y_slice_size, add_input) + Semantics: + y += x @ lora_b_stacked - def add_lora(self, - y: torch.Tensor, - x: torch.Tensor, - wa_t_all: torch.Tensor, - wb_t_all: torch.Tensor, - scale: float, - y_offset: Optional[int] = None, - y_slice_size: Optional[int] = None, - *, - buffer: Optional[torch.Tensor] = None) -> None: + Args: + y (torch.Tensor): Output tensor. + x (torch.Tensor): Input tensor. + lora_b_stacked (torch.Tensor): lora_b's weights. + add_input (bool): Default to True. + """ + + # Embedding layer only need expand op + expand_fun: Callable = (self._expand_prefill + if self.is_prefill else self._expand_decode) + expand_fun(y, x, lora_b_stacked, add_input) + + def add_lora_linear( + self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: Tuple[torch.Tensor, ...], + lora_b_stacked: Tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], + scale: float, + output_slices: Tuple[int, ...], + *, + buffer: Optional[Tuple[torch.Tensor, ...]] = None) -> None: + """ + Applicable to linear-related lora. + Semantics: - y[i] += ( - x[i].unsqueeze(0) - @ wa_t_all[indices[i], layer_idx, :, :].transpose(-1, -2) - @ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2) - * scale - ).squeeze(0) + for i in range(len(lora_a_stacked)): + y[i] += ( + x[i].unsqueeze(0) + @ lora_a_stacked[indices[i], layer_idx, :, :] + @ lora_b_stacked[indices[i], layer_idx, :, :] + * scale + ).squeeze(0)+lora_bias_stacked[i] + Args: - y (torch.Tensor): Output tensor. Will be changed in-place. + y (torch.Tensor): Output tensor. Will be changed in-place. x (torch.Tensor): Input tensor - wa_t_all (torch.Tensor): lora_a's weight - wb_t_all (torch.Tensor): lora_b's weight + lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weight. + lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight. + lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): lora's bias. scale (float): Scaling factor. - y_offset (Optional[int], optional): Offset to apply to the starting - column of y. - y_slice_size (Optional[int], optional): Size of the y column slice. - buffer (Optional[torch.Tensor], optional): Defaults to None. + output_slices (Tuple[int, ...]): Every slice's size. + buffer (Optional[Tuple[torch.Tensor, ...]]): Defaults to None. """ - y_org = y - y = y.view(-1, y.shape[-1]) - x = x.view(-1, x.shape[-1]) - r = wb_t_all.size(-1) + + assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices) + if lora_bias_stacked is not None: + assert len(lora_bias_stacked) == len(output_slices) + y = self._apply_bias(self.token_lora_indices, y, output_slices, + lora_bias_stacked) + if buffer is None: + r = lora_b_stacked[0].size(-1) # We set the buffer to be float32 by default ,refer to: # https://github.com/triton-lang/triton/issues/1387 - buffer = torch.zeros((x.size(0), r), - dtype=torch.float32, - device=x.device) - - self.add_shrink(buffer, x, wa_t_all, scale) - if y_offset is None and y_slice_size is None: - self.add_expand(y, buffer, wb_t_all, add_input=True) - else: - self.add_expand_slice(y, - buffer, - wb_t_all, - y_offset, - y_slice_size, - add_input=True) - y = y.view_as(y_org) - - def add_lora_packed_nslice(self, y: torch.Tensor, x: torch.Tensor, - lora_a_stacked: Tuple[torch.Tensor, - torch.Tensor, - torch.Tensor], - lora_b_stacked: Tuple[torch.Tensor, - torch.Tensor, - torch.Tensor], - scale: float, - output_slices: Tuple[int, ...]) -> None: - """ - Applies lora to each input. Similar to add_lora, This method is - used for layers that are composed of multiple sublayers - (slices) packed together. - """ - y_org = y - x = x.view(-1, x.shape[-1]) - y = y.view(-1, y.shape[-1]) - offset_left = 0 - # TODO fuse these kernels - for slice_idx in range(len(output_slices)): - self.add_lora(y, x, lora_a_stacked[slice_idx], - lora_b_stacked[slice_idx], scale, offset_left, - output_slices[slice_idx]) - offset_left += output_slices[slice_idx] - - y = y.view_as(y_org) + buffer = tuple( + torch.zeros( + (x.size(0), r), dtype=torch.float32, device=x.device) + for _ in range(len(output_slices))) + self.add_shrink(buffer, x, lora_a_stacked, scale) + self.add_expand(y, + buffer, + lora_b_stacked, + None, + output_slices, + add_input=True) def add_lora_logits(self, y: torch.Tensor, x: torch.Tensor, - wa_t_all: torch.Tensor, - wb_t_all: torch.Tensor, + lora_a_stacked: torch.Tensor, + lora_b_stacked: torch.Tensor, scale, *, buffer: Optional[torch.Tensor] = None) -> None: """ - LogitsProcessorWithLoRA always using bgmv - """ + Applies lora specifically for LogitsProcessorWithLoRA. + + Semantics: + buffer = (x @ lora_a_stacked) * scale + y += buffer @ lora_b_stacked + + Args: + y (torch.Tensor): Output tensor. + x (torch.Tensor): Input tensor. + lora_a_stacked (torch.Tensor): lora_a's weights. + lora_b_stacked (torch.Tensor):lora_b's weights. + scale (float): Scaling factor. + buffer (Optional[torch.Tensor]):Default to None. + """ y_org = y y = y.view(-1, y.shape[-1]) x = x.view(-1, x.shape[-1]) - r = wb_t_all.size(-1) + r = lora_b_stacked.size(-1) if buffer is None: # We set the buffer to be float32 by default ,refer to: # https://github.com/triton-lang/triton/issues/1387 buffer = torch.zeros((x.size(0), r), dtype=torch.float32, device=x.device) - - bgmv_shrink(x, wa_t_all, buffer, self.sampler_indices, scale) - bgmv_expand(buffer, wb_t_all, y, self.sampler_indices, add_inputs=True) + # LogitsProcessorWithLoRA always using bgmv. + bgmv_shrink(x, lora_a_stacked, buffer, self.sampler_indices, scale) + bgmv_expand(buffer, + lora_b_stacked, + y, + self.sampler_indices, + add_inputs=True) y = y.view_as(y_org) diff --git a/vllm/model_executor/guided_decoding/__init__.py b/vllm/model_executor/guided_decoding/__init__.py index d7b67425fcbc0..e631aec928ec5 100644 --- a/vllm/model_executor/guided_decoding/__init__.py +++ b/vllm/model_executor/guided_decoding/__init__.py @@ -1,14 +1,96 @@ -from typing import Optional +from __future__ import annotations -from vllm.logits_process import LogitsProcessor -from vllm.sampling_params import GuidedDecodingParams +from typing import TYPE_CHECKING + +from vllm.logger import init_logger +from vllm.platforms import CpuArchEnum, current_platform + +if TYPE_CHECKING: + from transformers import PreTrainedTokenizer + + from vllm.config import ModelConfig + from vllm.logits_process import LogitsProcessor + from vllm.sampling_params import GuidedDecodingParams + +logger = init_logger(__name__) + + +def has_xgrammar_unsupported_json_features(schema: dict) -> bool: + """Check if JSON schema contains features unsupported by xgrammar.""" + + def check_object(obj: dict) -> bool: + if not isinstance(obj, dict): + return False + + # Check for pattern restrictions + if "pattern" in obj: + return True + + # Check for numeric ranges + if obj.get("type") in ("integer", "number") and any( + key in obj for key in [ + "minimum", "maximum", "exclusiveMinimum", + "exclusiveMaximum", "multipleOf" + ]): + return True + + # Recursively check all nested objects and arrays + for value in obj.values(): + if isinstance(value, dict): + if check_object(value): + return True + elif isinstance(value, list): + for item in value: + if isinstance(item, dict) and check_object(item): + return True + + return False + + return check_object(schema) + + +def maybe_backend_fallback( + guided_params: GuidedDecodingParams) -> GuidedDecodingParams: + # lm-format-enforce doesn't support grammar, fallback to xgrammar + if (guided_params.backend == "lm-format-enforcer" + and guided_params.grammar is not None): + logger.warning( + "lm-format-enforcer does not support grammar guided decoding. " + "Falling back to use xgrammar instead.") + guided_params.backend = "xgrammar" + + if guided_params.backend == "xgrammar": + # xgrammar only has x86 wheels for linux, fallback to outlines + if current_platform.get_cpu_architecture() is not CpuArchEnum.X86: + logger.warning("xgrammar is only supported on x86 CPUs. " + "Falling back to use outlines instead.") + guided_params.backend = "outlines" + + # xgrammar doesn't support regex or choice, fallback to outlines + if guided_params.regex is not None or guided_params.choice is not None: + logger.warning( + "xgrammar only supports json or grammar guided decoding. " + "Falling back to use outlines instead.") + guided_params.backend = "outlines" + + # xgrammar doesn't support some JSON schema features + elif (guided_params.json is not None + and has_xgrammar_unsupported_json_features(guided_params.json)): + logger.warning( + "xgrammar does not support advanced JSON schema features like " + "patterns or numeric ranges. " + "Falling back to use outlines instead.") + guided_params.backend = "outlines" + + return guided_params async def get_guided_decoding_logits_processor( - guided_params: GuidedDecodingParams, - tokenizer) -> Optional[LogitsProcessor]: + guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizer, + model_config: ModelConfig) -> LogitsProcessor | None: + guided_params = maybe_backend_fallback(guided_params) # CFG grammar not supported by LMFE, so we use outlines instead - if guided_params.backend == 'outlines' or guided_params.grammar: + if guided_params.backend == 'outlines': # NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193 from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa get_outlines_guided_decoding_logits_processor) @@ -19,17 +101,23 @@ async def get_guided_decoding_logits_processor( get_local_lm_format_enforcer_guided_decoding_logits_processor) return get_local_lm_format_enforcer_guided_decoding_logits_processor( guided_params, tokenizer) + if guided_params.backend == 'xgrammar': + from vllm.model_executor.guided_decoding.xgrammar_decoding import ( # noqa + get_local_xgrammar_guided_decoding_logits_processor) + return get_local_xgrammar_guided_decoding_logits_processor( + guided_params, tokenizer, model_config) raise ValueError( f"Unknown guided decoding backend '{guided_params.backend}'. " - "Must be one of 'outlines, 'lm-format-enforcer'") + "Must be one of 'outlines, 'lm-format-enforcer', 'xgrammar'") def get_local_guided_decoding_logits_processor( - guided_params: GuidedDecodingParams, - tokenizer) -> Optional[LogitsProcessor]: + guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizer, + model_config: ModelConfig) -> LogitsProcessor | None: + guided_params = maybe_backend_fallback(guided_params) # CFG grammar not supported by LMFE, so we use outlines instead - if guided_params.backend == 'outlines' or guided_params.grammar: + if guided_params.backend == 'outlines': # NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193 from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa get_local_outlines_guided_decoding_logits_processor) @@ -40,7 +128,12 @@ def get_local_guided_decoding_logits_processor( get_local_lm_format_enforcer_guided_decoding_logits_processor) return get_local_lm_format_enforcer_guided_decoding_logits_processor( guided_params, tokenizer) + if guided_params.backend == 'xgrammar': + from vllm.model_executor.guided_decoding.xgrammar_decoding import ( # noqa + get_local_xgrammar_guided_decoding_logits_processor) + return get_local_xgrammar_guided_decoding_logits_processor( + guided_params, tokenizer, model_config) raise ValueError( f"Unknown guided decoding backend '{guided_params.backend}'. " - "Must be one of 'outlines, 'lm-format-enforcer'") + "Must be one of 'outlines, 'lm-format-enforcer', 'xgrammar'") diff --git a/vllm/model_executor/guided_decoding/xgrammar_decoding.py b/vllm/model_executor/guided_decoding/xgrammar_decoding.py new file mode 100644 index 0000000000000..b59a2269d2cd5 --- /dev/null +++ b/vllm/model_executor/guided_decoding/xgrammar_decoding.py @@ -0,0 +1,266 @@ +# noqa: UP007 +from __future__ import annotations + +import json +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, NamedTuple + +import torch +from transformers import PreTrainedTokenizerFast + +try: + import xgrammar as xgr + from xgrammar.base import _core as xgr_core +except ImportError: + pass + +from vllm.model_executor.guided_decoding.xgrammar_utils import ( + convert_lark_to_gbnf, grammar_is_likely_lark) + +if TYPE_CHECKING: + from transformers import PreTrainedTokenizer + + from vllm.config import ModelConfig + from vllm.sampling_params import GuidedDecodingParams + + +# TODO: passing batch size to max threads here +def get_local_xgrammar_guided_decoding_logits_processor( + guided_params: GuidedDecodingParams, + tokenizer: PreTrainedTokenizer, + model_config: ModelConfig, + max_threads: int = 8): + config = GrammarConfig.from_guided_params(guided_params=guided_params, + model_config=model_config, + tokenizer=tokenizer, + max_threads=max_threads) + return XGrammarLogitsProcessor(config) + + +class TokenizerData(NamedTuple): + """Immutable container for cached tokenizer data.""" + encoded_vocab: list[str] + stop_token_ids: list[int] | None + backend_str: str + + +class TokenizerDataCache: + """Cache manager for tokenizer data to avoid repeated processing.""" + _cache: dict[int, TokenizerData] = {} + + @classmethod + def get_tokenizer_data(cls, + tokenizer: PreTrainedTokenizer) -> TokenizerData: + tokenizer_hash = hash(tokenizer) + + if tokenizer_hash not in cls._cache: + # Vendored from xgrammar logic since we cannot pickle the tokenizer + # https://github.com/mlc-ai/xgrammar/blob/d77c0a0173ef14779c918e3be7966ba852f7910f/python/xgrammar/tokenizer_info.py#L98 # noqa: E501 + try: + encoded_vocab = [ + token for token, _ in sorted(tokenizer.get_vocab().items(), + key=lambda x: x[1]) + ] + except AttributeError as e: + raise ValueError( + f"Cannot get the vocabulary of the tokenizer " + f"{type(tokenizer)}. The tokenizer should have a " + "get_vocab method.") from e + + stop_token_ids = None + backend_str = xgr.VocabType.RAW + if isinstance(tokenizer, PreTrainedTokenizerFast): + backend_str = tokenizer.backend_tokenizer.to_str() + if stop_token_ids is None and hasattr( + tokenizer, + "eos_token_id") and tokenizer.eos_token_id is not None: + stop_token_ids = [tokenizer.eos_token_id] + + cls._cache[tokenizer_hash] = TokenizerData( + encoded_vocab=encoded_vocab, + stop_token_ids=stop_token_ids, + backend_str=backend_str) + + return cls._cache[tokenizer_hash] + + +class GrammarCompilerCache: + """ + Cache for GrammarCompiler instances based on tokenizer. + + This cache reduces the overhead of creating new compiler instances when + using the same tokenizer configuration. + """ + _cache: dict[str, xgr.GrammarCompiler] = {} + + @classmethod + def get_compiler(cls, config: GrammarConfig) -> xgr.GrammarCompiler: + cache_key = str(config.tokenizer_hash) + + if cache_key not in cls._cache: + assert config.encoded_vocab is not None + tokenizer_info = xgr.TokenizerInfo._create_from_handle( + xgr_core.TokenizerInfo.from_huggingface( + config.encoded_vocab, config.backend_str, + config.vocab_size, config.stop_token_ids)) + cls._cache[cache_key] = xgr.GrammarCompiler( + tokenizer_info, max_threads=config.max_threads) + + return cls._cache[cache_key] + + +@dataclass +class GrammarConfig: + """Serializable configuration for grammar compilation""" + tokenizer_hash: int + vocab_size: int + json_str: str | None = None + grammar_str: str | None = None + json_object: bool | None = None + max_threads: int = 8 + # Only populated if tokenizer_hash not in cache + encoded_vocab: list[str] | None = None + stop_token_ids: list[int] | None = None + backend_str: str | None = None + + @classmethod + def from_guided_params(cls, + guided_params: GuidedDecodingParams, + model_config: ModelConfig, + tokenizer: PreTrainedTokenizer, + max_threads: int = 8) -> GrammarConfig: + + tokenizer_hash = hash(tokenizer) + # Only get tokenizer data if not already cached + if tokenizer_hash in TokenizerDataCache._cache: + encoded_vocab = None + stop_token_ids = None + backend_str = None + else: + tokenizer_data = TokenizerDataCache.get_tokenizer_data(tokenizer) + encoded_vocab = tokenizer_data.encoded_vocab + stop_token_ids = tokenizer_data.stop_token_ids + backend_str = tokenizer_data.backend_str + + if guided_params.json: + if not isinstance(guided_params.json, str): + json_str = json.dumps(guided_params.json) + else: + json_str = guided_params.json + return cls(json_str=json_str, + vocab_size=model_config.hf_config.vocab_size, + encoded_vocab=encoded_vocab, + stop_token_ids=stop_token_ids, + backend_str=backend_str, + tokenizer_hash=tokenizer_hash, + max_threads=max_threads) + elif guided_params.grammar: + # XGrammar only supports GBNF grammars, so we must convert Lark + if grammar_is_likely_lark(guided_params.grammar): + try: + grammar_str = convert_lark_to_gbnf(guided_params.grammar) + except ValueError as e: + raise ValueError( + "Failed to convert the grammar from Lark to GBNF. " + "Please either use GBNF grammar directly or specify" + " --guided-decoding-backend=outlines.\n" + f"Conversion error: {str(e)}") from e + else: + grammar_str = guided_params.grammar + return cls(grammar_str=grammar_str, + vocab_size=model_config.hf_config.vocab_size, + encoded_vocab=encoded_vocab, + stop_token_ids=stop_token_ids, + backend_str=backend_str, + tokenizer_hash=tokenizer_hash, + max_threads=max_threads) + elif guided_params.json_object: + return cls(json_object=True, + vocab_size=model_config.hf_config.vocab_size, + encoded_vocab=encoded_vocab, + stop_token_ids=stop_token_ids, + backend_str=backend_str, + tokenizer_hash=tokenizer_hash, + max_threads=max_threads) + else: + raise ValueError( + "Currently only support JSON and EBNF grammar mode for xgrammar" + ) + + +@dataclass +class XGrammarLogitsProcessor: + """Wrapper class to support pickle protocol""" + config: GrammarConfig + + ctx: xgr.CompiledGrammar | None = None + token_bitmask: torch.Tensor = None # type: ignore[assignment] + matchers: list[xgr.GrammarMatcher] = field(default_factory=list) + batch_size: int = field(default=1) + prefilled: bool = field(default=False) + + def __getstate__(self) -> dict[str, Any]: + return {'config': self.config} + + def __setstate__(self, state: dict[str, Any]): + self.config = state['config'] + + self.ctx = None + self.matchers = [] + self.batch_size = 1 + self.token_bitmask = None # type: ignore[assignment] + self.prefilled = False + + def _ensure_ctx(self): + """Lazily initialize the processor in the worker process""" + if self.ctx is None: + compiler = GrammarCompilerCache.get_compiler(self.config) + if self.config.json_str is not None: + self.ctx = compiler.compile_json_schema(self.config.json_str) + elif self.config.grammar_str is not None: + self.ctx = compiler.compile_grammar(self.config.grammar_str) + elif self.config.json_object: + self.ctx = compiler.compile_builtin_json_grammar() + else: + raise ValueError( + "Invalid configuration for xgrammar logits processor") + + def __call__(self, input_ids: list[int], + scores: torch.Tensor) -> torch.Tensor: + if self.ctx is None: + self._ensure_ctx() + + if len(self.matchers) == 0: + self.matchers = [ + xgr.GrammarMatcher(self.ctx) for _ in range(self.batch_size) + ] + self.token_bitmask = xgr.allocate_token_bitmask( + self.batch_size, self.config.vocab_size) + + if not self.prefilled: + # Have not sampled a token yet + self.prefilled = True + else: + for i, matcher in enumerate(self.matchers): + if not matcher.is_terminated(): + sampled_token = input_ids[-1] + assert self.matchers[i].accept_token(sampled_token) + + for i, matcher in enumerate(self.matchers): + if not matcher.is_terminated(): + # @ubospica: ideally, fill_next_token_bitmask should be + # parallelized with model decoding + # See https://github.com/vllm-project/vllm/pull/10785/files#r1864278303 + matcher.fill_next_token_bitmask(self.token_bitmask, i) + + # token_bitmask is a CPU tensor for use with accept_token and + # fill_next_token_bitmask so we move it to the device of scores + device_type = scores.device.type + if device_type != "cuda": + scores = scores.to("cpu") + xgr.apply_token_bitmask_inplace(scores, + self.token_bitmask.to(scores.device)) + if device_type != "cuda": + scores = scores.to(device_type) + + return scores diff --git a/vllm/model_executor/guided_decoding/xgrammar_utils.py b/vllm/model_executor/guided_decoding/xgrammar_utils.py new file mode 100644 index 0000000000000..12b42245f4e3d --- /dev/null +++ b/vllm/model_executor/guided_decoding/xgrammar_utils.py @@ -0,0 +1,162 @@ +import re + + +def grammar_is_likely_lark(grammar_str: str) -> bool: + """ + Check if grammar appears to use Lark syntax. + + Args: + grammar_str: Input grammar string + + Returns: + bool: True if grammar appears to be in Lark format, False otherwise + + Examples: + >>> grammar_is_likely_lark("rule: 'abc'") + True + >>> grammar_is_likely_lark("rule ::= 'abc'") + False + """ + if not grammar_str or not isinstance(grammar_str, str): + return False + + for line in grammar_str.split('\n'): + # Remove both comment styles + line = re.sub(r'(#|//).*$', '', line).strip() + if not line: + continue + + # Look for Lark-style rule definitions + if ':' in line and '::=' not in line: + return True + + # Look for Lark-specific features + if any(pattern in line for pattern in ['?start:', '|', '~']): + return True + + return False + + +def convert_lark_to_gbnf(grammar_str: str) -> str: + """ + Convert a Lark grammar string to GBNF format. + + GBNF reference: + https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md + Lark grammar reference: + https://lark-parser.readthedocs.io/en/latest/grammar.html + + Args: + grammar_str: Input grammar in Lark format + + Returns: + str: Converted grammar in GBNF format + + Examples: + >>> print(convert_lark_to_gbnf("rule: 'hello'")) + root ::= rule + rule ::= "hello" + """ + if not isinstance(grammar_str, str): + raise ValueError(f"Grammar must be a string, got {type(grammar_str)}") + if not grammar_str.strip(): + raise ValueError("Grammar string cannot be empty") + + defined_rules = set() + referenced_rules = set() + output_lines = [] + + def clean_line(line: str) -> str: + """Remove comments and whitespace from line.""" + return re.sub(r'(#|//).*$', '', line).strip() + + def check_quotes(text: str, rule_name: str, line_num: int) -> None: + """Validate quote matching in text.""" + if text.count("'") % 2 != 0 or text.count('"') % 2 != 0: + raise ValueError( + f"Mismatched quotes in {rule_name} on line {line_num}") + + def extract_references(text: str) -> set: + """Extract rule references from text.""" + # Remove quoted strings and special characters + text = re.sub(r'"[^"]*"', '', text) + text = re.sub(r'[+*?()|\[\]{}]', ' ', text) + return set(re.findall(r'\b[a-zA-Z_][a-zA-Z0-9_]*\b', text)) + + # First pass: Find root rule and validate rule definitions + lines = [clean_line(line) for line in grammar_str.split('\n')] + first_rule = None + + for line_num, line in enumerate(lines, 1): + if not line or line.startswith('|'): + continue + + if ':' in line: + try: + name = line.split(':', 1)[0].strip().strip('?') + defined_rules.add(name) + if first_rule is None: + first_rule = name + if name == 'start': + first_rule = 'start' + except IndexError as e: + raise ValueError(f"Invalid rule format on line {line_num}. " + "Expected 'rule_name: definition'") from e + + if not defined_rules: + raise ValueError("No valid rules found in grammar") + + # Add root rule + output_lines.append(f"root ::= {first_rule}") + + # Second pass: Process rule definitions and alternatives + current_rule = None + current_definition = [] + + for line_num, line in enumerate(lines, 1): + if not line: + continue + + try: + if ':' in line and not line.startswith('|'): + # Save previous rule if exists + if current_rule: + output_lines.append( + f"{current_rule} ::= {' | '.join(current_definition)}") + + # Process new rule + name, definition = line.split(':', 1) + current_rule = name.strip().strip('?') + + check_quotes(definition, f"rule '{current_rule}'", line_num) + definition = re.sub(r"'([^']*)'", r'"\1"', definition) + referenced_rules.update(extract_references(definition)) + current_definition = [definition.strip()] + + elif line.startswith('|'): + if not current_rule: + raise ValueError(f"Alternative '|' on line {line_num} " + "without a preceding rule definition") + + alt_def = line[1:].strip() + check_quotes(alt_def, f"alternative for rule '{current_rule}'", + line_num) + alt_def = re.sub(r"'([^']*)'", r'"\1"', alt_def) + referenced_rules.update(extract_references(alt_def)) + current_definition.append(alt_def) + + except ValueError as e: + raise ValueError(f"Error on line {line_num}: {str(e)}") from e + + # Add final rule if exists + if current_rule: + output_lines.append( + f"{current_rule} ::= {' | '.join(current_definition)}") + + # Validate all rules are defined + undefined_rules = referenced_rules - defined_rules - {'root'} + if undefined_rules: + raise ValueError("Referenced rules are not defined: " + f"{', '.join(sorted(undefined_rules))}") + + return '\n'.join(output_lines) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 5570771ac917b..8c6f7c6e06515 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -242,7 +242,7 @@ def _load_per_tensor_weight_scale(self, shard_id: str, def _load_model_weight_or_group_weight_scale(self, shard_dim: int, expert_data: torch.Tensor, shard_id: str, - loaded_weight: torch.tensor, + loaded_weight: torch.Tensor, tp_rank: int): # Load grouped weight scales for group quantization # or model weights @@ -261,7 +261,7 @@ def _load_model_weight_or_group_weight_scale(self, shard_dim: int, def _load_per_channel_weight_scale(self, expert_data: torch.Tensor, shard_dim: int, shard_id: str, - loaded_weight: torch.tensor, + loaded_weight: torch.Tensor, tp_rank: int): # for per channel weight quantization if shard_id == "w2": @@ -274,7 +274,7 @@ def _load_per_channel_weight_scale(self, expert_data: torch.Tensor, tp_rank=tp_rank) def _load_w13(self, expert_data: torch.Tensor, shard_dim: int, - shard_id: str, loaded_weight: torch.tensor, tp_rank: int): + shard_id: str, loaded_weight: torch.Tensor, tp_rank: int): # Index the loaded weight for tp sharding. # gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim @@ -292,7 +292,7 @@ def _load_w13(self, expert_data: torch.Tensor, shard_dim: int, expert_data.copy_(loaded_weight) def _load_w2(self, expert_data: torch.Tensor, shard_dim: int, - shard_id: str, loaded_weight: torch.tensor, tp_rank: int): + shard_id: str, loaded_weight: torch.Tensor, tp_rank: int): # Index the loaded weight for tp sharding. # down_proj: "RowParallel" so tp sharding on input_dim @@ -311,7 +311,7 @@ def _load_single_value(self, param: torch.nn.Parameter, param_data[expert_id] = loaded_weight def _load_g_idx(self, shard_id: str, expert_data: torch.Tensor, - shard_dim: int, loaded_weight: torch.tensor, tp_rank: int): + shard_dim: int, loaded_weight: torch.Tensor, tp_rank: int): if shard_id == "w2": self._load_w2(shard_id=shard_id, diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py index f9437b4112ceb..e0d42e30ebef3 100644 --- a/vllm/model_executor/layers/pooler.py +++ b/vllm/model_executor/layers/pooler.py @@ -60,9 +60,7 @@ def from_config_with_defaults( softmax: bool, step_tag_id: Optional[int] = None, returned_token_ids: Optional[List[int]] = None, - ) -> Optional["Pooler"]: - if pooler_config is None: - return None + ) -> "Pooler": return cls( pooling_type=PoolingType[pooler_config.pooling_type] if pooler_config.pooling_type is not None else pooling_type, diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 37c2d789030b6..fdc4c6305bd5e 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -6,9 +6,9 @@ import glob import inspect import itertools -import json import math import os +import warnings from abc import ABC, abstractmethod from contextlib import contextmanager from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, cast @@ -17,7 +17,7 @@ import huggingface_hub import numpy as np import torch -from huggingface_hub import HfApi, hf_hub_download +from huggingface_hub import HfApi from torch import nn from transformers import AutoModelForCausalLM from transformers.utils import SAFE_WEIGHTS_INDEX_NAME @@ -97,22 +97,29 @@ def device_loading_context(module: torch.nn.Module, logger = init_logger(__name__) -def _initialize_model(vllm_config: VllmConfig, prefix: str = "") -> nn.Module: +def _initialize_model( + vllm_config: VllmConfig, + *, + prefix: str = "", +) -> nn.Module: """Initialize a model with the given configurations.""" model_config = vllm_config.model_config model_class, _ = get_model_architecture(model_config) + signatures = inspect.signature(model_class.__init__) all_params = [param.name for param in signatures.parameters.values()] if "vllm_config" in all_params and "prefix" in all_params: # new-style model class with set_current_vllm_config(vllm_config): return model_class(vllm_config=vllm_config, prefix=prefix) + msg = ("vLLM model class should accept `vllm_config` and `prefix` as " "input arguments. Possibly you have an old-style model class" " registered from out of tree and it is used for new vLLM version. " "Check https://docs.vllm.ai/en/latest/design/arch_overview.html " "for the design and update the model class accordingly.") - logger.warning(msg) + warnings.warn(msg, DeprecationWarning, stacklevel=2) + logger.warning( "Trying to guess the arguments for old-style model class %s", model_class, @@ -356,7 +363,7 @@ def load_model(self, vllm_config: VllmConfig) -> nn.Module: weights_to_load = {name for name, _ in model.named_parameters()} loaded_weights = model.load_weights( self._get_all_weights(model_config, model)) - # We only enable strict check for non-quantiized models + # We only enable strict check for non-quantized models # that have loaded weights tracking currently. if model_config.quantization is None and loaded_weights is not None: weights_not_loaded = weights_to_load - loaded_weights @@ -694,51 +701,9 @@ def __init__(self, load_config: LoadConfig): self.unsharded_weights_modules: List[str] = [] # Save the module names that are sharded by column. self.column_sharded_weights_modules: List[str] = [] - # we don't need to quantize the whole model, only the target modules - # that are specified in the adapter config file. If the adapter config - # file is not provided, we will quantize the default modules. - if (not load_config.model_loader_extra_config - or "qlora_adapter_name_or_path" - not in load_config.model_loader_extra_config): - self.target_modules = [] - return - - qlora_adapter = load_config.model_loader_extra_config[ - "qlora_adapter_name_or_path"] - - config_file_path = self._get_config_file(qlora_adapter) - - with open(config_file_path) as f: - config = json.load(f) - self.target_modules = config["target_modules"] - # TODO: target_modules could be either a list or a regex string. - # We need to handle both cases. - assert isinstance(self.target_modules, - list), "Unsupported target_modules: " - f"{self.target_modules}" - - def _get_config_file(self, qlora_adapter: str) -> str: - is_local = os.path.isdir(qlora_adapter) - config_file_path = None - if is_local: - for file in self.possible_config_file_names: - config_file_path = os.path.join(qlora_adapter, file) - if os.path.exists(config_file_path): - break - else: - hf_api = HfApi() - repo_files = hf_api.list_repo_files(repo_id=qlora_adapter) - for file in self.possible_config_file_names: - if file in repo_files: - config_file_path = hf_hub_download(repo_id=qlora_adapter, - filename=file) - break - - if not config_file_path: - raise ValueError( - f"Cannot find adapter config file in {qlora_adapter}") - - return config_file_path + # Store all module names (from transformers) that support + # BNB quantization. + self.target_modules: List[str] = [] def _get_weight_files( self, @@ -1020,25 +985,16 @@ def _get_bnb_target_modules(self, model: nn.Module) -> None: inverse_stacked_mapping[packed] = [] inverse_stacked_mapping[packed].insert(idx, orig) - linear_module_lst = [] for name, module in model.named_modules(): if isinstance(module, (LinearBase, )): last_name = name.split(".")[-1] if sub_modules := inverse_stacked_mapping.get(last_name, []): # Map vllm's names to transformers' names. for sub_name in sub_modules: - linear_module_lst.append( + self.target_modules.append( name.replace(last_name, sub_name)) else: - linear_module_lst.append(name) - if self.target_modules: - # Update self.target_modules - self.target_modules = [ - qual_name for qual_name in linear_module_lst - if any(t in qual_name for t in self.target_modules) - ] - else: - self.target_modules = linear_module_lst + self.target_modules.append(name) assert (self.target_modules ), "vllm currently does not support BNB quantization for" f" {type(model).__name__}" @@ -1110,7 +1066,14 @@ def _load_weights(self, model_config: ModelConfig, model_config.revision, pre_quant, load_8bit)) - model.load_weights(qweight_iterator) + weights_to_load = {name for name, _ in model.named_parameters()} + loaded_weights = model.load_weights(qweight_iterator) + # Some models may have weights loading tracker unimplemented. + if loaded_weights is not None: + weights_not_loaded = weights_to_load - loaded_weights + if weights_not_loaded: + raise ValueError("Following weights were not initialized from " + f"checkpoint: {weights_not_loaded}") torch.cuda.empty_cache() @@ -1142,9 +1105,10 @@ def _load_weights(self, model_config: ModelConfig, shard_name, weight_name) break + # Models like Clip/Siglip may skip some layers in initialization, + # causing unused quant_param_name in state_dict. if quant_param_name not in param_dict: - raise ValueError( - f"Parameter {quant_param_name} not found in the model.") + continue if quant_param_name not in stacked_quant_state_dict: stacked_quant_state_dict[quant_param_name] = {} diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index b95c0b7cd0612..cfb89e0f336bc 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -7,6 +7,7 @@ from vllm.config import ModelConfig from vllm.model_executor.models import ModelRegistry +from vllm.model_executor.models.adapters import as_embedding_model @contextlib.contextmanager @@ -21,6 +22,7 @@ def set_default_torch_dtype(dtype: torch.dtype): def get_model_architecture( model_config: ModelConfig) -> Tuple[Type[nn.Module], str]: architectures = getattr(model_config.hf_config, "architectures", []) + # Special handling for quantized Mixtral. # FIXME(woosuk): This is a temporary hack. mixtral_supported = [ @@ -32,7 +34,11 @@ def get_model_architecture( and "MixtralForCausalLM" in architectures): architectures = ["QuantMixtralForCausalLM"] - return ModelRegistry.resolve_model_cls(architectures) + model_cls, arch = ModelRegistry.resolve_model_cls(architectures) + if model_config.task == "embedding": + model_cls = as_embedding_model(model_cls) + + return model_cls, arch def get_architecture_class_name(model_config: ModelConfig) -> str: diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index d66373512b95e..a3ef9adad16d9 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -1,15 +1,14 @@ from .interfaces import (HasInnerState, SupportsLoRA, SupportsMultiModal, SupportsPP, has_inner_state, supports_lora, supports_multimodal, supports_pp) -from .interfaces_base import (VllmModelForEmbedding, - VllmModelForTextGeneration, is_embedding_model, - is_text_generation_model) +from .interfaces_base import (VllmModelForPooling, VllmModelForTextGeneration, + is_pooling_model, is_text_generation_model) from .registry import ModelRegistry __all__ = [ "ModelRegistry", - "VllmModelForEmbedding", - "is_embedding_model", + "VllmModelForPooling", + "is_pooling_model", "VllmModelForTextGeneration", "is_text_generation_model", "HasInnerState", @@ -20,4 +19,4 @@ "supports_multimodal", "SupportsPP", "supports_pp", -] \ No newline at end of file +] diff --git a/vllm/model_executor/models/adapters.py b/vllm/model_executor/models/adapters.py new file mode 100644 index 0000000000000..9cc43ae9181b9 --- /dev/null +++ b/vllm/model_executor/models/adapters.py @@ -0,0 +1,98 @@ +from collections.abc import Iterable +from typing import Any, TypeVar + +import torch +import torch.nn as nn + +from .interfaces_base import VllmModelForPooling, is_pooling_model + +_T = TypeVar("_T", bound=type[nn.Module]) + + +def as_embedding_model(cls: _T) -> _T: + """Subclass an existing vLLM model to support embeddings.""" + # Avoid modifying existing embedding models + if is_pooling_model(cls): + return cls + + # Lazy import + from vllm.config import VllmConfig + from vllm.model_executor.layers.pooler import (Pooler, PoolerOutput, + PoolingType) + from vllm.model_executor.pooling_metadata import PoolingMetadata + + from .utils import AutoWeightsLoader, WeightsMapper + + class ModelForEmbedding(cls, VllmModelForPooling): + + def __init__( + self, + *, + vllm_config: "VllmConfig", + prefix: str = "", + **kwargs: Any, + ) -> None: + super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs) + + # These are not used in embedding models + for attr in ("lm_head", "logits_processor"): + if hasattr(self, attr): + delattr(self, attr) + + pooler_config = vllm_config.model_config.pooler_config + assert pooler_config is not None + + # If the model already defines a pooler instance, don't overwrite it + if not getattr(self, "_pooler", None): + self._pooler = Pooler.from_config_with_defaults( + pooler_config, + pooling_type=PoolingType.LAST, + normalize=True, + softmax=False, + ) + + def pooler( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> PoolerOutput: + return self._pooler(hidden_states, pooling_metadata) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): + # TODO: Support uninitialized params tracking + + # We have deleted this attribute, so don't load it + weights = ((name, data) for name, data in weights + if not name.startswith("lm_head.")) + + # If `*ForCausalLM` defines `load_weights` on the inner model + # and there are no other inner modules with parameters, + # we support loading from both `*Model` and `*ForCausalLM` + if hasattr(self, "model") and hasattr(self.model, "load_weights"): + # Whether only `self.model` contains parameters + model_is_only_param = all( + name == "model" or next(child.parameters(), None) is None + for name, child in self.named_children()) + + if model_is_only_param: + mapper = WeightsMapper(orig_to_new_prefix={"model.": ""}) + weights = mapper.apply(weights) + + self.model.load_weights(weights) + return + + # For most other models + if hasattr(cls, "load_weights"): + cls.load_weights(self, weights) # type: ignore + # Fallback + else: + loader = AutoWeightsLoader(self) + loader.load_weights(weights) + + ModelForEmbedding.__name__ = cls.__name__ \ + .removesuffix("ForCausalLM") \ + .removesuffix("ForConditionalGeneration") \ + .removesuffix("ChatModel") \ + .removesuffix("LMHeadModel") + "ForEmbedding" + + return ModelForEmbedding # type: ignore diff --git a/vllm/model_executor/models/aria.py b/vllm/model_executor/models/aria.py index fa6b95f5481ad..dd4b0c75cb84d 100644 --- a/vllm/model_executor/models/aria.py +++ b/vllm/model_executor/models/aria.py @@ -32,9 +32,8 @@ maybe_prefix, merge_multimodal_embeddings) from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.base import MultiModalInputs from vllm.multimodal.image import cached_get_image_processor -from vllm.multimodal.inputs import NestedTensors +from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors from vllm.multimodal.utils import (cached_get_tokenizer, repeat_and_pad_placeholder_tokens) from vllm.sequence import IntermediateTensors @@ -451,7 +450,7 @@ def get_max_multimodal_tokens(ctx): def input_mapper_for_aria(ctx, data): - return MultiModalInputs(data) + return MultiModalKwargs(data) def input_processor(ctx, llm_inputs): diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py index 1fff72b3490e9..053d838432885 100644 --- a/vllm/model_executor/models/bert.py +++ b/vllm/model_executor/models/bert.py @@ -443,6 +443,8 @@ def pooler( def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""}) weights = hf_to_vllm_mapper.apply(weights) + weights = ((name, data) for name, data in weights + if not name.startswith("lm_head.")) self.model.load_weights(weights) def _build_model(self, diff --git a/vllm/model_executor/models/blip.py b/vllm/model_executor/models/blip.py index 6af59697160a0..42a239cadac46 100644 --- a/vllm/model_executor/models/blip.py +++ b/vllm/model_executor/models/blip.py @@ -4,11 +4,10 @@ import torch import torch.nn as nn -import torch.nn.functional as F from PIL import Image from transformers import Blip2VisionConfig, BlipVisionConfig -from vllm.attention.selector import _Backend +from vllm.attention.layer import MultiHeadAttention from vllm.config import ModelConfig from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.inputs import DecoderOnlyInputs, token_inputs @@ -22,8 +21,6 @@ repeat_and_pad_placeholder_tokens) from vllm.sequence import SequenceData -from .utils import get_vit_attn_backend - def get_blip_patch_grid_length(*, image_size: int, patch_size: int) -> int: assert image_size % patch_size == 0 @@ -205,11 +202,8 @@ def __init__( self.tp_size = get_tensor_model_parallel_world_size() self.num_heads_per_partition = divide(self.num_heads, self.tp_size) - # Detect attention implementation. - self.attn_backend = get_vit_attn_backend(support_fa=False) - if self.attn_backend not in {_Backend.TORCH_SDPA, _Backend.XFORMERS}: - raise RuntimeError( - f"BLIP does not support {self.attn_backend} backend now.") + self.attn = MultiHeadAttention(self.num_heads_per_partition, + self.head_dim, self.scale) def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, @@ -220,41 +214,10 @@ def forward( hidden_states: torch.Tensor, ): """Input shape: Batch x Time x Channel""" - bsz, tgt_len, _ = hidden_states.size() qkv_states, _ = self.qkv(hidden_states) query_states, key_states, value_states = qkv_states.chunk(3, dim=-1) - query_states = query_states.view(bsz, tgt_len, - self.num_heads_per_partition, - self.head_dim) - key_states = key_states.view(bsz, tgt_len, - self.num_heads_per_partition, - self.head_dim) - value_states = value_states.view(bsz, tgt_len, - self.num_heads_per_partition, - self.head_dim) - - if self.attn_backend == _Backend.XFORMERS: - from xformers import ops as xops - - out = xops.memory_efficient_attention_forward(query_states, - key_states, - value_states, - p=self.dropout, - scale=self.scale) - elif self.attn_backend == _Backend.TORCH_SDPA: - query_states, key_states, value_states = (x.transpose(1, 2) - for x in (query_states, - key_states, - value_states)) - out = F.scaled_dot_product_attention(query_states, - key_states, - value_states, - dropout_p=self.dropout, - scale=self.scale) - out = out.transpose(1, 2) - - out = out.view(bsz, tgt_len, -1) + out = self.attn(query_states, key_states, value_states) attn_output, _ = self.projection(out) return attn_output, None diff --git a/vllm/model_executor/models/blip2.py b/vllm/model_executor/models/blip2.py index d2592016aff34..76b8505ee1c2a 100644 --- a/vllm/model_executor/models/blip2.py +++ b/vllm/model_executor/models/blip2.py @@ -512,9 +512,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): ) self.language_model = init_vllm_registered_model( - config.text_config, vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "language_model")) + hf_config=config.text_config, + prefix=maybe_prefix(prefix, "language_model"), + ) self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/clip.py b/vllm/model_executor/models/clip.py index cd89519e95986..a5300dfd986f3 100644 --- a/vllm/model_executor/models/clip.py +++ b/vllm/model_executor/models/clip.py @@ -5,11 +5,10 @@ import numpy as np import torch import torch.nn as nn -import torch.nn.functional as F from PIL import Image from transformers import CLIPVisionConfig -from vllm.attention.selector import _Backend +from vllm.attention.layer import MultiHeadAttention from vllm.config import ModelConfig from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.inputs import DecoderOnlyInputs, token_inputs @@ -25,8 +24,6 @@ resolve_visual_encoder_outputs) from vllm.sequence import SequenceData -from .utils import get_vit_attn_backend - def get_clip_patch_grid_length(*, image_size: int, patch_size: int) -> int: assert image_size % patch_size == 0 @@ -235,11 +232,8 @@ def __init__( self.tp_size = get_tensor_model_parallel_world_size() self.num_heads_per_partition = divide(self.num_heads, self.tp_size) - # Detect attention implementation. - self.attn_backend = get_vit_attn_backend(support_fa=False) - if self.attn_backend not in {_Backend.TORCH_SDPA, _Backend.XFORMERS}: - raise RuntimeError( - f"CLIP does not support {self.attn_backend} backend now.") + self.attn = MultiHeadAttention(self.num_heads_per_partition, + self.head_dim, self.scale) def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, @@ -250,42 +244,10 @@ def forward( hidden_states: torch.Tensor, ): """Input shape: Batch x Time x Channel""" - bsz, tgt_len, _ = hidden_states.size() qkv_states, _ = self.qkv_proj(hidden_states) query_states, key_states, value_states = qkv_states.chunk(3, dim=-1) - - query_states = query_states.view(bsz, tgt_len, - self.num_heads_per_partition, - self.head_dim) - key_states = key_states.view(bsz, tgt_len, - self.num_heads_per_partition, - self.head_dim) - value_states = value_states.view(bsz, tgt_len, - self.num_heads_per_partition, - self.head_dim) - - if self.attn_backend == _Backend.XFORMERS: - from xformers import ops as xops - - out = xops.memory_efficient_attention_forward(query_states, - key_states, - value_states, - p=self.dropout, - scale=self.scale) - elif self.attn_backend == _Backend.TORCH_SDPA: - query_states, key_states, value_states = (x.transpose(1, 2) - for x in (query_states, - key_states, - value_states)) - out = F.scaled_dot_product_attention(query_states, - key_states, - value_states, - dropout_p=self.dropout, - scale=self.scale) - out = out.transpose(1, 2) - - out = out.view(bsz, tgt_len, -1) + out = self.attn(query_states, key_states, value_states) attn_output, _ = self.out_proj(out) return attn_output, None diff --git a/vllm/model_executor/models/exaone.py b/vllm/model_executor/models/exaone.py index 5ca26d53a17e7..0398f0943a70a 100644 --- a/vllm/model_executor/models/exaone.py +++ b/vllm/model_executor/models/exaone.py @@ -473,10 +473,11 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size, logit_scale) - self.sampler = get_sampler() else: self.lm_head = PPMissingLayer() + self.sampler = get_sampler() + self.make_empty_intermediate_tensors = ( self.transformer.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/gemma2.py b/vllm/model_executor/models/gemma2.py index c93223c740272..4664aa53ea092 100644 --- a/vllm/model_executor/models/gemma2.py +++ b/vllm/model_executor/models/gemma2.py @@ -30,19 +30,17 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.pooler import Pooler, PoolingType from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, PoolerOutput +from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP -from .utils import (AutoWeightsLoader, WeightsMapper, extract_layer_index, +from .utils import (AutoWeightsLoader, extract_layer_index, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -455,53 +453,3 @@ def load_weights(self, weights: Iterable[Tuple[str, if self.config.tie_word_embeddings else None), ) return loader.load_weights(weights) - - -class Gemma2EmbeddingModel(nn.Module, SupportsPP): - """ - A model that uses Gemma2 with additional embedding functionalities. - - This class encapsulates the Gemma2Model and provides an interface for - embedding operations and customized pooling functions. - - Attributes: - model: An instance of Gemma2Model used for forward operations. - _pooler: An instance of Pooler used for pooling operations. - """ - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() - - self.model = Gemma2Model(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) - self._pooler = Pooler.from_config_with_defaults( - vllm_config.model_config.pooler_config, - pooling_type=PoolingType.LAST, - normalize=True, - softmax=False) - self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) - - def forward( - self, - input_ids: Optional[torch.Tensor], - positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: - return self.model(input_ids, positions, kv_caches, attn_metadata, - intermediate_tensors, inputs_embeds) - - def pooler( - self, - hidden_states: torch.Tensor, - pooling_metadata: PoolingMetadata, - ) -> Optional[PoolerOutput]: - return self._pooler(hidden_states, pooling_metadata) - - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""}) - weights = hf_to_vllm_mapper.apply(weights) - self.model.load_weights(weights) diff --git a/vllm/model_executor/models/glm.py b/vllm/model_executor/models/glm.py new file mode 100644 index 0000000000000..942d1e14baed1 --- /dev/null +++ b/vllm/model_executor/models/glm.py @@ -0,0 +1,21 @@ +"""Inference-only HF format GLM-4 model compatible with THUDM weights.""" +from vllm.config import VllmConfig +from vllm.model_executor.models.llama import LlamaForCausalLM + +from .utils import PPMissingLayer + + +class GlmForCausalLM(LlamaForCausalLM): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, prefix=prefix) + # Hack Llama model to fit HF format GLM implementation + # Attention difference between GLM and Llama: + # 1. Half partial rotary_dim and no Neox style. + # 2. There is no bias for o_proj in attention + for layer in self.model.layers: + if not isinstance(layer, PPMissingLayer): + layer.self_attn.rotary_emb.rotary_dim //= 2 + layer.self_attn.rotary_emb.is_neox_style = False + layer.self_attn.o_proj.bias = None + layer.self_attn.o_proj.skip_bias_add = True diff --git a/vllm/model_executor/models/glm4_vision_encoder.py b/vllm/model_executor/models/glm4_vision_encoder.py index f37ab0f82d52a..39a5736eb199b 100644 --- a/vllm/model_executor/models/glm4_vision_encoder.py +++ b/vllm/model_executor/models/glm4_vision_encoder.py @@ -8,6 +8,7 @@ from torch import nn from torch.nn import LayerNorm +from vllm.attention.layer import MultiHeadAttention from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -77,27 +78,16 @@ def __init__( quant_config=quant_config, ) + self.attn = MultiHeadAttention(self.num_heads_per_rank, self.head_dim, + self.scale) self.output_dropout = torch.nn.Dropout(config.dropout_prob) def forward(self, x: torch.Tensor) -> torch.Tensor: - B, L, _ = x.shape qkv, _ = self.query_key_value(x) # B, L, 3 * H * D q, k, v = qkv.chunk(3, dim=-1) - q = q.reshape(B, L, self.num_heads_per_rank, - self.head_dim).permute(0, 2, 1, 3) # B, H, L, D - k = k.reshape(B, L, self.num_heads_per_rank, - self.head_dim).permute(0, 2, 1, 3) # B, H, L, D - v = v.reshape(B, L, self.num_heads_per_rank, - self.head_dim).permute(0, 2, 1, 3) # B, H, L, D - - out = torch.nn.functional.scaled_dot_product_attention(q, - k, - v, - attn_mask=None, - dropout_p=0., - is_causal=False) - - output, _ = self.dense(out.transpose(1, 2).view(B, L, -1)) + + out = self.attn(q, k, v) + output, _ = self.dense(out) output = self.output_dropout(output) return output diff --git a/vllm/model_executor/models/granite.py b/vllm/model_executor/models/granite.py index bd2394e71c973..f9e0443b9a508 100644 --- a/vllm/model_executor/models/granite.py +++ b/vllm/model_executor/models/granite.py @@ -400,16 +400,17 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.lm_head.weight = self.model.embed_tokens.weight logit_scale = getattr(config, "logit_scale", 1.0) - if hasattr(config, "logits_scaling"): logit_scale /= config.logits_scaling + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size, scale=logit_scale) - self.sampler = get_sampler() else: self.lm_head = PPMissingLayer() + self.sampler = get_sampler() + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) diff --git a/vllm/model_executor/models/idefics2_vision_model.py b/vllm/model_executor/models/idefics2_vision_model.py index 16192928beb1f..e430a158d869a 100644 --- a/vllm/model_executor/models/idefics2_vision_model.py +++ b/vllm/model_executor/models/idefics2_vision_model.py @@ -21,8 +21,8 @@ from torch import nn from transformers.models.idefics2.configuration_idefics2 import ( Idefics2Config, Idefics2VisionConfig) -from xformers import ops as xops +from vllm.attention.layer import MultiHeadAttention from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -141,35 +141,18 @@ def __init__( ) self.tp_size = get_tensor_model_parallel_world_size() self.num_heads_per_partition = divide(self.num_heads, self.tp_size) - self.is_causal = False + self.attn = MultiHeadAttention(self.num_heads_per_partition, + self.head_dim, self.scale) def forward( self, hidden_states: torch.Tensor, ) -> torch.Tensor: - batch_size, q_len, _ = hidden_states.size() qkv, _ = self.qkv_proj( hidden_states ) # batch_size, q_len, 3 * num_heads_per_partition * head_dim query_states, key_states, value_states = qkv.chunk(3, dim=-1) - query_states = query_states.view(batch_size, q_len, - self.num_heads_per_partition, - self.head_dim) - key_states = key_states.view(batch_size, q_len, - self.num_heads_per_partition, - self.head_dim) - value_states = value_states.view(batch_size, q_len, - self.num_heads_per_partition, - self.head_dim) - # see: https://facebookresearch.github.io/xformers/components/ops.html - out = xops.memory_efficient_attention_forward( - query_states, - key_states, - value_states, - p=self.dropout, - scale=self.scale, - ) - out = out.view(batch_size, q_len, -1) + out = self.attn(query_states, key_states, value_states) attn_output, _ = self.out_proj(out) return attn_output diff --git a/vllm/model_executor/models/idefics3.py b/vllm/model_executor/models/idefics3.py index 014e27bc869d4..e5d2edbd81eb1 100644 --- a/vllm/model_executor/models/idefics3.py +++ b/vllm/model_executor/models/idefics3.py @@ -267,54 +267,56 @@ def input_processor_for_idefics3(ctx: InputContext, n_images_in_text = [] text = inputs.get("prompt") - if text is not None: - if isinstance(text, str): - text = [text] - elif not isinstance(text, list) and not isinstance(text[0], str): - raise ValueError("Invalid input text. Please provide a string, " - "or a list of strings") - - fake_image_token = processor.fake_image_token.content - image_token = processor.image_token.content - global_img_token = processor.global_image_tag - - prompt_strings = [] - for sample, sample_rows, sample_cols in zip(text, image_rows, - image_cols): - n_images_in_text.append(sample.count(image_token)) - - # Replace the image token with fake tokens around the expanded - # image token sequence of length `image_seq_len` - image_prompt_strings = [] - for n_rows, n_cols in zip(sample_rows, sample_cols): - image_prompt_string = _get_image_prompt_string( - n_rows, - n_cols, - processor.image_seq_len, - image_token=image_token, - fake_token_around_image=fake_image_token, - global_img_token=global_img_token, - ) - image_prompt_strings.append(image_prompt_string) - - split_sample = sample.split(image_token) - if len(split_sample) == 0: - raise ValueError( - "The image token should be present in the text.") + if text is None: + prompt_token_ids = inputs.get("prompt_token_ids", []) + assert prompt_token_ids + text = tokenizer.decode(prompt_token_ids) + + if isinstance(text, str): + text = [text] + elif not isinstance(text, list) and not isinstance(text[0], str): + raise ValueError("Invalid input text. Please provide a string, " + "or a list of strings") + + fake_image_token = processor.fake_image_token.content + image_token = processor.image_token.content + global_img_token = processor.global_image_tag + + prompt_strings = [] + for sample, sample_rows, sample_cols in zip(text, image_rows, image_cols): + n_images_in_text.append(sample.count(image_token)) + + # Replace the image token with fake tokens around the expanded + # image token sequence of length `image_seq_len` + image_prompt_strings = [] + for n_rows, n_cols in zip(sample_rows, sample_cols): + image_prompt_string = _get_image_prompt_string( + n_rows, + n_cols, + processor.image_seq_len, + image_token=image_token, + fake_token_around_image=fake_image_token, + global_img_token=global_img_token, + ) + image_prompt_strings.append(image_prompt_string) - # Place in the image prompt strings where the image tokens are - sample = split_sample[0] - for i, image_prompt_string in enumerate(image_prompt_strings): - sample += image_prompt_string + split_sample[i + 1] - prompt_strings.append(sample) + split_sample = sample.split(image_token) + if len(split_sample) == 0: + raise ValueError("The image token should be present in the text.") - prompt_token_ids = tokenizer(text=prompt_strings[0]).input_ids + # Place in the image prompt strings where the image tokens are + sample = split_sample[0] + for i, image_prompt_string in enumerate(image_prompt_strings): + sample += image_prompt_string + split_sample[i + 1] + prompt_strings.append(sample) - return token_inputs( - prompt_token_ids=prompt_token_ids, - prompt=prompt_strings[0], - multi_modal_data=multi_modal_data, - ) + prompt_token_ids = tokenizer(text=prompt_strings[0]).input_ids + + return token_inputs( + prompt_token_ids=prompt_token_ids, + prompt=prompt_strings[0], + multi_modal_data=multi_modal_data, + ) def _get_max_num_image_patch(image_processor: Idefics3ImageProcessor) -> int: diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index 1545ce332309f..c3979eab905db 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -7,7 +7,7 @@ from vllm.logger import init_logger from vllm.utils import supports_kw -from .interfaces_base import is_embedding_model +from .interfaces_base import is_pooling_model if TYPE_CHECKING: from vllm.attention import AttentionMetadata @@ -36,6 +36,11 @@ def get_multimodal_embeddings(self, **kwargs) -> Optional[T]: """ Returns multimodal embeddings generated from multimodal kwargs to be merged with text embeddings. + + The output embeddings must be one of the following formats: + - A list or tuple of 2D tensors, where each tensor corresponds to + each input image. + - A single 3D tensor, with the batch dimension grouping the 2D tensors. """ ... @@ -389,4 +394,4 @@ def _supports_cross_encoding( def supports_cross_encoding( model: Union[Type[object], object], ) -> Union[TypeIs[Type[SupportsCrossEncoding]], TypeIs[SupportsCrossEncoding]]: - return is_embedding_model(model) and _supports_cross_encoding(model) + return is_pooling_model(model) and _supports_cross_encoding(model) diff --git a/vllm/model_executor/models/interfaces_base.py b/vllm/model_executor/models/interfaces_base.py index 957a5a6e26b5c..de733b6d49a53 100644 --- a/vllm/model_executor/models/interfaces_base.py +++ b/vllm/model_executor/models/interfaces_base.py @@ -141,7 +141,7 @@ def is_text_generation_model( @runtime_checkable -class VllmModelForEmbedding(VllmModel[C_co, T], Protocol[C_co, T]): +class VllmModelForPooling(VllmModel[C_co, T], Protocol[C_co, T]): def pooler( self, @@ -153,23 +153,22 @@ def pooler( @overload -def is_embedding_model( - model: Type[object]) -> TypeIs[Type[VllmModelForEmbedding]]: +def is_pooling_model(model: Type[object]) -> TypeIs[Type[VllmModelForPooling]]: ... @overload -def is_embedding_model(model: object) -> TypeIs[VllmModelForEmbedding]: +def is_pooling_model(model: object) -> TypeIs[VllmModelForPooling]: ... -def is_embedding_model( +def is_pooling_model( model: Union[Type[object], object], -) -> Union[TypeIs[Type[VllmModelForEmbedding]], TypeIs[VllmModelForEmbedding]]: +) -> Union[TypeIs[Type[VllmModelForPooling]], TypeIs[VllmModelForPooling]]: if not is_vllm_model(model): return False if isinstance(model, type): - return isinstance(model, VllmModelForEmbedding) + return isinstance(model, VllmModelForPooling) - return isinstance(model, VllmModelForEmbedding) + return isinstance(model, VllmModelForPooling) diff --git a/vllm/model_executor/models/intern_vit.py b/vllm/model_executor/models/intern_vit.py index c4346fcb3bd2a..7ff68bd60e8ad 100644 --- a/vllm/model_executor/models/intern_vit.py +++ b/vllm/model_executor/models/intern_vit.py @@ -12,7 +12,7 @@ import torch.nn.functional as F from transformers import PretrainedConfig -from vllm.attention.selector import _Backend +from vllm.attention.layer import MultiHeadAttention from vllm.distributed import (divide, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, split_tensor_along_last_dim, @@ -25,8 +25,6 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from .utils import get_vit_attn_backend - NORM2FN = { 'rms_norm': RMSNorm, 'layer_norm': nn.LayerNorm, @@ -183,10 +181,8 @@ def __init__( prefix=f"{prefix}.proj", ) - self.attn_backend = get_vit_attn_backend(support_fa=False) - if self.attn_backend not in {_Backend.TORCH_SDPA, _Backend.XFORMERS}: - raise RuntimeError( - f"InternViT does not support {self.attn_backend} backend now.") + self.attn = MultiHeadAttention(self.num_heads_per_partition, + self.head_dim, self.scale) def _apply_qk_norm(self, q: torch.Tensor, k: torch.Tensor): if self.tp_size > 1: @@ -209,23 +205,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: if self.qk_normalization: q, k = self._apply_qk_norm(q, k) - q = q.view(B, N, self.num_heads_per_partition, self.head_dim) - k = k.view(B, N, self.num_heads_per_partition, self.head_dim) - v = v.view(B, N, self.num_heads_per_partition, self.head_dim) - - if self.attn_backend == _Backend.XFORMERS: - from xformers import ops as xops - - out = xops.memory_efficient_attention_forward(q, - k, - v, - scale=self.scale) - elif self.attn_backend == _Backend.TORCH_SDPA: - q, k, v = (x.transpose(1, 2) for x in (q, k, v)) - out = F.scaled_dot_product_attention(q, k, v, scale=self.scale) - out = out.transpose(1, 2) - - out = out.view(B, N, -1) + out = self.attn(q, k, v) out, _ = self.proj(out) return out diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py index 906128940ff76..41b9f110d771f 100644 --- a/vllm/model_executor/models/internlm2.py +++ b/vllm/model_executor/models/internlm2.py @@ -27,7 +27,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from .interfaces import SupportsPP +from .interfaces import SupportsLoRA, SupportsPP from .utils import (is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -319,7 +319,21 @@ def forward( return hidden_states -class InternLM2ForCausalLM(nn.Module, SupportsPP): +class InternLM2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA): + packed_modules_mapping = { + "wqkv": ["wqkv"], + "gate_up_proj": ["w1", "w3"], + } + + # LoRA specific attributes + supported_lora_modules = [ + "wqkv", + "wo", + "gate_up_proj", + "w2", + ] + embedding_modules = {} + embedding_padding_modules = [] def __init__(self, *, @@ -329,8 +343,12 @@ def __init__(self, super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + self.config = config self.quant_config = quant_config + self.lora_config = lora_config + self.model = model_type(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) self.output = ParallelLMHead(config.vocab_size, diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index b1c0065afbf30..42c769f79e202 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -26,7 +26,7 @@ InternVisionPatchModel) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs -from vllm.multimodal.inputs import NestedTensors +from vllm.multimodal.inputs import NestedTensors, PlaceholderRange from vllm.multimodal.utils import cached_get_tokenizer from vllm.sequence import IntermediateTensors from vllm.utils import is_list_of @@ -52,12 +52,18 @@ class InternVLImagePixelInputs(TypedDict): Shape: `(batch_size * num_images * (1 + num_patches), num_channels, height, width)` """ + patches_per_image: List[int] + """ + List of number of total patches for each image in the batch. + """ class InternVLImageEmbeddingInputs(TypedDict): type: Literal["image_embeds"] - data: torch.Tensor - """Shape: `(batch_size * num_images, image_feature_size, hidden_size)` + data: NestedTensors + """ + A tensor of shape `(num_images, total_image_feature_size, hidden_size)` + or a list of tensors of shape `(total_image_feature_size, hidden_size)` `hidden_size` must match the hidden size of language model backbone. """ @@ -349,10 +355,32 @@ def input_processor( new_prompt = self._expand_image_prompt(prompt, image_feature_sizes, num_patches) new_prompt_token_ids = tokenizer.encode(new_prompt) + img_context_token_id = tokenizer.encode(self.img_context_token, + add_special_tokens=False) + assert len(img_context_token_id) == 1, \ + (f"Invalid image token '{self.img_context_token}': A valid image " + f"token encodes to a single token ID, got {img_context_token_id}.") + img_context_token_id = img_context_token_id[0] + + # Get precise tracking of placeholder positions + token_idx = image_idx = 0 + placeholder_ranges = [] + while token_idx < len(new_prompt_token_ids): + if new_prompt_token_ids[token_idx] == img_context_token_id: + curr_image_featue_size = image_feature_sizes[image_idx] + placeholder_ranges.append( + PlaceholderRange(offset=token_idx, + length=curr_image_featue_size)) + image_idx += 1 + token_idx += curr_image_featue_size + else: + token_idx += 1 - return token_inputs(prompt=prompt, - prompt_token_ids=new_prompt_token_ids, - multi_modal_data=multi_modal_data) + return token_inputs( + prompt=prompt, + prompt_token_ids=new_prompt_token_ids, + multi_modal_data=multi_modal_data, + multi_modal_placeholders={"image": placeholder_ranges}) def input_mapper( self, @@ -474,13 +502,15 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: ) self.language_model = init_vllm_registered_model( - config.text_config, vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "language_model")) + hf_config=config.text_config, + prefix=maybe_prefix(prefix, "language_model"), + ) self.mlp1 = self._init_mlp1(config) self.img_context_token_id = None + self.visual_token_mask = None self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) @@ -612,35 +642,54 @@ def _parse_and_validate_image_input( if not isinstance(pixel_values, (torch.Tensor, list)): raise ValueError("Incorrect type of pixel values. " f"Got type: {type(pixel_values)}") + + patches_per_image = [] + for request_pixel_values in pixel_values: + for image_pixel_values in request_pixel_values: + patches_per_image.append(image_pixel_values.shape[0]) # We need to flatten (B, N, P) to (B*N*P), # so we call flatten_bn twice. return InternVLImagePixelInputs( type="pixel_values", data=self._validate_pixel_values( flatten_bn(flatten_bn(pixel_values), concat=True)), - ) + patches_per_image=patches_per_image) raise AssertionError("This line should be unreachable.") def _process_image_input( self, image_input: InternVLImageInputs, - ) -> torch.Tensor: + ) -> Tuple[torch.Tensor]: if image_input["type"] == "image_embeds": return image_input["data"] assert self.vision_model is not None + image_embeds = self.extract_feature(image_input["data"]) + patches_per_image = image_input["patches_per_image"] + if len(patches_per_image) == 1: + image_embeds = image_embeds.unsqueeze(0) + return image_embeds + + # NOTE: Image embeddings are split into separate tensors for each image + # by the size of each embedding. + feature_size = image_embeds.shape[1] + image_embeds = image_embeds.view(-1, + self.config.text_config.hidden_size) + image_feature_sizes = [ + num_patches * feature_size for num_patches in patches_per_image + ] + image_embeds = image_embeds.split(image_feature_sizes) return image_embeds - def _get_visual_token_mask(self, input_ids: torch.Tensor) -> torch.Tensor: + def _set_visual_token_mask(self, input_ids: torch.Tensor) -> torch.Tensor: if self.is_mono: - visual_token_mask = ( + self.visual_token_mask = ( input_ids == self.img_context_token_id).reshape(-1, 1) else: - visual_token_mask = None - return visual_token_mask + self.visual_token_mask = None def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]: image_input = self._parse_and_validate_image_input(**kwargs) @@ -657,6 +706,7 @@ def get_input_embeddings( inputs_embeds = self.language_model.get_input_embeddings(input_ids) if multimodal_embeddings is not None: assert self.img_context_token_id is not None + self._set_visual_token_mask(input_ids) inputs_embeds = merge_multimodal_embeddings( input_ids, inputs_embeds, multimodal_embeddings, self.img_context_token_id) @@ -673,7 +723,6 @@ def forward( **kwargs: object, ) -> Union[SamplerOutput, IntermediateTensors]: - visual_token_mask = None if intermediate_tensors is not None: input_ids = None inputs_embeds = None @@ -694,15 +743,12 @@ def forward( "intermediate_tensors": intermediate_tensors, "inputs_embeds": inputs_embeds, } - if self.img_context_token_id is not None: - visual_token_mask = self._get_visual_token_mask(input_ids) - # We always overwrite it back to None after computing visual token - # mask so that this doesn't need to depend on encoder output - self.img_context_token_id = None - - if self.is_mono: - forward_kwargs.update({"visual_token_mask": visual_token_mask}) + # Only required if the model is mono-architecture + if self.visual_token_mask is not None: + forward_kwargs.update( + {"visual_token_mask": self.visual_token_mask}) + self.visual_token_mask = None hidden_states = self.language_model.model(**forward_kwargs) return hidden_states diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 48c97a28dcac8..0f7739351eb64 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -7,7 +7,7 @@ from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.layer import Attention -from vllm.config import CacheConfig, VllmConfig +from vllm.config import _BATCH_SIZES_TO_CAPTURE, CacheConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed.parallel_state import get_pp_group from vllm.model_executor.layers.fused_moe import FusedMoE @@ -26,9 +26,6 @@ MambaCacheParams) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from vllm.utils import LayerBlockType -from vllm.worker.model_runner import (_BATCH_SIZES_TO_CAPTURE, - _get_graph_batch_size) from .interfaces import HasInnerState, SupportsLoRA, SupportsPP from .utils import (is_pp_missing_parameter, @@ -434,7 +431,7 @@ def forward(self, inputs_embeds: Optional[torch.Tensor] = None, **kwargs): if self.mamba_cache is None: - max_batch_size = (_get_graph_batch_size( + max_batch_size = (VllmConfig.get_graph_batch_size( self.scheduler_config.max_num_seqs) if self.scheduler_config else max(_BATCH_SIZES_TO_CAPTURE) + 2) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index fffb3fe53b94c..733b1bc7d80ac 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -37,7 +37,6 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.pooler import Pooler, PoolingType from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( get_compressed_tensors_cache_scale) @@ -47,13 +46,12 @@ DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name) -from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.platforms import current_platform -from vllm.sequence import IntermediateTensors, PoolerOutput +from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP -from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper, +from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -114,6 +112,7 @@ def __init__( prefix: str = "", ) -> None: super().__init__() + layer_idx = extract_layer_index(prefix) self.hidden_size = hidden_size tp_size = get_tensor_model_parallel_world_size() self.total_num_heads = num_heads @@ -168,6 +167,18 @@ def __init__( rope_scaling=rope_scaling, is_neox_style=is_neox_style, ) + + if hasattr(config, "interleaved_sliding_window"): + if isinstance(config.interleaved_sliding_window, int): + sliding_window = config.interleaved_sliding_window + elif isinstance(config.interleaved_sliding_window, list): + sw_idx = layer_idx % len(config.interleaved_sliding_window) + sliding_window = config.interleaved_sliding_window[sw_idx] + else: + raise ValueError(f"{type(sliding_window)} is not supported.") + else: + sliding_window = None + self.attn = Attention( self.num_heads, self.head_dim, @@ -175,6 +186,7 @@ def __init__( num_kv_heads=self.num_kv_heads, cache_config=cache_config, quant_config=quant_config, + per_layer_sliding_window=sliding_window, prefix=f"{prefix}.attn", ) @@ -497,11 +509,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config lora_config = vllm_config.lora_config - pooler_config = vllm_config.model_config.pooler_config self.config = config self.lora_config = lora_config - self.model = self._init_model(vllm_config=vllm_config, prefix=prefix) + self.model = self._init_model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + if get_pp_group().is_last_rank: self.unpadded_vocab_size = config.vocab_size if lora_config: @@ -527,16 +540,13 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size, logit_scale) - self.sampler = get_sampler() else: self.lm_head = PPMissingLayer() + + self.sampler = get_sampler() + self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) - self._pooler = Pooler.from_config_with_defaults( - pooler_config, - pooling_type=PoolingType.STEP, - normalize=False, - softmax=False) def _init_model(self, vllm_config: VllmConfig, prefix: str = ""): return LlamaModel(vllm_config=vllm_config, prefix=prefix) @@ -567,14 +577,6 @@ def compute_logits( sampling_metadata) return logits - def pooler( - self, - hidden_states: torch.Tensor, - pooling_metadata: PoolingMetadata, - ) -> Optional[PoolerOutput]: - logits = self.compute_logits(hidden_states, None) - return self._pooler(logits, pooling_metadata) - def sample(self, logits: torch.Tensor, sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]: next_tokens = self.sampler(logits, sampling_metadata) @@ -625,76 +627,3 @@ def permute(w: torch.Tensor, n_heads: int): name = name.replace(item, mapping[item]) return name, loaded_weight - - -class LlamaEmbeddingModel(nn.Module, SupportsLoRA, SupportsPP): - """ - A model that uses Llama with additional embedding functionalities. - - This class encapsulates the LlamaModel and provides an interface for - embedding operations and customized pooling functions. - - Attributes: - model: An instance of LlamaModel used for forward operations. - _pooler: An instance of Pooler used for pooling operations. - """ - packed_modules_mapping = { - "qkv_proj": ["q_proj", "k_proj", "v_proj"], - "gate_up_proj": ["gate_proj", "up_proj"] - } - - # LoRA specific attributes - supported_lora_modules = [ - "qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens" - ] - embedding_modules = { - "embed_tokens": "input_embeddings", - } - embedding_padding_modules = [] - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() - - pooler_config = vllm_config.model_config.pooler_config - - self.model = LlamaModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) - self._pooler = Pooler.from_config_with_defaults( - pooler_config, - pooling_type=PoolingType.LAST, - normalize=True, - softmax=False) - self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) - - def forward( - self, - input_ids: Optional[torch.Tensor], - positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: - return self.model(input_ids, positions, kv_caches, attn_metadata, - intermediate_tensors, inputs_embeds) - - def pooler( - self, - hidden_states: torch.Tensor, - pooling_metadata: PoolingMetadata, - ) -> Optional[PoolerOutput]: - return self._pooler(hidden_states, pooling_metadata) - - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""}) - weights = hf_to_vllm_mapper.apply(weights) - self.model.load_weights(weights) - - def load_kv_cache_scales(self, quantization_param_path: str) -> None: - self.model.load_kv_cache_scales(quantization_param_path) - - # LRUCacheWorkerLoRAManager instantiation requires model config. - @property - def config(self): - return self.model.config diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index e7757b3c7d405..65c6bd07bfff0 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -1,37 +1,41 @@ from functools import cached_property +from types import MethodType from typing import (Iterable, List, Literal, Mapping, Optional, Protocol, Set, Tuple, TypedDict, Union) import torch import torch.nn as nn -from PIL import Image -from transformers import (CLIPVisionConfig, LlavaConfig, PixtralVisionConfig, - PretrainedConfig, SiglipVisionConfig) +from PIL.Image import Image +from transformers import (BatchFeature, CLIPVisionConfig, LlavaConfig, + PixtralVisionConfig, PretrainedConfig, + ProcessorMixin, SiglipVisionConfig) +from transformers.models.pixtral import PixtralProcessor from vllm.attention import AttentionMetadata from vllm.config import VllmConfig -from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, - InputContext) +from vllm.inputs import InputContext from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import NestedTensors +from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors +from vllm.multimodal.processing import (BaseMultiModalProcessor, + InputProcessingContext, + ModalityProcessingMetadata, + MultiModalProcessingMetadata, + PromptReplacement) from vllm.sequence import IntermediateTensors -from vllm.utils import is_list_of from .clip import (CLIPVisionModel, dummy_image_for_clip, - dummy_seq_data_for_clip, get_max_clip_image_tokens, - input_processor_for_clip) + get_max_clip_image_tokens) from .interfaces import SupportsMultiModal, SupportsPP from .pixtral import (PixtralHFVisionModel, dummy_image_for_pixtral_hf, - dummy_seq_data_for_pixtral_hf, - get_max_pixtral_hf_image_tokens, - input_processor_for_pixtral_hf) + get_max_pixtral_hf_image_tokens) from .siglip import (SiglipVisionModel, dummy_image_for_siglip, - dummy_seq_data_for_siglip, get_max_siglip_image_tokens, - input_processor_for_siglip) + get_max_siglip_image_tokens) from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, maybe_prefix, merge_multimodal_embeddings) @@ -59,25 +63,32 @@ class LlavaImageEmbeddingInputs(TypedDict): LlavaImageInputs = Union[LlavaImagePixelInputs, LlavaImageEmbeddingInputs] -# TODO(xwjiang): Run benchmark and decide if TP. class LlavaMultiModalProjector(nn.Module): - def __init__(self, vision_hidden_size: int, text_hidden_size: int, - projector_hidden_act: str): + def __init__(self, + vision_hidden_size: int, + text_hidden_size: int, + projector_hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): super().__init__() - self.linear_1 = nn.Linear(vision_hidden_size, - text_hidden_size, - bias=True) + self.linear_1 = ColumnParallelLinear(vision_hidden_size, + text_hidden_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.linear_1") self.act = get_act_fn(projector_hidden_act) - self.linear_2 = nn.Linear(text_hidden_size, - text_hidden_size, - bias=True) + self.linear_2 = RowParallelLinear(text_hidden_size, + text_hidden_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.linear_2") def forward(self, image_features: torch.Tensor) -> torch.Tensor: - hidden_states = self.linear_1(image_features) + hidden_states, _ = self.linear_1(image_features) hidden_states = self.act(hidden_states) - hidden_states = self.linear_2(hidden_states) + hidden_states, _ = self.linear_2(hidden_states) return hidden_states @@ -104,102 +115,115 @@ def get_max_llava_image_tokens(ctx: InputContext): raise ValueError(f"Unexpected select feature strategy: {strategy}") -def dummy_data_for_llava(ctx: InputContext, seq_len: int, - mm_counts: Mapping[str, int]): +def dummy_mm_kwargs_for_llava(ctx: InputProcessingContext, + mm_counts: Mapping[str, int]): hf_config = ctx.get_hf_config(LlavaConfig) vision_config = hf_config.vision_config num_images = mm_counts["image"] - image_feature_size = get_max_llava_image_tokens(ctx) - if isinstance(vision_config, CLIPVisionConfig): - seq_data, ranges = dummy_seq_data_for_clip( - vision_config, - seq_len, - num_images, - image_token_id=hf_config.image_token_index, - image_feature_size_override=image_feature_size, - ) - - mm_data = dummy_image_for_clip(vision_config, num_images) - return DummyData(seq_data, mm_data, ranges) + data = dummy_image_for_clip(vision_config, num_images) elif isinstance(vision_config, SiglipVisionConfig): - seq_data, ranges = dummy_seq_data_for_siglip( - vision_config, - seq_len, - num_images, - image_token_id=hf_config.image_token_index, - image_feature_size_override=image_feature_size, - ) - - mm_data = dummy_image_for_siglip(vision_config, num_images) - return DummyData(seq_data, mm_data, ranges) + data = dummy_image_for_siglip(vision_config, num_images) elif isinstance(vision_config, PixtralVisionConfig): - seq_data, ranges = dummy_seq_data_for_pixtral_hf( - vision_config, - seq_len, - num_images, - image_token_id=hf_config.image_token_index, - image_feature_size_override=image_feature_size, + data = dummy_image_for_pixtral_hf(vision_config, num_images) + else: + msg = f"Unsupported vision config: {type(vision_config)}" + raise NotImplementedError(msg) + + hf_processor = ctx.get_hf_processor() + image_processor = hf_processor.image_processor # type: ignore + hf_inputs = image_processor.preprocess(data['image'], return_tensors="pt") + is_pixtral = isinstance(hf_processor, PixtralProcessor) + + return MultiModalKwargs( + **hf_inputs, + is_pixtral=torch.tensor(is_pixtral), + ) + + +def create_metadata_for_llava( + ctx: InputProcessingContext) -> MultiModalProcessingMetadata: + hf_config = ctx.get_hf_config(LlavaConfig) + image_token_id = hf_config.image_token_index + + def get_repl_count( + mm_items: list[Image], + hf_inputs: BatchFeature, + item_idx: int, + ) -> int: + return get_max_llava_image_tokens(ctx) + + return { + "image": + ModalityProcessingMetadata(prompt_repls=[ + PromptReplacement(target=[image_token_id], + repl_unit=[image_token_id], + repl_count=get_repl_count), + ]), + } + + +class LlavaProcessor(BaseMultiModalProcessor): + + def __init__(self, ctx: InputProcessingContext) -> None: + super().__init__( + ctx=ctx, + metadata=create_metadata_for_llava(ctx), ) - mm_data = dummy_image_for_pixtral_hf(vision_config, num_images) - return DummyData(seq_data, mm_data, ranges) + def _patch_pixtral_processor(self, hf_processor: PixtralProcessor): + if getattr(hf_processor, "__is_patched__", False): + return # Already patched - msg = f"Unsupported vision config: {type(vision_config)}" - raise NotImplementedError(msg) + image_processor = hf_processor.image_processor # type: ignore + orig_preprocess = image_processor.preprocess + def preprocess(__self, *args, **kwargs): + hf_inputs = orig_preprocess(*args, **kwargs) + hf_inputs["is_pixtral"] = torch.tensor(True) + return hf_inputs -def input_processor_for_llava(ctx: InputContext, inputs: DecoderOnlyInputs): - multi_modal_data = inputs.get("multi_modal_data") - if multi_modal_data is None or "image" not in multi_modal_data: - return inputs + image_processor.preprocess = MethodType(preprocess, image_processor) - model_config = ctx.model_config - hf_config = ctx.get_hf_config(LlavaConfig) - vision_config = hf_config.vision_config + hf_processor.__is_patched__ = True # type: ignore - image_data = multi_modal_data["image"] - if isinstance(image_data, Image.Image): - image_feature_size = get_max_llava_image_tokens(ctx) - elif is_list_of(image_data, Image.Image): - image_feature_size = [get_max_llava_image_tokens(ctx) - ] * len(image_data) - elif isinstance(image_data, torch.Tensor): - num_images, image_feature_size, hidden_size = image_data.shape - elif is_list_of(image_data, torch.Tensor): - image_feature_size = [item.shape[1] for item in image_data] - else: - raise TypeError(f"Invalid image type: {type(image_data)}") + def _get_hf_processor(self) -> ProcessorMixin: + hf_processor = self.ctx.get_hf_processor() - if isinstance(vision_config, CLIPVisionConfig): - return input_processor_for_clip( - model_config, - vision_config, - inputs, - image_token_id=hf_config.image_token_index, - image_feature_size_override=image_feature_size, - ) - elif isinstance(vision_config, SiglipVisionConfig): - return input_processor_for_siglip( - model_config, - vision_config, - inputs, - image_token_id=hf_config.image_token_index, - image_feature_size_override=image_feature_size, - ) - elif isinstance(vision_config, PixtralVisionConfig): - # We ignore image_feature_size_override since we have non-uniform - # image sizes for Pixtral - return input_processor_for_pixtral_hf( - model_config, - vision_config, - inputs, - image_token_id=hf_config.image_token_index, - ) + if isinstance(hf_processor, PixtralProcessor): + self._patch_pixtral_processor(hf_processor) - msg = f"Unsupported vision config: {type(vision_config)}" - raise NotImplementedError(msg) + return hf_processor + + def _get_dummy_mm_kwargs( + self, + mm_counts: Mapping[str, int], + ) -> MultiModalKwargs: + hf_config = self.ctx.get_hf_config(LlavaConfig) + vision_config = hf_config.vision_config + num_images = mm_counts["image"] + + if isinstance(vision_config, CLIPVisionConfig): + data = dummy_image_for_clip(vision_config, num_images) + elif isinstance(vision_config, SiglipVisionConfig): + data = dummy_image_for_siglip(vision_config, num_images) + elif isinstance(vision_config, PixtralVisionConfig): + data = dummy_image_for_pixtral_hf(vision_config, num_images) + else: + msg = f"Unsupported vision config: {type(vision_config)}" + raise NotImplementedError(msg) + + hf_processor = self._get_hf_processor() + image_processor = hf_processor.image_processor # type: ignore + hf_inputs = image_processor.preprocess(data['image'], + return_tensors="pt") + is_pixtral = isinstance(hf_processor, PixtralProcessor) + + return MultiModalKwargs( + **hf_inputs, + is_pixtral=torch.tensor(is_pixtral), + ) class LlavaLikeConfig(Protocol): @@ -282,11 +306,18 @@ def init_vision_tower_for_llava( raise NotImplementedError(msg) -@MULTIMODAL_REGISTRY.register_image_input_mapper() @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens) -@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava) -@INPUT_REGISTRY.register_input_processor(input_processor_for_llava) +@MULTIMODAL_REGISTRY.register_processor(LlavaProcessor) class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): + # BitandBytes specific attributes + bitsandbytes_stacked_params_mapping = { + # shard_name, weight_name, index + "q_proj": ("qkv_proj", 0), + "k_proj": ("qkv_proj", 1), + "v_proj": ("qkv_proj", 2), + "gate_proj": ("gate_up_proj", 0), + "up_proj": ("gate_up_proj", 1), + } def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: super().__init__() @@ -316,12 +347,15 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: self.multi_modal_projector = LlavaMultiModalProjector( vision_hidden_size=config.vision_config.hidden_size, text_hidden_size=config.text_config.hidden_size, - projector_hidden_act=config.projector_hidden_act) + projector_hidden_act=config.projector_hidden_act, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "multi_modal_projector")) self.language_model = init_vllm_registered_model( - config.text_config, vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "language_model")) + hf_config=config.text_config, + prefix=maybe_prefix(prefix, "language_model"), + ) self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) @@ -346,38 +380,10 @@ def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: return data - def _validate_image_sizes(self, images: List[torch.Tensor], - sizes: List[torch.Tensor]) -> List[torch.Tensor]: - if not isinstance(sizes, list): - sizes = [sizes] - - total_images = sum(size.numel() // 2 for size in sizes) - if total_images != len(images): - raise ValueError("Mismatch in number of images. " - f"Expected {total_images}, got {len(images)}") - img_idx = 0 - for size in sizes: - # Flatten the size tensor to a list of (height, width) pairs - size = size.view(-1, 2).tolist() - for expected_h, expected_w in size: - if img_idx >= len(images): - raise ValueError("Ran out of images before sizes. " - f"{img_idx} >= {len(images)}") - img = images[img_idx] - if img.shape[-2:] != (expected_h, expected_w): - raise ValueError( - "Image size mismatch. Expected " - f"{(expected_h, expected_w)}, got {img.shape[-2:]}") - if img.shape[-3] != 3: - raise ValueError("Image channel mismatch. Expected 3, " - f"got {img.shape[-3]}") - img_idx += 1 - return images - def _parse_and_validate_image_input( self, **kwargs: object) -> Optional[LlavaImageInputs]: pixel_values = kwargs.pop("pixel_values", None) - image_sizes = kwargs.pop("image_sizes", None) + is_pixtral = kwargs.pop("is_pixtral", torch.tensor([False])) image_embeds = kwargs.pop("image_embeds", None) if pixel_values is None and image_embeds is None: @@ -388,9 +394,8 @@ def _parse_and_validate_image_input( raise ValueError("Incorrect type of pixel values. " f"Got type: {type(pixel_values)}") - # Case for models like PixtralHF that have dynamic image sizes - # so we need to produce a list of tensors - if image_sizes is not None: + assert isinstance(is_pixtral, torch.Tensor) + if is_pixtral.any(): images = pixel_values def flatten_to_3d_tensors(item): @@ -413,7 +418,7 @@ def flatten_to_3d_tensors(item): return LlavaImagePixelInputs( type="pixel_values", - data=self._validate_image_sizes(images, image_sizes), + data=images, ) return LlavaImagePixelInputs( @@ -581,3 +586,28 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) + + +class MantisProcessor(LlavaProcessor): + + def _get_hf_processor(self) -> ProcessorMixin: + try: + from mantis.models.mllava import MLlavaProcessor + except ModuleNotFoundError as exc: + raise ModuleNotFoundError( + "You need to `pip install " + "git+https://github.com/TIGER-AI-Lab/Mantis.git` " + "to use this model") from exc + + processor = MLlavaProcessor.from_pretrained( + self.ctx.model_config.tokenizer) + assert isinstance(processor, ProcessorMixin) + return processor + + +# To use this model, please use +# `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'` +@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens) +@MULTIMODAL_REGISTRY.register_processor(MantisProcessor) +class MantisForConditionalGeneration(LlavaForConditionalGeneration): + pass diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py index e113f5862830d..a39f2f4124d05 100644 --- a/vllm/model_executor/models/llava_next.py +++ b/vllm/model_executor/models/llava_next.py @@ -14,13 +14,11 @@ from vllm.config import VllmConfig from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, InputContext) -from vllm.model_executor.layers.pooler import Pooler, PoolingType from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler -from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import NestedTensors -from vllm.sequence import IntermediateTensors, PoolerOutput +from vllm.sequence import IntermediateTensors from vllm.utils import is_list_of from .clip import (CLIPVisionModel, dummy_image_for_clip, @@ -286,7 +284,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config - pooler_config = vllm_config.model_config.pooler_config multimodal_config = vllm_config.model_config.multimodal_config vision_feature_layer = config.vision_feature_layer @@ -321,17 +318,11 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: projector_hidden_act=config.projector_hidden_act) self.language_model = init_vllm_registered_model( - config.text_config, vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "language_model")) - - # The same model class supports both language generation and embedding - # because the architecture name is the same - self._pooler = Pooler.from_config_with_defaults( - pooler_config, - pooling_type=PoolingType.LAST, - normalize=True, - softmax=False) + hf_config=config.text_config, + prefix=maybe_prefix(prefix, "language_model"), + ) + self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) @@ -678,13 +669,6 @@ def sample( ) -> Optional[SamplerOutput]: return self.language_model.sample(logits, sampling_metadata) - def pooler( - self, - hidden_states: torch.Tensor, - pooling_metadata: PoolingMetadata, - ) -> Optional[PoolerOutput]: - return self._pooler(hidden_states, pooling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader(self) diff --git a/vllm/model_executor/models/llava_next_video.py b/vllm/model_executor/models/llava_next_video.py index b130791808924..0de9d8c5ea572 100644 --- a/vllm/model_executor/models/llava_next_video.py +++ b/vllm/model_executor/models/llava_next_video.py @@ -275,9 +275,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: text_hidden_size=config.text_config.hidden_size, projector_hidden_act=config.projector_hidden_act) self.language_model = init_vllm_registered_model( - config.text_config, vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "language_model")) + hf_config=config.text_config, + prefix=maybe_prefix(prefix, "language_model"), + ) self.make_empty_intermediate_tensors = ( self.language_model.model.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/llava_onevision.py b/vllm/model_executor/models/llava_onevision.py index 3166737d61582..0bebc1c745e2b 100644 --- a/vllm/model_executor/models/llava_onevision.py +++ b/vllm/model_executor/models/llava_onevision.py @@ -422,9 +422,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: prefix=maybe_prefix(prefix, "vision_tower")) self.multi_modal_projector = LlavaOnevisionMultiModalProjector(config) self.language_model = init_vllm_registered_model( - config.text_config, vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "language_model")) + hf_config=config.text_config, + prefix=maybe_prefix(prefix, "language_model"), + ) self.image_newline = nn.Parameter( torch.empty(config.text_config.hidden_size)) diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index c893ca6a0f7c7..d5bda01638421 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -6,7 +6,7 @@ from transformers import MambaConfig from vllm.attention.backends.abstract import AttentionMetadata -from vllm.config import CacheConfig, VllmConfig +from vllm.config import _BATCH_SIZES_TO_CAPTURE, CacheConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed.parallel_state import get_pp_group from vllm.model_executor.layers.layernorm import RMSNorm @@ -24,8 +24,6 @@ MambaCacheParams) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from vllm.worker.model_runner import (_BATCH_SIZES_TO_CAPTURE, - _get_graph_batch_size) from .utils import (make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -205,7 +203,7 @@ def forward(self, inputs_embeds: Optional[torch.Tensor] = None, **kwargs): if self.mamba_cache is None: - max_batch_size = (_get_graph_batch_size( + max_batch_size = (VllmConfig.get_graph_batch_size( self.scheduler_config.max_num_seqs) if self.scheduler_config else max(_BATCH_SIZES_TO_CAPTURE) + 2) self.mamba_cache = MambaCacheManager( diff --git a/vllm/model_executor/models/minicpm.py b/vllm/model_executor/models/minicpm.py index c9a573278a136..5a0f202364f26 100644 --- a/vllm/model_executor/models/minicpm.py +++ b/vllm/model_executor/models/minicpm.py @@ -52,7 +52,7 @@ from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP -from .utils import (is_pp_missing_parameter, +from .utils import (AutoWeightsLoader, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -378,6 +378,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config.hidden_size, org_num_embeddings=config.vocab_size, ) + self.num_experts = getattr(self.config, "num_experts", 0) self._init_layers(prefix, config, cache_config, quant_config) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.make_empty_intermediate_tensors = ( @@ -437,6 +438,73 @@ def forward( hidden_states = self.norm(hidden_states) return hidden_states + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + expert_params_mapping = [ + # (param_name, weight_name, expert_id) + ("ws" if weight_name in ["w1", "w3"] else "w2s", + f"experts.{expert_id}.{weight_name}.weight", expert_id) + for expert_id in range(self.num_experts) + for weight_name in ["w1", "w2", "w3"] + ] + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if ("rotary_emb.cos_cached" in name + or "rotary_emb.sin_cached" in name): + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for param_name, weight_name, expert_id in expert_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, + loaded_weight, + weight_name, + expert_id=expert_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP): packed_modules_mapping = { @@ -466,6 +534,16 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP): } embedding_padding_modules = ["lm_head"] + # BitandBytes specific attributes + bitsandbytes_stacked_params_mapping = { + # shard_name, weight_name, index + "q_proj": ("qkv_proj", 0), + "k_proj": ("qkv_proj", 1), + "v_proj": ("qkv_proj", 2), + "gate_proj": ("gate_up_proj", 0), + "up_proj": ("gate_up_proj", 1), + } + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config @@ -480,8 +558,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.cache_config = cache_config self.quant_config = quant_config - self.num_experts = getattr(self.config, "num_experts", 0) - self._init_model(vllm_config=vllm_config, prefix=prefix) + self.model = self._init_model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + unpadded_vocab_size = config.vocab_size if lora_config: unpadded_vocab_size += lora_config.lora_extra_vocab_size @@ -506,8 +585,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.model.make_empty_intermediate_tensors) def _init_model(self, *, vllm_config: VllmConfig, prefix: str = ""): - self.model = MiniCPMModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + return MiniCPMModel(vllm_config=vllm_config, prefix=prefix) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -546,72 +624,9 @@ def sample( def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ("gate_up_proj", "gate_proj", 0), - ("gate_up_proj", "up_proj", 1), - ] - expert_params_mapping = [ - # (param_name, weight_name, expert_id) - ("ws" if weight_name in ["w1", "w3"] else "w2s", - f"experts.{expert_id}.{weight_name}.weight", expert_id) - for expert_id in range(self.num_experts) - for weight_name in ["w1", "w2", "w3"] - ] - params_dict = dict(self.named_parameters()) - loaded_params: Set[str] = set() - for name, loaded_weight in weights: - if "rotary_emb.inv_freq" in name: - continue - if ("rotary_emb.cos_cached" in name - or "rotary_emb.sin_cached" in name): - # Models trained using ColossalAI may include these tensors in - # the checkpoint. Skip them. - continue - # With tie_word_embeddings, we can skip lm_head.weight - # The weight might appear unnecessarily in the files if the model is - # processed with quantization, LoRA, fine-tuning, etc. - if self.config.tie_word_embeddings and "lm_head.weight" in name: - continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - for param_name, weight_name, expert_id in expert_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, - loaded_weight, - weight_name, - expert_id=expert_id) - break - else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head."] + if self.config.tie_word_embeddings else None), + ) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/minicpm3.py b/vllm/model_executor/models/minicpm3.py index c38c31a0d4953..e9d7eada1d16c 100644 --- a/vllm/model_executor/models/minicpm3.py +++ b/vllm/model_executor/models/minicpm3.py @@ -40,7 +40,7 @@ MiniCPMForCausalLM, MiniCPMModel) -from .utils import make_layers, maybe_prefix +from .utils import make_layers class MiniCPM3Attention(nn.Module): @@ -241,6 +241,11 @@ class MiniCPM3ForCausalLM(MiniCPMForCausalLM): # `embedding_modules` and `embedding_padding_modules` # are inherited from MiniCPMForCausalLM + bitsandbytes_stacked_params_mapping = { + # shard_name, weight_name, index + "gate_proj": ("gate_up_proj", 0), + "up_proj": ("gate_up_proj", 1), + } + def _init_model(self, *, vllm_config: VllmConfig, prefix: str = ""): - self.model = MiniCPM3Model(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + return MiniCPM3Model(vllm_config=vllm_config, prefix=prefix) diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index aacce477e0460..1e8f9bd4cf418 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -22,7 +22,7 @@ """Inference-only MiniCPM-V model compatible with HuggingFace weights.""" import math import re -from functools import partial +from functools import cached_property, partial from typing import (Any, Callable, Iterable, List, Literal, Mapping, Optional, Set, Tuple, TypedDict, Union) @@ -37,19 +37,15 @@ from vllm.config import VllmConfig from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, InputContext, token_inputs) -from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.resampler import (BaseResampler, Resampler2, get_2d_sincos_pos_embed) from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler -from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader.utils import set_default_torch_dtype -from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.llama import LlamaModel -from vllm.model_executor.models.minicpm import MiniCPMModel +from vllm.model_executor.models.llama import LlamaForCausalLM +from vllm.model_executor.models.minicpm import MiniCPMForCausalLM from vllm.model_executor.models.module_mapping import MultiModelKeys -from vllm.model_executor.models.qwen2 import Qwen2Model -from vllm.model_executor.models.utils import LLMWrapper +from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs from vllm.multimodal.image import cached_get_image_processor @@ -58,11 +54,7 @@ from .idefics2_vision_model import Idefics2VisionTransformer from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP -from .utils import is_pp_missing_parameter, maybe_prefix - -_KEYS_TO_MODIFY_MAPPING = { - "llm.lm_head": "lm_head", -} +from .utils import AutoWeightsLoader, maybe_prefix RawImageType = Union[Image.Image, torch.Tensor] @@ -297,10 +289,9 @@ def input_processor_for_minicpmv(ctx: InputContext, inputs: DecoderOnlyInputs): def get_placeholder(image_size: Tuple[int, int], num_image: int): if version == (2, 0) or version == (2, 5): - return image_processor. \ - get_slice_image_placeholder(image_size) - return image_processor. \ - get_slice_image_placeholder(image_size, num_image) + return image_processor.get_slice_image_placeholder(image_size) + return image_processor.get_slice_image_placeholder( + image_size, num_image) prompt = inputs.get("prompt") token_ids = inputs.get("prompt_token_ids") @@ -400,37 +391,32 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.vpm = self.init_vision_module(config, quant_config, prefix=maybe_prefix(prefix, "vpm")) - param_dtype = torch.get_default_dtype() - self.vpm.to(dtype=param_dtype) self.vision_dim = (self.vpm.embed_dim if self.version == (2, 0) else self.vpm.embeddings.embed_dim) self.embed_dim = self.config.hidden_size + self.resampler = self.init_resampler(self.embed_dim, self.vision_dim, quant_config=quant_config, prefix=maybe_prefix( prefix, "resampler")) - self.resampler.to(device="cuda", dtype=param_dtype) - # TODO: why is there _KEYS_TO_MODIFY_MAPPING? lm_head should be in llm - self.lm_head = ParallelLMHead(config.vocab_size, - config.hidden_size, - quant_config=quant_config, - prefix=maybe_prefix( - prefix, "llm.lm_head")) - self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.llm.make_empty_intermediate_tensors) + @cached_property + def sampler(self): + if hasattr(self.llm, "sampler"): + return self.llm.sampler + + return get_sampler() + def get_embedding( self, input_ids: torch.Tensor, image_inputs: Optional[MiniCPMVImageInputs], ) -> Tuple[torch.Tensor, torch.Tensor]: - vlm_embedding: torch.Tensor = self.llm.embed_tokens(input_ids) - if hasattr(self.config, "scale_emb"): - vlm_embedding *= self.config.scale_emb + vlm_embedding: torch.Tensor = self.llm.get_input_embeddings(input_ids) if image_inputs is None: # No image vision_hidden_states = torch.tensor([], device=input_ids.device) @@ -575,7 +561,7 @@ def forward( # for `torch.compile` integration input_ids = None - output = self.llm( + output = self.llm.model( input_ids=input_ids, positions=positions, kv_caches=kv_caches, @@ -590,9 +576,7 @@ def compute_logits( hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) - return logits + return self.llm.compute_logits(hidden_states, sampling_metadata) def sample( self, @@ -604,52 +588,8 @@ def sample( def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ("gate_up_proj", "gate_proj", 0), - ("gate_up_proj", "up_proj", 1), - ] - params_dict = dict(self.named_parameters()) - loaded_params: Set[str] = set() - for name, loaded_weight in weights: - for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items(): - if key_to_modify in name: - name = name.replace(key_to_modify, new_key) - if "rotary_emb.inv_freq" in name: - continue - if ("rotary_emb.cos_cached" in name - or "rotary_emb.sin_cached" in name): - # Models trained using ColossalAI may include these tensors in - # the checkpoint. Skip them. - continue - use_default_weight_loading = False - if self.is_default_weight_loading(name): - use_default_weight_loading = True - else: - for param_name, weight_name, shard_id in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - use_default_weight_loading = True - if use_default_weight_loading: - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) def get_mm_mapping(self) -> MultiModelKeys: """ @@ -693,9 +633,6 @@ def get_vision_hidden_states(self, data: MiniCPMVImageInputs) -> torch.Tensor: raise NotImplementedError - def is_default_weight_loading(self, name: str) -> bool: - raise NotImplementedError - class MiniCPMV2_0(MiniCPMVBaseModel): @@ -708,8 +645,7 @@ def init_llm( vllm_config: VllmConfig, prefix: str = "", ) -> nn.Module: - return LLMWrapper(MiniCPMModel(vllm_config=vllm_config, prefix=prefix), - name="model") + return MiniCPMForCausalLM(vllm_config=vllm_config, prefix=prefix) def init_vision_module( self, @@ -717,11 +653,12 @@ def init_vision_module( quant_config: Optional[QuantizationConfig], prefix: str = "", ) -> nn.Module: - # TODO :refactor this vision model + # TODO: refactor this vision model try: import timm except ImportError: raise ImportError("Please install timm==0.9.10") from ImportError + with set_default_torch_dtype(torch.float16): model = timm.create_model( "vit_so400m_patch14_siglip_384.webli", @@ -731,6 +668,8 @@ def init_vision_module( dynamic_img_pad=True, ) + model = model.to(dtype=torch.get_default_dtype()) + if (isinstance(model, timm.models.VisionTransformer) and model.attn_pool is not None): model.attn_pool = torch.nn.Identity() @@ -759,7 +698,7 @@ def init_resampler(self, quant_config=quant_config, prefix=prefix) - return resampler + return resampler.to(device="cuda", dtype=torch.get_default_dtype()) def get_vision_embedding( self, @@ -790,9 +729,6 @@ def get_vision_hidden_states(self, return self.get_vision_embedding(pixel_values) - def is_default_weight_loading(self, name: str) -> bool: - return "resampler" in name or "vpm" in name - class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA): packed_modules_mapping = { @@ -843,8 +779,7 @@ def init_llm( vllm_config: VllmConfig, prefix: str = "", ) -> nn.Module: - return LLMWrapper(LlamaModel(vllm_config=vllm_config, prefix=prefix), - name="model") + return LlamaForCausalLM(vllm_config=vllm_config, prefix=prefix) def init_vision_module( self, @@ -871,7 +806,8 @@ def init_resampler(self, kv_dim=vision_dim, quant_config=quant_config, prefix=prefix) - return resampler + + return resampler.to(device="cuda", dtype=torch.get_default_dtype()) def get_vision_embedding( self, @@ -913,9 +849,6 @@ def get_vision_hidden_states(self, return self.get_vision_embedding(all_pixel_values.type(dtype), patch_attn_mask, tgt_sizes) - def is_default_weight_loading(self, name: str) -> bool: - return "resampler" in name - class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA): packed_modules_mapping = { @@ -966,8 +899,7 @@ def init_llm( vllm_config: VllmConfig, prefix: str = "", ) -> nn.Module: - return LLMWrapper(Qwen2Model(vllm_config=vllm_config, prefix=prefix), - name="model") + return Qwen2ForCausalLM(vllm_config=vllm_config, prefix=prefix) def init_vision_module( self, @@ -995,7 +927,8 @@ def init_resampler(self, kv_dim=vision_dim, quant_config=quant_config, prefix=prefix) - return resampler + + return resampler.to(device="cuda", dtype=torch.get_default_dtype()) def get_vision_embedding( self, @@ -1043,9 +976,6 @@ def get_vision_hidden_states(self, return self.resampler(vision_embedding, tgt_sizes) - def is_default_weight_loading(self, name: str) -> bool: - return "resampler" in name - _SUPPORT_VERSION = { (2, 0): MiniCPMV2_0, diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py index acedddd84d7cb..a328b5a2aeea7 100644 --- a/vllm/model_executor/models/molmo.py +++ b/vllm/model_executor/models/molmo.py @@ -3,7 +3,7 @@ from array import array from dataclasses import dataclass from functools import lru_cache, partial -from typing import Iterable, List, Mapping, Optional, Tuple, TypedDict +from typing import Iterable, List, Mapping, Optional, Set, Tuple, TypedDict import torch from einops import rearrange @@ -13,6 +13,7 @@ from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata +from vllm.attention.layer import MultiHeadAttention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, @@ -36,22 +37,25 @@ ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs -from vllm.multimodal.inputs import NestedTensors +from vllm.multimodal.inputs import NestedTensors, PlaceholderRange from vllm.multimodal.utils import cached_get_tokenizer -from vllm.platforms import _Backend from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, SequenceData) from vllm.transformers_utils.processor import get_processor from .interfaces import SupportsMultiModal, SupportsPP -from .utils import (get_vit_attn_backend, +from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) + maybe_prefix, merge_multimodal_embeddings) # TODO: hard-coded for now. Consider making it configurable. VIT_LAYERS = [-2, -9] NUM_PREFIX_TOKENS = 1 ADDITIONAL_VOCAB_SIZE = 128 +DEFAULT_IMAGE_PATCH_TOKEN_ID = 152066 +DEFAULT_IM_START_TOKEN_ID = 152067 +DEFAULT_IM_END_TOKEN_ID = 152064 +DEFAULT_IM_COL_TOKEN_ID = 152065 class MolmoImageInputs(TypedDict): @@ -75,6 +79,11 @@ class MolmoImageInputs(TypedDict): `(batch_size, num_crops, num_patch)` """ + image_start_end: Tuple[int, int] + """Starting and ending index of placeholder + tokens + """ + @dataclass class VisionBackboneConfig: @@ -187,13 +196,11 @@ def __init__( quant_config=quant_config, ) - # Detect attention implementation. - self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) - if self.attn_backend not in { - _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS - }: - raise RuntimeError( - f"Molmo does not support {self.attn_backend} backend now.") + self.scale = self.head_dim**-0.5 + self.attn = MultiHeadAttention(self.num_heads, + self.head_dim, + self.scale, + num_kv_heads=self.num_kv_heads) def forward(self, inputs_q: torch.Tensor, @@ -209,25 +216,8 @@ def forward(self, xq, _ = self.wq(inputs_q) xk, _ = self.wk(inputs_k) xv, _ = self.wv(inputs_v) - q_shape = xq.size()[:-1] + (self.num_heads, self.head_dim) - kv_shape = xk.size()[:-1] + (self.num_kv_heads, self.head_dim) - xq = xq.view(*q_shape) - xk = xk.view(*kv_shape) - xv = xv.view(*kv_shape) - - if self.attn_backend == _Backend.FLASH_ATTN: - from flash_attn import flash_attn_func - output = flash_attn_func(xq, xk, xv, dropout_p=0.0, causal=False) - elif self.attn_backend == _Backend.TORCH_SDPA: - xq, xk, xv = (rearrange(x, "b s h d -> b h s d") - for x in (xq, xk, xv)) - output = F.scaled_dot_product_attention(xq, xk, xv) - output = rearrange(output, "b h s d -> b s h d ") - elif self.attn_backend == _Backend.XFORMERS: - from xformers import ops as xops - output = xops.memory_efficient_attention_forward(xq, xk, xv, p=0) - - output = rearrange(output, "b s h d -> b s (h d)").contiguous() + + output = self.attn(xq, xk, xv) output, _ = self.wo(output) return output @@ -720,6 +710,42 @@ def forward( # image_features: (batch_size, num_image, num_patch, d_model) return image_features + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + + for name, loaded_weight in weights: + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + @support_torch_compile class MolmoModel(nn.Module): @@ -804,6 +830,28 @@ def forward( hidden_states = self.norm(hidden_states) return hidden_states + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + + for name, loaded_weight in weights: + if "gate_up_proj" in name: + up_proj, gate_proj = loaded_weight.chunk(2, dim=0) + loaded_weight = torch.cat([gate_proj, up_proj], dim=0) + + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + cached_get_processor = lru_cache(get_processor) @@ -879,6 +927,8 @@ def image_input_mapper_for_molmo( ctx: InputContext, data: object, ): + if isinstance(data, list): + data = data[0] return MultiModalKwargs(data) @@ -928,7 +978,22 @@ def dummy_data_for_molmo(ctx: InputContext, seq_len: int, if "image_masks" in out: dummy_imgdata["image_masks"] = out["image_masks"] dummy_imgdata["seq_len"] = torch.tensor(seq_len, dtype=torch.long) - return DummyData(dummy_seqdata, {"image": dummy_imgdata}) + size = 0 + offset = -1 + for i in range(len(token_ids)): + if token_ids[i] in (DEFAULT_IMAGE_PATCH_TOKEN_ID, + DEFAULT_IM_START_TOKEN_ID, DEFAULT_IM_END_TOKEN_ID, + DEFAULT_IM_COL_TOKEN_ID): + if offset < 0: + offset = i + size += 1 + dummy_imgdata["image_start_end"] = (offset, offset + size) + return DummyData(seq_data=dummy_seqdata, + multi_modal_data={"image": dummy_imgdata}, + multi_modal_placeholders={ + "image": + [PlaceholderRange(offset=offset, length=size)] + }) def pad_images( @@ -1016,19 +1081,34 @@ def input_processor_for_molmo(ctx: InputContext, inputs: DecoderOnlyInputs): if image_masks is not None: image_data["image_masks"] = image_masks - image_data["seq_len"] = torch.tensor(len(out["input_ids"]), + new_prompt_token_ids = out["input_ids"].tolist() + image_data["seq_len"] = torch.tensor(len(new_prompt_token_ids), dtype=torch.long) multi_modal_data = dict(image=image_data) + size = 0 + offset = -1 + for i in range(len(new_prompt_token_ids)): + if new_prompt_token_ids[i] in (DEFAULT_IMAGE_PATCH_TOKEN_ID, + DEFAULT_IM_START_TOKEN_ID, + DEFAULT_IM_END_TOKEN_ID, + DEFAULT_IM_COL_TOKEN_ID): + if offset < 0: + offset = i + size += 1 + image_data["image_start_end"] = (offset, offset + size) prompt = inputs.get("prompt") if prompt is None: - prompt = tokenizer.decode(out["input_ids"]) + prompt = tokenizer.decode(new_prompt_token_ids) return token_inputs( - prompt_token_ids=out["input_ids"], + prompt_token_ids=new_prompt_token_ids, prompt=prompt, multi_modal_data=multi_modal_data, + multi_modal_placeholders={ + "image": [PlaceholderRange(offset=offset, length=size)] + }, ) @@ -1074,6 +1154,7 @@ def _parse_and_validate_image_input( ) -> Optional[MolmoImageInputs]: images = kwargs.pop("images", None) image_masks = kwargs.pop("image_masks", None) + image_start_end = kwargs.pop("image_start_end", None) if images is None: return None @@ -1091,6 +1172,7 @@ def _parse_and_validate_image_input( image_input_idx=image_input_idx, seq_len=seq_len, image_masks=image_masks, + image_start_end=image_start_end, ) def _process_image_input( @@ -1139,9 +1221,16 @@ def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]: # Note: In this original implementation from AI2, the final # vision_embeddings will be always be the same length - # of input embedddings, which is not very efficient. - # TODO(ywang96): see if this can be optimized. + # of input embeddings. vision_embeddings = torch.einsum('nd,nm->md', image_features, mat) + + # Split by the sizes of the input sequences. For each full embedding, + # extract the actual vision embeddings to be merged. + vision_embeddings = list(vision_embeddings.split(seq_len.tolist())) + for i in range(len(vision_embeddings)): + start, end = image_input['image_start_end'][i] + vision_embeddings[i] = vision_embeddings[i][start:end] + return vision_embeddings def get_input_embeddings( @@ -1151,7 +1240,11 @@ def get_input_embeddings( ) -> torch.Tensor: inputs_embeds = self.model.get_input_embeddings(input_ids) if multimodal_embeddings is not None: - inputs_embeds = inputs_embeds + multimodal_embeddings + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, multimodal_embeddings, [ + DEFAULT_IMAGE_PATCH_TOKEN_ID, DEFAULT_IM_START_TOKEN_ID, + DEFAULT_IM_END_TOKEN_ID, DEFAULT_IM_COL_TOKEN_ID + ]) return inputs_embeds def forward( @@ -1200,103 +1293,53 @@ def sample( return next_tokens def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - - params_mapping = [ - ("model.transformer.ln_f.weight", "model.norm.weight"), - ("attn_out", "self_attn.o_proj"), - ("att_proj", "self_attn.qkv_proj"), - ("q_norm", "self_attn.q_norm"), - ("k_norm", "self_attn.k_norm"), - ("attn_norm", "input_layernorm"), - ("ff_norm", "post_attention_layernorm"), - ] - - params_dict = dict(self.named_parameters(remove_duplicate=False)) - - embedding_weight = dict() - projector_weight = dict() - for name, loaded_weight in weights: - if "rotary_emb.inv_freq" in name: - continue - if self.config.tie_word_embeddings and "lm_head.weight" in name: - continue - - if "wte.embedding" in name: - embedding_weight["embedding"] = loaded_weight - continue - - if "wte.new_embedding" in name: - embedding_weight["new_embedding"] = loaded_weight - continue - - if "vision_backbone" in name: - if name.startswith("model"): - name = name[len("model."):] - if 'image_projector' in name: - if 'w1' in name: - projector_weight['gate_proj'] = loaded_weight - elif 'w3' in name: - projector_weight['up_proj'] = loaded_weight - elif 'w2' in name: - projector_weight['down_proj'] = loaded_weight - else: - raise ValueError( - f"Unexpected projector weight: {name}") - continue - else: - if "transformer.blocks" in name: - name = name.replace("transformer.blocks", "layers") - - if "ff_proj" in name: - name = name.replace("ff_proj", "mlp.gate_up_proj") - assert 'weight' in name - up_weight, gate_weight = loaded_weight.chunk(2, dim=0) - loaded_weight = torch.cat([gate_weight, up_weight], dim=0) - - elif "ff_out" in name: - if "layers" in name: - name = name.replace("ff_out", "mlp.down_proj") - else: - # lm head - name = name.replace("model.transformer.ff_out", - "lm_head") - - else: - for (param_name, weight_name) in params_mapping: - if param_name in name: - name = name.replace(param_name, weight_name) - break - - try: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - param = params_dict[name] - except KeyError: - raise ValueError(f"Unexpected weight: {name}") from None - - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - - gate_up_proj_weight = torch.cat( - [projector_weight["gate_proj"], projector_weight["up_proj"]], - dim=0) - name = "vision_backbone.image_projector.gate_up_proj.weight" - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, gate_up_proj_weight) - - down_proj_weight = projector_weight["down_proj"] - name = "vision_backbone.image_projector.down_proj.weight" - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, down_proj_weight) - - embedding_weight = torch.cat( - [embedding_weight["embedding"], embedding_weight["new_embedding"]], - dim=0) - name = "model.embed_tokens.weight" - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, embedding_weight) + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_substr={ + # vision backbone mapping + "image_projector.w1.": "image_projector.gate_proj.", + "image_projector.w3.": "image_projector.up_proj.", + "image_projector.w2.": "image_projector.down_proj.", + # language backbone mapping + "att_proj": "self_attn.qkv_proj", + "attn_out": "self_attn.o_proj", + "q_norm": "self_attn.q_norm", + "k_norm": "self_attn.k_norm", + "ff_proj": "mlp.gate_up_proj", + "ff_out": "mlp.down_proj", + "attn_norm": "input_layernorm", + "ff_norm": "post_attention_layernorm", + }, + orig_to_new_prefix={ + # vision backbone mapping + "model.vision_backbone.": "vision_backbone.", + # language backbone mapping + "model.transformer.blocks.": "model.layers.", + "model.transformer.ln_f.": "model.norm.", + # lm_head is renamed to model.transformer.mlp.down_proj firstly, + # we need to run a second renaming for it + "model.transformer.mlp.down_proj.": "lm_head.", + }, + ) + loader = AutoWeightsLoader(self) + weights = _get_weights_with_merged_embedding(weights) + return loader.load_weights(weights, mapper=hf_to_vllm_mapper) + + +def _get_weights_with_merged_embedding( + weights: Iterable[Tuple[str, torch.Tensor]] +) -> Iterable[Tuple[str, torch.Tensor]]: + embedding_weights = {} + for name, weight in weights: + if "wte.embedding" in name: + embedding_weights["embedding"] = weight + elif "wte.new_embedding" in name: + embedding_weights["new_embedding"] = weight + else: + yield (name, weight) + # this is compatible with most of quantization, + # because they won't quantize embed_tokens + embedding_weights = torch.cat( + [embedding_weights["embedding"], embedding_weights["new_embedding"]], + dim=0, + ) + yield ("model.embed_tokens.weight", embedding_weights) diff --git a/vllm/model_executor/models/nemotron.py b/vllm/model_executor/models/nemotron.py index c7b4c22b6896b..34cb9981c167b 100644 --- a/vllm/model_executor/models/nemotron.py +++ b/vllm/model_executor/models/nemotron.py @@ -435,9 +435,11 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size, logit_scale) - self.sampler = get_sampler() else: self.lm_head = PPMissingLayer() + + self.sampler = get_sampler() + self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py index 2e5b6bee784e7..253e689e50a3b 100644 --- a/vllm/model_executor/models/paligemma.py +++ b/vllm/model_executor/models/paligemma.py @@ -151,9 +151,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.quant_config = quant_config config.text_config.architectures = ["GemmaForCausalLM"] self.language_model = init_vllm_registered_model( - config.text_config, vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "language_model")) + hf_config=config.text_config, + prefix=maybe_prefix(prefix, "language_model"), + ) logit_scale = getattr(config, "logit_scale", 1.0) self.language_model.logits_processor.scale *= logit_scale diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 4cb874a13e0c1..eef23029a2aca 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -29,24 +29,22 @@ from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, InputContext, token_inputs) from vllm.logger import init_logger -from vllm.model_executor.layers.pooler import Pooler, PoolingType from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.models.clip import CLIPVisionModel -from vllm.model_executor.models.llama import LlamaForCausalLM -from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import NestedTensors, PlaceholderRange from vllm.multimodal.utils import cached_get_tokenizer, repeat_and_pad_token -from vllm.sequence import IntermediateTensors, PoolerOutput +from vllm.sequence import IntermediateTensors from vllm.utils import is_list_of from .clip import dummy_image_for_clip, dummy_seq_data_for_clip from .interfaces import SupportsMultiModal, SupportsPP -from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, maybe_prefix, +from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, + init_vllm_registered_model, maybe_prefix, merge_multimodal_embeddings) logger = init_logger(__name__) @@ -536,7 +534,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config - pooler_config = vllm_config.model_config.pooler_config multimodal_config = vllm_config.model_config.multimodal_config self.config = config self.multimodal_config = multimodal_config @@ -556,18 +553,17 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): quant_config, prefix=maybe_prefix(prefix, "model.vision_embed_tokens")) - # The prefix is empty intentionally because default prefix of - # LlamaForCausalLM is "model" - self.language_model = LlamaForCausalLM(vllm_config=vllm_config, - prefix="") - - # The same model class supports both language generation and embedding - # because the architecture name is the same - self._pooler = Pooler.from_config_with_defaults( - pooler_config, - pooling_type=PoolingType.LAST, - normalize=True, - softmax=False) + self.language_model = init_vllm_registered_model( + vllm_config=vllm_config, + # The prefix is empty intentionally because default prefix of + # LlamaForCausalLM is "model" + prefix="", + # We don't directly initialize vLLM's LlamaForCausalLM so we + # can automatically apply embedding wrapper if this model is + # initialized as an embedding model + architectures=["LlamaForCausalLM"], + ) + self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) @@ -739,13 +735,6 @@ def sample( ) -> Optional[SamplerOutput]: return self.language_model.sample(logits, sampling_metadata) - def pooler( - self, - hidden_states: torch.Tensor, - pooling_metadata: PoolingMetadata, - ) -> Optional[PoolerOutput]: - return self._pooler(hidden_states, pooling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: hf_to_vllm_mapper = WeightsMapper( diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index 45171c1a04b17..c6786c363ab4a 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -48,6 +48,9 @@ except ImportError: USE_XFORMERS_OPS = False +PIXTRAL_IMAGE_BREAK_ID = 12 +PIXTRAL_IMAGE_END_ID = 13 + def get_max_pixtral_image_tokens(ctx: InputContext): tokenizer = cached_get_tokenizer( @@ -68,7 +71,6 @@ def dummy_data_for_pixtral(ctx: InputContext, seq_len: int, tokenizer_mode=ctx.model_config.tokenizer_mode) mm_encoder = tokenizer.mistral.instruct_tokenizer.mm_encoder - patch_size = mm_encoder.mm_config.image_patch_size image_token_id = mm_encoder.special_ids.img mm_config = ctx.model_config.multimodal_config @@ -78,8 +80,8 @@ def dummy_data_for_pixtral(ctx: InputContext, seq_len: int, size = 256 image = Image.new("RGB", (size, size), color=0) - image_feature_size = (size**2) // (patch_size**2) - + encoding = tokenizer.instruct.mm_encoder(ImageChunk(image=image)) + image_feature_size = len(encoding.tokens) num_image_tokens = image_feature_size * num_images seq_data = SequenceData.from_prompt_token_counts( (image_token_id, num_image_tokens), @@ -101,14 +103,13 @@ def input_mapper_for_pixtral(ctx: InputContext, Args: ctx: Context of the loaded model. - data: data potentially containing image/image embeddings to be mapped - to pixel_values in .forward() for a visual QWenLMHeadModel model. + data: data potentially containing PIL images to be processed + and mapped to `images`. Returns: MultiModalKwargs containing the stacked normalized images tensor or image embeddings. """ - # Early exit if we have provided an image to a language only Qwen model model_config = ctx.model_config tokenizer = cached_get_tokenizer( model_config.tokenizer, tokenizer_mode=model_config.tokenizer_mode) @@ -116,35 +117,67 @@ def input_mapper_for_pixtral(ctx: InputContext, data_list = data if isinstance(data, list) else [data] images = [] + image_tokens_list = [] for image_data in data_list: image = ImageChunk(image=image_data) encoding = tokenizer.instruct.mm_encoder(image) image = torch.from_numpy(encoding.image).to(device="cuda", dtype=torch.float16) images.append(image) + image_tokens_list.append(encoding.tokens) - return MultiModalKwargs({"images": images}) + image_tokens = torch.tensor([ + token_id for image_tokens in image_tokens_list + for token_id in image_tokens + ]) + return MultiModalKwargs({"images": images, "image_tokens": image_tokens}) def input_processor_for_pixtral(ctx: InputContext, inputs: DecoderOnlyInputs): multi_modal_data = inputs.get("multi_modal_data") - if multi_modal_data is not None and "image" in multi_modal_data: - tokenizer = cached_get_tokenizer( - ctx.model_config.tokenizer, - tokenizer_mode=ctx.model_config.tokenizer_mode) - - mm_encoder = tokenizer.mistral.instruct_tokenizer.mm_encoder - image_token_id = mm_encoder.special_ids.img + if multi_modal_data is None or "image" not in multi_modal_data: + return inputs - if image_token_id not in inputs['prompt_token_ids']: - raise ValueError( - f"You've passed {inputs=} without {image_token_id=}" - " Make sure to process your input via mistral_common's" - " tokenizer or pass a chat completion request. For more" - " For more info, see: " - "https://github.com/vllm-project/vllm/issues/8411.") + prompt_token_ids = inputs.get("prompt_token_ids") + prompt = inputs.get("prompt") + tokenizer = cached_get_tokenizer( + ctx.model_config.tokenizer, + tokenizer_mode=ctx.model_config.tokenizer_mode) - return inputs + mm_encoder = tokenizer.mistral.instruct_tokenizer.mm_encoder + image_token_id = mm_encoder.special_ids.img + image_break_id = mm_encoder.special_ids.img_break + image_end_id = mm_encoder.special_ids.img_end + + if image_token_id not in inputs['prompt_token_ids']: + raise ValueError( + f"You've passed {inputs=} without {image_token_id=}" + " Make sure to process your input via mistral_common's" + " tokenizer or pass a chat completion request. For more" + " For more info, see: " + "https://github.com/vllm-project/vllm/issues/8411.") + + # Get precise tracking of placeholder positions + placeholder_ranges = [] + curr_offset = -1 + curr_length = 0 + for i in range(len(prompt_token_ids)): + if prompt_token_ids[i] in (image_token_id, image_break_id): + if curr_offset < 0: + curr_offset = i + curr_length += 1 + elif prompt_token_ids[i] == image_end_id: + curr_length += 1 + placeholder_ranges.append( + PlaceholderRange(offset=curr_offset, length=curr_length)) + curr_offset = -1 + curr_length = 0 + else: + pass + return token_inputs(prompt=prompt, + prompt_token_ids=prompt_token_ids, + multi_modal_data=multi_modal_data, + multi_modal_placeholders={"image": placeholder_ranges}) @MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_pixtral) @@ -172,9 +205,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): # init MistralForCausalLM self.language_model = init_vllm_registered_model( - config.text_config, vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "language_model")) + hf_config=config.text_config, + prefix=maybe_prefix(prefix, "language_model"), + ) self.vision_encoder = VisionTransformer(self.vision_args) self.vision_language_adapter = VisionLanguageAdapter( @@ -191,11 +225,29 @@ def sampler(self): return get_sampler() def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]: - image_input = self._parse_and_validate_image_input(**kwargs) + image_input, image_tokens = self._parse_and_validate_image_input( + **kwargs) if image_input is None: return None + vision_embeddings = self._process_image_input(image_input) - return vision_embeddings + + # NOTE: We patch the outputs of the vision encoder with embeddings + # from `[IMG_BREAK]` and `[IMG_END]` tokens. + image_embeds = self.language_model.get_input_embeddings(image_tokens) + image_token_mask = image_tokens == self.vision_args.image_token_id + image_embeds[image_token_mask] = vision_embeddings + + # NOTE: Image embeddings are split into separate tensors for each image + # by the indices of `[IMG_END]` token. + split_indices = torch.where( + image_tokens == PIXTRAL_IMAGE_END_ID)[0] + 1 + if len(split_indices) <= 1: + # Do not split, return as tensor of shape [1, fs, hs] + return image_embeds.unsqueeze(0) + + image_embeds = image_embeds.tensor_split(split_indices.cpu()) + return image_embeds def get_input_embeddings( self, @@ -205,8 +257,10 @@ def get_input_embeddings( inputs_embeds = self.language_model.get_input_embeddings(input_ids) if multimodal_embeddings is not None: inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, multimodal_embeddings, - self.vision_args.image_token_id) + input_ids, inputs_embeds, multimodal_embeddings, [ + self.vision_args.image_token_id, PIXTRAL_IMAGE_END_ID, + PIXTRAL_IMAGE_BREAK_ID + ]) return inputs_embeds def forward( @@ -244,10 +298,11 @@ def forward( def _parse_and_validate_image_input( self, images: Optional[Union[List[List[torch.Tensor]], List[torch.Tensor], - torch.Tensor]] = None + torch.Tensor]] = None, + image_tokens: Optional[torch.Tensor] = None, ) -> Optional[List[torch.Tensor]]: if images is None: - return None + return None, None if isinstance(images, torch.Tensor): # if passed as batch take all images @@ -266,7 +321,16 @@ def _parse_and_validate_image_input( images = flatten_images - return images + if isinstance(image_tokens, torch.Tensor): + # image_tokens are batched + image_tokens = image_tokens.flatten() + elif isinstance(image_tokens, list): + # image_tokens are of different lengths thus passed as a list + image_tokens = torch.cat(image_tokens) + + assert image_tokens.dim() == 1 + + return images, image_tokens def _process_image_input(self, image_input: List[torch.Tensor]) -> torch.Tensor: diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index 9f706610a129a..3ce4eb5869f21 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -31,6 +31,7 @@ from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, @@ -55,6 +56,8 @@ make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) +logger = init_logger(__name__) + class Qwen2MLP(nn.Module): @@ -433,7 +436,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config lora_config = vllm_config.lora_config - pooler_config = vllm_config.model_config.pooler_config self.config = config self.lora_config = lora_config @@ -442,26 +444,21 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.model = Qwen2Model(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) - if config.tie_word_embeddings: - self.lm_head = self.model.embed_tokens + if get_pp_group().is_last_rank: + if config.tie_word_embeddings: + self.lm_head = self.model.embed_tokens + else: + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix( + prefix, "lm_head")) else: - self.lm_head = ParallelLMHead(config.vocab_size, - config.hidden_size, - quant_config=quant_config, - prefix=maybe_prefix( - prefix, "lm_head")) + self.lm_head = PPMissingLayer() self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = get_sampler() - # The same model class supports both language generation and embedding - # because the architecture name is the same - self._pooler = Pooler.from_config_with_defaults( - pooler_config, - pooling_type=PoolingType.LAST, - normalize=True, - softmax=False) - self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) @@ -499,13 +496,6 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def pooler( - self, - hidden_states: torch.Tensor, - pooling_metadata: PoolingMetadata, - ) -> Optional[PoolerOutput]: - return self._pooler(hidden_states, pooling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: loader = AutoWeightsLoader( @@ -553,6 +543,15 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.model = Qwen2Model(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) + # TODO: Replace this model class with for_embedding(Qwen2ForCausalLM), + # after changing the default pooling method + if pooler_config.pooling_type is None: + logger.warning( + "This embedding model will default to last-token pooling in " + "an upcoming version. To avoid breaking changes, you should " + "pass `--override-pooler-config '{\"pooling_type\": \"MEAN\"}'`" + " explicitly.") + self._pooler = Pooler.from_config_with_defaults( pooler_config, pooling_type=PoolingType.MEAN, @@ -580,4 +579,6 @@ def pooler( def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""}) weights = hf_to_vllm_mapper.apply(weights) + weights = ((name, data) for name, data in weights + if not name.startswith("lm_head.")) self.model.load_weights(weights) diff --git a/vllm/model_executor/models/qwen2_audio.py b/vllm/model_executor/models/qwen2_audio.py index a0605fee82aca..48a2d470414b9 100644 --- a/vllm/model_executor/models/qwen2_audio.py +++ b/vllm/model_executor/models/qwen2_audio.py @@ -19,7 +19,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Qwen2-Audio model compatible with HuggingFace weights.""" -from functools import lru_cache +from functools import cached_property, lru_cache from typing import (Iterable, List, Mapping, Optional, Set, Tuple, TypedDict, Union) @@ -34,12 +34,7 @@ from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, InputContext, token_inputs) from vllm.logger import init_logger -from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler -from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead -from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, maybe_remap_kv_scale_name) -from vllm.model_executor.models.qwen2 import Qwen2Model from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs from vllm.multimodal.inputs import NestedTensors @@ -47,15 +42,11 @@ from vllm.sequence import IntermediateTensors, SequenceData from .interfaces import SupportsMultiModal, SupportsPP -from .utils import merge_multimodal_embeddings +from .utils import (AutoWeightsLoader, init_vllm_registered_model, + maybe_prefix, merge_multimodal_embeddings) logger = init_logger(__name__) -_KEYS_TO_MODIFY_MAPPING = { - "language_model.lm_head": "lm_head", - "language_model.model": "language_model", -} - # # === Audio Inputs === # class Qwen2AudioInputs(TypedDict): @@ -281,25 +272,23 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.quant_config = quant_config - self.language_model = Qwen2Model( - vllm_config=vllm_config.with_hf_config(config.text_config), - prefix=prefix) - self.unpadded_vocab_size = config.text_config.vocab_size - if config.text_config.tie_word_embeddings: - self.lm_head = self.language_model.embed_tokens - else: - self.lm_head = ParallelLMHead(config.text_config.vocab_size, - config.text_config.hidden_size, - quant_config=quant_config) - logit_scale = getattr(config, "logit_scale", 1.0) - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.text_config.vocab_size, - logit_scale) - self.sampler = get_sampler() + self.language_model = init_vllm_registered_model( + vllm_config=vllm_config, + hf_config=config.text_config, + prefix=maybe_prefix(prefix, "language_model"), + architectures=["Qwen2ForCausalLM"], + ) self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) + @cached_property + def sampler(self): + if hasattr(self.language_model, "sampler"): + return self.language_model.sampler + + return get_sampler() + def _validate_and_reshape_mm_tensor(self, mm_input: Union[torch.Tensor, List[torch.Tensor]], @@ -414,72 +403,30 @@ def forward( multimodal_embeddings) input_ids = None - hidden_states = self.language_model(input_ids, - positions, - kv_caches, - attn_metadata, - intermediate_tensors, - inputs_embeds=inputs_embeds) + hidden_states = self.language_model.model(input_ids, + positions, + kv_caches, + attn_metadata, + intermediate_tensors, + inputs_embeds=inputs_embeds) return hidden_states - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) - return logits + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + return self.language_model.compute_logits(hidden_states, + sampling_metadata) def sample( self, logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens + return self.language_model.sample(logits, sampling_metadata) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ("gate_up_proj", "gate_proj", 0), - ("gate_up_proj", "up_proj", 1), - ] - params_dict = dict(self.named_parameters(remove_duplicate=False)) - loaded_params: Set[str] = set() - for name, loaded_weight in weights: - if "rotary_emb.inv_freq" in name: - continue - if (self.config.text_config.tie_word_embeddings - and "lm_head.weight" in name): - continue - for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items(): - if key_to_modify in name: - name = name.replace(key_to_modify, new_key) - for (param_name, weight_name, shard_id) in stacked_params_mapping: - if weight_name not in name or 'audio' in name: - continue - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - # Remapping the name of FP8 kv-scale. - name = maybe_remap_kv_scale_name(name, params_dict) - if name is None: - continue - - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 7956a98b21569..cfc90cdab01e4 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -21,7 +21,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Qwen2-VL model compatible with HuggingFace weights.""" -from functools import partial +from functools import cached_property, partial from typing import (Any, Callable, Dict, Iterable, List, Literal, Mapping, Optional, Set, Tuple, Type, TypedDict, Union) @@ -40,7 +40,7 @@ from vllm.attention import AttentionMetadata from vllm.config import VllmConfig -from vllm.distributed import get_pp_group, parallel_state +from vllm.distributed import parallel_state from vllm.distributed import utils as dist_utils from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, InputContext, token_inputs) @@ -49,31 +49,25 @@ from vllm.model_executor.layers.activation import QuickGELU from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) -from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.pooler import Pooler, PoolingType from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization.gptq import GPTQConfig from vllm.model_executor.layers.quantization.gptq_marlin import ( GPTQMarlinConfig) from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler -from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.qwen2 import Qwen2Model -from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.image import cached_get_image_processor from vllm.multimodal.inputs import (MultiModalData, MultiModalDataDict, MultiModalKwargs, NestedTensors) from vllm.multimodal.utils import cached_get_tokenizer from vllm.platforms import _Backend -from vllm.sequence import IntermediateTensors, PoolerOutput, SequenceData +from vllm.sequence import IntermediateTensors, SequenceData from vllm.transformers_utils.config import uses_mrope from vllm.transformers_utils.processor import cached_get_processor from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP -from .utils import (PPMissingLayer, get_vit_attn_backend, - is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, maybe_prefix) +from .utils import (AutoWeightsLoader, WeightsMapper, get_vit_attn_backend, + init_vllm_registered_model, maybe_prefix) logger = init_logger(__name__) @@ -508,6 +502,8 @@ def __init__( mlp_ratio: float = vision_config.mlp_ratio self.spatial_merge_size = spatial_merge_size + self.num_heads = num_heads + self.embed_dim = embed_dim self.patch_embed = Qwen2VisionPatchEmbed( patch_size=patch_size, @@ -597,6 +593,53 @@ def forward( x = self.merger(x) return x + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: Set[str] = set() + + for name, loaded_weight in weights: + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + if name.endswith("qkv.weight"): + visual_num_heads = self.num_heads + visual_embed_dim = self.embed_dim + head_size = visual_embed_dim // visual_num_heads + loaded_weight = loaded_weight.view(3, visual_num_heads, + head_size, + visual_embed_dim) + loaded_weight = loaded_weight.transpose(0, 1) + loaded_weight = loaded_weight.reshape(-1, visual_embed_dim) + elif name.endswith("qkv.bias"): + visual_num_heads = self.num_heads + visual_embed_dim = self.embed_dim + head_size = visual_embed_dim // visual_num_heads + loaded_weight = loaded_weight.view(3, visual_num_heads, + head_size) + loaded_weight = loaded_weight.transpose(0, 1) + loaded_weight = loaded_weight.reshape(-1) + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + # === Vision input helpers === # @@ -1070,7 +1113,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config - pooler_config = vllm_config.model_config.pooler_config multimodal_config = vllm_config.model_config.multimodal_config assert not cache_config.enable_prefix_caching, \ "Qwen2-VL currently does not support prefix caching" @@ -1085,31 +1127,21 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): prefix=maybe_prefix(prefix, "visual"), ) - self.model = Qwen2Model(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + self.language_model = init_vllm_registered_model( + vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "language_model"), + architectures=["Qwen2ForCausalLM"], + ) - if get_pp_group().is_last_rank: - if config.tie_word_embeddings: - self.lm_head = self.model.embed_tokens - else: - self.lm_head = ParallelLMHead(config.vocab_size, - config.hidden_size, - quant_config=quant_config, - prefix=maybe_prefix( - prefix, "lm_head")) - else: - self.lm_head = PPMissingLayer() - - self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = get_sampler() - self._pooler = Pooler.from_config_with_defaults( - pooler_config, - pooling_type=PoolingType.LAST, - normalize=True, - softmax=False) self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size)) + self.language_model.make_empty_intermediate_tensors) + + @cached_property + def sampler(self): + if hasattr(self.language_model, "sampler"): + return self.language_model.sampler + + return get_sampler() def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig): # GPTQ configs do not have a list of ignored modules, however AutoGPTQ @@ -1268,7 +1300,7 @@ def get_input_embeddings( multimodal_embeddings: Optional[List[Tuple[NestedTensors, str]]] = None, ) -> torch.Tensor: - inputs_embeds = self.model.get_input_embeddings(input_ids) + inputs_embeds = self.language_model.get_input_embeddings(input_ids) if multimodal_embeddings is not None: for embeddings, modality in multimodal_embeddings: if modality == "image": @@ -1337,7 +1369,7 @@ def forward( multimodal_embeddings) input_ids = None - hidden_states = self.model( + hidden_states = self.language_model.model( input_ids=input_ids, positions=positions, kv_caches=kv_caches, @@ -1347,87 +1379,28 @@ def forward( ) return hidden_states - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) - return logits + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + return self.language_model.compute_logits(hidden_states, + sampling_metadata) def sample( self, logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - - def pooler( - self, - hidden_states: torch.Tensor, - pooling_metadata: PoolingMetadata, - ) -> Optional[PoolerOutput]: - return self._pooler(hidden_states, pooling_metadata) + return self.language_model.sample(logits, sampling_metadata) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ("gate_up_proj", "up_proj", 1), - ("gate_up_proj", "gate_proj", 0), - ] - params_dict = dict(self.named_parameters(remove_duplicate=False)) - loaded_params: Set[str] = set() - for name, loaded_weight in weights: - if "rotary_emb.inv_freq" in name: - continue - if self.config.tie_word_embeddings and "lm_head.weight" in name: - continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - if "visual" in name and name.endswith("qkv.weight"): - visual_num_heads = self.config.vision_config.num_heads - visual_embed_dim = self.config.vision_config.embed_dim - head_size = visual_embed_dim // visual_num_heads - loaded_weight = loaded_weight.view(3, visual_num_heads, - head_size, - visual_embed_dim) - loaded_weight = loaded_weight.transpose(0, 1) - loaded_weight = loaded_weight.reshape(-1, visual_embed_dim) - elif "visual" in name and name.endswith("qkv.bias"): - visual_num_heads = self.config.vision_config.num_heads - visual_embed_dim = self.config.vision_config.embed_dim - head_size = visual_embed_dim // visual_num_heads - loaded_weight = loaded_weight.view(3, visual_num_heads, - head_size) - loaded_weight = loaded_weight.transpose(0, 1) - loaded_weight = loaded_weight.reshape(-1) - try: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - except KeyError: - raise ValueError(f"Unexpected weight: {name}") from None - - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + "lm_head.": "language_model.lm_head.", + "model.": "language_model.model.", + }) + + loader = AutoWeightsLoader(self) + return loader.load_weights(weights, mapper=hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 4462f6ed55a9c..e69596aa915b5 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -20,10 +20,11 @@ from vllm.logger import init_logger from vllm.platforms import current_platform +from .adapters import as_embedding_model from .interfaces import (has_inner_state, is_attention_free, supports_cross_encoding, supports_multimodal, supports_pp) -from .interfaces_base import is_embedding_model, is_text_generation_model +from .interfaces_base import is_pooling_model, is_text_generation_model logger = init_logger(__name__) @@ -48,6 +49,7 @@ "FalconForCausalLM": ("falcon", "FalconForCausalLM"), "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"), "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"), + "GlmForCausalLM": ("glm", "GlmForCausalLM"), "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"), "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"), "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"), @@ -92,7 +94,7 @@ "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"), "SolarForCausalLM": ("solar", "SolarForCausalLM"), "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"), - "XverseForCausalLM": ("xverse", "XverseForCausalLM"), + "XverseForCausalLM": ("llama", "LlamaForCausalLM"), # [Encoder-decoder] "BartModel": ("bart", "BartForConditionalGeneration"), "BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"), @@ -106,14 +108,15 @@ "RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"), "XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"), "DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"), - "Gemma2Model": ("gemma2", "Gemma2EmbeddingModel"), - "LlamaModel": ("llama", "LlamaEmbeddingModel"), + "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"), + "GlmForCausalLM": ("glm", "GlmForCausalLM"), + "LlamaModel": ("llama", "LlamaForCausalLM"), **{ # Multiple models share the same architecture, so we include them all k: (mod, arch) for k, (mod, arch) in _TEXT_GENERATION_MODELS.items() if arch == "LlamaForCausalLM" }, - "MistralModel": ("llama", "LlamaEmbeddingModel"), + "MistralModel": ("llama", "LlamaForCausalLM"), "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"), "Qwen2Model": ("qwen2", "Qwen2EmbeddingModel"), "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"), @@ -123,7 +126,7 @@ # [Multimodal] "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501 "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"), - "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration") # noqa: E501, + "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"), # noqa: E501 } _CROSS_ENCODER_MODELS = { @@ -149,6 +152,7 @@ "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501 "LlavaNextVideoForConditionalGeneration": ("llava_next_video", "LlavaNextVideoForConditionalGeneration"), # noqa: E501 "LlavaOnevisionForConditionalGeneration": ("llava_onevision", "LlavaOnevisionForConditionalGeneration"), # noqa: E501 + "MantisForConditionalGeneration": ("llava", "MantisForConditionalGeneration"), # noqa: E501 "MiniCPMV": ("minicpmv", "MiniCPMV"), "MolmoForCausalLM": ("molmo", "MolmoForCausalLM"), "NVLM_D": ("nvlm_d", "NVLM_D_Model"), @@ -206,8 +210,9 @@ @dataclass(frozen=True) class _ModelInfo: + architecture: str is_text_generation_model: bool - is_embedding_model: bool + is_pooling_model: bool supports_cross_encoding: bool supports_multimodal: bool supports_pp: bool @@ -216,9 +221,19 @@ class _ModelInfo: @staticmethod def from_model_cls(model: Type[nn.Module]) -> "_ModelInfo": + is_pooling_model_ = is_pooling_model(model) + if not is_pooling_model_: + try: + as_embedding_model(model) + except Exception: + pass + else: + is_pooling_model_ = True + return _ModelInfo( + architecture=model.__name__, is_text_generation_model=is_text_generation_model(model), - is_embedding_model=is_embedding_model(model), + is_pooling_model=is_pooling_model_, supports_cross_encoding=supports_cross_encoding(model), supports_multimodal=supports_multimodal(model), supports_pp=supports_pp(model), @@ -397,13 +412,13 @@ def _normalize_archs( def inspect_model_cls( self, architectures: Union[str, List[str]], - ) -> _ModelInfo: + ) -> Tuple[_ModelInfo, str]: architectures = self._normalize_archs(architectures) for arch in architectures: model_info = self._try_inspect_model_cls(arch) if model_info is not None: - return model_info + return (model_info, arch) return self._raise_for_unsupported(architectures) @@ -424,39 +439,50 @@ def is_text_generation_model( self, architectures: Union[str, List[str]], ) -> bool: - return self.inspect_model_cls(architectures).is_text_generation_model + model_cls, _ = self.inspect_model_cls(architectures) + return model_cls.is_text_generation_model - def is_embedding_model( + def is_pooling_model( self, architectures: Union[str, List[str]], ) -> bool: - return self.inspect_model_cls(architectures).is_embedding_model + model_cls, _ = self.inspect_model_cls(architectures) + return model_cls.is_pooling_model def is_cross_encoder_model( self, architectures: Union[str, List[str]], ) -> bool: - return self.inspect_model_cls(architectures).supports_cross_encoding + model_cls, _ = self.inspect_model_cls(architectures) + return model_cls.supports_cross_encoding def is_multimodal_model( self, architectures: Union[str, List[str]], ) -> bool: - return self.inspect_model_cls(architectures).supports_multimodal + model_cls, _ = self.inspect_model_cls(architectures) + return model_cls.supports_multimodal def is_pp_supported_model( self, architectures: Union[str, List[str]], ) -> bool: - return self.inspect_model_cls(architectures).supports_pp + model_cls, _ = self.inspect_model_cls(architectures) + return model_cls.supports_pp - def model_has_inner_state(self, architectures: Union[str, - List[str]]) -> bool: - return self.inspect_model_cls(architectures).has_inner_state + def model_has_inner_state( + self, + architectures: Union[str, List[str]], + ) -> bool: + model_cls, _ = self.inspect_model_cls(architectures) + return model_cls.has_inner_state - def is_attention_free_model(self, architectures: Union[str, - List[str]]) -> bool: - return self.inspect_model_cls(architectures).is_attention_free + def is_attention_free_model( + self, + architectures: Union[str, List[str]], + ) -> bool: + model_cls, _ = self.inspect_model_cls(architectures) + return model_cls.is_attention_free ModelRegistry = _ModelRegistry({ diff --git a/vllm/model_executor/models/siglip.py b/vllm/model_executor/models/siglip.py index deaed0ba7e4ce..6fb9e2cc4584f 100644 --- a/vllm/model_executor/models/siglip.py +++ b/vllm/model_executor/models/siglip.py @@ -6,12 +6,11 @@ import numpy as np import torch -import torch.nn.functional as F from PIL import Image from torch import nn from transformers import SiglipVisionConfig -from vllm.attention.selector import _Backend +from vllm.attention.layer import MultiHeadAttention from vllm.config import ModelConfig from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.inputs import DecoderOnlyInputs, token_inputs @@ -29,8 +28,6 @@ resolve_visual_encoder_outputs) from vllm.sequence import SequenceData -from .utils import get_vit_attn_backend - def get_siglip_patch_grid_length(*, image_size: int, patch_size: int) -> int: # Since interpolation is applied, the image size need not be divisible @@ -291,52 +288,18 @@ def __init__( self.tp_size = get_tensor_model_parallel_world_size() self.num_heads_per_partition = divide(self.num_heads, self.tp_size) - self.attn_backend = get_vit_attn_backend(support_fa=False) - if self.attn_backend not in {_Backend.TORCH_SDPA, _Backend.XFORMERS}: - raise RuntimeError( - f"SIGLIP does not support {self.attn_backend} backend now.") + self.attn = MultiHeadAttention(self.num_heads_per_partition, + self.head_dim, self.scale) def forward( self, hidden_states: torch.Tensor, ) -> torch.Tensor: """Input shape: Batch x Time x Channel""" - batch_size, q_len, _ = hidden_states.size() - qkv_states, _ = self.qkv_proj(hidden_states) query_states, key_states, value_states = qkv_states.chunk(3, dim=-1) - query_states = query_states.view(batch_size, q_len, - self.num_heads_per_partition, - self.head_dim) - key_states = key_states.view(batch_size, q_len, - self.num_heads_per_partition, - self.head_dim) - value_states = value_states.view(batch_size, q_len, - self.num_heads_per_partition, - self.head_dim) - - if self.attn_backend == _Backend.XFORMERS: - from xformers import ops as xops - - out = xops.memory_efficient_attention_forward(query_states, - key_states, - value_states, - p=self.dropout, - scale=self.scale) - elif self.attn_backend == _Backend.TORCH_SDPA: - query_states, key_states, value_states = (x.transpose(1, 2) - for x in (query_states, - key_states, - value_states)) - out = F.scaled_dot_product_attention(query_states, - key_states, - value_states, - dropout_p=self.dropout, - scale=self.scale) - out = out.transpose(1, 2) - - out = out.view(batch_size, q_len, -1) + out = self.attn(query_states, key_states, value_states) attn_output, _ = self.out_proj(out) return attn_output, None diff --git a/vllm/model_executor/models/solar.py b/vllm/model_executor/models/solar.py index f58710d215056..caae0b65d7d10 100644 --- a/vllm/model_executor/models/solar.py +++ b/vllm/model_executor/models/solar.py @@ -443,10 +443,11 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size, logit_scale) - self.sampler = get_sampler() else: self.lm_head = PPMissingLayer() + self.sampler = get_sampler() + self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index b61deccde45b7..ea1e5401d42c0 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -360,9 +360,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): )) self.multi_modal_projector = UltravoxProjector(config) self.language_model = init_vllm_registered_model( - config.text_config, vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "language_model")) + hf_config=config.text_config, + prefix=maybe_prefix(prefix, "language_model"), + ) if config.text_model_id is not None: # this prefix is not for initialization, but for loading weights # note the trailing dot diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 4c13cbc953273..269b66806adf4 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -1,7 +1,7 @@ import itertools from dataclasses import dataclass, field -from typing import (Any, Callable, Dict, Iterable, List, Literal, Mapping, - Optional, Protocol, Set, Tuple, Union, overload) +from typing import (Callable, Dict, Iterable, List, Literal, Mapping, Optional, + Protocol, Set, Tuple, Union, overload) import torch import torch.nn as nn @@ -17,7 +17,7 @@ from vllm.multimodal import MultiModalPlaceholderMap, NestedTensors from vllm.platforms import _Backend, current_platform from vllm.sequence import IntermediateTensors -from vllm.utils import is_pin_memory_available +from vllm.utils import is_pin_memory_available, print_warning_once logger = init_logger(__name__) @@ -173,8 +173,15 @@ def _load_module( module_load_weights = getattr(module, "load_weights", None) if callable(module_load_weights): loaded_params = module_load_weights(weights) - yield from map(lambda x: self._get_qualname(base_prefix, x), - loaded_params) + if loaded_params is None: + logger.warning( + "Unable to collect loaded parameters " + "for module %s", module) + else: + yield from map( + lambda x: self._get_qualname(base_prefix, x), + loaded_params, + ) child_modules = dict(module.named_children()) child_params = dict(module.named_parameters(recurse=False)) @@ -232,17 +239,27 @@ def load_weights( def init_vllm_registered_model( - hf_config: PretrainedConfig, vllm_config: VllmConfig, + *, prefix: str = "", + hf_config: Optional[PretrainedConfig] = None, + architectures: Optional[list[str]] = None, ) -> nn.Module: """ Helper function to initialize an inner model registered to vLLM, based on the arguments passed to the outer vLLM model. """ from vllm.model_executor.model_loader.loader import _initialize_model - vllm_config = vllm_config.with_hf_config(hf_config) - return _initialize_model(vllm_config, prefix) + + if hf_config is None and architectures is not None: + # So that the architectures field is overridden + hf_config = vllm_config.model_config.hf_config + + if hf_config is not None: + vllm_config = vllm_config.with_hf_config(hf_config, + architectures=architectures) + + return _initialize_model(vllm_config=vllm_config, prefix=prefix) @overload @@ -392,16 +409,42 @@ def merge_multimodal_embeddings( input_ids: torch.Tensor, inputs_embeds: torch.Tensor, multimodal_embeddings: NestedTensors, - placeholder_token_id: int, + placeholder_token_id: Union[int, List[int]], ) -> torch.Tensor: """ Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the positions in ``inputs_embeds`` corresponding to placeholder tokens in ``input_ids``. + + ``placeholder_token_id`` can be a list of token ids (e.g, token ids + of img_start, img_break, and img_end tokens) when needed: This means + the order of these tokens in the ``input_ids`` MUST MATCH the order of + their embeddings in ``multimodal_embeddings`` since we need to + slice-merge instead of individually scattering. + + For example, if input_ids is "TTTTTSIIIBIIIBIIIETTT", where + - T is text token + - S is image start token + - I is image embedding token + - B is image break token + - E is image end token. + + Then the image embeddings (that correspond to I's) from vision encoder + must be padded with embeddings of S, B, and E in the same order of + input_ids for a correct embedding merge. Note: This updates ``inputs_embeds`` in place. """ + if isinstance(placeholder_token_id, list): + placeholder_token_id = torch.tensor(placeholder_token_id, + device=input_ids.device) + return _merge_multimodal_embeddings( + inputs_embeds, + torch.isin(input_ids, placeholder_token_id), + multimodal_embeddings, + ) + return _merge_multimodal_embeddings( inputs_embeds, (input_ids == placeholder_token_id), @@ -560,30 +603,6 @@ def make_empty_intermediate_tensors( return make_empty_intermediate_tensors -class LLMWrapper(nn.Module): - """ - To align with the key names of LoRA trained with PEFT, we need to add an - additional layer to the llm's implementation. - """ - - def __init__(self, llm: nn.Module, name: str) -> None: - super().__init__() - self.model_name = name - setattr(self, name, llm) - - def __getattr__(self, key: str): - llm = super().__getattr__(self.model_name) - if key == self.model_name: - return llm - - return getattr(llm, key) - - # We need to explicitly override this - def __call__(self, *args: Any, **kwargs: Any) -> Any: - llm = super().__getattr__(self.model_name) - return llm(*args, **kwargs) - - def get_vit_attn_backend(support_fa: bool = False) -> _Backend: """ Get the available attention backend for Vision Transformer. @@ -602,7 +621,7 @@ def get_vit_attn_backend(support_fa: bool = False) -> _Backend: if is_flash_attn_2_available(): selected_backend = _Backend.FLASH_ATTN else: - logger.warning( + print_warning_once( "Current `vllm-flash-attn` has a bug inside vision module, " "so we use xformers backend instead. You can run " "`pip install flash-attn` to use flash-attention backend.") diff --git a/vllm/model_executor/models/xverse.py b/vllm/model_executor/models/xverse.py deleted file mode 100644 index 25a0d474e2863..0000000000000 --- a/vllm/model_executor/models/xverse.py +++ /dev/null @@ -1,423 +0,0 @@ -# Adapted from -# https://huggingface.co/xverse/XVERSE-7B/blob/main/modeling_xverse.py -# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Inference-only Xverse model compatible with HuggingFace weights.""" -from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union - -import torch -from torch import nn -from transformers import PretrainedConfig - -from vllm.attention import Attention, AttentionMetadata -from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size -from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) -from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler -from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors - -from .interfaces import SupportsLoRA, SupportsPP -from .utils import (is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) - - -class XverseMLP(nn.Module): - - def __init__( - self, - hidden_size: int, - intermediate_size: int, - hidden_act: str, - quant_config: Optional[QuantizationConfig] = None, - ) -> None: - super().__init__() - self.gate_up_proj = MergedColumnParallelLinear( - hidden_size, [intermediate_size] * 2, - bias=False, - quant_config=quant_config) - self.down_proj = RowParallelLinear(intermediate_size, - hidden_size, - bias=False, - quant_config=quant_config) - if hidden_act != "silu": - raise ValueError(f"Unsupported activation: {hidden_act}. " - "Only silu is supported for now.") - self.act_fn = SiluAndMul() - - def forward(self, x): - gate, _ = self.gate_up_proj(x) - x = self.act_fn(gate) - x, _ = self.down_proj(x) - return x - - -class XverseAttention(nn.Module): - - def __init__( - self, - hidden_size: int, - num_heads: int, - num_kv_heads: int, - rope_theta: float = 10000, - rope_scaling: Optional[Dict[str, Any]] = None, - max_position_embeddings: int = 8192, - quant_config: Optional[QuantizationConfig] = None, - bias: bool = False, - cache_config: Optional[CacheConfig] = None, - prefix: str = "", - ) -> None: - super().__init__() - self.hidden_size = hidden_size - tp_size = get_tensor_model_parallel_world_size() - self.total_num_heads = num_heads - assert self.total_num_heads % tp_size == 0 - self.num_heads = self.total_num_heads // tp_size - self.total_num_kv_heads = num_kv_heads - # partition the KV heads across multiple tensor parallel GPUs. - assert self.total_num_kv_heads % tp_size == 0 - self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) - self.head_dim = hidden_size // self.total_num_heads - self.q_size = self.num_heads * self.head_dim - self.kv_size = self.num_kv_heads * self.head_dim - self.scaling = self.head_dim**-0.5 - self.rope_theta = rope_theta - self.max_position_embeddings = max_position_embeddings - - self.qkv_proj = QKVParallelLinear( - hidden_size, - self.head_dim, - self.total_num_heads, - self.total_num_kv_heads, - bias=bias, - quant_config=quant_config, - ) - self.o_proj = RowParallelLinear( - self.total_num_heads * self.head_dim, - hidden_size, - bias=bias, - quant_config=quant_config, - ) - - self.rotary_emb = get_rope( - self.head_dim, - rotary_dim=self.head_dim, - max_position=max_position_embeddings, - base=rope_theta, - rope_scaling=rope_scaling, - ) - self.attn = Attention(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn") - - def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, - ) -> torch.Tensor: - qkv, _ = self.qkv_proj(hidden_states) - q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) - output, _ = self.o_proj(attn_output) - return output - - -class XverseDecoderLayer(nn.Module): - - def __init__( - self, - config: PretrainedConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ) -> None: - super().__init__() - self.hidden_size = config.hidden_size - rope_theta = getattr(config, "rope_theta", 10000) - rope_scaling = getattr(config, "rope_scaling", None) - max_position_embeddings = getattr(config, "max_position_embeddings", - 8192) - self.self_attn = XverseAttention( - hidden_size=self.hidden_size, - num_heads=config.num_attention_heads, - num_kv_heads=getattr(config, "num_key_value_heads", - config.num_attention_heads), - rope_theta=rope_theta, - rope_scaling=rope_scaling, - max_position_embeddings=max_position_embeddings, - quant_config=quant_config, - bias=getattr(config, "bias", False), - cache_config=cache_config, - prefix=f"{prefix}.self_attn", - ) - self.mlp = XverseMLP( - hidden_size=self.hidden_size, - intermediate_size=config.intermediate_size, - hidden_act=config.hidden_act, - quant_config=quant_config, - ) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - - def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, - residual: Optional[torch.Tensor], - ) -> Tuple[torch.Tensor, torch.Tensor]: - # Self Attention - if residual is None: - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) - hidden_states = self.self_attn( - positions=positions, - hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, - ) - - # Fully Connected - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) - hidden_states = self.mlp(hidden_states) - return hidden_states, residual - - -@support_torch_compile -class XverseModel(nn.Module): - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() - config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config - quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config - self.config = config - self.padding_idx = config.pad_token_id - lora_vocab = (lora_config.lora_extra_vocab_size * - (lora_config.max_loras or 1)) if lora_config else 0 - self.vocab_size = config.vocab_size + lora_vocab - self.org_vocab_size = config.vocab_size - self.embed_tokens = VocabParallelEmbedding( - self.vocab_size, - config.hidden_size, - org_num_embeddings=config.vocab_size, - ) - self.start_layer, self.end_layer, self.layers = make_layers( - config.num_hidden_layers, - lambda prefix: XverseDecoderLayer( - config, cache_config, quant_config, prefix=prefix), - prefix=f"{prefix}.layers", - ) - self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size)) - - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.embed_tokens(input_ids) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, - intermediate_tensors: Optional[IntermediateTensors], - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: - if get_pp_group().is_first_rank: - if inputs_embeds is not None: - hidden_states = inputs_embeds - else: - hidden_states = self.get_input_embeddings(input_ids) - residual = None - else: - hidden_states = intermediate_tensors["hidden_states"] - residual = intermediate_tensors["residual"] - for i in range(self.start_layer, self.end_layer): - layer = self.layers[i] - hidden_states, residual = layer( - positions, - hidden_states, - kv_caches[i - self.start_layer], - attn_metadata, - residual, - ) - if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) - hidden_states, _ = self.norm(hidden_states, residual) - return hidden_states - - -class XverseForCausalLM(nn.Module, SupportsLoRA, SupportsPP): - packed_modules_mapping = { - "qkv_proj": [ - "q_proj", - "k_proj", - "v_proj", - ], - "gate_up_proj": [ - "gate_proj", - "up_proj", - ], - } - - # LoRA specific attributes - supported_lora_modules = [ - "qkv_proj", - "o_proj", - "gate_up_proj", - "down_proj", - "embed_tokens", - "lm_head", - ] - embedding_modules = { - "embed_tokens": "input_embeddings", - "lm_head": "output_embeddings", - } - embedding_padding_modules = ["lm_head"] - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() - - config = vllm_config.model_config.hf_config - quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config - - self.config = config - self.lora_config = lora_config - - self.quant_config = quant_config - self.model = XverseModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) - self.lm_head = ParallelLMHead(config.vocab_size, - config.hidden_size, - quant_config=quant_config) - if self.config.tie_word_embeddings: - self.lm_head.weight = self.model.embed_tokens.weight - self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = get_sampler() - self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) - - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.get_input_embeddings(input_ids) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors, - inputs_embeds) - return hidden_states - - def compute_logits( - self, - hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) - return logits - - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: - stacked_params_mapping = [ - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ("gate_up_proj", "gate_proj", 0), - ("gate_up_proj", "up_proj", 1), - ] - params_dict = dict(self.named_parameters()) - loaded_params: Set[str] = set() - for name, loaded_weight in weights: - if ("rotary_emb.inv_freq" in name - or "rotary_emb.cos_cached" in name - or "rotary_emb.sin_cached" in name): - continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py index 84f35f75a0c32..1df8f84ed4093 100644 --- a/vllm/model_executor/sampling_metadata.py +++ b/vllm/model_executor/sampling_metadata.py @@ -454,6 +454,7 @@ def from_sampling_metadata( if do_penalties: for seq_group in sampling_metadata.seq_groups: seq_ids = seq_group.seq_ids + sampling_params = seq_group.sampling_params if (seq_group.is_prompt and sampling_params.prompt_logprobs is not None): prefill_len = len(seq_group.prompt_logprob_indices) diff --git a/vllm/multimodal/__init__.py b/vllm/multimodal/__init__.py index 03a5f3a91f7a1..928c31a2f2843 100644 --- a/vllm/multimodal/__init__.py +++ b/vllm/multimodal/__init__.py @@ -27,18 +27,3 @@ "MULTIMODAL_REGISTRY", "MultiModalRegistry", ] - - -def __getattr__(name: str): - import warnings - - if name == "MultiModalInputs": - msg = ("MultiModalInputs has been renamed to MultiModalKwargs. " - "The original name will take another meaning in an upcoming " - "version.") - - warnings.warn(DeprecationWarning(msg), stacklevel=2) - - return MultiModalKwargs - - raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/vllm/multimodal/base.py b/vllm/multimodal/base.py index 6eec660e42ac4..7dba94b885b6d 100644 --- a/vllm/multimodal/base.py +++ b/vllm/multimodal/base.py @@ -7,7 +7,7 @@ from vllm.inputs import InputContext from vllm.logger import init_logger -from vllm.utils import (get_allowed_kwarg_only_overrides, +from vllm.utils import (ClassRegistry, get_allowed_kwarg_only_overrides, resolve_mm_processor_kwargs) if TYPE_CHECKING: @@ -54,8 +54,8 @@ class MultiModalPlugin(ABC): """ def __init__(self) -> None: - self._input_mappers: Dict[Type[nn.Module], MultiModalInputMapper] = {} - self._max_mm_tokens: Dict[Type[nn.Module], MultiModalTokensCalc] = {} + self._input_mappers = ClassRegistry[nn.Module, MultiModalInputMapper]() + self._max_mm_tokens = ClassRegistry[nn.Module, MultiModalTokensCalc]() @abstractmethod def get_data_key(self) -> str: @@ -226,16 +226,16 @@ def get_max_multimodal_tokens(self, model_config: "ModelConfig") -> int: """ # Avoid circular import from vllm.model_executor.model_loader import get_model_architecture + from vllm.model_executor.models import supports_multimodal model_cls, _ = get_model_architecture(model_config) - if model_cls not in self._input_mappers: + if not supports_multimodal(model_cls): return 0 max_mm_tokens = self._max_mm_tokens.get(model_cls) if max_mm_tokens is None: - raise KeyError(f"No maximum number of multi-modal tokens is given " - f"for model class {model_cls.__name__} in {self}.") + return 0 if callable(max_mm_tokens): mm_processor_kwargs = get_allowed_kwarg_only_overrides( @@ -326,26 +326,47 @@ def from_seq_group( src_ranges = [] dest_ranges = [] """ - if (not seq_group.multi_modal_data - or not seq_group.multi_modal_placeholders): - return seq_group.multi_modal_data, {} + seq_mm_data = seq_group.multi_modal_data + seq_mm_placeholders = seq_group.multi_modal_placeholders + + if not seq_mm_data or not seq_mm_placeholders: + return seq_mm_data, {} + + # For merged processor, we directly use mm_kwargs as mm_data + if isinstance(seq_mm_data, MultiModalKwargs): + placeholder_maps = dict[str, MultiModalPlaceholderMap]() + + for modality, placeholders in seq_mm_placeholders.items(): + placeholder_map = MultiModalPlaceholderMap() + + if positions: + placeholder_map.append_items_from_seq_group( + positions, + # Dummy, since we don't care about intersecting items + [None] * len(placeholders), + placeholders, + ) + + placeholder_maps[modality] = placeholder_map + + return seq_mm_data, placeholder_maps - mm_data = {**seq_group.multi_modal_data} - placeholder_maps: Dict[str, MultiModalPlaceholderMap] = defaultdict( + mm_data = {**seq_mm_data} + placeholder_maps = defaultdict[str, MultiModalPlaceholderMap]( MultiModalPlaceholderMap) - for ( - modality, - placeholders, - ) in seq_group.multi_modal_placeholders.items(): + for modality, placeholders in seq_mm_placeholders.items(): mm_items = mm_data.pop(modality) if not isinstance(mm_items, list): mm_items = [mm_items] if positions: - intersecting_items = placeholder_maps[ - modality].append_items_from_seq_group( - positions, mm_items, placeholders) + intersecting_items = placeholder_maps[modality] \ + .append_items_from_seq_group( + positions, + mm_items, + placeholders, + ) if intersecting_items: mm_data[modality] = intersecting_items @@ -433,18 +454,3 @@ def index_map(self) -> "IndexMap": return MultiModalPlaceholderMap.IndexMap(src=src_indices, dest=dest_indices) - - -def __getattr__(name: str): - import warnings - - if name == "MultiModalInputs": - msg = ("MultiModalInputs has been renamed to MultiModalKwargs. " - "The original name will take another meaning in an upcoming " - "version.") - - warnings.warn(DeprecationWarning(msg), stacklevel=2) - - return MultiModalKwargs - - raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py index 640c7c04b8817..229a8fbdf5831 100644 --- a/vllm/multimodal/inputs.py +++ b/vllm/multimodal/inputs.py @@ -96,7 +96,8 @@ class PlaceholderRange(TypedDict): """The length of the placeholder.""" -NestedTensors = Union[List["NestedTensors"], List[torch.Tensor], torch.Tensor] +NestedTensors = Union[List["NestedTensors"], List[torch.Tensor], torch.Tensor, + Tuple[torch.Tensor, ...]] """ Uses a list instead of a tensor if the dimensions of each element do not match. """ diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index 28c8dda581982..c3a95d60e6fe6 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -3,14 +3,13 @@ from collections.abc import Callable, ItemsView, Iterable, Mapping, Sequence from dataclasses import dataclass from functools import lru_cache -from itertools import groupby from typing import Any, Generic, NamedTuple, Optional, Protocol, TypeVar, Union -import numpy as np -from transformers import BatchFeature +import torch +from transformers import BatchFeature, ProcessorMixin from typing_extensions import TypeAlias, TypedDict -from vllm.inputs import InputProcessingContext +from vllm.inputs import DummyData, InputProcessingContext from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer from vllm.utils import flatten_2d_lists, full_groupby, is_list_of @@ -256,63 +255,6 @@ def to_multi_format(data: MultiModalDataDict) -> dict[str, list[Any]]: return multi_data -class _TokenRun(NamedTuple): - token_id: int - - start_idx: int - length: int - - -def iter_token_runs(token_ids: list[int]) -> Iterable[_TokenRun]: - """ - Yield the starting index and length of each run of tokens that are the same. - """ - start_idx = 0 - - for token_id, it in groupby(token_ids): - length = sum(1 for _ in it) - yield _TokenRun(token_id=token_id, start_idx=start_idx, length=length) - - start_idx += length - - -class _PlaceholderInfo(NamedTuple): - modality: str - offset: int - length: int - - def to_range(self) -> PlaceholderRange: - return PlaceholderRange(offset=self.offset, length=self.length) - - -def iter_placeholders( - prompt_repls: Sequence[_BoundPromptReplacement[Any]], - token_ids: list[int], - *, - min_placeholder_count: int, -) -> Iterable[_PlaceholderInfo]: - """Yield each set of placeholder tokens found in :code:`token_ids`.""" - placeholder_ids_by_modality = { - modality: { - token_id - for prompt_repl in repls - for token_id in prompt_repl.repl_unit.token_ids - } - for modality, repls in full_groupby_modality(prompt_repls) - } - - for run_info in iter_token_runs(token_ids): - if run_info.length > min_placeholder_count: - for (modality, - placeholder_ids) in placeholder_ids_by_modality.items(): - if run_info.token_id in placeholder_ids: - yield _PlaceholderInfo( - modality=modality, - offset=run_info.start_idx, - length=run_info.length, - ) - - class _TokenMatch(NamedTuple): start_idx: int end_idx: int @@ -353,13 +295,9 @@ def start_idx(self) -> int: def end_idx(self) -> int: raise NotImplementedError + @property @abstractmethod - def get_repl( - self, - mm_items: list[_T], - hf_inputs: BatchFeature, - item_idx: int, - ) -> _S: + def repl_unit(self) -> _S: raise NotImplementedError def __repr__(self) -> str: @@ -380,15 +318,9 @@ def start_idx(self) -> int: def end_idx(self) -> int: return self.match.end_idx - def get_repl( - self, - mm_items: list[_T], - hf_inputs: BatchFeature, - item_idx: int, - ) -> list[int]: - prompt_repl = self.prompt_repl - count = prompt_repl.get_count(mm_items, hf_inputs, item_idx) - return prompt_repl.repl_unit.token_ids * count + @property + def repl_unit(self) -> list[int]: + return self.prompt_repl.repl_unit.token_ids @dataclass(repr=False) @@ -404,15 +336,26 @@ def start_idx(self) -> int: def end_idx(self) -> int: return self.match.end() - def get_repl( - self, - mm_items: list[_T], - hf_inputs: BatchFeature, - item_idx: int, - ) -> str: - prompt_repl = self.prompt_repl - count = prompt_repl.get_count(mm_items, hf_inputs, item_idx) - return prompt_repl.repl_unit.text * count + @property + def repl_unit(self) -> str: + return self.prompt_repl.repl_unit.text + + +class _PlaceholderInfo(NamedTuple): + modality: str + start_idx: int + unit: list[int] + unit_count: int + + @property + def length(self) -> int: + return len(self.unit) * self.unit_count + + def to_range(self) -> PlaceholderRange: + return PlaceholderRange( + offset=self.start_idx, + length=self.length, + ) def find_token_matches( @@ -447,15 +390,17 @@ def _resolve_matches( Resolve :code:`matches` to ensure that there are no overlapping matches, and sort them such that earlier matches take priority over later ones. """ - num_matches_by_idx = np.zeros(len(prompt), dtype=int) + seen_matches: list[Optional[_PromptReplacementMatch[_T, _S]]] \ + = [None] * len(prompt) + for match in matches: - num_matches_by_idx[match.start_idx:match.end_idx] += 1 + for idx in range(match.start_idx, match.end_idx): + if seen_matches[idx] is not None: + raise ValueError("Found overlapping matches " + f"({seen_matches[idx]} and {match}) " + f"at index={idx} of prompt={prompt}") - duplicate_matches_idxs, = np.nonzero(num_matches_by_idx > 1) - if len(duplicate_matches_idxs) > 0: - raise ValueError("Unable to find a unique replacement " - f"at indices={duplicate_matches_idxs} " - f"of prompt={prompt}") + seen_matches[idx] = match return sorted(matches, key=lambda x: x.start_idx) @@ -480,9 +425,12 @@ def _replace_matches( start_idx = match.start_idx end_idx = match.end_idx - repl_ids = match.get_repl(mm_items, hf_inputs, item_idx) + repl_unit = match.repl_unit + repl_info = match.prompt_repl + repl_count = repl_info.get_count(mm_items, hf_inputs, item_idx) - out_seqs.append(prompt[prev_end_idx:start_idx] + repl_ids) + out_seqs.append(prompt[prev_end_idx:start_idx] + + repl_unit * repl_count) prev_end_idx = end_idx next_idx_by_modality[modality] += 1 @@ -531,9 +479,59 @@ def replace_text_matches( return "".join(texts) -class MultiModalProcessor: +def _merge_placeholder_matches( + matches: Iterable[_PromptReplacementTokenMatch], +) -> Iterable[_PromptReplacementTokenMatch]: + current_match = None + + for match in sorted(matches, key=lambda x: x.start_idx): + if current_match is None: + current_match = match + elif (current_match.prompt_repl == match.prompt_repl + and current_match.end_idx == match.start_idx): + current_match = _PromptReplacementTokenMatch( + current_match.prompt_repl, + match=_TokenMatch(current_match.start_idx, match.end_idx), + ) + else: + yield current_match + current_match = match + + if current_match is not None: + yield current_match + + +def iter_placeholders( + prompt_repls: Sequence[_BoundPromptReplacement[Any]], + prompt: list[int], + *, + min_unit_count: int = 1, +) -> Iterable[_PlaceholderInfo]: + """Yield each set of placeholder tokens found in :code:`token_ids`.""" + if min_unit_count <= 0: + raise ValueError("`min_unit_count` must be a positive integer") + + matches = (_PromptReplacementTokenMatch(prompt_repl, match) + for prompt_repl in prompt_repls + if len(repl_unit := prompt_repl.repl_unit.token_ids) > 0 + for match in iter_token_matches(prompt, repl_unit)) + + for match in _merge_placeholder_matches(matches): + unit = match.repl_unit + placeholder = _PlaceholderInfo( + modality=match.modality, + start_idx=match.start_idx, + unit=unit, + unit_count=(match.end_idx - match.start_idx) // len(unit), + ) + + if placeholder.unit_count >= min_unit_count: + yield placeholder + + +class BaseMultiModalProcessor(ABC): """ - Helper class to process multi-modal inputs to be used in vLLM. + Abstract base class to process multi-modal inputs to be used in vLLM. """ def __init__( @@ -546,6 +544,12 @@ def __init__( self.ctx = ctx self.metadata = metadata + def _get_hf_processor(self) -> ProcessorMixin: + return self.ctx.get_hf_processor() + + def _get_tokenizer(self) -> AnyTokenizer: + return self.ctx.tokenizer + def __call__( self, prompt: str, @@ -562,13 +566,13 @@ def _find_placeholders( # To avoid false positives from multi-input when detecting # whether placeholder tokens have been inserted, in case # the target sequence is a subset of the replacement tokens - min_placeholder_count: int = 16, + min_unit_count: int = 16, ) -> list[_PlaceholderInfo]: return list( iter_placeholders( all_prompt_repls, new_token_ids, - min_placeholder_count=min_placeholder_count, + min_unit_count=min_unit_count, )) def _apply_hf_processor( @@ -577,19 +581,49 @@ def _apply_hf_processor( mm_data: MultiModalDataDict, mm_processor_kwargs: Mapping[str, object], ) -> BatchFeature: - hf_processor = self.ctx.get_hf_processor() + hf_processor = self._get_hf_processor() + + processor_data = dict[str, Any]() + passthrough_data = dict[str, Any]() + for k, v in mm_data.items(): + # TODO: Make a separate modality for embedding inputs + # to avoid confusion + if k in ("image", "video", "audio"): + if isinstance(v, torch.Tensor) and v.ndim == 3: + # Pass through embedding inputs (single) + passthrough_data[f"{k}_embeds"] = [v] + elif is_list_of(v, torch.Tensor) and v[0].ndim == 2: + # Pass through embedding inputs (multi) + passthrough_data[f"{k}_embeds"] = v + else: + # Map keys to plural form, e.g.: image -> images + processor_data[f"{k}s"] = v + else: + processor_data[k] = v + + try: + hf_inputs = hf_processor( + text=prompt, # type: ignore + **processor_data, + **mm_processor_kwargs, + return_tensors="pt", + ) + except Exception as exc: + data = dict(text=prompt, **processor_data) - return hf_processor( - text=prompt, # type: ignore - **mm_data, - **mm_processor_kwargs, - ) + raise RuntimeError( + f"Failed to apply {type(hf_processor).__name__} " + f"on data={data} with kwargs={mm_processor_kwargs}") from exc + + hf_inputs.update(passthrough_data) + + return hf_inputs def _bind_prompt_replacements( self, mm_data: MultiModalDataDict, ) -> list[_BoundPromptReplacement[Any]]: - tokenizer = self.ctx.tokenizer + tokenizer = self._get_tokenizer() return [ prompt_repl.bind(modality, tokenizer) @@ -604,7 +638,7 @@ def _apply_prompt_replacements( token_ids: list[int], prompt_repls: Sequence[_BoundPromptReplacement[Any]], ) -> tuple[list[int], str, list[_PlaceholderInfo]]: - tokenizer = self.ctx.tokenizer + tokenizer = self._get_tokenizer() mm_items = to_multi_format(mm_data) token_matches = find_token_matches(token_ids, prompt_repls) @@ -620,7 +654,7 @@ def _apply_prompt_replacements( # of the search text in the prompt, we instead perform string # replacement on the decoded token IDs, then encode them back. if all( - len(matches) >= len(mm_data[modality]) + len(matches) >= len(mm_items[modality]) for modality, matches in full_groupby_modality(token_matches) ): # yapf: disable token_ids = replace_token_matches( @@ -648,15 +682,6 @@ def _apply_prompt_replacements( placeholders = self._find_placeholders(matched_repls, token_ids) - # Sanity check - assert len(placeholders) == len(matched_repls), dict( - # Log this information for easier debugging - text=text, - token_ids=token_ids, - placeholders=placeholders, - matched_repls=matched_repls, - ) - return token_ids, text, placeholders def apply( @@ -678,7 +703,7 @@ def apply( 3. Extract information about the placeholder tokens from the processed token IDs. """ - tokenizer = self.ctx.tokenizer + tokenizer = self._get_tokenizer() hf_inputs = self._apply_hf_processor(prompt_text, mm_data, mm_processor_kwargs) @@ -717,3 +742,59 @@ def apply( mm_kwargs=mm_kwargs, mm_placeholders=mm_placeholders, ) + + @abstractmethod + def _get_dummy_mm_kwargs( + self, + mm_counts: Mapping[str, int], + ) -> MultiModalKwargs: + """ + Build the input that corresponds to `mm_max_tokens` in + :meth:`get_dummy_data`. + """ + raise NotImplementedError + + def get_dummy_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + mm_max_tokens: Mapping[str, int], + ) -> DummyData: + # Avoid circular import + from vllm.sequence import SequenceData + + tokenizer = self._get_tokenizer() + + mm_placeholders = dict[str, _PlaceholderInfo]() + offset = 0 + + for modality, max_tokens in mm_max_tokens.items(): + if max_tokens == 0: + continue + + metadata = self.metadata[modality] + repl = metadata.prompt_repls[0].bind(modality, tokenizer) + repl_token_ids = repl.repl_unit.token_ids + + placeholders = _PlaceholderInfo( + modality=modality, + start_idx=offset, + unit=repl_token_ids, + unit_count=max_tokens // len(repl_token_ids), + ) + + mm_placeholders[modality] = placeholders + offset += placeholders.length + + prompt_token_ids = flatten_2d_lists( + [p.unit * p.unit_count for p in mm_placeholders.values()]) + prompt_token_ids.extend([0] * (seq_len - len(prompt_token_ids))) + + return DummyData( + seq_data=SequenceData.from_seqs(prompt_token_ids), + multi_modal_data=self._get_dummy_mm_kwargs(mm_counts), + multi_modal_placeholders={ + modality: [p.to_range()] + for modality, p in mm_placeholders.items() + }, + ) diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index b992442d3b314..6ab6c0fe2f12e 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -9,12 +9,13 @@ from vllm.inputs import InputProcessingContext from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.utils import ClassRegistry from .audio import AudioPlugin from .base import MultiModalInputMapper, MultiModalPlugin, MultiModalTokensCalc from .image import ImagePlugin from .inputs import MultiModalDataDict, MultiModalKwargs, NestedTensors -from .processing import MultiModalProcessor +from .processing import BaseMultiModalProcessor from .video import VideoPlugin if TYPE_CHECKING: @@ -25,7 +26,7 @@ N = TypeVar("N", bound=Type[nn.Module]) MultiModalProcessorFactory: TypeAlias = Callable[[InputProcessingContext], - MultiModalProcessor] + BaseMultiModalProcessor] """ Constructs a :class:`MultiModalProcessor` instance from the context. @@ -62,8 +63,8 @@ def __init__( plugins: Sequence[MultiModalPlugin] = DEFAULT_PLUGINS) -> None: self._plugins = {p.get_data_key(): p for p in plugins} - self._processor_factories: Dict[Type[nn.Module], - MultiModalProcessorFactory] = {} + self._processor_factories = ClassRegistry[nn.Module, + MultiModalProcessorFactory]() # This is used for non-multimodal models self._disabled_limits_per_plugin = {k: 0 for k in self._plugins} @@ -199,9 +200,12 @@ def register_max_image_tokens( """ return self.register_max_multimodal_tokens("image", max_mm_tokens) - def get_max_multimodal_tokens(self, model_config: "ModelConfig") -> int: + def get_max_tokens_by_modality( + self, + model_config: "ModelConfig", + ) -> Mapping[str, int]: """ - Get the maximum number of multi-modal tokens + Get the maximum number of tokens from each modality for profiling the memory usage of a model. See :meth:`MultiModalPlugin.get_max_multimodal_tokens` for more details. @@ -211,9 +215,23 @@ def get_max_multimodal_tokens(self, model_config: "ModelConfig") -> int: """ limits_per_plugin = self._limits_by_model[model_config] - return sum((limits_per_plugin[key] * - plugin.get_max_multimodal_tokens(model_config)) - for key, plugin in self._plugins.items()) + return { + key: (limits_per_plugin[key] * + plugin.get_max_multimodal_tokens(model_config)) + for key, plugin in self._plugins.items() + } + + def get_max_multimodal_tokens(self, model_config: "ModelConfig") -> int: + """ + Get the maximum number of multi-modal tokens + for profiling the memory usage of a model. + + See :meth:`MultiModalPlugin.get_max_multimodal_tokens` for more details. + + Note: + This should be called after :meth:`init_mm_limits_per_prompt`. + """ + return sum(self.get_max_tokens_by_modality(model_config).values()) def init_mm_limits_per_prompt( self, @@ -269,7 +287,8 @@ def register_processor( factory: MultiModalProcessorFactory, ): """ - Register a multi-modal processor to a model class. + Register a multi-modal processor to a model class. The processor + is constructed lazily, hence a factory method should be passed. When the model receives multi-modal data, the provided function is invoked to transform the data into a dictionary of model inputs. @@ -306,7 +325,7 @@ def create_processor( self, model_config: "ModelConfig", tokenizer: AnyTokenizer, - ) -> MultiModalProcessor: + ) -> BaseMultiModalProcessor: """ Create a multi-modal processor for a specific model and tokenizer. """ diff --git a/vllm/multimodal/utils.py b/vllm/multimodal/utils.py index d4333b7519b47..c898ca4e6573e 100644 --- a/vllm/multimodal/utils.py +++ b/vllm/multimodal/utils.py @@ -535,11 +535,13 @@ def repeat_and_pad_placeholder_tokens( return new_prompt, new_token_ids, placeholder_ranges -def consecutive_placeholder_ranges(num_items: int, - item_size: int) -> List[PlaceholderRange]: +def consecutive_placeholder_ranges( + num_items: int, + item_size: int, + initial_offset: int = 0) -> List[PlaceholderRange]: """Returns a list of consecutive PlaceholderRanges of a fixed size""" return [ - PlaceholderRange(offset=i * item_size, length=item_size) - for i in range(num_items) + PlaceholderRange(offset=initial_offset + i * item_size, + length=item_size) for i in range(num_items) ] diff --git a/vllm/outputs.py b/vllm/outputs.py index 2d256803edfe8..86264f604f6bc 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -53,8 +53,8 @@ def __repr__(self) -> str: @dataclass -class EmbeddingOutput: - """The output data of one completion output of a request. +class PoolingOutput: + """The output data of one pooling output of a request. Args: embedding: The embedding vector, which is a list of floats. The @@ -63,7 +63,7 @@ class EmbeddingOutput: embedding: List[float] def __repr__(self) -> str: - return (f"EmbeddingOutput(" + return (f"PoolingOutput(" f"embedding={len(self.embedding)})") @@ -316,18 +316,18 @@ def __repr__(self) -> str: f"multi_modal_placeholders={self.multi_modal_placeholders})") -class EmbeddingRequestOutput: +class PoolingRequestOutput: """ - The output data of an embedding request to the LLM. + The output data of a pooling request to the LLM. Args: - request_id (str): A unique identifier for the embedding request. - outputs (EmbeddingOutput): The embedding results for the given input. + request_id (str): A unique identifier for the pooling request. + outputs (PoolingOutput): The pooling results for the given input. prompt_token_ids (List[int]): A list of token IDs used in the prompt. - finished (bool): A flag indicating whether the embedding is completed. + finished (bool): A flag indicating whether the pooling is completed. """ - def __init__(self, request_id: str, outputs: "EmbeddingOutput", + def __init__(self, request_id: str, outputs: "PoolingOutput", prompt_token_ids: List[int], finished: bool): self.request_id = request_id self.prompt_token_ids = prompt_token_ids @@ -336,11 +336,11 @@ def __init__(self, request_id: str, outputs: "EmbeddingOutput", @classmethod def from_seq_group(cls, - seq_group: 'SequenceGroup') -> "EmbeddingRequestOutput": + seq_group: 'SequenceGroup') -> "PoolingRequestOutput": if seq_group.embeddings is None: raise ValueError( "Embeddings are missing in seq_group for EmbeddingRequest.") - output = EmbeddingOutput(seq_group.embeddings) + output = PoolingOutput(seq_group.embeddings) prompt_token_ids = seq_group.prompt_token_ids finished = seq_group.is_finished() @@ -348,15 +348,15 @@ def from_seq_group(cls, def __repr__(self): """ - Returns a string representation of an EmbeddingRequestOutput instance. + Returns a string representation of an PoolingRequestOutput instance. The representation includes the request_id and the number of outputs, - providing a quick overview of the embedding request's results. + providing a quick overview of the pooling request's results. Returns: - str: A string representation of the EmbeddingRequestOutput instance. + str: A string representation of the PoolingRequestOutput instance. """ - return (f"EmbeddingRequestOutput(request_id='{self.request_id}', " + return (f"PoolingRequestOutput(request_id='{self.request_id}', " f"outputs={repr(self.outputs)}, " f"prompt_token_ids={self.prompt_token_ids}, " f"finished={self.finished})") @@ -415,7 +415,30 @@ def create(seq_group: SequenceGroup, # Determine the type based on a condition, for example: if hasattr(seq_group, 'embeddings') and seq_group.embeddings is not None: - return EmbeddingRequestOutput.from_seq_group(seq_group) + return PoolingRequestOutput.from_seq_group(seq_group) else: return RequestOutput.from_seq_group(seq_group, use_cache, seq_id_to_seq_group) + + +def __getattr__(name: str): + import warnings + + if name == "EmbeddingOutput": + msg = ("EmbeddingOutput has been renamed to PoolingOutput. " + "The original name will be removed in an upcoming version.") + + warnings.warn(DeprecationWarning(msg), stacklevel=2) + + return PoolingOutput + + if name == "EmbeddingRequestOutput": + msg = ("EmbeddingRequestOutput has been renamed to " + "PoolingRequestOutput. " + "The original name will be removed in an upcoming version.") + + warnings.warn(DeprecationWarning(msg), stacklevel=2) + + return PoolingRequestOutput + + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/vllm/platforms/__init__.py b/vllm/platforms/__init__.py index 7cb8ac4b0a1e0..419237c252ffd 100644 --- a/vllm/platforms/__init__.py +++ b/vllm/platforms/__init__.py @@ -1,5 +1,5 @@ from .interface import _Backend # noqa: F401 -from .interface import Platform, PlatformEnum, UnspecifiedPlatform +from .interface import CpuArchEnum, Platform, PlatformEnum, UnspecifiedPlatform current_platform: Platform @@ -120,4 +120,4 @@ def cuda_is_jetson() -> bool: else: current_platform = UnspecifiedPlatform() -__all__ = ['Platform', 'PlatformEnum', 'current_platform'] +__all__ = ['Platform', 'PlatformEnum', 'current_platform', 'CpuArchEnum'] diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py index 3e22c87f61fac..680ee74129739 100644 --- a/vllm/platforms/cpu.py +++ b/vllm/platforms/cpu.py @@ -19,6 +19,7 @@ class CpuPlatform(Platform): _enum = PlatformEnum.CPU + device_name: str = "cpu" device_type: str = "cpu" dispatch_key: str = "CPU" @@ -45,7 +46,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: import vllm.envs as envs from vllm.utils import GiB_bytes model_config = vllm_config.model_config - # Reminder: Please update docs/source/serving/compatibility_matrix.rst + # Reminder: Please update docs/source/usage/compatibility_matrix.rst # If the feature combo become valid if not model_config.enforce_eager: logger.warning( diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 5e9ce551f2332..846a1869da228 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -72,6 +72,7 @@ def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R: class CudaPlatformBase(Platform): _enum = PlatformEnum.CUDA + device_name: str = "cuda" device_type: str = "cuda" dispatch_key: str = "CUDA" diff --git a/vllm/platforms/hpu.py b/vllm/platforms/hpu.py index 3071136e43b85..10aaa6d54962c 100644 --- a/vllm/platforms/hpu.py +++ b/vllm/platforms/hpu.py @@ -12,6 +12,7 @@ class HpuPlatform(Platform): _enum = PlatformEnum.HPU + device_name: str = "hpu" device_type: str = "hpu" dispatch_key: str = "HPU" diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 3328665029039..0be7df7941b8b 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -1,4 +1,5 @@ import enum +import platform import random from typing import TYPE_CHECKING, NamedTuple, Optional, Tuple, Union @@ -37,6 +38,14 @@ class PlatformEnum(enum.Enum): UNSPECIFIED = enum.auto() +class CpuArchEnum(enum.Enum): + X86 = enum.auto() + ARM = enum.auto() + POWERPC = enum.auto() + OTHER = enum.auto() + UNKNOWN = enum.auto() + + class DeviceCapability(NamedTuple): major: int minor: int @@ -56,11 +65,13 @@ def to_int(self) -> int: class Platform: _enum: PlatformEnum + device_name: str device_type: str # available dispatch keys: # check https://github.com/pytorch/pytorch/blob/313dac6c1ca0fa0cde32477509cce32089f8532a/torchgen/model.py#L134 # noqa # use "CPU" as a fallback for platforms not registered in PyTorch dispatch_key: str = "CPU" + supported_quantization: list[str] = [] def is_cuda(self) -> bool: return self._enum == PlatformEnum.CUDA @@ -171,6 +182,34 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: """ pass + @classmethod + def verify_quantization(cls, quant: str) -> None: + """ + Verify whether the quantization is supported by the current platform. + """ + if cls.supported_quantization and \ + quant not in cls.supported_quantization: + raise ValueError( + f"{quant} quantization is currently not supported in " + f"{cls.device_name}.") + + @classmethod + def get_cpu_architecture(cls) -> CpuArchEnum: + """ + Determine the CPU architecture of the current system. + Returns CpuArchEnum indicating the architecture type. + """ + machine = platform.machine().lower() + + if machine in ("x86_64", "amd64", "i386", "i686"): + return CpuArchEnum.X86 + elif machine.startswith("arm") or machine.startswith("aarch"): + return CpuArchEnum.ARM + elif machine.startswith("ppc"): + return CpuArchEnum.POWERPC + + return CpuArchEnum.OTHER if machine else CpuArchEnum.UNKNOWN + class UnspecifiedPlatform(Platform): _enum = PlatformEnum.UNSPECIFIED diff --git a/vllm/platforms/neuron.py b/vllm/platforms/neuron.py index 4c4d778ed3dd4..87655ea198303 100644 --- a/vllm/platforms/neuron.py +++ b/vllm/platforms/neuron.py @@ -10,7 +10,9 @@ class NeuronPlatform(Platform): _enum = PlatformEnum.NEURON + device_name: str = "neuron" device_type: str = "neuron" + supported_quantization: list[str] = ["neuron_quant"] @classmethod def get_device_name(cls, device_id: int = 0) -> str: diff --git a/vllm/platforms/openvino.py b/vllm/platforms/openvino.py index ea5ec7b40b95c..29b61e955d9ab 100644 --- a/vllm/platforms/openvino.py +++ b/vllm/platforms/openvino.py @@ -23,6 +23,7 @@ class OpenVinoPlatform(Platform): _enum = PlatformEnum.OPENVINO + device_name: str = "openvino" device_type: str = "openvino" dispatch_key: str = "CPU" diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index d2f44c3e423e3..3c14fbc179f69 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -4,6 +4,7 @@ import torch +import vllm.envs as envs from vllm.logger import init_logger from .interface import DeviceCapability, Platform, PlatformEnum, _Backend @@ -35,8 +36,13 @@ class RocmPlatform(Platform): _enum = PlatformEnum.ROCM + device_name: str = "rocm" device_type: str = "cuda" dispatch_key: str = "CUDA" + supported_quantization: list[str] = [ + "awq", "gptq", "fp8", "compressed_tensors", "compressed-tensors", + "fbgemm_fp8", "gguf" + ] @classmethod def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend: @@ -79,3 +85,12 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: "vllm.spec_decode.spec_decode_worker.create_spec_worker" else: parallel_config.worker_cls = "vllm.worker.worker.Worker" + + @classmethod + def verify_quantization(cls, quant: str) -> None: + super().verify_quantization(quant) + if quant == "awq" and not envs.VLLM_USE_TRITON_AWQ: + logger.warning( + "Using AWQ quantization with ROCm, but VLLM_USE_TRITON_AWQ" + " is not set, enabling VLLM_USE_TRITON_AWQ.") + envs.VLLM_USE_TRITON_AWQ = True diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index 137af57023ea9..b138f7e1c54c5 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -16,8 +16,10 @@ class TpuPlatform(Platform): _enum = PlatformEnum.TPU + device_name: str = "tpu" device_type: str = "tpu" dispatch_key: str = "XLA" + supported_quantization: list[str] = ["tpu_int8"] @classmethod def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend: diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index 69388a8e0f27c..9665786f4c499 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -16,6 +16,7 @@ class XPUPlatform(Platform): _enum = PlatformEnum.XPU + device_name: str = "xpu" device_type: str = "xpu" dispatch_key: str = "XPU" diff --git a/vllm/plugins/__init__.py b/vllm/plugins/__init__.py index 3c64726ca3344..17f604ea0e202 100644 --- a/vllm/plugins/__init__.py +++ b/vllm/plugins/__init__.py @@ -4,6 +4,7 @@ import torch import vllm.envs as envs +from vllm.platforms import current_platform logger = logging.getLogger(__name__) @@ -25,6 +26,23 @@ def load_general_plugins(): os.environ['TORCHINDUCTOR_COMPILE_THREADS'] = '1' # see https://github.com/vllm-project/vllm/issues/10619 torch._inductor.config.compile_threads = 1 + if current_platform.is_xpu(): + # see https://github.com/pytorch/pytorch/blob/8cada5cbe5450e17c26fb8b358116785324537b2/torch/_dynamo/config.py#L158 # noqa + os.environ['TORCH_COMPILE_DISABLE'] = 'True' + if current_platform.is_hpu(): + # NOTE(kzawora): PT HPU lazy backend (PT_HPU_LAZY_MODE = 1) + # does not support torch.compile + # Eager backend (PT_HPU_LAZY_MODE = 0) must be selected for + # torch.compile support + is_lazy = os.environ.get('PT_HPU_LAZY_MODE', '1') == '1' + if is_lazy: + # see https://github.com/pytorch/pytorch/blob/43c5f59/torch/_dynamo/config.py#L158 + torch._dynamo.config.disable = True + # NOTE(kzawora) multi-HPU inference with HPUGraphs (lazy-only) + # requires enabling lazy collectives + # see https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_HPU_Graphs.html # noqa: E501 + os.environ['PT_HPU_ENABLE_LAZY_COLLECTIVES'] = 'true' + global plugins_loaded if plugins_loaded: return @@ -39,7 +57,7 @@ def load_general_plugins(): discovered_plugins = entry_points(group='vllm.general_plugins') if len(discovered_plugins) == 0: - logger.info("No plugins found.") + logger.debug("No plugins found.") return logger.info("Available plugins:") for plugin in discovered_plugins: diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 5c6df5aaf5446..fc77f3ca529b2 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -293,8 +293,9 @@ def __post_init__(self) -> None: raise ValueError( f"best_of must be greater than or equal to n, " f"got n={self.n} and best_of={self.best_of}.") - self._real_n = self.n - self.n = self.best_of + if not self._real_n: + self._real_n = self.n + self.n = self.best_of if 0 < self.temperature < _MAX_TEMP: logger.warning( diff --git a/vllm/spec_decode/multi_step_worker.py b/vllm/spec_decode/multi_step_worker.py index d249b37c780e4..676ac5eb3609d 100644 --- a/vllm/spec_decode/multi_step_worker.py +++ b/vllm/spec_decode/multi_step_worker.py @@ -120,6 +120,9 @@ def sampler_output( indices_of_seq_with_bonus_tokens) model_outputs.append(model_output) + # move indices to device to avoid stream sync + indices_of_seq_with_bonus_tokens = torch.tensor( + indices_of_seq_with_bonus_tokens, device=self.device) filtered_model_outputs = self._filter_model_output( model_outputs, indices_of_seq_with_bonus_tokens) return filtered_model_outputs, True @@ -189,7 +192,7 @@ def _expand_execute_model_request( @staticmethod def _filter_model_output( expanded_batch_outputs: List[SamplerOutput], - output_indices_to_retain: List[int]) -> List[SamplerOutput]: + output_indices_to_retain: torch.Tensor) -> List[SamplerOutput]: """ Filters the model output to include only the specified sequence outputs. This method contracts the expanded batch output from the @@ -199,8 +202,8 @@ def _filter_model_output( Args: expanded_batch_output (List[SamplerOutput]): The expanded output batch from the model. - output_indices_to_retain (List[int]): Indices of the model outputs - to retain. + output_indices_to_retain (torch.Tensor): Indices of the model + outputs to retain. Returns: List[SamplerOutput]: A list containing the filtered model diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 53634f7b0b366..2689802161987 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -54,6 +54,10 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker": speculative_config: SpeculativeConfig = vllm_config.speculative_config assert speculative_config is not None + if vllm_config.parallel_config.pipeline_parallel_size > 1: + raise NotImplementedError("Speculative decoding is currently " + "incompatible with pipeline parallelism") + draft_worker_kwargs = kwargs.copy() kwargs["model_runner_cls"] = TargetModelRunner @@ -104,7 +108,7 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker": return spec_decode_worker -# Reminder: Please update docs/source/serving/compatibility_matrix.rst +# Reminder: Please update docs/source/usage/compatibility_matrix.rst # If the feature combo become valid class SpecDecodeWorker(LoraNotSupportedWorkerBase): """Worker which implements speculative decoding. diff --git a/vllm/transformers_utils/tokenizer_group/__init__.py b/vllm/transformers_utils/tokenizer_group/__init__.py index 6a114b513f382..c0b3d2585a962 100644 --- a/vllm/transformers_utils/tokenizer_group/__init__.py +++ b/vllm/transformers_utils/tokenizer_group/__init__.py @@ -1,7 +1,7 @@ from typing import Optional, Type -from vllm.config import (ModelConfig, ParallelConfig, SchedulerConfig, - TokenizerPoolConfig) +from vllm.config import (LoRAConfig, ModelConfig, ParallelConfig, + SchedulerConfig, TokenizerPoolConfig) from vllm.executor.ray_utils import ray from .base_tokenizer_group import AnyTokenizer, BaseTokenizerGroup @@ -16,10 +16,11 @@ def init_tokenizer_from_configs(model_config: ModelConfig, scheduler_config: SchedulerConfig, parallel_config: ParallelConfig, - enable_lora: bool): + lora_config: LoRAConfig): init_kwargs = dict(tokenizer_id=model_config.tokenizer, - enable_lora=enable_lora, + enable_lora=bool(lora_config), max_num_seqs=scheduler_config.max_num_seqs, + max_loras=lora_config.max_loras if lora_config else 0, max_input_length=None, tokenizer_mode=model_config.tokenizer_mode, trust_remote_code=model_config.trust_remote_code, diff --git a/vllm/transformers_utils/tokenizer_group/tokenizer_group.py b/vllm/transformers_utils/tokenizer_group/tokenizer_group.py index e516eeabaadef..761b07f34d2f9 100644 --- a/vllm/transformers_utils/tokenizer_group/tokenizer_group.py +++ b/vllm/transformers_utils/tokenizer_group/tokenizer_group.py @@ -21,8 +21,9 @@ def __init__(self, tokenizer_id: str, enable_lora: bool, max_num_seqs: int, self.enable_lora = enable_lora self.max_input_length = max_input_length self.tokenizer = get_tokenizer(self.tokenizer_id, **tokenizer_config) + max_loras = tokenizer_config.get("max_loras", 0) self.lora_tokenizers = LRUCache[AnyTokenizer]( - capacity=max_num_seqs if enable_lora else 0) + capacity=max(max_loras, max_num_seqs) if enable_lora else 0) @classmethod def from_config(cls, tokenizer_pool_config: Optional[TokenizerPoolConfig], diff --git a/vllm/utils.py b/vllm/utils.py index f19cc86898fc6..025a99455433e 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -1,5 +1,6 @@ import argparse import asyncio +import concurrent import contextlib import datetime import enum @@ -19,13 +20,13 @@ import warnings import weakref from asyncio import FIRST_COMPLETED, AbstractEventLoop, Future, Task -from collections import defaultdict +from collections import UserDict, defaultdict from collections.abc import Iterable, Mapping from functools import lru_cache, partial, wraps from platform import uname -from typing import (Any, AsyncGenerator, Awaitable, Callable, Dict, Generic, - Hashable, List, Literal, Optional, OrderedDict, Set, Tuple, - Type, TypeVar, Union, overload) +from typing import (TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable, + Dict, Generic, Hashable, List, Literal, Optional, + OrderedDict, Set, Tuple, Type, TypeVar, Union, overload) from uuid import uuid4 import numpy as np @@ -42,11 +43,14 @@ from vllm.logger import enable_trace_function_call, init_logger from vllm.platforms import current_platform +if TYPE_CHECKING: + from vllm.config import VllmConfig + logger = init_logger(__name__) # Exception strings for non-implemented encoder/decoder scenarios -# Reminder: Please update docs/source/serving/compatibility_matrix.rst +# Reminder: Please update docs/source/usage/compatibility_matrix.rst # If the feature combo become valid STR_NOT_IMPL_ENC_DEC_SWA = \ @@ -339,24 +343,16 @@ def random_uuid() -> str: return str(uuid.uuid4().hex) -@lru_cache(maxsize=None) -def get_vllm_instance_id() -> str: - """ - If the environment variable VLLM_INSTANCE_ID is set, return it. - Otherwise, return a random UUID. - Instance id represents an instance of the VLLM. All processes in the same - instance should have the same instance id. - """ - return envs.VLLM_INSTANCE_ID or f"vllm-instance-{random_uuid()}" - - @lru_cache(maxsize=None) def in_wsl() -> bool: # Reference: https://github.com/microsoft/WSL/issues/4071 return "microsoft" in " ".join(uname()).lower() -def make_async(func: Callable[P, T]) -> Callable[P, Awaitable[T]]: +def make_async( + func: Callable[P, T], + executor: Optional[concurrent.futures.Executor] = None +) -> Callable[P, Awaitable[T]]: """Take a blocking function, and run it on in an executor thread. This function prevents the blocking function from blocking the @@ -367,7 +363,7 @@ def make_async(func: Callable[P, T]) -> Callable[P, Awaitable[T]]: def _async_wrapper(*args: P.args, **kwargs: P.kwargs) -> asyncio.Future: loop = asyncio.get_event_loop() p_func = partial(func, *args, **kwargs) - return loop.run_in_executor(executor=None, func=p_func) + return loop.run_in_executor(executor=executor, func=p_func) return _async_wrapper @@ -998,7 +994,7 @@ def find_nccl_library() -> str: return so_file -def enable_trace_function_call_for_thread() -> None: +def enable_trace_function_call_for_thread(vllm_config: "VllmConfig") -> None: """Set up function tracing for the current thread, if enabled via the VLLM_TRACE_FUNCTION environment variable """ @@ -1010,7 +1006,8 @@ def enable_trace_function_call_for_thread() -> None: filename = (f"VLLM_TRACE_FUNCTION_for_process_{os.getpid()}" f"_thread_{threading.get_ident()}_" f"at_{datetime.datetime.now()}.log").replace(" ", "_") - log_path = os.path.join(tmp_dir, "vllm", get_vllm_instance_id(), + log_path = os.path.join(tmp_dir, "vllm", + f"vllm-instance-{vllm_config.instance_id}", filename) os.makedirs(os.path.dirname(log_path), exist_ok=True) enable_trace_function_call(log_path) @@ -1518,13 +1515,13 @@ def value(self): # Adapted from: https://stackoverflow.com/a/47212782/5082708 -class LazyDict(Mapping, Generic[T]): +class LazyDict(Mapping[str, T], Generic[T]): def __init__(self, factory: Dict[str, Callable[[], T]]): self._factory = factory self._dict: Dict[str, T] = {} - def __getitem__(self, key) -> T: + def __getitem__(self, key: str) -> T: if key not in self._dict: if key not in self._factory: raise KeyError(key) @@ -1541,6 +1538,22 @@ def __len__(self): return len(self._factory) +class ClassRegistry(UserDict[Type[T], _V]): + + def __getitem__(self, key: Type[T]) -> _V: + for cls in key.mro(): + if cls in self.data: + return self.data[cls] + + raise KeyError(key) + + def __contains__(self, key: object) -> bool: + if not isinstance(key, type): + return False + + return any(cls in self.data for cls in key.mro()) + + def weak_ref_tensor(tensor: torch.Tensor) -> torch.Tensor: """ Create a weak reference to a tensor. diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 5f8535eaa303f..d37989055c2e5 100644 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -6,8 +6,6 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType) -from vllm.forward_context import get_forward_context -from vllm.utils import direct_register_custom_op from vllm.vllm_flash_attn import flash_attn_varlen_func @@ -113,13 +111,14 @@ def forward( k_scale: float = 1.0, v_scale: float = 1.0, attn_type: AttentionType = AttentionType.DECODER, + output: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with FlashAttention. Args: - query: shape = [num_tokens, num_heads * head_size] - key: shape = [num_tokens, num_kv_heads * head_size] - value: shape = [num_tokens, num_kv_heads * head_size] + query: shape = [num_tokens, num_heads, head_size] + key: shape = [num_tokens, num_kv_heads, head_size] + value: shape = [num_tokens, num_kv_heads, head_size] kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size] attn_metadata: Metadata for attention. Returns: @@ -135,117 +134,42 @@ def forward( assert k_scale == 1.0 and v_scale == 1.0, ( "key/v_scale is not supported in FlashAttention.") - output = torch.empty_like(query) - torch.ops.vllm.unified_v1_flash_attention( - output, - query, - key, - value, - self.num_heads, - self.head_size, - self.num_kv_heads, - kv_cache, + if attn_metadata is None: + # Profiling run. + return output + + num_actual_tokens = attn_metadata.num_actual_tokens + + # Reshape the input keys and values and store them in the cache. + key_cache = kv_cache[0] + value_cache = kv_cache[1] + torch.ops._C_cache_ops.reshape_and_cache_flash( + key[:num_actual_tokens], + value[:num_actual_tokens], + key_cache, + value_cache, + attn_metadata.slot_mapping, self.kv_cache_dtype, k_scale, v_scale, - self.scale, - self.sliding_window, - self.alibi_slopes, - self.logits_soft_cap, ) - return output + # Compute attention and update output up to `num_actual_tokens`. + flash_attn_varlen_func( + q=query[:num_actual_tokens], + k=key_cache, + v=value_cache, + out=output[:num_actual_tokens], + cu_seqlens_q=attn_metadata.query_start_loc, + max_seqlen_q=attn_metadata.max_query_len, + cu_seqlens_k=attn_metadata.seq_start_loc, + max_seqlen_k=attn_metadata.max_seq_len, + softmax_scale=self.scale, + causal=True, + alibi_slopes=self.alibi_slopes, + window_size=self.sliding_window, + block_table=attn_metadata.block_table, + softcap=self.logits_soft_cap, + ) -def unified_v1_flash_attention( - output: torch.Tensor, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - num_heads: int, - head_size: int, - num_kv_heads: int, - kv_cache: torch.Tensor, - kv_cache_dtype: str, - k_scale: float, - v_scale: float, - softmax_scale: float, - window_size: Optional[List[int]] = None, - alibi_slopes: Optional[torch.Tensor] = None, - logits_soft_cap: Optional[float] = None, -) -> None: - context = get_forward_context() - current_metadata = context.dynamic_forward_context - if current_metadata is None: - # Profiling run. - return - - assert current_metadata is not None - assert isinstance(current_metadata, FlashAttentionMetadata) - attn_metadata: FlashAttentionMetadata = current_metadata - num_actual_tokens = attn_metadata.num_actual_tokens - - # Reshape the query, key, and value tensors. - query = query.view(-1, num_heads, head_size) - key = key.view(-1, num_kv_heads, head_size) - value = value.view(-1, num_kv_heads, head_size) - - # Reshape the input keys and values and store them in the cache. - key_cache = kv_cache[0] - value_cache = kv_cache[1] - torch.ops._C_cache_ops.reshape_and_cache_flash( - key[:num_actual_tokens], - value[:num_actual_tokens], - key_cache, - value_cache, - attn_metadata.slot_mapping, - kv_cache_dtype, - k_scale, - v_scale, - ) - - attn_output = flash_attn_varlen_func( - q=query[:num_actual_tokens], - k=key_cache, - v=value_cache, - cu_seqlens_q=attn_metadata.query_start_loc, - max_seqlen_q=attn_metadata.max_query_len, - cu_seqlens_k=attn_metadata.seq_start_loc, - max_seqlen_k=attn_metadata.max_seq_len, - softmax_scale=softmax_scale, - causal=True, - alibi_slopes=alibi_slopes, - window_size=window_size, - block_table=attn_metadata.block_table, - softcap=logits_soft_cap, - ) - attn_output = attn_output.view(num_actual_tokens, -1) - # TODO(woosuk): Optimize this. - output[:num_actual_tokens].copy_(attn_output) - - -def unified_v1_flash_attention_fake( - output: torch.Tensor, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - num_heads: int, - head_size: int, - num_kv_heads: int, - kv_cache: torch.Tensor, - kv_cache_dtype: str, - k_scale: float, - v_scale: float, - softmax_scale: float, - window_size: Optional[List[int]] = None, - alibi_slopes: Optional[torch.Tensor] = None, - logits_soft_cap: Optional[float] = None, -) -> None: - return - - -direct_register_custom_op( - op_name="unified_v1_flash_attention", - op_func=unified_v1_flash_attention, - mutates_args=["kv_cache", "output"], - fake_impl=unified_v1_flash_attention_fake, -) + return output diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 8eb3fb976eb87..b492a755e6dd5 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -17,12 +17,15 @@ def __init__( self, block_size: int, num_gpu_blocks: int, + max_model_len: int, sliding_window: Optional[int] = None, enable_caching: bool = True, num_preallocate_tokens: int = 64, ) -> None: self.block_size = block_size self.num_gpu_blocks = num_gpu_blocks + self.max_model_len = max_model_len + self.max_num_blocks_per_req = cdiv(max_model_len, block_size) self.sliding_window = sliding_window self.enable_caching = enable_caching # NOTE(woosuk): To avoid frequent block allocation, we preallocate some @@ -132,7 +135,14 @@ def append_slots( num_new_blocks = min( num_new_blocks + self.num_preallocate_blocks, self.free_block_queue.num_free_blocks, + # Should not exceed the maximum number of blocks per request. + # This is especially because the block table has the shape + # [..., max_num_blocks_per_req]. + # TODO(woosuk): Check and reject requests if + # num_prompt_tokens + max_tokens > max_model_len. + self.max_num_blocks_per_req - len(req_blocks), ) + assert num_new_blocks > 0 new_blocks = self._get_new_blocks(num_new_blocks) req_blocks.extend(new_blocks) @@ -212,7 +222,14 @@ def allocate_slots( num_required_blocks + self.num_preallocate_blocks, self.free_block_queue.num_free_blocks - num_evictable_computed_blocks, + # Should not exceed the maximum number of blocks per request. + # This is especially because the block table has the shape + # [..., max_num_blocks_per_req]. + # TODO(woosuk): Check and reject requests if + # num_prompt_tokens + max_tokens > max_model_len. + self.max_num_blocks_per_req - len(computed_blocks), ) + assert num_new_blocks > 0 # Concatenate the computed block IDs and the new block IDs. new_blocks = self._get_new_blocks(num_new_blocks) diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index ba50a9786d805..1203d35fc985f 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -33,22 +33,23 @@ def __init__( # TODO: Support LoRA. assert lora_config is None, "V1 does not support LoRA yet." + # Scheduling constraints. + self.max_num_running_reqs = self.scheduler_config.max_num_seqs + self.max_num_scheduled_tokens = \ + self.scheduler_config.max_num_batched_tokens + self.max_model_len = self.scheduler_config.max_model_len + num_gpu_blocks = cache_config.num_gpu_blocks assert isinstance(num_gpu_blocks, int) and num_gpu_blocks > 0 - # Create the block space manager. + # Create the KV cache manager. self.kv_cache_manager = KVCacheManager( block_size=self.cache_config.block_size, num_gpu_blocks=num_gpu_blocks, + max_model_len=self.max_model_len, sliding_window=self.cache_config.sliding_window, enable_caching=self.cache_config.enable_prefix_caching) self.block_size = self.cache_config.block_size - # Scheduling constraints. - self.max_num_running_reqs = self.scheduler_config.max_num_seqs - self.max_num_scheduled_tokens = \ - self.scheduler_config.max_num_batched_tokens - self.max_model_len = self.scheduler_config.max_model_len - # req_id -> Request self.requests: Dict[str, Request] = {} # Priority queues for requests. @@ -72,12 +73,12 @@ def __init__( # has the Transformer architecture (e.g., ViT). # FIXME(woosuk): Below are placeholder values. We need to calculate the # actual values from the configurations. - self.max_num_encoder_input_tokens = 2048 + self.max_num_encoder_input_tokens = 16384 # NOTE(woosuk): For the models without encoder (e.g., text-only models), # the encoder cache will not be initialized and used, regardless of # the cache size. This is because the memory space for the encoder cache # is preallocated in the profiling run. - self.encoder_cache_manager = EncoderCacheManager(cache_size=2048) + self.encoder_cache_manager = EncoderCacheManager(cache_size=16384) def schedule(self) -> "SchedulerOutput": # NOTE(woosuk) on the scheduling algorithm: diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index 967124fd850ea..3cf0e610ae7af 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -1,11 +1,11 @@ import enum from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Union +from typing import List, Optional, Union import msgspec from vllm.lora.request import LoRARequest -from vllm.multimodal import MultiModalDataDict, MultiModalPlaceholderDict +from vllm.multimodal import MultiModalKwargs, MultiModalPlaceholderDict from vllm.sampling_params import RequestOutputKind, SamplingParams @@ -35,9 +35,8 @@ class EngineCoreRequest: # always be tokenized? prompt: Optional[str] prompt_token_ids: List[int] - mm_data: Optional[MultiModalDataDict] + mm_inputs: Optional[List[MultiModalKwargs]] mm_placeholders: Optional[MultiModalPlaceholderDict] - mm_processor_kwargs: Optional[Dict[str, Any]] sampling_params: SamplingParams eos_token_id: Optional[int] arrival_time: float diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index a17c8eac4b77c..4ef372fd8464b 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -9,7 +9,7 @@ from vllm.inputs.preprocess import InputPreprocessor from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.outputs import EmbeddingRequestOutput, RequestOutput +from vllm.outputs import PoolingRequestOutput, RequestOutput from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams @@ -51,7 +51,7 @@ def __init__( model_config=vllm_config.model_config, scheduler_config=vllm_config.scheduler_config, parallel_config=vllm_config.parallel_config, - enable_lora=bool(vllm_config.lora_config)) + lora_config=vllm_config.lora_config) self.tokenizer.ping() # Request streams (map of request_id -> AsyncStream). @@ -133,7 +133,7 @@ async def add_request( trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, - ) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]: + ) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]: """Add new request to the AsyncLLM.""" if self.detokenizer.is_request_active(request_id): diff --git a/vllm/v1/engine/async_stream.py b/vllm/v1/engine/async_stream.py index 3e6c759ad5ebd..35449238c3259 100644 --- a/vllm/v1/engine/async_stream.py +++ b/vllm/v1/engine/async_stream.py @@ -1,11 +1,11 @@ import asyncio from typing import Any, AsyncGenerator, Callable, Optional, Type, Union -from vllm.outputs import EmbeddingRequestOutput, RequestOutput +from vllm.outputs import PoolingRequestOutput, RequestOutput class AsyncStream: - """A stream of RequestOutputs or EmbeddingRequestOutputs for a request + """A stream of RequestOutputs or PoolingRequestOutputs for a request that can be iterated over asynchronously via an async generator.""" STOP_ITERATION = Exception() # Sentinel @@ -16,7 +16,7 @@ def __init__(self, request_id: str, cancel: Callable[[str], None]) -> None: self._queue: asyncio.Queue = asyncio.Queue() self._finished = False - def put(self, item: Union[RequestOutput, EmbeddingRequestOutput, + def put(self, item: Union[RequestOutput, PoolingRequestOutput, Exception]) -> None: if not self._finished: self._queue.put_nowait(item) @@ -32,7 +32,7 @@ def finish( async def generator( self - ) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]: + ) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]: finished = False try: while True: diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 34f99dd30ef2e..751eb3b40a68d 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -67,6 +67,7 @@ def __init__( def _initialize_kv_caches(self, cache_config: CacheConfig) -> Tuple[int, int]: + start = time.time() num_gpu_blocks, _ = self.model_executor.determine_num_available_blocks( ) @@ -80,18 +81,14 @@ def _initialize_kv_caches(self, num_cpu_blocks = 0 self.model_executor.initialize_cache(num_gpu_blocks) + elapsed = time.time() - start + logger.info(("init engine (profile, create kv cache, " + "warmup model) took %.2f seconds"), elapsed) return num_gpu_blocks, num_cpu_blocks def add_request(self, request: EngineCoreRequest): """Add request to the scheduler.""" - req = Request.from_engine_core_request(request) - # FIXME(woosuk): The input mapping (e.g., PIL images to tensors) may - # take 10-50 ms, which can cause a spike in the latency. We should - # consider moving this to a separate thread. - if req.mm_data: - req.mm_inputs = self.mm_input_mapper.process_inputs( - req.mm_data, req.mm_processor_kwargs) self.scheduler.add_request(req) def abort_requests(self, request_ids: List[str]): diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index bd19d998a4adb..994e68669108e 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -1,5 +1,7 @@ from typing import Dict, List, Mapping, Optional, Type, Union +from typing_extensions import TypeVar + from vllm.config import VllmConfig from vllm.engine.arg_utils import EngineArgs from vllm.engine.metrics_types import StatLoggerBase @@ -12,7 +14,8 @@ from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams -from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs +from vllm.transformers_utils.tokenizer_group import ( + BaseTokenizerGroup, init_tokenizer_from_configs) from vllm.usage.usage_lib import UsageContext from vllm.v1.engine.core_client import EngineCoreClient from vllm.v1.engine.detokenizer import Detokenizer @@ -21,6 +24,8 @@ logger = init_logger(__name__) +_G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup) + class LLMEngine: """Legacy LLMEngine for backwards compatibility.""" @@ -46,7 +51,7 @@ def __init__( model_config=vllm_config.model_config, scheduler_config=vllm_config.scheduler_config, parallel_config=vllm_config.parallel_config, - enable_lora=bool(vllm_config.lora_config)) + lora_config=vllm_config.lora_config) self.tokenizer.ping() # Processor (convert Inputs --> EngineCoreRequests) @@ -169,5 +174,18 @@ def start_profile(self): def stop_profile(self): self.engine_core.profile(False) - def get_tokenizer_group(self, group_type): - pass + def get_tokenizer_group( + self, + group_type: Type[_G] = BaseTokenizerGroup, + ) -> _G: + tokenizer_group = self.tokenizer + + if tokenizer_group is None: + raise ValueError("Unable to get tokenizer because " + "skip_tokenizer_init is True") + if not isinstance(tokenizer_group, group_type): + raise TypeError("Invalid type of tokenizer group. " + f"Expected type: {group_type}, but " + f"found type: {type(tokenizer_group)}") + + return tokenizer_group diff --git a/vllm/v1/engine/mm_input_mapper.py b/vllm/v1/engine/mm_input_mapper.py index 594c973678235..7ad6882b04520 100644 --- a/vllm/v1/engine/mm_input_mapper.py +++ b/vllm/v1/engine/mm_input_mapper.py @@ -12,6 +12,7 @@ def __init__( model_config: ModelConfig, mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, ): + self.model_config = model_config self.mm_registry = mm_registry self.multi_modal_input_mapper = mm_registry.create_input_mapper( model_config) @@ -32,7 +33,7 @@ def process_inputs( num_images = len(image_inputs) for i in range(num_images): mm_input = self.multi_modal_input_mapper( - {"image": [image_inputs[i]]}, + {"image": image_inputs[i]}, mm_processor_kwargs=mm_processor_kwargs, ) mm_inputs.append(mm_input) diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 5c1577190c75a..120fc64969552 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -7,13 +7,15 @@ from vllm.inputs.parse import is_encoder_decoder_inputs from vllm.inputs.preprocess import InputPreprocessor from vllm.lora.request import LoRARequest -from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry +from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs, + MultiModalRegistry) from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams from vllm.transformers_utils.config import try_get_generation_config from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup from vllm.v1.engine import DetokenizerRequest, EngineCoreRequest +from vllm.v1.engine.mm_input_mapper import MMInputMapper class Processor: @@ -39,6 +41,9 @@ def __init__( self.input_processor = input_registry.create_input_processor( model_config) + # Multi-modal (huggingface) input mapper + self.mm_input_mapper = MMInputMapper(model_config) + # TODO: run in an ThreadpoolExecutor or BackgroundProcess. # This ideally should releases the GIL, so we should not block the # asyncio loop while this is running. @@ -96,6 +101,17 @@ def process_inputs( sampling_params.update_from_generation_config( self.generation_config_fields, eos_token_id) + # Preprocess multi-modal data + if len(decoder_inputs.multi_modal_data) == 0: + mm_inputs = None + elif isinstance(decoder_inputs.multi_modal_data, MultiModalKwargs): + mm_inputs = [decoder_inputs.multi_modal_data] + else: + mm_inputs = self.mm_input_mapper.process_inputs( + decoder_inputs.multi_modal_data, + decoder_inputs.mm_processor_kwargs, + ) + # Make Request for Detokenizer. detokenizer_request = DetokenizerRequest( request_id, @@ -113,9 +129,8 @@ def process_inputs( request_id, decoder_inputs.prompt, decoder_inputs.prompt_token_ids, - decoder_inputs.multi_modal_data, + mm_inputs, decoder_inputs.multi_modal_placeholders, - decoder_inputs.mm_processor_kwargs, sampling_params, eos_token_id, arrival_time, diff --git a/vllm/v1/request.py b/vllm/v1/request.py index 51fb4003e5fe0..6bc1e4d5c769f 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -45,9 +45,6 @@ def __init__( self._all_token_ids: List[int] = self.prompt_token_ids.copy() self.num_computed_tokens = 0 - # Raw multimodal data before the mm input mapper (e.g., PIL images). - self.mm_data = self.inputs.multi_modal_data - self.mm_processor_kwargs = self.inputs.mm_processor_kwargs mm_positions = self.inputs.multi_modal_placeholders if mm_positions: # FIXME(woosuk): Support other modalities. @@ -55,7 +52,10 @@ def __init__( else: self.mm_positions = [] # Output of the mm input mapper (e.g., image tensors). - self.mm_inputs: List[MultiModalKwargs] = [] + if self.inputs.multi_modal_inputs: + self.mm_inputs = self.inputs.multi_modal_inputs + else: + self.mm_inputs: List[MultiModalKwargs] = [] @classmethod def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request": @@ -64,9 +64,10 @@ def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request": inputs=token_inputs( prompt_token_ids=request.prompt_token_ids, prompt=request.prompt, - multi_modal_data=request.mm_data, + multi_modal_data=None, + multi_modal_inputs=request.mm_inputs, multi_modal_placeholders=request.mm_placeholders, - mm_processor_kwargs=request.mm_processor_kwargs, + mm_processor_kwargs=None, ), sampling_params=request.sampling_params, eos_token_id=request.eos_token_id, @@ -110,7 +111,7 @@ def get_finished_reason(self) -> Union[str, None]: return RequestStatus.get_finished_reason(self.status) def has_encoder_inputs(self) -> bool: - return len(self.mm_data) > 0 + return len(self.mm_inputs) > 0 @property def num_encoder_inputs(self) -> int: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index a40f4d6c9cd04..9f1b37d3b8e8c 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -8,7 +8,6 @@ import torch.distributed import torch.nn as nn -from vllm.compilation.compile_context import set_compile_context from vllm.config import CompilationLevel, VllmConfig from vllm.distributed.parallel_state import graph_capture from vllm.forward_context import set_forward_context @@ -100,7 +99,11 @@ def __init__( == CompilationLevel.PIECEWISE and not self.model_config.enforce_eager) # TODO(woosuk): Provide an option to tune the max cudagraph batch size. - self.cudagraph_batch_sizes = [1, 2, 4] + [i for i in range(8, 513, 8)] + # The convention is different. + # self.cudagraph_batch_sizes sorts in ascending order. + # The batch sizes in the config are in descending order. + self.cudagraph_batch_sizes = list( + reversed(self.vllm_config.compilation_config.capture_sizes)) self.positions = torch.zeros(self.max_num_tokens, dtype=torch.int64, device=self.device) @@ -257,7 +260,8 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2] # where M is the max_model_len. - token_indices = positions_np + req_indices * self.max_model_len + token_indices = (positions_np + + req_indices * self.input_batch.token_ids_cpu.shape[1]) token_indices = torch.from_numpy(token_indices) input_ids = torch.empty((total_num_scheduled_tokens, ), dtype=torch.int32, @@ -270,9 +274,15 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): out=input_ids) # Calculate the slot mapping. + # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + # -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1] + # where K is the max_num_blocks_per_req and the block size is 2. + # NOTE(woosuk): We can't simply use `token_indices // block_size` here + # because M (max_model_len) is not necessarily divisible by block_size. block_numbers = self.input_batch.block_table_cpu_tensor.flatten()[ - token_indices // self.block_size] - block_offsets = token_indices % self.block_size + req_indices * self.max_num_blocks_per_req + + positions_np // self.block_size] + block_offsets = torch.from_numpy(positions_np % self.block_size) slot_mapping = torch.empty((total_num_scheduled_tokens, ), dtype=torch.int32, device="cpu", @@ -548,10 +558,9 @@ def profile_run(self) -> None: torch.tensor([], dtype=torch.float32, device=self.device) for _ in range(self.num_attn_layers) ] - with set_compile_context(self.cudagraph_batch_sizes): - # Trigger compilation for general shape. - hidden_states = self._dummy_run(self.model, self.max_num_tokens, - dummy_kv_caches) + # Trigger compilation for general shape. + hidden_states = self._dummy_run(self.model, self.max_num_tokens, + dummy_kv_caches) logits = self.model.compute_logits(hidden_states, None) logits = logits[:self.max_num_tokens] # TODO(woosuk): Consider the memory usage of the sampler. diff --git a/vllm/worker/cpu_embedding_model_runner.py b/vllm/worker/cpu_pooling_model_runner.py similarity index 98% rename from vllm/worker/cpu_embedding_model_runner.py rename to vllm/worker/cpu_pooling_model_runner.py index 3954e4c4c8a5b..17b2fd2564a04 100644 --- a/vllm/worker/cpu_embedding_model_runner.py +++ b/vllm/worker/cpu_pooling_model_runner.py @@ -16,12 +16,12 @@ @dataclasses.dataclass(frozen=True) class ModelInputForCPUWithPoolingMetadata(ModelInputForCPU): """ - Used by the CPUEmbeddingModelRunner. + Used by the CPUPoolingModelRunner. """ pooling_metadata: Optional["PoolingMetadata"] = None -class CPUEmbeddingModelRunner( +class CPUPoolingModelRunner( CPUModelRunnerBase[ModelInputForCPUWithPoolingMetadata]): _model_input_cls: Type[ModelInputForCPUWithPoolingMetadata] = ( ModelInputForCPUWithPoolingMetadata) diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py index cf04808b73372..4fad1a3f4caeb 100644 --- a/vllm/worker/cpu_worker.py +++ b/vllm/worker/cpu_worker.py @@ -14,9 +14,9 @@ from vllm.model_executor import set_random_seed from vllm.sequence import ExecuteModelRequest from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE -from vllm.worker.cpu_embedding_model_runner import CPUEmbeddingModelRunner from vllm.worker.cpu_enc_dec_model_runner import CPUEncoderDecoderModelRunner from vllm.worker.cpu_model_runner import CPUModelRunner, CPUModelRunnerBase +from vllm.worker.cpu_pooling_model_runner import CPUPoolingModelRunner from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, LoraNotSupportedWorkerBase, WorkerBase, WorkerInput) @@ -164,7 +164,7 @@ def __init__( else {"return_hidden_states": True} ModelRunnerClass: Type[CPUModelRunnerBase] = CPUModelRunner if self.model_config.task == "embedding": - ModelRunnerClass = CPUEmbeddingModelRunner + ModelRunnerClass = CPUPoolingModelRunner elif self.model_config.is_encoder_decoder: ModelRunnerClass = CPUEncoderDecoderModelRunner self.model_runner: CPUModelRunnerBase = ModelRunnerClass( diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index ae18c79c980c8..5697fbbaa2041 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -25,8 +25,7 @@ from vllm.utils import STR_NOT_IMPL_ENC_DEC_BACKEND, make_tensor_with_pad from vllm.worker.model_runner import (GPUModelRunnerBase, ModelInputForGPUBuilder, - ModelInputForGPUWithSamplingMetadata, - _get_graph_batch_size) + ModelInputForGPUWithSamplingMetadata) from vllm.worker.model_runner_base import ( _add_attn_metadata_broadcastable_dict, _add_sampling_metadata_broadcastable_dict) @@ -465,7 +464,8 @@ def _prepare_encoder_model_input_tensors( # We will be using CUDA graph replay for this decode. max_len_of_block_table = self.get_max_block_per_batch() batch_size = len(encoder_seq_lens) - graph_batch_size = _get_graph_batch_size(batch_size) + graph_batch_size = self.vllm_config.get_graph_batch_size( + batch_size) assert graph_batch_size >= batch_size cuda_graph_pad_size = graph_batch_size - batch_size # extend the cross_block_tables and encoder_seq_lens to match diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 1f654a9cce465..1bc5f65c7127f 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -18,10 +18,9 @@ from vllm.attention import AttentionMetadata, get_attn_backend from vllm.attention.backends.abstract import AttentionState from vllm.attention.backends.utils import CommonAttentionState -from vllm.compilation.compile_context import set_compile_context from vllm.config import CompilationLevel, VllmConfig from vllm.core.scheduler import SchedulerOutputs -from vllm.distributed import get_pp_group +from vllm.distributed import get_kv_transfer_group, get_pp_group from vllm.distributed.parallel_state import graph_capture from vllm.forward_context import set_forward_context from vllm.inputs import INPUT_REGISTRY, InputRegistry @@ -63,16 +62,7 @@ logger = init_logger(__name__) LORA_WARMUP_RANK = 8 -_BATCH_SIZE_ALIGNMENT = 8 -# all the token sizes that **can** be captured by cudagraph. -# they can be arbitrarily large. -# currently it includes: 1, 2, 4, 8, 16, 24, 32, 40, ..., 8192. -# the actual sizes to capture will be determined by the model, -# depending on the model's max_num_seqs. -# NOTE: _get_graph_batch_size needs to be updated if this list is changed. -_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [ - _BATCH_SIZE_ALIGNMENT * i for i in range(1, 1025) -] + _NUM_WARMUP_ITERS = 2 TModelInputForGPU = TypeVar('TModelInputForGPU', bound="ModelInputForGPU") @@ -763,7 +753,6 @@ def _use_captured_graph(self, max_decode_seq_len: int, max_encoder_seq_len: int = 0) -> bool: return (decode_only and not self.runner.model_config.enforce_eager - and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1] and max_decode_seq_len <= self.runner.max_seq_len_to_capture and max_encoder_seq_len <= self.runner.max_seq_len_to_capture and batch_size <= self.runner.max_batchsize_to_capture) @@ -811,7 +800,7 @@ def _get_cuda_graph_pad_size(self, max_encoder_seq_len): return -1 - graph_batch_size = _get_graph_batch_size(batch_size) + graph_batch_size = VllmConfig.get_graph_batch_size(batch_size) assert graph_batch_size >= batch_size return graph_batch_size - batch_size @@ -1023,7 +1012,7 @@ def __init__( self.sliding_window = model_config.get_sliding_window() self.block_size = cache_config.block_size self.max_seq_len_to_capture = self.model_config.max_seq_len_to_capture - self.max_batchsize_to_capture = _get_max_graph_batch_size( + self.max_batchsize_to_capture = VllmConfig.get_max_graph_batch_size( self.scheduler_config.max_num_seqs) self.graph_runners: List[Dict[int, CUDAGraphRunner]] = [ @@ -1333,14 +1322,7 @@ def profile_run(self) -> None: dtype=self.model_config.dtype, device=self.device) - graph_batch_size = self.max_batchsize_to_capture - batch_size_capture_list = [ - bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size - ] - if self.model_config.enforce_eager: - batch_size_capture_list = [] - with set_compile_context(batch_size_capture_list): - self.execute_model(model_input, kv_caches, intermediate_tensors) + self.execute_model(model_input, kv_caches, intermediate_tensors) torch.cuda.synchronize() return @@ -1459,18 +1441,14 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: dtype=self.model_config.dtype, device=self.device) - graph_batch_size = self.max_batchsize_to_capture - batch_size_capture_list = [ - bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size - ] - with self.attn_state.graph_capture( max_batch_size), graph_capture() as graph_capture_context: # NOTE: Capturing the largest batch size first may help reduce the # memory usage of CUDA graph. for virtual_engine in range( self.parallel_config.pipeline_parallel_size): - for batch_size in reversed(batch_size_capture_list): + for batch_size in \ + self.vllm_config.compilation_config.capture_sizes: attn_metadata = ( self.attn_state.graph_capture_get_metadata_for_batch( batch_size, @@ -1666,6 +1644,24 @@ def execute_model( else: model_executable = self.model + # Receive KV cache in distributed KV cache transfer setting + # In disagg prefill setting, it will also recv hidden states and bypass + # model forwarding + # In KV cache database setting, it will change the model input so that + # we can skip prefilling on tokens that successfully received KV caches + # NOTE: The receive operation is blocking + bypass_model_exec = False + if self.need_recv_kv(model_input, kv_caches): + hidden_or_intermediate_states, bypass_model_exec, model_input = \ + get_kv_transfer_group().recv_kv_caches_and_hidden_states( + # model is used to know which layer the current worker + # is working on, so that we can receive KV for only those + # layers. + model_executable, + model_input, + kv_caches=kv_caches + ) + multi_modal_kwargs = model_input.multi_modal_kwargs or {} seqlen_agnostic_kwargs = { "finished_requests_ids": model_input.finished_requests_ids, @@ -1677,21 +1673,36 @@ def execute_model( model_forward_end = torch.cuda.Event(enable_timing=True) model_forward_start.record() - with set_forward_context(model_input.attn_metadata, self.vllm_config): - hidden_or_intermediate_states = model_executable( - input_ids=model_input.input_tokens, - positions=model_input.input_positions, - kv_caches=kv_caches, - attn_metadata=model_input.attn_metadata, - intermediate_tensors=intermediate_tensors, - **MultiModalKwargs.as_kwargs(multi_modal_kwargs, - device=self.device), - **seqlen_agnostic_kwargs) + if not bypass_model_exec: + with set_forward_context(model_input.attn_metadata, + self.vllm_config): + hidden_or_intermediate_states = model_executable( + input_ids=model_input.input_tokens, + positions=model_input.input_positions, + kv_caches=kv_caches, + attn_metadata=model_input.attn_metadata, + intermediate_tensors=intermediate_tensors, + **MultiModalKwargs.as_kwargs(multi_modal_kwargs, + device=self.device), + **seqlen_agnostic_kwargs) if (self.observability_config is not None and self.observability_config.collect_model_forward_time): model_forward_end.record() + # Sending KV cache in distributed KV cache transfer setting + # NOTE: the send operation is non-blocking + if self.need_send_kv(model_input, kv_caches): + get_kv_transfer_group().send_kv_caches_and_hidden_states( + # model_executable is used to know which layer the current + # worker is working on, so that we can send KV for only those + # layers. + model_executable, + model_input, + kv_caches, + hidden_or_intermediate_states, + ) + # Compute the logits in the last pipeline stage. if not get_pp_group().is_last_rank: if (self.is_driver_worker @@ -1759,6 +1770,56 @@ def execute_model( return [output] + def need_recv_kv(self, model_input, kv_caches) -> bool: + """Check if we need to receive kv-cache from the other worker. + We need to receive KV when + 1. current vLLM instance is KV cache consumer/decode vLLM instance + 2. this batch is not a profiling run + 3. this batch is a prefill run + + Args: + model_input: input to the model executable + kv_caches: vLLM's paged memory + """ + + if self.vllm_config.kv_transfer_config is None: + return False + + prefill_meta = model_input.attn_metadata.prefill_metadata + + # check if the current run is profiling + is_profile_run = (kv_caches[0].numel() == 0) + # check if the current run is prefill + is_prefill_run = prefill_meta is not None + + return self.vllm_config.kv_transfer_config.is_kv_consumer and ( + not is_profile_run) and is_prefill_run + + def need_send_kv(self, model_input, kv_caches) -> bool: + """Check if we need to send kv-cache to the other worker. + We need to send KV when + 1. current vLLM instance is KV cache producer/prefill vLLM instance + 2. this batch is not a profiling run + 3. this batch is a prefill run + + Args: + model_input: input to the model executable + kv_caches: vLLM's paged memory + """ + + if self.vllm_config.kv_transfer_config is None: + return False + + prefill_meta = model_input.attn_metadata.prefill_metadata + + # check if the current run is profiling + is_profile_run = (kv_caches[0].numel() == 0) + # check if the current run is prefill + is_prefill_run = prefill_meta is not None + + return self.vllm_config.kv_transfer_config.is_kv_producer and ( + not is_profile_run) and is_prefill_run + # NOTE: this is nn.Module so the profiler can properly capture/group # kernels calls made within the graph @@ -1910,37 +1971,3 @@ def forward( return self.output_buffers["hidden_states"] return self.output_buffers - - -def _get_graph_batch_size(batch_size: int) -> int: - """Returns the padded batch size given actual batch size. - - Batch sizes are 1, 2, 4, _BATCH_SIZE_ALIGNMENT, - 2*_BATCH_SIZE_ALIGNMENT, 3*_BATCH_SIZE_ALIGNMENT... - """ - if batch_size <= 2: - return batch_size - elif batch_size <= 4: - return 4 - else: - return ((batch_size + _BATCH_SIZE_ALIGNMENT - 1) // - _BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT) - - -def _get_max_graph_batch_size(max_num_seqs: int) -> int: - """ - max_num_seqs: Maximum number of sequences in a batch. - _BATCH_SIZES_TO_CAPTURE: all the sizes that we want to capture. - - pad the max_num_seqs if necessary by calling _get_graph_batch_size, - which will deal with some edge cases like 1, 2, 4. - - if the padded size is in _BATCH_SIZES_TO_CAPTURE, return the padded size. - if not, it means the padded size is larger than the largest size in - _BATCH_SIZES_TO_CAPTURE, return the largest size in _BATCH_SIZES_TO_CAPTURE. - """ - padded_size = _get_graph_batch_size(max_num_seqs) - if padded_size in _BATCH_SIZES_TO_CAPTURE: - return padded_size - assert padded_size > _BATCH_SIZES_TO_CAPTURE[-1] - return _BATCH_SIZES_TO_CAPTURE[-1] diff --git a/vllm/worker/multi_step_model_runner.py b/vllm/worker/multi_step_model_runner.py index 6f8627bd18a41..e08a61e31fe42 100644 --- a/vllm/worker/multi_step_model_runner.py +++ b/vllm/worker/multi_step_model_runner.py @@ -819,7 +819,7 @@ def _pythonize_sampler_output( for sgdx, (seq_group, sample_result) in enumerate(zip(seq_groups, samples_list)): - # Reminder: Please update docs/source/serving/compatibility_matrix.rst + # Reminder: Please update docs/source/usage/compatibility_matrix.rst # If the feature combo become valid # (Check for Guided Decoding) if seq_group.sampling_params.logits_processors: diff --git a/vllm/worker/openvino_worker.py b/vllm/worker/openvino_worker.py index 205f8a337ce6c..0bf522d5333ed 100644 --- a/vllm/worker/openvino_worker.py +++ b/vllm/worker/openvino_worker.py @@ -489,7 +489,7 @@ def model_profile_run(): block_size = cache_config.block_size seq_num_blocks = (seq_len + block_size - 1) // block_size - seq_data, dummy_multi_modal_data = input_registry \ + dummy_data = input_registry \ .dummy_data_for_profiling(model_config, seq_len, mm_registry) @@ -498,11 +498,11 @@ def model_profile_run(): seq = SequenceGroupMetadata( request_id=str(group_id), is_prompt=True, - seq_data={group_id: seq_data}, + seq_data={group_id: dummy_data.seq_data}, sampling_params=sampling_params, block_tables=block_tables, lora_request=None, - multi_modal_data=dummy_multi_modal_data) + multi_modal_data=dummy_data.multi_modal_data) seqs.append(seq) self.model_runner.block_size = tmp_cache_config.block_size diff --git a/vllm/worker/embedding_model_runner.py b/vllm/worker/pooling_model_runner.py similarity index 98% rename from vllm/worker/embedding_model_runner.py rename to vllm/worker/pooling_model_runner.py index f56805918fd15..1beae1e3884c5 100644 --- a/vllm/worker/embedding_model_runner.py +++ b/vllm/worker/pooling_model_runner.py @@ -21,12 +21,12 @@ @dataclasses.dataclass(frozen=True) class ModelInputForGPUWithPoolingMetadata(ModelInputForGPU): """ - Used by the EmbeddingModelRunner. + Used by the PoolingModelRunner. """ pooling_metadata: Optional["PoolingMetadata"] = None -class EmbeddingModelRunner( +class PoolingModelRunner( GPUModelRunnerBase[ModelInputForGPUWithPoolingMetadata]): _model_input_cls: Type[ModelInputForGPUWithPoolingMetadata] = ( ModelInputForGPUWithPoolingMetadata) @@ -52,7 +52,7 @@ def execute_model( ) -> Optional[Union[List[PoolerOutput], IntermediateTensors]]: if num_steps > 1: raise ValueError( - "EmbeddingModelRunner does not support multi-step execution.") + "PoolingModelRunner does not support multi-step execution.") if self.lora_config: assert model_input.lora_requests is not None diff --git a/vllm/worker/utils.py b/vllm/worker/utils.py index f43635464ef00..5f71ec0c14df8 100644 --- a/vllm/worker/utils.py +++ b/vllm/worker/utils.py @@ -13,7 +13,7 @@ def assert_enc_dec_mr_supported_scenario( a supported scenario. ''' - # Reminder: Please update docs/source/serving/compatibility_matrix.rst + # Reminder: Please update docs/source/usage/compatibility_matrix.rst # If the feature combo become valid if enc_dec_mr.cache_config.enable_prefix_caching: diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 24e7bc760b0c0..094dd5a5d08b3 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -8,8 +8,9 @@ import torch.distributed import vllm.envs as envs -from vllm.config import ParallelConfig, VllmConfig -from vllm.distributed import (ensure_model_parallel_initialized, +from vllm.config import VllmConfig +from vllm.distributed import (ensure_kv_transfer_initialized, + ensure_model_parallel_initialized, init_distributed_environment, set_custom_all_reduce) from vllm.logger import init_logger @@ -22,9 +23,9 @@ from vllm.sequence import (ExecuteModelRequest, IntermediateTensors, SequenceGroupMetadata, SequenceGroupMetadataDelta) from vllm.worker.cache_engine import CacheEngine -from vllm.worker.embedding_model_runner import EmbeddingModelRunner from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner +from vllm.worker.pooling_model_runner import PoolingModelRunner from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase, WorkerInput) @@ -75,7 +76,7 @@ def __init__( ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner if model_config.task == "embedding": - ModelRunnerClass = EmbeddingModelRunner + ModelRunnerClass = PoolingModelRunner elif self.model_config.is_encoder_decoder: ModelRunnerClass = EncoderDecoderModelRunner self.model_runner: GPUModelRunnerBase = ModelRunnerClass( @@ -144,7 +145,7 @@ def init_device(self) -> None: raise RuntimeError( f"Not support device type: {self.device_config.device}") # Initialize the distributed environment. - init_worker_distributed_environment(self.parallel_config, self.rank, + init_worker_distributed_environment(self.vllm_config, self.rank, self.distributed_init_method, self.local_rank) # Set random seed. @@ -457,20 +458,22 @@ def get_cache_block_size_bytes(self) -> int: def init_worker_distributed_environment( - parallel_config: ParallelConfig, + vllm_config: VllmConfig, rank: int, distributed_init_method: Optional[str] = None, local_rank: int = -1, ) -> None: """Initialize the distributed environment.""" + parallel_config = vllm_config.parallel_config set_custom_all_reduce(not parallel_config.disable_custom_all_reduce) init_distributed_environment(parallel_config.world_size, rank, distributed_init_method, local_rank) - ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, parallel_config.pipeline_parallel_size) + ensure_kv_transfer_initialized(vllm_config) + def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): # Check if the GPU supports the dtype. diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 7aaa8b453cff1..6d00102e0a324 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -43,6 +43,7 @@ def __init__( self.speculative_config = vllm_config.speculative_config self.prompt_adapter_config = vllm_config.prompt_adapter_config self.observability_config = vllm_config.observability_config + self.kv_transfer_config = vllm_config.kv_transfer_config @abstractmethod def init_device(self) -> None: @@ -438,7 +439,7 @@ def init_worker(self, *args, **kwargs): Here we inject some common logic before initializing the worker. Arguments are passed to the worker class constructor. """ - enable_trace_function_call_for_thread() + enable_trace_function_call_for_thread(self.vllm_config) # see https://github.com/NVIDIA/nccl/issues/1234 os.environ['NCCL_CUMEM_ENABLE'] = '0'