Skip to content

Commit

Permalink
Fix bugs (fp8 checkpoints, triton cache manager) (#729)
Browse files Browse the repository at this point in the history
  • Loading branch information
Ying1123 authored Jul 25, 2024
1 parent ae0f613 commit 8fbba3d
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 10 deletions.
6 changes: 1 addition & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,6 @@ docker run --gpus all \
```

### Common Notes
- If you see errors from the Triton compiler, please install the [Triton Nightly](https://triton-lang.org/main/getting-started/installation.html) by
```
pip uninstall -y triton triton-nightly
pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly
```
- If you cannot install FlashInfer, check out its [installation](https://docs.flashinfer.ai/installation.html#) page. If you still cannot install it, you can use the slower Triton kernels by adding `--disable-flashinfer` when launching the server.
- If you only need to use the OpenAI backend, you can avoid installing other dependencies by using `pip install "sglang[openai]"`.

Expand Down Expand Up @@ -157,6 +152,7 @@ python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct
```
- If the model does not have a template in the Hugging Face tokenizer, you can specify a [custom chat template](docs/custom_chat_template.md).
- To enable fp8 quantization, you can add `--quantization fp8` on a fp16 checkpoint or directly load a fp8 checkpoint without specifying any arguments.
- To enable experimental torch.compile support, you can add `--enable-torch-compile`. It accelerates small models on small batch sizes.

### Supported Models

Expand Down
9 changes: 9 additions & 0 deletions python/sglang/srt/managers/controller/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,11 @@
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import (
get_available_gpu_memory,
is_llama3_405b_fp8,
is_multimodal_model,
monkey_patch_vllm_dummy_weight_loader,
monkey_patch_vllm_p2p_access_check,
monkey_patch_vllm_qvk_linear_loader,
)

logger = logging.getLogger("srt.model_runner")
Expand Down Expand Up @@ -118,6 +120,13 @@ def load_model(self):
seed=42,
skip_tokenizer_init=True,
)

if is_llama3_405b_fp8(self.model_config):
# A temporary hack to fix the num_heads for meta-llama/Meta-Llama-3.1-405B-FP8 checkpoints
self.model_config.hf_config.num_key_value_heads = 8
vllm_model_config.hf_config.num_key_value_heads = 8
monkey_patch_vllm_qvk_linear_loader()

self.dtype = vllm_model_config.dtype
if self.model_config.model_overide_args is not None:
vllm_model_config.hf_config.update(self.model_config.model_overide_args)
Expand Down
5 changes: 1 addition & 4 deletions python/sglang/srt/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,15 +202,12 @@ def launch_server(
"reinstall the latest version by following the instructions "
"at https://docs.flashinfer.ai/installation.html.",
)

if server_args.tp_size // server_args.dp_size > 1:
if server_args.tp_size * server_args.dp_size > 1:
# FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency.
maybe_set_triton_cache_manager()

if server_args.chat_template:
# TODO: replace this with huggingface transformers template
load_chat_template_for_openai_api(server_args.chat_template)

if server_args.enable_torch_compile:
_set_torch_compile_config()

Expand Down
51 changes: 50 additions & 1 deletion python/sglang/srt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from fastapi.responses import JSONResponse
from packaging import version as pkg_version
from starlette.middleware.base import BaseHTTPMiddleware
from torch.nn.parameter import Parameter
from triton.runtime.cache import (
FileCacheManager,
default_cache_dir,
Expand Down Expand Up @@ -471,7 +472,7 @@ def maybe_set_triton_cache_manager() -> None:
cache_manger = os.environ.get("TRITON_CACHE_MANAGER", None)
if cache_manger is None:
manager = "sglang.srt.utils:CustomCacheManager"
logger.info("Setting Triton cache manager to: %s", manager)
logger.debug("Setting Triton cache manager to: %s", manager)
os.environ["TRITON_CACHE_MANAGER"] = manager


Expand Down Expand Up @@ -615,3 +616,51 @@ def set_ulimit(target_soft_limit=65535):
resource.setrlimit(resource_type, (target_soft_limit, current_hard))
except ValueError as e:
logger.warn(f"Fail to set RLIMIT_NOFILE: {e}")


def is_llama3_405b_fp8(model_config):
"""Return whether the model is meta-llama/Meta-Llama-3.1-405B-FP8 with 16 kv heads."""
if (
model_config.hf_config.architectures[0] == "LlamaForCausalLM"
and model_config.hf_config.hidden_size == 16384
and model_config.hf_config.intermediate_size == 53248
and model_config.hf_config.num_hidden_layers == 126
and model_config.hf_config.num_key_value_heads == 16
and model_config.hf_config.quantization_config["quant_method"] == "fbgemm_fp8"
):
return True
return False


def monkey_patch_vllm_qvk_linear_loader():
"""A temporary hack to fix the num_heads for meta-llama/Meta-Llama-3.1-405B-FP8 checkpoints."""
from vllm.model_executor.layers.linear import QKVParallelLinear

origin_weight_loader = QKVParallelLinear.weight_loader

def get_original_weight(loaded_weight, head_dim):
n_kv_head = loaded_weight.shape[0] // (2 * head_dim)
dim = loaded_weight.shape[1]
for i in range(n_kv_head):
loaded_weight[i * head_dim : (i + 1) * head_dim, :] = loaded_weight[
2 * i * head_dim : (2 * i + 1) * head_dim, :
]
original_kv_weight = loaded_weight[: n_kv_head * head_dim, :]
assert original_kv_weight.shape == (n_kv_head * head_dim, dim)
return original_kv_weight

def weight_loader_srt(
self,
param: Parameter,
loaded_weight: torch.Tensor,
loaded_shard_id: Optional[str] = None,
):
if (
loaded_shard_id in ["k", "v"]
and loaded_weight.shape[0] == self.head_size * self.total_num_kv_heads * 2
):
loaded_weight = get_original_weight(loaded_weight, self.head_size)

origin_weight_loader(self, param, loaded_weight, loaded_shard_id)

setattr(QKVParallelLinear, "weight_loader", weight_loader_srt)

0 comments on commit 8fbba3d

Please sign in to comment.