diff --git a/.github/workflows/accuracy-test.yml b/.github/workflows/accuracy-test.yml deleted file mode 100644 index 374f0d2856d..00000000000 --- a/.github/workflows/accuracy-test.yml +++ /dev/null @@ -1,46 +0,0 @@ -name: Accuracy Test - -on: - push: - branches: [ main ] - paths: - - "python/sglang/**" - - "test/**" - pull_request: - branches: [ main ] - paths: - - "python/sglang/**" - - "test/**" - workflow_dispatch: - -concurrency: - group: accuracy-test-${{ github.ref }} - cancel-in-progress: true - -jobs: - accuracy-test: - if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' - runs-on: accuracy - - steps: - - name: Checkout code - uses: actions/checkout@v3 - - - name: Install dependencies - run: | - source $HOME/venv/bin/activate - echo "$HOME/venv/bin" >> $GITHUB_PATH - - pip install --upgrade pip - pip install -e "python[all]" - pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall - - git clone https://github.com/merrymercy/human-eval.git - cd human-eval - pip install -e . - - - name: Evaluate Accuracy - run: | - cd test/srt - python3 test_eval_accuracy_large.py - timeout-minutes: 10 diff --git a/.github/workflows/cache-purge.yml b/.github/workflows/cache-purge.yml deleted file mode 100644 index c699f49885f..00000000000 --- a/.github/workflows/cache-purge.yml +++ /dev/null @@ -1,27 +0,0 @@ -name: Weekly Cache Purge - -on: - schedule: - - cron: '0 0 * * 0' # Every Sunday at 00:00 - workflow_dispatch: - -jobs: - purge-cache: - if: github.repository == 'sgl-project/sglang' - runs-on: self-hosted - - steps: - - name: Checkout code - uses: actions/checkout@v3 - - - name: Purge pip cache - run: | - source $HOME/venv/bin/activate - echo "$HOME/venv/bin" >> $GITHUB_PATH - pip cache purge - - - name: Update dependencies - run: | - pip install --upgrade pip - pip install -e "python[all]" - pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall diff --git a/.github/workflows/e2e-test.yml b/.github/workflows/e2e-test.yml deleted file mode 100644 index ad271c37edb..00000000000 --- a/.github/workflows/e2e-test.yml +++ /dev/null @@ -1,54 +0,0 @@ -name: E2E Test - -on: - push: - branches: [ main ] - paths: - - "python/sglang/**" - - "test/**" - pull_request: - branches: [ main ] - paths: - - "python/sglang/**" - - "test/**" - workflow_dispatch: - -concurrency: - group: e2e-test-${{ github.ref }} - cancel-in-progress: true - -jobs: - e2e-test: - if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' - runs-on: e2e - - steps: - - name: Checkout code - uses: actions/checkout@v3 - - - name: Install dependencies - run: | - source $HOME/venv/bin/activate - echo "$HOME/venv/bin" >> $GITHUB_PATH - - pip install --upgrade pip - pip install -e "python[all]" - pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall - - - name: Benchmark Serving Throughput - run: | - cd test/srt - python3 -m unittest test_serving_throughput.TestServingThroughput.test_default - timeout-minutes: 10 - - - name: Benchmark Serving Throughput (w/o RadixAttention) - run: | - cd test/srt - python3 -m unittest test_serving_throughput.TestServingThroughput.test_default_without_radix_cache - timeout-minutes: 10 - - - name: Benchmark Serving Throughput (w/o ChunkedPrefill) - run: | - cd test/srt - python3 -m unittest test_serving_throughput.TestServingThroughput.test_default_without_chunked_prefill - timeout-minutes: 10 diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 07614050640..4857f844f27 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -1,19 +1,22 @@ name: Lint -on: [push, pull_request] +on: [pull_request] jobs: lint: - runs-on: ubuntu-20.04 + runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 - - name: Set up Python 3.8 + + - name: Set up Python 3.9 uses: actions/setup-python@v2 with: - python-version: 3.8 + python-version: 3.9 + - name: Install pre-commit hook run: | python -m pip install pre-commit pre-commit install + - name: Linting run: pre-commit run --all-files diff --git a/.github/workflows/moe-test.yml b/.github/workflows/moe-test.yml deleted file mode 100644 index 51f7d022614..00000000000 --- a/.github/workflows/moe-test.yml +++ /dev/null @@ -1,47 +0,0 @@ -name: MoE Test - -on: - push: - branches: [ main ] - paths: - - "python/sglang/**" - - "test/**" - pull_request: - branches: [ main ] - paths: - - "python/sglang/**" - - "test/**" - workflow_dispatch: - -concurrency: - group: moe-test-${{ github.ref }} - cancel-in-progress: true - -jobs: - moe-test: - if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' - runs-on: accuracy - - steps: - - name: Checkout code - uses: actions/checkout@v3 - - - name: Install dependencies - run: | - source $HOME/venv/bin/activate - echo "$HOME/venv/bin" >> $GITHUB_PATH - - pip install --upgrade pip - pip install -e "python[all]" - pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall - - - name: Benchmark MOE Serving Throughput - uses: nick-fields/retry@v3 - with: - timeout_minutes: 15 - max_attempts: 2 - retry_on: error - command: | - cd test/srt - python3 -m unittest test_moe_serving_throughput.TestServingThroughput.test_default - python3 -m unittest test_moe_serving_throughput.TestServingThroughput.test_default_without_radix_cache diff --git a/.github/workflows/nightly-eval.yml b/.github/workflows/nightly-eval.yml new file mode 100644 index 00000000000..4ac911c9a92 --- /dev/null +++ b/.github/workflows/nightly-eval.yml @@ -0,0 +1,35 @@ +name: Nightly Evaluation + +on: + schedule: + - cron: '0 0 * * *' + push: + branches: + - main + paths: + - "python/sglang/version.py" + workflow_dispatch: + +concurrency: + group: nightly-eval-${{ github.ref }} + cancel-in-progress: true + +jobs: + nightly-eval-2-gpu: + if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' + runs-on: 2-gpu-runner + steps: + - name: Checkout code + uses: actions/checkout@v3 + + - name: Install dependencies + run: | + pip install --upgrade pip + pip install -e "python[all]" + pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall + + - name: Nightly gsm8k Accuracy + timeout-minutes: 60 + run: | + cd test/srt + python3 test_nightly_gsm8k_eval.py diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml new file mode 100644 index 00000000000..5784a097584 --- /dev/null +++ b/.github/workflows/pr-test.yml @@ -0,0 +1,201 @@ +name: PR Test + +on: + push: + branches: [ main ] + paths: + - "python/sglang/**" + - "test/**" + pull_request: + branches: [ main ] + paths: + - "python/sglang/**" + - "test/**" + workflow_dispatch: + +concurrency: + group: pr-test-${{ github.ref }} + cancel-in-progress: true + +jobs: + unit-test-frontend: + if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' + runs-on: 1-gpu-runner + steps: + - name: Checkout code + uses: actions/checkout@v3 + + - name: Install dependencies + run: | + pip install --upgrade pip + pip install -e "python[dev]" + pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall + + - name: Run test + timeout-minutes: 20 + run: | + cd test/lang + python3 run_suite.py --suite minimal + + unit-test-backend-part-0: + if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' + runs-on: 1-gpu-runner + steps: + - name: Checkout code + uses: actions/checkout@v3 + + - name: Install dependencies + run: | + pip install --upgrade pip + pip install -e "python[dev]" + pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall + + - name: Run test + timeout-minutes: 20 + run: | + cd test/srt + python3 run_suite.py --suite minimal --range-begin 0 --range-end 8 + + unit-test-backend-part-1: + if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' + runs-on: 1-gpu-runner + steps: + - name: Checkout code + uses: actions/checkout@v3 + + - name: Install dependencies + run: | + pip install --upgrade pip + pip install -e "python[dev]" + pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall + + - name: Run test + timeout-minutes: 20 + run: | + cd test/srt + python3 run_suite.py --suite minimal --range-begin 8 + + performance-test-1-gpu: + if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' + runs-on: 1-gpu-runner + steps: + - name: Checkout code + uses: actions/checkout@v3 + + - name: Install dependencies + run: | + pip install --upgrade pip + pip install -e "python[all]" + pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall + + - name: Benchmark Serving Throughput + timeout-minutes: 10 + run: | + cd test/srt + python3 -m unittest test_serving_throughput.TestServingThroughput.test_default + + - name: Benchmark Serving Latency + timeout-minutes: 10 + run: | + cd test/srt + python3 -m unittest test_serving_latency.TestServingLatency.test_default + + - name: Benchmark Serving Throughput (w/o RadixAttention) + timeout-minutes: 10 + run: | + cd test/srt + python3 -m unittest test_serving_throughput.TestServingThroughput.test_default_without_radix_cache + + - name: Benchmark Serving Throughput (w/o ChunkedPrefill) + timeout-minutes: 10 + run: | + cd test/srt + python3 -m unittest test_serving_throughput.TestServingThroughput.test_default_without_chunked_prefill + + performance-test-2-gpu: + if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' + runs-on: 2-gpu-runner + steps: + - name: Checkout code + uses: actions/checkout@v3 + + - name: Install dependencies + run: | + pip install --upgrade pip + pip install -e "python[all]" + pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall + + - name: Benchmark Serving Throughput (TP=2) + timeout-minutes: 10 + run: | + cd test/srt + python3 -m unittest test_moe_serving_throughput.TestServingThroughput.test_default + + - name: Benchmark Serving Latency (TP=2) + timeout-minutes: 10 + run: | + cd test/srt + python3 -m unittest test_moe_serving_latency.TestServingLatency.test_default + + - name: Benchmark Serving Throughput (w/o RadixAttention) (TP=2) + timeout-minutes: 10 + run: | + cd test/srt + python3 -m unittest test_moe_serving_throughput.TestServingThroughput.test_default_without_radix_cache + + accuracy-test-1-gpu: + if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' + runs-on: 1-gpu-runner + steps: + - name: Checkout code + uses: actions/checkout@v3 + + - name: Install dependencies + run: | + pip install --upgrade pip + pip install -e "python[all]" + pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall + + git clone https://github.com/merrymercy/human-eval.git + cd human-eval + pip install -e . + + - name: Evaluate Accuracy + timeout-minutes: 20 + run: | + cd test/srt + python3 test_eval_accuracy_large.py + + accuracy-test-2-gpu: + if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' + runs-on: 2-gpu-runner + steps: + - name: Checkout code + uses: actions/checkout@v3 + + - name: Install dependencies + run: | + pip install --upgrade pip + pip install -e "python[all]" + pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall + + git clone https://github.com/merrymercy/human-eval.git + cd human-eval + pip install -e . + + - name: Evaluate Accuracy + timeout-minutes: 20 + run: | + cd test/srt + python3 test_moe_eval_accuracy_large.py + + finish: + needs: [ + unit-test-frontend, unit-test-backend-part-0, unit-test-backend-part-1, + performance-test-1-gpu, performance-test-2-gpu, + accuracy-test-1-gpu, accuracy-test-2-gpu + ] + runs-on: ubuntu-latest + steps: + - name: Finish + run: echo "This is an empty step to ensure that all jobs are completed." diff --git a/.github/workflows/unit-test.yml b/.github/workflows/unit-test.yml deleted file mode 100644 index 3422cde40d9..00000000000 --- a/.github/workflows/unit-test.yml +++ /dev/null @@ -1,50 +0,0 @@ -name: Unit Test - -on: - push: - branches: [ main ] - paths: - - "python/sglang/**" - - "test/**" - pull_request: - branches: [ main ] - paths: - - "python/sglang/**" - - "test/**" - workflow_dispatch: - -concurrency: - group: unit-test-${{ github.ref }} - cancel-in-progress: true - -jobs: - unit-test: - if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' - runs-on: unit - - steps: - - name: Checkout code - uses: actions/checkout@v3 - - - name: Install dependencies - run: | - source $HOME/venv/bin/activate - echo "$HOME/venv/bin" >> $GITHUB_PATH - - pip install --upgrade pip - pip install -e "python[all]" - pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall - pip install accelerate - pip install sentence_transformers - - - name: Test Backend Runtime - run: | - cd test/srt - python3 run_suite.py --suite minimal - timeout-minutes: 18 - - - name: Test Frontend Language - run: | - cd test/lang - python3 run_suite.py --suite minimal - timeout-minutes: 10 diff --git a/README.md b/README.md index c7d47d67866..eb3099cf7ae 100644 --- a/README.md +++ b/README.md @@ -17,17 +17,18 @@ SGLang is a fast serving framework for large language models and vision language It makes your interaction with models faster and more controllable by co-designing the backend runtime and frontend language. The core features include: -- **Fast Backend Runtime**: Efficient serving with RadixAttention for prefix caching, jump-forward constrained decoding, continuous batching, token attention (paged attention), tensor parallelism, flashinfer kernels, and quantization (AWQ/FP8/GPTQ/Marlin). +- **Fast Backend Runtime**: Efficient serving with RadixAttention for prefix caching, jump-forward constrained decoding, continuous batching, token attention (paged attention), tensor parallelism, FlashInfer kernels, and quantization (AWQ/FP8/GPTQ/Marlin). - **Flexible Frontend Language**: Enables easy programming of LLM applications with chained generation calls, advanced prompting, control flow, multiple modalities, parallelism, and external interactions. ## News +- [2024/09] 🔥 SGLang v0.3 Release: 7x Faster DeepSeek MLA, 1.5x Faster torch.compile, Multi-Image/Video LLaVA-OneVision ([blog](https://lmsys.org/blog/2024-09-04-sglang-v0-3/)). - [2024/07] 🔥 Faster Llama3 Serving with SGLang Runtime (vs. TensorRT-LLM, vLLM) ([blog](https://lmsys.org/blog/2024-07-25-sglang-llama3/)). -- [2024/04] SGLang is used by the official **LLaVA-NeXT (video)** release ([blog](https://llava-vl.github.io/blog/2024-04-30-llava-next-video/)). - [2024/02] SGLang enables **3x faster JSON decoding** with compressed finite state machine ([blog](https://lmsys.org/blog/2024-02-05-compressed-fsm/)).
More +- [2024/04] SGLang is used by the official **LLaVA-NeXT (video)** release ([blog](https://llava-vl.github.io/blog/2024-04-30-llava-next-video/)). - [2024/01] SGLang provides up to **5x faster inference** with RadixAttention ([blog](https://lmsys.org/blog/2024-01-17-sglang/)). - [2024/01] SGLang powers the serving of the official **LLaVA v1.6** release demo ([usage](https://github.com/haotian-liu/LLaVA?tab=readme-ov-file#demo)). @@ -55,7 +56,7 @@ pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ ### Method 2: From source ``` # Use the last release branch -git clone -b v0.2.13 https://github.com/sgl-project/sglang.git +git clone -b v0.3.0 https://github.com/sgl-project/sglang.git cd sglang pip install --upgrade pip @@ -82,6 +83,7 @@ docker run --gpus all \ ### Method 4: Using docker compose
+More > This method is recommended if you plan to serve it as a service. > A better approach is to use the [k8s-sglang-service.yaml](./docker/k8s-sglang-service.yaml). @@ -93,6 +95,7 @@ docker run --gpus all \ ### Method 5: Run on Kubernetes or Clouds with SkyPilot
+More To deploy on Kubernetes or 12+ clouds, you can use [SkyPilot](https://github.com/skypilot-org/skypilot). @@ -132,7 +135,7 @@ sky status --endpoint 30000 sglang ### Common Notes -- [FlashInfer](https://github.com/flashinfer-ai/flashinfer) is currently one of the dependencies that must be installed for SGLang. If you are using NVIDIA GPU devices below sm80, such as T4, you can't use SGLang for the time being. We expect to resolve this issue soon, so please stay tuned. If you encounter any FlashInfer-related issues on sm80+ devices (e.g., A100, L40S, H100), consider using Triton's kernel by `--disable-flashinfer --disable-flashinfer-sampling` and raise a issue. +- [FlashInfer](https://github.com/flashinfer-ai/flashinfer) is currently one of the dependencies that must be installed for SGLang. It only supports sm75 and above. If you encounter any FlashInfer-related issues on sm75+ devices (e.g., T4, A10, A100, L4, L40S, H100), consider using Triton's kernel by `--disable-flashinfer --disable-flashinfer-sampling` and raise an issue. - If you only need to use the OpenAI backend, you can avoid installing other dependencies by using `pip install "sglang[openai]"`. ## Backend: SGLang Runtime (SRT) @@ -186,6 +189,13 @@ response = client.chat.completions.create( max_tokens=64, ) print(response) + +# Text embedding +response = client.embeddings.create( + model="default", + input="How are you today", +) +print(response) ``` It supports streaming, vision, and most features of the Chat/Completions/Models/Batch endpoints specified by the [OpenAI API Reference](https://platform.openai.com/docs/api-reference/). @@ -195,7 +205,7 @@ It supports streaming, vision, and most features of the Chat/Completions/Models/ ``` python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --port 30000 --tp 2 ``` -- Add `--dp 2` to enable multi-GPU data parallelism. It can also be used together with tensor parallelism. Data parallelism is better for throughput if there is enough memory. +- Add `--dp 2` to enable multi-GPU data parallelism. Data parallelism is better for throughput if there is enough memory. It can also be used together with tensor parallelism. The following command uses 4 GPUs in total. ``` python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --port 30000 --dp 2 --tp 2 ``` @@ -222,57 +232,75 @@ python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct ### Supported Models +**Generative Models** - Llama / Llama 2 / Llama 3 / Llama 3.1 - Mistral / Mixtral / Mistral NeMo - Gemma / Gemma 2 - Qwen / Qwen 2 / Qwen 2 MoE - DeepSeek / DeepSeek 2 -- LLaVA 1.5 / 1.6 - - `python -m sglang.launch_server --model-path liuhaotian/llava-v1.5-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --chat-template vicuna_v1.1 --port 30000` - - `python -m sglang.launch_server --model-path liuhaotian/llava-v1.6-vicuna-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --chat-template vicuna_v1.1 --port 30000` - - `python -m sglang.launch_server --model-path liuhaotian/llava-v1.6-34b --tokenizer-path liuhaotian/llava-v1.6-34b-tokenizer --port 30000` -- LLaVA-NeXT-Video - - see [examples/usage/llava_video](examples/usage/llava_video) +- [LLaVA-OneVision](https://llava-vl.github.io/blog/2024-08-05-llava-onevision/) + - `python3 -m sglang.launch_server --model-path lmms-lab/llava-onevision-qwen2-7b-ov --port=30000 --chat-template=chatml-llava` + - `python3 -m sglang.launch_server --model-path lmms-lab/llava-onevision-qwen2-72b-ov --port=30000 --tp-size=8 --chat-template=chatml-llava` + - Query the server with the [OpenAI Vision API](https://platform.openai.com/docs/guides/vision). See examples at [test/srt/test_vision_openai_server.py](test/srt/test_vision_openai_server.py) +- LLaVA 1.5 / 1.6 / NeXT + - `python -m sglang.launch_server --model-path lmms-lab/llama3-llava-next-8b --port=30000 --tp-size=1 --chat-template=llava_llama_3` + - `python -m sglang.launch_server --model-path lmms-lab/llava-next-72b --port=30000 --tp-size=8 --chat-template=chatml-llava` + - Query the server with the [OpenAI Vision API](https://platform.openai.com/docs/guides/vision). See examples at [test/srt/test_vision_openai_server.py](test/srt/test_vision_openai_server.py) - Yi-VL - - see [srt_example_yi_vl.py](examples/quick_start/srt_example_yi_vl.py). - StableLM - Command-R - DBRX - Grok - ChatGLM - InternLM 2 +- Exaone 3 + +**Embedding Models** + +- e5-mistral +- gte-Qwen2 + - `python -m sglang.launch_server --model-path Alibaba-NLP/gte-Qwen2-7B-instruct --is-embedding` Instructions for supporting a new model are [here](https://github.com/sgl-project/sglang/blob/main/docs/en/model_support.md). #### Use Models From ModelScope -To use model from [ModelScope](https://www.modelscope.cn), setting environment variable SGLANG_USE_MODELSCOPE. +
+More + +To use a model from [ModelScope](https://www.modelscope.cn), set the environment variable SGLANG_USE_MODELSCOPE. ``` export SGLANG_USE_MODELSCOPE=true ``` Launch [Qwen2-7B-Instruct](https://www.modelscope.cn/models/qwen/qwen2-7b-instruct) Server ``` SGLANG_USE_MODELSCOPE=true python -m sglang.launch_server --model-path qwen/Qwen2-7B-Instruct --port 30000 -``` +``` + +
#### Run Llama 3.1 405B +
+More ```bash -## Run 405B (fp8) on a single node +# Run 405B (fp8) on a single node python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-405B-Instruct-FP8 --tp 8 -## Run 405B (fp16) on two nodes -# replace the `172.16.4.52:20000` with your own first node ip address and port, disable CUDA Graph temporarily - -# on the first node -GLOO_SOCKET_IFNAME=eth0 python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-405B-Instruct --tp 16 --nccl-init-addr 172.16.4.52:20000 --nnodes 2 --node-rank 0 --disable-cuda-graph --mem-frac 0.75 +# Run 405B (fp16) on two nodes +## on the first node, replace the `172.16.4.52:20000` with your own first node ip address and port +GLOO_SOCKET_IFNAME=eth0 python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-405B-Instruct --tp 16 --nccl-init-addr 172.16.4.52:20000 --nnodes 2 --node-rank 0 --disable-cuda-graph -# on the second -GLOO_SOCKET_IFNAME=eth0 python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-405B-Instruct --tp 16 --nccl-init-addr 172.16.4.52:20000 --nnodes 2 --node-rank 1 --disable-cuda-graph --mem-frac 0.75 +## on the first node, replace the `172.16.4.52:20000` with your own first node ip address and port +GLOO_SOCKET_IFNAME=eth0 python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-405B-Instruct --tp 16 --nccl-init-addr 172.16.4.52:20000 --nnodes 2 --node-rank 1 --disable-cuda-graph ``` +
+ ### Benchmark Performance -- Benchmark a single static batch by running the following command without launching a server. The arguments are the same as for `launch_server.py`. Note that this is not a dynamic batching server, so it may run out of memory for a batch size that a real server can handle. A real server truncates the prefill into several batches, while this unit test does not. For accurate large batch testing, consider using `sglang.bench_serving`. +- Benchmark a single static batch by running the following command without launching a server. The arguments are the same as for `launch_server.py`. + Note that this is not a dynamic batching server, so it may run out of memory for a batch size that a real server can handle. + A real server truncates the prefill into several batches, while this unit test does not. For accurate large batch testing, please use `sglang.bench_serving` instead. ``` python -m sglang.bench_latency --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 32 --input-len 256 --output-len 32 ``` @@ -353,7 +381,7 @@ print(state["answer_1"]) #### More Examples Anthropic and VertexAI (Gemini) models are also supported. -You can find more examples at [examples/quick_start](examples/quick_start). +You can find more examples at [examples/quick_start](examples/frontend_language/quick_start). ### Language Feature To begin with, import sglang. @@ -366,7 +394,7 @@ You can implement your prompt flow in a function decorated by `sgl.function`. You can then invoke the function with `run` or `run_batch`. The system will manage the state, chat template, parallelism and batching for you. -The complete code for the examples below can be found at [readme_examples.py](examples/usage/readme_examples.py) +The complete code for the examples below can be found at [readme_examples.py](examples/frontend_language/usage/readme_examples.py) #### Control Flow You can use any Python code within the function body, including control flow, nested function calls, and external libraries. @@ -405,7 +433,7 @@ def tip_suggestion(s): s += "In summary" + sgl.gen("summary") ``` -#### Multi Modality +#### Multi-Modality Use `sgl.image` to pass an image as input. ```python @@ -415,7 +443,7 @@ def image_qa(s, image_file, question): s += sgl.assistant(sgl.gen("answer", max_tokens=256) ``` -See also [srt_example_llava.py](examples/quick_start/srt_example_llava.py). +See also [srt_example_llava.py](examples/frontend_language/quick_start/local_example_llava_next.py). #### Constrained Decoding Use `regex` to specify a regular expression as a decoding constraint. @@ -459,7 +487,7 @@ def character_gen(s, name): s += sgl.gen("json_output", max_tokens=256, regex=character_regex) ``` -See also [json_decode.py](examples/usage/json_decode.py) for an additional example on specifying formats with Pydantic models. +See also [json_decode.py](examples/frontend_language/usage/json_decode.py) for an additional example of specifying formats with Pydantic models. #### Batching Use `run_batch` to run a batch of requests with continuous batching. @@ -521,7 +549,6 @@ def chat_example(s): - The `choices` argument in `sgl.gen` is implemented by computing the [token-length normalized log probabilities](https://blog.eleuther.ai/multiple-choice-normalization/) of all choices and selecting the one with the highest probability. - The `regex` argument in `sgl.gen` is implemented through autoregressive decoding with logit bias masking, according to the constraints set by the regex. It is compatible with `temperature=0` and `temperature != 0`. - ## Benchmark And Performance ![8b_throughput](https://lmsys.org/images/blog/sglang_llama3/8b_throughput.svg) ![70b_fp8_throughput](https://lmsys.org/images/blog/sglang_llama3/70b_fp8_throughput.svg) diff --git a/benchmark/benchmark_vllm_060/README.md b/benchmark/benchmark_vllm_060/README.md new file mode 100644 index 00000000000..5a1247c5f4b --- /dev/null +++ b/benchmark/benchmark_vllm_060/README.md @@ -0,0 +1,89 @@ +## How to reproduce the benchmark results for SGLang v0.3.0 compared to vLLM v0.6.0 + +In short, with multi step enabled, in online scenarios that we benchmarked, the Median TTFT of vLLM is **3 times** that of SGLang, and the Median ITL is **10 times** that of SGLang. Lower Median TTFT and ITL are better. vLLM's multi-step optimization did not improve throughput while ensuring lower Median TTFT and ITL. Also, under maximum throughput benchmark, if vLLM does not set gpu util to 0.95 separately and uses the default configuration instead, its maximum throughput is **lower** than that of SGLang. + +## Online benchmark results + +### Llama 3.1 8B Instruct 1 x A100 80G + +| RPS | Num prompts | Engine | Median E2E Latency | Median TTFT | Median TPOT | Median ITL | +|------|-------------|--------|--------------------|-------------|-------------|------------| +| 4 | 1200 | SGLang | 1564.17 | **31.98** | 13.17 | **11.93** | +| 4 | 1200 | vLLM | 1691.97 | **100.48** | 14.14 | **129.32** | +| 8 | 2400 | SGLang | 2175.02 | **35.68** | 17.85 | **14.41** | +| 8 | 2400 | vLLM | 2137.16 | **120.39** | 17.09 | **158.63** | + +### Llama 3.1 70B Insruct 4 x H100 80G + +| RPS | Num Prompts | Engine | Median E2E Latency | Median TTFT | Median TPOT | Median ITL | +|------|-------------|--------|--------------------|-------------|-------------|------------| +| 4 | 1200 | SGLang | 3005.24 | **53.94** | 25.03 | **21.67** | +| 4 | 1200 | vLLM | 2915.60 | **179.15** | 23.58 | **231.23** | +| 8 | 2400 | SGLang | 4064.98 | **58.11** | 33.07 | **24.45** | +| 8 | 2400 | vLLM | 3752.38 | **207.12** | 29.15 | **275.32** | + +## Offline benchmark results + +### Llama 3.1 8B Instruct 1 x A100 80G + +| RPS | Num Prompts | Engine | Request throughput | Output token throughput | +|------|-------------|--------|--------------------|-------------------------| +| inf | 5000 | SGLang | 22.03 | **4281.51** | +| inf | 5000 | vLLM | 21.27 | **4132.37** | + +### Llama 3.1 70B Insruct 4 x H100 80G + +| RPS | Num Prompts | Engine | Request throughput | Output token throughput | +|------|-------------|--------|--------------------|-------------------------| +| inf | 5000 | SGLang | 19.84 | **3856.01** | +| inf | 5000 | vLLM | 19.04 | **3700.64** | + +## Installation + +```bash +# install sglang v0.3.0 +pip install --upgrade pip +pip install "sglang[all]"==0.3.0 +pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ + +# install vllm v0.6.0 +pip install vllm==0.6.0 +``` + +## Notes + +We referred to the reproduction method in https://github.com/vllm-project/vllm/issues/8176, and added the `--num-scheduler-steps 10` parameter when starting the vLLM server. The `gpu_memory_utilization` of vLLM is by default 0.9 at both TP 1 and TP 4, while SGLang's `mem_frac` is 0.88 at TP 1 and 0.85 at TP 4, so we manually set it to 0.88 at TP 4. + +## Online benchmarks + +```bash +# Llama 3.1 8B Instruct on 1 x A100 +python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --enable-torch-compile --disable-radix-cache +python -m vllm.entrypoints.openai.api_server --model meta-llama/Meta-Llama-3.1-8B-Instruct --disable-log-requests --num-scheduler-steps 10 --max_model_len 4096 + +# Llama 3.1 70B Instruct on 4 x H100 +python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-70B-Instruct --disable-radix-cache --tp 4 +python -m vllm.entrypoints.openai.api_server --model meta-llama/Meta-Llama-3.1-70B-Instruct --disable-log-requests --num-scheduler-steps 10 --tensor 4 --max_model_len 4096 + +# bench serving +python3 -m sglang.bench_serving --backend sglang --dataset-name sharegpt --num-prompts 1200 --request-rate 4 +python3 -m sglang.bench_serving --backend sglang --dataset-name sharegpt --num-prompts 2400 --request-rate 8 +python3 -m sglang.bench_serving --backend vllm --dataset-name sharegpt --num-prompts 1200 --request-rate 4 +python3 -m sglang.bench_serving --backend vllm --dataset-name sharegpt --num-prompts 2400 --request-rate 8 +``` + +## Offline benchmarks + +```bash +# Llama 3.1 8B Instruct on 1 x A100 +python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --enable-torch-compile --disable-radix-cache +python -m vllm.entrypoints.openai.api_server --model meta-llama/Meta-Llama-3.1-8B-Instruct --disable-log-requests --num-scheduler-steps 10 --max_model_len 4096 + +# Llama 3.1 70B Instruct on 4 x H100 +python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-70B-Instruct --disable-radix-cache --tp 4 --mem-frac 0.88 +python -m vllm.entrypoints.openai.api_server --model meta-llama/Meta-Llama-3.1-70B-Instruct --disable-log-requests --num-scheduler-steps 10 --tensor 4 --max_model_len 4096 + +# bench serving +python3 -m sglang.bench_serving --backend sglang --dataset-name sharegpt --num-prompts 5000 +python3 -m sglang.bench_serving --backend vllm --dataset-name sharegpt --num-prompts 5000 +``` diff --git a/docs/en/custom_chat_template.md b/docs/en/custom_chat_template.md index 815c7e6760b..3760bbc6a18 100644 --- a/docs/en/custom_chat_template.md +++ b/docs/en/custom_chat_template.md @@ -1,6 +1,9 @@ # Custom Chat Template in SGLang Runtime -By default, the server uses the chat template specified in the model tokenizer from Hugging Face. It should just work for most official models such as Llama-2/Llama-3. +**NOTE**: There are two chat template systems in SGLang project. This document is about setting a custom chat template for the OpenAI-compatible API server (defined at [conversation.py](../../python/sglang/srt/conversation.py)). It is NOT related to the chat template used in the SGLang language frontend (defined at [chat_template.py](../../python/sglang/lang/chat_template.py)). + +By default, the server uses the chat template specified in the model tokenizer from Hugging Face. +It should just work for most official models such as Llama-2/Llama-3. If needed, you can also override the chat template when launching the server: diff --git a/docs/en/hyperparameter_tuning.md b/docs/en/hyperparameter_tuning.md index 02a0657c3f0..f2bf9d55f3d 100644 --- a/docs/en/hyperparameter_tuning.md +++ b/docs/en/hyperparameter_tuning.md @@ -6,7 +6,7 @@ Achieving a large batch size is the most important thing for attaining high thro When the server is running at full load, look for the following in the log: -```[gpu=0] Decode batch. #running-req: 233, #token: 370959, token usage: 0.82, gen throughput (token/s): 4594.01, #queue-req: 417``` +```Decode batch. #running-req: 233, #token: 370959, token usage: 0.82, gen throughput (token/s): 4594.01, #queue-req: 417``` ### Tune Your Request Submission Speed `#queue-req` indicates the number of requests in the queue. If you frequently see `#queue-req == 0`, it suggests you are bottlenecked by the request submission speed. diff --git a/docs/en/sampling_params.md b/docs/en/sampling_params.md index 7d866e69295..71de754fde2 100644 --- a/docs/en/sampling_params.md +++ b/docs/en/sampling_params.md @@ -1,5 +1,8 @@ # Sampling Parameters in SGLang Runtime This doc describes the sampling parameters of the SGLang Runtime. +It is the low-level endpoint of the runtime. +If you want a high-level endpoint that can automatically handle chat templates, consider using the [OpenAI Compatible API +](https://github.com/sgl-project/sglang?tab=readme-ov-file#openai-compatible-api). The `/generate` endpoint accepts the following arguments in the JSON format. @@ -47,6 +50,12 @@ top_p: float = 1.0, top_k: int = -1, # Min-p sampling min_p: float = 0.0, +# DRY sampling +dry_multiplier: float = 0.0, +dry_base: float = 0.0, +dry_allowed_length: int = 2, +dry_penalty_last_n: int = 0, +dry_sequence_breakers: Optional[List[str]] = [], # Whether to ignore EOS token. ignore_eos: bool = False, # Whether to skip the special tokens during detokenization. @@ -57,6 +66,9 @@ spaces_between_special_tokens: bool = True, regex: Optional[str] = None, # Do parallel sampling and return `n` outputs. n: int = 1, +# Constrains the output to follow a given JSON schema. +# `regex` and `json_schema` cannot be set at the same time. +json_schema: Optional[str] = None, ## Penalties. See [Performance Implications on Penalties] section below for more informations. @@ -140,7 +152,7 @@ print("") Launch a server ``` -python3 -m sglang.launch_server --model-path liuhaotian/llava-v1.6-vicuna-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --chat-template vicuna_v1.1 --port 30000 +python3 -m sglang.launch_server --model-path lmms-lab/llava-onevision-qwen2-7b-ov --chat-template chatml-llava ``` Download an image @@ -155,7 +167,9 @@ import requests response = requests.post( "http://localhost:30000/generate", json={ - "text": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: \nDescribe this picture ASSISTANT:", + "text": "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n" + "<|im_start|>user\n\nDescribe this image in a very short sentence.<|im_end|>\n" + "<|im_start|>assistant\n", "image_data": "example_image.png", "sampling_params": { "temperature": 0, diff --git a/docs/en/setup_github_runner.md b/docs/en/setup_github_runner.md index 97a7f26266b..8e817dcc88c 100644 --- a/docs/en/setup_github_runner.md +++ b/docs/en/setup_github_runner.md @@ -1,89 +1,44 @@ -# Set up self hosted runner for GitHub Action +# Set Up Self-hosted Runners for GitHub Action -## Config Runner +## Add a Runner -```bash -# https://github.com/sgl-project/sglang/settings/actions/runners/new?arch=x64&os=linux -# Involves some TOKEN and other private information, click the link to view specific steps. -``` +### Step 1: Start a docker container. -## Start Runner +You can mount a folder for the shared huggingface model weights cache. The command below uses `/tmp/huggingface` as an example. -add `/lib/systemd/system/e2e.service` ``` -[Unit] -StartLimitIntervalSec=0 -[Service] -Environment="CUDA_VISIBLE_DEVICES=7" -Environment="XDG_CACHE_HOME=/data/.cache" -Environment="HF_TOKEN=hf_xx" -Environment="OPENAI_API_KEY=sk-xx" -Environment="HOME=/data/zhyncs/runner-v1" -Environment="SGLANG_IS_IN_CI=true" -Restart=always -RestartSec=1 -ExecStart=/data/zhyncs/runner-v1/actions-runner/run.sh -[Install] -WantedBy=multi-user.target +docker pull nvidia/cuda:12.1.1-devel-ubuntu22.04 +docker run --shm-size 64g -it -v /tmp/huggingface:/hf_home --gpus all nvidia/cuda:12.1.1-devel-ubuntu22.04 /bin/bash ``` -add `/lib/systemd/system/unit.service` -``` -[Unit] -StartLimitIntervalSec=0 -[Service] -Environment="CUDA_VISIBLE_DEVICES=6" -Environment="XDG_CACHE_HOME=/data/.cache" -Environment="HF_TOKEN=hf_xx" -Environment="OPENAI_API_KEY=sk-xx" -Environment="HOME=/data/zhyncs/runner-v2" -Environment="SGLANG_IS_IN_CI=true" -Restart=always -RestartSec=1 -ExecStart=/data/zhyncs/runner-v2/actions-runner/run.sh -[Install] -WantedBy=multi-user.target -``` +### Step 2: Configure the runner by `config.sh` + +Run these commands inside the container. -add `/lib/systemd/system/accuracy.service` ``` -[Unit] -StartLimitIntervalSec=0 -[Service] -Environment="CUDA_VISIBLE_DEVICES=5" -Environment="XDG_CACHE_HOME=/data/.cache" -Environment="HF_TOKEN=hf_xx" -Environment="OPENAI_API_KEY=sk-xx" -Environment="HOME=/data/zhyncs/runner-v3" -Environment="SGLANG_IS_IN_CI=true" -Restart=always -RestartSec=1 -ExecStart=/data/zhyncs/runner-v3/actions-runner/run.sh -[Install] -WantedBy=multi-user.target +apt update && apt install -y curl python3-pip git +export RUNNER_ALLOW_RUNASROOT=1 ``` -```bash -cd /data/zhyncs/runner-v1 -python3 -m venv venv +Then follow https://github.com/sgl-project/sglang/settings/actions/runners/new?arch=x64&os=linux to run `config.sh` -cd /data/zhyncs/runner-v2 -python3 -m venv venv +**Notes** +- Do not need to specify the runner group +- Give it a name (e.g., `test-sgl-gpu-0`) and some labels (e.g., `1-gpu-runner`). The labels can be editted later in Github Settings. +- Do not need to change the work folder. -cd /data/zhyncs/runner-v3 -python3 -m venv venv +### Step 3: Run the runner by `run.sh` -sudo systemctl daemon-reload - -sudo systemctl start e2e -sudo systemctl enable e2e -sudo systemctl status e2e - -sudo systemctl start unit -sudo systemctl enable unit -sudo systemctl status unit +- Set up environment variables +``` +export HF_HOME=/hf_home +export SGLANG_IS_IN_CI=true +export HF_TOKEN=hf_xxx +export OPENAI_API_KEY=sk-xxx +export CUDA_VISIBLE_DEVICES=0 +``` -sudo systemctl start accuracy -sudo systemctl enable accuracy -sudo systemctl status accuracy +- Run it forever ``` +while true; do ./run.sh; echo "Restarting..."; sleep 2; done +``` \ No newline at end of file diff --git a/examples/quick_start/anthropic_example_chat.py b/examples/frontend_language/quick_start/anthropic_example_chat.py similarity index 100% rename from examples/quick_start/anthropic_example_chat.py rename to examples/frontend_language/quick_start/anthropic_example_chat.py diff --git a/examples/quick_start/anthropic_example_complete.py b/examples/frontend_language/quick_start/anthropic_example_complete.py similarity index 100% rename from examples/quick_start/anthropic_example_complete.py rename to examples/frontend_language/quick_start/anthropic_example_complete.py diff --git a/examples/quick_start/azure_openai_example_chat.py b/examples/frontend_language/quick_start/azure_openai_example_chat.py similarity index 100% rename from examples/quick_start/azure_openai_example_chat.py rename to examples/frontend_language/quick_start/azure_openai_example_chat.py diff --git a/examples/quick_start/gemini_example_chat.py b/examples/frontend_language/quick_start/gemini_example_chat.py similarity index 100% rename from examples/quick_start/gemini_example_chat.py rename to examples/frontend_language/quick_start/gemini_example_chat.py diff --git a/examples/quick_start/gemini_example_complete.py b/examples/frontend_language/quick_start/gemini_example_complete.py similarity index 100% rename from examples/quick_start/gemini_example_complete.py rename to examples/frontend_language/quick_start/gemini_example_complete.py diff --git a/examples/quick_start/gemini_example_multimodal_chat.py b/examples/frontend_language/quick_start/gemini_example_multimodal_chat.py similarity index 100% rename from examples/quick_start/gemini_example_multimodal_chat.py rename to examples/frontend_language/quick_start/gemini_example_multimodal_chat.py diff --git a/examples/quick_start/images/cat.jpeg b/examples/frontend_language/quick_start/images/cat.jpeg similarity index 100% rename from examples/quick_start/images/cat.jpeg rename to examples/frontend_language/quick_start/images/cat.jpeg diff --git a/examples/quick_start/images/dog.jpeg b/examples/frontend_language/quick_start/images/dog.jpeg similarity index 100% rename from examples/quick_start/images/dog.jpeg rename to examples/frontend_language/quick_start/images/dog.jpeg diff --git a/examples/quick_start/srt_example_chat.py b/examples/frontend_language/quick_start/local_example_chat.py similarity index 98% rename from examples/quick_start/srt_example_chat.py rename to examples/frontend_language/quick_start/local_example_chat.py index b1e1658a2a9..e1e4b62ccac 100644 --- a/examples/quick_start/srt_example_chat.py +++ b/examples/frontend_language/quick_start/local_example_chat.py @@ -1,6 +1,6 @@ """ Usage: -python3 srt_example_chat.py +python3 local_example_chat.py """ import sglang as sgl diff --git a/examples/quick_start/srt_example_complete.py b/examples/frontend_language/quick_start/local_example_complete.py similarity index 97% rename from examples/quick_start/srt_example_complete.py rename to examples/frontend_language/quick_start/local_example_complete.py index 056245979f4..00a451cf642 100644 --- a/examples/quick_start/srt_example_complete.py +++ b/examples/frontend_language/quick_start/local_example_complete.py @@ -1,6 +1,6 @@ """ Usage: -python3 srt_example_complete.py +python3 local_example_complete.py """ import sglang as sgl diff --git a/examples/quick_start/srt_example_llava.py b/examples/frontend_language/quick_start/local_example_llava_next.py similarity index 74% rename from examples/quick_start/srt_example_llava.py rename to examples/frontend_language/quick_start/local_example_llava_next.py index 5d8f752394f..fc5a1d04c65 100644 --- a/examples/quick_start/srt_example_llava.py +++ b/examples/frontend_language/quick_start/local_example_llava_next.py @@ -1,8 +1,9 @@ """ -Usage: python3 srt_example_llava.py +Usage: python3 local_example_llava_next.py """ import sglang as sgl +from sglang.lang.chat_template import get_chat_template @sgl.function @@ -44,10 +45,17 @@ def batch(): if __name__ == "__main__": - runtime = sgl.Runtime( - model_path="liuhaotian/llava-v1.6-vicuna-7b", - tokenizer_path="llava-hf/llava-1.5-7b-hf", - ) + import multiprocessing as mp + + mp.set_start_method("spawn", force=True) + + runtime = sgl.Runtime(model_path="lmms-lab/llama3-llava-next-8b") + runtime.endpoint.chat_template = get_chat_template("llama-3-instruct") + + # Or you can use the 72B model + # runtime = sgl.Runtime(model_path="lmms-lab/llava-next-72b", tp_size=8) + # runtime.endpoint.chat_template = get_chat_template("chatml-llava") + sgl.set_default_backend(runtime) print(f"chat template: {runtime.endpoint.chat_template.name}") diff --git a/examples/quick_start/openai_example_chat.py b/examples/frontend_language/quick_start/openai_example_chat.py similarity index 100% rename from examples/quick_start/openai_example_chat.py rename to examples/frontend_language/quick_start/openai_example_chat.py diff --git a/examples/quick_start/openai_example_complete.py b/examples/frontend_language/quick_start/openai_example_complete.py similarity index 100% rename from examples/quick_start/openai_example_complete.py rename to examples/frontend_language/quick_start/openai_example_complete.py diff --git a/examples/quick_start/openrouter_example_chat.py b/examples/frontend_language/quick_start/openrouter_example_chat.py similarity index 100% rename from examples/quick_start/openrouter_example_chat.py rename to examples/frontend_language/quick_start/openrouter_example_chat.py diff --git a/examples/quick_start/together_example_chat.py b/examples/frontend_language/quick_start/together_example_chat.py similarity index 100% rename from examples/quick_start/together_example_chat.py rename to examples/frontend_language/quick_start/together_example_chat.py diff --git a/examples/quick_start/together_example_complete.py b/examples/frontend_language/quick_start/together_example_complete.py similarity index 100% rename from examples/quick_start/together_example_complete.py rename to examples/frontend_language/quick_start/together_example_complete.py diff --git a/examples/usage/chinese_regex.py b/examples/frontend_language/usage/chinese_regex.py similarity index 100% rename from examples/usage/chinese_regex.py rename to examples/frontend_language/usage/chinese_regex.py diff --git a/examples/usage/choices_logprob.py b/examples/frontend_language/usage/choices_logprob.py similarity index 100% rename from examples/usage/choices_logprob.py rename to examples/frontend_language/usage/choices_logprob.py diff --git a/examples/usage/cot_decoding.py b/examples/frontend_language/usage/cot_decoding.py similarity index 100% rename from examples/usage/cot_decoding.py rename to examples/frontend_language/usage/cot_decoding.py diff --git a/examples/usage/json_decode.py b/examples/frontend_language/usage/json_decode.py similarity index 100% rename from examples/usage/json_decode.py rename to examples/frontend_language/usage/json_decode.py diff --git a/examples/usage/json_logprobs.py b/examples/frontend_language/usage/json_logprobs.py similarity index 100% rename from examples/usage/json_logprobs.py rename to examples/frontend_language/usage/json_logprobs.py diff --git a/examples/usage/llava_video/srt_example_llava_v.py b/examples/frontend_language/usage/llava_video/srt_example_llava_v.py similarity index 85% rename from examples/usage/llava_video/srt_example_llava_v.py rename to examples/frontend_language/usage/llava_video/srt_example_llava_v.py index 27ba862d30d..02bab342ac5 100644 --- a/examples/usage/llava_video/srt_example_llava_v.py +++ b/examples/frontend_language/usage/llava_video/srt_example_llava_v.py @@ -1,7 +1,8 @@ """ Usage: pip install opencv-python-headless -python3 srt_example_llava.py + +python3 srt_example_llava_v.py """ import argparse @@ -9,6 +10,8 @@ import os import time +import requests + import sglang as sgl @@ -121,6 +124,20 @@ def batch(video_dir, save_dir, cur_chunk, num_chunks, num_frames=16, batch_size= if __name__ == "__main__": + url = "https://raw.githubusercontent.com/EvolvingLMMs-Lab/sglang/dev/onevision_local/assets/jobs.mp4" + + cache_dir = os.path.expanduser("~/.cache") + file_path = os.path.join(cache_dir, "jobs.mp4") + + os.makedirs(cache_dir, exist_ok=True) + + response = requests.get(url) + response.raise_for_status() # Raise an exception for bad responses + + with open(file_path, "wb") as f: + f.write(response.content) + + print(f"File downloaded and saved to: {file_path}") # Create the parser parser = argparse.ArgumentParser( description="Run video processing with specified port." @@ -148,7 +165,7 @@ def batch(video_dir, save_dir, cur_chunk, num_chunks, num_frames=16, batch_size= parser.add_argument( "--video-dir", type=str, - default="./videos/Q98Z4OTh8RwmDonc.mp4", + default=os.path.expanduser("~/.cache/jobs.mp4"), help="The directory or path for the processed video files.", ) parser.add_argument( @@ -167,13 +184,9 @@ def batch(video_dir, save_dir, cur_chunk, num_chunks, num_frames=16, batch_size= # Parse the arguments args = parser.parse_args() - cur_port = args.port - cur_chunk = args.chunk_idx - num_chunks = args.num_chunks - num_frames = args.num_frames if "34b" in args.model_path.lower(): @@ -184,20 +197,19 @@ def batch(video_dir, save_dir, cur_chunk, num_chunks, num_frames=16, batch_size= print("Invalid model path. Please specify a valid model path.") exit() - model_overide_args = {} - - model_overide_args["mm_spatial_pool_stride"] = args.mm_spatial_pool_stride - model_overide_args["architectures"] = ["LlavaVidForCausalLM"] - model_overide_args["num_frames"] = args.num_frames - model_overide_args["model_type"] = "llava" + model_override_args = {} + model_override_args["mm_spatial_pool_stride"] = args.mm_spatial_pool_stride + model_override_args["architectures"] = ["LlavaVidForCausalLM"] + model_override_args["num_frames"] = args.num_frames + model_override_args["model_type"] = "llava" if "34b" in args.model_path.lower(): - model_overide_args["image_token_index"] = 64002 + model_override_args["image_token_index"] = 64002 if args.num_frames == 32: - model_overide_args["rope_scaling"] = {"factor": 2.0, "type": "linear"} - model_overide_args["max_sequence_length"] = 4096 * 2 - model_overide_args["tokenizer_model_max_length"] = 4096 * 2 + model_override_args["rope_scaling"] = {"factor": 2.0, "type": "linear"} + model_override_args["max_sequence_length"] = 4096 * 2 + model_override_args["tokenizer_model_max_length"] = 4096 * 2 elif args.num_frames < 32: pass else: @@ -211,14 +223,13 @@ def batch(video_dir, save_dir, cur_chunk, num_chunks, num_frames=16, batch_size= tokenizer_path=tokenizer_path, port=cur_port, additional_ports=[cur_port + 1, cur_port + 2, cur_port + 3, cur_port + 4], - model_overide_args=model_overide_args, + model_override_args=model_override_args, tp_size=1, ) sgl.set_default_backend(runtime) print(f"chat template: {runtime.endpoint.chat_template.name}") # Run a single request - # try: print("\n========== single ==========\n") root = args.video_dir if os.path.isfile(root): @@ -240,13 +251,10 @@ def batch(video_dir, save_dir, cur_chunk, num_chunks, num_frames=16, batch_size= ) # Calculate the average processing time print(f"Average processing time per video: {average_time:.2f} seconds") runtime.shutdown() - # except Exception as e: - # print(e) - runtime.shutdown() - # # # Run a batch of requests + # # Run a batch of requests # print("\n========== batch ==========\n") # if not os.path.exists(args.save_dir): # os.makedirs(args.save_dir) - # batch(args.video_dir,args.save_dir,cur_chunk, num_chunks, num_frames, num_chunks) + # batch(args.video_dir, args.save_dir, cur_chunk, num_chunks, num_frames, num_chunks) # runtime.shutdown() diff --git a/examples/usage/llava_video/srt_example_llava_v.sh b/examples/frontend_language/usage/llava_video/srt_example_llava_v.sh similarity index 100% rename from examples/usage/llava_video/srt_example_llava_v.sh rename to examples/frontend_language/usage/llava_video/srt_example_llava_v.sh diff --git a/examples/usage/openai_chat_speculative.py b/examples/frontend_language/usage/openai_chat_speculative.py similarity index 100% rename from examples/usage/openai_chat_speculative.py rename to examples/frontend_language/usage/openai_chat_speculative.py diff --git a/examples/usage/openai_speculative.py b/examples/frontend_language/usage/openai_speculative.py similarity index 100% rename from examples/usage/openai_speculative.py rename to examples/frontend_language/usage/openai_speculative.py diff --git a/examples/usage/parallel_sample.py b/examples/frontend_language/usage/parallel_sample.py similarity index 100% rename from examples/usage/parallel_sample.py rename to examples/frontend_language/usage/parallel_sample.py diff --git a/examples/usage/rag_using_parea/trace_and_evaluate_rag_using_parea.ipynb b/examples/frontend_language/usage/rag_using_parea/trace_and_evaluate_rag_using_parea.ipynb similarity index 100% rename from examples/usage/rag_using_parea/trace_and_evaluate_rag_using_parea.ipynb rename to examples/frontend_language/usage/rag_using_parea/trace_and_evaluate_rag_using_parea.ipynb diff --git a/examples/usage/readme_examples.py b/examples/frontend_language/usage/readme_examples.py similarity index 100% rename from examples/usage/readme_examples.py rename to examples/frontend_language/usage/readme_examples.py diff --git a/examples/usage/streaming.py b/examples/frontend_language/usage/streaming.py similarity index 100% rename from examples/usage/streaming.py rename to examples/frontend_language/usage/streaming.py diff --git a/examples/usage/triton/Dockerfile b/examples/frontend_language/usage/triton/Dockerfile similarity index 100% rename from examples/usage/triton/Dockerfile rename to examples/frontend_language/usage/triton/Dockerfile diff --git a/examples/usage/triton/README.md b/examples/frontend_language/usage/triton/README.md similarity index 100% rename from examples/usage/triton/README.md rename to examples/frontend_language/usage/triton/README.md diff --git a/examples/usage/triton/models/character_generation/1/model.py b/examples/frontend_language/usage/triton/models/character_generation/1/model.py similarity index 100% rename from examples/usage/triton/models/character_generation/1/model.py rename to examples/frontend_language/usage/triton/models/character_generation/1/model.py diff --git a/examples/usage/triton/models/character_generation/config.pbtxt b/examples/frontend_language/usage/triton/models/character_generation/config.pbtxt similarity index 100% rename from examples/usage/triton/models/character_generation/config.pbtxt rename to examples/frontend_language/usage/triton/models/character_generation/config.pbtxt diff --git a/examples/quick_start/srt_example_yi_vl.py b/examples/quick_start/srt_example_yi_vl.py deleted file mode 100644 index 66c7d57126c..00000000000 --- a/examples/quick_start/srt_example_yi_vl.py +++ /dev/null @@ -1,70 +0,0 @@ -""" -Usage: python3 srt_example_yi_vl.py - -Requirements: transformers==4.38 -""" - -import sglang as sgl - - -@sgl.function -def image_qa(s, image_path, question): - s += sgl.user(sgl.image(image_path) + question) - s += sgl.assistant(sgl.gen("answer")) - - -def single(): - state = image_qa.run( - image_path="images/cat.jpeg", - question="What is this?", - max_new_tokens=64, - stop="###", - ) - print(state["answer"], "\n") - - -def stream(): - state = image_qa.run( - image_path="images/cat.jpeg", - question="What is this?", - max_new_tokens=64, - stream=True, - stop="###", - ) - - for out in state.text_iter("answer"): - print(out, end="", flush=True) - print() - - -def batch(): - states = image_qa.run_batch( - [ - {"image_path": "images/cat.jpeg", "question": "What is this?"}, - {"image_path": "images/dog.jpeg", "question": "What is this?"}, - ], - max_new_tokens=64, - stop="###", - ) - for s in states: - print(s["answer"], "\n") - - -if __name__ == "__main__": - runtime = sgl.Runtime(model_path="BabyChou/Yi-VL-6B") - # runtime = sgl.Runtime(model_path="BabyChou/Yi-VL-34B") - sgl.set_default_backend(runtime) - - # Run a single request - print("\n========== single ==========\n") - single() - - # Stream output - print("\n========== stream ==========\n") - stream() - - # Run a batch of requests - print("\n========== batch ==========\n") - batch() - - runtime.shutdown() diff --git a/examples/usage/async_io.py b/examples/runtime/async_io_api.py similarity index 100% rename from examples/usage/async_io.py rename to examples/runtime/async_io_api.py diff --git a/examples/usage/llava/http_llama3_llava_test.py b/examples/runtime/llava_onevision/http_llama3_llava_test.py similarity index 94% rename from examples/usage/llava/http_llama3_llava_test.py rename to examples/runtime/llava_onevision/http_llama3_llava_test.py index 813a26af531..a019e214d6f 100644 --- a/examples/usage/llava/http_llama3_llava_test.py +++ b/examples/runtime/llava_onevision/http_llama3_llava_test.py @@ -4,7 +4,7 @@ # Installing latest sglang. # Endpoint Service CLI: -# python -m sglang.launch_server --model-path lmms-lab/llama3-llava-next-8b --tokenizer-path lmms-lab/llama3-llava-next-8b-tokenizer --port=30000 --host="127.0.0.1" --tp-size=4 +python -m sglang.launch_server --model-path lmms-lab/llama3-llava-next-8b --port=30000 python3 http_llama3_llava_test.py @@ -16,7 +16,6 @@ import asyncio import copy import json -import time import aiohttp import requests diff --git a/examples/runtime/llava_onevision/http_llava_onevision_test.py b/examples/runtime/llava_onevision/http_llava_onevision_test.py new file mode 100644 index 00000000000..0c93d2ce2b2 --- /dev/null +++ b/examples/runtime/llava_onevision/http_llava_onevision_test.py @@ -0,0 +1,264 @@ +""" +Usage: + +python3 -m sglang.launch_server --model-path lmms-lab/llava-onevision-qwen2-72b-ov --port=30000 --tp-size=8 --chat-template=chatml-llava + +python3 http_llava_onevision_test.py +""" + +import base64 +import io +import os +import sys +import time + +import numpy as np +import openai +import requests +from decord import VideoReader, cpu +from PIL import Image + +# pip install httpx==0.23.3 +# pip install decord +# pip install protobuf==3.20.0 + + +def download_video(url, cache_dir): + file_path = os.path.join(cache_dir, "jobs.mp4") + os.makedirs(cache_dir, exist_ok=True) + + response = requests.get(url) + response.raise_for_status() + + with open(file_path, "wb") as f: + f.write(response.content) + + print(f"File downloaded and saved to: {file_path}") + return file_path + + +def create_openai_client(base_url): + return openai.Client(api_key="EMPTY", base_url=base_url) + + +def image_stream_request_test(client): + print("----------------------Image Stream Request Test----------------------") + stream_request = client.chat.completions.create( + model="default", + messages=[ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": "https://raw.githubusercontent.com/sgl-project/sglang/main/assets/logo.png" + }, + }, + { + "type": "text", + "text": "Please describe this image. Please list the benchmarks and the models.", + }, + ], + }, + ], + temperature=0.7, + max_tokens=1024, + stream=True, + ) + stream_response = "" + + for chunk in stream_request: + if chunk.choices[0].delta.content is not None: + content = chunk.choices[0].delta.content + stream_response += content + sys.stdout.write(content) + sys.stdout.flush() + + print("-" * 30) + + +def multi_image_stream_request_test(client): + print( + "----------------------Multi-Images Stream Request Test----------------------" + ) + stream_request = client.chat.completions.create( + model="default", + messages=[ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": "https://raw.githubusercontent.com/sgl-project/sglang/main/assets/logo.png" + }, + }, + { + "type": "image_url", + "image_url": { + "url": "https://raw.githubusercontent.com/sgl-project/sglang/main/test/lang/example_image.png" + }, + }, + { + "type": "text", + "text": "I have shown you two images. Please describe the two images to me.", + }, + ], + }, + ], + temperature=0.7, + max_tokens=1024, + stream=True, + ) + stream_response = "" + + for chunk in stream_request: + if chunk.choices[0].delta.content is not None: + content = chunk.choices[0].delta.content + stream_response += content + sys.stdout.write(content) + sys.stdout.flush() + + print("-" * 30) + + +def video_stream_request_test(client, video_path): + print("------------------------Video Stream Request Test----------------------") + messages = prepare_video_messages(video_path) + + video_request = client.chat.completions.create( + model="default", + messages=messages, + temperature=0, + max_tokens=1024, + stream=True, + ) + print("-" * 30) + video_response = "" + + for chunk in video_request: + if chunk.choices[0].delta.content is not None: + content = chunk.choices[0].delta.content + video_response += content + sys.stdout.write(content) + sys.stdout.flush() + print("-" * 30) + + +def image_speed_test(client): + print("----------------------Image Speed Test----------------------") + start_time = time.time() + request = client.chat.completions.create( + model="default", + messages=[ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": "https://raw.githubusercontent.com/sgl-project/sglang/main/assets/logo.png" + }, + }, + { + "type": "text", + "text": "Please describe this image. Please list the benchmarks and the models.", + }, + ], + }, + ], + temperature=0, + max_tokens=1024, + ) + end_time = time.time() + response = request.choices[0].message.content + print(response) + print("-" * 30) + print_speed_test_results(request, start_time, end_time) + + +def video_speed_test(client, video_path): + print("------------------------Video Speed Test------------------------") + messages = prepare_video_messages(video_path) + + start_time = time.time() + video_request = client.chat.completions.create( + model="default", + messages=messages, + temperature=0, + max_tokens=1024, + ) + end_time = time.time() + video_response = video_request.choices[0].message.content + print(video_response) + print("-" * 30) + print_speed_test_results(video_request, start_time, end_time) + + +def prepare_video_messages(video_path): + max_frames_num = 32 + vr = VideoReader(video_path, ctx=cpu(0)) + total_frame_num = len(vr) + uniform_sampled_frames = np.linspace( + 0, total_frame_num - 1, max_frames_num, dtype=int + ) + frame_idx = uniform_sampled_frames.tolist() + frames = vr.get_batch(frame_idx).asnumpy() + + base64_frames = [] + for frame in frames: + pil_img = Image.fromarray(frame) + buff = io.BytesIO() + pil_img.save(buff, format="JPEG") + base64_str = base64.b64encode(buff.getvalue()).decode("utf-8") + base64_frames.append(base64_str) + + messages = [{"role": "user", "content": []}] + frame_format = { + "type": "image_url", + "image_url": {"url": "data:image/jpeg;base64,{}"}, + } + + for base64_frame in base64_frames: + frame_format["image_url"]["url"] = "data:image/jpeg;base64,{}".format( + base64_frame + ) + messages[0]["content"].append(frame_format.copy()) + + prompt = {"type": "text", "text": "Please describe the video in detail."} + messages[0]["content"].append(prompt) + + return messages + + +def print_speed_test_results(request, start_time, end_time): + total_tokens = request.usage.total_tokens + completion_tokens = request.usage.completion_tokens + prompt_tokens = request.usage.prompt_tokens + + print(f"Total tokens: {total_tokens}") + print(f"Completion tokens: {completion_tokens}") + print(f"Prompt tokens: {prompt_tokens}") + print(f"Time taken: {end_time - start_time} seconds") + print(f"Token per second: {total_tokens / (end_time - start_time)}") + print(f"Completion token per second: {completion_tokens / (end_time - start_time)}") + print(f"Prompt token per second: {prompt_tokens / (end_time - start_time)}") + + +def main(): + url = "https://raw.githubusercontent.com/EvolvingLMMs-Lab/sglang/dev/onevision_local/assets/jobs.mp4" + cache_dir = os.path.expanduser("~/.cache") + video_path = download_video(url, cache_dir) + + client = create_openai_client("http://127.0.0.1:30000/v1") + + image_stream_request_test(client) + multi_image_stream_request_test(client) + video_stream_request_test(client, video_path) + image_speed_test(client) + video_speed_test(client, video_path) + + +if __name__ == "__main__": + main() diff --git a/examples/usage/llava/http_qwen_llava_test.py b/examples/runtime/llava_onevision/http_qwen_llava_test.py similarity index 95% rename from examples/usage/llava/http_qwen_llava_test.py rename to examples/runtime/llava_onevision/http_qwen_llava_test.py index 1c29658c609..dca56e7a33c 100644 --- a/examples/usage/llava/http_qwen_llava_test.py +++ b/examples/runtime/llava_onevision/http_qwen_llava_test.py @@ -4,7 +4,7 @@ # Installing latest sglang. # Endpoint Service CLI: -# python -m sglang.launch_server --model-path lmms-lab/llava-next-72b --tokenizer-path lmms-lab/llavanext-qwen-tokenizer --port=30000 --host="127.0.0.1" --tp-size=4 +python -m sglang.launch_server --model-path lmms-lab/llava-next-72b --port=30000 --tp-size=8 python3 http_qwen_llava_test.py @@ -16,7 +16,6 @@ import asyncio import copy import json -import time import aiohttp import requests diff --git a/examples/usage/openai_batch_chat.py b/examples/runtime/openai_batch_chat.py similarity index 100% rename from examples/usage/openai_batch_chat.py rename to examples/runtime/openai_batch_chat.py diff --git a/examples/usage/openai_batch_complete.py b/examples/runtime/openai_batch_complete.py similarity index 100% rename from examples/usage/openai_batch_complete.py rename to examples/runtime/openai_batch_complete.py diff --git a/examples/usage/llava/srt_llava_next_test.py b/examples/usage/llava/srt_llava_next_test.py deleted file mode 100644 index 0f9621648a7..00000000000 --- a/examples/usage/llava/srt_llava_next_test.py +++ /dev/null @@ -1,90 +0,0 @@ -""" -Usage: python3 srt_example_llava.py -""" - -from PIL import ImageFile - -import sglang as sgl -from sglang.lang.chat_template import get_chat_template -from sglang.srt.utils import load_image - -ImageFile.LOAD_TRUNCATED_IMAGES = True # Allow loading of truncated images - - -@sgl.function -def image_qa(s, image, question): - s += sgl.user(sgl.image(image) + question) - s += sgl.assistant(sgl.gen("answer")) - - -def single(): - image_url = "https://farm4.staticflickr.com/3175/2653711032_804ff86d81_z.jpg" - pil_image, _ = load_image(image_url) - state = image_qa.run(image=pil_image, question="What is this?", max_new_tokens=512) - print(state["answer"], "\n") - - -def stream(): - image_url = "https://farm4.staticflickr.com/3175/2653711032_804ff86d81_z.jpg" - pil_image, _ = load_image(image_url) - state = image_qa.run( - image=pil_image, - question="Please generate short caption for this image.", - max_new_tokens=512, - temperature=0, - stream=True, - ) - - for out in state.text_iter("answer"): - print(out, end="", flush=True) - print() - - -def batch(): - image_url = "https://farm4.staticflickr.com/3175/2653711032_804ff86d81_z.jpg" - pil_image, _ = load_image(image_url) - states = image_qa.run_batch( - [ - {"image": pil_image, "question": "What is this?"}, - {"image": pil_image, "question": "What is this?"}, - ], - max_new_tokens=512, - ) - for s in states: - print(s["answer"], "\n") - - -if __name__ == "__main__": - import multiprocessing as mp - - mp.set_start_method("spawn", force=True) - runtime = sgl.Runtime( - model_path="lmms-lab/llama3-llava-next-8b", - tokenizer_path="lmms-lab/llama3-llava-next-8b-tokenizer", - ) - runtime.endpoint.chat_template = get_chat_template("llama-3-instruct") - # runtime = sgl.Runtime( - # model_path="lmms-lab/llava-next-72b", - # tokenizer_path="lmms-lab/llavanext-qwen-tokenizer", - # ) - # runtime.endpoint.chat_template = get_chat_template("chatml-llava") - sgl.set_default_backend(runtime) - print(f"chat template: {runtime.endpoint.chat_template.name}") - - # Or you can use API models - # sgl.set_default_backend(sgl.OpenAI("gpt-4-vision-preview")) - # sgl.set_default_backend(sgl.VertexAI("gemini-pro-vision")) - - # Run a single request - print("\n========== single ==========\n") - single() - - # Stream output - print("\n========== stream ==========\n") - stream() - - # Run a batch of requests - print("\n========== batch ==========\n") - batch() - - runtime.shutdown() diff --git a/examples/usage/llava_video/videos/Q98Z4OTh8RwmDonc.mp4 b/examples/usage/llava_video/videos/Q98Z4OTh8RwmDonc.mp4 deleted file mode 100644 index 32d912dbfa1..00000000000 Binary files a/examples/usage/llava_video/videos/Q98Z4OTh8RwmDonc.mp4 and /dev/null differ diff --git a/examples/usage/openai_parallel_sample.py b/examples/usage/openai_parallel_sample.py deleted file mode 100644 index 753e66c744f..00000000000 --- a/examples/usage/openai_parallel_sample.py +++ /dev/null @@ -1,153 +0,0 @@ -import openai - -client = openai.Client(base_url="http://127.0.0.1:30000/v1", api_key="EMPTY") - -# Text completion -response = client.completions.create( - model="default", - prompt="I am a robot and I want to study like humans. Now let's tell a story. Once upon a time, there was a little", - n=1, - temperature=0.8, - max_tokens=32, -) -print(response) - - -# Text completion -response = client.completions.create( - model="default", - prompt="I am a robot and I want to study like humans. Now let's tell a story. Once upon a time, there was a little", - n=5, - temperature=0.8, - max_tokens=320, -) -print(response) - - -# Text completion -response = client.completions.create( - model="default", - prompt="I am a robot and I want to study like humans. Now let's tell a story. Once upon a time, there was a little", - n=3, - temperature=0.8, - max_tokens=32, -) -print(response) - - -# Text completion -response = client.completions.create( - model="default", - prompt=["The name of the famous soccer player is"], - n=1, - temperature=0.8, - max_tokens=128, -) -print(response) - - -# Text completion -response = client.completions.create( - model="default", - prompt=["The name of the famous soccer player is ", "The capital of US is"], - n=1, - temperature=0.8, - max_tokens=32, -) -print(response) - - -# Text completion -response = client.completions.create( - model="default", - prompt=["The name of the famous soccer player is ", "The capital of US is"], - n=3, - temperature=0.8, - max_tokens=32, -) -print(response) - - -response = client.completions.create( - model="default", - prompt=[ - "prompt1: I am a robot and I want to learn like humans. Now let's begin a tale. Once upon a time, there was a small", - "prompt2: As a robot, my goal is to understand human learning. Let's start a story. In a faraway land, there lived a tiny", - "prompt3: Being a robot, I aspire to study like people. Let's share a story. Long ago, there was a little", - "prompt4: I am a robot aiming to learn like humans. Let's narrate a story. Once, in a distant kingdom, there was a young", - "prompt5: As a robot, I seek to learn in human ways. Let's tell a story. Once upon a time, in a small village, there was a young", - ], - n=1, - temperature=0.8, - max_tokens=320, -) -print(response) - - -# Text completion -response = client.completions.create( - model="default", - prompt=[ - "The capital of France is", - "The capital of Germany is", - "The capital of US is", - ], - n=3, - temperature=0.8, - max_tokens=32, -) -print(response) - -# Chat completion -response = client.chat.completions.create( - model="default", - messages=[ - {"role": "system", "content": "You are a helpful AI assistant"}, - {"role": "user", "content": "List 3 countries and their capitals."}, - ], - temperature=0.8, - max_tokens=1, - logprobs=True, - top_logprobs=3, -) -print(response) - -# Chat completion -response = client.chat.completions.create( - model="default", - messages=[ - {"role": "system", "content": "You are a helpful AI assistant"}, - {"role": "user", "content": "List 3 countries and their capitals."}, - ], - temperature=0.8, - max_tokens=1, - n=1, -) -print(response) - -# Chat completion -response = client.chat.completions.create( - model="default", - messages=[ - {"role": "system", "content": "You are a helpful AI assistant"}, - {"role": "user", "content": "List 3 countries and their capitals."}, - ], - temperature=0.8, - max_tokens=1, - logprobs=True, - top_logprobs=3, -) -print(response) - -# Chat completion -response = client.chat.completions.create( - model="default", - messages=[ - {"role": "system", "content": "You are a helpful AI assistant"}, - {"role": "user", "content": "List 3 countries and their capitals."}, - ], - temperature=0.8, - max_tokens=1, - n=4, -) -print(response) diff --git a/examples/usage/rag_using_parea/max-tokens-fixed-rag-trace.png b/examples/usage/rag_using_parea/max-tokens-fixed-rag-trace.png deleted file mode 100644 index 2ea09fdc602..00000000000 Binary files a/examples/usage/rag_using_parea/max-tokens-fixed-rag-trace.png and /dev/null differ diff --git a/python/pyproject.toml b/python/pyproject.toml index 7ba4b4c6bda..daf09ea25de 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "sglang" -version = "0.2.13" +version = "0.3.0" description = "SGLang is yet another fast serving framework for large language models and vision language models." readme = "README.md" requires-python = ">=3.8" @@ -20,14 +20,14 @@ dependencies = [ ] [project.optional-dependencies] -srt = ["aiohttp", "fastapi", "hf_transfer", "huggingface_hub", "interegular", +srt = ["aiohttp", "decord", "fastapi", "hf_transfer", "huggingface_hub", "interegular", "packaging", "pillow", "psutil", "pydantic", "python-multipart", "torch", "uvicorn", "uvloop", "zmq", - "vllm==0.5.4", "outlines>=0.0.44"] + "vllm==0.5.5", "outlines>=0.0.44"] openai = ["openai>=1.0", "tiktoken"] anthropic = ["anthropic>=0.20.0"] litellm = ["litellm>=1.0.0"] -test = ["jsonlines", "matplotlib", "pandas"] +test = ["jsonlines", "matplotlib", "pandas", "sentence_transformers", "accelerate"] all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]", "sglang[litellm]"] dev = ["sglang[all]", "sglang[test]"] diff --git a/python/sglang/README.md b/python/sglang/README.md index c92144254a2..481c69affec 100644 --- a/python/sglang/README.md +++ b/python/sglang/README.md @@ -2,8 +2,8 @@ - `lang`: The frontend language. - `srt`: The backend engine for running local models. (SRT = SGLang Runtime). -- `test`: Test utilities. -- `api.py`: Public API. +- `test`: The test utilities. +- `api.py`: The public APIs. - `bench_latency.py`: Benchmark a single static batch. - `bench_serving.py`: Benchmark online serving with dynamic requests. - `global_config.py`: The global configs and constants. diff --git a/python/sglang/api.py b/python/sglang/api.py index 3a2f747bec2..71f129f8740 100644 --- a/python/sglang/api.py +++ b/python/sglang/api.py @@ -69,6 +69,12 @@ def gen( min_p: Optional[float] = None, frequency_penalty: Optional[float] = None, presence_penalty: Optional[float] = None, + # DRY sampling + dry_multiplier: float = 0.0, + dry_base: float = 0.0, + dry_allowed_length: int = 2, + dry_penalty_last_n: int = 0, + dry_sequence_breakers: Optional[List[str]] = [], ignore_eos: Optional[bool] = None, return_logprob: Optional[bool] = None, logprob_start_len: Optional[int] = None, @@ -78,6 +84,7 @@ def gen( choices: Optional[List[str]] = None, choices_method: Optional[ChoicesSamplingMethod] = None, regex: Optional[str] = None, + json_schema: Optional[str] = None, ): """Call the model to generate. See the meaning of the arguments in docs/en/sampling_params.md""" @@ -107,6 +114,11 @@ def gen( min_p, frequency_penalty, presence_penalty, + dry_multiplier, + dry_base, + dry_allowed_length, + dry_penalty_last_n, + dry_sequence_breakers, ignore_eos, return_logprob, logprob_start_len, @@ -114,6 +126,7 @@ def gen( return_text_in_logprobs, dtype, regex, + json_schema, ) @@ -128,6 +141,11 @@ def gen_int( min_p: Optional[float] = None, frequency_penalty: Optional[float] = None, presence_penalty: Optional[float] = None, + dry_multiplier: float = 0.0, + dry_base: float = 0.0, + dry_allowed_length: int = 2, + dry_penalty_last_n: int = 0, + dry_sequence_breakers: Optional[List[str]] = [], ignore_eos: Optional[bool] = None, return_logprob: Optional[bool] = None, logprob_start_len: Optional[int] = None, @@ -145,6 +163,11 @@ def gen_int( min_p, frequency_penalty, presence_penalty, + dry_multiplier, + dry_base, + dry_allowed_length, + dry_penalty_last_n, + dry_sequence_breakers, ignore_eos, return_logprob, logprob_start_len, @@ -166,6 +189,11 @@ def gen_string( min_p: Optional[float] = None, frequency_penalty: Optional[float] = None, presence_penalty: Optional[float] = None, + dry_multiplier: float = 0.0, + dry_base: float = 0.0, + dry_allowed_length: int = 2, + dry_penalty_last_n: int = 0, + dry_sequence_breakers: Optional[List[str]] = [], ignore_eos: Optional[bool] = None, return_logprob: Optional[bool] = None, logprob_start_len: Optional[int] = None, @@ -183,6 +211,11 @@ def gen_string( min_p, frequency_penalty, presence_penalty, + dry_multiplier, + dry_base, + dry_allowed_length, + dry_penalty_last_n, + dry_sequence_breakers, ignore_eos, return_logprob, logprob_start_len, diff --git a/python/sglang/bench_latency.py b/python/sglang/bench_latency.py index ba1a81d54dc..9006b7150aa 100644 --- a/python/sglang/bench_latency.py +++ b/python/sglang/bench_latency.py @@ -11,26 +11,34 @@ ## plot the results in series of lines: python -m sglang.bench_latency --result-filename out.jsonl --graph-sql="select run_name, batch_size, prefill_throughput from results" - # Usage (correctness test): python -m sglang.bench_latency --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --correct ## Reference output (of the correctness test above, can be gpu dependent): -prefill logits (first half) tensor([[-10.0312, -9.5000, 0.8936, ..., -4.9414, -3.2402, -3.3633], - [-10.0312, -9.5000, 0.8936, ..., -4.9414, -3.2402, -3.3633], - [ -9.1875, -10.2500, 2.7109, ..., -4.3359, -4.0664, -4.1328]], - device='cuda:0', dtype=torch.float16) -prefill logits (final) tensor([[-8.3203, -7.1211, 3.3379, ..., -4.9570, -4.1328, -3.4141], - [-8.9062, -9.0156, 4.1445, ..., -4.9922, -4.4961, -4.0742], - [-9.6328, -9.0547, 4.0117, ..., -5.3047, -4.7148, -4.4609]], - device='cuda:0', dtype=torch.float16) - The capital of France is. +input_ids=[[1, 450, 7483, 310, 3444, 338], [1, 450, 7483, 310, 278, 3303, 13187, 290, 338], [1, 20628, 338, 263, 6575, 1460, 2462, 322, 306, 763]] + +prefill logits (first half): tensor([[-10.0312, -9.5000, 0.8931, ..., -4.9414, -3.2422, -3.3633], + [-10.0312, -9.5000, 0.8931, ..., -4.9414, -3.2422, -3.3633], + [ -9.1875, -10.2500, 2.7129, ..., -4.3359, -4.0664, -4.1328]], + device='cuda:0') + +prefill logits (final): tensor([[-8.3125, -7.1172, 3.3457, ..., -4.9570, -4.1328, -3.4141], + [-8.9141, -9.0156, 4.1445, ..., -4.9922, -4.4961, -4.0781], + [-9.6328, -9.0547, 4.0195, ..., -5.3047, -4.7148, -4.4570]], + device='cuda:0') + +========== Prompt 0 ========== + The capital of France is Paris. The capital of the United States is Washington, D.C. - The capital of the United Kindom is. + +========== Prompt 1 ========== + The capital of the United Kindom is London. The capital of the United Kingdom is London. The capital of the - Today is a sunny day and I like go for a walk in the park. + +========== Prompt 2 ========== + Today is a sunny day and I like to go for a walk in the park. I'm going to the park """ @@ -111,7 +119,11 @@ def load_model(server_args, tp_rank): suppress_other_loggers() rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None - model_config = ModelConfig(path=server_args.model_path) + model_config = ModelConfig( + server_args.model_path, + server_args.trust_remote_code, + context_length=server_args.context_length, + ) model_runner = ModelRunner( model_config=model_config, mem_fraction_static=server_args.mem_fraction_static, @@ -196,16 +208,16 @@ def extend(reqs, model_runner): tree_cache=None, ) batch.prepare_for_extend(model_runner.model_config.vocab_size) - output = model_runner.forward(batch, ForwardMode.EXTEND) - next_token_ids = batch.sample(output.next_token_logits) - return next_token_ids, output.next_token_logits, batch + sample_output, logits_output = model_runner.forward(batch, ForwardMode.EXTEND) + next_token_ids = sample_output.batch_next_token_ids.tolist() + return next_token_ids, logits_output.next_token_logits, batch def decode(input_token_ids, batch, model_runner): - batch.prepare_for_decode(input_token_ids.cpu().numpy()) - output = model_runner.forward(batch, ForwardMode.DECODE) - next_token_ids = batch.sample(output.next_token_logits) - return next_token_ids, output.next_token_logits + batch.prepare_for_decode(input_token_ids) + sample_output, logits_output = model_runner.forward(batch, ForwardMode.DECODE) + next_token_ids = sample_output.batch_next_token_ids.tolist() + return next_token_ids, logits_output.next_token_logits @torch.inference_mode() @@ -221,12 +233,12 @@ def correctness_test( # Prepare inputs input_ids, reqs = prepare_inputs_for_correctness_test(bench_args, tokenizer) - rank_print(f"{input_ids=}") + rank_print(f"\n{input_ids=}\n") if bench_args.cut_len > 0: # Prefill next_token_ids, next_token_logits, batch = extend(reqs, model_runner) - rank_print("prefill logits (first half)", next_token_logits) + rank_print(f"prefill logits (first half): {next_token_logits} \n") # Prepare extend inputs reqs = prepare_extend_inputs_for_correctness_test( @@ -235,7 +247,7 @@ def correctness_test( # Extend next_token_ids, next_token_logits, batch = extend(reqs, model_runner) - rank_print("prefill logits (final)", next_token_logits) + rank_print(f"prefill logits (final): {next_token_logits} \n") # Decode output_ids = [input_ids[i] + [next_token_ids[i]] for i in range(len(input_ids))] @@ -246,7 +258,8 @@ def correctness_test( # Print for i in range(len(reqs)): - rank_print(tokenizer.decode(output_ids[i])) + rank_print(f"========== Prompt {i} ==========") + rank_print(tokenizer.decode(output_ids[i]), "\n") @torch.inference_mode() @@ -288,6 +301,7 @@ def latency_test_run_once( measurement_results["prefill_throughput"] = throughput # Decode + decode_latencies = [] for i in range(output_len): torch.cuda.synchronize() tic = time.time() @@ -296,17 +310,18 @@ def latency_test_run_once( latency = time.time() - tic tot_latency += latency throughput = batch_size / latency + decode_latencies.append(latency) if i < 5: rank_print( f"Decode. latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s" ) - avg_decode_latency = (tot_latency - prefill_latency) / output_len - avg_decode_throughput = batch_size / avg_decode_latency + med_decode_latency = np.median(decode_latencies) + med_decode_throughput = batch_size / med_decode_latency rank_print( - f"Decode. avg latency: {avg_decode_latency:6.5f} s, avg throughput: {avg_decode_throughput:9.2f} token/s" + f"Decode. median latency: {med_decode_latency:6.5f} s, median throughput: {med_decode_throughput:9.2f} token/s" ) - measurement_results["avg_decode_latency"] = avg_decode_latency - measurement_results["avg_decode_throughput"] = avg_decode_throughput + measurement_results["median_decode_latency"] = med_decode_latency + measurement_results["median_decode_throughput"] = med_decode_throughput throughput = (input_len + output_len) * batch_size / tot_latency rank_print( diff --git a/python/sglang/global_config.py b/python/sglang/global_config.py index d5f16e2ae54..7bd5aa0901f 100644 --- a/python/sglang/global_config.py +++ b/python/sglang/global_config.py @@ -11,10 +11,6 @@ def __init__(self): # Default backend of the language self.default_backend = None - # Runtime constants: Request dependency time due to network delay - self.request_dependency_delay = 0.02 - self.wait_for_new_request_delay = 0.0006 - # Runtime constants: New generation token ratio estimation self.init_new_token_ratio = 0.7 self.base_min_new_token_ratio = 0.1 diff --git a/python/sglang/lang/backend/runtime_endpoint.py b/python/sglang/lang/backend/runtime_endpoint.py index 5012f646ea1..344b51d2dc2 100644 --- a/python/sglang/lang/backend/runtime_endpoint.py +++ b/python/sglang/lang/backend/runtime_endpoint.py @@ -4,7 +4,7 @@ from sglang.global_config import global_config from sglang.lang.backend.base_backend import BaseBackend -from sglang.lang.chat_template import get_chat_template_by_model_path +from sglang.lang.chat_template import get_chat_template, get_chat_template_by_model_path from sglang.lang.choices import ChoicesDecision, ChoicesSamplingMethod from sglang.lang.interpreter import StreamExecutor from sglang.lang.ir import ( @@ -23,6 +23,7 @@ def __init__( base_url: str, api_key: Optional[str] = None, verify: Optional[str] = None, + chat_template_name: Optional[str] = None, ): super().__init__() self.support_concate_and_append = True @@ -39,9 +40,12 @@ def __init__( self._assert_success(res) self.model_info = res.json() - self.chat_template = get_chat_template_by_model_path( - self.model_info["model_path"] - ) + if chat_template_name: + self.chat_template = get_chat_template(chat_template_name) + else: + self.chat_template = get_chat_template_by_model_path( + self.model_info["model_path"] + ) def get_model_name(self): return self.model_info["model_path"] diff --git a/python/sglang/lang/chat_template.py b/python/sglang/lang/chat_template.py index bfde4bbdb6a..fa300b25f02 100644 --- a/python/sglang/lang/chat_template.py +++ b/python/sglang/lang/chat_template.py @@ -1,6 +1,6 @@ -from dataclasses import dataclass, field +from dataclasses import dataclass from enum import Enum, auto -from typing import Callable, Dict, List, Optional, Tuple +from typing import Callable, Dict, List, Tuple class ChatTemplateStyle(Enum): @@ -137,7 +137,7 @@ def get_chat_template_by_model_path(model_path): register_chat_template( ChatTemplate( name="chatml-llava", - default_system_prompt="Answer the questions.", + default_system_prompt="You are a helpful assistant.", role_prefix_and_suffix={ "system": ("<|im_start|>system\n", "<|im_end|>\n"), "user": ("<|im_start|>user\n", "<|im_end|>\n"), @@ -145,7 +145,7 @@ def get_chat_template_by_model_path(model_path): }, style=ChatTemplateStyle.PLAIN, stop_str=("<|im_end|>",), - image_token=" \n", + image_token="\n", ) ) @@ -322,12 +322,17 @@ def match_chat_ml(model_path: str): if "tinyllama" in model_path: return get_chat_template("chatml") # Now the suffix for qwen2 chat model is "instruct" - if "qwen" in model_path and ("chat" in model_path or "instruct" in model_path): + if ( + "qwen" in model_path + and ("chat" in model_path or "instruct" in model_path) + and ("llava" not in model_path) + ): return get_chat_template("qwen") if ( "llava-v1.6-34b" in model_path or "llava-v1.6-yi-34b" in model_path or "llava-next-video-34b" in model_path + or "llava-onevision-qwen2" in model_path ): return get_chat_template("chatml-llava") diff --git a/python/sglang/lang/compiler.py b/python/sglang/lang/compiler.py index 5e1b411fc29..cdc78f5afbf 100644 --- a/python/sglang/lang/compiler.py +++ b/python/sglang/lang/compiler.py @@ -1,7 +1,7 @@ import multiprocessing from concurrent.futures import ThreadPoolExecutor from queue import Queue -from typing import List, Union +from typing import List, Union, Optional from sglang.global_config import global_config from sglang.lang.interpreter import ProgramState, StreamExecutor, cache_program @@ -133,6 +133,11 @@ def run( min_p: float = 0.0, frequency_penalty: float = 0.0, presence_penalty: float = 0.0, + dry_multiplier: float = 0.0, + dry_base: float = 0.0, + dry_allowed_length: int = 2, + dry_penalty_last_n: int = 0, + dry_sequence_breakers: Optional[List[str]] = [], backend=None, **kwargs, ): @@ -149,6 +154,11 @@ def run( min_p=min_p, frequency_penalty=frequency_penalty, presence_penalty=presence_penalty, + dry_multiplier=dry_multiplier, + dry_base=dry_base, + dry_allowed_length=dry_allowed_length, + dry_penalty_last_n=dry_penalty_last_n, + dry_sequence_breakers=dry_sequence_breakers, ) return self.run_internal(backend, kwargs, default_sampling_para) @@ -165,6 +175,11 @@ def run_batch( min_p: float = 0.0, frequency_penalty: float = 0.0, presence_penalty: float = 0.0, + dry_multiplier: float = 0.0, + dry_base: float = 0.0, + dry_allowed_length: int = 2, + dry_penalty_last_n: int = 0, + dry_sequence_breakers: Optional[List[str]] = [], backend=None, num_threads: Union[str, int] = "auto", ): @@ -184,6 +199,11 @@ def run_batch( min_p=min_p, frequency_penalty=frequency_penalty, presence_penalty=presence_penalty, + dry_multiplier=dry_multiplier, + dry_base=dry_base, + dry_allowed_length=dry_allowed_length, + dry_penalty_last_n=dry_penalty_last_n, + dry_sequence_breakers=dry_sequence_breakers, ) # Extract prefix by tracing and cache it diff --git a/python/sglang/lang/interpreter.py b/python/sglang/lang/interpreter.py index 306d280c7f0..c2731857ff0 100644 --- a/python/sglang/lang/interpreter.py +++ b/python/sglang/lang/interpreter.py @@ -666,6 +666,11 @@ def _resolve_sampling_params(self, sampling_params): "min_p", "frequency_penalty", "presence_penalty", + "dry_multiplier", + "dry_base", + "dry_allowed_length", + "dry_penalty_last_n", + "dry_sequence_breakers", "ignore_eos", "return_logprob", "logprob_start_len", @@ -673,6 +678,7 @@ def _resolve_sampling_params(self, sampling_params): "return_text_in_logprobs", "dtype", "regex", + "json_schema", ]: value = getattr(sampling_params, item, None) if value is not None: @@ -854,6 +860,8 @@ def get_meta_info(self, name): return self.stream_executor.get_meta_info(name) def __iadd__(self, other): + if other is None: + raise ValueError("Tried to append None to state.") self.stream_executor.submit(other) return self diff --git a/python/sglang/lang/ir.py b/python/sglang/lang/ir.py index 199a7ac7a4e..67e17c76085 100644 --- a/python/sglang/lang/ir.py +++ b/python/sglang/lang/ir.py @@ -25,11 +25,17 @@ class SglSamplingParams: min_p: float = 0.0 frequency_penalty: float = 0.0 presence_penalty: float = 0.0 + dry_multiplier: float = 0.0 + dry_base: float = 0.0 + dry_allowed_length: int = 2 + dry_penalty_last_n: int = 0 + dry_sequence_breakers: Optional[List[str]] = () ignore_eos: bool = False return_logprob: Optional[bool] = None logprob_start_len: Optional[int] = (None,) top_logprobs_num: Optional[int] = (None,) return_text_in_logprobs: Optional[bool] = (None,) + json_schema: Optional[str] = None # for constrained generation, not included in to_xxx_kwargs dtype: Optional[str] = None @@ -46,11 +52,17 @@ def clone(self): self.min_p, self.frequency_penalty, self.presence_penalty, + self.dry_multiplier, + self.dry_base, + self.dry_allowed_length, + self.dry_penalty_last_n, + self.dry_sequence_breakers, self.ignore_eos, self.return_logprob, self.logprob_start_len, self.top_logprobs_num, self.return_text_in_logprobs, + self.json_schema, ) def to_openai_kwargs(self): @@ -62,8 +74,15 @@ def to_openai_kwargs(self): "stop": self.stop or None, "temperature": self.temperature, "top_p": self.top_p, + "min_p": self.min_p, + "top_k": self.top_k, "frequency_penalty": self.frequency_penalty, "presence_penalty": self.presence_penalty, + "dry_multiplier": self.dry_multiplier, + "dry_base": self.dry_base, + "dry_allowed_length": self.dry_allowed_length, + "dry_penalty_last_n": self.dry_penalty_last_n, + "dry_sequence_breakers": self.dry_sequence_breakers, } def to_vertexai_kwargs(self): @@ -78,6 +97,12 @@ def to_vertexai_kwargs(self): "temperature": self.temperature, "top_p": self.top_p, "top_k": self.top_k if self.top_k > 0 else None, + "min_p": self.min_p, + "dry_multiplier": self.dry_multiplier, + "dry_base": self.dry_base, + "dry_allowed_length": self.dry_allowed_length, + "dry_penalty_last_n": self.dry_penalty_last_n, + "dry_sequence_breakers": self.dry_sequence_breakers, } def to_anthropic_kwargs(self): @@ -106,6 +131,11 @@ def to_litellm_kwargs(self): "top_p": self.top_p, "frequency_penalty": self.frequency_penalty, "presence_penalty": self.presence_penalty, + "dry_multiplier": self.dry_multiplier, + "dry_base": self.dry_base, + "dry_allowed_length": self.dry_allowed_length, + "dry_penalty_last_n": self.dry_penalty_last_n, + "dry_sequence_breakers": self.dry_sequence_breakers, } def to_srt_kwargs(self): @@ -119,8 +149,14 @@ def to_srt_kwargs(self): "min_p": self.min_p, "frequency_penalty": self.frequency_penalty, "presence_penalty": self.presence_penalty, + "dry_multiplier": self.dry_multiplier, + "dry_base": self.dry_base, + "dry_allowed_length": self.dry_allowed_length, + "dry_penalty_last_n": self.dry_penalty_last_n, + "dry_sequence_breakers": self.dry_sequence_breakers, "ignore_eos": self.ignore_eos, "regex": self.regex, + "json_schema": self.json_schema, } @@ -155,6 +191,11 @@ def run( min_p: float = 0.0, frequency_penalty: float = 0.0, presence_penalty: float = 0.0, + dry_multiplier: float = 0.0, + dry_base: float = 0.0, + dry_allowed_length: int = 2, + dry_penalty_last_n: int = 0, + dry_sequence_breakers: Optional[List[str]] = [], ignore_eos: bool = False, return_logprob: Optional[bool] = None, logprob_start_len: Optional[int] = None, @@ -176,6 +217,11 @@ def run( min_p=min_p, frequency_penalty=frequency_penalty, presence_penalty=presence_penalty, + dry_multiplier=dry_multiplier, + dry_base=dry_base, + dry_allowed_length=dry_allowed_length, + dry_penalty_last_n=dry_penalty_last_n, + dry_sequence_breakers=dry_sequence_breakers, ignore_eos=ignore_eos, return_logprob=return_logprob, logprob_start_len=logprob_start_len, @@ -198,6 +244,11 @@ def run_batch( min_p: float = 0.0, frequency_penalty: float = 0.0, presence_penalty: float = 0.0, + dry_multiplier: float = 0.0, + dry_base: float = 0.0, + dry_allowed_length: int = 2, + dry_penalty_last_n: int = 0, + dry_sequence_breakers: Optional[List[str]] = [], ignore_eos: bool = False, return_logprob: Optional[bool] = None, logprob_start_len: Optional[int] = None, @@ -237,6 +288,11 @@ def run_batch( min_p=min_p, frequency_penalty=frequency_penalty, presence_penalty=presence_penalty, + dry_multiplier=dry_multiplier, + dry_base=dry_base, + dry_allowed_length=dry_allowed_length, + dry_penalty_last_n=dry_penalty_last_n, + dry_sequence_breakers=dry_sequence_breakers, ignore_eos=ignore_eos, return_logprob=return_logprob, logprob_start_len=logprob_start_len, @@ -418,6 +474,11 @@ def __init__( min_p: Optional[float] = None, frequency_penalty: Optional[float] = None, presence_penalty: Optional[float] = None, + dry_multiplier: float = 0.0, + dry_base: float = 0.0, + dry_allowed_length: int = 2, + dry_penalty_last_n: int = 0, + dry_sequence_breakers: Optional[List[str]] = [], ignore_eos: Optional[bool] = None, return_logprob: Optional[bool] = None, logprob_start_len: Optional[int] = None, @@ -425,6 +486,7 @@ def __init__( return_text_in_logprobs: Optional[bool] = None, dtype: Optional[type] = None, regex: Optional[str] = None, + json_schema: Optional[str] = None, ): """Call the model to generate. See the meaning of the arguments in docs/en/sampling_params.md""" super().__init__() @@ -439,6 +501,11 @@ def __init__( min_p=min_p, frequency_penalty=frequency_penalty, presence_penalty=presence_penalty, + dry_multiplier=dry_multiplier, + dry_base=dry_base, + dry_allowed_length=dry_allowed_length, + dry_penalty_last_n=dry_penalty_last_n, + dry_sequence_breakers=dry_sequence_breakers, ignore_eos=ignore_eos, return_logprob=return_logprob, logprob_start_len=logprob_start_len, @@ -446,6 +513,7 @@ def __init__( return_text_in_logprobs=return_text_in_logprobs, dtype=dtype, regex=regex, + json_schema=json_schema, ) def __repr__(self): diff --git a/python/sglang/launch_server_llavavid.py b/python/sglang/launch_server_llavavid.py index c34dd211672..43eefef4efa 100644 --- a/python/sglang/launch_server_llavavid.py +++ b/python/sglang/launch_server_llavavid.py @@ -5,25 +5,22 @@ from sglang.srt.server import ServerArgs, launch_server if __name__ == "__main__": - model_overide_args = {} - - model_overide_args["mm_spatial_pool_stride"] = 2 - model_overide_args["architectures"] = ["LlavaVidForCausalLM"] - model_overide_args["num_frames"] = 16 - model_overide_args["model_type"] = "llavavid" - if model_overide_args["num_frames"] == 32: - model_overide_args["rope_scaling"] = {"factor": 2.0, "type": "linear"} - model_overide_args["max_sequence_length"] = 4096 * 2 - model_overide_args["tokenizer_model_max_length"] = 4096 * 2 - model_overide_args["model_max_length"] = 4096 * 2 - parser = argparse.ArgumentParser() ServerArgs.add_cli_args(parser) args = parser.parse_args() + server_args = ServerArgs.from_cli_args(args) + model_override_args = {} + model_override_args["mm_spatial_pool_stride"] = 2 + model_override_args["architectures"] = ["LlavaVidForCausalLM"] + model_override_args["num_frames"] = 16 + model_override_args["model_type"] = "llavavid" + if model_override_args["num_frames"] == 32: + model_override_args["rope_scaling"] = {"factor": 2.0, "type": "linear"} + model_override_args["max_sequence_length"] = 4096 * 2 + model_override_args["tokenizer_model_max_length"] = 4096 * 2 + model_override_args["model_max_length"] = 4096 * 2 if "34b" in args.model_path.lower(): - model_overide_args["image_token_index"] = 64002 - - server_args = ServerArgs.from_cli_args(args) + model_override_args["image_token_index"] = 64002 - launch_server(server_args, model_overide_args, None) + launch_server(server_args, model_override_args, None) diff --git a/python/sglang/srt/configs/__init__.py b/python/sglang/srt/configs/__init__.py new file mode 100644 index 00000000000..9e74366709f --- /dev/null +++ b/python/sglang/srt/configs/__init__.py @@ -0,0 +1,5 @@ +from sglang.srt.configs.exaone import ExaoneConfig + +__all__ = [ + "ExaoneConfig", +] diff --git a/python/sglang/srt/configs/exaone.py b/python/sglang/srt/configs/exaone.py new file mode 100644 index 00000000000..7b0a2d290da --- /dev/null +++ b/python/sglang/srt/configs/exaone.py @@ -0,0 +1,195 @@ +# coding=utf-8 +# Copyright 2024 The LG AI Research EXAONE Lab. All rights reserved. +# Copyright 2024 The LG CNS AI Engineering Team. +# Copyright 2023-2024 SGLang Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" EXAONE model configuration """ +from typing import Any, Dict + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + +EXAONE_PRETRAINED_CONFIG_ARCHIVE_MAP: Dict[str, Any] = {} + + +# ruff: noqa: E501 +class ExaoneConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a :class:`~transformers.ExaoneModel`. It is used to + instantiate a EXAONE model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the Exaone + + Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model + outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information. + + + Args: + vocab_size (:obj:`int`, `optional`, defaults to 102400): + Vocabulary size of the EXAONE model. Defines the number of different tokens that can be represented by the + :obj:`inputs_ids` passed when calling :class:`~transformers.ExaoneModel`. Vocabulary size of the model. + Defines the different tokens that can be represented by the `inputs_ids` passed to the forward method of + :class:`~transformers.EXAONEModel`. + max_position_embeddings (:obj:`int`, `optional`, defaults to 2048): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + hidden_size (:obj:`int`, `optional`, defaults to 2048): + Dimensionality of the encoder layers and the pooler layer. + num_layers (:obj:`int`, `optional`, defaults to 32): + Number of hidden layers in the Transformer encoder. + num_attention_heads (:obj:`int`, `optional`, defaults to 32): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (:obj:`int`, `optional`): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + intermediate_size (:obj:`int`, `optional`, defaults to `hidden_size * 4`): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + activation_function (:obj:`str` or :obj:`function`, `optional`, defaults to :obj:`"silu"`): + The non-linear activation function (function or string) in the decoder. + rope_theta (:obj:`float`, `optional`, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (:obj:`Dict`, `optional`): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (:obj:`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3'], with 'default' being the original RoPE implementation. + `factor` (:obj:`float`, `optional`): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (:obj:`int`, `optional`): + Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (:obj:`float`, `optional`): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (:obj:`float`, `optional`): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (:obj:`float`, `optional`): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (:obj:`List[float]`, `optional`): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (:obj:`List[float]`, `optional`): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (:obj:`float`, `optional`): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (:obj:`float`, `optional`): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + embed_dropout (:obj:`float`, `optional`, defaults to 0.0): + The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler. + attention_dropout (:obj:`float`, `optional`, defaults to 0.0): + The dropout ratio for the attention probabilities. + layer_norm_epsilon (:obj:`float`, `optional`, defaults to 1e-5): + The epsilon used by the layer normalization layers. + initializer_range (:obj:`float`, `optional`, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if ``configs.is_decoder=True``. + bos_token_id (:obj:`int`, `optional`, defaults to 0): + Beginning of stream token id. + eos_token_id (:obj:`int`, `optional`, defaults to 2): + End of stream token id. + tie_word_embeddings (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether to tie weight embeddings + gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`): + If True, use gradient checkpointing to save memory at the expense of slower backward pass. + + Example:: + + >>> from transformers import EXAONEModel, ExaoneConfig + + >>> # Initializing a EXAONE configuration + >>> configuration = ExaoneConfig() + + >>> # Initializing a model from configuration + >>> model = EXAONEModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.configs + """ + + model_type = "exaone" + keys_to_ignore_at_inference = ["past_key_values"] + attribute_map = {"num_hidden_layers": "num_layers"} + + def __init__( + self, + vocab_size=102400, + max_position_embeddings=2048, + hidden_size=2048, + num_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + intermediate_size=None, + activation_function="silu", + rope_theta=10000.0, + rope_scaling=None, + embed_dropout=0.0, + attention_dropout=0.0, + layer_norm_epsilon=1e-5, + initializer_range=0.02, + use_cache=True, + bos_token_id=0, + eos_token_id=2, + tie_word_embeddings=True, + **kwargs + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.num_layers = num_layers + self.num_attention_heads = num_attention_heads + self.num_hidden_layers = num_layers + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + if intermediate_size: + self.intermediate_size = intermediate_size + else: + self.intermediate_size = hidden_size * 4 + self.activation_function = activation_function + self.embed_dropout = embed_dropout + self.attention_dropout = attention_dropout + self.layer_norm_epsilon = layer_norm_epsilon + self.initializer_range = initializer_range + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + + super().__init__( + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs + ) diff --git a/python/sglang/srt/constrained/fsm_cache.py b/python/sglang/srt/constrained/fsm_cache.py index fa41f90de3c..57c49130622 100644 --- a/python/sglang/srt/constrained/fsm_cache.py +++ b/python/sglang/srt/constrained/fsm_cache.py @@ -15,6 +15,8 @@ """Cache for the compressed finite state machine.""" +from outlines.fsm.json_schema import build_regex_from_schema + from sglang.srt.constrained import RegexGuide, TransformerTokenizer from sglang.srt.constrained.base_tool_cache import BaseToolCache @@ -26,9 +28,12 @@ def __init__( tokenizer_args_dict, enable=True, skip_tokenizer_init=False, + json_schema_mode=False, ): super().__init__(enable=enable) + self.json_schema_mode = json_schema_mode + if ( skip_tokenizer_init or tokenizer_path.endswith(".json") @@ -72,5 +77,9 @@ def fset(self, value): tokenizer_path, **tokenizer_args_dict ) - def init_value(self, regex): - return RegexGuide(regex, self.outlines_tokenizer) + def init_value(self, value): + if self.json_schema_mode: + regex = build_regex_from_schema(value, whitespace_pattern=r"[\n\t ]*") + return RegexGuide(regex, self.outlines_tokenizer), regex + else: + return RegexGuide(value, self.outlines_tokenizer) diff --git a/python/sglang/srt/constrained/jump_forward.py b/python/sglang/srt/constrained/jump_forward.py index b00c48d4784..244931e0509 100644 --- a/python/sglang/srt/constrained/jump_forward.py +++ b/python/sglang/srt/constrained/jump_forward.py @@ -23,6 +23,7 @@ import interegular import outlines.caching +from outlines.fsm.json_schema import build_regex_from_schema from sglang.srt.constrained import ( FSMInfo, diff --git a/python/sglang/srt/conversation.py b/python/sglang/srt/conversation.py index 5ee12169740..dbc376d9593 100644 --- a/python/sglang/srt/conversation.py +++ b/python/sglang/srt/conversation.py @@ -34,6 +34,7 @@ class SeparatorStyle(IntEnum): NO_COLON_TWO = auto() ADD_NEW_LINE_SINGLE = auto() LLAMA2 = auto() + LLAMA3 = auto() CHATGLM = auto() CHATML = auto() CHATINTERN = auto() @@ -137,6 +138,20 @@ def get_prompt(self) -> str: else: ret += role + ":" return ret + elif self.sep_style == SeparatorStyle.LLAMA3: + ret = "<|begin_of_text|>" + if self.system_message: + ret += system_prompt + else: + ret += "" + for i, (role, message) in enumerate(self.messages): + if message: + ret += f"<|start_header_id|>{role}<|end_header_id|>\n\n" + ret += f"{message.strip()}<|eot_id|>" + else: + ret += f"<|start_header_id|>{role}<|end_header_id|>\n\n" + # print(ret) + return ret elif self.sep_style == SeparatorStyle.LLAMA2: seps = [self.sep, self.sep2] if self.system_message: @@ -371,7 +386,16 @@ def generate_chat_conv( for message in request.messages: msg_role = message.role if msg_role == "system": - conv.system_message = message.content + if isinstance(message.content, str): + conv.system_message = message.content + elif isinstance(message.content, list): + if ( + len(message.content) != 1 + or getattr(message.content[0], "type", None) != "text" + ): + raise ValueError("The system message should be a single text.") + else: + conv.system_message = getattr(message.content[0], "text", "") elif msg_role == "user": # Handle the various types of Chat Request content types here. role = conv.roles[0] @@ -379,16 +403,40 @@ def generate_chat_conv( conv.append_message(conv.roles[0], message.content) else: real_content = "" + # calculate number of image_url + num_image_url = 0 + for content in message.content: + if content.type == "image_url": + num_image_url += 1 + if num_image_url > 1: + image_token = "" + else: + image_token = "\n" for content in message.content: if content.type == "text": + if num_image_url > 16: + real_content += "\n" # for video real_content += content.text elif content.type == "image_url": # NOTE: Only works for llava - real_content += "\n" + real_content += image_token conv.append_image(content.image_url.url) conv.append_message(conv.roles[0], real_content) elif msg_role == "assistant": - conv.append_message(conv.roles[1], message.content) + parsed_content = "" + if isinstance(message.content, str): + parsed_content = message.content + elif isinstance(message.content, list): + if ( + len(message.content) != 1 + or getattr(message.content[0], "type", None) != "text" + ): + raise ValueError( + "The assistant's response should be a single text." + ) + else: + parsed_content = getattr(message.content[0], "text", "") + conv.append_message(conv.roles[1], parsed_content) else: raise ValueError(f"Unknown role: {msg_role}") @@ -425,6 +473,18 @@ def generate_chat_conv( ) ) +register_conv_template( + Conversation( + name="chatml-llava", + system_template="<|im_start|>system\n{system_message}", + system_message="You are a helpful assistant.", + roles=("<|im_start|>user", "<|im_start|>assistant"), + sep_style=SeparatorStyle.CHATML, + sep="<|im_end|>", + stop_str=["<|endoftext|>", "<|im_end|>"], + ) +) + register_conv_template( Conversation( name="vicuna_v1.1", @@ -437,6 +497,17 @@ def generate_chat_conv( ) ) +register_conv_template( + Conversation( + name="llava_llama_3", + system_message="You are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.", + system_template="<|start_header_id|>system<|end_header_id|>\n\n{system_message}<|eot_id|>", + roles=("user", "assistant"), + sep_style=SeparatorStyle.LLAMA3, + sep="", + stop_str=["<|end_of_text|>", "<|eot_id|>"], + ) +) # Reference: https://github.com/InternLM/lmdeploy/blob/387bf54b4f124e72aab30ae9755f562e435d3d01/lmdeploy/model.py#L425-L442 register_conv_template( Conversation( diff --git a/python/sglang/srt/hf_transformers_utils.py b/python/sglang/srt/hf_transformers_utils.py index 525d2954399..ae3070c5a78 100644 --- a/python/sglang/srt/hf_transformers_utils.py +++ b/python/sglang/srt/hf_transformers_utils.py @@ -15,6 +15,7 @@ """Utilities for Huggingface Transformers.""" +import contextlib import functools import json import os @@ -34,15 +35,20 @@ try: from vllm.transformers_utils.configs import ChatGLMConfig, DbrxConfig + from sglang.srt.configs import ExaoneConfig + _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = { ChatGLMConfig.model_type: ChatGLMConfig, DbrxConfig.model_type: DbrxConfig, + ExaoneConfig.model_type: ExaoneConfig, } except ImportError: # We want this file to run without vllm dependency _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {} -from sglang.srt.utils import is_multimodal_model +for name, cls in _CONFIG_REGISTRY.items(): + with contextlib.suppress(ValueError): + AutoConfig.register(name, cls) def download_from_hf(model_path: str): @@ -52,17 +58,11 @@ def download_from_hf(model_path: str): return snapshot_download(model_path, allow_patterns=["*.json", "*.bin", "*.model"]) -def get_config_json(model_path: str): - with open(os.path.join(model_path, "config.json")) as f: - config = json.load(f) - return config - - def get_config( model: str, trust_remote_code: bool, revision: Optional[str] = None, - model_overide_args: Optional[dict] = None, + model_override_args: Optional[dict] = None, ): config = AutoConfig.from_pretrained( model, trust_remote_code=trust_remote_code, revision=revision @@ -70,8 +70,8 @@ def get_config( if config.model_type in _CONFIG_REGISTRY: config_class = _CONFIG_REGISTRY[config.model_type] config = config_class.from_pretrained(model, revision=revision) - if model_overide_args: - config.update(model_overide_args) + if model_override_args: + config.update(model_override_args) return config @@ -89,10 +89,10 @@ def get_config( def get_context_length(config): - """Get the context length of a model from a huggingface model config.""" + """Get the context length of a model from a huggingface model configs.""" rope_scaling = getattr(config, "rope_scaling", None) if rope_scaling: - rope_scaling_factor = config.rope_scaling["factor"] + rope_scaling_factor = config.rope_scaling.get("factor", 1) if "original_max_position_embeddings" in rope_scaling: rope_scaling_factor = 1 if config.rope_scaling.get("rope_type", None) == "llama3": @@ -119,40 +119,12 @@ def get_tokenizer( tokenizer_revision: Optional[str] = None, **kwargs, ) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: - if tokenizer_name.endswith(".json"): - return TiktokenTokenizer(tokenizer_name) - - if tokenizer_name.endswith(".model"): - return SentencePieceTokenizer(tokenizer_name) - """Gets a tokenizer for the given model name via Huggingface.""" - if is_multimodal_model(tokenizer_name): - processor = get_processor( - tokenizer_name, - *args, - trust_remote_code=trust_remote_code, - tokenizer_revision=tokenizer_revision, - **kwargs, - ) - tokenizer = processor.tokenizer - return tokenizer - if tokenizer_mode == "slow": if kwargs.get("use_fast", False): raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.") kwargs["use_fast"] = False - if ( - "llama" in tokenizer_name.lower() - and kwargs.get("use_fast", True) - and tokenizer_name != _FAST_LLAMA_TOKENIZER - ): - warnings.warn( - "For some LLaMA V1 models, initializing the fast tokenizer may " - "take a long time. To reduce the initialization time, consider " - f"using '{_FAST_LLAMA_TOKENIZER}' instead of the original " - "tokenizer." - ) try: tokenizer = AutoTokenizer.from_pretrained( tokenizer_name, @@ -210,135 +182,3 @@ def get_processor( **kwargs, ) return processor - - -class TiktokenTokenizer: - def __init__(self, tokenizer_path): - import tiktoken - from jinja2 import Template - - PAT_STR_B = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""" - - # Read JSON - name = "tmp-json" - with open(tokenizer_path, "rb") as fin: - tok_dict = json.load(fin) - - mergeable_ranks = { - bytes(item["bytes"]): item["token"] for item in tok_dict["regular_tokens"] - } - special_tokens = { - bytes(item["bytes"]).decode(): item["token"] - for item in tok_dict["special_tokens"] - } - assert tok_dict["word_split"] == "V1" - - default_allowed_special = None - - kwargs = { - "name": name, - "pat_str": tok_dict.get("pat_str", PAT_STR_B), - "mergeable_ranks": mergeable_ranks, - "special_tokens": special_tokens, - } - if "default_allowed_special" in tok_dict: - default_allowed_special = set( - [ - bytes(bytes_list).decode() - for bytes_list in tok_dict["default_allowed_special"] - ] - ) - if "vocab_size" in tok_dict: - kwargs["explicit_n_vocab"] = tok_dict["vocab_size"] - - PAD = "<|pad|>" - EOS = "<|eos|>" - SEP = "<|separator|>" - - DEFAULT_CONTROL_TOKENS = {"pad": PAD, "sep": EOS, "eos": SEP} - - tokenizer = tiktoken.Encoding(**kwargs) - tokenizer._default_allowed_special = default_allowed_special or set() - tokenizer._control_tokens = DEFAULT_CONTROL_TOKENS - - def encode_patched( - self, - text: str, - *, - allowed_special: Union[ - Literal["all"], AbstractSet[str] - ] = set(), # noqa: B006 - disallowed_special: Union[Literal["all"], Collection[str]] = "all", - ) -> List[int]: - if isinstance(allowed_special, set): - allowed_special |= self._default_allowed_special - return tiktoken.Encoding.encode( - self, - text, - allowed_special=allowed_special, - disallowed_special=(), - ) - - tokenizer.encode = functools.partial(encode_patched, tokenizer) - - # Convert to HF interface - self.tokenizer = tokenizer - self.eos_token_id = tokenizer._special_tokens[EOS] - self.vocab_size = tokenizer.n_vocab - self.chat_template = Template( - "{% for message in messages %}{% if message['role'] == 'user' %}{{ 'Human: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'system' %}{{ 'System: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + '<|separator|>\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}" - ) - - def encode(self, x, add_special_tokens=False): - return self.tokenizer.encode(x) - - def decode(self, x): - return self.tokenizer.decode(x) - - def batch_decode( - self, batch, skip_special_tokens=True, spaces_between_special_tokens=False - ): - if isinstance(batch[0], int): - batch = [[x] for x in batch] - return self.tokenizer.decode_batch(batch) - - def apply_chat_template(self, messages, tokenize, add_generation_prompt): - ret = self.chat_template.render( - messages=messages, add_generation_prompt=add_generation_prompt - ) - return self.encode(ret) if tokenize else ret - - -class SentencePieceTokenizer: - def __init__(self, tokenizer_path): - import sentencepiece as spm - from jinja2 import Template - - tokenizer = spm.SentencePieceProcessor(model_file=tokenizer_path) - - # Convert to HF interface - self.tokenizer = tokenizer - self.eos_token_id = tokenizer.eos_id() - self.vocab_size = tokenizer.vocab_size() - self.chat_template = Template( - "{% for message in messages %}{% if message['role'] == 'user' %}{{ 'Human: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'system' %}{{ 'System: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + '<|separator|>\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}" - ) - - def encode(self, x, add_special_tokens=False): - return self.tokenizer.encode(x) - - def decode(self, x): - return self.tokenizer.decode(x) - - def batch_decode( - self, batch, skip_special_tokens=True, spaces_between_special_tokens=False - ): - if isinstance(batch[0], int): - batch = [[x] for x in batch] - return self.tokenizer.decode(batch) - - def apply_chat_template(self, messages, tokenize, add_generation_prompt): - ret = self.chat_template.render( - messages=messages, add_generation_prompt=add_generation_prompt - ) - return self.encode(ret) if tokenize else ret diff --git a/python/sglang/srt/layers/activation.py b/python/sglang/srt/layers/activation.py index a6f05610bd4..9047197af2f 100644 --- a/python/sglang/srt/layers/activation.py +++ b/python/sglang/srt/layers/activation.py @@ -13,27 +13,125 @@ """Fused operators for activation layers.""" +from typing import Optional + import torch +import torch.nn as nn import torch.nn.functional as F -from flashinfer.activation import silu_and_mul +from flashinfer.activation import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul +from vllm.distributed import ( + divide, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) from vllm.model_executor.custom_op import CustomOp +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.utils import set_weight_attrs class SiluAndMul(CustomOp): - def __init__(self, **kwargs): - super().__init__() - self.is_lower_sm80 = torch.cuda.get_device_capability()[0] < 8 - def forward_native(self, x: torch.Tensor) -> torch.Tensor: d = x.shape[-1] // 2 return F.silu(x[..., :d]) * x[..., d:] def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: - if self.is_lower_sm80: - return self.forward_native(x) - d = x.shape[-1] // 2 output_shape = x.shape[:-1] + (d,) out = torch.empty(output_shape, dtype=x.dtype, device=x.device) silu_and_mul(x, out) return out + + +class GeluAndMul(CustomOp): + def __init__(self, approximate="tanh"): + super().__init__() + self.approximate = approximate + + def forward_native(self, x: torch.Tensor) -> torch.Tensor: + d = x.shape[-1] // 2 + return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:] + + def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: + d = x.shape[-1] // 2 + output_shape = x.shape[:-1] + (d,) + out = torch.empty(output_shape, dtype=x.dtype, device=x.device) + if self.approximate == "tanh": + gelu_tanh_and_mul(x, out) + elif self.approximate == "none": + gelu_and_mul(x, out) + else: + raise RuntimeError("GeluAndMul only support tanh or none") + return out + + +class ScaledActivation(nn.Module): + """An activation function with post-scale parameters. + + This is used for some quantization methods like AWQ. + """ + + def __init__( + self, + act_module: nn.Module, + intermediate_size: int, + input_is_parallel: bool = True, + params_dtype: Optional[torch.dtype] = None, + ): + super().__init__() + self.act = act_module + self.input_is_parallel = input_is_parallel + if input_is_parallel: + tp_size = get_tensor_model_parallel_world_size() + intermediate_size_per_partition = divide(intermediate_size, tp_size) + else: + intermediate_size_per_partition = intermediate_size + if params_dtype is None: + params_dtype = torch.get_default_dtype() + self.scales = nn.Parameter( + torch.empty(intermediate_size_per_partition, dtype=params_dtype) + ) + set_weight_attrs(self.scales, {"weight_loader": self.weight_loader}) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.act(x) / self.scales + + def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor): + param_data = param.data + if self.input_is_parallel: + tp_rank = get_tensor_model_parallel_rank() + shard_size = param_data.shape[0] + start_idx = tp_rank * shard_size + loaded_weight = loaded_weight.narrow(0, start_idx, shard_size) + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + + +_ACTIVATION_REGISTRY = { + "gelu": nn.GELU(), + "gelu_pytorch_tanh": nn.GELU(approximate="tanh"), +} + + +def get_act_fn( + act_fn_name: str, + quant_config: Optional[QuantizationConfig] = None, + intermediate_size: Optional[int] = None, + input_is_parallel: bool = True, + params_dtype: Optional[torch.dtype] = None, +) -> nn.Module: + """Get an activation function by name.""" + act_fn_name = act_fn_name.lower() + if act_fn_name not in _ACTIVATION_REGISTRY: + raise ValueError(f"Activation function {act_fn_name!r} is not supported.") + + act_fn = _ACTIVATION_REGISTRY[act_fn_name] + if quant_config is not None and act_fn_name in quant_config.get_scaled_act_names(): + if intermediate_size is None: + raise ValueError( + "intermediate_size must be specified for scaled " + "activation functions." + ) + return ScaledActivation( + act_fn, intermediate_size, input_is_parallel, params_dtype + ) + return act_fn diff --git a/python/sglang/srt/layers/decode_attention.py b/python/sglang/srt/layers/decode_attention.py index eef3c000968..dc92a65480c 100644 --- a/python/sglang/srt/layers/decode_attention.py +++ b/python/sglang/srt/layers/decode_attention.py @@ -26,7 +26,7 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict -if global_server_args_dict.get("attention_reduce_in_fp32", False): +if global_server_args_dict.get("triton_attention_reduce_in_fp32", False): REDUCE_TRITON_TYPE = tl.float32 REDUCE_TORCH_TYPE = torch.float32 else: diff --git a/python/sglang/srt/layers/extend_attention.py b/python/sglang/srt/layers/extend_attention.py index 097adca3ca0..6c7686971e0 100644 --- a/python/sglang/srt/layers/extend_attention.py +++ b/python/sglang/srt/layers/extend_attention.py @@ -127,8 +127,7 @@ def _fwd_kernel( ) k = tl.load(K_Buffer + offs_buf_k, mask=mask_n[None, :], other=0.0) - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, k) + qk = tl.dot(q.to(k.dtype), k) if BLOCK_DPE > 0: offs_kpe = ( offs_kv_loc[None, :] * stride_buf_kbs @@ -140,7 +139,7 @@ def _fwd_kernel( mask=mask_n[None, :], other=0.0, ) - qk += tl.dot(qpe, kpe) + qk += tl.dot(qpe.to(kpe.dtype), kpe) qk *= sm_scale if logit_cap > 0: @@ -179,9 +178,7 @@ def _fwd_kernel( ) k = tl.load(K_Extend + offs_k, mask=mask_n[None, :], other=0.0) - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, k) - + qk = tl.dot(q, k, out_dtype=tl.float32) if BLOCK_DPE > 0: offs_kpe = ( (cur_seq_extend_start_contiguous + start_n + offs_n[None, :]) @@ -276,9 +273,17 @@ def extend_attention_fwd( BLOCK_DV = Lv if CUDA_CAPABILITY[0] >= 9: - BLOCK_M, BLOCK_N = (128, 64) + if Lq <= 256: + BLOCK_M, BLOCK_N = (128, 64) + else: + BLOCK_M, BLOCK_N = (32, 64) elif CUDA_CAPABILITY[0] >= 8: - BLOCK_M, BLOCK_N = (128, 128) if Lq <= 128 else (64, 64) + if Lq <= 128: + BLOCK_M, BLOCK_N = (128, 128) + elif Lq <= 256: + BLOCK_M, BLOCK_N = (64, 64) + else: + BLOCK_M, BLOCK_N = (32, 64) else: BLOCK_M, BLOCK_N = (64, 64) if Lq <= 128 else (32, 32) diff --git a/python/sglang/srt/layers/fused_moe/layer.py b/python/sglang/srt/layers/fused_moe/layer.py index 0b17c14ffd8..e08ec5c58a8 100644 --- a/python/sglang/srt/layers/fused_moe/layer.py +++ b/python/sglang/srt/layers/fused_moe/layer.py @@ -239,7 +239,7 @@ def weight_loader( weight_name: str, shard_id: int, expert_id: int, - pre_sharded: bool, + use_presharded_weights: bool = False, ): param_data = param.data @@ -273,7 +273,7 @@ def weight_loader( else: tp_rank = get_tensor_model_parallel_rank() shard_size = self.intermediate_size_per_partition - if pre_sharded: + if use_presharded_weights: shard = slice(None) else: shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size) diff --git a/python/sglang/srt/layers/layernorm.py b/python/sglang/srt/layers/layernorm.py index 6cea85404a0..4c24f50ffe4 100644 --- a/python/sglang/srt/layers/layernorm.py +++ b/python/sglang/srt/layers/layernorm.py @@ -19,7 +19,12 @@ import torch import torch.nn as nn -from flashinfer.norm import fused_add_rmsnorm, rmsnorm +from flashinfer.norm import ( + fused_add_rmsnorm, + gemma_fused_add_rmsnorm, + gemma_rmsnorm, + rmsnorm, +) from vllm.model_executor.custom_op import CustomOp @@ -32,15 +37,12 @@ def __init__( super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps - self.is_lower_sm80 = torch.cuda.get_device_capability()[0] < 8 def forward_cuda( self, x: torch.Tensor, residual: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - if self.is_lower_sm80: - return self.forward_native(x, residual) if residual is not None: fused_add_rmsnorm(x, residual, self.weight.data, self.variance_epsilon) @@ -66,3 +68,44 @@ def forward_native( return x else: return x, residual + + +class GemmaRMSNorm(CustomOp): + def __init__( + self, + hidden_size: int, + eps: float = 1e-6, + ) -> None: + super().__init__() + self.weight = nn.Parameter(torch.zeros(hidden_size)) + self.variance_epsilon = eps + + def forward_native( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + orig_dtype = x.dtype + if residual is not None: + x = x + residual + residual = x + + x = x.float() + variance = x.pow(2).mean(dim=-1, keepdim=True) + x = x * torch.rsqrt(variance + self.variance_epsilon) + x = x * (1.0 + self.weight.float()) + x = x.to(orig_dtype) + return x if residual is None else (x, residual) + + def forward_cuda( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + if residual is not None: + gemma_fused_add_rmsnorm( + x, residual, self.weight.data, self.variance_epsilon + ) + return x, residual + out = gemma_rmsnorm(x, self.weight.data, self.variance_epsilon) + return out diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index a5ba06de02e..b81f3d2a040 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -29,7 +29,7 @@ @dataclasses.dataclass -class LogitProcessorOutput: +class LogitsProcessorOutput: # The logits of the next tokens. shape: [#seq, vocab_size] next_token_logits: torch.Tensor # The logprobs of the next tokens. shape: [#seq, vocab_size] @@ -180,12 +180,12 @@ def forward( if hasattr(self.config, "final_logit_softcapping"): last_logits.div_(self.config.final_logit_softcapping) - last_logits = torch.tanh(last_logits) + torch.tanh(last_logits, out=last_logits) last_logits.mul_(self.config.final_logit_softcapping) # Return only last_logits if logprob is not requested if not logits_metadata.return_logprob: - return LogitProcessorOutput( + return LogitsProcessorOutput( next_token_logits=last_logits, next_token_logprobs=None, normalized_prompt_logprobs=None, @@ -209,7 +209,7 @@ def forward( else: output_top_logprobs = None - return LogitProcessorOutput( + return LogitsProcessorOutput( next_token_logits=last_logits, next_token_logprobs=last_logprobs, normalized_prompt_logprobs=None, @@ -241,7 +241,7 @@ def forward( if hasattr(self.config, "final_logit_softcapping"): all_logits.div_(self.config.final_logit_softcapping) - all_logits = torch.tanh(all_logits) + torch.tanh(all_logits, out=all_logits) all_logits.mul_(self.config.final_logit_softcapping) all_logprobs = all_logits @@ -278,7 +278,7 @@ def forward( # Remove the last token logprob for the prefill tokens. input_token_logprobs = input_token_logprobs[:-1] - return LogitProcessorOutput( + return LogitsProcessorOutput( next_token_logits=last_logits, next_token_logprobs=last_logprobs, normalized_prompt_logprobs=normalized_prompt_logprobs, diff --git a/python/sglang/srt/layers/radix_attention.py b/python/sglang/srt/layers/radix_attention.py index a02673dc374..91735a1b810 100644 --- a/python/sglang/srt/layers/radix_attention.py +++ b/python/sglang/srt/layers/radix_attention.py @@ -203,7 +203,6 @@ def forward(self, q, k, v, input_metadata: InputMetadata): return self.decode_forward(q, k, v, input_metadata) def store_kv_cache(self, cache_k, cache_v, input_metadata: InputMetadata): - k_cache = input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id) - v_cache = input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id) - k_cache[input_metadata.out_cache_loc] = cache_k - v_cache[input_metadata.out_cache_loc] = cache_v + input_metadata.token_to_kv_pool.set_kv_buffer( + self.layer_id, input_metadata.out_cache_loc, cache_k, cache_v + ) diff --git a/python/sglang/srt/layers/sampler.py b/python/sglang/srt/layers/sampler.py index 3006e765c88..6cb7d5b5508 100644 --- a/python/sglang/srt/layers/sampler.py +++ b/python/sglang/srt/layers/sampler.py @@ -1,4 +1,6 @@ +import dataclasses import logging +from typing import Tuple, Union import torch from flashinfer.sampling import ( @@ -7,8 +9,11 @@ top_k_top_p_sampling_from_probs, top_p_renorm_prob, ) +from torch.library import custom_op as torch_custom_op from vllm.model_executor.custom_op import CustomOp +from sglang.srt.layers.logits_processor import LogitsProcessorOutput + # TODO: move this dict to another place from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo @@ -16,37 +21,76 @@ logger = logging.getLogger(__name__) +@dataclasses.dataclass +class SampleOutput: + success: torch.Tensor + probs: torch.Tensor + batch_next_token_ids: torch.Tensor + + class Sampler(CustomOp): def __init__(self): super().__init__() + # FIXME: torch.multinomial has too many bugs + self.forward_native = self.forward_cuda + self.is_torch_compile = False + + def _apply_penalties(self, logits: torch.Tensor, sampling_info: SamplingBatchInfo): + # min-token, presence, frequency + if sampling_info.linear_penalties is not None: + logits += sampling_info.linear_penalties + + # repetition + if sampling_info.scaling_penalties is not None: + logits = torch.where( + logits > 0, + logits / sampling_info.scaling_penalties, + logits * sampling_info.scaling_penalties, + ) + + return logits - def forward_cuda(self, logits: torch.Tensor, sampling_info: SamplingBatchInfo): + def _get_probs(self, logits: torch.Tensor, sampling_info: SamplingBatchInfo): # Post process logits logits = logits.contiguous() logits.div_(sampling_info.temperatures) + if self.is_torch_compile: + # FIXME: Temporary workaround for unknown bugs in torch.compile + logits.add_(0) + if sampling_info.logit_bias is not None: logits.add_(sampling_info.logit_bias) if sampling_info.vocab_mask is not None: - logits = logits.masked_fill(~sampling_info.vocab_mask, float("-inf")) + logits = logits.masked_fill(sampling_info.vocab_mask, float("-inf")) - logits = sampling_info.penalizer_orchestrator.apply(logits) + logits = self._apply_penalties(logits, sampling_info) - probs = torch.softmax(logits, dim=-1) + return torch.softmax(logits, dim=-1) + + def forward_cuda( + self, + logits: Union[torch.Tensor, LogitsProcessorOutput], + sampling_info: SamplingBatchInfo, + ): + if isinstance(logits, LogitsProcessorOutput): + logits = logits.next_token_logits + + probs = self._get_probs(logits, sampling_info) if not global_server_args_dict["disable_flashinfer_sampling"]: max_top_k_round, batch_size = 32, probs.shape[0] uniform_samples = torch.rand( (max_top_k_round, batch_size), device=probs.device ) - if sampling_info.min_ps.any(): + if sampling_info.need_min_p_sampling: probs = top_k_renorm_prob(probs, sampling_info.top_ks) probs = top_p_renorm_prob(probs, sampling_info.top_ps) batch_next_token_ids, success = min_p_sampling_from_probs( probs, uniform_samples, sampling_info.min_ps ) else: - batch_next_token_ids, success = top_k_top_p_sampling_from_probs( + batch_next_token_ids, success = flashinfer_top_k_top_p( probs, uniform_samples, sampling_info.top_ks, sampling_info.top_ps ) else: @@ -55,18 +99,48 @@ def forward_cuda(self, logits: torch.Tensor, sampling_info: SamplingBatchInfo): probs, sampling_info.top_ks, sampling_info.top_ps, sampling_info.min_ps ) - if not torch.all(success): - logging.warning("Sampling failed, fallback to top_k=1 strategy") - probs = probs.masked_fill(torch.isnan(probs), 0.0) - argmax_ids = torch.argmax(probs, dim=-1) - batch_next_token_ids = torch.where( - success, batch_next_token_ids, argmax_ids - ) + return SampleOutput(success, probs, batch_next_token_ids) + + def forward_native( + self, + logits: Union[torch.Tensor, LogitsProcessorOutput], + sampling_info: SamplingBatchInfo, + ): + if isinstance(logits, LogitsProcessorOutput): + logits = logits.next_token_logits + + probs = self._get_probs(logits, sampling_info) + + batch_next_token_ids, success = top_k_top_p_min_p_sampling_from_probs_torch( + probs, sampling_info.top_ks, sampling_info.top_ps, sampling_info.min_ps + ) + + return SampleOutput(success, probs, batch_next_token_ids) - return batch_next_token_ids - def forward_native(): - raise NotImplementedError("Native forward is not implemented yet.") +@torch_custom_op("my_lib::flashinfer_top_k_top_p", mutates_args={}) +def flashinfer_top_k_top_p( + probs: torch.Tensor, + uniform_samples: torch.Tensor, + top_ks: torch.Tensor, + top_ps: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + # NOTE: we do not use min_p neither in CUDA nor in torch.compile + return top_k_top_p_sampling_from_probs(probs, uniform_samples, top_ks, top_ps) + + +@flashinfer_top_k_top_p.register_fake +def _( + probs: torch.Tensor, + uniform_samples: torch.Tensor, + top_ks: torch.Tensor, + top_ps: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + bs = probs.shape[0] + return ( + torch.ones(bs, dtype=torch.bool, device=probs.device), + torch.zeros(bs, dtype=torch.int32, device=probs.device), + ) def top_k_top_p_min_p_sampling_from_probs_torch( @@ -87,7 +161,10 @@ def top_k_top_p_min_p_sampling_from_probs_torch( probs_sort[probs_sort < min_p_thresholds.view(-1, 1)] = 0.0 probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0]) try: - sampled_index = torch.multinomial(probs_sort, num_samples=1) + # FIXME: torch.multiomial does not support num_samples = 1 + sampled_index = torch.multinomial(probs_sort, num_samples=2, replacement=True)[ + :, :1 + ] except RuntimeError as e: logger.warning(f"Sampling error: {e}") batch_next_token_ids = torch.zeros( diff --git a/python/sglang/srt/managers/controller_multi.py b/python/sglang/srt/managers/controller_multi.py index 38229cd4660..ba626d4cffc 100644 --- a/python/sglang/srt/managers/controller_multi.py +++ b/python/sglang/srt/managers/controller_multi.py @@ -35,7 +35,7 @@ TokenizedGenerateReqInput, ) from sglang.srt.server_args import PortArgs, ServerArgs -from sglang.srt.utils import kill_parent_process +from sglang.srt.utils import configure_logger, kill_parent_process from sglang.utils import get_exception_traceback logger = logging.getLogger(__name__) @@ -71,12 +71,12 @@ def __init__( self, server_args: ServerArgs, port_args: PortArgs, - model_overide_args, + model_override_args, ): # Parse args self.server_args = server_args self.port_args = port_args - self.model_overide_args = model_overide_args + self.model_override_args = model_override_args self.load_balance_method = LoadBalanceMethod.from_str( server_args.load_balance_method ) @@ -114,7 +114,7 @@ def start_dp_worker(self, dp_worker_id: int): self.server_args, self.port_args, pipe_controller_writer, - self.model_overide_args, + self.model_override_args, True, gpu_ids, dp_worker_id, @@ -189,17 +189,14 @@ def start_controller_process( server_args: ServerArgs, port_args: PortArgs, pipe_writer, - model_overide_args: dict, + model_override_args: dict, ): """Start a controller process.""" - logging.basicConfig( - level=getattr(logging, server_args.log_level.upper()), - format="%(message)s", - ) + configure_logger(server_args) try: - controller = ControllerMulti(server_args, port_args, model_overide_args) + controller = ControllerMulti(server_args, port_args, model_override_args) except Exception: pipe_writer.send(get_exception_traceback()) raise diff --git a/python/sglang/srt/managers/controller_single.py b/python/sglang/srt/managers/controller_single.py index 422db943f6b..2ae37059c10 100644 --- a/python/sglang/srt/managers/controller_single.py +++ b/python/sglang/srt/managers/controller_single.py @@ -27,7 +27,7 @@ launch_tp_servers, ) from sglang.srt.server_args import PortArgs, ServerArgs -from sglang.srt.utils import kill_parent_process +from sglang.srt.utils import configure_logger, kill_parent_process from sglang.utils import get_exception_traceback logger = logging.getLogger(__name__) @@ -40,7 +40,7 @@ def __init__( self, server_args: ServerArgs, port_args: PortArgs, - model_overide_args: dict, + model_override_args: dict, gpu_ids: List[int], is_data_parallel_worker: bool, dp_worker_id: int, @@ -52,7 +52,7 @@ def __init__( self.dp_worker_id = dp_worker_id self.mp_queue = mp_queue - # Init communication + # Init inter-process communication context = zmq.Context(2) if not self.is_dp_worker: @@ -76,7 +76,7 @@ def __init__( tp_rank_range, server_args, port_args.nccl_ports[dp_worker_id], - model_overide_args, + model_override_args, ) # Launch tp rank 0 @@ -85,7 +85,7 @@ def __init__( 0, server_args, port_args.nccl_ports[dp_worker_id], - model_overide_args, + model_override_args, ) self.tp_cpu_group = self.tp_server.model_runner.tp_group.cpu_group @@ -126,18 +126,18 @@ def start_controller_process( server_args: ServerArgs, port_args: PortArgs, pipe_writer: multiprocessing.connection.Connection, - model_overide_args: dict, + model_override_args: dict, is_data_parallel_worker: bool = False, gpu_ids: List[int] = None, dp_worker_id: int = None, queue: multiprocessing.connection.Connection = None, ): """Start a controller process.""" - - logging.basicConfig( - level=getattr(logging, server_args.log_level.upper()), - format="%(message)s", - ) + if is_data_parallel_worker: + logger_prefix = f" DP{dp_worker_id} TP0" + else: + logger_prefix = " TP0" + configure_logger(server_args, prefix=logger_prefix) if not is_data_parallel_worker: tp_size_local = server_args.tp_size // server_args.nnodes @@ -149,7 +149,7 @@ def start_controller_process( controller = ControllerSingle( server_args, port_args, - model_overide_args, + model_override_args, gpu_ids, is_data_parallel_worker, dp_worker_id, diff --git a/python/sglang/srt/managers/detokenizer_manager.py b/python/sglang/srt/managers/detokenizer_manager.py index 9a4306372b1..cd5f63125cb 100644 --- a/python/sglang/srt/managers/detokenizer_manager.py +++ b/python/sglang/srt/managers/detokenizer_manager.py @@ -56,6 +56,7 @@ def __init__( server_args: ServerArgs, port_args: PortArgs, ): + # Init inter-process communication context = zmq.asyncio.Context(2) self.recv_from_router = context.socket(zmq.PULL) self.recv_from_router.bind(f"tcp://127.0.0.1:{port_args.detokenizer_port}") @@ -75,10 +76,13 @@ def __init__( self.decode_status = {} async def handle_loop(self): + """The event loop that handles requests""" + while True: - recv_obj: BatchTokenIDOut = await self.recv_from_router.recv_pyobj() + recv_obj = await self.recv_from_router.recv_pyobj() if isinstance(recv_obj, BatchEmbeddingOut): + # If it is embedding model, no detokenization is needed. self.send_to_tokenizer.send_pyobj( BatchEmbeddingOut( rids=recv_obj.rids, @@ -88,19 +92,18 @@ async def handle_loop(self): ) ) continue - - if isinstance(recv_obj, UpdateWeightReqOutput): + elif isinstance(recv_obj, UpdateWeightReqOutput): + # If it is a weight update request, no detokenization is needed. + self.send_to_tokenizer.send_pyobj(recv_obj) + continue + elif self.tokenizer is None: + # If the tokenizer is skipped, no detokenization is needed self.send_to_tokenizer.send_pyobj(recv_obj) continue assert isinstance(recv_obj, BatchTokenIDOut) bs = len(recv_obj.rids) - if self.tokenizer is None: - # Send BatchTokenIDOut if no tokenizer init'ed. - self.send_to_tokenizer.send_pyobj(recv_obj) - continue - # Initialize decode status read_ids, surr_ids = [], [] for i in range(bs): @@ -134,6 +137,7 @@ async def handle_loop(self): spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[0], ) + # Incremental decoding output_strs = [] for i in range(bs): s = self.decode_status[recv_obj.rids[i]] diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 56e3d8f7990..5b91ff62e9d 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -18,8 +18,9 @@ processes (TokenizerManager, DetokenizerManager, Controller). """ +import copy import uuid -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Dict, List, Optional, Union from sglang.srt.managers.schedule_batch import BaseFinishReason @@ -55,6 +56,7 @@ def post_init(self): self.text is not None and self.input_ids is not None ): raise ValueError("Either text or input_ids should be provided.") + if ( isinstance(self.sampling_params, dict) and self.sampling_params.get("n", 1) != 1 @@ -161,10 +163,10 @@ class TokenizedGenerateReqInput: input_ids: List[int] # The pixel values for input images pixel_values: List[float] - # The hash of input images - image_hash: int - # The image size - image_size: List[int] + # The hash values of input images + image_hashes: List[int] + # The image sizes + image_sizes: List[List[int]] # The sampling parameters sampling_params: SamplingParams # Whether to return the logprobs @@ -248,6 +250,10 @@ class BatchTokenIDOut: meta_info: List[Dict] finished_reason: List[BaseFinishReason] + def __post_init__(self): + # deepcopy meta_info to avoid modification in place + self.meta_info = copy.deepcopy(self.meta_info) + @dataclass class BatchStrOut: diff --git a/python/sglang/srt/managers/policy_scheduler.py b/python/sglang/srt/managers/policy_scheduler.py index 04169e80861..3a70bfe5482 100644 --- a/python/sglang/srt/managers/policy_scheduler.py +++ b/python/sglang/srt/managers/policy_scheduler.py @@ -108,18 +108,24 @@ class PrefillAdder: def __init__( self, tree_cache: BasePrefixCache, + running_batch: ScheduleBatch, + new_token_ratio: float, rem_total_tokens: int, rem_input_tokens: int, rem_chunk_tokens: Optional[int], mixed_with_decode_tokens: int = 0, ): self.tree_cache = tree_cache + self.running_batch = running_batch + self.new_token_ratio = new_token_ratio self.rem_total_tokens = rem_total_tokens - mixed_with_decode_tokens + self.total_tokens = rem_total_tokens self.rem_input_tokens = rem_input_tokens - mixed_with_decode_tokens self.rem_chunk_tokens = rem_chunk_tokens if self.rem_chunk_tokens is not None: self.rem_chunk_tokens -= mixed_with_decode_tokens + self.req_states = None self.can_run_list = [] self.new_inflight_req = None self.log_hit_tokens = 0 @@ -136,16 +142,14 @@ def no_remaining_tokens(self): ) ) - def remove_running_tokens( - self, running_batch: ScheduleBatch, new_token_ratio: float - ): + def remove_running_tokens(self, running_batch: ScheduleBatch): self.rem_total_tokens -= sum( [ min( (r.sampling_params.max_new_tokens - len(r.output_ids)), CLIP_MAX_NEW_TOKENS, ) - * new_token_ratio + * self.new_token_ratio for r in running_batch.reqs ] ) @@ -161,7 +165,29 @@ def _prefill_one_req( self.log_hit_tokens += prefix_len self.log_input_tokens += extend_input_len + def add_inflight_req_ignore_eos(self, req: Req): + truncated = req.extend_input_len > self.rem_chunk_tokens + req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens) + req.fill_ids = req.fill_ids[: len(req.prefix_indices) + req.extend_input_len] + self.can_run_list.append(req) + + self._prefill_one_req( + 0, + req.extend_input_len, + ( + min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS) + if not truncated + else 0 + ), + ) + + # Return if chunked prefill not finished + return req if truncated else None + def add_inflight_req(self, req: Req): + if req.sampling_params.ignore_eos: + return self.add_inflight_req_ignore_eos(req) + truncated = req.extend_input_len > self.rem_chunk_tokens req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens) req.fill_ids = req.fill_ids[: len(req.prefix_indices) + req.extend_input_len] @@ -190,7 +216,81 @@ def _lock_node(self, last_node: TreeNode): delta = self.tree_cache.dec_lock_ref(last_node) self.rem_total_tokens += delta + def add_one_req_ignore_eos(self, req: Req): + def get_req_state(r): + new_token_ratio = ( + 1.0 if r.sampling_params.ignore_eos else self.new_token_ratio + ) + tokens_left = r.sampling_params.max_new_tokens * new_token_ratio - len( + r.output_ids + ) + tokens_occupied = len(r.origin_input_ids) + len(r.output_ids) + + if tokens_left > 0: + return (tokens_left, tokens_occupied) + + return None + + if self.req_states is None: + self.req_states = [] + if self.running_batch is not None: + for r in self.running_batch.reqs: + state = get_req_state(r) + if state is not None: + self.req_states.append(state) + for r in self.can_run_list: + state = get_req_state(r) + if state is not None: + self.req_states.append(state) + state = get_req_state(req) + if state is not None: + self.req_states.append(state) + + self.req_states.sort(key=lambda x: x[0]) + else: + state = get_req_state(req) + if state is not None: + for i, (tokens_left, tokens_occupied) in enumerate(self.req_states): + if tokens_left >= state[0]: + self.req_states.insert(i, state) + break + else: + self.req_states.append(state) + + tokens_freed = 0 + for i, (tokens_left, tokens_occupied) in enumerate(self.req_states): + decode_steps = ( + self.req_states[i + 1][0] + if i + 1 < len(self.req_states) + else tokens_left + ) + bs = len(self.req_states) - i + if self.total_tokens + tokens_freed - decode_steps * bs <= 0: + return False + tokens_freed += tokens_occupied + + if req.extend_input_len <= self.rem_chunk_tokens: + self.can_run_list.append(req) + self._prefill_one_req( + 0, + req.extend_input_len, + min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS), + ) + else: + # Chunked prefill + trunc_len = self.rem_chunk_tokens + req.extend_input_len = trunc_len + req.fill_ids = req.fill_ids[:trunc_len] + self.can_run_list.append(req) + self.new_inflight_req = req + self._prefill_one_req(0, trunc_len, 0) + + return True + def add_one_req(self, req: Req): + if req.sampling_params.ignore_eos and self.tree_cache.disable: + return self.add_one_req_ignore_eos(req) + total_tokens = req.extend_input_len + min( req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS ) @@ -233,4 +333,4 @@ def add_one_req(self, req: Req): self.tree_cache.inc_lock_ref(req.last_node) self._prefill_one_req(prefix_len, trunc_len, 0) - return True + return True and not self.no_remaining_tokens() diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 75c33bb8b4a..c80cf2e2723 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -1,3 +1,5 @@ +from __future__ import annotations + """ Copyright 2023-2024 SGLang Team Licensed under the Apache License, Version 2.0 (the "License"); @@ -17,7 +19,7 @@ import logging from dataclasses import dataclass -from typing import List, Optional, Union +from typing import TYPE_CHECKING, List, Optional, Union import torch @@ -29,13 +31,17 @@ from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo +if TYPE_CHECKING: + from sglang.srt.layers.sampler import SampleOutput + + INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5 # Put some global args for easy access global_server_args_dict = { "disable_flashinfer": False, "disable_flashinfer_sampling": False, - "attention_reduce_in_fp32": False, + "triton_attention_reduce_in_fp32": False, "enable_mla": False, } @@ -121,8 +127,8 @@ def __init__(self, rid, origin_input_text, origin_input_ids): # For vision input self.pixel_values = None - self.image_size = None - self.image_offset = None + self.image_sizes = None + self.image_offsets = None self.pad_value = None # Prefix info @@ -172,19 +178,22 @@ def init_next_round_input(self, tree_cache: Optional[BasePrefixCache] = None): def adjust_max_prefix_ids(self): self.fill_ids = self.origin_input_ids + self.output_ids input_len = len(self.fill_ids) - max_prefix_len = input_len + + # FIXME: To work around some bugs in logprob computation, we need to ensure each + # request has at least one token. Later, we can relax this requirement and use `input_len`. + max_prefix_len = input_len - 1 if self.sampling_params.max_new_tokens > 0: # Need at least one token to compute logits max_prefix_len = min(max_prefix_len, input_len - 1) if self.return_logprob: - max_prefix_len = min(max_prefix_len, self.logprob_start_len) - if self.normalized_prompt_logprob is None: # Need at least two tokens to compute normalized logprob max_prefix_len = min(max_prefix_len, input_len - 2) + max_prefix_len = min(max_prefix_len, self.logprob_start_len) + max_prefix_len = max(max_prefix_len, 0) return self.fill_ids[:max_prefix_len] # Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313 @@ -262,7 +271,14 @@ def jump_forward_and_retokenize(self, jump_forward_str, next_state): all_text = self.origin_input_text + self.decoded_text + jump_forward_str all_ids = self.tokenizer.encode(all_text) + if not all_ids: + logger.warning("Encoded all_text resulted in empty all_ids") + return False + prompt_tokens = len(self.origin_input_ids_unpadded) + if prompt_tokens > len(all_ids): + logger.warning("prompt_tokens is larger than encoded all_ids") + return False if all_ids[prompt_tokens - 1] != self.origin_input_ids_unpadded[-1]: # TODO(lsyin): fix token fusion @@ -593,12 +609,12 @@ def check_for_jump_forward(self, model_runner): if req.pixel_values is not None: ( req.origin_input_ids, - req.image_offset, + req.image_offsets, ) = model_runner.model.pad_input_ids( req.origin_input_ids_unpadded, req.pad_value, - req.pixel_values.shape, - req.image_size, + req.pixel_values, + req.image_sizes, ) jump_forward_reqs.append(req) @@ -671,11 +687,17 @@ def merge(self, other: "ScheduleBatch"): self.top_logprobs_nums.extend(other.top_logprobs_nums) self.return_logprob = any(req.return_logprob for req in self.reqs) - def sample(self, logits: torch.Tensor): - from sglang.srt.layers.sampler import Sampler - - sampler = Sampler() - - batch_next_token_ids = sampler(logits, self.sampling_info) + def check_sample_results(self, sample_output: SampleOutput): + if not torch.all(sample_output.success): + probs = sample_output.probs + batch_next_token_ids = sample_output.batch_next_token_ids + logging.warning("Sampling failed, fallback to top_k=1 strategy") + probs = probs.masked_fill(torch.isnan(probs), 0.0) + argmax_ids = torch.argmax(probs, dim=-1) + batch_next_token_ids = torch.where( + sample_output.success, batch_next_token_ids, argmax_ids + ) + sample_output.probs = probs + sample_output.batch_next_token_ids = batch_next_token_ids - return batch_next_token_ids + return sample_output.batch_next_token_ids diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 328519cb26e..6af82064152 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -21,8 +21,9 @@ import logging import multiprocessing as mp import os -from typing import Dict, List, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union +import fastapi import numpy as np import transformers import uvloop @@ -76,35 +77,38 @@ def __init__( self, server_args: ServerArgs, port_args: PortArgs, - model_overide_args: dict = None, + model_override_args: dict = None, ): self.server_args = server_args + # Init inter-process communication context = zmq.asyncio.Context(2) self.recv_from_detokenizer = context.socket(zmq.PULL) self.recv_from_detokenizer.bind(f"tcp://127.0.0.1:{port_args.tokenizer_port}") - self.send_to_router = context.socket(zmq.PUSH) - self.send_to_router.connect(f"tcp://127.0.0.1:{port_args.controller_port}") + self.send_to_controller = context.socket(zmq.PUSH) + self.send_to_controller.connect(f"tcp://127.0.0.1:{port_args.controller_port}") + # Read model args self.model_path = server_args.model_path self.served_model_name = server_args.served_model_name self.hf_config = get_config( self.model_path, trust_remote_code=server_args.trust_remote_code, - model_overide_args=model_overide_args, + model_override_args=model_override_args, + ) + self.is_generation = is_generation_model( + self.hf_config.architectures, self.server_args.is_embedding + ) + self.context_len = server_args.context_length or get_context_length( + self.hf_config ) - self.is_generation = is_generation_model(self.hf_config.architectures) - - if server_args.context_length is not None: - self.context_len = server_args.context_length - else: - self.context_len = get_context_length(self.hf_config) + # Create tokenizer if server_args.skip_tokenizer_init: self.tokenizer = self.processor = None else: - if is_multimodal_model(self.model_path): + if is_multimodal_model(self.hf_config.architectures): self.processor = get_processor( server_args.tokenizer_path, tokenizer_mode=server_args.tokenizer_mode, @@ -112,6 +116,9 @@ def __init__( ) self.tokenizer = self.processor.tokenizer os.environ["TOKENIZERS_PARALLELISM"] = "false" + + # We want to parallelize the image pre-processing so we + # create an executor for it self.executor = concurrent.futures.ProcessPoolExecutor( initializer=init_global_processor, mp_context=mp.get_context("fork"), @@ -124,40 +131,24 @@ def __init__( trust_remote_code=server_args.trust_remote_code, ) + # Store states self.to_create_loop = True self.rid_to_state: Dict[str, ReqState] = {} - # for update model weights + # For update model weights self.model_update_lock = asyncio.Lock() self.model_update_result = None - async def get_pixel_values(self, image_data): - aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", None) - grid_pinpoints = ( - self.hf_config.image_grid_pinpoints if aspect_ratio == "anyres" else None - ) - if self.executor is not None: - loop = asyncio.get_event_loop() - return await loop.run_in_executor( - self.executor, - get_pixel_values, - image_data, - aspect_ratio, - grid_pinpoints, - ) - else: - return get_pixel_values( - image_data, aspect_ratio, grid_pinpoints, self.processor - ) - async def generate_request( - self, obj: Union[GenerateReqInput, EmbeddingReqInput], request=None + self, + obj: Union[GenerateReqInput, EmbeddingReqInput], + request: Optional[fastapi.Request] = None, ): if self.to_create_loop: self.create_handle_loop() while self.model_update_lock.locked(): - await asyncio.sleep(0) + await asyncio.sleep(0.001) obj.post_init() is_single = obj.is_single @@ -172,9 +163,9 @@ async def generate_request( async def _handle_single_request( self, obj: Union[GenerateReqInput, EmbeddingReqInput], - request, - index=None, - is_cache_for_prefill=False, + request: Optional[fastapi.Request] = None, + index: Optional[int] = None, + is_cache_for_prefill: Optional[bool] = False, ): if not is_cache_for_prefill: # The normal case with a single prompt not_use_index = index is None @@ -194,7 +185,7 @@ async def _handle_single_request( ) if self.is_generation: - pixel_values, image_hash, image_size = await self._get_pixel_values( + pixel_values, image_hashes, image_sizes = await self._get_pixel_values( obj.image_data if not_use_index else obj.image_data[index] ) return_logprob = ( @@ -207,7 +198,6 @@ async def _handle_single_request( ) if return_logprob and logprob_start_len == -1: logprob_start_len = len(input_ids) - 1 - top_logprobs_num = ( obj.top_logprobs_num if not_use_index @@ -250,13 +240,14 @@ async def _handle_single_request( sampling_params = SamplingParams(**obj.sampling_params[0]) sampling_params.max_new_tokens = 0 - pixel_values, image_hash, image_size = await self._get_pixel_values( + pixel_values, image_hashes, image_sizes = await self._get_pixel_values( obj.image_data[0] ) return_logprob = obj.return_logprob[0] logprob_start_len = obj.logprob_start_len[0] top_logprobs_num = obj.top_logprobs_num[0] + # Send to the controller if self.is_generation: if return_logprob and logprob_start_len == -1: logprob_start_len = len(input_ids) - 1 @@ -265,8 +256,8 @@ async def _handle_single_request( input_text, input_ids, pixel_values, - image_hash, - image_size, + image_hashes, + image_sizes, sampling_params, return_logprob, logprob_start_len, @@ -280,31 +271,31 @@ async def _handle_single_request( input_ids, sampling_params, ) + self.send_to_controller.send_pyobj(tokenized_obj) - self.send_to_router.send_pyobj(tokenized_obj) - + # Recv results event = asyncio.Event() state = ReqState([], False, event) self.rid_to_state[rid] = state if not is_cache_for_prefill: - async for response in self._wait_for_response( - event, state, obj, rid, request - ): + async for response in self._wait_for_response(state, obj, rid, request): yield response else: assert self.is_generation - await self._wait_for_cache_prefill_response(event, state, obj, rid, request) + await self._wait_for_cache_prefill_response(state, obj, rid, request) yield input_ids async def _handle_batch_request( - self, obj: Union[GenerateReqInput, EmbeddingReqInput], request + self, + obj: Union[GenerateReqInput, EmbeddingReqInput], + request: Optional[fastapi.Request] = None, ): batch_size = obj.batch_size if self.is_generation: parallel_sample_num = obj.parallel_sample_num if parallel_sample_num != 1: - # Send prefill requests to cache the common input + # Send prefill requests to cache the common prefix parallel_sample_num += 1 input_id_result = [] if obj.input_ids is None else None for i in range(batch_size): @@ -352,8 +343,8 @@ async def _handle_batch_request( if self.is_generation: if obj.return_logprob[index] and obj.logprob_start_len[index] == -1: obj.logprob_start_len[index] = len(input_ids) - 1 - pixel_values, image_hash, image_size = await self._get_pixel_values( - obj.image_data[index] + pixel_values, image_hashes, image_sizes = ( + await self._get_pixel_values(obj.image_data[index]) ) tokenized_obj = TokenizedGenerateReqInput( @@ -361,8 +352,8 @@ async def _handle_batch_request( input_text, input_ids, pixel_values, - image_hash, - image_size, + image_hashes, + image_sizes, sampling_params, obj.return_logprob[index], obj.logprob_start_len[index], @@ -376,7 +367,7 @@ async def _handle_batch_request( input_ids, sampling_params, ) - self.send_to_router.send_pyobj(tokenized_obj) + self.send_to_controller.send_pyobj(tokenized_obj) event = asyncio.Event() state = ReqState([], False, event) @@ -384,7 +375,6 @@ async def _handle_batch_request( generators.append( self._wait_for_response( - event, state, obj, rid, @@ -395,17 +385,17 @@ async def _handle_batch_request( ) # Then process the responses based on streaming option - is_stream = hasattr(obj, "stream") and obj.stream tasks = [asyncio.create_task(gen.__anext__()) for gen in generators] - output_list = [] + output_list = [None] * len(tasks) + # Recv results while tasks: done, _ = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) for task in done: - gen_index = tasks.index(task) + cur_index = tasks.index(task) try: result = task.result() @@ -413,14 +403,14 @@ async def _handle_batch_request( if is_stream: yield result else: - output_list.append(result) + output_list[result["index"]] = result - tasks[gen_index] = asyncio.create_task( - generators[gen_index].__anext__() + tasks[cur_index] = asyncio.create_task( + generators[cur_index].__anext__() ) except StopAsyncIteration: - del generators[gen_index] - del tasks[gen_index] + del generators[cur_index] + del tasks[cur_index] if not is_stream: yield output_list @@ -439,27 +429,18 @@ def _get_sampling_params(self, sampling_params_data: dict): sampling_params.verify() return sampling_params - async def _get_pixel_values(self, image_data): - if isinstance(image_data, list) and len(image_data) > 0: - return await self.get_pixel_values(image_data[0]) - elif isinstance(image_data, str): - return await self.get_pixel_values(image_data) - else: - return None, None, None - async def _wait_for_response( self, - event: asyncio.Event, state: ReqState, obj: Union[GenerateReqInput, EmbeddingReqInput], rid: str, - request, - index: int = None, + request: Optional[fastapi.Request] = None, + index: Optional[int] = None, response_index: int = 0, ): while True: try: - await asyncio.wait_for(event.wait(), timeout=4) + await asyncio.wait_for(state.event.wait(), timeout=4) except asyncio.TimeoutError: if request is not None and await request.is_disconnected(): for rid in [obj.rid] if obj.is_single else obj.rid: @@ -493,16 +474,15 @@ async def _wait_for_response( yield out break - event.clear() + state.event.clear() yield out async def _wait_for_cache_prefill_response( self, - event: asyncio.Event, state: ReqState, obj: GenerateReqInput, rid: str, - request, + request: Optional[fastapi.Request] = None, ): while True: try: @@ -520,9 +500,18 @@ async def _wait_for_cache_prefill_response( def flush_cache(self): req = FlushCacheReq() - self.send_to_router.send_pyobj(req) + self.send_to_controller.send_pyobj(req) + + def abort_request(self, rid: str): + if rid not in self.rid_to_state: + return + del self.rid_to_state[rid] + req = AbortReq(rid) + self.send_to_controller.send_pyobj(req) - async def update_weights(self, obj: UpdateWeightReqInput, request): + async def update_weights( + self, obj: UpdateWeightReqInput, request: Optional[fastapi.Request] = None + ): if self.to_create_loop: self.create_handle_loop() @@ -535,7 +524,7 @@ async def update_weights(self, obj: UpdateWeightReqInput, request): # wait for the previous generation requests to finish while len(self.rid_to_state) > 0: await asyncio.sleep(0) - self.send_to_router.send_pyobj(obj) + self.send_to_controller.send_pyobj(obj) self.model_update_result = asyncio.Future() result = await self.model_update_result if result.success: @@ -546,13 +535,6 @@ async def update_weights(self, obj: UpdateWeightReqInput, request): else: return False, "Another update is in progress. Please try again later." - def abort_request(self, rid: str): - if rid not in self.rid_to_state: - return - del self.rid_to_state[rid] - req = AbortReq(rid) - self.send_to_router.send_pyobj(req) - def create_abort_task(self, obj: GenerateReqInput): # Abort the request if the client is disconnected. async def abort_request(): @@ -568,11 +550,16 @@ async def abort_request(): return background_tasks def create_handle_loop(self): + if not self.to_create_loop: + return + self.to_create_loop = False loop = asyncio.get_event_loop() loop.create_task(self.handle_loop()) async def handle_loop(self): + """The event loop that handles requests""" + while True: recv_obj: Union[ BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut, UpdateWeightReqOutput @@ -669,11 +656,75 @@ def detokenize_top_logprobs_tokens(self, top_logprobs, decode_to_text: bool): ) return top_logprobs + async def _get_pixel_values(self, image_data: List[Union[str, bytes]]): + if not image_data: + return None, None, None + + aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", None) + grid_pinpoints = ( + self.hf_config.image_grid_pinpoints + if hasattr(self.hf_config, "image_grid_pinpoints") + and "anyres" in aspect_ratio + else None + ) + + if isinstance(image_data, list) and len(image_data) > 0: + # Multiple images + if len(image_data) > 1: + aspect_ratio = "pad" # LLaVA OneVision Handling: more than one image --> interleaved image mode or video mode. We do not use anyres + pixel_values, image_hashes, image_sizes = [], [], [] + for img_data in image_data: + pixel_v, image_h, image_s = await self._process_single_image( + img_data, aspect_ratio, grid_pinpoints + ) + pixel_values.append(pixel_v) + image_hashes.append(image_h) + image_sizes.append(image_s) + + if isinstance(pixel_values[0], np.ndarray): + pixel_values = np.stack(pixel_values, axis=0) + else: + # A single image + pixel_values, image_hash, image_size = await self._process_single_image( + image_data[0], aspect_ratio, grid_pinpoints + ) + image_hashes = [image_hash] + image_sizes = [image_size] + elif isinstance(image_data, str): + # A single image + pixel_values, image_hash, image_size = await self._process_single_image( + image_data, aspect_ratio, grid_pinpoints + ) + image_hashes = [image_hash] + image_sizes = [image_size] + else: + raise ValueError(f"Invalid image data: {image_data}") + + return pixel_values, image_hashes, image_sizes + + async def _process_single_image( + self, image_data: Union[bytes, str], aspect_ratio: str, grid_pinpoints: str + ): + if self.executor is not None: + loop = asyncio.get_event_loop() + return await loop.run_in_executor( + self.executor, + _process_single_image_task, + image_data, + aspect_ratio, + grid_pinpoints, + ) + else: + return _process_single_image_task( + image_data, aspect_ratio, grid_pinpoints, self.processor + ) + global global_processor def init_global_processor(server_args: ServerArgs): + """Init the global processor for multi modal models.""" global global_processor transformers.logging.set_verbosity_error() global_processor = get_processor( @@ -683,13 +734,17 @@ def init_global_processor(server_args: ServerArgs): ) -def get_pixel_values( - image_data, image_aspect_ratio=None, image_grid_pinpoints=None, processor=None +def _process_single_image_task( + image_data: Union[str, bytes], + image_aspect_ratio: Optional[str] = None, + image_grid_pinpoints: Optional[str] = None, + processor=None, ): try: processor = processor or global_processor image, image_size = load_image(image_data) if image_size is not None: + # It is a video with multiple images image_hash = hash(image_data) pixel_values = processor.image_processor(image)["pixel_values"] for _ in range(len(pixel_values)): @@ -697,20 +752,28 @@ def get_pixel_values( pixel_values = np.stack(pixel_values, axis=0) return pixel_values, image_hash, image_size else: + # It is an image image_hash = hash(image_data) if image_aspect_ratio == "pad": image = expand2square( image, tuple(int(x * 255) for x in processor.image_processor.image_mean), ) - pixel_values = processor.image_processor(image)["pixel_values"][0] - elif image_aspect_ratio == "anyres": + pixel_values = processor.image_processor(image.convert("RGB"))[ + "pixel_values" + ][0] + elif image_aspect_ratio == "anyres" or ( + image_aspect_ratio is not None and "anyres_max" in image_aspect_ratio + ): pixel_values = process_anyres_image( image, processor.image_processor, image_grid_pinpoints ) else: pixel_values = processor.image_processor(image)["pixel_values"][0] - pixel_values = pixel_values.astype(np.float16) + + if isinstance(pixel_values, np.ndarray): + pixel_values = pixel_values.astype(np.float16) + return pixel_values, image_hash, image.size except Exception: - print("Exception in TokenizerManager:\n" + get_exception_traceback()) + logger.error("Exception in TokenizerManager:\n" + get_exception_traceback()) diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 41f90830123..d914a71c27a 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -31,7 +31,7 @@ from sglang.srt.constrained.fsm_cache import FSMCache from sglang.srt.constrained.jump_forward import JumpForwardCache from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer -from sglang.srt.layers.logits_processor import LogitProcessorOutput +from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.managers.io_struct import ( AbortReq, BatchEmbeddingOut, @@ -56,6 +56,7 @@ from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.server_args import ServerArgs from sglang.srt.utils import ( + configure_logger, is_multimodal_model, set_random_seed, suppress_other_loggers, @@ -75,7 +76,7 @@ def __init__( tp_rank: int, server_args: ServerArgs, nccl_port: int, - model_overide_args: dict, + model_override_args: dict, ): suppress_other_loggers() @@ -92,8 +93,9 @@ def __init__( server_args.model_path, server_args.trust_remote_code, context_length=server_args.context_length, - model_overide_args=model_overide_args, + model_override_args=model_override_args, ) + self.model_runner = ModelRunner( model_config=self.model_config, mem_fraction_static=server_args.mem_fraction_static, @@ -106,7 +108,7 @@ def __init__( if server_args.skip_tokenizer_init: self.tokenizer = self.processor = None else: - if is_multimodal_model(server_args.model_path): + if is_multimodal_model(self.model_config.hf_config.architectures): self.processor = get_processor( server_args.tokenizer_path, tokenizer_mode=server_args.tokenizer_mode, @@ -144,7 +146,6 @@ def __init__( # Print info logger.info( - f"[gpu={self.gpu_id}] " f"max_total_num_tokens={self.max_total_num_tokens}, " f"max_prefill_tokens={self.max_prefill_tokens}, " f"max_running_requests={self.max_running_requests}, " @@ -196,6 +197,16 @@ def __init__( "trust_remote_code": server_args.trust_remote_code, }, skip_tokenizer_init=server_args.skip_tokenizer_init, + json_schema_mode=False, + ) + self.json_fsm_cache = FSMCache( + server_args.tokenizer_path, + { + "tokenizer_mode": server_args.tokenizer_mode, + "trust_remote_code": server_args.trust_remote_code, + }, + skip_tokenizer_init=server_args.skip_tokenizer_init, + json_schema_mode=True, ) self.jump_forward_cache = JumpForwardCache() @@ -210,6 +221,7 @@ def __init__( ) self.new_token_ratio = self.min_new_token_ratio self.new_token_ratio_decay = global_config.new_token_ratio_decay + self.do_not_get_new_batch = False def exposed_step(self, recv_reqs: List): try: @@ -242,7 +254,13 @@ def exposed_step(self, recv_reqs: List): @torch.inference_mode() def forward_step(self): - new_batch = self.get_new_prefill_batch() + if self.current_inflight_req is not None: + self.do_not_get_new_batch = False + + new_batch = ( + self.get_new_prefill_batch() if not self.do_not_get_new_batch else None + ) + self.do_not_get_new_batch = False if new_batch is not None: # Run a new prefill batch @@ -283,7 +301,7 @@ def print_decode_stats(self): self.num_generated_tokens = 0 self.last_stats_tic = time.time() logger.info( - f"[gpu={self.gpu_id}] Decode batch. " + f"Decode batch. " f"#running-req: {len(self.running_batch.reqs)}, " f"#token: {num_used}, " f"token usage: {num_used / self.max_total_num_tokens:.2f}, " @@ -322,29 +340,42 @@ def handle_generate_request( if self.model_runner.is_generation: req.pixel_values = recv_req.pixel_values if req.pixel_values is not None: + # Use image hash as fake token_ids, which is then used + # for prefix matching + image_hash = hash(tuple(recv_req.image_hashes)) req.pad_value = [ - (recv_req.image_hash) % self.model_config.vocab_size, - (recv_req.image_hash >> 16) % self.model_config.vocab_size, - (recv_req.image_hash >> 32) % self.model_config.vocab_size, - (recv_req.image_hash >> 64) % self.model_config.vocab_size, + (image_hash) % self.model_config.vocab_size, + (image_hash >> 16) % self.model_config.vocab_size, + (image_hash >> 32) % self.model_config.vocab_size, + (image_hash >> 64) % self.model_config.vocab_size, ] - req.image_size = recv_req.image_size + req.image_sizes = recv_req.image_sizes ( req.origin_input_ids, - req.image_offset, + req.image_offsets, ) = self.model_runner.model.pad_input_ids( req.origin_input_ids_unpadded, req.pad_value, - req.pixel_values.shape, - req.image_size, + req.pixel_values, + req.image_sizes, ) req.return_logprob = recv_req.return_logprob req.logprob_start_len = recv_req.logprob_start_len req.top_logprobs_num = recv_req.top_logprobs_num req.stream = recv_req.stream + # Init regex fsm fron json + if req.sampling_params.json_schema is not None: + req.regex_fsm, computed_regex_string = self.json_fsm_cache.query( + req.sampling_params.json_schema + ) + if not self.disable_regex_jump_forward: + req.jump_forward_map = self.jump_forward_cache.query( + computed_regex_string + ) + # Init regex fsm - if req.sampling_params.regex is not None: + elif req.sampling_params.regex is not None: req.regex_fsm = self.regex_fsm_cache.query(req.sampling_params.regex) if not self.disable_regex_jump_forward: req.jump_forward_map = self.jump_forward_cache.query( @@ -385,6 +416,8 @@ def get_new_prefill_batch(self) -> Optional[ScheduleBatch]: adder = PrefillAdder( self.tree_cache, + self.running_batch, + self.new_token_ratio, self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size(), self.max_prefill_tokens, self.chunked_prefill_size, @@ -392,7 +425,7 @@ def get_new_prefill_batch(self) -> Optional[ScheduleBatch]: ) if self.running_batch is not None: - adder.remove_running_tokens(self.running_batch, self.new_token_ratio) + adder.remove_running_tokens(self.running_batch) has_inflight = self.current_inflight_req is not None if self.current_inflight_req is not None: @@ -404,11 +437,12 @@ def get_new_prefill_batch(self) -> Optional[ScheduleBatch]: ) for req in self.waiting_queue: + if adder.no_remaining_tokens(): + break req.init_next_round_input(None if prefix_computed else self.tree_cache) res = adder.add_one_req(req) if ( not res - or adder.no_remaining_tokens() or running_bs + len(adder.can_run_list) >= self.max_running_requests ): break @@ -437,7 +471,7 @@ def get_new_prefill_batch(self) -> Optional[ScheduleBatch]: if num_mixed_running > 0: logger.info( - f"[gpu={self.gpu_id}] Prefill batch" + f"Prefill batch" f"(mixed #running-req: {num_mixed_running}). " f"#new-seq: {len(can_run_list)}, " f"#new-token: {adder.log_input_tokens}, " @@ -447,7 +481,7 @@ def get_new_prefill_batch(self) -> Optional[ScheduleBatch]: ) else: logger.info( - f"[gpu={self.gpu_id}] Prefill batch. " + f"Prefill batch. " f"#new-seq: {len(can_run_list)}, " f"#new-token: {adder.log_input_tokens}, " f"#cached-token: {adder.log_hit_tokens}, " @@ -480,21 +514,29 @@ def forward_prefill_batch(self, batch: ScheduleBatch): if self.model_runner.is_generation: # Forward and sample the next tokens if batch.extend_num_tokens != 0: - output = self.model_runner.forward(batch, ForwardMode.EXTEND) - next_token_ids = batch.sample(output.next_token_logits) + sample_output, logits_output = self.model_runner.forward( + batch, ForwardMode.EXTEND + ) + next_token_ids = batch.check_sample_results(sample_output) batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens( next_token_ids ) # Move logprobs to cpu - if output.next_token_logprobs is not None: - output.next_token_logprobs = output.next_token_logprobs[ - torch.arange(len(next_token_ids), device=next_token_ids.device), - next_token_ids, - ].tolist() - output.input_token_logprobs = output.input_token_logprobs.tolist() - output.normalized_prompt_logprobs = ( - output.normalized_prompt_logprobs.tolist() + if logits_output.next_token_logprobs is not None: + logits_output.next_token_logprobs = ( + logits_output.next_token_logprobs[ + torch.arange( + len(next_token_ids), device=next_token_ids.device + ), + next_token_ids, + ].tolist() + ) + logits_output.input_token_logprobs = ( + logits_output.input_token_logprobs.tolist() + ) + logits_output.normalized_prompt_logprobs = ( + logits_output.normalized_prompt_logprobs.tolist() ) next_token_ids = next_token_ids.tolist() @@ -533,12 +575,14 @@ def forward_prefill_batch(self, batch: ScheduleBatch): self.req_to_token_pool.free(req.req_pool_idx) if req.return_logprob: - self.add_logprob_return_values(i, req, pt, next_token_ids, output) + self.add_logprob_return_values( + i, req, pt, next_token_ids, logits_output + ) pt += req.extend_input_len else: assert batch.extend_num_tokens != 0 - output = self.model_runner.forward(batch, ForwardMode.EXTEND) - embeddings = output.embeddings.tolist() + logits_output = self.model_runner.forward(batch, ForwardMode.EXTEND) + embeddings = logits_output.embeddings.tolist() # Check finish conditions for i, req in enumerate(batch.reqs): @@ -566,7 +610,7 @@ def add_logprob_return_values( req: Req, pt: int, next_token_ids: List[int], - output: LogitProcessorOutput, + output: LogitsProcessorOutput, ): if req.normalized_prompt_logprob is None: req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i] @@ -625,7 +669,7 @@ def forward_decode_batch(self, batch: ScheduleBatch): self.new_token_ratio = new_token_ratio logger.info( - "decode out of memory happened, " + "Decode out of memory happened. " f"#retracted_reqs: {len(retracted_reqs)}, " f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}" ) @@ -648,15 +692,17 @@ def forward_decode_batch(self, batch: ScheduleBatch): batch.prepare_for_decode() # Forward and sample the next tokens - output = self.model_runner.forward(batch, ForwardMode.DECODE) - next_token_ids = batch.sample(output.next_token_logits) + sample_output, logits_output = self.model_runner.forward( + batch, ForwardMode.DECODE + ) + next_token_ids = batch.check_sample_results(sample_output) batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens( next_token_ids ) # Move logprobs to cpu - if output.next_token_logprobs is not None: - next_token_logprobs = output.next_token_logprobs[ + if logits_output.next_token_logprobs is not None: + next_token_logprobs = logits_output.next_token_logprobs[ torch.arange(len(next_token_ids), device=next_token_ids.device), next_token_ids, ].tolist() @@ -664,6 +710,7 @@ def forward_decode_batch(self, batch: ScheduleBatch): next_token_ids = next_token_ids.tolist() # Check finish condition + has_finished = False for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)): req.completion_tokens_wo_jump_forward += 1 req.output_ids.append(next_token_id) @@ -676,13 +723,17 @@ def forward_decode_batch(self, batch: ScheduleBatch): if req.finished(): self.tree_cache.cache_finished_req(req) + has_finished = True if req.return_logprob: req.output_token_logprobs.append( (next_token_logprobs[i], next_token_id) ) if req.top_logprobs_num > 0: - req.output_top_logprobs.append(output.output_top_logprobs[i]) + req.output_top_logprobs.append(logits_output.output_top_logprobs[i]) + + if not has_finished: + self.do_not_get_new_batch = True self.handle_finished_requests(batch) @@ -840,16 +891,18 @@ def run_tp_server( tp_rank: int, server_args: ServerArgs, nccl_port: int, - model_overide_args: dict, + model_override_args: dict, ): - """Run a tensor parallel server.""" + """Run a tensor parallel model server.""" + configure_logger(server_args, prefix=f" TP{tp_rank}") + try: model_server = ModelTpServer( gpu_id, tp_rank, server_args, nccl_port, - model_overide_args, + model_override_args, ) tp_cpu_group = model_server.model_runner.tp_group.cpu_group @@ -866,14 +919,14 @@ def launch_tp_servers( tp_rank_range: List[int], server_args: ServerArgs, nccl_port: int, - model_overide_args: dict, + model_override_args: dict, ): """Launch multiple tensor parallel servers.""" procs = [] for i in tp_rank_range: proc = multiprocessing.Process( target=run_tp_server, - args=(gpu_ids[i], i, server_args, nccl_port, model_overide_args), + args=(gpu_ids[i], i, server_args, nccl_port, model_override_args), ) proc.start() procs.append(proc) diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index 68cefbbf9f7..fef74321ac6 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -16,7 +16,8 @@ """Memory pool.""" import logging -from typing import List, Union +from abc import ABC, abstractmethod +from typing import List, Tuple, Union import torch @@ -52,14 +53,21 @@ def clear(self): self.free_slots = list(range(self.size)) -class BaseTokenToKVPool: +class BaseTokenToKVPool(ABC): """A memory pool that maps a token to its kv cache locations""" def __init__( self, size: int, + dtype: torch.dtype, ): self.size = size + self.dtype = dtype + if dtype == torch.float8_e5m2: + # NOTE: Store as torch.uint8 because Tensor index_put is not implemented for torch.float8_e5m2 + self.store_dtype = torch.uint8 + else: + self.store_dtype = dtype # We also add one slot. This slot is used for writing dummy output from padded tokens. self.mem_state = torch.ones((self.size + 1,), dtype=torch.bool, device="cuda") @@ -112,6 +120,28 @@ def clear(self): # We also add one slot. This slot is used for writing dummy output from padded tokens. self.mem_state[0] = False + @abstractmethod + def get_key_buffer(self, layer_id: int) -> torch.Tensor: + raise NotImplementedError() + + @abstractmethod + def get_value_buffer(self, layer_id: int) -> torch.Tensor: + raise NotImplementedError() + + @abstractmethod + def get_kv_buffer(self, layer_id: int) -> Tuple[torch.Tensor, torch.Tensor]: + raise NotImplementedError() + + @abstractmethod + def set_kv_buffer( + self, + layer_id: int, + loc: torch.Tensor, + cache_k: torch.Tensor, + cache_v: torch.Tensor, + ) -> None: + raise NotImplementedError() + class MHATokenToKVPool(BaseTokenToKVPool): @@ -123,26 +153,52 @@ def __init__( head_dim: int, layer_num: int, ): - super().__init__(size) + super().__init__(size, dtype) # [size, head_num, head_dim] for each layer self.k_buffer = [ - torch.empty((size + 1, head_num, head_dim), dtype=dtype, device="cuda") + torch.empty( + (size + 1, head_num, head_dim), dtype=self.store_dtype, device="cuda" + ) for _ in range(layer_num) ] self.v_buffer = [ - torch.empty((size + 1, head_num, head_dim), dtype=dtype, device="cuda") + torch.empty( + (size + 1, head_num, head_dim), dtype=self.store_dtype, device="cuda" + ) for _ in range(layer_num) ] def get_key_buffer(self, layer_id: int): + if self.store_dtype != self.dtype: + return self.k_buffer[layer_id].view(self.dtype) return self.k_buffer[layer_id] def get_value_buffer(self, layer_id: int): + if self.store_dtype != self.dtype: + return self.v_buffer[layer_id].view(self.dtype) return self.v_buffer[layer_id] def get_kv_buffer(self, layer_id: int): - return self.k_buffer[layer_id], self.v_buffer[layer_id] + return self.get_key_buffer(layer_id), self.get_value_buffer(layer_id) + + def set_kv_buffer( + self, + layer_id: int, + loc: torch.Tensor, + cache_k: torch.Tensor, + cache_v: torch.Tensor, + ): + if cache_k.dtype != self.dtype: + cache_k = cache_k.to(self.dtype) + if cache_v.dtype != self.dtype: + cache_v = cache_v.to(self.dtype) + if self.store_dtype != self.dtype: + self.k_buffer[layer_id][loc] = cache_k.view(self.store_dtype) + self.v_buffer[layer_id][loc] = cache_v.view(self.store_dtype) + else: + self.k_buffer[layer_id][loc] = cache_k + self.v_buffer[layer_id][loc] = cache_v class MLATokenToKVPool(BaseTokenToKVPool): @@ -155,23 +211,41 @@ def __init__( qk_rope_head_dim: int, layer_num: int, ): - super().__init__(size) + super().__init__(size, dtype) self.kv_lora_rank = kv_lora_rank self.kv_buffer = [ torch.empty( (size + 1, 1, kv_lora_rank + qk_rope_head_dim), - dtype=dtype, + dtype=self.store_dtype, device="cuda", ) for _ in range(layer_num) ] def get_key_buffer(self, layer_id: int): + if self.store_dtype != self.dtype: + return self.kv_buffer[layer_id].view(self.dtype) return self.kv_buffer[layer_id] def get_value_buffer(self, layer_id: int): + if self.store_dtype != self.dtype: + return self.kv_buffer[layer_id][..., : self.kv_lora_rank].view(self.dtype) return self.kv_buffer[layer_id][..., : self.kv_lora_rank] def get_kv_buffer(self, layer_id: int): return self.get_key_buffer(layer_id), self.get_value_buffer(layer_id) + + def set_kv_buffer( + self, + layer_id: int, + loc: torch.Tensor, + cache_k: torch.Tensor, + cache_v: torch.Tensor, + ): + if cache_k.dtype != self.dtype: + cache_k = cache_k.to(self.dtype) + if self.store_dtype != self.dtype: + self.kv_buffer[layer_id][loc] = cache_k.view(self.store_dtype) + else: + self.kv_buffer[layer_id][loc] = cache_k diff --git a/python/sglang/srt/mm_utils.py b/python/sglang/srt/mm_utils.py index e09c8215c6d..7918f3f7111 100644 --- a/python/sglang/srt/mm_utils.py +++ b/python/sglang/srt/mm_utils.py @@ -13,10 +13,25 @@ limitations under the License. """ -# Source: https://github.com/haotian-liu/LLaVA/blob/main/llava/mm_utils.py +# Source: https://github.com/LLaVA-VL/LLaVA-NeXT/blob/main/llava/mm_utils.py +""" +Utilities for multi-modal models. + +This python file mainly contains utilities that were used in the +image processing logic of llava-next including operations such as +anyres and anyres_max + +Currently supports the anyres and anyres_max operation for CLIP and +SigLip. For more information, you may refer to the paper or the blog + +LLaVA-NeXT : https://llava-vl.github.io/blog/2024-01-30-llava-next/ +LLaVA-Onevision : https://arxiv.org/pdf/2408.03326 + +""" import ast import base64 import math +import re from io import BytesIO import numpy as np @@ -40,10 +55,13 @@ def select_best_resolution(original_size, possible_resolutions): min_wasted_resolution = float("inf") for width, height in possible_resolutions: + # Calculate the downscaled size to keep the aspect ratio scale = min(width / original_width, height / original_height) downscaled_width, downscaled_height = int(original_width * scale), int( original_height * scale ) + + # Calculate effective and wasted resolutions effective_resolution = min( downscaled_width * downscaled_height, original_width * original_height ) @@ -129,6 +147,26 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): Returns: tuple: The shape of the image patch grid in the format (width, height). """ + if isinstance(grid_pinpoints, str) and "x" in grid_pinpoints: + assert patch_size in [ + 224, + 336, + 384, + 448, + 512, + ], "patch_size should be in [224, 336, 384, 448, 512]" + # Use regex to extract the range from the input string + matches = re.findall(r"\((\d+)x(\d+)\)", grid_pinpoints) + range_start = tuple(map(int, matches[0])) + range_end = tuple(map(int, matches[-1])) + # Generate a matrix of tuples from (range_start[0], range_start[1]) to (range_end[0], range_end[1]) + grid_pinpoints = [ + (i, j) + for i in range(range_start[0], range_end[0] + 1) + for j in range(range_start[1], range_end[1] + 1) + ] + # Multiply all elements by patch_size + grid_pinpoints = [[dim * patch_size for dim in pair] for pair in grid_pinpoints] if type(grid_pinpoints) is list: possible_resolutions = grid_pinpoints else: @@ -149,6 +187,31 @@ def process_anyres_image(image, processor, grid_pinpoints): Returns: np.array: An np array containing the processed image patches. """ + if isinstance(grid_pinpoints, str) and "x" in grid_pinpoints: + try: + patch_size = processor.size[0] + except Exception as e: + patch_size = processor.size["shortest_edge"] + assert patch_size in [ + 224, + 336, + 384, + 448, + 512, + ], "patch_size should be in [224, 336, 384, 448, 512]" + # Use regex to extract the range from the input string + matches = re.findall(r"\((\d+)x(\d+)\)", grid_pinpoints) + range_start = tuple(map(int, matches[0])) + range_end = tuple(map(int, matches[-1])) + # Generate a matrix of tuples from (range_start[0], range_start[1]) to (range_end[0], range_end[1]) + grid_pinpoints = [ + (i, j) + for i in range(range_start[0], range_end[0] + 1) + for j in range(range_start[1], range_end[1] + 1) + ] + # Multiply all elements by patch_size + grid_pinpoints = [[dim * patch_size for dim in pair] for pair in grid_pinpoints] + if type(grid_pinpoints) is list: possible_resolutions = grid_pinpoints else: @@ -156,15 +219,24 @@ def process_anyres_image(image, processor, grid_pinpoints): best_resolution = select_best_resolution(image.size, possible_resolutions) image_padded = resize_and_pad_image(image, best_resolution) - patches = divide_to_patches(image_padded, processor.crop_size["height"]) - - image_original_resize = image.resize( - (processor.size["shortest_edge"], processor.size["shortest_edge"]) + # For Siglip processor, only have size but no crop size + crop_size = ( + processor.crop_size["height"] + if "crop_size" in processor.__dict__ + else processor.size["height"] ) + shortest_edge = ( + processor.size["shortest_edge"] + if "shortest_edge" in processor.size + else processor.size["height"] + ) + patches = divide_to_patches(image_padded, crop_size) + + image_original_resize = image.resize((shortest_edge, shortest_edge)) image_patches = [image_original_resize] + patches image_patches = [ - processor.preprocess(image_patch)["pixel_values"][0] + processor.preprocess(image_patch.convert("RGB"))["pixel_values"][0] for image_patch in image_patches ] return np.stack(image_patches, axis=0) @@ -255,7 +327,7 @@ def process_images(images, image_processor, model_cfg): ) image = image_processor.preprocess(image)["pixel_values"][0] new_images.append(image) - elif image_aspect_ratio == "anyres": + elif "anyres" in image_aspect_ratio: for image in images: image = process_anyres_image( image, image_processor, model_cfg.image_grid_pinpoints diff --git a/python/sglang/srt/model_config.py b/python/sglang/srt/model_config.py index ed496515cd3..edf89f6b977 100644 --- a/python/sglang/srt/model_config.py +++ b/python/sglang/srt/model_config.py @@ -33,17 +33,17 @@ def __init__( trust_remote_code: bool = True, revision: Optional[str] = None, context_length: Optional[int] = None, - model_overide_args: Optional[dict] = None, + model_override_args: Optional[dict] = None, ) -> None: self.path = path self.trust_remote_code = trust_remote_code self.revision = revision - self.model_overide_args = model_overide_args + self.model_override_args = model_override_args self.hf_config = get_config( self.path, trust_remote_code, revision, - model_overide_args=model_overide_args, + model_override_args=model_override_args, ) self.hf_text_config = get_hf_text_config(self.hf_config) if context_length is not None: diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index d045be56d84..4459213b02f 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -17,6 +17,7 @@ import bisect from contextlib import contextmanager +from typing import Callable, List import torch from flashinfer import BatchDecodeWithPagedKVCacheWrapper @@ -25,16 +26,18 @@ from vllm.model_executor.custom_op import CustomOp from sglang.srt.layers.logits_processor import ( - LogitProcessorOutput, LogitsMetadata, LogitsProcessor, + LogitsProcessorOutput, ) +from sglang.srt.layers.sampler import SampleOutput from sglang.srt.managers.schedule_batch import ScheduleBatch from sglang.srt.model_executor.forward_batch_info import ( ForwardMode, InputMetadata, update_flashinfer_indices, ) +from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.utils import monkey_patch_vllm_all_gather @@ -43,20 +46,22 @@ def _to_torch(model: torch.nn.Module, reverse: bool = False): if isinstance(sub, CustomOp): if reverse: sub._forward_method = sub.forward_cuda + setattr(sub, "is_torch_compile", False) else: sub._forward_method = sub.forward_native + setattr(sub, "is_torch_compile", True) if isinstance(sub, torch.nn.Module): _to_torch(sub, reverse) @contextmanager def patch_model( - model: torch.nn.Module, use_compile: bool, tp_group: "GroupCoordinator" + model: torch.nn.Module, enable_compile: bool, tp_group: "GroupCoordinator" ): backup_ca_comm = None try: - if use_compile: + if enable_compile: _to_torch(model) monkey_patch_vllm_all_gather() backup_ca_comm = tp_group.ca_comm @@ -65,7 +70,7 @@ def patch_model( else: yield model.forward finally: - if use_compile: + if enable_compile: _to_torch(model, reverse=True) monkey_patch_vllm_all_gather(reverse=True) tp_group.ca_comm = backup_ca_comm @@ -86,7 +91,7 @@ def set_torch_compile_config(): class CudaGraphRunner: def __init__( self, - model_runner, + model_runner: "ModelRunner", max_batch_size_to_capture: int, use_torch_compile: bool, disable_padding: bool, @@ -143,18 +148,22 @@ def __init__( self.flashinfer_kv_indices.clone(), ] + # Sampling inputs + vocab_size = model_runner.model_config.vocab_size + self.sampling_info = SamplingBatchInfo.dummy_one(self.max_bs, vocab_size) + self.compile_bs = [1, 2, 4, 8, 16, 24, 32] if use_torch_compile else [] if use_torch_compile: set_torch_compile_config() - def can_run(self, batch_size): + def can_run(self, batch_size: int): if self.disable_padding: return batch_size in self.graphs else: return batch_size <= self.max_bs - def capture(self, batch_size_list): + def capture(self, batch_size_list: List[int]): self.batch_size_list = batch_size_list with graph_capture() as graph_capture_context: self.stream = graph_capture_context.stream @@ -175,7 +184,7 @@ def capture(self, batch_size_list): self.output_buffers[bs] = output_buffers self.flashinfer_handlers[bs] = flashinfer_handler - def capture_one_batch_size(self, bs, forward): + def capture_one_batch_size(self, bs: int, forward: Callable): graph = torch.cuda.CUDAGraph() stream = self.stream @@ -234,6 +243,7 @@ def capture_one_batch_size(self, bs, forward): def run_once(): input_metadata = InputMetadata( forward_mode=ForwardMode.DECODE, + sampling_info=self.sampling_info[:bs], batch_size=bs, req_pool_indices=req_pool_indices, seq_lens=seq_lens, @@ -298,27 +308,35 @@ def replay(self, batch: ScheduleBatch): self.flashinfer_handlers[bs], ) + # Sampling inputs + self.sampling_info.inplace_assign(raw_bs, batch.sampling_info) + # Replay torch.cuda.synchronize() self.graphs[bs].replay() torch.cuda.synchronize() - output = self.output_buffers[bs] + sample_output, logits_output = self.output_buffers[bs] # Unpad if bs != raw_bs: - output = LogitProcessorOutput( - next_token_logits=output.next_token_logits[:raw_bs], + logits_output = LogitsProcessorOutput( + next_token_logits=logits_output.next_token_logits[:raw_bs], next_token_logprobs=None, normalized_prompt_logprobs=None, input_token_logprobs=None, input_top_logprobs=None, output_top_logprobs=None, ) + sample_output = SampleOutput( + sample_output.success[:raw_bs], + sample_output.probs[:raw_bs], + sample_output.batch_next_token_ids[:raw_bs], + ) # Extract logprobs if batch.return_logprob: - output.next_token_logprobs = torch.nn.functional.log_softmax( - output.next_token_logits, dim=-1 + logits_output.next_token_logprobs = torch.nn.functional.log_softmax( + logits_output.next_token_logits, dim=-1 ) return_top_logprob = any(x > 0 for x in batch.top_logprobs_nums) if return_top_logprob: @@ -326,8 +344,8 @@ def replay(self, batch: ScheduleBatch): forward_mode=ForwardMode.DECODE, top_logprobs_nums=batch.top_logprobs_nums, ) - output.output_top_logprobs = LogitsProcessor.get_top_logprobs( - output.next_token_logprobs, logits_metadata + logits_output.output_top_logprobs = LogitsProcessor.get_top_logprobs( + logits_output.next_token_logprobs, logits_metadata )[1] - return output + return sample_output, logits_output diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index bac0a05378d..a443b113d44 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -1,3 +1,5 @@ +from __future__ import annotations + """ Copyright 2023-2024 SGLang Team Licensed under the Apache License, Version 2.0 (the "License"); @@ -16,16 +18,19 @@ """ModelRunner runs the forward passes of the models.""" from dataclasses import dataclass from enum import IntEnum, auto -from typing import TYPE_CHECKING, List, Optional +from typing import TYPE_CHECKING, List import numpy as np import torch +import triton +import triton.language as tl from sglang.srt.managers.schedule_batch import ScheduleBatch from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool if TYPE_CHECKING: from sglang.srt.model_executor.model_runner import ModelRunner + from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo class ForwardMode(IntEnum): @@ -42,6 +47,7 @@ class InputMetadata: """Store all inforamtion of a forward pass.""" forward_mode: ForwardMode + sampling_info: SamplingBatchInfo batch_size: int req_pool_indices: torch.Tensor seq_lens: torch.Tensor @@ -58,6 +64,7 @@ class InputMetadata: # For extend extend_seq_lens: torch.Tensor = None + extend_prefix_lens: torch.Tensor = None extend_start_loc: torch.Tensor = None extend_no_prefix: bool = None @@ -69,8 +76,8 @@ class InputMetadata: # For multimodal pixel_values: List[torch.Tensor] = None - image_sizes: List[List[int]] = None - image_offsets: List[int] = None + image_sizes: List[List[List[int]]] = None + image_offsets: List[List[int]] = None # Trition attention backend triton_max_seq_len: int = 0 @@ -87,15 +94,8 @@ class InputMetadata: def init_multimuldal_info(self, batch: ScheduleBatch): reqs = batch.reqs self.pixel_values = [r.pixel_values for r in reqs] - self.image_sizes = [r.image_size for r in reqs] - self.image_offsets = [ - ( - (r.image_offset - batch.prefix_lens_cpu[i]) - if r.image_offset is not None - else 0 - ) - for i, r in enumerate(reqs) - ] + self.image_sizes = [r.image_sizes for r in reqs] + self.image_offsets = [r.image_offsets for r in reqs] def compute_positions(self, batch: ScheduleBatch): position_ids_offsets = batch.position_ids_offsets @@ -148,6 +148,7 @@ def compute_extend_infos(self, batch: ScheduleBatch): for i, r in enumerate(batch.reqs) ] self.extend_seq_lens = torch.tensor(extend_lens_cpu, device="cuda") + self.extend_prefix_lens = torch.tensor(batch.prefix_lens_cpu, device="cuda") self.extend_start_loc = torch.zeros_like(self.seq_lens) self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0) self.extend_no_prefix = all(l == 0 for l in batch.prefix_lens_cpu) @@ -174,6 +175,7 @@ def from_schedule_batch( ): ret = cls( forward_mode=forward_mode, + sampling_info=batch.sampling_info, batch_size=batch.batch_size(), req_pool_indices=batch.req_pool_indices, seq_lens=batch.seq_lens, @@ -184,6 +186,8 @@ def from_schedule_batch( top_logprobs_nums=batch.top_logprobs_nums, ) + ret.sampling_info.prepare_penalties() + ret.compute_positions(batch) ret.compute_extend_infos(batch) @@ -233,10 +237,10 @@ def init_flashinfer_handlers( prefix_lens_cpu, flashinfer_use_ragged, ): - if self.forward_mode != ForwardMode.DECODE: - prefix_lens = torch.tensor(prefix_lens_cpu, device="cuda") - else: + if self.forward_mode == ForwardMode.DECODE: prefix_lens = None + else: + prefix_lens = self.extend_prefix_lens update_flashinfer_indices( self.forward_mode, @@ -260,6 +264,42 @@ def init_flashinfer_handlers( ) +@triton.jit +def create_flashinfer_kv_indices_triton( + req_to_token_ptr, # [max_batch, max_context_len] + req_pool_indices_ptr, + page_kernel_lens_ptr, + kv_indptr, + kv_start_idx, + max_context_len, + kv_indices_ptr, +): + BLOCK_SIZE: tl.constexpr = 512 + pid = tl.program_id(axis=0) + req_pool_index = tl.load(req_pool_indices_ptr + pid) + kv_indices_offset = tl.load(kv_indptr + pid) + + kv_start = 0 + kv_end = 0 + if kv_start_idx: + kv_start = tl.load(kv_start_idx + pid).to(tl.int32) + kv_end = kv_start + kv_end += tl.load(page_kernel_lens_ptr + pid).to(tl.int32) + + req_to_token_ptr += req_pool_index * max_context_len + kv_indices_ptr += kv_indices_offset + + ld_offset = kv_start + tl.arange(0, BLOCK_SIZE) + st_offset = tl.arange(0, BLOCK_SIZE) + num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE) + for _ in range(num_loop): + mask = ld_offset < kv_end + data = tl.load(req_to_token_ptr + ld_offset, mask=mask) + tl.store(kv_indices_ptr + st_offset, data, mask=mask) + ld_offset += BLOCK_SIZE + st_offset += BLOCK_SIZE + + def update_flashinfer_indices( forward_mode, model_runner, @@ -283,17 +323,18 @@ def update_flashinfer_indices( kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda") kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0) - req_pool_indices_cpu = req_pool_indices.cpu().numpy() - paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy() - kv_indices = torch.cat( - [ - model_runner.req_to_token_pool.req_to_token[ - req_pool_indices_cpu[i], : paged_kernel_lens_cpu[i] - ] - for i in range(batch_size) - ], - dim=0, - ).contiguous() + + kv_indices = torch.empty(kv_indptr[-1], dtype=torch.int32, device="cuda") + create_flashinfer_kv_indices_triton[(batch_size,)]( + model_runner.req_to_token_pool.req_to_token, + req_pool_indices, + paged_kernel_lens, + kv_indptr, + None, + model_runner.req_to_token_pool.req_to_token.size(1), + kv_indices, + ) + kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda") if forward_mode == ForwardMode.DECODE: @@ -310,6 +351,8 @@ def update_flashinfer_indices( num_kv_heads, head_dim, 1, + data_type=model_runner.kv_cache_dtype, + q_data_type=model_runner.dtype, ) else: # extend part @@ -361,18 +404,17 @@ def update_flashinfer_indices( kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda") kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0) - req_pool_indices_cpu = req_pool_indices.cpu().numpy() - paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy() - kv_indices = torch.cat( - [ - model_runner.req_to_token_pool.req_to_token[ - req_pool_indices_cpu[i], - kv_start_idx[i] : kv_start_idx[i] + paged_kernel_lens_cpu[i], - ] - for i in range(batch_size) - ], - dim=0, - ).contiguous() + + kv_indices = torch.empty(kv_indptr[-1], dtype=torch.int32, device="cuda") + create_flashinfer_kv_indices_triton[(batch_size,)]( + model_runner.req_to_token_pool.req_to_token, + req_pool_indices, + paged_kernel_lens, + kv_indptr, + kv_start_idx, + model_runner.req_to_token_pool.req_to_token.size(1), + kv_indices, + ) if forward_mode == ForwardMode.DECODE: # CUDA graph uses different flashinfer_decode_wrapper @@ -388,6 +430,8 @@ def update_flashinfer_indices( num_kv_heads, head_dim, 1, + data_type=model_runner.kv_cache_dtype, + q_data_type=model_runner.dtype, ) else: # extend part diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index b91191c5dc3..3d3e0cde9d1 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -20,9 +20,8 @@ import importlib.resources import logging import pkgutil -import warnings from functools import lru_cache -from typing import Optional, Type +from typing import Optional, Tuple, Type import torch import torch.nn as nn @@ -45,13 +44,15 @@ from vllm.model_executor.models import ModelRegistry from sglang.global_config import global_config +from sglang.srt.layers.logits_processor import LogitsProcessorOutput +from sglang.srt.layers.sampler import SampleOutput from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict from sglang.srt.mem_cache.memory_pool import ( MHATokenToKVPool, MLATokenToKVPool, ReqToTokenPool, ) -from sglang.srt.model_config import AttentionArch +from sglang.srt.model_config import AttentionArch, ModelConfig from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata from sglang.srt.server_args import ServerArgs from sglang.srt.utils import ( @@ -70,7 +71,7 @@ class ModelRunner: def __init__( self, - model_config, + model_config: ModelConfig, mem_fraction_static: float, gpu_id: int, tp_rank: int, @@ -86,28 +87,49 @@ def __init__( self.tp_size = tp_size self.nccl_port = nccl_port self.server_args = server_args - self.is_multimodal_model = is_multimodal_model(self.model_config) + self.is_multimodal_model = is_multimodal_model( + self.model_config.hf_config.architectures + ) global_server_args_dict.update( { "disable_flashinfer": server_args.disable_flashinfer, "disable_flashinfer_sampling": server_args.disable_flashinfer_sampling, - "attention_reduce_in_fp32": server_args.attention_reduce_in_fp32, + "triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32, "enable_mla": server_args.enable_mla, } ) + if self.is_multimodal_model: + logger.info( + "Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models." + ) + server_args.chunked_prefill_size = None + server_args.mem_fraction_static *= 0.95 + + min_per_gpu_memory = self.init_torch_distributed() + self.load_model() + self.init_memory_pool( + min_per_gpu_memory, + server_args.max_num_reqs, + server_args.max_total_tokens, + ) + self.init_cublas() + self.init_flashinfer() + self.init_cuda_graphs() + + def init_torch_distributed(self): # Init torch distributed torch.cuda.set_device(self.gpu_id) - logger.info(f"[gpu={self.gpu_id}] Init nccl begin.") + logger.info("Init nccl begin.") - if not server_args.enable_p2p_check: + if not self.server_args.enable_p2p_check: monkey_patch_vllm_p2p_access_check(self.gpu_id) - if server_args.nccl_init_addr: - nccl_init_method = f"tcp://{server_args.nccl_init_addr}" + if self.server_args.nccl_init_addr: + nccl_init_method = f"tcp://{self.server_args.nccl_init_addr}" else: nccl_init_method = f"tcp://127.0.0.1:{self.nccl_port}" - set_custom_all_reduce(not server_args.disable_custom_all_reduce) + set_custom_all_reduce(not self.server_args.disable_custom_all_reduce) init_distributed_environment( backend="nccl", world_size=self.tp_size, @@ -116,43 +138,41 @@ def __init__( distributed_init_method=nccl_init_method, ) initialize_model_parallel(tensor_model_parallel_size=self.tp_size) - total_gpu_memory = get_available_gpu_memory( + min_per_gpu_memory = get_available_gpu_memory( self.gpu_id, distributed=self.tp_size > 1 ) self.tp_group = get_tp_group() + # Currently, there is a bug with mulit-node tensor parallelsim + padded cuda graph, + # so we disable padding in cuda graph. + if not all(in_the_same_node_as(self.tp_group.cpu_group, source_rank=0)): + self.server_args.disable_cuda_graph_padding = True + logger.info( + "Setting disable_cuda_graph_padding to True because of multi-node tensor parallelism." + ) + + # Check memory for tensor parallelism if self.tp_size > 1: - total_local_gpu_memory = get_available_gpu_memory(self.gpu_id) - if total_local_gpu_memory < total_gpu_memory * 0.9: + local_gpu_memory = get_available_gpu_memory(self.gpu_id) + if min_per_gpu_memory < local_gpu_memory * 0.9: raise ValueError( "The memory capacity is unbalanced. Some GPUs may be occupied by other processes." ) - # Load the model and create memory pool - self.load_model() - self.init_memory_pool( - total_gpu_memory, - server_args.max_num_reqs, - server_args.max_total_tokens, - ) - self.init_cublas() - self.init_flashinfer() - - if self.is_generation: - # FIXME Currently, cuda graph only capture decode steps, which only exists in causal models - # Capture cuda graphs - self.init_cuda_graphs() + return min_per_gpu_memory def load_model(self): + torch.set_num_threads(1) logger.info( - f"[gpu={self.gpu_id}] Load weight begin. " - f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB" + f"Load weight begin. avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB" ) if torch.cuda.get_device_capability()[0] < 8: logger.info( - "Compute capability below sm80 use float16 due to lack of bfloat16 support." + "Compute capability below sm80. Use float16 due to lack of bfloat16 support." ) self.server_args.dtype = "float16" + if torch.cuda.get_device_capability()[1] < 5: + raise RuntimeError("SGLang only supports sm75 and above.") monkey_patch_vllm_dummy_weight_loader() self.device_config = DeviceConfig() @@ -168,45 +188,46 @@ def load_model(self): skip_tokenizer_init=True, ) + # A temporary hack to fix the num_heads for meta-llama/Meta-Llama-3.1-405B-FP8 checkpoints + # Drop this after Sept, 2024. if is_llama3_405b_fp8_head_16(self.model_config) and self.tp_size <= 8: - # A temporary hack to fix the num_heads for meta-llama/Meta-Llama-3.1-405B-FP8 checkpoints self.model_config.hf_config.num_key_value_heads = 8 self.vllm_model_config.hf_config.num_key_value_heads = 8 monkey_patch_vllm_qvk_linear_loader() self.dtype = self.vllm_model_config.dtype - if self.model_config.model_overide_args is not None: + if self.model_config.model_override_args is not None: self.vllm_model_config.hf_config.update( - self.model_config.model_overide_args + self.model_config.model_override_args ) self.model = get_model( model_config=self.vllm_model_config, - device_config=self.device_config, load_config=self.load_config, - lora_config=None, - multimodal_config=None, + device_config=self.device_config, parallel_config=None, scheduler_config=None, + lora_config=None, cache_config=None, ) self.sliding_window_size = ( - self.model.get_window_size() - if hasattr(self.model, "get_window_size") + self.model.get_attention_sliding_window_size() + if hasattr(self.model, "get_attention_sliding_window_size") else None ) self.is_generation = is_generation_model( - self.model_config.hf_config.architectures + self.model_config.hf_config.architectures, self.server_args.is_embedding ) logger.info( - f"[gpu={self.gpu_id}] Load weight end. " + f"Load weight end. " f"type={type(self.model).__name__}, " f"dtype={self.dtype}, " f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB" ) - def update_weights(self, model_path, load_format): + def update_weights(self, model_path: str, load_format: str): + """Update weights in-place.""" from vllm.model_executor.model_loader.loader import ( DefaultModelLoader, device_loading_context, @@ -215,13 +236,14 @@ def update_weights(self, model_path, load_format): from vllm.model_executor.model_loader.utils import set_default_torch_dtype logger.info( - f"[gpu={self.gpu_id}] Update weights begin. " + f"Update weights begin. " f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB" ) target_device = torch.device(self.device_config.device) try: + # TODO: Use a better method to check this vllm_model_config = VllmModelConfig( model=model_path, quantization=self.server_args.quantization, @@ -288,10 +310,10 @@ def model_load_weights(model, iter): self.load_config = load_config self.model_config.path = model_path - logger.info(f"[gpu={self.gpu_id}] Update weights end.") + logger.info("Update weights end.") return True, "Succeeded to update model weights" - def profile_max_num_token(self, total_gpu_memory): + def profile_max_num_token(self, total_gpu_memory: int): available_gpu_memory = get_available_gpu_memory( self.gpu_id, distributed=self.tp_size > 1 ) @@ -302,7 +324,7 @@ def profile_max_num_token(self, total_gpu_memory): cell_size = ( (self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim) * self.model_config.num_hidden_layers - * torch._utils._element_size(self.dtype) + * torch._utils._element_size(self.kv_cache_dtype) ) else: cell_size = ( @@ -310,7 +332,7 @@ def profile_max_num_token(self, total_gpu_memory): * self.model_config.head_dim * self.model_config.num_hidden_layers * 2 - * torch._utils._element_size(self.dtype) + * torch._utils._element_size(self.kv_cache_dtype) ) rest_memory = available_gpu_memory - total_gpu_memory * ( 1 - self.mem_fraction_static @@ -319,8 +341,20 @@ def profile_max_num_token(self, total_gpu_memory): return max_num_token def init_memory_pool( - self, total_gpu_memory, max_num_reqs=None, max_total_tokens=None + self, + total_gpu_memory: int, + max_num_reqs: int = None, + max_total_tokens: int = None, ): + if self.server_args.kv_cache_dtype == "auto": + self.kv_cache_dtype = self.dtype + elif self.server_args.kv_cache_dtype == "fp8_e5m2": + self.kv_cache_dtype = torch.float8_e5m2 + else: + raise ValueError( + f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}." + ) + self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory) if max_total_tokens is not None: if max_total_tokens > self.max_total_num_tokens: @@ -357,7 +391,7 @@ def init_memory_pool( ): self.token_to_kv_pool = MLATokenToKVPool( self.max_total_num_tokens, - dtype=self.dtype, + dtype=self.kv_cache_dtype, kv_lora_rank=self.model_config.kv_lora_rank, qk_rope_head_dim=self.model_config.qk_rope_head_dim, layer_num=self.model_config.num_hidden_layers, @@ -368,13 +402,13 @@ def init_memory_pool( else: self.token_to_kv_pool = MHATokenToKVPool( self.max_total_num_tokens, - dtype=self.dtype, + dtype=self.kv_cache_dtype, head_num=self.model_config.get_num_kv_heads(self.tp_size), head_dim=self.model_config.head_dim, layer_num=self.model_config.num_hidden_layers, ) logger.info( - f"[gpu={self.gpu_id}] Memory pool end. " + f"Memory pool end. " f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB" ) @@ -388,6 +422,7 @@ def init_cublas(self): return c def init_flashinfer(self): + """Init flashinfer attention kernel wrappers.""" if self.server_args.disable_flashinfer: assert ( self.sliding_window_size is None @@ -448,16 +483,24 @@ def init_flashinfer(self): ) def init_cuda_graphs(self): + """Capture cuda graphs.""" + if not self.is_generation: + # TODO: Currently, cuda graph only captures decode steps, which only exists for generation models + return + from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner if self.server_args.disable_cuda_graph or self.server_args.disable_flashinfer: self.cuda_graph_runner = None return - logger.info( - f"[gpu={self.gpu_id}] Capture cuda graph begin. This can take up to several minutes." - ) - batch_size_list = [1, 2, 4] + [i * 8 for i in range(1, 17)] + logger.info("Capture cuda graph begin. This can take up to several minutes.") + + if self.server_args.disable_cuda_graph_padding: + batch_size_list = list(range(1, 32)) + [64, 128] + else: + batch_size_list = [1, 2, 4] + [i * 8 for i in range(1, 21)] + self.cuda_graph_runner = CudaGraphRunner( self, max_batch_size_to_capture=max(batch_size_list), @@ -470,15 +513,19 @@ def init_cuda_graphs(self): raise Exception( f"Capture cuda graph failed: {e}\n" "Possible solutions:\n" - "1. disable torch compile by not using --enable-torch-compile\n" - "2. disable cuda graph by --disable-cuda-graph\n" - "3. set --mem-fraction-static to a smaller value\n" + "1. disable cuda graph by --disable-cuda-graph\n" + "2. set --mem-fraction-static to a smaller value\n" + "3. disable torch compile by not using --enable-torch-compile\n" "Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n" ) @torch.inference_mode() def forward_decode(self, batch: ScheduleBatch): - if self.cuda_graph_runner and self.cuda_graph_runner.can_run(len(batch.reqs)): + if ( + self.cuda_graph_runner + and self.cuda_graph_runner.can_run(len(batch.reqs)) + and batch.sampling_info.can_run_in_cuda_graph() + ): return self.cuda_graph_runner.replay(batch) input_metadata = InputMetadata.from_schedule_batch( @@ -498,9 +545,18 @@ def forward_extend(self, batch: ScheduleBatch): batch, forward_mode=ForwardMode.EXTEND, ) - return self.model.forward( - batch.input_ids, input_metadata.positions, input_metadata - ) + if self.is_generation: + return self.model.forward( + batch.input_ids, input_metadata.positions, input_metadata + ) + else: + # Only embedding models have get_embedding parameter + return self.model.forward( + batch.input_ids, + input_metadata.positions, + input_metadata, + get_embedding=True, + ) @torch.inference_mode() def forward_extend_multi_modal(self, batch: ScheduleBatch): @@ -518,7 +574,9 @@ def forward_extend_multi_modal(self, batch: ScheduleBatch): input_metadata.image_offsets, ) - def forward(self, batch: ScheduleBatch, forward_mode: ForwardMode): + def forward( + self, batch: ScheduleBatch, forward_mode: ForwardMode + ) -> Tuple[SampleOutput, LogitsProcessorOutput]: if self.is_multimodal_model and forward_mode == ForwardMode.EXTEND: return self.forward_extend_multi_modal(batch) elif forward_mode == ForwardMode.DECODE: @@ -549,16 +607,6 @@ def import_model_classes(): assert entry.__name__ not in model_arch_name_to_cls model_arch_name_to_cls[entry.__name__] = entry - # compat: some models such as chatglm has incorrect class set in config.json - # usage: [ tuple("From_Entry_Class_Name": EntryClass), ] - if hasattr(module, "EntryClassRemapping") and isinstance( - module.EntryClassRemapping, list - ): - for remap in module.EntryClassRemapping: - if isinstance(remap, tuple) and len(remap) == 2: - assert remap[0] not in model_arch_name_to_cls - model_arch_name_to_cls[remap[0]] = remap[1] - return model_arch_name_to_cls @@ -574,4 +622,4 @@ def load_model_cls_srt(model_arch: str) -> Optional[Type[nn.Module]]: # Monkey patch model loader -setattr(ModelRegistry, "load_model_cls", load_model_cls_srt) +setattr(ModelRegistry, "_try_load_model_cls", load_model_cls_srt) diff --git a/python/sglang/srt/models/chatglm.py b/python/sglang/srt/models/chatglm.py index 0a22f994bb4..94b405f8e8f 100644 --- a/python/sglang/srt/models/chatglm.py +++ b/python/sglang/srt/models/chatglm.py @@ -17,7 +17,7 @@ # Adapted from # https://github.com/THUDM/ChatGLM2-6B """Inference-only ChatGLM model compatible with THUDM weights.""" -from typing import Iterable, List, Optional, Tuple +from typing import Iterable, Optional, Tuple import torch from torch import nn @@ -31,20 +31,18 @@ ) from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import SamplerOutput from vllm.transformers_utils.configs import ChatGLMConfig from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.sampler import Sampler from sglang.srt.model_executor.forward_batch_info import InputMetadata LoraConfig = None @@ -383,17 +381,11 @@ def forward( input_metadata: InputMetadata, ) -> torch.Tensor: hidden_states = self.transformer(input_ids, positions, input_metadata) - return self.logits_processor( + logits_output = self.logits_processor( input_ids, hidden_states, self.lm_head.weight, input_metadata ) - - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens + sample_output = self.sampler(logits_output, input_metadata.sampling_info) + return sample_output, logits_output def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): params_dict = dict(self.named_parameters(remove_duplicate=False)) @@ -410,6 +402,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader(param, loaded_weight) -EntryClass = ChatGLMForCausalLM -# compat: glm model.config class == ChatGLMModel -EntryClassRemapping = [("ChatGLMModel", ChatGLMForCausalLM)] +class ChatGLMModel(ChatGLMForCausalLM): + pass + + +EntryClass = [ChatGLMForCausalLM, ChatGLMModel] diff --git a/python/sglang/srt/models/commandr.py b/python/sglang/srt/models/commandr.py index f6d6f6e1f94..c360106f97c 100644 --- a/python/sglang/srt/models/commandr.py +++ b/python/sglang/srt/models/commandr.py @@ -64,6 +64,7 @@ from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.sampler import Sampler from sglang.srt.model_executor.forward_batch_info import InputMetadata @@ -326,6 +327,7 @@ def __init__( self.config = config self.quant_config = quant_config self.logits_processor = LogitsProcessor(config) + self.sampler = Sampler() self.model = CohereModel(config, quant_config) @torch.no_grad() @@ -340,9 +342,11 @@ def forward( positions, input_metadata, ) - return self.logits_processor( + logits_output = self.logits_processor( input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata ) + sample_output = self.sampler(logits_output, input_metadata.sampling_info) + return sample_output, logits_output def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ diff --git a/python/sglang/srt/models/dbrx.py b/python/sglang/srt/models/dbrx.py index 39ac4aefa72..b3a76b56ae2 100644 --- a/python/sglang/srt/models/dbrx.py +++ b/python/sglang/srt/models/dbrx.py @@ -45,6 +45,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.sampler import Sampler from sglang.srt.model_executor.forward_batch_info import InputMetadata @@ -382,6 +383,7 @@ def __init__( padding_size=DEFAULT_VOCAB_PADDING_SIZE, ) self.logits_processor = LogitsProcessor(config) + self.sampler = Sampler() @torch.no_grad() def forward( @@ -391,9 +393,11 @@ def forward( input_metadata: InputMetadata, ) -> torch.Tensor: hidden_states = self.transformer(input_ids, positions, input_metadata) - return self.logits_processor( + logits_output = self.logits_processor( input_ids, hidden_states, self.lm_head.weight, input_metadata ) + sample_output = self.sampler(logits_output, input_metadata.sampling_info) + return sample_output, logits_output def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): expert_params_mapping = [ diff --git a/python/sglang/srt/models/deepseek.py b/python/sglang/srt/models/deepseek.py index 59fd1ec7ed8..b939602c1ba 100644 --- a/python/sglang/srt/models/deepseek.py +++ b/python/sglang/srt/models/deepseek.py @@ -46,6 +46,7 @@ from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.sampler import Sampler from sglang.srt.model_executor.forward_batch_info import InputMetadata @@ -385,6 +386,7 @@ def __init__( config.vocab_size, config.hidden_size, quant_config=quant_config ) self.logits_processor = LogitsProcessor(config) + self.sampler = Sampler() @torch.no_grad() def forward( @@ -394,9 +396,11 @@ def forward( input_metadata: InputMetadata, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, input_metadata) - return self.logits_processor( + logits_output = self.logits_processor( input_ids, hidden_states, self.lm_head.weight, input_metadata ) + sample_output = self.sampler(logits_output, input_metadata.sampling_info) + return sample_output, logits_output def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 13dd477392e..bb80e2da2f5 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -19,6 +19,7 @@ from typing import Any, Dict, Iterable, Optional, Tuple import torch +from flashinfer import bmm_fp8 from torch import nn from transformers import PretrainedConfig from vllm.config import CacheConfig @@ -45,6 +46,7 @@ from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.sampler import Sampler from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import InputMetadata @@ -160,6 +162,15 @@ def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float: return 0.1 * mscale * math.log(scale) + 1.0 +def input_to_float8(x, dtype=torch.float8_e4m3fn): + finfo = torch.finfo(dtype) + min_val, max_val = x.aminmax() + amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12) + scale = finfo.max / amax + x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max) + return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal() + + class DeepseekV2Attention(nn.Module): def __init__( @@ -254,11 +265,6 @@ def __init__( mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) self.scaling = self.scaling * mscale * mscale - # self.attn = Attention(self.num_heads, - # self.qk_head_dim, - # self.scaling, - # num_kv_heads=self.num_heads) - # TODO, support head_size 192 self.attn = RadixAttention( self.num_local_heads, @@ -282,7 +288,7 @@ def forward( q = self.q_proj(hidden_states)[0].view( -1, self.num_local_heads, self.qk_head_dim ) - q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + _, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0] kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) latent_cache = latent_cache.unsqueeze(1) @@ -416,12 +422,9 @@ def __init__( v_head_dim=self.kv_lora_rank, ) - kv_b_proj = self.kv_b_proj - w_kc, w_vc = kv_b_proj.weight.unflatten( - 0, (-1, qk_nope_head_dim + v_head_dim) - ).split([qk_nope_head_dim, v_head_dim], dim=1) - self.w_kc = w_kc - self.w_vc = w_vc + self.w_kc = None + self.w_vc = None + self.w_scale = None def forward( self, @@ -442,8 +445,17 @@ def forward( -1, self.num_local_heads, self.qk_head_dim ) q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) - q_nope_out = q_input[..., : self.kv_lora_rank] - torch.bmm(q_nope.transpose(0, 1), self.w_kc, out=q_nope_out.transpose(0, 1)) + + if self.w_kc.dtype == torch.float8_e4m3fn: + q_nope_val, q_nope_scale = input_to_float8( + q_nope.transpose(0, 1), torch.float8_e4m3fn + ) + q_nope_out = bmm_fp8( + q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16 + ) + else: + q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc) + q_input[..., : self.kv_lora_rank] = q_nope_out.transpose(0, 1) latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0] v_input = latent_cache[..., : self.kv_lora_rank] @@ -458,16 +470,21 @@ def forward( attn_output = self.attn(q_input, k_input, v_input, input_metadata) attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank) - attn_bmm_output = attn_output.new_empty( - q_len, self.num_local_heads, self.v_head_dim - ) - torch.bmm( - attn_output.transpose(0, 1), - self.w_vc.transpose(1, 2).contiguous(), - out=attn_bmm_output.transpose(0, 1), - ) - attn_output = attn_bmm_output.flatten(1, 2) + if self.w_vc.dtype == torch.float8_e4m3fn: + attn_output_val, attn_output_scale = input_to_float8( + attn_output.transpose(0, 1), torch.float8_e4m3fn + ) + attn_bmm_output = bmm_fp8( + attn_output_val, + self.w_vc, + attn_output_scale, + self.w_scale, + torch.bfloat16, + ) + else: + attn_bmm_output = torch.bmm(attn_output.transpose(0, 1), self.w_vc) + attn_output = attn_bmm_output.transpose(0, 1).flatten(1, 2) output, _ = self.o_proj(attn_output) return output @@ -632,6 +649,7 @@ def __init__( config.vocab_size, config.hidden_size, quant_config=quant_config ) self.logits_processor = LogitsProcessor(config) + self.sampler = Sampler() def forward( self, @@ -640,9 +658,11 @@ def forward( input_metadata: InputMetadata, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, input_metadata) - return self.logits_processor( + logits_output = self.logits_processor( input_ids, hidden_states, self.lm_head.weight, input_metadata ) + sample_output = self.sampler(logits_output, input_metadata.sampling_info) + return sample_output, logits_output def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ @@ -695,7 +715,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader( param, loaded_weight, - weight_name, + name, shard_id=shard_id, expert_id=expert_id, ) @@ -711,5 +731,17 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ) weight_loader(param, loaded_weight) + if global_server_args_dict["enable_mla"]: + for layer_id in range(self.config.num_hidden_layers): + self_attn = self.model.layers[layer_id].self_attn + w_kc, w_vc = self_attn.kv_b_proj.weight.unflatten( + 0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim) + ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1) + self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2) + self_attn.w_vc = w_vc.contiguous().transpose(1, 2) + if hasattr(self_attn.kv_b_proj, "weight_scale"): + self_attn.w_scale = self_attn.kv_b_proj.weight_scale + del self_attn.kv_b_proj + EntryClass = DeepseekV2ForCausalLM diff --git a/python/sglang/srt/models/exaone.py b/python/sglang/srt/models/exaone.py new file mode 100644 index 00000000000..bb077f2c87d --- /dev/null +++ b/python/sglang/srt/models/exaone.py @@ -0,0 +1,368 @@ +""" +Copyright 2024 The LGcns AI Engineering Team +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +# Adapted from llama2.py +"""Inference-only Exaone model compatible with HuggingFace weights.""" + +from typing import Any, Dict, Iterable, Optional, Tuple + +import torch +from torch import nn +from vllm.config import CacheConfig +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) +from vllm.model_executor.layers.quantization.base_config import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, + VocabParallelEmbedding, +) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader + +from sglang.srt.layers.activation import SiluAndMul +from sglang.srt.layers.layernorm import RMSNorm +from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput +from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.sampler import Sampler +from sglang.srt.model_executor.forward_batch_info import InputMetadata + + +class ExaoneGatedMLP(nn.Module): + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, + [intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", + ) + self.c_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.c_proj", + ) + if hidden_act != "silu": + raise ValueError( + f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now." + ) + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.c_proj(x) + return x + + +class ExaoneAttention(nn.Module): + def __init__( + self, + config, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + layer_id: int = 0, + rope_theta: float = 500000, + rope_scaling: Optional[Dict[str, Any]] = None, + rope_is_neox_style: bool = True, + max_position_embeddings: int = 4096, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + # MistralConfig has an optional head_dim introduced by Mistral-Nemo + self.head_dim = getattr( + config, "head_dim", self.hidden_size // self.total_num_heads + ) + self.rotary_dim = int( + self.head_dim * getattr(config, "partial_rotary_factor", 1) + ) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.out_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.out_proj", + ) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.rotary_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + is_neox_style=rope_is_neox_style, + ) + self.attn = RadixAttention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + layer_id=layer_id, + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + input_metadata: InputMetadata, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, input_metadata) + output, _ = self.out_proj(attn_output) + return output + + +class ExaoneDecoderLayer(nn.Module): + def __init__( + self, + config, + layer_id: int = 0, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 500000) + rope_scaling = getattr(config, "rope_scaling", None) + if rope_scaling is not None and getattr( + config, "original_max_position_embeddings", None + ): + rope_scaling["original_max_position_embeddings"] = ( + config.original_max_position_embeddings + ) + rope_is_neox_style = getattr(config, "rope_is_neox_style", True) + max_position_embeddings = getattr(config, "max_position_embeddings", 4096) + self.self_attn = ExaoneAttention( + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + layer_id=layer_id, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + rope_is_neox_style=rope_is_neox_style, + max_position_embeddings=max_position_embeddings, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) + self.mlp = ExaoneGatedMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.activation_function, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + rms_norm_eps = config.layer_norm_epsilon + self.ln_1 = RMSNorm(config.hidden_size, eps=rms_norm_eps) + self.ln_2 = RMSNorm(config.hidden_size, eps=rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + input_metadata: InputMetadata, + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.ln_1(hidden_states) + else: + hidden_states, residual = self.ln_1(hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + input_metadata=input_metadata, + ) + + # Fully Connected + hidden_states, residual = self.ln_2(hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +class ExaoneModel(nn.Module): + def __init__( + self, + config, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.config = config + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.wte = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) + self.h = nn.ModuleList( + [ + ExaoneDecoderLayer( + config, i, quant_config=quant_config, prefix=f"model.h.{i}" + ) + for i in range(config.num_hidden_layers) + ] + ) + rms_norm_eps = config.layer_norm_epsilon + self.ln_f = RMSNorm(config.hidden_size, eps=rms_norm_eps) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + input_metadata: InputMetadata, + input_embeds: torch.Tensor = None, + ) -> torch.Tensor: + if input_embeds is None: + hidden_states = self.wte(input_ids) + else: + hidden_states = input_embeds + residual = None + for i in range(len(self.h)): + layer = self.h[i] + hidden_states, residual = layer( + positions, + hidden_states, + input_metadata, + residual, + ) + hidden_states, _ = self.ln_f(hidden_states, residual) + return hidden_states + + +class ExaoneForCausalLM(nn.Module): + def __init__( + self, + config, + quant_config: Optional[QuantizationConfig] = None, + cache_config: Optional[CacheConfig] = None, + ) -> None: + super().__init__() + self.config = config + self.quant_config = quant_config + self.transformer = ExaoneModel(config, quant_config=quant_config) + self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) + self.logits_processor = LogitsProcessor(config) + self.sampler = Sampler() + + @torch.no_grad() + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + input_metadata: InputMetadata, + input_embeds: torch.Tensor = None, + ) -> LogitsProcessorOutput: + hidden_states = self.transformer( + input_ids, positions, input_metadata, input_embeds + ) + logits_output = self.logits_processor( + input_ids, hidden_states, self.lm_head.weight, input_metadata + ) + sample_output = self.sampler(logits_output, input_metadata.sampling_info) + return sample_output, logits_output + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "c_fc_0", 0), + ("gate_up_proj", "c_fc_1", 1), + ] + params_dict = dict(self.named_parameters()) + + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name or "projector" in name: + continue + if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + if name.startswith("model.vision_tower") and name not in params_dict: + continue + + name = name.replace("attn.attention", "self_attn") + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + + +EntryClass = ExaoneForCausalLM diff --git a/python/sglang/srt/models/gemma.py b/python/sglang/srt/models/gemma.py index 990937f5180..5a6e5df37fe 100644 --- a/python/sglang/srt/models/gemma.py +++ b/python/sglang/srt/models/gemma.py @@ -23,7 +23,6 @@ from transformers import PretrainedConfig from vllm.config import CacheConfig, LoRAConfig from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.model_executor.layers.activation import GeluAndMul from vllm.model_executor.layers.linear import ( MergedColumnParallelLinear, QKVParallelLinear, @@ -34,9 +33,11 @@ from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from sglang.srt.layers.activation import GeluAndMul from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.sampler import Sampler from sglang.srt.model_executor.forward_batch_info import InputMetadata @@ -60,7 +61,7 @@ def __init__( bias=False, quant_config=quant_config, ) - self.act_fn = GeluAndMul() + self.act_fn = GeluAndMul("none") def forward(self, x): gate_up, _ = self.gate_up_proj(x) @@ -287,6 +288,7 @@ def __init__( self.quant_config = quant_config self.model = GemmaModel(config, quant_config=quant_config) self.logits_processor = LogitsProcessor(config) + self.sampler = Sampler() @torch.no_grad() def forward( @@ -297,9 +299,11 @@ def forward( input_embeds: torch.Tensor = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) - return self.logits_processor( + logits_output = self.logits_processor( input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata ) + sample_output = self.sampler(logits_output, input_metadata.sampling_info) + return (sample_output, logits_output) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ diff --git a/python/sglang/srt/models/gemma2.py b/python/sglang/srt/models/gemma2.py index 80b99742e3f..77ebd8564c6 100644 --- a/python/sglang/srt/models/gemma2.py +++ b/python/sglang/srt/models/gemma2.py @@ -22,12 +22,6 @@ from transformers import PretrainedConfig from vllm.config import CacheConfig, LoRAConfig from vllm.distributed import get_tensor_model_parallel_world_size - -# FIXME: temporary solution, remove after next vllm release -from vllm.model_executor.custom_op import CustomOp -from vllm.model_executor.layers.activation import GeluAndMul - -# from vllm.model_executor.layers.layernorm import GemmaRMSNorm from vllm.model_executor.layers.linear import ( MergedColumnParallelLinear, QKVParallelLinear, @@ -39,63 +33,20 @@ from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from sglang.srt.layers.activation import GeluAndMul +from sglang.srt.layers.layernorm import GemmaRMSNorm from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.sampler import Sampler from sglang.srt.model_executor.forward_batch_info import InputMetadata # Aligned with HF's implementation, using sliding window inclusive with the last token # SGLang assumes exclusive -def get_window_size(config): +def get_attention_sliding_window_size(config): return config.sliding_window - 1 -class GemmaRMSNorm(CustomOp): - """RMS normalization for Gemma. - - Two differences from the above RMSNorm: - 1. x * (1 + w) instead of x * w. - 2. (x * w).to(orig_dtype) instead of x.to(orig_dtype) * w. - """ - - def __init__( - self, - hidden_size: int, - eps: float = 1e-6, - ) -> None: - super().__init__() - self.weight = nn.Parameter(torch.zeros(hidden_size)) - self.variance_epsilon = eps - - def forward_native( - self, - x: torch.Tensor, - residual: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - """PyTorch-native implementation equivalent to forward().""" - orig_dtype = x.dtype - if residual is not None: - x = x + residual - residual = x - - x = x.float() - variance = x.pow(2).mean(dim=-1, keepdim=True) - x = x * torch.rsqrt(variance + self.variance_epsilon) - # Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16) - # See https://github.com/huggingface/transformers/pull/29402 - x = x * (1.0 + self.weight.float()) - x = x.to(orig_dtype) - return x if residual is None else (x, residual) - - def forward_cuda( - self, - x: torch.Tensor, - residual: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - # from vLLM: TODO(woosuk): Implement an optimized kernel for GemmaRMSNorm. - return self.forward_native(x, residual) - - # FIXME: temporary solution, remove after next vllm release from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding @@ -135,7 +86,7 @@ def __init__( "function. Please set `hidden_act` and `hidden_activation` to " "`gelu_pytorch_tanh`." ) - self.act_fn = GeluAndMul(approximate="tanh") + self.act_fn = GeluAndMul() def forward(self, x: torch.Tensor) -> torch.Tensor: gate_up, _ = self.gate_up_proj(x) @@ -213,7 +164,11 @@ def __init__( self.scaling, num_kv_heads=self.num_kv_heads, layer_id=layer_idx, - sliding_window_size=get_window_size(config) if use_sliding_window else None, + sliding_window_size=( + get_attention_sliding_window_size(config) + if use_sliding_window + else None + ), logit_cap=self.config.attn_logit_softcapping, ) @@ -392,6 +347,7 @@ def __init__( self.quant_config = quant_config self.model = Gemma2Model(config, cache_config, quant_config) self.logits_processor = LogitsProcessor(config) + self.sampler = Sampler() @torch.no_grad() def forward( @@ -402,12 +358,14 @@ def forward( input_embeds: torch.Tensor = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) - return self.logits_processor( + logits_output = self.logits_processor( input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata ) + sample_output = self.sampler(logits_output, input_metadata.sampling_info) + return sample_output, logits_output - def get_window_size(self): - return get_window_size(self.config) + def get_attention_sliding_window_size(self): + return get_attention_sliding_window_size(self.config) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ diff --git a/python/sglang/srt/models/gpt_bigcode.py b/python/sglang/srt/models/gpt_bigcode.py index 9a9e2aec3a7..dc828f0142e 100644 --- a/python/sglang/srt/models/gpt_bigcode.py +++ b/python/sglang/srt/models/gpt_bigcode.py @@ -23,7 +23,6 @@ from transformers import GPTBigCodeConfig from vllm.config import CacheConfig, LoRAConfig from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import ( ColumnParallelLinear, QKVParallelLinear, @@ -33,8 +32,10 @@ from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from sglang.srt.layers.activation import get_act_fn from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.sampler import Sampler from sglang.srt.model_executor.forward_batch_info import InputMetadata @@ -261,6 +262,7 @@ def __init__( if lora_config: self.unpadded_vocab_size += lora_config.lora_extra_vocab_size self.logits_processor = LogitsProcessor(config) + self.sampler = Sampler() @torch.no_grad() def forward( @@ -270,9 +272,11 @@ def forward( input_metadata: InputMetadata, ) -> torch.Tensor: hidden_states = self.transformer(input_ids, positions, input_metadata) - return self.logits_processor( + logits_output = self.logits_processor( input_ids, hidden_states, self.lm_head.weight, input_metadata ) + sample_output = self.sampler(logits_output, input_metadata.sampling_info) + return sample_output, logits_output def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): params_dict = dict(self.named_parameters(remove_duplicate=False)) diff --git a/python/sglang/srt/models/grok.py b/python/sglang/srt/models/grok.py index 75b086fd6a1..3c2a2c65eae 100644 --- a/python/sglang/srt/models/grok.py +++ b/python/sglang/srt/models/grok.py @@ -46,6 +46,7 @@ from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.sampler import Sampler from sglang.srt.model_executor.forward_batch_info import InputMetadata @@ -273,9 +274,9 @@ def forward( ) -> torch.Tensor: if input_embeds is None: hidden_states = self.embed_tokens(input_ids) + hidden_states.mul_(self.config.embedding_multiplier_scale) else: hidden_states = input_embeds - hidden_states.mul_(self.config.embedding_multiplier_scale) for i in range(len(self.layers)): hidden_states = self.layers[i](positions, hidden_states, input_metadata) @@ -284,7 +285,7 @@ def forward( return hidden_states -class Grok1ModelForCausalLM(nn.Module): +class Grok1ForCausalLM(nn.Module): def __init__( self, config: PretrainedConfig, @@ -295,12 +296,15 @@ def __init__( self.config = config self.quant_config = quant_config self.model = Grok1Model(config, quant_config=quant_config) - # self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) - self.lm_head = ReplicatedLinear(config.hidden_size, config.vocab_size) - self.logits_processor = LogitsProcessor(config, skip_all_gather=True) + self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) + self.logits_processor = LogitsProcessor(config) + self.sampler = Sampler() # Monkey patch _prepare_weights to load pre-sharded weights setattr(DefaultModelLoader, "_prepare_weights", _prepare_presharded_weights) + + self.use_presharded_weights = True + warnings.filterwarnings("ignore", category=FutureWarning) def forward( @@ -311,9 +315,11 @@ def forward( input_embeds: torch.Tensor = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) - return self.logits_processor( + logits_output = self.logits_processor( input_ids, hidden_states, self.lm_head.weight, input_metadata ) + sample_output = self.sampler(logits_output, input_metadata.sampling_info) + return sample_output, logits_output def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ @@ -356,6 +362,13 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): continue name = name.replace(weight_name, param_name) + if self.use_presharded_weights: + extra_kwargs = { + "use_presharded_weights": self.use_presharded_weights + } + else: + extra_kwargs = {} + param = params_dict[name] weight_loader = param.weight_loader weight_loader( @@ -364,7 +377,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_name, shard_id=shard_id, expert_id=expert_id, - pre_sharded=get_tensor_model_parallel_world_size() > 1, + **extra_kwargs, ) break else: @@ -406,4 +419,10 @@ def _prepare_presharded_weights( return hf_folder, hf_weights_files, use_safetensors -EntryClass = Grok1ModelForCausalLM +class Grok1ModelForCausalLM(Grok1ForCausalLM): + """An alias for backward-compatbility.""" + + pass + + +EntryClass = [Grok1ForCausalLM, Grok1ModelForCausalLM] diff --git a/python/sglang/srt/models/internlm2.py b/python/sglang/srt/models/internlm2.py index f2947e991b5..c0e4d19e128 100644 --- a/python/sglang/srt/models/internlm2.py +++ b/python/sglang/srt/models/internlm2.py @@ -40,6 +40,7 @@ from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.sampler import Sampler from sglang.srt.model_executor.forward_batch_info import InputMetadata @@ -262,6 +263,7 @@ def __init__( self.model = InternLM2Model(config, quant_config) self.output = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config) + self.sampler = Sampler() @torch.no_grad() def forward( @@ -272,9 +274,11 @@ def forward( input_embeds: torch.Tensor = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) - return self.logits_processor( + logits_output = self.logits_processor( input_ids, hidden_states, self.output.weight, input_metadata ) + sample_output = self.sampler(logits_output, input_metadata.sampling_info) + return sample_output, logits_output def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ diff --git a/python/sglang/srt/models/llama2.py b/python/sglang/srt/models/llama.py similarity index 86% rename from python/sglang/srt/models/llama2.py rename to python/sglang/srt/models/llama.py index 9de8d33c5c1..926d87db8b7 100644 --- a/python/sglang/srt/models/llama2.py +++ b/python/sglang/srt/models/llama.py @@ -39,8 +39,9 @@ from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.layernorm import RMSNorm -from sglang.srt.layers.logits_processor import LogitProcessorOutput, LogitsProcessor +from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.sampler import Sampler from sglang.srt.model_executor.forward_batch_info import InputMetadata @@ -294,7 +295,6 @@ def __init__( config: LlamaConfig, quant_config: Optional[QuantizationConfig] = None, cache_config: Optional[CacheConfig] = None, - efficient_weight_load=False, ) -> None: super().__init__() self.config = config @@ -302,6 +302,9 @@ def __init__( self.model = LlamaModel(config, quant_config=quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config) + self.sampler = Sampler() + + self.param_dict = dict(self.named_parameters()) @torch.no_grad() def forward( @@ -310,53 +313,35 @@ def forward( positions: torch.Tensor, input_metadata: InputMetadata, input_embeds: torch.Tensor = None, - ) -> LogitProcessorOutput: + ) -> LogitsProcessorOutput: hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) - return self.logits_processor( + logits_output = self.logits_processor( input_ids, hidden_states, self.lm_head.weight, input_metadata ) + sample_output = self.sampler(logits_output, input_metadata.sampling_info) + return sample_output, logits_output - def get_module_name(self, name): - stacked_params_mapping = [ - # (param_name, shard_name, shard_id, num_shard) - ("qkv_proj", "q_proj", "q", 3), - ("qkv_proj", "k_proj", "k", 3), - ("qkv_proj", "v_proj", "v", 3), - ("gate_up_proj", "gate_proj", 0, 2), - ("gate_up_proj", "up_proj", 1, 2), - ] - for param_name, weight_name, shard_id, num_shard in stacked_params_mapping: - if weight_name in name: - return ( - name.replace(weight_name, param_name)[: -len(".weight")], - num_shard, - ) - return name[: -len(".weight")], 1 - - def get_num_params(self): - params_dict = dict(self.named_parameters()) - return len(params_dict) - - def load_weights( - self, weights: Iterable[Tuple[str, torch.Tensor]], name=None, loaded_weight=None - ): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ("gate_up_proj", "gate_proj", 0), - ("gate_up_proj", "up_proj", 1), + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".gate_proj", 0), + (".gate_up_proj", ".up_proj", 1), ] - params_dict = dict(self.named_parameters()) + params_dict = self.param_dict - def load_weights_per_param(name, loaded_weight): + for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name or "projector" in name: - return + continue if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: # Models trained using ColossalAI may include these tensors in # the checkpoint. Skip them. - return + continue + if name.startswith("model.vision_tower") and name not in params_dict: + continue + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue @@ -364,8 +349,6 @@ def load_weights_per_param(name, loaded_weight): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue - if name.startswith("model.vision_tower") and name not in params_dict: - continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) @@ -373,18 +356,14 @@ def load_weights_per_param(name, loaded_weight): else: # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: - return - if name.startswith("model.vision_tower") and name not in params_dict: - return + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) - if name is None or loaded_weight is None: - for name, loaded_weight in weights: - load_weights_per_param(name, loaded_weight) - else: - load_weights_per_param(name, loaded_weight) + +class Phi3ForCausalLM(LlamaForCausalLM): + pass -EntryClass = LlamaForCausalLM +EntryClass = [LlamaForCausalLM, Phi3ForCausalLM] diff --git a/python/sglang/srt/models/llama_classification.py b/python/sglang/srt/models/llama_classification.py index 02224971d6a..db424ff1804 100644 --- a/python/sglang/srt/models/llama_classification.py +++ b/python/sglang/srt/models/llama_classification.py @@ -16,17 +16,16 @@ from typing import Iterable, Optional, Tuple import torch -import tqdm from torch import nn from transformers import LlamaConfig from vllm.config import CacheConfig -from vllm.distributed import get_tensor_model_parallel_rank from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from sglang.srt.layers.logits_processor import LogitProcessorOutput +from sglang.srt.layers.logits_processor import LogitsProcessorOutput +from sglang.srt.layers.sampler import SampleOutput from sglang.srt.model_executor.forward_batch_info import InputMetadata -from sglang.srt.models.llama2 import LlamaModel +from sglang.srt.models.llama import LlamaForCausalLM, LlamaModel class LlamaForClassification(nn.Module): @@ -42,10 +41,12 @@ def __init__( self.model = LlamaModel(config, quant_config=quant_config) self.classification_head = nn.Linear( - config.hidden_size, config.classification_out_size + config.hidden_size, config.classification_out_size, bias=False ) self.eos_token_id = config.eos_token_id + self.param_dict = dict(self.named_parameters()) + @torch.no_grad() def forward( self, @@ -65,7 +66,7 @@ def forward( (input_metadata.batch_size, self.config.classification_out_size) ).to(input_ids.device) - return LogitProcessorOutput( + logits_output = LogitsProcessorOutput( next_token_logits=scores, next_token_logprobs=scores, normalized_prompt_logprobs=scores, @@ -74,50 +75,38 @@ def forward( output_top_logprobs=None, ) + # A dummy to make this work + sample_output = SampleOutput( + success=torch.full( + size=(scores.shape[0],), + fill_value=True, + dtype=torch.bool, + ), + probs=torch.full( + size=(scores.shape[0], 1), + fill_value=1.0, + dtype=torch.float16, + ), + batch_next_token_ids=torch.full( + size=(scores.shape[0],), + fill_value=0, + dtype=torch.long, + ), + ) + return sample_output, logits_output + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ("gate_up_proj", "gate_proj", 0), - ("gate_up_proj", "up_proj", 1), - ] - params_dict = dict(self.named_parameters()) - if get_tensor_model_parallel_rank() == 0: - weights = tqdm.tqdm(weights, total=int(len(params_dict) * 1.5)) - for name, loaded_weight in weights: - if "rotary_emb.inv_freq" in name or "projector" in name: - continue - if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: - # Models trained using ColossalAI may include these tensors in - # the checkpoint. Skip them. - continue - if "lm_head" in name: - continue + params_dict = self.param_dict - for param_name, weight_name, shard_id in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - if name.startswith("model.vision_tower") and name not in params_dict: - continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - if name.startswith("model.vision_tower") and name not in params_dict: - continue + for name, loaded_weight in weights: + if "classification_head" in name: param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + elif "lm_head" in name: + continue + else: + LlamaForCausalLM.load_weights(self, [(name, loaded_weight)]) EntryClass = LlamaForClassification diff --git a/python/sglang/srt/models/llama_embedding.py b/python/sglang/srt/models/llama_embedding.py index e8e6780472d..fe407b29f24 100644 --- a/python/sglang/srt/models/llama_embedding.py +++ b/python/sglang/srt/models/llama_embedding.py @@ -1,4 +1,4 @@ -from typing import Iterable, Optional, Tuple +from typing import Iterable, Tuple import torch from torch import nn @@ -7,7 +7,7 @@ from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType from sglang.srt.model_executor.model_runner import InputMetadata -from sglang.srt.models.llama2 import LlamaForCausalLM, LlamaModel +from sglang.srt.models.llama import LlamaModel class LlamaEmbeddingModel(nn.Module): @@ -16,7 +16,6 @@ def __init__( config: LlamaConfig, quant_config=None, cache_config=None, - efficient_weight_load=False, ) -> None: super().__init__() self.model = LlamaModel(config, quant_config=quant_config) @@ -29,7 +28,11 @@ def forward( positions: torch.Tensor, input_metadata: InputMetadata, input_embeds: torch.Tensor = None, + get_embedding: bool = True, ) -> EmbeddingPoolerOutput: + assert ( + get_embedding + ), "LlamaEmbeddingModel / MistralModel is only used for embedding" hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) return self.pooler(hidden_states, input_metadata) @@ -53,6 +56,9 @@ def load_weights_per_param(name, loaded_weight): # Models trained using ColossalAI may include these tensors in # the checkpoint. Skip them. return + if name.startswith("model.vision_tower") and name not in params_dict: + return + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue @@ -60,8 +66,6 @@ def load_weights_per_param(name, loaded_weight): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue - if name.startswith("model.vision_tower") and name not in params_dict: - continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) @@ -70,8 +74,6 @@ def load_weights_per_param(name, loaded_weight): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: return - if name.startswith("model.vision_tower") and name not in params_dict: - return param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) @@ -83,6 +85,8 @@ def load_weights_per_param(name, loaded_weight): load_weights_per_param(name, loaded_weight) -EntryClass = LlamaEmbeddingModel -# compat: e5-mistral model.config class == MistralModel -EntryClassRemapping = [("MistralModel", LlamaEmbeddingModel)] +class MistralModel(LlamaEmbeddingModel): + pass + + +EntryClass = [LlamaEmbeddingModel, MistralModel] diff --git a/python/sglang/srt/models/llava.py b/python/sglang/srt/models/llava.py index a885a6e5953..2e3c9ceba1a 100644 --- a/python/sglang/srt/models/llava.py +++ b/python/sglang/srt/models/llava.py @@ -15,6 +15,8 @@ """Inference-only LLaVa model compatible with HuggingFace weights.""" +import math +import re from typing import Iterable, List, Optional, Tuple import numpy as np @@ -26,6 +28,7 @@ LlavaConfig, MistralConfig, Qwen2Config, + SiglipVisionModel, ) from transformers.models.llava.modeling_llava import LlavaMultiModalProjector from vllm.config import CacheConfig @@ -38,59 +41,73 @@ unpad_image_shape, ) from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata -from sglang.srt.models.llama2 import LlamaForCausalLM +from sglang.srt.models.llama import LlamaForCausalLM from sglang.srt.models.mistral import MistralForCausalLM from sglang.srt.models.qwen2 import Qwen2ForCausalLM -class LlavaLlamaForCausalLM(nn.Module): - def __init__( +class LlavaBaseForCausalLM(nn.Module): + def pad_input_ids( self, - config: LlavaConfig, - quant_config: Optional[QuantizationConfig] = None, - cache_config: Optional[CacheConfig] = None, - ) -> None: - super().__init__() - self.config = config - self.vision_tower = None - self.config.vision_config.hidden_size = config.mm_hidden_size - self.config.text_config.hidden_size = config.hidden_size - self.multi_modal_projector = LlavaMultiModalProjector(config) - self.language_model = LlamaForCausalLM(config, quant_config=quant_config) - if "unpad" in getattr(config, "mm_patch_merge_type", ""): - self.language_model.model.image_newline = nn.Parameter( - torch.empty(config.text_config.hidden_size, dtype=torch.float16) - ) + input_ids: List[int], + pad_value: List[int], + pixel_values: List, + image_sizes: List[List[int]], + ): + # hardcode for spatial_unpad + anyres + image_aspect_ratio = "anyres" if len(image_sizes) == 1 else "pad" + offset_list = [] + for image_s in image_sizes: + if len(image_sizes) > 16: + # 2x2 pooling with stride 2 + new_image_feature_len = ( + math.ceil(self.image_size / self.patch_size / 2) ** 2 + ) + else: + new_image_feature_len = self.image_feature_len # multiimage - def pad_input_ids(self, input_ids, pad_value, pt_shape=None, image_size=None): - new_image_feature_len = self.image_feature_len - # now only support spatial_unpad + anyres - if self.mm_patch_merge_type.startswith("spatial"): height = width = self.num_patches_per_side - if pt_shape[0] > 1: - if self.image_aspect_ratio == "anyres": - num_patch_width, num_patch_height = get_anyres_image_grid_shape( - image_size, - self.image_grid_pinpoints, - self.vision_tower.config.image_size, + if "anyres" in image_aspect_ratio: + num_patch_width, num_patch_height = get_anyres_image_grid_shape( + image_s, + self.image_grid_pinpoints, + self.vision_tower.config.image_size, + ) + h = num_patch_height * height + w = num_patch_width * width + new_h, new_w = unpad_image_shape(h, w, image_s) + + if "anyres_max" in self.config.image_aspect_ratio: + matched_anyres_max_num_patches = re.match( + r"anyres_max_(\d+)", self.config.image_aspect_ratio + ) + if matched_anyres_max_num_patches: + max_num_patches = int(matched_anyres_max_num_patches.group(1)) + # times = math.sqrt(h * w / (max_num_patches * unit**2)) + times = math.sqrt( + new_h * new_w / (max_num_patches * self.image_feature_len) ) - if "unpad" in self.mm_patch_merge_type: - h = num_patch_height * height - w = num_patch_width * width - new_h, new_w = unpad_image_shape(h, w, image_size) - new_image_feature_len += new_h * (new_w + 1) - - pad_ids = pad_value * ( - (new_image_feature_len + len(pad_value)) // len(pad_value) - ) - offset = input_ids.index(self.config.image_token_index) - # old_len + pad_len - 1, because we need to remove image_token_id - new_input_ids = ( - input_ids[:offset] - + pad_ids[:new_image_feature_len] - + input_ids[offset + 1 :] - ) - return new_input_ids, offset + if times > 1.1: + new_h = int(new_h // times) + new_w = int(new_w // times) + new_image_feature_len += new_h * (new_w + 1) + + pad_ids = pad_value * ( + (new_image_feature_len + len(pad_value)) // len(pad_value) + ) + # print("calculated new_image_feature_len: ", new_image_feature_len) + try: + offset = input_ids.index(self.config.image_token_index) + except ValueError: + offset = 0 + # old_len + pad_len - 1, because we need to remove image_token_id + input_ids = ( + input_ids[:offset] + + pad_ids[:new_image_feature_len] + + input_ids[offset + 1 :] + ) + offset_list.append(offset) + return input_ids, offset_list def encode_images(self, pixel_values: torch.Tensor) -> torch.Tensor: image_outputs = self.vision_tower(pixel_values, output_hidden_states=True) @@ -122,18 +139,15 @@ def forward( if input_metadata.forward_mode == ForwardMode.EXTEND: bs = input_metadata.batch_size - # Embed text input + # Embed text inputs input_embeds = self.language_model.model.embed_tokens(input_ids) - # Embed vision input - need_vision = ( - (positions[input_metadata.extend_start_loc] < self.image_feature_len) - .cpu() - .numpy() + # Whether the requests need vision inputs + max_image_offset = np.array( + [max(image_offsets[i]) if image_offsets[i] else -1 for i in range(bs)] ) - # FIXME: We need to substract the length of the system prompt - has_pixel = np.array([pixel_values[i] is not None for i in range(bs)]) - need_vision = need_vision & has_pixel + start_positions = positions[input_metadata.extend_start_loc].cpu().numpy() + need_vision = start_positions <= max_image_offset if need_vision.any(): pixel_values = [pixel_values[i] for i in range(bs) if need_vision[i]] @@ -163,27 +177,73 @@ def forward( if self.mm_patch_merge_type.startswith("spatial"): new_image_features = [] + height = width = self.num_patches_per_side for image_idx, image_feature in enumerate(image_features): - if image_feature.shape[0] > 1: + if len(image_sizes[image_idx]) == 1: + image_aspect_ratio = ( + self.config.image_aspect_ratio + ) # single image + else: + image_aspect_ratio = "pad" # multi image + # image_aspect_ratio = ( + # "anyres" if len(image_sizes[image_idx]) == 1 else "pad" + # ) + if ( + image_feature.shape[0] > 1 + and "anyres" in image_aspect_ratio + ): base_image_feature = image_feature[0] image_feature = image_feature[1:] - height = width = self.num_patches_per_side assert height * width == base_image_feature.shape[0] - if self.image_aspect_ratio == "anyres": - ( - num_patch_width, - num_patch_height, - ) = get_anyres_image_grid_shape( - image_sizes[image_idx], - self.image_grid_pinpoints, - self.vision_tower.config.image_size, + + if "anyres_max" in image_aspect_ratio: + matched_anyres_max_num_patches = re.match( + r"anyres_max_(\d+)", image_aspect_ratio ) + if matched_anyres_max_num_patches: + max_num_patches = int( + matched_anyres_max_num_patches.group(1) + ) + + if ( + image_aspect_ratio == "anyres" + or "anyres_max" in image_aspect_ratio + ): + vision_tower_image_size = self.image_size + try: + num_patch_width, num_patch_height = ( + get_anyres_image_grid_shape( + image_sizes[image_idx][0], + self.config.image_grid_pinpoints, + vision_tower_image_size, + ) + ) + except Exception as e: + print(f"Error: {e}") + num_patch_width, num_patch_height = 2, 2 image_feature = image_feature.view( num_patch_height, num_patch_width, height, width, -1 ) else: - raise NotImplementedError() + image_feature = image_feature.view( + 2, 2, height, width, -1 + ) + + # ( + # num_patch_width, + # num_patch_height, + # ) = get_anyres_image_grid_shape( + # image_sizes[image_idx][0], + # self.image_grid_pinpoints, + # self.vision_tower.config.image_size, + # ) + + # image_feature = image_feature.view( + # num_patch_height, num_patch_width, height, width, -1 + # ) + if "unpad" in self.mm_patch_merge_type: + unit = image_feature.shape[2] image_feature = image_feature.permute( 4, 0, 2, 1, 3 ).contiguous() @@ -191,8 +251,23 @@ def forward( 2, 3 ) image_feature = unpad_image( - image_feature, image_sizes[image_idx] + image_feature, image_sizes[image_idx][0] ) + if ( + "anyres_max" in image_aspect_ratio + and matched_anyres_max_num_patches + ): + c, h, w = image_feature.shape + times = math.sqrt( + h * w / (max_num_patches * unit**2) + ) + if times > 1.1: + image_feature = image_feature[None] + image_feature = nn.functional.interpolate( + image_feature, + [int(h // times), int(w // times)], + mode="bilinear", + )[0] image_feature = torch.cat( ( image_feature, @@ -213,43 +288,63 @@ def forward( image_feature = torch.cat( (base_image_feature, image_feature), dim=0 ) + image_feature = image_feature.unsqueeze(0) else: - image_feature = image_feature[0] - if "unpad" in self.mm_patch_merge_type: - image_feature = torch.cat( - ( - image_feature, - self.language_model.model.image_newline[None], - ), - dim=0, + if image_feature.shape[0] > 16: # video + # 2x2 pooling + num_of_frames = image_feature.shape[0] + image_feature = image_feature.view( + num_of_frames, height, width, -1 + ) + image_feature = image_feature.permute( + 0, 3, 1, 2 + ).contiguous() # N, C, H, W + height, weight = image_feature.shape[2:] + scaled_shape = [ + math.ceil(height / 2), + math.ceil(weight / 2), + ] + image_feature = nn.functional.interpolate( + image_feature, size=scaled_shape, mode="bilinear" ) + image_feature = ( + image_feature.flatten(2) + .transpose(1, 2) + .contiguous() + ) # N, C, H*W + new_image_features.append(image_feature) image_features = new_image_features + # Fill in the placeholder for the image extend_start_loc_cpu = input_metadata.extend_start_loc.cpu().numpy() + prefix_lens_cpu = input_metadata.extend_prefix_lens.cpu().numpy() pt = 0 for i in range(bs): if not need_vision[i]: continue start_idx = extend_start_loc_cpu[i] - pad_len, pad_dim = image_features[pt].shape # 576, 4096 - dim = input_embeds.shape[1] - assert ( - pad_dim == dim - ), "invalid pad_dim={}, input_embed_dim={}!".format(pad_dim, dim) - # Fill in the placeholder for the image - try: - input_embeds[ - start_idx - + image_offsets[i] : start_idx - + image_offsets[i] - + pad_len - ] = image_features[pt] - except RuntimeError as e: - print(f"RuntimeError in llava image encoding: {e}") - print(input_embeds.shape) - print(start_idx, image_offsets[i]) + prefix_len = prefix_lens_cpu[i] + + # Multiple images + for j, image_offset in enumerate(image_offsets[i]): + if image_offset < prefix_len: + continue + + tmp_image_feature = image_features[pt][j] + pad_len = tmp_image_feature.shape[0] + + left_idx = start_idx + (image_offset - prefix_len) + right_idx = start_idx + (image_offset - prefix_len) + pad_len + try: + input_embeds[left_idx:right_idx] = tmp_image_feature + except RuntimeError as e: + print(f"RuntimeError in image encoding: {e}") + print(f"{input_embeds.shape=}, {tmp_image_feature.shape=}") + print( + f"{start_idx=}, {image_offset=}, {prefix_len=}, {pad_len=}" + ) pt += 1 return self.language_model( @@ -259,12 +354,20 @@ def forward( return self.language_model(input_ids, positions, input_metadata) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - # load clip vision model by cfg['mm_vision_tower']: - # huggingface_name or path_of_clip_relative_to_llava_model_dir + # Load clip vision model by cfg['mm_vision_tower']: + # huggingface_name or path_of_clip_relative_to_llava_model_dir + # We put the initialization here instead of __init__ to allow it being reused by other subclasses. vision_path = self.config.mm_vision_tower - self.vision_tower = CLIPVisionModel.from_pretrained( - vision_path, torch_dtype=torch.float16 - ).cuda() + if "clip" in vision_path: + self.vision_tower = CLIPVisionModel.from_pretrained( + vision_path, torch_dtype=torch.float16 + ).cuda() + elif "siglip" in vision_path: + self.vision_tower = SiglipVisionModel.from_pretrained( + vision_path, torch_dtype=torch.float16 + ).cuda() + # Siglip needs all feature tokens + self.config.mm_vision_select_feature = "full" self.vision_tower.eval() self.vision_feature_layer = self.config.mm_vision_select_layer @@ -276,8 +379,11 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): self.image_aspect_ratio = getattr(self.config, "image_aspect_ratio", "square") self.image_grid_pinpoints = getattr(self.config, "image_grid_pinpoints", None) - self.image_feature_len = int((self.image_size / self.patch_size) ** 2) - if self.vision_feature_select_strategy == "patch": + self.image_feature_len = int((self.image_size // self.patch_size) ** 2) + if ( + self.vision_feature_select_strategy == "patch" + or self.vision_feature_select_strategy == "full" + ): pass elif self.vision_feature_select_strategy == "cls_patch": self.image_feature_len += 1 @@ -289,42 +395,61 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): "model.mm_projector.0": "multi_modal_projector.linear_1", "model.mm_projector.2": "multi_modal_projector.linear_2", "model.vision_tower.vision_tower": "vision_tower", # Update the vision tower weights if we find them in the checkpoint (it may be finetuned). + "model.image_newline": "language_model.model.image_newline", } params_dict = dict(self.named_parameters()) - weights = list(weights) for name, loaded_weight in weights: - # FIXME: why projector weights read two times? - if "projector" in name or "vision_tower" in name: + if "projector" in name or "vision_tower" in name or "image_newline" in name: for weight_name, param_name in projector_weights.items(): if weight_name in name: name = name.replace(weight_name, param_name) param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) - - # load language model - self.language_model.load_weights(weights) - - monkey_path_clip_vision_embed_forward() + else: + self.language_model.load_weights([(name, loaded_weight)]) @property def num_patches_per_side(self): return self.image_size // self.patch_size -class LlavaQwenForCausalLM(LlavaLlamaForCausalLM): +class LlavaLlamaForCausalLM(LlavaBaseForCausalLM): + def __init__( + self, + config: LlavaConfig, + quant_config: Optional[QuantizationConfig] = None, + cache_config: Optional[CacheConfig] = None, + ) -> None: + super().__init__() + + self.config = config + self.vision_tower = None + self.config.vision_config.hidden_size = config.mm_hidden_size + self.config.text_config.hidden_size = config.hidden_size + + self.multi_modal_projector = LlavaMultiModalProjector(config) + self.language_model = LlamaForCausalLM(config, quant_config=quant_config) + if "unpad" in getattr(config, "mm_patch_merge_type", ""): + self.language_model.model.image_newline = nn.Parameter( + torch.empty(config.text_config.hidden_size, dtype=torch.float16) + ) + + +class LlavaQwenForCausalLM(LlavaBaseForCausalLM): def __init__( self, config: LlavaConfig, quant_config: Optional[QuantizationConfig] = None, cache_config: Optional[CacheConfig] = None, ) -> None: - super().__init__(config, quant_config=quant_config, cache_config=cache_config) + super().__init__() + self.config = config self.vision_tower = None + if getattr(self.config, "vision_config", None) is None: self.config.vision_config = CLIPVisionConfig(self.config.mm_vision_tower) - if getattr(self.config, "text_config", None) is None: self.config.text_config = Qwen2Config(self.config._name_or_path) @@ -333,7 +458,6 @@ def __init__( if getattr(self.config, "projector_hidden_act", None) is None: self.config.projector_hidden_act = "gelu" - if getattr(self.config, "image_token_index", None) is None: self.config.image_token_index = 151646 @@ -345,19 +469,20 @@ def __init__( ) -class LlavaMistralForCausalLM(LlavaLlamaForCausalLM): +class LlavaMistralForCausalLM(LlavaBaseForCausalLM): def __init__( self, config: LlavaConfig, quant_config: Optional[QuantizationConfig] = None, cache_config: Optional[CacheConfig] = None, ) -> None: - super().__init__(config, quant_config=quant_config, cache_config=cache_config) + super().__init__() + self.config = config self.vision_tower = None + if getattr(self.config, "vision_config", None) is None: self.config.vision_config = CLIPVisionConfig(self.config.mm_vision_tower) - if getattr(self.config, "text_config", None) is None: self.config.text_config = MistralConfig(self.config._name_or_path) @@ -366,7 +491,6 @@ def __init__( if getattr(self.config, "projector_hidden_act", None) is None: self.config.projector_hidden_act = "gelu" - if getattr(self.config, "image_token_index", None) is None: self.config.image_token_index = 32000 @@ -378,36 +502,4 @@ def __init__( ) -first_call = True - - -def clip_vision_embed_forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: - batch_size = pixel_values.shape[0] - - # Move this conv layer to CPU to avoid a bug in torch >= 2.1 on A10G. - global first_call - if first_call: - self.patch_embedding.cpu().float() - first_call = False - pixel_values = pixel_values.to(dtype=torch.float32, device="cpu") - patch_embeds = self.patch_embedding(pixel_values).cuda().half() - - patch_embeds = patch_embeds.flatten(2).transpose(1, 2) - - class_embeds = self.class_embedding.expand(batch_size, 1, -1) - embeddings = torch.cat([class_embeds, patch_embeds], dim=1) - embeddings = embeddings + self.position_embedding(self.position_ids) - return embeddings - - -def monkey_path_clip_vision_embed_forward(): - import transformers - - setattr( - transformers.models.clip.modeling_clip.CLIPVisionEmbeddings, - "forward", - clip_vision_embed_forward, - ) - - EntryClass = [LlavaLlamaForCausalLM, LlavaQwenForCausalLM, LlavaMistralForCausalLM] diff --git a/python/sglang/srt/models/llavavid.py b/python/sglang/srt/models/llavavid.py index 8b81251d692..f268ecbbcd7 100644 --- a/python/sglang/srt/models/llavavid.py +++ b/python/sglang/srt/models/llavavid.py @@ -26,13 +26,8 @@ from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from sglang.srt.mm_utils import ( - get_anyres_image_grid_shape, - unpad_image, - unpad_image_shape, -) from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata -from sglang.srt.models.llama2 import LlamaForCausalLM +from sglang.srt.models.llama import LlamaForCausalLM class LlavaVidForCausalLM(nn.Module): @@ -59,23 +54,14 @@ def __init__( torch.empty(config.text_config.hidden_size, dtype=torch.float16) ) - def pad_input_ids(self, input_ids, pad_value, pt_shape=None, image_size=None): + def pad_input_ids( + self, + input_ids: List[int], + pad_value: List[int], + pixel_values: List, + image_sizes: List[List[int]], + ): new_image_feature_len = self.image_feature_len - # now only support spatial_unpad + anyres - # if self.mm_patch_merge_type.startswith("spatial"): - # height = width = self.num_patches_per_side - # if pt_shape[0] > 1: - # if self.image_aspect_ratio == "anyres": - # num_patch_width, num_patch_height = get_anyres_image_grid_shape( - # image_size, - # self.image_grid_pinpoints, - # self.vision_tower.config.image_size, - # ) - # if "unpad" in self.mm_patch_merge_type: - # h = num_patch_height * height - # w = num_patch_width * width - # new_h, new_w = unpad_image_shape(h, w, image_size) - # new_image_feature_len += new_h * (new_w + 1) pad_ids = pad_value * ( (new_image_feature_len + len(pad_value)) // len(pad_value) @@ -87,7 +73,7 @@ def pad_input_ids(self, input_ids, pad_value, pt_shape=None, image_size=None): + pad_ids[:new_image_feature_len] + input_ids[offset + 1 :] ) - return new_input_ids, offset + return new_input_ids, [offset] def encode_images(self, pixel_values: torch.Tensor) -> torch.Tensor: image_outputs = self.vision_tower(pixel_values, output_hidden_states=True) @@ -133,22 +119,18 @@ def forward( if input_metadata.forward_mode == ForwardMode.EXTEND: bs = input_metadata.batch_size - # Embed text input + # Embed text inputs input_embeds = self.language_model.model.embed_tokens(input_ids) - # Embed vision input - need_vision = ( - (positions[input_metadata.extend_start_loc] < self.image_feature_len) - .cpu() - .numpy() + # Whether the requests need vision inputs + max_image_offset = np.array( + [max(image_offsets[i]) if image_offsets[i] else -1 for i in range(bs)] ) - # FIXME: We need to substract the length of the system prompt - has_pixel = np.array([pixel_values[i] is not None for i in range(bs)]) - need_vision = need_vision & has_pixel + start_positions = positions[input_metadata.extend_start_loc].cpu().numpy() + need_vision = start_positions <= max_image_offset if need_vision.any(): pixel_values = [pixel_values[i] for i in range(bs) if need_vision[i]] - image_sizes = [image_sizes[i] for i in range(bs) if need_vision[i]] ########## Encode Image ######## @@ -183,31 +165,36 @@ def forward( new_image_features.append(image_feature.flatten(0, 1)) image_features = new_image_features + # Fill in the placeholder for the image extend_start_loc_cpu = input_metadata.extend_start_loc.cpu().numpy() + prefix_lens_cpu = input_metadata.extend_prefix_lens.cpu().numpy() pt = 0 for i in range(bs): if not need_vision[i]: continue start_idx = extend_start_loc_cpu[i] - pad_len, pad_dim = image_features[pt].shape # 576, 4096 - dim = input_embeds.shape[1] - assert ( - pad_dim == dim - ), "invalid pad_dim={}, input_embed_dim={}!".format(pad_dim, dim) - # Fill in the placeholder for the image - try: - input_embeds[ - start_idx - + image_offsets[i] : start_idx - + image_offsets[i] - + pad_len - ] = image_features[pt] - except RuntimeError as e: - print(f"RuntimeError in llava image encoding: {e}") - print(input_embeds.shape) - print(start_idx, image_offsets[i]) - pt += 1 + prefix_len = prefix_lens_cpu[i] + + # Multiple images + for image_offset in image_offsets[i]: + if image_offset < prefix_len: + continue + + tmp_image_feature = image_features[pt] + pad_len = tmp_image_feature.shape[0] + + left_idx = start_idx + (image_offset - prefix_len) + right_idx = start_idx + (image_offset - prefix_len) + pad_len + try: + input_embeds[left_idx:right_idx] = tmp_image_feature + except RuntimeError as e: + print(f"RuntimeError in image encoding: {e}") + print(f"{input_embeds.shape=}, {tmp_image_feature.shape=}") + print( + f"{start_idx=}, {image_offset=}, {prefix_len=}, {pad_len=}" + ) + pt += 1 return self.language_model( input_ids, positions, input_metadata, input_embeds=input_embeds @@ -216,8 +203,9 @@ def forward( return self.language_model(input_ids, positions, input_metadata) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - # load clip vision model by cfg['mm_vision_tower']: - # huggingface_name or path_of_clip_relative_to_llava_model_dir + # Load clip vision model by cfg['mm_vision_tower']: + # huggingface_name or path_of_clip_relative_to_llava_model_dir + # We put the initialization here instead of __init__ to allow it being reused by other subclasses. vision_path = self.config.mm_vision_tower self.vision_tower = CLIPVisionModel.from_pretrained( vision_path, torch_dtype=torch.float16 @@ -251,12 +239,12 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): "model.vision_resampler.mm_projector.0": "multi_modal_projector.linear_1", "model.vision_resampler.mm_projector.2": "multi_modal_projector.linear_2", "model.vision_tower.vision_tower": "vision_tower", # Update the vision tower weights if we find them in the checkpoint (it may be finetuned). + "model.image_newline": "language_model.model.image_newline", } params_dict = dict(self.named_parameters()) - weights = list(weights) for name, loaded_weight in weights: # FIXME: why projector weights read two times? - if "projector" in name or "vision_tower" in name: + if "projector" in name or "vision_tower" in name or "image_newline" in name: for weight_name, param_name in projector_weights.items(): if weight_name in name: name = name.replace(weight_name, param_name) @@ -267,47 +255,12 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): continue weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) - - # load language model - self.language_model.load_weights(weights) - - monkey_path_clip_vision_embed_forward() + else: + self.language_model.load_weights([(name, loaded_weight)]) @property def num_patches_per_side(self): return self.image_size // self.patch_size -first_call = True - - -def clip_vision_embed_forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: - batch_size = pixel_values.shape[0] - - # Move this conv layer to CPU to avoid a bug in torch >= 2.1 on A10G. - global first_call - if first_call: - self.patch_embedding.cpu().float() - first_call = False - pixel_values = pixel_values.to(dtype=torch.float32, device="cpu") - patch_embeds = self.patch_embedding(pixel_values).cuda().half() - - patch_embeds = patch_embeds.flatten(2).transpose(1, 2) - - class_embeds = self.class_embedding.expand(batch_size, 1, -1) - embeddings = torch.cat([class_embeds, patch_embeds], dim=1) - embeddings = embeddings + self.position_embedding(self.position_ids) - return embeddings - - -def monkey_path_clip_vision_embed_forward(): - import transformers - - setattr( - transformers.models.clip.modeling_clip.CLIPVisionEmbeddings, - "forward", - clip_vision_embed_forward, - ) - - EntryClass = LlavaVidForCausalLM diff --git a/python/sglang/srt/models/minicpm.py b/python/sglang/srt/models/minicpm.py index 49ff1926f39..0028ae67a8c 100644 --- a/python/sglang/srt/models/minicpm.py +++ b/python/sglang/srt/models/minicpm.py @@ -39,6 +39,7 @@ from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.sampler import Sampler from sglang.srt.model_executor.forward_batch_info import InputMetadata @@ -297,6 +298,7 @@ def __init__( self.scale_width = self.config.hidden_size / self.config.dim_model_base self.logits_processor = LogitsProcessor(config) + self.sampler = Sampler() @torch.no_grad() def forward( @@ -314,9 +316,11 @@ def forward( lm_head_weight = self.model.embed_tokens.weight else: lm_head_weight = self.lm_head.weight - return self.logits_processor( + logits_output = self.logits_processor( input_ids, hidden_states, lm_head_weight, input_metadata ) + sample_output = self.sampler(logits_output, input_metadata.sampling_info) + return sample_output, logits_output def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ diff --git a/python/sglang/srt/models/mistral.py b/python/sglang/srt/models/mistral.py index 614c1c1d747..1430ece436e 100644 --- a/python/sglang/srt/models/mistral.py +++ b/python/sglang/srt/models/mistral.py @@ -15,12 +15,11 @@ """Inference-only Mistral model.""" -from sglang.srt.models.llama2 import LlamaForCausalLM +from sglang.srt.models.llama import LlamaForCausalLM class MistralForCausalLM(LlamaForCausalLM): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + pass EntryClass = MistralForCausalLM diff --git a/python/sglang/srt/models/mixtral.py b/python/sglang/srt/models/mixtral.py index d11f6c95198..85f4576c46d 100644 --- a/python/sglang/srt/models/mixtral.py +++ b/python/sglang/srt/models/mixtral.py @@ -41,6 +41,7 @@ from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.sampler import Sampler from sglang.srt.model_executor.forward_batch_info import InputMetadata @@ -299,6 +300,7 @@ def __init__( self.model = MixtralModel(config, quant_config=quant_config, prefix="model") self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config) + self.sampler = Sampler() def forward( self, @@ -308,9 +310,11 @@ def forward( input_embeds: torch.Tensor = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) - return self.logits_processor( + logits_output = self.logits_processor( input_ids, hidden_states, self.lm_head.weight, input_metadata ) + sample_output = self.sampler(logits_output, input_metadata.sampling_info) + return sample_output, logits_output def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ @@ -358,7 +362,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader( param, loaded_weight, - weight_name, + name, shard_id=shard_id, expert_id=expert_id, ) diff --git a/python/sglang/srt/models/mixtral_quant.py b/python/sglang/srt/models/mixtral_quant.py index b02e925c5a0..97ac09ee629 100644 --- a/python/sglang/srt/models/mixtral_quant.py +++ b/python/sglang/srt/models/mixtral_quant.py @@ -45,6 +45,7 @@ from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.sampler import Sampler from sglang.srt.model_executor.forward_batch_info import InputMetadata @@ -333,6 +334,7 @@ def __init__( self.model = MixtralModel(config, quant_config=quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config) + self.sampler = Sampler() @torch.no_grad() def forward( @@ -343,9 +345,11 @@ def forward( input_embeds: torch.Tensor = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) - return self.logits_processor( + logits_output = self.logits_processor( input_ids, hidden_states, self.lm_head.weight, input_metadata ) + sample_output = self.sampler(logits_output, input_metadata.sampling_info) + return sample_output, logits_output def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ diff --git a/python/sglang/srt/models/qwen.py b/python/sglang/srt/models/qwen.py index 93dae9585c3..4958a812985 100644 --- a/python/sglang/srt/models/qwen.py +++ b/python/sglang/srt/models/qwen.py @@ -39,6 +39,7 @@ from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.sampler import Sampler from sglang.srt.model_executor.forward_batch_info import InputMetadata @@ -251,6 +252,7 @@ def __init__( vocab_size = ((config.vocab_size + 63) // 64) * 64 self.lm_head = ParallelLMHead(vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config) + self.sampler = Sampler() @torch.no_grad() def forward( @@ -260,10 +262,11 @@ def forward( input_metadata: InputMetadata, ): hidden_states = self.transformer(input_ids, positions, input_metadata) - next_tokens = self.logits_processor( + logits_output = self.logits_processor( input_ids, hidden_states, self.lm_head.weight, input_metadata ) - return next_tokens + sample_output = self.sampler(logits_output, input_metadata.sampling_info) + return sample_output, logits_output def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ diff --git a/python/sglang/srt/models/qwen2.py b/python/sglang/srt/models/qwen2.py index d1295bd8cc8..6bb5c0b9066 100644 --- a/python/sglang/srt/models/qwen2.py +++ b/python/sglang/srt/models/qwen2.py @@ -38,7 +38,9 @@ from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.pooler import Pooler, PoolingType from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.sampler import Sampler from sglang.srt.model_executor.forward_batch_info import InputMetadata Qwen2Config = None @@ -275,6 +277,8 @@ def __init__( self.model = Qwen2Model(config, quant_config=quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config) + self.sampler = Sampler() + self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) @torch.no_grad() def forward( @@ -283,11 +287,17 @@ def forward( positions: torch.Tensor, input_metadata: InputMetadata, input_embeds: torch.Tensor = None, + get_embedding: bool = False, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) - return self.logits_processor( - input_ids, hidden_states, self.lm_head.weight, input_metadata - ) + if not get_embedding: + logits_output = self.logits_processor( + input_ids, hidden_states, self.lm_head.weight, input_metadata + ) + sample_output = self.sampler(logits_output, input_metadata.sampling_info) + return sample_output, logits_output + else: + return self.pooler(hidden_states, input_metadata) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ @@ -306,6 +316,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Models trained using ColossalAI may include these tensors in # the checkpoint. Skip them. continue + if name.startswith("model.vision_tower") and name not in params_dict: + continue + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue @@ -313,8 +326,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue - if name.startswith("model.vision_tower") and name not in params_dict: - continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) @@ -323,8 +334,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue - if name.startswith("model.vision_tower") and name not in params_dict: - continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) diff --git a/python/sglang/srt/models/qwen2_moe.py b/python/sglang/srt/models/qwen2_moe.py index 9bdbd750660..67b5a6ce663 100644 --- a/python/sglang/srt/models/qwen2_moe.py +++ b/python/sglang/srt/models/qwen2_moe.py @@ -35,10 +35,8 @@ ReplicatedLinear, RowParallelLinear, ) -from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, @@ -49,6 +47,7 @@ from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.sampler import Sampler from sglang.srt.model_executor.forward_batch_info import InputMetadata @@ -366,6 +365,7 @@ def __init__( config.vocab_size, config.hidden_size, quant_config=quant_config ) self.logits_processor = LogitsProcessor(config) + self.sampler = Sampler() @torch.no_grad() def forward( @@ -376,20 +376,11 @@ def forward( input_embeds: torch.Tensor = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) - return self.logits_processor( + logits_output = self.logits_processor( input_ids, hidden_states, self.lm_head.weight, input_metadata ) - - def compute_logits( - self, - input_ids: torch.Tensor, - hidden_states: torch.Tensor, - input_metadata: InputMetadata, - ) -> torch.Tensor: - logits = self.logits_processor( - input_ids, hidden_states, self.lm_head.weight, input_metadata - ) - return logits + sample_output = self.sampler(logits_output, input_metadata.sampling_info) + return sample_output, logits_output def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ @@ -401,24 +392,12 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ("gate_up_proj", "up_proj", 1), ] - expert_params_mapping = [ - # These are the weights for the experts - # (param_name, weight_name, expert_id, shard_id) - ( - ( - "experts.w13_weight" - if weight_name in ["gate_proj", "up_proj"] - else "experts.w2_weight" - ), - f"experts.{expert_id}.{weight_name}.weight", - expert_id, - shard_id, - ) - for expert_id in range(self.config.num_experts) - for shard_id, weight_name in enumerate( - ["gate_proj", "down_proj", "up_proj"] - ) - ] + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.num_experts, + ) params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: @@ -458,7 +437,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader( param, loaded_weight, - weight_name, + name, shard_id=shard_id, expert_id=expert_id, ) diff --git a/python/sglang/srt/models/stablelm.py b/python/sglang/srt/models/stablelm.py index 9e10f12f2a2..a3102baabd4 100644 --- a/python/sglang/srt/models/stablelm.py +++ b/python/sglang/srt/models/stablelm.py @@ -40,6 +40,7 @@ from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.sampler import Sampler from sglang.srt.model_executor.forward_batch_info import InputMetadata @@ -249,6 +250,7 @@ def __init__( self.model = StableLMEpochModel(config, quant_config=quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config) + self.sampler = Sampler() @torch.no_grad() def forward( @@ -259,9 +261,11 @@ def forward( input_embeds: torch.Tensor = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) - return self.logits_processor( + logits_output = self.logits_processor( input_ids, hidden_states, self.lm_head.weight, input_metadata ) + sample_output = self.sampler(logits_output, input_metadata.sampling_info) + return sample_output, logits_output def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ diff --git a/python/sglang/srt/models/yivl.py b/python/sglang/srt/models/yivl.py index 11d4cda1c00..0f86206d821 100644 --- a/python/sglang/srt/models/yivl.py +++ b/python/sglang/srt/models/yivl.py @@ -24,10 +24,7 @@ from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from sglang.srt.models.llava import ( - LlavaLlamaForCausalLM, - monkey_path_clip_vision_embed_forward, -) +from sglang.srt.models.llava import LlavaLlamaForCausalLM class YiVLForCausalLM(LlavaLlamaForCausalLM): @@ -50,7 +47,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): self.config._name_or_path, torch_dtype=torch.float16, subfolder=self.vision_tower_subfolder, - ).cuda() + ).to("cuda") self.vision_tower.eval() @@ -94,8 +91,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # load language model self.language_model.load_weights(weights) - monkey_path_clip_vision_embed_forward() - class YiVLMultiModalProjector(nn.Module): def __init__(self, config: LlavaConfig): diff --git a/python/sglang/srt/openai_api/adapter.py b/python/sglang/srt/openai_api/adapter.py index 582457ae049..92c69000cd6 100644 --- a/python/sglang/srt/openai_api/adapter.py +++ b/python/sglang/srt/openai_api/adapter.py @@ -123,7 +123,7 @@ def create_streaming_error_response( def load_chat_template_for_openai_api(tokenizer_manager, chat_template_arg): global chat_template_name - print(f"Use chat template: {chat_template_arg}") + logger.info(f"Use chat template: {chat_template_arg}") if not chat_template_exists(chat_template_arg): if not os.path.exists(chat_template_arg): raise RuntimeError( @@ -275,10 +275,12 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe end_point = batch_storage[batch_id].endpoint file_request_list = [] all_requests = [] + request_ids = [] for line in lines: request_data = json.loads(line) file_request_list.append(request_data) body = request_data["body"] + request_ids.append(request_data["custom_id"]) # Although streaming is supported for standalone completions, it is not supported in # batch mode (multiple completions in single request). @@ -289,12 +291,16 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe all_requests.append(ChatCompletionRequest(**body)) elif end_point == "/v1/completions": all_requests.append(CompletionRequest(**body)) + if end_point == "/v1/chat/completions": adapted_request, request = v1_chat_generate_request( - all_requests, tokenizer_manager + all_requests, tokenizer_manager, request_ids=request_ids ) elif end_point == "/v1/completions": - adapted_request, request = v1_generate_request(all_requests) + adapted_request, request = v1_generate_request( + all_requests, request_ids=request_ids + ) + try: ret = await tokenizer_manager.generate_request(adapted_request).__anext__() if not isinstance(ret, list): @@ -326,6 +332,7 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe } all_ret.append(response_json) completed_requests += 1 + # Write results to a new file output_file_id = f"backend_result_file-{uuid.uuid4()}" global storage_dir @@ -355,7 +362,7 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe } except Exception as e: - print("error in SGLang:", e) + logger.error("error in SGLang:", e) # Update batch status to "failed" retrieve_batch = batch_storage[batch_id] retrieve_batch.status = "failed" @@ -372,6 +379,72 @@ async def v1_retrieve_batch(batch_id: str): return batch_response +async def v1_cancel_batch(tokenizer_manager, batch_id: str): + # Retrieve the batch job from the in-memory storage + batch_response = batch_storage.get(batch_id) + if batch_response is None: + raise HTTPException(status_code=404, detail="Batch not found") + + # Only do cancal when status is "validating" or "in_progress" + if batch_response.status in ["validating", "in_progress"]: + # Start cancelling the batch asynchronously + asyncio.create_task( + cancel_batch( + tokenizer_manager=tokenizer_manager, + batch_id=batch_id, + input_file_id=batch_response.input_file_id, + ) + ) + + # Update batch status to "cancelling" + batch_response.status = "cancelling" + + return batch_response + else: + raise HTTPException( + status_code=500, + detail=f"Current status is {batch_response.status}, no need to cancel", + ) + + +async def cancel_batch(tokenizer_manager, batch_id: str, input_file_id: str): + try: + # Update the batch status to "cancelling" + batch_storage[batch_id].status = "cancelling" + + # Retrieve the input file content + input_file_request = file_id_request.get(input_file_id) + if not input_file_request: + raise ValueError("Input file not found") + + # Parse the JSONL file and process each request + input_file_path = file_id_storage.get(input_file_id) + with open(input_file_path, "r", encoding="utf-8") as f: + lines = f.readlines() + + file_request_list = [] + request_ids = [] + for line in lines: + request_data = json.loads(line) + file_request_list.append(request_data) + request_ids.append(request_data["custom_id"]) + + # Cancel requests by request_ids + for rid in request_ids: + tokenizer_manager.abort_request(rid=rid) + + retrieve_batch = batch_storage[batch_id] + retrieve_batch.status = "cancelled" + + except Exception as e: + logger.error("error in SGLang:", e) + # Update batch status to "failed" + retrieve_batch = batch_storage[batch_id] + retrieve_batch.status = "failed" + retrieve_batch.failed_at = int(time.time()) + retrieve_batch.errors = {"message": str(e)} + + async def v1_retrieve_file(file_id: str): # Retrieve the batch job from the in-memory storage file_response = file_id_response.get(file_id) @@ -392,7 +465,9 @@ def iter_file(): return StreamingResponse(iter_file(), media_type="application/octet-stream") -def v1_generate_request(all_requests: List[CompletionRequest]): +def v1_generate_request( + all_requests: List[CompletionRequest], request_ids: List[str] = None +): prompts = [] sampling_params_list = [] return_logprobs = [] @@ -433,7 +508,14 @@ def v1_generate_request(all_requests: List[CompletionRequest]): "presence_penalty": request.presence_penalty, "frequency_penalty": request.frequency_penalty, "repetition_penalty": request.repetition_penalty, + "min_p": request.min_p, + "dry_multiplier": request.dry_multiplier, + "dry_base": request.dry_base, + "dry_allowed_length": request.dry_allowed_length, + "dry_penalty_last_n": request.dry_penalty_last_n, + "dry_sequence_breakers": request.dry_sequence_breakers, "regex": request.regex, + "json_schema": request.json_schema, "n": request.n, "ignore_eos": request.ignore_eos, } @@ -463,6 +545,7 @@ def v1_generate_request(all_requests: List[CompletionRequest]): logprob_start_len=logprob_start_lens, return_text_in_logprobs=True, stream=all_requests[0].stream, + rid=request_ids, ) if len(all_requests) == 1: @@ -745,7 +828,9 @@ async def generate_stream_resp(): def v1_chat_generate_request( - all_requests: List[ChatCompletionRequest], tokenizer_manager + all_requests: List[ChatCompletionRequest], + tokenizer_manager, + request_ids: List[str] = None, ): input_ids = [] sampling_params_list = [] @@ -765,8 +850,23 @@ def v1_chat_generate_request( if not isinstance(request.messages, str): # Apply chat template and its stop strings. if chat_template_name is None: + openai_compatible_messages = [] + for message in request.messages: + if isinstance(message.content, str): + openai_compatible_messages.append( + {"role": message.role, "content": message.content} + ) + else: + content_list = message.dict()["content"] + for content in content_list: + if content["type"] == "text": + openai_compatible_messages.append( + {"role": message.role, "content": content["text"]} + ) prompt_ids = tokenizer_manager.tokenizer.apply_chat_template( - request.messages, tokenize=True, add_generation_prompt=True + openai_compatible_messages, + tokenize=True, + add_generation_prompt=True, ) stop = request.stop image_data = None @@ -798,10 +898,17 @@ def v1_chat_generate_request( "stop": stop, "stop_token_ids": request.stop_token_ids, "top_p": request.top_p, + "min_p": request.min_p, "presence_penalty": request.presence_penalty, "frequency_penalty": request.frequency_penalty, "repetition_penalty": request.repetition_penalty, + "dry_multiplier": request.dry_multiplier, + "dry_base": request.dry_base, + "dry_allowed_length": request.dry_allowed_length, + "dry_penalty_last_n": request.dry_penalty_last_n, + "dry_sequence_breakers": request.dry_sequence_breakers, "regex": request.regex, + "json_schema": request.json_schema, "n": request.n, } ) @@ -832,6 +939,7 @@ def v1_chat_generate_request( top_logprobs_num=top_logprobs_nums, stream=all_requests[0].stream, return_text_in_logprobs=True, + rid=request_ids, ) if len(all_requests) == 1: return adapted_request, all_requests[0] diff --git a/python/sglang/srt/openai_api/protocol.py b/python/sglang/srt/openai_api/protocol.py index 758e48edefb..844551af1cc 100644 --- a/python/sglang/srt/openai_api/protocol.py +++ b/python/sglang/srt/openai_api/protocol.py @@ -161,11 +161,17 @@ class CompletionRequest(BaseModel): # Extra parameters for SRT backend only and will be ignored by OpenAI models. regex: Optional[str] = None + json_schema: Optional[str] = None ignore_eos: Optional[bool] = False min_tokens: Optional[int] = 0 repetition_penalty: Optional[float] = 1.0 stop_token_ids: Optional[List[int]] = Field(default_factory=list) - + min_p: float = 0.0 + dry_multiplier: float = 0.0 + dry_base: float = 0.0 + dry_allowed_length: int = 2 + dry_penalty_last_n: int = 0 + dry_sequence_breakers: Optional[List[str]] = [] class CompletionResponseChoice(BaseModel): index: int @@ -199,11 +205,6 @@ class CompletionStreamResponse(BaseModel): usage: Optional[UsageInfo] = None -class ChatCompletionMessageGenericParam(BaseModel): - role: Literal["system", "assistant"] - content: str - - class ChatCompletionMessageContentTextPart(BaseModel): type: Literal["text"] text: str @@ -224,6 +225,11 @@ class ChatCompletionMessageContentImagePart(BaseModel): ] +class ChatCompletionMessageGenericParam(BaseModel): + role: Literal["system", "assistant"] + content: Union[str, List[ChatCompletionMessageContentTextPart]] + + class ChatCompletionMessageUserParam(BaseModel): role: Literal["user"] content: Union[str, List[ChatCompletionMessageContentPart]] @@ -262,10 +268,16 @@ class ChatCompletionRequest(BaseModel): # Extra parameters for SRT backend only and will be ignored by OpenAI models. regex: Optional[str] = None + json_schema: Optional[str] = None min_tokens: Optional[int] = 0 repetition_penalty: Optional[float] = 1.0 stop_token_ids: Optional[List[int]] = Field(default_factory=list) - + min_p: float = 0.0 + dry_multiplier: float = 0.0 + dry_base: float = 0.0 + dry_allowed_length: int = 2 + dry_penalty_last_n: int = 0 + dry_sequence_breakers: Optional[List[str]] = [] class ChatMessage(BaseModel): role: Optional[str] = None diff --git a/python/sglang/srt/sampling/penaltylib/__init__.py b/python/sglang/srt/sampling/penaltylib/__init__.py index 43fff0fca44..5ab2a7dc2f5 100644 --- a/python/sglang/srt/sampling/penaltylib/__init__.py +++ b/python/sglang/srt/sampling/penaltylib/__init__.py @@ -3,6 +3,7 @@ from .penalizers.min_new_tokens import BatchedMinNewTokensPenalizer from .penalizers.presence_penalty import BatchedPresencePenalizer from .penalizers.repetition_penalty import BatchedRepetitionPenalizer +from .penalizers.dry_penalty import BatchedDryPenalizer __all__ = [ "BatchedFrequencyPenalizer", @@ -10,4 +11,5 @@ "BatchedPresencePenalizer", "BatchedRepetitionPenalizer", "BatchedPenalizerOrchestrator", + "BatchedDryPenalizer", ] diff --git a/python/sglang/srt/sampling/penaltylib/penalizers/dry_penalty.py b/python/sglang/srt/sampling/penaltylib/penalizers/dry_penalty.py new file mode 100644 index 00000000000..f86ecfd5d8e --- /dev/null +++ b/python/sglang/srt/sampling/penaltylib/penalizers/dry_penalty.py @@ -0,0 +1,127 @@ +import typing +from collections import defaultdict +import torch + +from ..orchestrator import _BatchedPenalizer, _TokenIDs + +class BatchedDryPenalizer(_BatchedPenalizer): + """ + DRY (Don't Repeat Yourself) penalizer penalizes tokens based on their repetition patterns in the input. + """ + + multipliers: torch.Tensor = None + bases: torch.Tensor = None + allowed_lengths: torch.Tensor = None + sequence_breakers: typing.List[set[int]] = None + ranges: torch.Tensor = None + input_ids: torch.Tensor = None + output_ids: torch.Tensor = None + + def _is_required(self) -> bool: + return any( + req.sampling_params.dry_multiplier != 0.0 + for req in self.orchestrator.reqs() + ) + + def _prepare(self): + self.multipliers = torch.tensor( + [req.sampling_params.dry_multiplier for req in self.orchestrator.reqs()], + dtype=torch.float32, + device=self.orchestrator.device + ) + self.bases = torch.tensor( + [req.sampling_params.dry_base for req in self.orchestrator.reqs()], + dtype=torch.float32, + device=self.orchestrator.device + ) + self.allowed_lengths = torch.tensor( + [req.sampling_params.dry_allowed_length for req in self.orchestrator.reqs()], + dtype=torch.float32, # Ensure this is float to match other tensors + device=self.orchestrator.device + ) + self.sequence_breakers = [ + [req.tokenizer.encode(f'a{prompt}', add_special_tokens=False)[-1] + for prompt in req.sampling_params.dry_sequence_breakers] + for req in self.orchestrator.reqs() + ] + self.ranges = torch.tensor( + [req.sampling_params.dry_penalty_last_n for req in self.orchestrator.reqs()], + dtype=torch.int64, + device=self.orchestrator.device + ) + + def _teardown(self): + del self.multipliers + del self.bases + del self.allowed_lengths + del self.sequence_breakers + del self.ranges + + self.multipliers = None + self.bases = None + self.allowed_lengths = None + self.sequence_breakers = None + self.ranges = None + + def _cumulate_input_tokens(self, input_ids: _TokenIDs): + self.input_ids = input_ids.token_ids + + def _cumulate_output_tokens(self, output_ids: _TokenIDs): + self.output_ids = output_ids.token_ids + + def _apply(self, logits: torch.Tensor) -> torch.Tensor: + batch_size, seq_length = logits.shape[0], logits.shape[1] + max_back_length = 50 # Limit the backward match to 50 to prevent overflow + for i in range(batch_size): + if self.output_ids is not None: + input_ids = self.input_ids[i] = torch.cat( + [self.input_ids[i], self.output_ids], dim=0 + ) + else: + input_ids = self.input_ids[i] + input_ids = input_ids.tolist() + range_limit = min(self.ranges[i].item(), len(input_ids)) + input_ids = input_ids[-range_limit:] if range_limit > 0 else input_ids + last_token = input_ids[-1] + if last_token in self.sequence_breakers[i]: + continue + + match_indices = [idx for idx, val in enumerate(input_ids[:-1]) if val == last_token] + match_lengths = defaultdict(int) + + for idx in match_indices: + next_token = input_ids[idx + 1] + if next_token in self.sequence_breakers[i]: + continue + match_length = 1 + while match_length < max_back_length and idx - match_length >= 0: + previous_token = input_ids[-(match_length + 1)] + if input_ids[idx - match_length] != previous_token: + break + if previous_token in self.sequence_breakers[i]: + break + match_length += 1 + match_lengths[next_token] = max(match_length, match_lengths[next_token]) + + for token, match_length in match_lengths.items(): + if match_length >= self.allowed_lengths[i].item(): + penalty = self.multipliers[i].item() * self.bases[i].item() ** (match_length - self.allowed_lengths[i].item()) + logits[i, token] -= penalty + + return logits + + def _filter( + self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor + ): + self.multipliers = self.multipliers[indices_tensor_to_keep] + self.bases = self.bases[indices_tensor_to_keep] + self.allowed_lengths = self.allowed_lengths[indices_tensor_to_keep] + self.sequence_breakers = [self.sequence_breakers[i] for i in indices_to_keep] + self.ranges = self.ranges[indices_tensor_to_keep] + + def _merge(self, their: "BatchedDryPenalizer"): + self.multipliers = torch.cat([self.multipliers, their.multipliers], dim=0) + self.bases = torch.cat([self.bases, their.bases], dim=0) + self.allowed_lengths = torch.cat([self.allowed_lengths, their.allowed_lengths], dim=0) + self.sequence_breakers.extend(their.sequence_breakers) + self.ranges = torch.cat([self.ranges, their.ranges], dim=0) \ No newline at end of file diff --git a/python/sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py b/python/sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py index 178cb54b24c..749a3b0bf4d 100644 --- a/python/sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +++ b/python/sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py @@ -54,9 +54,7 @@ def _cumulate_input_tokens(self, input_ids: _TokenIDs): pass def _cumulate_output_tokens(self, output_ids: _TokenIDs): - self.cumulated_frequency_penalties += ( - self.frequency_penalties * output_ids.occurrence_count() - ) + pass def _apply(self, logits: torch.Tensor) -> torch.Tensor: logits -= self.cumulated_frequency_penalties diff --git a/python/sglang/srt/sampling/sampling_batch_info.py b/python/sglang/srt/sampling/sampling_batch_info.py index bc70a9018ed..0da0e0ddb9b 100644 --- a/python/sglang/srt/sampling/sampling_batch_info.py +++ b/python/sglang/srt/sampling/sampling_batch_info.py @@ -21,16 +21,64 @@ class SamplingBatchInfo: top_ps: torch.Tensor = None top_ks: torch.Tensor = None min_ps: torch.Tensor = None - penalizer_orchestrator: penaltylib.BatchedPenalizerOrchestrator = None + + # Dispatch in CUDA graph + need_min_p_sampling: bool = False + + # Bias Tensors logit_bias: torch.Tensor = None vocab_mask: torch.Tensor = None + # Penalizer + penalizer_orchestrator: penaltylib.BatchedPenalizerOrchestrator = None + linear_penalties: torch.Tensor = None + scaling_penalties: torch.Tensor = None + + def can_run_in_cuda_graph(self): + # Vocab bias and min_ps are not supported in CUDA graph + return ( + self.logit_bias is None + and self.vocab_mask is None + and self.linear_penalties is None + and self.scaling_penalties is None + and not self.need_min_p_sampling + ) + + @classmethod + def dummy_one(cls, max_bs: int, vocab_size: int): + ret = cls(vocab_size=vocab_size) + ret.temperatures = torch.ones((max_bs, 1), dtype=torch.float, device="cuda") + ret.top_ps = torch.ones((max_bs,), dtype=torch.float, device="cuda") + ret.top_ks = torch.ones((max_bs,), dtype=torch.int, device="cuda") + return ret + + def __getitem__(self, key): + if isinstance(key, slice): + # NOTE:This method is only used in CUDA graph + assert self.can_run_in_cuda_graph() + return SamplingBatchInfo( + vocab_size=self.vocab_size, + temperatures=self.temperatures[key], + top_ps=self.top_ps[key], + top_ks=self.top_ks[key], + ) + else: + raise NotImplementedError + + def inplace_assign(self, bs: int, other: SamplingBatchInfo): + # NOTE:This method is only used in CUDA graph + assert self.can_run_in_cuda_graph() + + self.vocab_size = other.vocab_size + self.temperatures[:bs] = other.temperatures + self.top_ps[:bs] = other.top_ps + self.top_ks[:bs] = other.top_ks + @classmethod def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int): device = "cuda" reqs = batch.reqs ret = cls(vocab_size=vocab_size) - ret.temperatures = torch.tensor( [r.sampling_params.temperature for r in reqs], dtype=torch.float, @@ -45,6 +93,7 @@ def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int): ret.min_ps = torch.tensor( [r.sampling_params.min_p for r in reqs], dtype=torch.float, device=device ) + ret.need_min_p_sampling = any(r.sampling_params.min_p > 0 for r in reqs) # Each penalizers will do nothing if they evaluate themselves as not required by looking at # the sampling_params of the requests (See {_is_required()} of each penalizers). So this @@ -62,6 +111,7 @@ def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int): penaltylib.BatchedMinNewTokensPenalizer, penaltylib.BatchedPresencePenalizer, penaltylib.BatchedRepetitionPenalizer, + penaltylib.BatchedDryPenalizer, }, ) @@ -72,6 +122,25 @@ def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int): return ret + def prepare_penalties(self): + self.scaling_penalties = None + self.linear_penalties = None + + for penalizer in self.penalizer_orchestrator.penalizers.values(): + if isinstance(penalizer, penaltylib.BatchedRepetitionPenalizer): + if penalizer.is_prepared(): + self.scaling_penalties = penalizer.cumulated_repetition_penalties + else: + if penalizer.is_prepared(): + if self.linear_penalties is None: + bs = self.penalizer_orchestrator.batch.batch_size() + self.linear_penalties = torch.zeros( + (bs, self.vocab_size), + dtype=torch.float32, + device="cuda", + ) + self.linear_penalties = penalizer.apply(self.linear_penalties) + def update_regex_vocab_mask(self, batch: ScheduleBatch): bs, reqs = batch.batch_size(), batch.reqs device = "cuda" @@ -81,15 +150,15 @@ def update_regex_vocab_mask(self, batch: ScheduleBatch): self.vocab_mask = None if has_regex: + self.vocab_mask = torch.zeros( + bs, self.vocab_size, dtype=torch.bool, device=device + ) for i, req in enumerate(reqs): if req.regex_fsm is not None: - if self.vocab_mask is None: - self.vocab_mask = torch.zeros( - bs, self.vocab_size, dtype=torch.bool, device=device - ) + self.vocab_mask[i].fill_(1) self.vocab_mask[i][ req.regex_fsm.get_next_instruction(req.regex_fsm_state).tokens - ] = 1 + ] = 0 def filter(self, unfinished_indices: List[int], new_indices: torch.Tensor): self.penalizer_orchestrator.filter(unfinished_indices, new_indices) diff --git a/python/sglang/srt/sampling/sampling_params.py b/python/sglang/srt/sampling/sampling_params.py index c30717dd7cb..d0172937ec0 100644 --- a/python/sglang/srt/sampling/sampling_params.py +++ b/python/sglang/srt/sampling/sampling_params.py @@ -34,11 +34,17 @@ def __init__( frequency_penalty: float = 0.0, presence_penalty: float = 0.0, repetition_penalty: float = 1.0, + dry_multiplier: float = 0.0, + dry_base: float = 0.0, + dry_allowed_length: int = 2, + dry_penalty_last_n: int = 0, + dry_sequence_breakers: Optional[List[str]] = [], ignore_eos: bool = False, skip_special_tokens: bool = True, spaces_between_special_tokens: bool = True, regex: Optional[str] = None, n: int = 1, + json_schema: Optional[str] = None, ) -> None: self.temperature = temperature self.top_p = top_p @@ -47,6 +53,11 @@ def __init__( self.frequency_penalty = frequency_penalty self.presence_penalty = presence_penalty self.repetition_penalty = repetition_penalty + self.dry_multiplier=dry_multiplier + self.dry_base=dry_base + self.dry_allowed_length=dry_allowed_length + self.dry_penalty_last_n=dry_penalty_last_n + self.dry_sequence_breakers=dry_sequence_breakers self.stop_strs = stop self.stop_token_ids = {*stop_token_ids} self.max_new_tokens = max_new_tokens @@ -56,6 +67,7 @@ def __init__( self.spaces_between_special_tokens = spaces_between_special_tokens self.regex = regex self.n = n + self.json_schema = json_schema # Process some special cases if self.temperature < _SAMPLING_EPS: @@ -106,6 +118,17 @@ def verify(self): f"min_new_tokens must be in (0, max_new_tokens({self.max_new_tokens})], got " f"{self.min_new_tokens}." ) + if self.dry_multiplier is not None and self.dry_multiplier > 0.0: + if self.dry_multiplier < 0: + raise ValueError( + f"dry_multiplier must be at least 0, got {self.dry_multiplier}." + ) + if self.dry_allowed_length < 0: + raise ValueError( + f"dry_allowed_length must be at least 0, got {self.dry_allowed_length}." + ) + if self.regex is not None and self.json_schema is not None: + raise ValueError("regex and json_schema cannot be both set.") def normalize(self, tokenizer): # Process stop strings @@ -136,8 +159,14 @@ def to_srt_kwargs(self): "temperature": self.temperature, "top_p": self.top_p, "top_k": self.top_k, + "min_p": self.min_p, "frequency_penalty": self.frequency_penalty, "presence_penalty": self.presence_penalty, + "dry_multiplier": self.dry_multiplier, + "dry_base": self.dry_base, + "dry_allowed_length": self.dry_allowed_length, + "dry_penalty_last_n": self.dry_penalty_last_n, + "dry_sequence_breakers": self.dry_sequence_breakers, "ignore_eos": self.ignore_eos, "regex": self.regex, } diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index 3ec5cd633f4..feaf91dd390 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -59,6 +59,7 @@ from sglang.srt.openai_api.adapter import ( load_chat_template_for_openai_api, v1_batches, + v1_cancel_batch, v1_chat_completions, v1_completions, v1_delete_file, @@ -74,6 +75,7 @@ add_api_key_middleware, allocate_init_ports, assert_pkg_version, + configure_logger, enable_show_time_cost, kill_child_process, maybe_set_triton_cache_manager, @@ -245,6 +247,12 @@ async def openai_v1_batches(raw_request: Request): return await v1_batches(tokenizer_manager, raw_request) +@app.post("/v1/batches/{batch_id}/cancel") +async def cancel_batches(batch_id: str): + # https://platform.openai.com/docs/api-reference/batch/cancel + return await v1_cancel_batch(tokenizer_manager, batch_id) + + @app.get("/v1/batches/{batch_id}") async def retrieve_batch(batch_id: str): return await v1_retrieve_batch(batch_id) @@ -264,21 +272,18 @@ async def retrieve_file_content(file_id: str): def launch_server( server_args: ServerArgs, - model_overide_args: Optional[dict] = None, + model_override_args: Optional[dict] = None, pipe_finish_writer: Optional[mp.connection.Connection] = None, ): """Launch an HTTP server.""" global tokenizer_manager - logging.basicConfig( - level=getattr(logging, server_args.log_level.upper()), - format="%(message)s", - ) + configure_logger(server_args) server_args.check_server_args() _set_envs_and_config(server_args) - # Allocate ports + # Allocate ports for inter-process communications server_args.port, server_args.additional_ports = allocate_init_ports( server_args.port, server_args.additional_ports, @@ -312,7 +317,7 @@ def launch_server( tp_rank_range, server_args, ports[3], - model_overide_args, + model_override_args, ) try: @@ -323,21 +328,23 @@ def launch_server( return # Launch processes - tokenizer_manager = TokenizerManager(server_args, port_args, model_overide_args) + tokenizer_manager = TokenizerManager(server_args, port_args, model_override_args) if server_args.chat_template: load_chat_template_for_openai_api(tokenizer_manager, server_args.chat_template) pipe_controller_reader, pipe_controller_writer = mp.Pipe(duplex=False) pipe_detoken_reader, pipe_detoken_writer = mp.Pipe(duplex=False) if server_args.dp_size == 1: - start_process = start_controller_process_single + start_controller_process = start_controller_process_single else: - start_process = start_controller_process_multi + start_controller_process = start_controller_process_multi + proc_controller = mp.Process( - target=start_process, - args=(server_args, port_args, pipe_controller_writer, model_overide_args), + target=start_controller_process, + args=(server_args, port_args, pipe_controller_writer, model_override_args), ) proc_controller.start() + proc_detoken = mp.Process( target=start_detokenizer_process, args=( @@ -414,7 +421,7 @@ def _set_envs_and_config(server_args: ServerArgs): if not server_args.disable_flashinfer: assert_pkg_version( "flashinfer", - "0.1.5", + "0.1.6", "Please uninstall the old version and " "reinstall the latest version by following the instructions " "at https://docs.flashinfer.ai/installation.html.", @@ -494,7 +501,7 @@ class Runtime: def __init__( self, log_level: str = "error", - model_overide_args: Optional[dict] = None, + model_override_args: Optional[dict] = None, *args, **kwargs, ): @@ -515,9 +522,10 @@ def __init__( self.pid = None pipe_reader, pipe_writer = mp.Pipe(duplex=False) + proc = mp.Process( target=launch_server, - args=(self.server_args, model_overide_args, pipe_writer), + args=(self.server_args, model_override_args, pipe_writer), ) proc.start() pipe_writer.close() @@ -591,7 +599,7 @@ async def async_generate( def generate( self, - prompt: str, + prompt: Union[str, List[str]], sampling_params: Optional[Dict] = None, return_logprob: Optional[Union[List[bool], bool]] = False, logprob_start_len: Optional[Union[List[int], int]] = None, @@ -612,7 +620,7 @@ def generate( def encode( self, - prompt: str, + prompt: Union[str, List[str]], ): json_data = { "text": prompt, diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 33451d645e7..8a56c02e162 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -33,11 +33,13 @@ class ServerArgs: skip_tokenizer_init: bool = False load_format: str = "auto" dtype: str = "auto" + kv_cache_dtype: str = "auto" trust_remote_code: bool = True context_length: Optional[int] = None quantization: Optional[str] = None served_model_name: Optional[str] = None chat_template: Optional[str] = None + is_embedding: bool = False # Port host: str = "127.0.0.1" @@ -81,13 +83,12 @@ class ServerArgs: disable_cuda_graph: bool = False disable_cuda_graph_padding: bool = False disable_disk_cache: bool = False + disable_custom_all_reduce: bool = False enable_mixed_chunk: bool = False enable_torch_compile: bool = False enable_p2p_check: bool = False enable_mla: bool = False - attention_reduce_in_fp32: bool = False - efficient_weight_load: bool = False - disable_custom_all_reduce: bool = False + triton_attention_reduce_in_fp32: bool = False # Distributed args nccl_init_addr: Optional[str] = None @@ -196,11 +197,23 @@ def add_cli_args(parser: argparse.ArgumentParser): '* "float" is shorthand for FP32 precision.\n' '* "float32" for FP32 precision.', ) + parser.add_argument( + "--kv-cache-dtype", + type=str, + default=ServerArgs.kv_cache_dtype, + choices=["auto", "fp8_e5m2"], + help='Data type for kv cache storage. "auto" will use model data type. "fp8_e5m2" is supported for CUDA 11.8+.', + ) parser.add_argument( "--trust-remote-code", action="store_true", help="Whether or not to allow for custom models defined on the Hub in their own modeling files.", ) + parser.add_argument( + "--is-embedding", + action="store_true", + help="Whether to use a CausalLM as an embedding model.", + ) parser.add_argument( "--context-length", type=int, @@ -404,10 +417,16 @@ def add_cli_args(parser: argparse.ArgumentParser): action="store_true", help="Disable disk cache to avoid possible crashes related to file system or high concurrency.", ) + parser.add_argument( + "--disable-custom-all-reduce", + action="store_true", + default=False, + help="Disable the custom all-reduce kernel and fall back to NCCL.", + ) parser.add_argument( "--enable-mixed-chunk", action="store_true", - help="Enabling mixing prefill and decode in a chunked batch.", + help="Enabling mixing prefill and decode in a batch when using chunked prefill.", ) parser.add_argument( "--enable-torch-compile", @@ -425,7 +444,7 @@ def add_cli_args(parser: argparse.ArgumentParser): help="Enable Multi-head Latent Attention (MLA) for DeepSeek-V2.", ) parser.add_argument( - "--attention-reduce-in-fp32", + "--triton-attention-reduce-in-fp32", action="store_true", help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16." "This only affects Triton attention kernels.", @@ -435,12 +454,6 @@ def add_cli_args(parser: argparse.ArgumentParser): action="store_true", help="Turn on memory efficient weight loading with quantization (quantize per layer during loading).", ) - parser.add_argument( - "--disable-custom-all-reduce", - action="store_true", - default=False, - help="Disable the custom all-reduce kernel and fall back to NCCL.", - ) @classmethod def from_cli_args(cls, args: argparse.Namespace): @@ -459,6 +472,11 @@ def check_server_args(self): assert not ( self.dp_size > 1 and self.node_rank is not None ), "multi-node data parallel is not supported" + if "Alibaba-NLP/gte-Qwen2-1.5B-instruct" == self.model_path: + logger.info( + "Not sure why, the tokenizer will add an additional token at the end of the prompt when trust_remote_mode=True" + ) + self.trust_remote_code = False if "gemma-2" in self.model_path.lower(): logger.info("When using sliding window in gemma-2, turn on flashinfer.") self.disable_flashinfer = False diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index a15ea16307b..66a5679d756 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -26,7 +26,7 @@ import time from importlib.metadata import PackageNotFoundError, version from io import BytesIO -from typing import List, Optional +from typing import List, Optional, Union import numpy as np import psutil @@ -193,44 +193,30 @@ def allocate_init_ports( return ret_ports[0], ret_ports[1:num_ports_needed] -def get_int_token_logit_bias(tokenizer, vocab_size): - """Get the logit bias for integer-only tokens.""" - # a bug when model's vocab size > tokenizer.vocab_size - if tokenizer == None: - return [-1e5] * vocab_size - vocab_size = tokenizer.vocab_size - logit_bias = np.zeros(vocab_size, dtype=np.float32) - for t_id in range(vocab_size): - ss = tokenizer.decode([t_id]).strip() - if not (ss.isdigit() or len(ss) == 0 or t_id == tokenizer.eos_token_id): - logit_bias[t_id] = -1e5 - - return logit_bias - - -def is_multimodal_model(model): - from sglang.srt.model_config import ModelConfig - - if isinstance(model, str): - model = model.lower() - return "llava" in model or "yi-vl" in model or "llava-next" in model - - if isinstance(model, ModelConfig): - model_path = model.path.lower() - return ( - "llava" in model_path or "yi-vl" in model_path or "llava-next" in model_path - ) +def is_multimodal_model(model_architectures): + if ( + "LlavaLlamaForCausalLM" in model_architectures + or "LlavaQwenForCausalLM" in model_architectures + or "LlavaMistralForCausalLM" in model_architectures + or "LlavaVidForCausalLM" in model_architectures + ): + return True + else: + return False - raise ValueError("unrecognized type") +def is_generation_model(model_architectures, is_embedding: bool = False): + # We have two ways to determine whether a model is a generative model. + # 1. Check the model architectue + # 2. check the `is_embedding` server args -def is_generation_model(model_architectures): if ( "LlamaEmbeddingModel" in model_architectures or "MistralModel" in model_architectures ): return False - return True + else: + return not is_embedding def decode_video_base64(video_base64): @@ -312,12 +298,14 @@ def decode_video_base64(video_base64): ) # Return an empty array and size tuple if no frames were found -def load_image(image_file): +def load_image(image_file: Union[str, bytes]): from PIL import Image image = image_size = None - if image_file.startswith("http://") or image_file.startswith("https://"): + if isinstance(image_file, bytes): + image = Image.open(BytesIO(image_file)) + elif image_file.startswith("http://") or image_file.startswith("https://"): timeout = int(os.getenv("REQUEST_TIMEOUT", "3")) response = requests.get(image_file, timeout=timeout) image = Image.open(BytesIO(response.content)) @@ -329,8 +317,10 @@ def load_image(image_file): elif image_file.startswith("video:"): image_file = image_file.replace("video:", "") image, image_size = decode_video_base64(image_file) - else: + elif isinstance(image_file, str): image = Image.open(BytesIO(base64.b64decode(image_file))) + else: + raise ValueError(f"Invalid image: {image}") return image, image_size @@ -347,7 +337,7 @@ def suppress_other_loggers(): logging.WARN ) logging.getLogger("vllm.selector").setLevel(logging.WARN) - logging.getLogger("vllm.utils").setLevel(logging.WARN) + logging.getLogger("vllm.utils").setLevel(logging.ERROR) def assert_pkg_version(pkg: str, min_version: str, message: str): @@ -417,7 +407,6 @@ def monkey_patch_vllm_dummy_weight_loader(): DummyModelLoader, LoRAConfig, ModelConfig, - MultiModalConfig, ParallelConfig, SchedulerConfig, _initialize_model, @@ -432,7 +421,6 @@ def load_model( model_config: ModelConfig, device_config: DeviceConfig, lora_config: Optional[LoRAConfig], - multimodal_config: Optional[MultiModalConfig], parallel_config: ParallelConfig, scheduler_config: SchedulerConfig, cache_config: CacheConfig, @@ -443,7 +431,6 @@ def load_model( model_config, self.load_config, lora_config, - multimodal_config, cache_config, ) @@ -451,10 +438,6 @@ def load_model( quant_method = getattr(module, "quant_method", None) if quant_method is not None: quant_method.process_weights_after_loading(module) - # FIXME: Remove this after Mixtral is updated - # to use quant_method. - if hasattr(module, "process_weights_after_loading"): - module.process_weights_after_loading() # NOTE(woosuk): For accurate performance evaluation, we assign # random values to the weights. @@ -691,7 +674,7 @@ def weight_loader_srt( setattr(QKVParallelLinear, "weight_loader", weight_loader_srt) -def add_api_key_middleware(app, api_key): +def add_api_key_middleware(app, api_key: str): @app.middleware("http") async def authentication(request, call_next): if request.method == "OPTIONS": @@ -703,7 +686,7 @@ async def authentication(request, call_next): return await call_next(request) -def prepare_model(model_path): +def prepare_model(model_path: str): if "SGLANG_USE_MODELSCOPE" in os.environ: if not os.path.exists(model_path): from modelscope import snapshot_download @@ -712,7 +695,7 @@ def prepare_model(model_path): return model_path -def prepare_tokenizer(tokenizer_path): +def prepare_tokenizer(tokenizer_path: str): if "SGLANG_USE_MODELSCOPE" in os.environ: if not os.path.exists(tokenizer_path): from modelscope import snapshot_download @@ -721,3 +704,13 @@ def prepare_tokenizer(tokenizer_path): tokenizer_path, ignore_patterns=["*.bin", "*.safetensors"] ) return tokenizer_path + + +def configure_logger(server_args, prefix: str = ""): + format = f"[%(asctime)s{prefix}] %(message)s" + logging.basicConfig( + level=getattr(logging, server_args.log_level.upper()), + format=format, + datefmt="%H:%M:%S", + force=True, + ) diff --git a/python/sglang/test/runners.py b/python/sglang/test/runners.py index 9386d7f7afd..ac69ab875b9 100644 --- a/python/sglang/test/runners.py +++ b/python/sglang/test/runners.py @@ -14,7 +14,7 @@ """ import json -import multiprocessing +import multiprocessing as mp import os from dataclasses import dataclass from typing import List, Union @@ -24,15 +24,15 @@ from transformers import AutoModelForCausalLM, AutoTokenizer from sglang.srt.server import Runtime -from sglang.srt.utils import is_generation_model +from sglang.test.test_utils import DEFAULT_PORT_FOR_SRT_TEST_RUNNER DEFAULT_PROMPTS = [ # the output of gemma-2-2b from SRT is unstable on the commented prompt # "The capital of France is", - "The capital of the United Kindom is", + "Apple is red. Banana is Yellow. " * 800 + "Apple is", + "The capital of the United Kingdom is", "Today is a sunny day and I like", "AI is a field of computer science focused on", - "Apple is red. Banana is Yellow. " * 800 + "Apple is", ] dirpath = os.path.dirname(__file__) @@ -63,44 +63,37 @@ class HFRunner: def __init__( self, model_path, - torch_dtype=torch.float16, - is_generation_model=None, + torch_dtype, + is_generation, ): - self.in_queue = multiprocessing.Queue() - self.out_queue = multiprocessing.Queue() + self.is_generation = is_generation + + self.in_queue = mp.Queue() + self.out_queue = mp.Queue() - self.model_proc = multiprocessing.Process( + self.model_proc = mp.Process( target=self.start_model_process, args=( self.in_queue, self.out_queue, model_path, torch_dtype, - is_generation_model, ), ) self.model_proc.start() - def start_model_process( - self, in_queue, out_queue, model_path, torch_dtype, is_generation_model - ): + def start_model_process(self, in_queue, out_queue, model_path, torch_dtype): self.tokenizer = AutoTokenizer.from_pretrained( model_path, torch_dtype=torch_dtype, - trust_remote_code=True, ) - self.is_generation_model = ( - is_generation_model(model_path) - if is_generation_model is None - else is_generation_model - ) - if self.is_generation_model: + if self.is_generation: self.model = AutoModelForCausalLM.from_pretrained( model_path, torch_dtype=torch_dtype, + trust_remote_code=False, low_cpu_mem_usage=True, - trust_remote_code=True, ).cuda() else: from sentence_transformers import SentenceTransformer @@ -113,7 +106,7 @@ def start_model_process( while True: prompts, max_new_tokens = in_queue.get() if prompts is not None: - if self.is_generation_model: + if self.is_generation: output_strs = [] prefill_logprobs = [] for p in prompts: @@ -176,22 +169,20 @@ class SRTRunner: def __init__( self, model_path, + torch_dtype, + is_generation, tp_size=1, - torch_dtype=torch.float16, - is_generation_model=None, - port=5157, + port=DEFAULT_PORT_FOR_SRT_TEST_RUNNER, ): - self.is_generation_model = ( - is_generation_model(model_path) - if is_generation_model is None - else is_generation_model - ) + self.is_generation = is_generation self.runtime = Runtime( model_path=model_path, tp_size=tp_size, dtype=get_dtype_str(torch_dtype), port=port, - mem_fraction_static=0.7, + mem_fraction_static=0.69, + trust_remote_code=False, + is_embedding=not self.is_generation, ) def forward( @@ -199,7 +190,7 @@ def forward( prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS, max_new_tokens=8, ): - if self.is_generation_model: + if self.is_generation: # the return value contains logprobs from prefill output_strs = [] top_input_logprobs = [] diff --git a/python/sglang/test/test_activation.py b/python/sglang/test/test_activation.py new file mode 100644 index 00000000000..357a23319bc --- /dev/null +++ b/python/sglang/test/test_activation.py @@ -0,0 +1,55 @@ +import itertools +import unittest + +import torch + +from sglang.srt.layers.activation import GeluAndMul + + +class TestGeluAndMul(unittest.TestCase): + DTYPES = [torch.half, torch.bfloat16] + NUM_TOKENS = [7, 83, 2048] + D = [512, 4096, 5120, 13824] + SEEDS = [0] + + @classmethod + def setUpClass(cls): + if not torch.cuda.is_available(): + raise unittest.SkipTest("CUDA is not available") + torch.set_default_device("cuda") + + def _run_gelu_and_mul_test(self, num_tokens, d, dtype, seed): + torch.manual_seed(seed) + + layer = GeluAndMul().to(dtype=dtype) + x = torch.randn(num_tokens, 2 * d, dtype=dtype) + + with torch.inference_mode(): + ref_out = layer.forward_native(x) + out = layer.forward_cuda(x) + + if dtype == torch.bfloat16: + atol = rtol = 1e-2 + else: + atol = rtol = 1e-3 + + self.assertTrue(torch.allclose(out, ref_out, atol=atol, rtol=rtol)) + + def test_gelu_and_mul(self): + for params in itertools.product( + self.NUM_TOKENS, + self.D, + self.DTYPES, + self.SEEDS, + ): + with self.subTest( + num_tokens=params[0], + d=params[1], + dtype=params[2], + seed=params[3], + ): + self._run_gelu_and_mul_test(*params) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/python/sglang/test/test_layernorm.py b/python/sglang/test/test_layernorm.py index ab61aa80405..770e69733db 100644 --- a/python/sglang/test/test_layernorm.py +++ b/python/sglang/test/test_layernorm.py @@ -3,7 +3,7 @@ import torch -from sglang.srt.layers.layernorm import RMSNorm +from sglang.srt.layers.layernorm import GemmaRMSNorm, RMSNorm class TestRMSNorm(unittest.TestCase): @@ -56,5 +56,57 @@ def test_rms_norm(self): self._run_rms_norm_test(*params) +class TestGemmaRMSNorm(unittest.TestCase): + DTYPES = [torch.half, torch.bfloat16] + NUM_TOKENS = [7, 83, 4096] + HIDDEN_SIZES = [768, 769, 770, 771, 5120, 5124, 5125, 5126, 8192, 8199] + ADD_RESIDUAL = [False, True] + SEEDS = [0] + + @classmethod + def setUpClass(cls): + if not torch.cuda.is_available(): + raise unittest.SkipTest("CUDA is not available") + torch.set_default_device("cuda") + + def _run_gemma_rms_norm_test( + self, num_tokens, hidden_size, add_residual, dtype, seed + ): + torch.manual_seed(seed) + + layer = GemmaRMSNorm(hidden_size).to(dtype=dtype) + layer.weight.data.normal_(mean=1.0, std=0.1) + scale = 1 / (2 * hidden_size) + x = torch.randn(num_tokens, hidden_size, dtype=dtype) * scale + residual = torch.randn_like(x) * scale if add_residual else None + + with torch.inference_mode(): + ref_out = layer.forward_native(x, residual) + out = layer(x, residual) + + if add_residual: + self.assertTrue(torch.allclose(out[0], ref_out[0], atol=1e-3, rtol=1e-3)) + self.assertTrue(torch.allclose(out[1], ref_out[1], atol=1e-3, rtol=1e-3)) + else: + self.assertTrue(torch.allclose(out, ref_out, atol=1e-3, rtol=1e-3)) + + def test_gemma_rms_norm(self): + for params in itertools.product( + self.NUM_TOKENS, + self.HIDDEN_SIZES, + self.ADD_RESIDUAL, + self.DTYPES, + self.SEEDS, + ): + with self.subTest( + num_tokens=params[0], + hidden_size=params[1], + add_residual=params[2], + dtype=params[3], + seed=params[4], + ): + self._run_gemma_rms_norm_test(*params) + + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/python/sglang/test/test_programs.py b/python/sglang/test/test_programs.py index ce402558550..bdecdff2f94 100644 --- a/python/sglang/test/test_programs.py +++ b/python/sglang/test/test_programs.py @@ -2,8 +2,12 @@ import json import re +import time + +import numpy as np import sglang as sgl +from sglang.utils import fetch_and_cache_jsonl def test_few_shot_qa(): @@ -447,3 +451,67 @@ def gen_character_spec(s): ) gen_character_spec().sync() + + +def test_hellaswag_select(): + """Benchmark the accuracy of sgl.select on the HellaSwag dataset.""" + + url = "https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_val.jsonl" + lines = fetch_and_cache_jsonl(url) + + # Construct prompts + def get_one_example(lines, i, include_answer): + ret = lines[i]["activity_label"] + ": " + lines[i]["ctx"] + " " + if include_answer: + ret += lines[i]["endings"][lines[i]["label"]] + return ret + + def get_few_shot_examples(lines, k): + ret = "" + for i in range(k): + ret += get_one_example(lines, i, True) + "\n\n" + return ret + + num_questions = 200 + num_shots = 20 + few_shot_examples = get_few_shot_examples(lines, num_shots) + + questions = [] + choices = [] + labels = [] + for i in range(len(lines[:num_questions])): + questions.append(get_one_example(lines, i, False)) + choices.append(lines[i]["endings"]) + labels.append(lines[i]["label"]) + arguments = [{"question": q, "choices": c} for q, c in zip(questions, choices)] + + ##################################### + ######### SGL Program Begin ######### + ##################################### + + import sglang as sgl + + @sgl.function + def few_shot_hellaswag(s, question, choices): + s += few_shot_examples + question + s += sgl.select("answer", choices=choices) + + ##################################### + ########## SGL Program End ########## + ##################################### + + # Run requests + tic = time.time() + rets = few_shot_hellaswag.run_batch( + arguments, + temperature=0, + num_threads=64, + progress_bar=True, + ) + preds = [choices[i].index(rets[i]["answer"]) for i in range(len(rets))] + latency = time.time() - tic + + # Compute accuracy + accuracy = np.mean(np.array(preds) == np.array(labels)) + + return accuracy, latency diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index 9f6aa68ab12..1b9b63e882f 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -2,12 +2,10 @@ import argparse import asyncio -import multiprocessing import os import subprocess import threading import time -import unittest from functools import partial from typing import Callable, List, Optional @@ -19,21 +17,23 @@ from sglang.global_config import global_config from sglang.lang.backend.openai import OpenAI from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint +from sglang.srt.utils import kill_child_process from sglang.utils import get_exception_traceback DEFAULT_MODEL_NAME_FOR_TEST = "meta-llama/Meta-Llama-3.1-8B-Instruct" DEFAULT_MOE_MODEL_NAME_FOR_TEST = "mistralai/Mixtral-8x7B-Instruct-v0.1" +DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH = 600 +DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_TP1 = "meta-llama/Meta-Llama-3.1-8B-Instruct,mistralai/Mistral-7B-Instruct-v0.3,deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct,google/gemma-2-27b-it" +DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_TP2 = "meta-llama/Meta-Llama-3.1-70B-Instruct,mistralai/Mixtral-8x7B-Instruct-v0.1,Qwen/Qwen2-57B-A14B-Instruct" +DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_FP8_TP1 = "neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8,neuralmagic/Mistral-7B-Instruct-v0.3-FP8,neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8,neuralmagic/gemma-2-2b-it-FP8" +DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_FP8_TP2 = "neuralmagic/Meta-Llama-3.1-70B-Instruct-FP8,neuralmagic/Mixtral-8x7B-Instruct-v0.1-FP8,neuralmagic/Qwen2-72B-Instruct-FP8,neuralmagic/Qwen2-57B-A14B-Instruct-FP8" if os.getenv("SGLANG_IS_IN_CI", "false") == "true": - DEFAULT_URL_FOR_MOE_TEST = "http://127.0.0.1:6157" - DEFAULT_URL_FOR_ACCURACY_TEST = "http://127.0.0.1:7157" - DEFAULT_URL_FOR_UNIT_TEST = "http://127.0.0.1:8157" - DEFAULT_URL_FOR_E2E_TEST = "http://127.0.0.1:9157" + DEFAULT_PORT_FOR_SRT_TEST_RUNNER = 5157 + DEFAULT_URL_FOR_TEST = "http://127.0.0.1:6157" else: - DEFAULT_URL_FOR_MOE_TEST = "http://127.0.0.1:1157" - DEFAULT_URL_FOR_ACCURACY_TEST = "http://127.0.0.1:1257" - DEFAULT_URL_FOR_UNIT_TEST = "http://127.0.0.1:1357" - DEFAULT_URL_FOR_E2E_TEST = "http://127.0.0.1:1457" + DEFAULT_PORT_FOR_SRT_TEST_RUNNER = 1157 + DEFAULT_URL_FOR_TEST = "http://127.0.0.1:2157" def call_generate_lightllm(prompt, temperature, max_tokens, stop=None, url=None): @@ -465,34 +465,36 @@ def run_unittest_files(files: List[str], timeout_per_file: float): success = True for filename in files: + global process - def func(): - print(f"\n\nRun {filename}\n\n") - ret = unittest.main(module=None, argv=["", "-vb"] + [filename]) - - p = multiprocessing.Process(target=func) - - def run_one_file(): - p.start() - p.join() + def run_one_file(filename): + filename = os.path.join(os.getcwd(), filename) + print(f"\n\nRun:\npython3 {filename}\n\n", flush=True) + process = subprocess.Popen( + ["python3", filename], stdout=None, stderr=None, env=os.environ + ) + process.wait() + return process.returncode try: - run_with_timeout(run_one_file, timeout=timeout_per_file) - if p.exitcode != 0: - success = False - break + ret_code = run_with_timeout( + run_one_file, args=(filename,), timeout=timeout_per_file + ) + assert ret_code == 0 except TimeoutError: - p.terminate() + kill_child_process(process.pid) time.sleep(5) print( - f"\nTimeout after {timeout_per_file} seconds when running {filename}\n" + f"\nTimeout after {timeout_per_file} seconds when running {filename}\n", + flush=True, ) - return False + success = False + break if success: - print(f"Success. Time elapsed: {time.time() - tic:.2f}s") + print(f"Success. Time elapsed: {time.time() - tic:.2f}s", flush=True) else: - print(f"Fail. Time elapsed: {time.time() - tic:.2f}s") + print(f"Fail. Time elapsed: {time.time() - tic:.2f}s", flush=True) return 0 if success else -1 diff --git a/python/sglang/utils.py b/python/sglang/utils.py index c880d259d53..b212f6caa31 100644 --- a/python/sglang/utils.py +++ b/python/sglang/utils.py @@ -4,6 +4,7 @@ import importlib import json import logging +import os import signal import sys import traceback @@ -15,6 +16,7 @@ import numpy as np import requests +from tqdm import tqdm logger = logging.getLogger(__name__) @@ -260,3 +262,40 @@ def __getattr__(self, name: str): def __call__(self, *args, **kwargs): module = self._load() return module(*args, **kwargs) + + +def fetch_and_cache_jsonl(url, cache_file="cached_data.jsonl"): + """Read and cache a jsonl file from a url.""" + + # Check if the cache file already exists + if os.path.exists(cache_file): + print("Loading data from cache...") + with open(cache_file, "r") as f: + data = [json.loads(line) for line in f] + else: + print("Downloading data from URL...") + # Stream the response to show the progress bar + response = requests.get(url, stream=True) + response.raise_for_status() # Check for request errors + + # Total size of the file in bytes + total_size = int(response.headers.get("content-length", 0)) + chunk_size = 1024 # Download in chunks of 1KB + + # Use tqdm to display the progress bar + with open(cache_file, "wb") as f, tqdm( + desc=cache_file, + total=total_size, + unit="B", + unit_scale=True, + unit_divisor=1024, + ) as bar: + for chunk in response.iter_content(chunk_size=chunk_size): + f.write(chunk) + bar.update(len(chunk)) + + # Convert the data to a list of dictionaries + with open(cache_file, "r") as f: + data = [json.loads(line) for line in f] + + return data diff --git a/python/sglang/version.py b/python/sglang/version.py index 11ef0928681..493f7415d73 100644 --- a/python/sglang/version.py +++ b/python/sglang/version.py @@ -1 +1 @@ -__version__ = "0.2.13" +__version__ = "0.3.0" diff --git a/scripts/convert_yi_vl.py b/scripts/deprecated/convert_yi_vl.py similarity index 100% rename from scripts/convert_yi_vl.py rename to scripts/deprecated/convert_yi_vl.py diff --git a/scripts/convert_yi_vl.sh b/scripts/deprecated/convert_yi_vl.sh similarity index 100% rename from scripts/convert_yi_vl.sh rename to scripts/deprecated/convert_yi_vl.sh diff --git a/scripts/deprecated/test_httpserver_classify.py b/scripts/deprecated/test_httpserver_classify.py index cafbd19fdf6..dbcafb88d7d 100644 --- a/scripts/deprecated/test_httpserver_classify.py +++ b/scripts/deprecated/test_httpserver_classify.py @@ -1,6 +1,6 @@ """ Usage: -python3 -m sglang.launch_server --model-path /model/llama-classification +python3 -m sglang.launch_server --disable-cuda-graph --model-path /model/llama-classification python3 test_httpserver_classify.py """ diff --git a/scripts/playground/reference_hf.py b/scripts/playground/reference_hf.py index d2d31161017..95aeddb9a14 100644 --- a/scripts/playground/reference_hf.py +++ b/scripts/playground/reference_hf.py @@ -3,23 +3,24 @@ python3 reference_hf.py --model TinyLlama/TinyLlama-1.1B-Chat-v0.4 Reference output: +========== Prompt 0 ========== +prefill logits (final) tensor([-8.3125, -7.1172, 3.3398, ..., -4.9531, -4.1328, -3.4141], + device='cuda:0') The capital of France is Paris. The capital of the United States is Washington, D.C. -The capital of Canada is Ottawa. -The capital of Japan is Tokyo -prefill logits tensor([-8.3125, -7.1172, 3.3398, ..., -4.9570, -4.1328, -3.4141], + +========== Prompt 1 ========== +prefill logits (final) tensor([-8.9062, -9.0156, 4.1484, ..., -4.9922, -4.4961, -4.0742], device='cuda:0') The capital of the United Kindom is London. The capital of the United Kingdom is London. -The capital of the United Kingdom is London. -The capital of the United Kingdom is London. -prefill logits tensor([-8.9062, -9.0156, 4.1406, ..., -4.9922, -4.4961, -4.0742], +The capital of + +========== Prompt 2 ========== +prefill logits (final) tensor([-9.6328, -9.0547, 4.0234, ..., -5.3047, -4.7148, -4.4609], device='cuda:0') Today is a sunny day and I like to go for a walk in the park. -I'm going to the park to play in the grass and water. -Today is a very -prefill logits tensor([-9.6328, -9.0547, 4.0195, ..., -5.3047, -4.7148, -4.4609], - device='cuda:0') +I'm going to the """ import argparse @@ -47,7 +48,7 @@ def normal_text(args): ] max_new_tokens = 16 - for p in prompts: + for i, p in enumerate(prompts): if isinstance(p, str): input_ids = t.encode(p, return_tensors="pt").cuda() else: @@ -60,7 +61,8 @@ def normal_text(args): prefill_logits = m.forward(input_ids).logits[0][-1] - print("prefill logits", prefill_logits) + print(f"\n========== Prompt {i} ==========") + print("prefill logits (final)", prefill_logits) print(output_str) diff --git a/test/lang/test_srt_backend.py b/test/lang/test_srt_backend.py index fcd86ae3d31..62c59592821 100644 --- a/test/lang/test_srt_backend.py +++ b/test/lang/test_srt_backend.py @@ -7,6 +7,7 @@ test_dtype_gen, test_expert_answer, test_few_shot_qa, + test_hellaswag_select, test_mt_bench, test_parallel_decoding, test_regex, @@ -62,6 +63,12 @@ def test_regex(self): def test_dtype_gen(self): test_dtype_gen() + def test_hellaswag_select(self): + # Run twice to capture more bugs + for _ in range(2): + accuracy, latency = test_hellaswag_select() + assert accuracy > 0.71 + if __name__ == "__main__": unittest.main() diff --git a/test/srt/models/test_embedding_models.py b/test/srt/models/test_embedding_models.py index 67e47d90d3b..a5a73bf319f 100644 --- a/test/srt/models/test_embedding_models.py +++ b/test/srt/models/test_embedding_models.py @@ -13,6 +13,7 @@ limitations under the License. """ +import multiprocessing as mp import unittest import torch @@ -20,7 +21,10 @@ from sglang.test.runners import DEFAULT_PROMPTS, HFRunner, SRTRunner from sglang.test.test_utils import get_similarities -MODELS = [("intfloat/e5-mistral-7b-instruct", 1)] +MODELS = [ + ("Alibaba-NLP/gte-Qwen2-1.5B-instruct", 1, 1e-5), + ("intfloat/e5-mistral-7b-instruct", 1, 1e-5), +] TORCH_DTYPES = [torch.float16] @@ -32,9 +36,10 @@ def assert_close_prefill_logits( model_path, tp_size, torch_dtype, + prefill_tolerance, ) -> None: with HFRunner( - model_path, torch_dtype=torch_dtype, is_generation_model=False + model_path, torch_dtype=torch_dtype, is_generation=False ) as hf_runner: hf_outputs = hf_runner.forward(prompts) @@ -42,32 +47,34 @@ def assert_close_prefill_logits( model_path, tp_size=tp_size, torch_dtype=torch_dtype, - is_generation_model=False, + is_generation=False, ) as srt_runner: - srt_outputs = srt_runner.forward( - prompts, - ) + srt_outputs = srt_runner.forward(prompts) for i in range(len(prompts)): hf_logits = torch.Tensor(hf_outputs.embed_logits[i]) srt_logits = torch.Tensor(srt_outputs.embed_logits[i]) - similarities = torch.tensor(get_similarities(hf_logits, srt_logits)) - print("max similarity diff", torch.max(abs(similarities - 1))) + similarity = torch.tensor(get_similarities(hf_logits, srt_logits)) + print("similarity diff", abs(similarity - 1)) - if hf_logits.shape[0] <= 100: - tolerance = 1e-2 + if len(prompts[i]) <= 1000: assert torch.all( - abs(similarities - 1) < tolerance - ), f"embeddings not all close" + abs(similarity - 1) < prefill_tolerance + ), "embeddings are not all close" def test_prefill_logits(self): - for model, tp_size in MODELS: + for model, tp_size, prefill_tolerance in MODELS: for torch_dtype in TORCH_DTYPES: self.assert_close_prefill_logits( - DEFAULT_PROMPTS, model, tp_size, torch_dtype + DEFAULT_PROMPTS, model, tp_size, torch_dtype, prefill_tolerance ) if __name__ == "__main__": - unittest.main(warnings="ignore") + try: + mp.set_start_method("spawn") + except RuntimeError: + pass + + unittest.main() diff --git a/test/srt/models/test_generation_models.py b/test/srt/models/test_generation_models.py index bb56ebdad79..08288c510c9 100644 --- a/test/srt/models/test_generation_models.py +++ b/test/srt/models/test_generation_models.py @@ -13,6 +13,7 @@ limitations under the License. """ +import multiprocessing as mp import unittest import torch @@ -20,14 +21,47 @@ from sglang.test.runners import DEFAULT_PROMPTS, HFRunner, SRTRunner MODELS = [ - ("meta-llama/Meta-Llama-3.1-8B-Instruct", 1, 1.1), - ("google/gemma-2-2b", 1, 3), + ("meta-llama/Meta-Llama-3.1-8B-Instruct", 1, 1.1, 3e-2, 1), + ("google/gemma-2-2b", 1, 3, 3e-2, 1), + ("Alibaba-NLP/gte-Qwen2-1.5B-instruct", 1, None, 6e-2, 1), ] TORCH_DTYPES = [torch.float16] -class TestGenerationModels(unittest.TestCase): +def lcs(X, Y): + m = len(X) + n = len(Y) + L = [[0] * (n + 1) for _ in range(m + 1)] + + for i in range(m + 1): + for j in range(n + 1): + if i == 0 or j == 0: + L[i][j] = 0 + elif X[i - 1] == Y[j - 1]: + L[i][j] = L[i - 1][j - 1] + 1 + else: + L[i][j] = max(L[i - 1][j], L[i][j - 1]) + + return L[m][n] + + +def calculate_rouge_l(output_strs_list1, output_strs_list2): + rouge_l_scores = [] + + for s1, s2 in zip(output_strs_list1, output_strs_list2): + lcs_len = lcs(s1, s2) + precision = lcs_len / len(s1) if len(s1) > 0 else 0 + recall = lcs_len / len(s2) if len(s2) > 0 else 0 + if precision + recall > 0: + fmeasure = (2 * precision * recall) / (precision + recall) + else: + fmeasure = 0.0 + rouge_l_scores.append(fmeasure) + return rouge_l_scores + + +class TestGenerationModels(unittest.TestCase): def assert_close_prefill_logits_and_output_strs( self, prompts, @@ -35,10 +69,14 @@ def assert_close_prefill_logits_and_output_strs( tp_size, torch_dtype, max_new_tokens, + prefill_tolerance, + rouge_threshold, long_context_tolerance, ) -> None: + if model_path == "Alibaba-NLP/gte-Qwen2-1.5B-instruct": + prompts = prompts[:-1] with HFRunner( - model_path, torch_dtype=torch_dtype, is_generation_model=True + model_path, torch_dtype=torch_dtype, is_generation=True ) as hf_runner: hf_outputs = hf_runner.forward(prompts, max_new_tokens=max_new_tokens) @@ -46,7 +84,7 @@ def assert_close_prefill_logits_and_output_strs( model_path, tp_size=tp_size, torch_dtype=torch_dtype, - is_generation_model=True, + is_generation=True, ) as srt_runner: srt_outputs = srt_runner.forward(prompts, max_new_tokens=max_new_tokens) @@ -56,28 +94,46 @@ def assert_close_prefill_logits_and_output_strs( print("max_diff", torch.max(abs(hf_logprobs - srt_logprobs))) if hf_logprobs.shape[0] <= 100: - tolerance = 3e-2 assert torch.all( - abs(hf_logprobs - srt_logprobs) < tolerance - ), f"prefill logprobs not all close" - - print(hf_outputs.output_strs) - print(srt_outputs.output_strs) - assert hf_outputs.output_strs == srt_outputs.output_strs + abs(hf_logprobs - srt_logprobs) < prefill_tolerance + ), f"prefill logprobs are not all close with model_path={model_path} prompts={prompts} prefill_tolerance={prefill_tolerance}" + + print(f"hf_outputs.output_strs={hf_outputs.output_strs}") + print(f"srt_outputs.output_strs={srt_outputs.output_strs}") + rouge_l_scores = calculate_rouge_l( + hf_outputs.output_strs, srt_outputs.output_strs + ) + print(f"rouge_l_scores={rouge_l_scores}") + assert all( + score >= rouge_threshold for score in rouge_l_scores + ), f"Not all ROUGE-L scores are greater than rouge_threshold={rouge_threshold}" def test_prefill_logits_and_output_strs(self): - for model, tp_size, long_context_tolerance in MODELS: + for ( + model, + tp_size, + long_context_tolerance, + prefill_tolerance, + rouge_threshold, + ) in MODELS: for torch_dtype in TORCH_DTYPES: - max_new_tokens = 8 + max_new_tokens = 32 self.assert_close_prefill_logits_and_output_strs( DEFAULT_PROMPTS, model, tp_size, torch_dtype, max_new_tokens, + prefill_tolerance=prefill_tolerance, + rouge_threshold=rouge_threshold, long_context_tolerance=long_context_tolerance, ) if __name__ == "__main__": - unittest.main(warnings="ignore") + try: + mp.set_start_method("spawn") + except RuntimeError: + pass + + unittest.main() diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 4d3f7de30a0..cafcf3f2d59 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -5,19 +5,20 @@ suites = { "minimal": [ + "models/test_embedding_models.py", + "models/test_generation_models.py", + "sampling/penaltylib", "test_chunked_prefill.py", "test_embedding_openai_server.py", "test_eval_accuracy_mini.py", "test_large_max_new_tokens.py", "test_openai_server.py", + "test_json_constrained.py", "test_skip_tokenizer_init.py", "test_torch_compile.py", "test_triton_attn_backend.py", + "test_update_weights.py", "test_vision_openai_server.py", - "test_large_max_new_tokens.py", - "models/test_generation_models.py", - "models/test_embedding_models.py", - "sampling/penaltylib", ], "sampling/penaltylib": glob.glob( "sampling/penaltylib/**/test_*.py", recursive=True @@ -32,6 +33,7 @@ tests.remove(target_suite_name) tests.extend(target_tests) + if __name__ == "__main__": arg_parser = argparse.ArgumentParser() arg_parser.add_argument( @@ -47,6 +49,18 @@ choices=list(suites.keys()) + ["all"], help="The suite to run", ) + arg_parser.add_argument( + "--range-begin", + type=int, + default=0, + help="The begin index of the range of the files to run.", + ) + arg_parser.add_argument( + "--range-end", + type=int, + default=None, + help="The end index of the range of the files to run.", + ) args = arg_parser.parse_args() if args.suite == "all": @@ -54,5 +68,7 @@ else: files = suites[args.suite] + files = files[args.range_begin : args.range_end] + exit_code = run_unittest_files(files, args.timeout_per_file) exit(exit_code) diff --git a/test/srt/sampling/penaltylib/test_srt_endpoint_with_penalizers.py b/test/srt/sampling/penaltylib/test_srt_endpoint_with_penalizers.py index e72dc30f956..e3496102cb1 100644 --- a/test/srt/sampling/penaltylib/test_srt_endpoint_with_penalizers.py +++ b/test/srt/sampling/penaltylib/test_srt_endpoint_with_penalizers.py @@ -5,7 +5,12 @@ import requests from sglang.srt.utils import kill_child_process -from sglang.test.test_utils import DEFAULT_MODEL_NAME_FOR_TEST, popen_launch_server +from sglang.test.test_utils import ( + DEFAULT_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + popen_launch_server, +) class TestBatchPenalizerE2E(unittest.TestCase): @@ -13,11 +18,11 @@ class TestBatchPenalizerE2E(unittest.TestCase): @classmethod def setUpClass(cls): cls.model = DEFAULT_MODEL_NAME_FOR_TEST - cls.base_url = f"http://127.0.0.1:{8157}" + cls.base_url = DEFAULT_URL_FOR_TEST cls.process = popen_launch_server( cls.model, cls.base_url, - timeout=300, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, other_args=( "--random-seed", "0", @@ -107,4 +112,4 @@ def test_repetition_penalty(self): if __name__ == "__main__": - unittest.main(warnings="ignore") + unittest.main() diff --git a/test/srt/test_chunked_prefill.py b/test/srt/test_chunked_prefill.py index 8d81dc0c3e1..2eb704dc919 100644 --- a/test/srt/test_chunked_prefill.py +++ b/test/srt/test_chunked_prefill.py @@ -5,7 +5,8 @@ from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( DEFAULT_MODEL_NAME_FOR_TEST, - DEFAULT_URL_FOR_UNIT_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, popen_launch_server, ) @@ -20,11 +21,11 @@ def run_mmlu(self, disable_radix_cache, enable_mixed_chunk): other_args += ["--enable-mixed-chunk"] model = DEFAULT_MODEL_NAME_FOR_TEST - base_url = DEFAULT_URL_FOR_UNIT_TEST + base_url = DEFAULT_URL_FOR_TEST process = popen_launch_server( model, base_url, - timeout=300, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, other_args=other_args, ) diff --git a/test/srt/test_create_kvindices.py b/test/srt/test_create_kvindices.py new file mode 100644 index 00000000000..230302f264f --- /dev/null +++ b/test/srt/test_create_kvindices.py @@ -0,0 +1,76 @@ +import itertools +import unittest + +import numpy as np +import torch + +from sglang.srt.model_executor.forward_batch_info import ( + create_flashinfer_kv_indices_triton, +) + + +class TestCreateKvIndices(unittest.TestCase): + @classmethod + def setUpClass(cls): + if not torch.cuda.is_available(): + raise unittest.SkipTest("CUDA is not available") + torch.set_default_device("cuda") + + def _run_test(self, batch, max_batch, max_context_len): + req_to_token = torch.arange( + max_batch * max_context_len, dtype=torch.int32, device="cuda" + ).reshape((max_batch, max_context_len)) + req_pool_indices = torch.tensor( + torch.from_numpy( + np.random.choice(range(max_batch), size=batch, replace=False) + ), + dtype=torch.int32, + device="cuda", + ) + paged_kernel_lens = torch.tensor( + torch.from_numpy( + np.random.choice(range(max_context_len), size=batch, replace=False) + ), + dtype=torch.int32, + device="cuda", + ) + + kv_indptr = torch.zeros((batch + 1,), dtype=torch.int32, device="cuda") + kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0) + + # ref + req_pool_indices_cpu = req_pool_indices.cpu().numpy() + paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy() + kv_indices_ref = torch.cat( + [ + req_to_token[req_pool_indices_cpu[i], : paged_kernel_lens_cpu[i]] + for i in range(batch) + ], + dim=0, + ).contiguous() + + # triton + kv_indices_triton = torch.empty(kv_indptr[-1], dtype=torch.int32, device="cuda") + create_flashinfer_kv_indices_triton[(batch,)]( + req_to_token, + req_pool_indices, + paged_kernel_lens, + kv_indptr, + None, + req_to_token.size(1), + kv_indices_triton, + ) + + # Check + self.assertTrue(torch.equal(kv_indices_ref, kv_indices_triton)) + + def test_create_kvindices(self): + BATCH = [1, 37, 1786] + MAX_BATCH = 4096 + MAX_CONTEXT_LEN = 4096 + for batch in BATCH: + self._run_test(batch, MAX_BATCH, MAX_CONTEXT_LEN) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/test_embedding_openai_server.py b/test/srt/test_embedding_openai_server.py index fd8fec48e90..45f7850da99 100644 --- a/test/srt/test_embedding_openai_server.py +++ b/test/srt/test_embedding_openai_server.py @@ -4,17 +4,24 @@ from sglang.srt.hf_transformers_utils import get_tokenizer from sglang.srt.utils import kill_child_process -from sglang.test.test_utils import DEFAULT_URL_FOR_UNIT_TEST, popen_launch_server +from sglang.test.test_utils import ( + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + popen_launch_server, +) class TestOpenAIServer(unittest.TestCase): @classmethod def setUpClass(cls): cls.model = "intfloat/e5-mistral-7b-instruct" - cls.base_url = DEFAULT_URL_FOR_UNIT_TEST + cls.base_url = DEFAULT_URL_FOR_TEST cls.api_key = "sk-123456" cls.process = popen_launch_server( - cls.model, cls.base_url, timeout=300, api_key=cls.api_key + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + api_key=cls.api_key, ) cls.base_url += "/v1" cls.tokenizer = get_tokenizer(cls.model) diff --git a/test/srt/test_eval_accuracy_large.py b/test/srt/test_eval_accuracy_large.py index 470ed11aa45..3729ad26b6a 100644 --- a/test/srt/test_eval_accuracy_large.py +++ b/test/srt/test_eval_accuracy_large.py @@ -5,8 +5,8 @@ from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( DEFAULT_MODEL_NAME_FOR_TEST, - DEFAULT_URL_FOR_ACCURACY_TEST, - DEFAULT_URL_FOR_UNIT_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, popen_launch_server, ) @@ -15,11 +15,11 @@ class TestEvalAccuracyLarge(unittest.TestCase): @classmethod def setUpClass(cls): cls.model = DEFAULT_MODEL_NAME_FOR_TEST - cls.base_url = DEFAULT_URL_FOR_ACCURACY_TEST + cls.base_url = DEFAULT_URL_FOR_TEST cls.process = popen_launch_server( cls.model, cls.base_url, - timeout=300, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, other_args=["--log-level-http", "warning"], ) diff --git a/test/srt/test_eval_accuracy_large_chunked_prefill.py b/test/srt/test_eval_accuracy_large_chunked_prefill.py index 951f481da32..02df2a7f56a 100644 --- a/test/srt/test_eval_accuracy_large_chunked_prefill.py +++ b/test/srt/test_eval_accuracy_large_chunked_prefill.py @@ -5,7 +5,8 @@ from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( DEFAULT_MODEL_NAME_FOR_TEST, - DEFAULT_URL_FOR_ACCURACY_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, popen_launch_server, ) @@ -14,11 +15,11 @@ class TestEvalAccuracyLargeChunkedPrefill(unittest.TestCase): @classmethod def setUpClass(cls): cls.model = DEFAULT_MODEL_NAME_FOR_TEST - cls.base_url = DEFAULT_URL_FOR_ACCURACY_TEST + cls.base_url = DEFAULT_URL_FOR_TEST cls.process = popen_launch_server( cls.model, cls.base_url, - timeout=300, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, other_args=["--log-level-http", "warning", "--chunked-prefill-size", "256"], ) diff --git a/test/srt/test_eval_accuracy_large_mixed_chunked_prefill.py b/test/srt/test_eval_accuracy_large_mixed_chunked_prefill.py index 210c32b5196..8ba71e5c836 100644 --- a/test/srt/test_eval_accuracy_large_mixed_chunked_prefill.py +++ b/test/srt/test_eval_accuracy_large_mixed_chunked_prefill.py @@ -5,7 +5,8 @@ from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( DEFAULT_MODEL_NAME_FOR_TEST, - DEFAULT_URL_FOR_ACCURACY_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, popen_launch_server, ) @@ -14,11 +15,11 @@ class TestEvalAccuracyLargeChunkedPrefill(unittest.TestCase): @classmethod def setUpClass(cls): cls.model = DEFAULT_MODEL_NAME_FOR_TEST - cls.base_url = DEFAULT_URL_FOR_ACCURACY_TEST + cls.base_url = DEFAULT_URL_FOR_TEST cls.process = popen_launch_server( cls.model, cls.base_url, - timeout=300, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, other_args=[ "--log-level-http", "warning", diff --git a/test/srt/test_eval_accuracy_mini.py b/test/srt/test_eval_accuracy_mini.py index a4219b1a0a7..25aa0ca116b 100644 --- a/test/srt/test_eval_accuracy_mini.py +++ b/test/srt/test_eval_accuracy_mini.py @@ -5,7 +5,8 @@ from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( DEFAULT_MODEL_NAME_FOR_TEST, - DEFAULT_URL_FOR_UNIT_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, popen_launch_server, ) @@ -14,8 +15,10 @@ class TestEvalAccuracyMini(unittest.TestCase): @classmethod def setUpClass(cls): cls.model = DEFAULT_MODEL_NAME_FOR_TEST - cls.base_url = DEFAULT_URL_FOR_UNIT_TEST - cls.process = popen_launch_server(cls.model, cls.base_url, timeout=300) + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, cls.base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH + ) @classmethod def tearDownClass(cls): diff --git a/test/srt/test_json_constrained.py b/test/srt/test_json_constrained.py new file mode 100644 index 00000000000..5393ecc33ca --- /dev/null +++ b/test/srt/test_json_constrained.py @@ -0,0 +1,96 @@ +import json +import unittest + +import openai +import requests + +from sglang.srt.utils import kill_child_process +from sglang.test.test_utils import ( + DEFAULT_MODEL_NAME_FOR_TEST, + DEFAULT_URL_FOR_TEST, + popen_launch_server, +) + + +class TestJSONConstrained(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + cls.api_key = "sk-123456" + cls.json_schema = json.dumps( + { + "type": "object", + "properties": { + "name": {"type": "string", "pattern": "^[\\w]+$"}, + "population": {"type": "integer"}, + }, + "required": ["name", "population"], + } + ) + cls.process = popen_launch_server( + cls.model, cls.base_url, timeout=300, api_key=cls.api_key + ) + + @classmethod + def tearDownClass(cls): + kill_child_process(cls.process.pid) + + def run_decode(self, return_logprob=False, top_logprobs_num=0, n=1): + headers = {"Authorization": f"Bearer {self.api_key}"} + response = requests.post( + self.base_url + "/generate", + json={ + "text": "The capital of France is", + "sampling_params": { + "temperature": 0 if n == 1 else 0.5, + "max_new_tokens": 128, + "n": n, + "stop_token_ids": [119690], + "json_schema": self.json_schema, + }, + "stream": False, + "return_logprob": return_logprob, + "top_logprobs_num": top_logprobs_num, + "logprob_start_len": 0, + }, + headers=headers, + ) + print(json.dumps(response.json())) + print("=" * 100) + try: + js_obj = json.loads(response.json()["text"]) + except (TypeError, json.decoder.JSONDecodeError): + raise + assert isinstance(js_obj["name"], str) + assert isinstance(js_obj["population"], int) + + def test_json_generate(self): + self.run_decode() + + def test_json_openai(self): + client = openai.Client(api_key=self.api_key, base_url=f"{self.base_url}/v1") + + response = client.chat.completions.create( + model=self.model, + messages=[ + {"role": "system", "content": "You are a helpful AI assistant"}, + {"role": "user", "content": "Introduce the capital of France."}, + ], + temperature=0, + max_tokens=128, + extra_body={"json_schema": self.json_schema}, + ) + text = response.choices[0].message.content + + try: + js_obj = json.loads(text) + except (TypeError, json.decoder.JSONDecodeError): + print("JSONDecodeError", text) + raise + assert isinstance(js_obj["name"], str) + assert isinstance(js_obj["population"], int) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/test_large_max_new_tokens.py b/test/srt/test_large_max_new_tokens.py index f29adabced9..10b82706a61 100644 --- a/test/srt/test_large_max_new_tokens.py +++ b/test/srt/test_large_max_new_tokens.py @@ -10,7 +10,8 @@ from sglang.srt.utils import kill_child_process from sglang.test.test_utils import ( DEFAULT_MODEL_NAME_FOR_TEST, - DEFAULT_URL_FOR_UNIT_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, popen_launch_server, ) @@ -19,12 +20,12 @@ class TestOpenAIServer(unittest.TestCase): @classmethod def setUpClass(cls): cls.model = DEFAULT_MODEL_NAME_FOR_TEST - cls.base_url = DEFAULT_URL_FOR_UNIT_TEST + cls.base_url = DEFAULT_URL_FOR_TEST cls.api_key = "sk-123456" cls.process = popen_launch_server( cls.model, cls.base_url, - timeout=300, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, api_key=cls.api_key, other_args=("--max-total-token", "1024"), env={"SGLANG_CLIP_MAX_NEW_TOKENS": "256", **os.environ}, diff --git a/test/srt/test_moe_eval_accuracy_large.py b/test/srt/test_moe_eval_accuracy_large.py new file mode 100644 index 00000000000..d4b1354b793 --- /dev/null +++ b/test/srt/test_moe_eval_accuracy_large.py @@ -0,0 +1,73 @@ +import unittest +from types import SimpleNamespace + +from sglang.srt.utils import kill_child_process +from sglang.test.run_eval import run_eval +from sglang.test.test_utils import ( + DEFAULT_MOE_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + popen_launch_server, +) + + +class TestEvalAccuracyLarge(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_MOE_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--log-level-http", + "warning", + "--tp", + "2", + ], + ) + + @classmethod + def tearDownClass(cls): + kill_child_process(cls.process.pid) + + def test_mmlu(self): + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="mmlu", + num_examples=3000, + num_threads=1024, + ) + + metrics = run_eval(args) + assert metrics["score"] >= 0.63, f"{metrics}" + + def test_human_eval(self): + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="humaneval", + num_examples=None, + num_threads=1024, + ) + + metrics = run_eval(args) + assert metrics["score"] >= 0.42, f"{metrics}" + + def test_mgsm_en(self): + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="mgsm_en", + num_examples=None, + num_threads=1024, + ) + + metrics = run_eval(args) + assert metrics["score"] >= 0.63, f"{metrics}" + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/test_moe_serving_latency.py b/test/srt/test_moe_serving_latency.py new file mode 100644 index 00000000000..9d521532316 --- /dev/null +++ b/test/srt/test_moe_serving_latency.py @@ -0,0 +1,45 @@ +import os +import subprocess +import unittest + +from sglang.srt.utils import kill_child_process +from sglang.test.test_utils import DEFAULT_MOE_MODEL_NAME_FOR_TEST + + +class TestServingLatency(unittest.TestCase): + def test_default(self): + command = [ + "python3", + "-m", + "sglang.bench_latency", + "--model", + DEFAULT_MOE_MODEL_NAME_FOR_TEST, + "--batch-size", + "1", + "--input", + "128", + "--output", + "8", + "--tp", + "2", + ] + process = subprocess.Popen( + command, stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) + stdout, stderr = process.communicate() + output = stdout.decode() + error = stderr.decode() + print(f"Output: {output}") + print(f"Error: {error}") + + lastline = output.split("\n")[-3] + value = float(lastline.split(" ")[-2]) + + if os.getenv("SGLANG_IS_IN_CI", "false") == "true": + assert value > 125 + + kill_child_process(process.pid) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/test_moe_serving_throughput.py b/test/srt/test_moe_serving_throughput.py index bbcd5122769..2acf626c1c4 100644 --- a/test/srt/test_moe_serving_throughput.py +++ b/test/srt/test_moe_serving_throughput.py @@ -7,7 +7,8 @@ from sglang.srt.utils import kill_child_process from sglang.test.test_utils import ( DEFAULT_MOE_MODEL_NAME_FOR_TEST, - DEFAULT_URL_FOR_MOE_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, popen_launch_server, ) @@ -22,16 +23,18 @@ def run_test(self, disable_radix_cache, disable_flashinfer, chunked_prefill_size other_args.append("--disable-flashinfer") other_args.extend(["--chunked-prefill-size", str(chunked_prefill_size)]) other_args.extend(["--tensor-parallel-size", "2"]) - other_args.append("--enable-p2p-check") model = DEFAULT_MOE_MODEL_NAME_FOR_TEST - base_url = DEFAULT_URL_FOR_MOE_TEST + base_url = DEFAULT_URL_FOR_TEST process = popen_launch_server( - model, base_url, timeout=300, other_args=other_args + model, + base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=other_args, ) # Run benchmark - num_prompts = 200 + num_prompts = 300 args = SimpleNamespace( backend="sglang", base_url=base_url, @@ -72,8 +75,7 @@ def test_default(self): ) if os.getenv("SGLANG_IS_IN_CI", "false") == "true": - # A100 (PCIE) performance - assert res["output_throughput"] > 930 + assert res["output_throughput"] > 1800 def test_default_without_radix_cache(self): res = self.run_test( @@ -83,29 +85,7 @@ def test_default_without_radix_cache(self): ) if os.getenv("SGLANG_IS_IN_CI", "false") == "true": - # A100 (PCIE) performance - assert res["output_throughput"] > 930 - - def test_default_without_chunked_prefill(self): - res = self.run_test( - disable_radix_cache=ServerArgs.disable_radix_cache, - disable_flashinfer=ServerArgs.disable_flashinfer, - chunked_prefill_size=-1, - ) - - if os.getenv("SGLANG_IS_IN_CI", "false") == "true": - # A100 (PCIE) performance - print(res["output_throughput"]) - - def test_all_cases(self): - for disable_radix_cache in [False, True]: - for disable_flashinfer in [False, True]: - for chunked_prefill_size in [-1, 2048]: - self.run_test( - disable_radix_cache=False, - disable_flashinfer=False, - chunked_prefill_size=-1, - ) + assert res["output_throughput"] > 1950 if __name__ == "__main__": diff --git a/test/srt/test_nightly_gsm8k_eval.py b/test/srt/test_nightly_gsm8k_eval.py new file mode 100644 index 00000000000..35e7d6eb7db --- /dev/null +++ b/test/srt/test_nightly_gsm8k_eval.py @@ -0,0 +1,89 @@ +import unittest +from types import SimpleNamespace + +from sglang.srt.utils import kill_child_process +from sglang.test.run_eval import run_eval +from sglang.test.test_utils import ( + DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_FP8_TP1, + DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_FP8_TP2, + DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_TP1, + DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_TP2, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + popen_launch_server, +) + + +def parse_models(model_string): + return [model.strip() for model in model_string.split(",") if model.strip()] + + +class TestEvalAccuracyLarge(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.model_groups = [ + (parse_models(DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_TP1), False, False), + (parse_models(DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_TP2), False, True), + (parse_models(DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_FP8_TP1), True, False), + (parse_models(DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_FP8_TP2), True, True), + ] + cls.base_url = DEFAULT_URL_FOR_TEST + + def setUp(self): + self.process = None + + def tearDown(self): + if self.process: + kill_child_process(self.process.pid) + + def launch_server(self, model, is_fp8, is_tp2): + other_args = ["--log-level-http", "warning", "--trust-remote-code"] + if is_fp8: + if "Llama-3" in model or "gemma-2" in model: + # compressed-tensors + other_args.extend(["--kv-cache-dtype", "fp8_e5m2"]) + elif "Qwen2-72B-Instruct-FP8" in model: + # bug + other_args.extend(["--quantization", "fp8"]) + else: + other_args.extend( + ["--quantization", "fp8", "--kv-cache-dtype", "fp8_e5m2"] + ) + if is_tp2: + other_args.extend(["--tp", "2"]) + if "DeepSeek" in model: + other_args.append("--enable-mla") + + self.process = popen_launch_server( + model, + self.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=other_args, + ) + + def test_mgsm_en_all_models(self): + for model_group, is_fp8, is_tp2 in self.model_groups: + for model in model_group: + with self.subTest(model=model): + self.launch_server(model, is_fp8, is_tp2) + + args = SimpleNamespace( + base_url=self.base_url, + model=model, + eval_name="mgsm_en", + num_examples=None, + num_threads=1024, + ) + + metrics = run_eval(args) + print( + f"{'=' * 42}\n{model} - metrics={metrics} score={metrics['score']}\n{'=' * 42}\n" + ) + # loosely threshold + assert metrics["score"] > 0.5, f"score={metrics['score']} <= 0.5" + + self.tearDown() + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/test_openai_server.py b/test/srt/test_openai_server.py index 828f5ab532c..3fc5785517f 100644 --- a/test/srt/test_openai_server.py +++ b/test/srt/test_openai_server.py @@ -8,7 +8,8 @@ from sglang.srt.utils import kill_child_process from sglang.test.test_utils import ( DEFAULT_MODEL_NAME_FOR_TEST, - DEFAULT_URL_FOR_UNIT_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, popen_launch_server, ) @@ -17,10 +18,13 @@ class TestOpenAIServer(unittest.TestCase): @classmethod def setUpClass(cls): cls.model = DEFAULT_MODEL_NAME_FOR_TEST - cls.base_url = DEFAULT_URL_FOR_UNIT_TEST + cls.base_url = DEFAULT_URL_FOR_TEST cls.api_key = "sk-123456" cls.process = popen_launch_server( - cls.model, cls.base_url, timeout=300, api_key=cls.api_key + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + api_key=cls.api_key, ) cls.base_url += "/v1" cls.tokenizer = get_tokenizer(DEFAULT_MODEL_NAME_FOR_TEST) @@ -252,8 +256,7 @@ def run_chat_completion_stream(self, logprobs, parallel_sample_num=1): index, True ), f"index {index} is not found in the response" - def run_batch(self, mode): - client = openai.Client(api_key=self.api_key, base_url=self.base_url) + def _create_batch(self, mode, client): if mode == "completion": input_file_path = "complete_input.jsonl" # write content to input file @@ -329,9 +332,11 @@ def run_batch(self, mode): }, }, ] + with open(input_file_path, "w") as file: for line in content: file.write(json.dumps(line) + "\n") + with open(input_file_path, "rb") as file: uploaded_file = client.files.create(file=file, purpose="batch") if mode == "completion": @@ -344,6 +349,13 @@ def run_batch(self, mode): endpoint=endpoint, completion_window=completion_window, ) + + return batch_job, content, uploaded_file + + def run_batch(self, mode): + client = openai.Client(api_key=self.api_key, base_url=self.base_url) + batch_job, content, uploaded_file = self._create_batch(mode=mode, client=client) + while batch_job.status not in ["completed", "failed", "cancelled"]: time.sleep(3) print( @@ -366,6 +378,29 @@ def run_batch(self, mode): if line.strip() != "" ] assert len(results) == len(content) + for delete_fid in [uploaded_file.id, result_file_id]: + del_pesponse = client.files.delete(delete_fid) + assert del_pesponse.deleted + + def run_cancel_batch(self, mode): + client = openai.Client(api_key=self.api_key, base_url=self.base_url) + batch_job, _, uploaded_file = self._create_batch(mode=mode, client=client) + + assert batch_job.status not in ["cancelling", "cancelled"] + + batch_job = client.batches.cancel(batch_id=batch_job.id) + assert batch_job.status == "cancelling" + + while batch_job.status not in ["failed", "cancelled"]: + batch_job = client.batches.retrieve(batch_job.id) + print( + f"Batch job status: {batch_job.status}...trying again in 3 seconds..." + ) + time.sleep(3) + + assert batch_job.status == "cancelled" + del_response = client.files.delete(uploaded_file.id) + assert del_response.deleted def test_completion(self): for echo in [False, True]: @@ -410,6 +445,10 @@ def test_batch(self): for mode in ["completion", "chat"]: self.run_batch(mode) + def test_calcel_batch(self): + for mode in ["completion", "chat"]: + self.run_cancel_batch(mode) + def test_regex(self): client = openai.Client(api_key=self.api_key, base_url=self.base_url) diff --git a/test/srt/test_serving_latency.py b/test/srt/test_serving_latency.py new file mode 100644 index 00000000000..e762892c8eb --- /dev/null +++ b/test/srt/test_serving_latency.py @@ -0,0 +1,43 @@ +import os +import subprocess +import unittest + +from sglang.srt.utils import kill_child_process +from sglang.test.test_utils import DEFAULT_MODEL_NAME_FOR_TEST + + +class TestServingLatency(unittest.TestCase): + def test_default(self): + command = [ + "python3", + "-m", + "sglang.bench_latency", + "--model", + DEFAULT_MODEL_NAME_FOR_TEST, + "--batch-size", + "1", + "--input", + "128", + "--output", + "8", + ] + process = subprocess.Popen( + command, stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) + stdout, stderr = process.communicate() + output = stdout.decode() + error = stderr.decode() + print(f"Output: {output}") + print(f"Error: {error}") + + lastline = output.split("\n")[-3] + value = float(lastline.split(" ")[-2]) + + if os.getenv("SGLANG_IS_IN_CI", "false") == "true": + assert value > 130 + + kill_child_process(process.pid) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/test_serving_throughput.py b/test/srt/test_serving_throughput.py index 261ac6ec52f..d4ed12612ac 100644 --- a/test/srt/test_serving_throughput.py +++ b/test/srt/test_serving_throughput.py @@ -7,7 +7,8 @@ from sglang.srt.utils import kill_child_process from sglang.test.test_utils import ( DEFAULT_MODEL_NAME_FOR_TEST, - DEFAULT_URL_FOR_E2E_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, popen_launch_server, ) @@ -23,13 +24,16 @@ def run_test(self, disable_radix_cache, disable_flashinfer, chunked_prefill_size other_args.extend(["--chunked-prefill-size", str(chunked_prefill_size)]) model = DEFAULT_MODEL_NAME_FOR_TEST - base_url = DEFAULT_URL_FOR_E2E_TEST + base_url = DEFAULT_URL_FOR_TEST process = popen_launch_server( - model, base_url, timeout=300, other_args=other_args + model, + base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=other_args, ) # Run benchmark - num_prompts = 400 + num_prompts = 500 args = SimpleNamespace( backend="sglang", base_url=base_url, @@ -70,8 +74,7 @@ def test_default(self): ) if os.getenv("SGLANG_IS_IN_CI", "false") == "true": - # A100 (PCIE) performance - assert res["output_throughput"] > 1400 + assert res["output_throughput"] > 2400 def test_default_without_radix_cache(self): res = self.run_test( @@ -81,8 +84,7 @@ def test_default_without_radix_cache(self): ) if os.getenv("SGLANG_IS_IN_CI", "false") == "true": - # A100 (PCIE) performance - assert res["output_throughput"] > 1450 + assert res["output_throughput"] > 2800 def test_default_without_chunked_prefill(self): res = self.run_test( @@ -92,18 +94,7 @@ def test_default_without_chunked_prefill(self): ) if os.getenv("SGLANG_IS_IN_CI", "false") == "true": - # A100 (PCIE) performance - assert res["output_throughput"] > 1400 - - def test_all_cases(self): - for disable_radix_cache in [False, True]: - for disable_flashinfer in [False, True]: - for chunked_prefill_size in [-1, 2048]: - self.run_test( - disable_radix_cache=False, - disable_flashinfer=False, - chunked_prefill_size=-1, - ) + assert res["output_throughput"] > 2400 if __name__ == "__main__": diff --git a/test/srt/test_skip_tokenizer_init.py b/test/srt/test_skip_tokenizer_init.py index 75010561514..b159bb55787 100644 --- a/test/srt/test_skip_tokenizer_init.py +++ b/test/srt/test_skip_tokenizer_init.py @@ -6,7 +6,8 @@ from sglang.srt.utils import kill_child_process from sglang.test.test_utils import ( DEFAULT_MODEL_NAME_FOR_TEST, - DEFAULT_URL_FOR_UNIT_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, popen_launch_server, ) @@ -15,9 +16,12 @@ class TestSkipTokenizerInit(unittest.TestCase): @classmethod def setUpClass(cls): cls.model = DEFAULT_MODEL_NAME_FOR_TEST - cls.base_url = DEFAULT_URL_FOR_UNIT_TEST + cls.base_url = DEFAULT_URL_FOR_TEST cls.process = popen_launch_server( - cls.model, cls.base_url, timeout=300, other_args=["--skip-tokenizer-init"] + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=["--skip-tokenizer-init"], ) @classmethod diff --git a/test/srt/test_srt_endpoint.py b/test/srt/test_srt_endpoint.py index 60f4cd58a3b..818aae2151e 100644 --- a/test/srt/test_srt_endpoint.py +++ b/test/srt/test_srt_endpoint.py @@ -6,7 +6,8 @@ from sglang.srt.utils import kill_child_process from sglang.test.test_utils import ( DEFAULT_MODEL_NAME_FOR_TEST, - DEFAULT_URL_FOR_UNIT_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, popen_launch_server, ) @@ -15,8 +16,10 @@ class TestSRTEndpoint(unittest.TestCase): @classmethod def setUpClass(cls): cls.model = DEFAULT_MODEL_NAME_FOR_TEST - cls.base_url = DEFAULT_URL_FOR_UNIT_TEST - cls.process = popen_launch_server(cls.model, cls.base_url, timeout=300) + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, cls.base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH + ) @classmethod def tearDownClass(cls): diff --git a/test/srt/test_torch_compile.py b/test/srt/test_torch_compile.py index 5133d3cd3c2..e8cafa15d25 100644 --- a/test/srt/test_torch_compile.py +++ b/test/srt/test_torch_compile.py @@ -1,11 +1,14 @@ import unittest from types import SimpleNamespace +import requests + from sglang.srt.utils import kill_child_process from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( DEFAULT_MODEL_NAME_FOR_TEST, - DEFAULT_URL_FOR_UNIT_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, popen_launch_server, ) @@ -14,9 +17,12 @@ class TestTorchCompile(unittest.TestCase): @classmethod def setUpClass(cls): cls.model = DEFAULT_MODEL_NAME_FOR_TEST - cls.base_url = DEFAULT_URL_FOR_UNIT_TEST + cls.base_url = DEFAULT_URL_FOR_TEST cls.process = popen_launch_server( - cls.model, cls.base_url, timeout=300, other_args=["--enable-torch-compile"] + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=["--enable-torch-compile", "--disable-radix-cache"], ) @classmethod @@ -35,6 +41,33 @@ def test_mmlu(self): metrics = run_eval(args) assert metrics["score"] >= 0.6 + def run_decode(self, max_new_tokens): + response = requests.post( + self.base_url + "/generate", + json={ + "text": "The capital of France is", + "sampling_params": { + "temperature": 0, + "max_new_tokens": max_new_tokens, + }, + "ignore_eos": True, + }, + ) + return response.json() + + def test_throughput(self): + import time + + max_tokens = 256 + + tic = time.time() + res = self.run_decode(max_tokens) + tok = time.time() + print(res["text"]) + throughput = max_tokens / (tok - tic) + print(f"Throughput: {throughput} tokens/s") + assert throughput >= 152 + if __name__ == "__main__": unittest.main() diff --git a/test/srt/test_triton_attn_backend.py b/test/srt/test_triton_attn_backend.py index 7a453d8be7e..a94ca921240 100644 --- a/test/srt/test_triton_attn_backend.py +++ b/test/srt/test_triton_attn_backend.py @@ -5,7 +5,8 @@ from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( DEFAULT_MODEL_NAME_FOR_TEST, - DEFAULT_URL_FOR_UNIT_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, popen_launch_server, ) @@ -14,9 +15,12 @@ class TestTritonAttnBackend(unittest.TestCase): @classmethod def setUpClass(cls): cls.model = DEFAULT_MODEL_NAME_FOR_TEST - cls.base_url = DEFAULT_URL_FOR_UNIT_TEST + cls.base_url = DEFAULT_URL_FOR_TEST cls.process = popen_launch_server( - cls.model, cls.base_url, timeout=300, other_args=["--disable-flashinfer"] + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=["--disable-flashinfer"], ) @classmethod diff --git a/test/srt/test_update_weights.py b/test/srt/test_update_weights.py index 64f84263aa9..7b8404c735f 100644 --- a/test/srt/test_update_weights.py +++ b/test/srt/test_update_weights.py @@ -6,7 +6,8 @@ from sglang.srt.utils import kill_child_process from sglang.test.test_utils import ( DEFAULT_MODEL_NAME_FOR_TEST, - DEFAULT_URL_FOR_UNIT_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, popen_launch_server, ) @@ -15,8 +16,10 @@ class TestReplaceWeights(unittest.TestCase): @classmethod def setUpClass(cls): cls.model = DEFAULT_MODEL_NAME_FOR_TEST - cls.base_url = DEFAULT_URL_FOR_UNIT_TEST - cls.process = popen_launch_server(cls.model, cls.base_url, timeout=300) + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, cls.base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH + ) @classmethod def tearDownClass(cls): diff --git a/test/srt/test_vision_openai_server.py b/test/srt/test_vision_openai_server.py index c599d8b368a..4f764c09cd8 100644 --- a/test/srt/test_vision_openai_server.py +++ b/test/srt/test_vision_openai_server.py @@ -1,30 +1,38 @@ +import base64 +import io import json +import os import unittest +import numpy as np import openai +import requests +from decord import VideoReader, cpu +from PIL import Image -from sglang.srt.hf_transformers_utils import get_tokenizer from sglang.srt.utils import kill_child_process -from sglang.test.test_utils import DEFAULT_URL_FOR_UNIT_TEST, popen_launch_server +from sglang.test.test_utils import ( + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + popen_launch_server, +) class TestOpenAIVisionServer(unittest.TestCase): @classmethod def setUpClass(cls): - cls.model = "liuhaotian/llava-v1.6-vicuna-7b" - cls.base_url = DEFAULT_URL_FOR_UNIT_TEST + cls.model = "lmms-lab/llava-onevision-qwen2-0.5b-ov" + cls.base_url = DEFAULT_URL_FOR_TEST cls.api_key = "sk-123456" cls.process = popen_launch_server( cls.model, cls.base_url, - timeout=300, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, api_key=cls.api_key, other_args=[ "--chat-template", - "vicuna_v1.1", - "--tokenizer-path", - "llava-hf/llava-1.5-7b-hf", - "--log-requests", + "chatml-llava", + # "--log-requests", ], ) cls.base_url += "/v1" @@ -61,13 +69,180 @@ def test_chat_completion(self): assert response.choices[0].message.role == "assistant" text = response.choices[0].message.content assert isinstance(text, str) - assert "car" in text or "taxi" in text, text + assert "man" in text or "cab" in text, text assert response.id assert response.created assert response.usage.prompt_tokens > 0 assert response.usage.completion_tokens > 0 assert response.usage.total_tokens > 0 + def test_multi_turn_chat_completion(self): + client = openai.Client(api_key=self.api_key, base_url=self.base_url) + + response = client.chat.completions.create( + model="default", + messages=[ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": "https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true" + }, + }, + { + "type": "text", + "text": "Describe this image in a very short sentence.", + }, + ], + }, + { + "role": "assistant", + "content": [ + { + "type": "text", + "text": "There is a man at the back of a yellow cab ironing his clothes.", + } + ], + }, + { + "role": "user", + "content": [ + {"type": "text", "text": "Repeat your previous answer."} + ], + }, + ], + temperature=0, + ) + + assert response.choices[0].message.role == "assistant" + text = response.choices[0].message.content + assert isinstance(text, str) + assert "man" in text or "cab" in text, text + assert response.id + assert response.created + assert response.usage.prompt_tokens > 0 + assert response.usage.completion_tokens > 0 + assert response.usage.total_tokens > 0 + + def test_mult_images_chat_completion(self): + client = openai.Client(api_key=self.api_key, base_url=self.base_url) + + response = client.chat.completions.create( + model="default", + messages=[ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": "https://raw.githubusercontent.com/sgl-project/sglang/main/test/lang/example_image.png" + }, + }, + { + "type": "image_url", + "image_url": { + "url": "https://raw.githubusercontent.com/sgl-project/sglang/main/assets/logo.png" + }, + }, + { + "type": "text", + "text": "I have two very different images. They are not related at all. " + "Please describe the first image in one sentence, and then describe the second image in another sentence.", + }, + ], + }, + ], + temperature=0, + ) + + assert response.choices[0].message.role == "assistant" + text = response.choices[0].message.content + assert isinstance(text, str) + print(text) + assert "man" in text and "taxi" in text, text + assert "logo" in text, text + assert response.id + assert response.created + assert response.usage.prompt_tokens > 0 + assert response.usage.completion_tokens > 0 + assert response.usage.total_tokens > 0 + + def prepare_video_messages(self, video_path): + max_frames_num = 32 + vr = VideoReader(video_path, ctx=cpu(0)) + total_frame_num = len(vr) + uniform_sampled_frames = np.linspace( + 0, total_frame_num - 1, max_frames_num, dtype=int + ) + frame_idx = uniform_sampled_frames.tolist() + frames = vr.get_batch(frame_idx).asnumpy() + + base64_frames = [] + for frame in frames: + pil_img = Image.fromarray(frame) + buff = io.BytesIO() + pil_img.save(buff, format="JPEG") + base64_str = base64.b64encode(buff.getvalue()).decode("utf-8") + base64_frames.append(base64_str) + + messages = [{"role": "user", "content": []}] + frame_format = { + "type": "image_url", + "image_url": {"url": "data:image/jpeg;base64,{}"}, + } + + for base64_frame in base64_frames: + frame_format["image_url"]["url"] = "data:image/jpeg;base64,{}".format( + base64_frame + ) + messages[0]["content"].append(frame_format.copy()) + + prompt = {"type": "text", "text": "Please describe the video in detail."} + messages[0]["content"].append(prompt) + + return messages + + def test_video_chat_completion(self): + url = "https://raw.githubusercontent.com/EvolvingLMMs-Lab/sglang/dev/onevision_local/assets/jobs.mp4" + cache_dir = os.path.expanduser("~/.cache") + file_path = os.path.join(cache_dir, "jobs.mp4") + os.makedirs(cache_dir, exist_ok=True) + + if not os.path.exists(file_path): + response = requests.get(url) + response.raise_for_status() + + with open(file_path, "wb") as f: + f.write(response.content) + + client = openai.Client(api_key=self.api_key, base_url=self.base_url) + + messages = self.prepare_video_messages(file_path) + + video_request = client.chat.completions.create( + model="default", + messages=messages, + temperature=0, + max_tokens=1024, + stream=True, + ) + + print("-" * 30) + video_response = "" + for chunk in video_request: + if chunk.choices[0].delta.content is not None: + content = chunk.choices[0].delta.content + video_response += content + print(content, end="", flush=True) + print("-" * 30) + + # Add assertions to validate the video response + self.assertIsNotNone(video_response) + self.assertGreater(len(video_response), 0) + def test_regex(self): client = openai.Client(api_key=self.api_key, base_url=self.base_url)