diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 14e06df1f6c8d..578271b11fcf3 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -35,10 +35,6 @@ def register_fake(fn): # activation ops -def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: - torch.ops._C.silu_and_mul(out, x) - - def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: torch.ops._C.gelu_and_mul(out, x) diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py index 34d65ed51ef3f..46d4670bfe4f9 100644 --- a/vllm/model_executor/layers/activation.py +++ b/vllm/model_executor/layers/activation.py @@ -10,6 +10,7 @@ get_tensor_model_parallel_world_size) from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.utils import set_weight_attrs +from vllm.platforms import current_platform from vllm.utils import LazyDict @@ -58,27 +59,31 @@ class SiluAndMul(CustomOp): return: (num_tokens, d) or (batch_size, seq_len, d) """ + def __init__(self): + super().__init__() + if current_platform.is_cuda_alike(): + self.op = torch.ops._C.silu_and_mul + elif current_platform.is_xpu(): + import intel_extension_for_pytorch as ipex + self.op = ipex.llm.functional.silu_and_mul + def forward_native(self, x: torch.Tensor) -> torch.Tensor: """PyTorch-native implementation equivalent to forward().""" d = x.shape[-1] // 2 return F.silu(x[..., :d]) * x[..., d:] def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: - from vllm import _custom_ops as ops - d = x.shape[-1] // 2 output_shape = (x.shape[:-1] + (d, )) out = torch.empty(output_shape, dtype=x.dtype, device=x.device) - ops.silu_and_mul(out, x) + self.op(out, x) return out def forward_xpu(self, x: torch.Tensor) -> torch.Tensor: - from vllm._ipex_ops import ipex_ops as ops - d = x.shape[-1] // 2 output_shape = (x.shape[:-1] + (d, )) out = torch.empty(output_shape, dtype=x.dtype, device=x.device) - ops.silu_and_mul(out, x) + self.op(out, x) return out diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index 4741d69de11ac..87993267c05b5 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -4,7 +4,6 @@ import torch -from vllm import _custom_ops as ops from vllm.model_executor.layers.fused_moe.fused_moe import ( fused_topk, moe_align_block_size, try_get_optimal_moe_config) from vllm.scalar_type import scalar_types @@ -301,7 +300,8 @@ def fused_marlin_moe( False, ) - ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, 2 * N)) + torch.ops._C.silu_and_mul(intermediate_cache2, + intermediate_cache1.view(-1, 2 * N)) intermediate_cache3 = torch.ops._moe_C.marlin_gemm_moe( intermediate_cache2, diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 4101facbe7874..1bb6bc753d37c 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -753,7 +753,8 @@ def fused_experts_impl(hidden_states: torch.Tensor, use_int8_w8a16=use_int8_w8a16, block_shape=block_shape) - ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) + torch.ops._C.silu_and_mul(intermediate_cache2, + intermediate_cache1.view(-1, N)) invoke_fused_moe_kernel(intermediate_cache2, w2,