From a0681be109e293df9c550a53f8025671c6197afe Mon Sep 17 00:00:00 2001 From: Ying Sheng Date: Thu, 25 Jul 2024 07:29:12 -0700 Subject: [PATCH 1/7] Misc bug fixes --- README.md | 5 -- .../srt/managers/controller/model_runner.py | 8 +++ python/sglang/srt/server.py | 5 +- python/sglang/srt/utils.py | 49 ++++++++++++++++++- 4 files changed, 57 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index 6c13c1d6a79..68aab9e0d5a 100644 --- a/README.md +++ b/README.md @@ -70,11 +70,6 @@ docker run --gpus all \ ``` ### Common Notes -- If you see errors from the Triton compiler, please install the [Triton Nightly](https://triton-lang.org/main/getting-started/installation.html) by -``` -pip uninstall -y triton triton-nightly -pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly -``` - If you cannot install FlashInfer, check out its [installation](https://docs.flashinfer.ai/installation.html#) page. If you still cannot install it, you can use the slower Triton kernels by adding `--disable-flashinfer` when launching the server. - If you only need to use the OpenAI backend, you can avoid installing other dependencies by using `pip install "sglang[openai]"`. diff --git a/python/sglang/srt/managers/controller/model_runner.py b/python/sglang/srt/managers/controller/model_runner.py index b5a7c06163c..ac353b64465 100644 --- a/python/sglang/srt/managers/controller/model_runner.py +++ b/python/sglang/srt/managers/controller/model_runner.py @@ -33,6 +33,7 @@ is_multimodal_model, monkey_patch_vllm_dummy_weight_loader, monkey_patch_vllm_p2p_access_check, + monkey_patch_vllm_qvk_linear_loader, ) logger = logging.getLogger("srt.model_runner") @@ -118,6 +119,13 @@ def load_model(self): seed=42, skip_tokenizer_init=True, ) + + if is_llama3_405b_fp8(self.model_config): + # A temporary hack to fix the num_heads for meta-llama/Meta-Llama-3.1-405B-FP8 checkpoints + self.model_config.hf_config.num_key_value_heads = 8 + vllm_model_config.hf_config.num_key_value_heads = 8 + monkey_patch_vllm_qvk_linear_loader() + self.dtype = vllm_model_config.dtype if self.model_config.model_overide_args is not None: vllm_model_config.hf_config.update(self.model_config.model_overide_args) diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index fbbdd06cbdd..762db53228a 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -202,15 +202,12 @@ def launch_server( "reinstall the latest version by following the instructions " "at https://docs.flashinfer.ai/installation.html.", ) - - if server_args.tp_size // server_args.dp_size > 1: + if server_args.tp_size * server_args.dp_size > 1: # FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency. maybe_set_triton_cache_manager() - if server_args.chat_template: # TODO: replace this with huggingface transformers template load_chat_template_for_openai_api(server_args.chat_template) - if server_args.enable_torch_compile: _set_torch_compile_config() diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index a7a9f26d4cf..0063c8d4100 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -21,6 +21,7 @@ from fastapi.responses import JSONResponse from packaging import version as pkg_version from starlette.middleware.base import BaseHTTPMiddleware +from torch.nn.parameter import Parameter from triton.runtime.cache import ( FileCacheManager, default_cache_dir, @@ -471,7 +472,7 @@ def maybe_set_triton_cache_manager() -> None: cache_manger = os.environ.get("TRITON_CACHE_MANAGER", None) if cache_manger is None: manager = "sglang.srt.utils:CustomCacheManager" - logger.info("Setting Triton cache manager to: %s", manager) + logger.debug("Setting Triton cache manager to: %s", manager) os.environ["TRITON_CACHE_MANAGER"] = manager @@ -615,3 +616,49 @@ def set_ulimit(target_soft_limit=65535): resource.setrlimit(resource_type, (target_soft_limit, current_hard)) except ValueError as e: logger.warn(f"Fail to set RLIMIT_NOFILE: {e}") + + +def is_llama3_405b_fp8(model_config): + """Return whether the model is meta-llama/Meta-Llama-3.1-405B-FP8 with 16 kv heads.""" + if ( + model_config.hf_config.architectures[0] == "LlamaForCausalLM" + and model_config.hf_config.hidden_size == 16384 + and model_config.hf_config.intermediate_size == 53248 + and model_config.hf_config.num_hidden_layers == 126 + and model_config.hf_config.num_key_value_heads == 16 + and model_config.hf_config.quantization_config["quant_method"] == "fbgemm_fp8" + ): + return True + return False + + +def monkey_patch_vllm_qvk_linear_loader(): + """A temporary hack to fix the num_heads for meta-llama/Meta-Llama-3.1-405B-FP8 checkpoints.""" + from vllm.model_executor.layers.linear import QKVParallelLinear + + def get_original_weight(loaded_weight, head_dim): + n_kv_head = loaded_weight.shape[0] // (2 * head_dim) + dim = loaded_weight.shape[1] + for i in range(n_kv_head): + loaded_weight[i * head_dim : (i + 1) * head_dim, :] = loaded_weight[ + 2 * i * head_dim : (2 * i + 1) * head_dim, : + ] + original_kv_weight = loaded_weight[: n_kv_head * head_dim, :] + assert original_kv_weight.shape == (n_kv_head * head_dim, dim) + return original_kv_weight + + def weight_loader_srt( + self, + param: Parameter, + loaded_weight: torch.Tensor, + loaded_shard_id: Optional[str] = None, + ): + if ( + loaded_shard_id in ["k", "v"] + and loaded_weight.shape[0] == self.head_size * self.total_num_kv_heads * 2 + ): + loaded_weight = get_original_weight(loaded_weight, self.head_size) + + self.weight_loader(self, param, loaded_weight, loaded_shard_id) + + setattr(QKVParallelLinear, "weight_loader", weight_loader_srt) From 43028230210334cd47d786877990f44947b8c736 Mon Sep 17 00:00:00 2001 From: Ying Sheng Date: Thu, 25 Jul 2024 07:29:53 -0700 Subject: [PATCH 2/7] update --- python/sglang/srt/managers/controller/model_runner.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/sglang/srt/managers/controller/model_runner.py b/python/sglang/srt/managers/controller/model_runner.py index ac353b64465..8459c98b81a 100644 --- a/python/sglang/srt/managers/controller/model_runner.py +++ b/python/sglang/srt/managers/controller/model_runner.py @@ -30,6 +30,7 @@ from sglang.srt.server_args import ServerArgs from sglang.srt.utils import ( get_available_gpu_memory, + is_llama3_405b_fp8, is_multimodal_model, monkey_patch_vllm_dummy_weight_loader, monkey_patch_vllm_p2p_access_check, From 4fe7d2508c401b7b48d2aed342736883e942ae92 Mon Sep 17 00:00:00 2001 From: Ying Sheng Date: Thu, 25 Jul 2024 07:31:06 -0700 Subject: [PATCH 3/7] update --- python/sglang/srt/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 0063c8d4100..16b536669e7 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -659,6 +659,6 @@ def weight_loader_srt( ): loaded_weight = get_original_weight(loaded_weight, self.head_size) - self.weight_loader(self, param, loaded_weight, loaded_shard_id) + self.weight_loader(param, loaded_weight, loaded_shard_id) setattr(QKVParallelLinear, "weight_loader", weight_loader_srt) From c4677978baf0122c633adbc9d542e4bec0bf4835 Mon Sep 17 00:00:00 2001 From: Ying Sheng Date: Thu, 25 Jul 2024 07:34:15 -0700 Subject: [PATCH 4/7] update --- python/sglang/srt/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 16b536669e7..a99e96f4857 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -635,6 +635,7 @@ def is_llama3_405b_fp8(model_config): def monkey_patch_vllm_qvk_linear_loader(): """A temporary hack to fix the num_heads for meta-llama/Meta-Llama-3.1-405B-FP8 checkpoints.""" from vllm.model_executor.layers.linear import QKVParallelLinear + origin_weight_loader = QKVParallelLinear.weight_loader def get_original_weight(loaded_weight, head_dim): n_kv_head = loaded_weight.shape[0] // (2 * head_dim) @@ -659,6 +660,6 @@ def weight_loader_srt( ): loaded_weight = get_original_weight(loaded_weight, self.head_size) - self.weight_loader(param, loaded_weight, loaded_shard_id) + origin_weight_loader(param, loaded_weight, loaded_shard_id) setattr(QKVParallelLinear, "weight_loader", weight_loader_srt) From 4083ae166644789953cd3e1a70164d66359e4f7b Mon Sep 17 00:00:00 2001 From: Ying Sheng Date: Thu, 25 Jul 2024 07:35:16 -0700 Subject: [PATCH 5/7] update --- python/sglang/srt/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index a99e96f4857..26d6047b879 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -660,6 +660,6 @@ def weight_loader_srt( ): loaded_weight = get_original_weight(loaded_weight, self.head_size) - origin_weight_loader(param, loaded_weight, loaded_shard_id) + origin_weight_loader(self, param, loaded_weight, loaded_shard_id) setattr(QKVParallelLinear, "weight_loader", weight_loader_srt) From 254a140f271be283fd5a8dd9f49db63dafff0744 Mon Sep 17 00:00:00 2001 From: Ying Sheng Date: Thu, 25 Jul 2024 07:38:26 -0700 Subject: [PATCH 6/7] update readme --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 68aab9e0d5a..3b7d1ed6eee 100644 --- a/README.md +++ b/README.md @@ -152,6 +152,7 @@ python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct ``` - If the model does not have a template in the Hugging Face tokenizer, you can specify a [custom chat template](docs/custom_chat_template.md). - To enable fp8 quantization, you can add `--quantization fp8` on a fp16 checkpoint or directly load a fp8 checkpoint without specifying any arguments. +- To enable experimental torch.compile support, you can add `--enable-torch-compile`. It accelerates small models on small batch sizes. ### Supported Models From f5c4485af49342c8935edad9daf1407a270f5e04 Mon Sep 17 00:00:00 2001 From: Ying Sheng Date: Thu, 25 Jul 2024 07:41:17 -0700 Subject: [PATCH 7/7] update --- python/sglang/srt/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 26d6047b879..e4367f4a415 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -635,6 +635,7 @@ def is_llama3_405b_fp8(model_config): def monkey_patch_vllm_qvk_linear_loader(): """A temporary hack to fix the num_heads for meta-llama/Meta-Llama-3.1-405B-FP8 checkpoints.""" from vllm.model_executor.layers.linear import QKVParallelLinear + origin_weight_loader = QKVParallelLinear.weight_loader def get_original_weight(loaded_weight, head_dim):