diff --git a/python/sglang/srt/managers/controller/model_runner.py b/python/sglang/srt/managers/controller/model_runner.py index 5357e4b9878..b5a7c06163c 100644 --- a/python/sglang/srt/managers/controller/model_runner.py +++ b/python/sglang/srt/managers/controller/model_runner.py @@ -15,7 +15,6 @@ BatchPrefillWithRaggedKVCacheWrapper, ) from flashinfer.decode import _grouped_size_compiled_for_decode_kernels -from torch.nn.parameter import Parameter from vllm.config import DeviceConfig, LoadConfig from vllm.config import ModelConfig as VllmModelConfig from vllm.distributed import ( @@ -23,7 +22,6 @@ init_distributed_environment, initialize_model_parallel, ) -from vllm.model_executor.layers.linear import QKVParallelLinear from vllm.model_executor.models import ModelRegistry from sglang.global_config import global_config @@ -40,18 +38,6 @@ logger = logging.getLogger("srt.model_runner") -def is_llama3_405b_fp8(model_config): - if ( - model_config.hf_config.architectures[0] == "LlamaForCausalLM" - and model_config.hf_config.hidden_size == 16384 - and model_config.hf_config.intermediate_size == 53248 - and model_config.hf_config.num_hidden_layers == 126 - and model_config.hf_config.quantization_config["quant_method"] == "fbgemm_fp8" - ): - return True - return False - - class ModelRunner: def __init__( self, @@ -132,9 +118,6 @@ def load_model(self): seed=42, skip_tokenizer_init=True, ) - if is_llama3_405b_fp8(self.model_config): - self.model_config.hf_config.num_key_value_heads = 8 - vllm_model_config.hf_config.num_key_value_heads = 8 self.dtype = vllm_model_config.dtype if self.model_config.model_overide_args is not None: vllm_model_config.hf_config.update(self.model_config.model_overide_args) @@ -387,39 +370,5 @@ def load_model_cls_srt(model_arch: str) -> Optional[Type[nn.Module]]: return model_arch_name_to_cls[model_arch] -def get_original_weight(loaded_weight, head_dim): - n_kv_head = loaded_weight.shape[0] // (2 * head_dim) - dim = loaded_weight.shape[1] - for i in range(n_kv_head): - loaded_weight[i * head_dim : (i + 1) * head_dim, :] = loaded_weight[ - 2 * i * head_dim : (2 * i + 1) * head_dim, : - ] - original_kv_weight = loaded_weight[: n_kv_head * head_dim, :] - assert original_kv_weight.shape == (n_kv_head * head_dim, dim) - return original_kv_weight - - -def get_weight_loader_srt(weight_loader): - def weight_loader_srt( - self, - param: Parameter, - loaded_weight: torch.Tensor, - loaded_shard_id: Optional[str] = None, - ): - if ( - loaded_shard_id in ["k", "v"] - and loaded_weight.shape[0] == self.head_size * self.total_num_kv_heads * 2 - ): - loaded_weight = get_original_weight(loaded_weight, self.head_size) - - weight_loader(self, param, loaded_weight, loaded_shard_id) - - return weight_loader_srt - - # Monkey patch model loader setattr(ModelRegistry, "load_model_cls", load_model_cls_srt) -original_weight_loader = QKVParallelLinear.weight_loader -setattr( - QKVParallelLinear, "weight_loader", get_weight_loader_srt(original_weight_loader) -)