diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 3d822fc0c7f99..da0ce1885dbb2 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -38,7 +38,7 @@ class FusedMoEMethodBase(QuantizeMethodBase): @abstractmethod def create_weights(self, layer: torch.nn.Module, num_experts: int, - hidden_size: int, intermediate_size: int, + hidden_size: int, intermediate_size_per_partition: int, params_dtype: torch.dtype, **extra_weight_attrs): raise NotImplementedError @@ -65,22 +65,24 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): """MoE method without quantization.""" def create_weights(self, layer: torch.nn.Module, num_experts: int, - hidden_size: int, intermediate_size: int, + hidden_size: int, intermediate_size_per_partition: int, params_dtype: torch.dtype, **extra_weight_attrs): # Fused gate_up_proj (column parallel) - w13_weight = torch.nn.Parameter(torch.empty(num_experts, - 2 * intermediate_size, - hidden_size, - dtype=params_dtype), + w13_weight = torch.nn.Parameter(torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size, + dtype=params_dtype), requires_grad=False) layer.register_parameter("w13_weight", w13_weight) set_weight_attrs(w13_weight, extra_weight_attrs) # down_proj (row parallel) - w2_weight = torch.nn.Parameter(torch.empty(num_experts, - hidden_size, - intermediate_size, - dtype=params_dtype), + w2_weight = torch.nn.Parameter(torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition, + dtype=params_dtype), requires_grad=False) layer.register_parameter("w2_weight", w2_weight) set_weight_attrs(w2_weight, extra_weight_attrs) @@ -289,13 +291,20 @@ def __init__( self.quant_method = quant_config.get_quant_method(self, prefix) assert self.quant_method is not None - self.quant_method.create_weights( - layer=self, - num_experts=num_experts, - hidden_size=hidden_size, - intermediate_size=self.intermediate_size_per_partition, - params_dtype=params_dtype, - weight_loader=self.weight_loader) + moe_quant_params = { + "num_experts": num_experts, + "hidden_size": hidden_size, + "intermediate_size_per_partition": + self.intermediate_size_per_partition, + "params_dtype": params_dtype, + "weight_loader": self.weight_loader, + } + # need full intermediate size pre-sharding for WNA16 act order + if (self.quant_method.__class__.__name__ == + "CompressedTensorsWNA16MoEMethod"): + moe_quant_params["intermediate_size_full"] = intermediate_size + + self.quant_method.create_weights(layer=self, **moe_quant_params) def _load_per_tensor_weight_scale(self, shard_id: str, param: torch.nn.Parameter, @@ -312,19 +321,30 @@ def _load_per_tensor_weight_scale(self, shard_id: str, elif shard_id == "w2": param_data[expert_id] = loaded_weight - def _load_model_weight_or_group_weight_scale(self, shard_dim: int, + def _load_model_weight_or_group_weight_scale(self, + shard_dim: int, expert_data: torch.Tensor, shard_id: str, loaded_weight: torch.Tensor, - tp_rank: int): - # Load grouped weight scales for group quantization - # or model weights + tp_rank: int, + load_full_w2: bool = False): + """ + Load grouped weight scales for group quantization or model weights + :param shard_dim: dimension to shard + :param expert_data: parameter for a particular expert + :param shard_id: either w1, w2, or w3 + :param loaded_weight: checkpoint weight to load into the param + :param tp_rank: tensor parallel rank + :param load_full_w2: whether or not the w2 loaded should be sharded. + """ if shard_id == "w2": - self._load_w2(shard_id=shard_id, - shard_dim=shard_dim, + # In the case where we have actorder/g_idx, we do not partition the + # w2 scales, as indicated by `load_full` argument, for all tp cases + self._load_w2(shard_dim=shard_dim, loaded_weight=loaded_weight, expert_data=expert_data, - tp_rank=tp_rank) + tp_rank=tp_rank, + load_full=load_full_w2) elif shard_id in ("w1", "w3"): self._load_w13(shard_id=shard_id, shard_dim=shard_dim, @@ -364,15 +384,21 @@ def _load_w13(self, expert_data: torch.Tensor, shard_dim: int, expert_data = expert_data.narrow(shard_dim, shard_size, shard_size) expert_data.copy_(loaded_weight) - def _load_w2(self, expert_data: torch.Tensor, shard_dim: int, - shard_id: str, loaded_weight: torch.Tensor, tp_rank: int): + def _load_w2(self, + expert_data: torch.Tensor, + shard_dim: int, + loaded_weight: torch.Tensor, + tp_rank: int, + load_full: bool = False): # Index the loaded weight for tp sharding. # down_proj: "RowParallel" so tp sharding on input_dim # Narrow parameter and load. shard_size = expert_data.shape[shard_dim] - loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank, - shard_size) + if not load_full: + loaded_weight = loaded_weight.narrow(shard_dim, + shard_size * tp_rank, + shard_size) # w2, down_proj: Load into only logical weight of w2. expert_data.copy_(loaded_weight) @@ -387,8 +413,7 @@ def _load_g_idx(self, shard_id: str, expert_data: torch.Tensor, shard_dim: int, loaded_weight: torch.Tensor, tp_rank: int): if shard_id == "w2": - self._load_w2(shard_id=shard_id, - shard_dim=shard_dim, + self._load_w2(shard_dim=shard_dim, loaded_weight=loaded_weight, expert_data=expert_data, tp_rank=tp_rank) @@ -416,7 +441,7 @@ def weight_loader(self, param: torch.nn.Parameter, ] # Fetch the dim to shard the parameter/loaded weight # based on the shard id. This will be whatever - # dimension intermediate_size is used. + # dimension intermediate_size_per_partition is used. SHARD_ID_TO_SHARDED_DIM = {"w1": 0, "w2": 1, "w3": 0} expert_data = param.data[expert_id] @@ -424,11 +449,11 @@ def weight_loader(self, param: torch.nn.Parameter, # is_transposed: if the dim to shard the weight # should be flipped. Required by GPTQ, compressed-tensors - # should be whatever dimension intermediate_size is + # should be whatever dimension intermediate_size_per_partition is is_transposed = getattr(param, "is_transposed", False) shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id] if is_transposed: - shard_dim = ~shard_dim + shard_dim = int(not shard_dim) # Case input scale: input_scale loading is only supported for fp8 if "input_scale" in weight_name: @@ -480,7 +505,8 @@ def weight_loader(self, param: torch.nn.Parameter, shard_dim=shard_dim, loaded_weight=loaded_weight, expert_data=expert_data, - tp_rank=tp_rank) + tp_rank=tp_rank, + load_full_w2=getattr(param, "load_full_w2", False)) elif quant_method == FusedMoeWeightScaleSupported.TENSOR.value: self._load_per_tensor_weight_scale(shard_id=shard_id, param=param, diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py index c28fd0c6737e0..0c3c9816878e9 100644 --- a/vllm/model_executor/layers/quantization/awq_marlin.py +++ b/vllm/model_executor/layers/quantization/awq_marlin.py @@ -303,7 +303,7 @@ def __init__(self, quant_config: AWQMarlinConfig): self.quant_config = quant_config def create_weights(self, layer: torch.nn.Module, num_experts: int, - hidden_size: int, intermediate_size: int, + hidden_size: int, intermediate_size_per_partition: int, params_dtype: torch.dtype, **extra_weight_attrs): extra_weight_attrs.update({ "is_transposed": @@ -312,17 +312,18 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, FusedMoeWeightScaleSupported.GROUP.value, }) - w13_qweight = Parameter(torch.empty(num_experts, - hidden_size, - 2 * intermediate_size // - self.quant_config.pack_factor, - dtype=torch.int32), - requires_grad=False) + w13_qweight = Parameter( + torch.empty(num_experts, + hidden_size, + 2 * intermediate_size_per_partition // + self.quant_config.pack_factor, + dtype=torch.int32), + requires_grad=False) layer.register_parameter("w13_qweight", w13_qweight) set_weight_attrs(w13_qweight, extra_weight_attrs) w2_qweight = Parameter(torch.empty(num_experts, - intermediate_size, + intermediate_size_per_partition, hidden_size // self.quant_config.pack_factor, dtype=torch.int32), @@ -331,13 +332,14 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, set_weight_attrs(w2_qweight, extra_weight_attrs) num_groups_w13 = hidden_size // self.quant_config.group_size - num_groups_w2 = intermediate_size // self.quant_config.group_size + num_groups_w2 = (intermediate_size_per_partition // + self.quant_config.group_size) # WEIGHT_SCALES # Allocate 2 scales for w1 and w3 respectively. w13_scales = Parameter(torch.empty(num_experts, num_groups_w13, - intermediate_size * 2, + intermediate_size_per_partition * 2, dtype=params_dtype), requires_grad=False) layer.register_parameter("w13_scales", w13_scales) @@ -353,12 +355,13 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, # WEIGHT_ZERO_POINT # Allocate 2 zero points for w1 and w3 respectively. - w13_qzeros = Parameter(torch.empty(num_experts, - num_groups_w13, - 2 * intermediate_size // - self.quant_config.pack_factor, - dtype=torch.int32), - requires_grad=False) + w13_qzeros = Parameter( + torch.empty(num_experts, + num_groups_w13, + 2 * intermediate_size_per_partition // + self.quant_config.pack_factor, + dtype=torch.int32), + requires_grad=False) layer.register_parameter("w13_qzeros", w13_qzeros) set_weight_attrs(w13_qzeros, extra_weight_attrs) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 4fb8fd84e92d4..e1c45f4e42e41 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -13,6 +13,7 @@ FusedMoeWeightScaleSupported) from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( WNA16_SUPPORTED_BITS) +from vllm.model_executor.layers.quantization.utils import replace_parameter from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( all_close_1d, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize) from vllm.model_executor.utils import set_weight_attrs @@ -75,24 +76,26 @@ def __init__( self.static_input_scales = not self.input_quant.dynamic def create_weights(self, layer: torch.nn.Module, num_experts: int, - hidden_size: int, intermediate_size: int, + hidden_size: int, intermediate_size_per_partition: int, params_dtype: torch.dtype, **extra_weight_attrs): params_dtype = torch.float8_e4m3fn # WEIGHTS - w13_weight = torch.nn.Parameter(torch.empty(num_experts, - 2 * intermediate_size, - hidden_size, - dtype=params_dtype), + w13_weight = torch.nn.Parameter(torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size, + dtype=params_dtype), requires_grad=False) layer.register_parameter("w13_weight", w13_weight) set_weight_attrs(w13_weight, extra_weight_attrs) - w2_weight = torch.nn.Parameter(torch.empty(num_experts, - hidden_size, - intermediate_size, - dtype=params_dtype), + w2_weight = torch.nn.Parameter(torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition, + dtype=params_dtype), requires_grad=False) layer.register_parameter("w2_weight", w2_weight) set_weight_attrs(w2_weight, extra_weight_attrs) @@ -254,6 +257,7 @@ def __init__( self.packed_factor = 32 // config.num_bits self.strategy = config.strategy self.group_size = config.group_size + self.actorder = config.actorder assert config.symmetric, ( "Only symmetric quantization is supported for MoE") @@ -266,9 +270,16 @@ def __init__( f"{WNA16_SUPPORTED_BITS}") def create_weights(self, layer: torch.nn.Module, num_experts: int, - hidden_size: int, intermediate_size: int, + hidden_size: int, intermediate_size_per_partition: int, params_dtype: torch.dtype, **extra_weight_attrs): + assert params_dtype == torch.float16, ( + "float16 is required for MoE compressed models. Set dtype=torch.float16" # noqa: E501 + ) + + intermediate_size_full = extra_weight_attrs.pop( + "intermediate_size_full") + # Will transpose the loaded weight along the # intermediate and hidden dim sizes. Will # shard for TP along the transposed dims @@ -276,35 +287,45 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, "is_transposed": True, "quant_method": self.strategy }) - w13_weight = torch.nn.Parameter(torch.empty(num_experts, - hidden_size // - self.packed_factor, - 2 * intermediate_size, - dtype=torch.int32), + w13_weight = torch.nn.Parameter(torch.empty( + num_experts, + hidden_size // self.packed_factor, + 2 * intermediate_size_per_partition, + dtype=torch.int32), requires_grad=False) layer.register_parameter("w13_weight_packed", w13_weight) set_weight_attrs(w13_weight, extra_weight_attrs) - w2_weight = torch.nn.Parameter(torch.empty(num_experts, - intermediate_size // - self.packed_factor, - hidden_size, - dtype=torch.int32), + w2_weight = torch.nn.Parameter(torch.empty( + num_experts, + intermediate_size_per_partition // self.packed_factor, + hidden_size, + dtype=torch.int32), requires_grad=False) layer.register_parameter("w2_weight_packed", w2_weight) set_weight_attrs(w2_weight, extra_weight_attrs) + # In the case where we have actorder/g_idx, + # we do not partition the w2 scales + load_full_w2 = self.actorder and self.group_size != -1 + w2_scales_size = (intermediate_size_full + if load_full_w2 else intermediate_size_per_partition) + + self.is_k_full = (not self.actorder) or ( + intermediate_size_per_partition == intermediate_size_full) + if self.strategy == "channel": num_groups_w2 = num_groups_w13 = 1 self.group_size = -1 else: - num_groups_w2 = intermediate_size // self.group_size + num_groups_w2 = w2_scales_size // self.group_size num_groups_w13 = hidden_size // self.group_size - w13_scale = torch.nn.Parameter(torch.ones(num_experts, - num_groups_w13, - 2 * intermediate_size, - dtype=params_dtype), + w13_scale = torch.nn.Parameter(torch.ones( + num_experts, + num_groups_w13, + 2 * intermediate_size_per_partition, + dtype=params_dtype), requires_grad=False) layer.register_parameter("w13_weight_scale", w13_scale) set_weight_attrs(w13_scale, extra_weight_attrs) @@ -316,6 +337,7 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, requires_grad=False) layer.register_parameter("w2_weight_scale", w2_scale) set_weight_attrs(w2_scale, extra_weight_attrs) + set_weight_attrs(w2_scale, {"load_full_w2": load_full_w2}) w2_weight_shape = torch.nn.Parameter(torch.empty(num_experts, 2), requires_grad=False) @@ -335,18 +357,18 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, ), requires_grad=False, ) - layer.register_parameter("w13_g_idx", w13_g_idx) + layer.register_parameter("w13_weight_g_idx", w13_g_idx) set_weight_attrs(w13_g_idx, extra_weight_attrs) w2_g_idx = torch.nn.Parameter( torch.empty( num_experts, - intermediate_size, + intermediate_size_per_partition, dtype=torch.int32, ), requires_grad=False, ) - layer.register_parameter("w2_g_idx", w2_g_idx) + layer.register_parameter("w2_weight_g_idx", w2_g_idx) set_weight_attrs(w2_g_idx, extra_weight_attrs) w13_g_idx_sort_indices = torch.nn.Parameter( @@ -364,7 +386,7 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, w2_g_idx_sort_indices = torch.nn.Parameter( torch.empty( num_experts, - intermediate_size, + intermediate_size_per_partition, dtype=torch.int32, ), requires_grad=False, @@ -422,24 +444,55 @@ def marlin_moe_permute_scales(s: torch.Tensor, size_k: int, size_k2 = layer.w2_weight_packed.shape[2] size_k13 = layer.w13_weight_packed.shape[2] - num_experts = layer.w13_g_idx.shape[0] - device = layer.w13_g_idx.device - layer.w13_g_idx = torch.nn.Parameter( - torch.empty((num_experts, 0), dtype=torch.int32, device=device), - requires_grad=False, - ) - layer.w2_g_idx = torch.nn.Parameter( - torch.empty((num_experts, 0), dtype=torch.int32, device=device), - requires_grad=False, - ) - layer.w13_g_idx_sort_indices = torch.nn.Parameter( - torch.empty((num_experts, 0), dtype=torch.int32, device=device), - requires_grad=False, - ) - layer.w2_g_idx_sort_indices = torch.nn.Parameter( - torch.empty((num_experts, 0), dtype=torch.int32, device=device), - requires_grad=False, - ) + num_experts = layer.w13_weight_g_idx.shape[0] + device = layer.w13_weight_g_idx.device + + # when running models with grouped act order, + # resort to g_idx values provided in checkpoint + if self.actorder == "group": + w13_g_idx_sort_indices = torch.empty_like(layer.w13_weight_g_idx) + w2_g_idx_sort_indices = torch.empty_like(layer.w2_weight_g_idx) + w13_sorted_g_idx = torch.empty_like(layer.w13_weight_g_idx) + w2_sorted_g_idx = torch.empty_like(layer.w2_weight_g_idx) + + for e in range(num_experts): + w13_g_idx_sort_indices[e] = torch.argsort( + layer.w13_weight_g_idx[e]).to(torch.int32) + w2_g_idx_sort_indices[e] = torch.argsort( + layer.w2_weight_g_idx[e]).to(torch.int32) + w13_sorted_g_idx[e] = layer.w13_weight_g_idx[e][ + w13_g_idx_sort_indices[e]] + w2_sorted_g_idx[e] = layer.w2_weight_g_idx[e][ + w2_g_idx_sort_indices[e]] + + replace_parameter(layer, "w13_weight_g_idx", w13_sorted_g_idx) + replace_parameter(layer, "w2_weight_g_idx", w2_sorted_g_idx) + replace_parameter(layer, "w13_g_idx_sort_indices", + w13_g_idx_sort_indices) + replace_parameter(layer, "w2_g_idx_sort_indices", + w2_g_idx_sort_indices) + + else: + layer.w13_weight_g_idx = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, + device=device), + requires_grad=False, + ) + layer.w2_weight_g_idx = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, + device=device), + requires_grad=False, + ) + layer.w13_g_idx_sort_indices = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, + device=device), + requires_grad=False, + ) + layer.w2_g_idx_sort_indices = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, + device=device), + requires_grad=False, + ) marlin_w13_qweight = ops.gptq_marlin_moe_repack( layer.w13_weight_packed, @@ -511,9 +564,9 @@ def apply( router_logits, topk_weights, topk_ids, - g_idx1=layer.w13_g_idx, - g_idx2=layer.w2_g_idx, + g_idx1=layer.w13_weight_g_idx, + g_idx2=layer.w2_weight_g_idx, sort_indices1=layer.w13_g_idx_sort_indices, sort_indices2=layer.w2_g_idx_sort_indices, num_bits=self.num_bits, - ) + is_k_full=self.is_k_full) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py index 61d1c911cd1ad..2e1b5e3c2d3b1 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py @@ -62,7 +62,7 @@ def create_weights(self, layer: torch.nn.Module, input_size: int, **kwargs): assert params_dtype == torch.float16, ( - "float16 is required for marlin24 compressd models. Set dtype=torch.float16" # noqa: E501 + "float16 is required for marlin24 compressed models. Set dtype=torch.float16" # noqa: E501 ) pack_factor = 32 // self.quant_type.size_bits diff --git a/vllm/model_executor/layers/quantization/experts_int8.py b/vllm/model_executor/layers/quantization/experts_int8.py index 209f12c6dfec9..100cbfa4c9598 100644 --- a/vllm/model_executor/layers/quantization/experts_int8.py +++ b/vllm/model_executor/layers/quantization/experts_int8.py @@ -52,7 +52,7 @@ def __init__(self, quant_config: ExpertsInt8Config): self.quant_config = quant_config def create_weights(self, layer: torch.nn.Module, num_experts: int, - hidden_size: int, intermediate_size: int, + hidden_size: int, intermediate_size_per_partition: int, params_dtype: torch.dtype, **extra_weight_attrs): int8_dtype = torch.int8 @@ -64,26 +64,29 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, extra_weight_attrs['weight_loader'] = wrapped_weight_loader # Fused gate_up_proj (column parallel) - w13_weight = torch.nn.Parameter(torch.empty(num_experts, - 2 * intermediate_size, - hidden_size, - dtype=int8_dtype), + w13_weight = torch.nn.Parameter(torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size, + dtype=int8_dtype), requires_grad=False) layer.register_parameter("w13_weight", w13_weight) set_weight_attrs(w13_weight, extra_weight_attrs) # down_proj (row parallel) - w2_weight = torch.nn.Parameter(torch.empty(num_experts, - hidden_size, - intermediate_size, - dtype=int8_dtype), + w2_weight = torch.nn.Parameter(torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition, + dtype=int8_dtype), requires_grad=False) layer.register_parameter("w2_weight", w2_weight) set_weight_attrs(w2_weight, extra_weight_attrs) - w13_scale = torch.nn.Parameter(torch.zeros(num_experts, - 2 * intermediate_size, - dtype=torch.float32), + w13_scale = torch.nn.Parameter(torch.zeros( + num_experts, + 2 * intermediate_size_per_partition, + dtype=torch.float32), requires_grad=False) layer.register_parameter("w13_scale", w13_scale) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 26dd5df4e55b2..21d4355b36ab0 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -386,8 +386,8 @@ def __init__(self, quant_config: Fp8Config): self.block_quant = self.quant_config.weight_block_size is not None def create_weights(self, layer: Module, num_experts: int, hidden_size: int, - intermediate_size: int, params_dtype: torch.dtype, - **extra_weight_attrs): + intermediate_size_per_partition: int, + params_dtype: torch.dtype, **extra_weight_attrs): if self.quant_config.is_checkpoint_fp8_serialized: params_dtype = torch.float8_e4m3fn @@ -402,30 +402,34 @@ def create_weights(self, layer: Module, num_experts: int, hidden_size: int, # scales, the output_size of the weights for both the gate and up # layers must be divisible by block_n. # Required by column parallel or enabling merged weights - if intermediate_size % block_n != 0: + if intermediate_size_per_partition % block_n != 0: raise ValueError( f"The output_size of gate's and up's weight = " - f"{intermediate_size} is not divisible by " + f"{intermediate_size_per_partition} is not divisible by " f"weight quantization block_n = {block_n}.") - if (tp_size > 1 and intermediate_size % block_k != 0): + if (tp_size > 1 + and intermediate_size_per_partition % block_k != 0): # Required by row parallel - raise ValueError(f"The input_size of down's weight = " - f"{intermediate_size} is not divisible by " - f"weight quantization block_k = {block_k}.") + raise ValueError( + f"The input_size of down's weight = " + f"{intermediate_size_per_partition} is not divisible by " + f"weight quantization block_k = {block_k}.") # WEIGHTS - w13_weight = torch.nn.Parameter(torch.empty(num_experts, - 2 * intermediate_size, - hidden_size, - dtype=params_dtype), + w13_weight = torch.nn.Parameter(torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size, + dtype=params_dtype), requires_grad=False) layer.register_parameter("w13_weight", w13_weight) set_weight_attrs(w13_weight, extra_weight_attrs) - w2_weight = torch.nn.Parameter(torch.empty(num_experts, - hidden_size, - intermediate_size, - dtype=params_dtype), + w2_weight = torch.nn.Parameter(torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition, + dtype=params_dtype), requires_grad=False) layer.register_parameter("w2_weight", w2_weight) set_weight_attrs(w2_weight, extra_weight_attrs) @@ -446,7 +450,8 @@ def create_weights(self, layer: Module, num_experts: int, hidden_size: int, w13_weight_scale = torch.nn.Parameter( torch.ones( num_experts, - 2 * ((intermediate_size + block_n - 1) // block_n), + 2 * ((intermediate_size_per_partition + block_n - 1) // + block_n), (hidden_size + block_k - 1) // block_k, dtype=torch.float32, ), @@ -456,7 +461,7 @@ def create_weights(self, layer: Module, num_experts: int, hidden_size: int, torch.ones( num_experts, (hidden_size + block_n - 1) // block_n, - (intermediate_size + block_k - 1) // block_k, + (intermediate_size_per_partition + block_k - 1) // block_k, dtype=torch.float32, ), requires_grad=False, diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 2dbfca9b07690..4dc4b052b0410 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -317,7 +317,7 @@ def create_weights( layer: torch.nn.Module, num_experts: int, hidden_size: int, - intermediate_size: int, + intermediate_size_per_partition: int, params_dtype: torch.dtype, **extra_weight_attrs, ): @@ -326,7 +326,8 @@ def create_weights( # Supports only sym for now (no zp) if self.quant_config.group_size != -1: scales_size13 = hidden_size // self.quant_config.group_size - scales_size2 = intermediate_size // self.quant_config.group_size + scales_size2 = (intermediate_size_per_partition // + self.quant_config.group_size) strategy = FusedMoeWeightScaleSupported.GROUP.value else: scales_size13 = 1 @@ -342,7 +343,7 @@ def create_weights( torch.empty( num_experts, hidden_size // self.quant_config.pack_factor, - 2 * intermediate_size, + 2 * intermediate_size_per_partition, dtype=torch.int32, ), requires_grad=False, @@ -353,7 +354,8 @@ def create_weights( w2_qweight = torch.nn.Parameter( torch.empty( num_experts, - intermediate_size // self.quant_config.pack_factor, + intermediate_size_per_partition // + self.quant_config.pack_factor, hidden_size, dtype=torch.int32, ), @@ -365,7 +367,7 @@ def create_weights( w13_scales = torch.nn.Parameter( torch.empty(num_experts, scales_size13, - 2 * intermediate_size, + 2 * intermediate_size_per_partition, dtype=torch.half), requires_grad=False, ) @@ -385,7 +387,8 @@ def create_weights( w13_qzeros = torch.nn.Parameter( torch.empty(num_experts, scales_size13, - 2 * intermediate_size // self.quant_config.pack_factor, + 2 * intermediate_size_per_partition // + self.quant_config.pack_factor, dtype=params_dtype), requires_grad=False, ) @@ -414,7 +417,7 @@ def create_weights( w2_g_idx = torch.nn.Parameter( torch.empty( num_experts, - intermediate_size, + intermediate_size_per_partition, dtype=torch.int32, ), requires_grad=False, @@ -435,7 +438,7 @@ def create_weights( w2_g_idx_sort_indices = torch.nn.Parameter( torch.empty( num_experts, - intermediate_size, + intermediate_size_per_partition, dtype=torch.int32, ), requires_grad=False, diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py index 3e19247300808..68a3954540763 100644 --- a/vllm/model_executor/layers/quantization/quark/quark_moe.py +++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py @@ -60,24 +60,26 @@ def __init__(self, weight_config: Dict[str, Any], input_config: Dict[str, self.static_input_scales = not self.input_quant.get("is_dynamic") def create_weights(self, layer: torch.nn.Module, num_experts: int, - hidden_size: int, intermediate_size: int, + hidden_size: int, intermediate_size_per_partition: int, params_dtype: torch.dtype, **extra_weight_attrs): params_dtype = torch.float8_e4m3fn # WEIGHTS - w13_weight = torch.nn.Parameter(torch.empty(num_experts, - 2 * intermediate_size, - hidden_size, - dtype=params_dtype), + w13_weight = torch.nn.Parameter(torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size, + dtype=params_dtype), requires_grad=False) layer.register_parameter("w13_weight", w13_weight) set_weight_attrs(w13_weight, extra_weight_attrs) - w2_weight = torch.nn.Parameter(torch.empty(num_experts, - hidden_size, - intermediate_size, - dtype=params_dtype), + w2_weight = torch.nn.Parameter(torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition, + dtype=params_dtype), requires_grad=False) layer.register_parameter("w2_weight", w2_weight) set_weight_attrs(w2_weight, extra_weight_attrs)