Skip to content

Commit

Permalink
fix: fp8 config (#723)
Browse files Browse the repository at this point in the history
  • Loading branch information
Ying1123 authored Jul 25, 2024
1 parent fded674 commit d63f13c
Showing 1 changed file with 51 additions and 0 deletions.
51 changes: 51 additions & 0 deletions python/sglang/srt/managers/controller/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@
BatchPrefillWithRaggedKVCacheWrapper,
)
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
from torch.nn.parameter import Parameter
from vllm.config import DeviceConfig, LoadConfig
from vllm.config import ModelConfig as VllmModelConfig
from vllm.distributed import (
get_tp_group,
init_distributed_environment,
initialize_model_parallel,
)
from vllm.model_executor.layers.linear import QKVParallelLinear
from vllm.model_executor.models import ModelRegistry

from sglang.global_config import global_config
Expand All @@ -38,6 +40,18 @@
logger = logging.getLogger("srt.model_runner")


def is_llama3_405b_fp8(model_config):
if (
model_config.hf_config.architectures[0] == "LlamaForCausalLM"
and model_config.hf_config.hidden_size == 16384
and model_config.hf_config.intermediate_size == 53248
and model_config.hf_config.num_hidden_layers == 126
and model_config.hf_config.quantization_config["quant_method"] == "fbgemm_fp8"
):
return True
return False


class ModelRunner:
def __init__(
self,
Expand Down Expand Up @@ -118,6 +132,9 @@ def load_model(self):
seed=42,
skip_tokenizer_init=True,
)
if is_llama3_405b_fp8(self.model_config):
self.model_config.hf_config.num_key_value_heads = 8
vllm_model_config.hf_config.num_key_value_heads = 8
self.dtype = vllm_model_config.dtype
if self.model_config.model_overide_args is not None:
vllm_model_config.hf_config.update(self.model_config.model_overide_args)
Expand Down Expand Up @@ -370,5 +387,39 @@ def load_model_cls_srt(model_arch: str) -> Optional[Type[nn.Module]]:
return model_arch_name_to_cls[model_arch]


def get_original_weight(loaded_weight, head_dim):
n_kv_head = loaded_weight.shape[0] // (2 * head_dim)
dim = loaded_weight.shape[1]
for i in range(n_kv_head):
loaded_weight[i * head_dim : (i + 1) * head_dim, :] = loaded_weight[
2 * i * head_dim : (2 * i + 1) * head_dim, :
]
original_kv_weight = loaded_weight[: n_kv_head * head_dim, :]
assert original_kv_weight.shape == (n_kv_head * head_dim, dim)
return original_kv_weight


def get_weight_loader_srt(weight_loader):
def weight_loader_srt(
self,
param: Parameter,
loaded_weight: torch.Tensor,
loaded_shard_id: Optional[str] = None,
):
if (
loaded_shard_id in ["k", "v"]
and loaded_weight.shape[0] == self.head_size * self.total_num_kv_heads * 2
):
loaded_weight = get_original_weight(loaded_weight, self.head_size)

weight_loader(self, param, loaded_weight, loaded_shard_id)

return weight_loader_srt


# Monkey patch model loader
setattr(ModelRegistry, "load_model_cls", load_model_cls_srt)
original_weight_loader = QKVParallelLinear.weight_loader
setattr(
QKVParallelLinear, "weight_loader", get_weight_loader_srt(original_weight_loader)
)

0 comments on commit d63f13c

Please sign in to comment.