diff --git a/python/sglang/srt/managers/controller/model_runner.py b/python/sglang/srt/managers/controller/model_runner.py index b5a7c06163c..5357e4b9878 100644 --- a/python/sglang/srt/managers/controller/model_runner.py +++ b/python/sglang/srt/managers/controller/model_runner.py @@ -15,6 +15,7 @@ BatchPrefillWithRaggedKVCacheWrapper, ) from flashinfer.decode import _grouped_size_compiled_for_decode_kernels +from torch.nn.parameter import Parameter from vllm.config import DeviceConfig, LoadConfig from vllm.config import ModelConfig as VllmModelConfig from vllm.distributed import ( @@ -22,6 +23,7 @@ init_distributed_environment, initialize_model_parallel, ) +from vllm.model_executor.layers.linear import QKVParallelLinear from vllm.model_executor.models import ModelRegistry from sglang.global_config import global_config @@ -38,6 +40,18 @@ logger = logging.getLogger("srt.model_runner") +def is_llama3_405b_fp8(model_config): + if ( + model_config.hf_config.architectures[0] == "LlamaForCausalLM" + and model_config.hf_config.hidden_size == 16384 + and model_config.hf_config.intermediate_size == 53248 + and model_config.hf_config.num_hidden_layers == 126 + and model_config.hf_config.quantization_config["quant_method"] == "fbgemm_fp8" + ): + return True + return False + + class ModelRunner: def __init__( self, @@ -118,6 +132,9 @@ def load_model(self): seed=42, skip_tokenizer_init=True, ) + if is_llama3_405b_fp8(self.model_config): + self.model_config.hf_config.num_key_value_heads = 8 + vllm_model_config.hf_config.num_key_value_heads = 8 self.dtype = vllm_model_config.dtype if self.model_config.model_overide_args is not None: vllm_model_config.hf_config.update(self.model_config.model_overide_args) @@ -370,5 +387,39 @@ def load_model_cls_srt(model_arch: str) -> Optional[Type[nn.Module]]: return model_arch_name_to_cls[model_arch] +def get_original_weight(loaded_weight, head_dim): + n_kv_head = loaded_weight.shape[0] // (2 * head_dim) + dim = loaded_weight.shape[1] + for i in range(n_kv_head): + loaded_weight[i * head_dim : (i + 1) * head_dim, :] = loaded_weight[ + 2 * i * head_dim : (2 * i + 1) * head_dim, : + ] + original_kv_weight = loaded_weight[: n_kv_head * head_dim, :] + assert original_kv_weight.shape == (n_kv_head * head_dim, dim) + return original_kv_weight + + +def get_weight_loader_srt(weight_loader): + def weight_loader_srt( + self, + param: Parameter, + loaded_weight: torch.Tensor, + loaded_shard_id: Optional[str] = None, + ): + if ( + loaded_shard_id in ["k", "v"] + and loaded_weight.shape[0] == self.head_size * self.total_num_kv_heads * 2 + ): + loaded_weight = get_original_weight(loaded_weight, self.head_size) + + weight_loader(self, param, loaded_weight, loaded_shard_id) + + return weight_loader_srt + + # Monkey patch model loader setattr(ModelRegistry, "load_model_cls", load_model_cls_srt) +original_weight_loader = QKVParallelLinear.weight_loader +setattr( + QKVParallelLinear, "weight_loader", get_weight_loader_srt(original_weight_loader) +)