From f10d9e18d40fe634d3708af19deb2c37ba408e28 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Fri, 27 Dec 2024 04:03:51 +0000 Subject: [PATCH 1/2] updated --- vllm/model_executor/layers/fused_moe/layer.py | 17 ++++++++++++++--- .../layers/quantization/awq_marlin.py | 6 +++++- .../compressed_tensors_moe.py | 14 +++++++++++--- .../layers/quantization/experts_int8.py | 6 +++++- vllm/model_executor/layers/quantization/fp8.py | 1 - .../layers/quantization/gptq_marlin.py | 6 +++++- 6 files changed, 40 insertions(+), 10 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 01ffac4550f28..c54043b2552c0 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -41,9 +41,20 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, raise NotImplementedError @abstractmethod - def apply(self, layer: torch.nn.Module, x: torch.Tensor, - router_logits: torch.Tensor, top_k: int, renormalize: bool, - use_grouped_topk: bool) -> torch.Tensor: + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None + ) -> torch.Tensor: raise NotImplementedError diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py index 4d1a837d11585..b3adb3dbcb440 100644 --- a/vllm/model_executor/layers/quantization/awq_marlin.py +++ b/vllm/model_executor/layers/quantization/awq_marlin.py @@ -445,6 +445,8 @@ def apply( num_expert_group: Optional[int] = None, topk_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x, @@ -454,7 +456,9 @@ def apply( renormalize=renormalize, topk_group=topk_group, num_expert_group=num_expert_group, - custom_routing_function=custom_routing_function) + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias) return torch.ops.vllm.fused_marlin_moe( x, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index dad04017d3212..e7056725d3d2f 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -208,8 +208,9 @@ def apply( num_expert_group: Optional[int] = None, topk_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: - from vllm.model_executor.layers.fused_moe import fused_experts topk_weights, topk_ids = FusedMoE.select_experts( @@ -220,7 +221,9 @@ def apply( renormalize=renormalize, topk_group=topk_group, num_expert_group=num_expert_group, - custom_routing_function=custom_routing_function) + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias) return fused_experts(x, layer.w13_weight, @@ -481,7 +484,10 @@ def apply( num_expert_group: Optional[int] = None, topk_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: + topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x, router_logits=router_logits, @@ -490,7 +496,9 @@ def apply( renormalize=renormalize, topk_group=topk_group, num_expert_group=num_expert_group, - custom_routing_function=custom_routing_function) + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias) return torch.ops.vllm.fused_marlin_moe( x, diff --git a/vllm/model_executor/layers/quantization/experts_int8.py b/vllm/model_executor/layers/quantization/experts_int8.py index 97297970d9317..ace9c9f7f31f6 100644 --- a/vllm/model_executor/layers/quantization/experts_int8.py +++ b/vllm/model_executor/layers/quantization/experts_int8.py @@ -104,6 +104,8 @@ def apply( num_expert_group: Optional[int] = None, topk_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: from vllm.model_executor.layers.fused_moe import fused_experts @@ -115,7 +117,9 @@ def apply( renormalize=renormalize, topk_group=topk_group, num_expert_group=num_expert_group, - custom_routing_function=custom_routing_function) + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias) return fused_experts(x, layer.w13_weight, diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 4362468c1db69..918ccca9c59b0 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -608,7 +608,6 @@ def apply( scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: - from vllm.model_executor.layers.fused_moe import fused_experts topk_weights, topk_ids = FusedMoE.select_experts( diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index a3e58bf1b2a4c..9b430fcc7f184 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -537,6 +537,8 @@ def apply( num_expert_group: Optional[int] = None, topk_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: # The input must currently be float16 orig_dtype = x.dtype @@ -550,7 +552,9 @@ def apply( renormalize=renormalize, topk_group=topk_group, num_expert_group=num_expert_group, - custom_routing_function=None) + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias) return torch.ops.vllm.fused_marlin_moe( x, From 60e173ce059e12dee9c54e910de6608564b0e513 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Fri, 27 Dec 2024 04:09:27 +0000 Subject: [PATCH 2/2] make tests pass --- vllm/model_executor/layers/fused_moe/layer.py | 4 ++-- vllm/model_executor/layers/quantization/awq_marlin.py | 4 ++-- .../compressed_tensors/compressed_tensors_moe.py | 8 ++++---- vllm/model_executor/layers/quantization/experts_int8.py | 4 ++-- vllm/model_executor/layers/quantization/fp8.py | 2 +- vllm/model_executor/layers/quantization/gptq_marlin.py | 4 ++-- 6 files changed, 13 insertions(+), 13 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index c54043b2552c0..b108cbd52c218 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -48,7 +48,7 @@ def apply( router_logits: torch.Tensor, top_k: int, renormalize: bool, - use_grouped_topk: bool, + use_grouped_topk: bool = False, topk_group: Optional[int] = None, num_expert_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None, @@ -90,7 +90,7 @@ def apply( router_logits: torch.Tensor, top_k: int, renormalize: bool, - use_grouped_topk: bool, + use_grouped_topk: bool = False, topk_group: Optional[int] = None, num_expert_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None, diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py index b3adb3dbcb440..c28fd0c6737e0 100644 --- a/vllm/model_executor/layers/quantization/awq_marlin.py +++ b/vllm/model_executor/layers/quantization/awq_marlin.py @@ -440,10 +440,10 @@ def apply( x: torch.Tensor, router_logits: torch.Tensor, top_k: int, - renormalize: bool = True, + renormalize: bool, use_grouped_topk: bool = False, - num_expert_group: Optional[int] = None, topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index e7056725d3d2f..5fd6b017f444b 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -203,10 +203,10 @@ def apply( x: torch.Tensor, router_logits: torch.Tensor, top_k: int, - renormalize: bool = True, + renormalize: bool, use_grouped_topk: bool = False, - num_expert_group: Optional[int] = None, topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None, @@ -479,10 +479,10 @@ def apply( x: torch.Tensor, router_logits: torch.Tensor, top_k: int, - renormalize: bool = True, + renormalize: bool, use_grouped_topk: bool = False, - num_expert_group: Optional[int] = None, topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None, diff --git a/vllm/model_executor/layers/quantization/experts_int8.py b/vllm/model_executor/layers/quantization/experts_int8.py index ace9c9f7f31f6..209f12c6dfec9 100644 --- a/vllm/model_executor/layers/quantization/experts_int8.py +++ b/vllm/model_executor/layers/quantization/experts_int8.py @@ -99,10 +99,10 @@ def apply( x: torch.Tensor, router_logits: torch.Tensor, top_k: int, - renormalize: bool = True, + renormalize: bool, use_grouped_topk: bool = False, - num_expert_group: Optional[int] = None, topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None, diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 918ccca9c59b0..7f779ac8d3b3e 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -601,7 +601,7 @@ def apply( router_logits: torch.Tensor, top_k: int, renormalize: bool, - use_grouped_topk: bool, + use_grouped_topk: bool = False, topk_group: Optional[int] = None, num_expert_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None, diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 9b430fcc7f184..a006d729cc627 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -532,10 +532,10 @@ def apply( x: torch.Tensor, router_logits: torch.Tensor, top_k: int, - renormalize: bool = True, + renormalize: bool, use_grouped_topk: bool = False, - num_expert_group: Optional[int] = None, topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None,