diff --git a/CMakeLists.txt b/CMakeLists.txt index ead539993d98c..87f9007620dc4 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -220,6 +220,7 @@ set(VLLM_EXT_SRC "csrc/quantization/fp8/common.cu" "csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu" "csrc/quantization/gguf/gguf_kernel.cu" + "csrc/quantization/activation_kernels.cu" "csrc/cuda_utils_kernels.cu" "csrc/prepare_inputs/advance_step.cu" "csrc/torch_bindings.cpp") diff --git a/csrc/core/math.hpp b/csrc/core/math.hpp index ba9f40a230c8e..f80fa0667e7be 100644 --- a/csrc/core/math.hpp +++ b/csrc/core/math.hpp @@ -1,7 +1,28 @@ +#pragma once + #include #include inline uint32_t next_pow_2(uint32_t const num) { if (num <= 1) return num; return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1)); -} \ No newline at end of file +} + +template +static inline constexpr auto div_ceil(A a, B b) { + return (a + b - 1) / b; +} + +// Round a down to the next multiple of b. The caller is responsible for making +// sure that b is non-zero +template +inline constexpr T round_to_previous_multiple_of(T a, T b) { + return a % b == 0 ? a : (a / b) * b; +} + +// Round a up to the next multiple of b. The caller is responsible for making +// sure that b is non-zero +template +inline constexpr T round_to_next_multiple_of(T a, T b) { + return a % b == 0 ? a : ((a / b) + 1) * b; +} diff --git a/csrc/ops.h b/csrc/ops.h index 346898964010d..eba4a93ed9d9b 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -88,6 +88,9 @@ void batched_rotary_embedding(torch::Tensor& positions, torch::Tensor& query, void silu_and_mul(torch::Tensor& out, torch::Tensor& input); +void silu_and_mul_quant(torch::Tensor& out, torch::Tensor& input, + torch::Tensor& scale); + void mul_and_silu(torch::Tensor& out, torch::Tensor& input); void gelu_and_mul(torch::Tensor& out, torch::Tensor& input); diff --git a/csrc/quantization/activation_kernels.cu b/csrc/quantization/activation_kernels.cu new file mode 100644 index 0000000000000..9edebd9c3caae --- /dev/null +++ b/csrc/quantization/activation_kernels.cu @@ -0,0 +1,113 @@ +#include +#include +#include + +#include +#include "fp8/common.cuh" +#include "../core/math.hpp" +#include "../cuda_compat.h" +#include "../dispatch_utils.h" + +namespace vllm { + +template +__device__ __forceinline__ T silu_kernel(const T& x) { + // x * sigmoid(x) + return (T)(((float)x) / (1.0f + expf((float)-x))); +} + +// Activation and gating kernel template. +template +__global__ void act_and_mul_quant_kernel( + FP8_TYPE* __restrict__ out, // [..., d] + const scalar_t* __restrict__ input, // [..., 2, d] + const float* scale, const int d) { + const int32_t blocks_per_token = gridDim.y; + + const int32_t elems_per_128bit_load = (128 / 8) / sizeof(scalar_t); + + // We don't expect the hidden dimension to exceed 32 bits so int32 should + // be safe here. + const int32_t tgt_elems_per_block = div_ceil(d, blocks_per_token); + const int32_t elems_per_block = + round_to_next_multiple_of(tgt_elems_per_block, elems_per_128bit_load); + const int32_t block_start = blockIdx.y * elems_per_block; + int32_t block_end = block_start + elems_per_block; + block_end = block_end > d ? d : block_end; + + // token_idx is 64 bit to prevent 32 bit overflow when the number of tokens + // is very large + const int64_t token_idx = blockIdx.x; + const scalar_t* __restrict__ x_ptr = input + token_idx * 2 * d; + const scalar_t* __restrict__ y_ptr = input + token_idx * 2 * d + d; + FP8_TYPE* __restrict__ out_ptr = out + token_idx * d; + + // 128-bit vectorized code + const int32_t vec_loop_end = + round_to_previous_multiple_of(elems_per_128bit_load, block_end); + const int32_t vec_end_idx = vec_loop_end / elems_per_128bit_load; + const int32_t vec_start_idx = block_start / elems_per_128bit_load; + + const int4* __restrict__ x_128bit_ptr = reinterpret_cast(x_ptr); + const int4* __restrict__ y_128bit_ptr = reinterpret_cast(y_ptr); + int2* __restrict__ out_128bit_ptr = reinterpret_cast(out_ptr); + + float inverted_scale = 1 / *scale; +#pragma unroll + for (int32_t vec_idx = vec_start_idx + threadIdx.x; vec_idx < vec_end_idx; + vec_idx += blockDim.x) { + const int4 x_128bit = VLLM_LDG(&x_128bit_ptr[vec_idx]); + const int4 y_128bit = VLLM_LDG(&y_128bit_ptr[vec_idx]); + using scalar_128bit_vec_t = std::array; + using scalar_64bit_vec_t = std::array; + + scalar_64bit_vec_t out_vec; + const auto x_vec = reinterpret_cast(x_128bit); + const auto y_vec = reinterpret_cast(y_128bit); + +#pragma unroll + for (int i = 0; i < elems_per_128bit_load; i++) { + out_vec[i] = scaled_fp8_conversion(ACT_FN(x_vec[i]) * y_vec[i], + inverted_scale); + } + + out_128bit_ptr[vec_idx] = reinterpret_cast(out_vec); + } + + // Scalar cleanup code + if (block_end > vec_loop_end) { + for (int64_t idx = vec_loop_end + threadIdx.x; idx < block_end; + idx += blockDim.x) { + const scalar_t x = VLLM_LDG(&x_ptr[idx]); + const scalar_t y = VLLM_LDG(&y_ptr[idx]); + out_ptr[idx] = scaled_fp8_conversion(ACT_FN(x) * y, inverted_scale); + } + } +} +} // namespace vllm + +// Launch activation, gating, and quantize kernel. +#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL) \ + int d = input.size(-1) / 2; \ + int64_t num_tokens = input.numel() / input.size(-1); \ + dim3 grid(num_tokens, num_tokens > 16 ? num_tokens > 32 ? 1 : 2 : 4); \ + dim3 block(std::min(d, 512)); \ + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ + VLLM_DISPATCH_FLOATING_TYPES( \ + input.scalar_type(), "act_and_mul_kernel", [&] { \ + vllm::act_and_mul_quant_kernel> \ + <<>>(out.data_ptr(), \ + input.data_ptr(), \ + scale.data_ptr(), d); \ + }); + +void silu_and_mul_quant(torch::Tensor& out, // [..., d] + torch::Tensor& input, // [..., 2 * d] + torch::Tensor& scale) { + TORCH_CHECK(out.dtype() == torch::kFloat8_e4m3fn); + TORCH_CHECK(input.dtype() == torch::kFloat16 || + input.dtype() == torch::kBFloat16); + TORCH_CHECK(input.size(-1) % 2 == 0); + LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel); +} diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index ec63170d511f0..7a9bee717de08 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -52,9 +52,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // Activation ops // Activation function used in SwiGLU. - ops.def("silu_and_mul(Tensor! out, Tensor input) -> ()"); + ops.def("silu_and_mul(Tensor! result, Tensor input) -> ()"); ops.impl("silu_and_mul", torch::kCUDA, &silu_and_mul); + ops.def( + "silu_and_mul_quant(Tensor! result, Tensor input, Tensor scale) -> ()"); + ops.impl("silu_and_mul_quant", torch::kCUDA, &silu_and_mul_quant); + ops.def("mul_and_silu(Tensor! out, Tensor input) -> ()"); ops.impl("mul_and_silu", torch::kCUDA, &mul_and_silu); diff --git a/tests/compile/test_functionalization.py b/tests/compile/test_functionalization.py index ea3aaee9565ec..1c8fe53baa0fb 100644 --- a/tests/compile/test_functionalization.py +++ b/tests/compile/test_functionalization.py @@ -3,6 +3,7 @@ import vllm.envs as envs from vllm import LLM, SamplingParams +from vllm.compilation.activation_quant_fusion import ActivationQuantFusionPass from vllm.compilation.fix_functionalization import FixFunctionalizationPass from vllm.compilation.fusion import (FUSED_OPS, FusionPass, QuantKey, kFp8DynamicTokenSym, kFp8StaticTensorSym) @@ -15,18 +16,13 @@ OPS_IN_MODEL = [ torch.ops._C.rotary_embedding.default, torch.ops._C.fused_add_rms_norm.default, - torch.ops._C.silu_and_mul.default, ] RMS_OP = torch.ops._C.rms_norm.default -RMS_QUANT_OPS = { - "static_fp8": [ - torch.ops._C.rms_norm_static_fp8_quant.default, - torch.ops._C.fused_add_rms_norm_static_fp8_quant.default - ], -} +SILU_MUL_OP = torch.ops._C.silu_and_mul.default +SILU_MUL_QUANT_OP = torch.ops._C.silu_and_mul_quant.default prompts = [ "Hello, my name is", "The president of the United States is", @@ -51,8 +47,13 @@ def test_fix_functionalization(model: str, quant_key: QuantKey, enable_reshape=True) reshape_pass = RedundantReshapesPass(config) fusion_pass = FusionPass.instance(config) + act_quant_fusion_pass = ActivationQuantFusionPass.instance(config) - passes = [reshape_pass, fusion_pass] if do_fusion else [reshape_pass] + passes = [ + reshape_pass, + fusion_pass, + act_quant_fusion_pass, + ] if do_fusion else [reshape_pass] func_pass = FixFunctionalizationPass(config) backend_func = TestBackend(*passes, func_pass) backend_no_func = TestBackend(*passes) @@ -75,6 +76,7 @@ def test_fix_functionalization(model: str, quant_key: QuantKey, model_runner.model = torch.compile(orig_model, fullgraph=True, backend=backend_no_func) + gen_no_func = llm.generate(prompts, sampling_params) for output_func, output_no_func in zip(gen_func, gen_no_func): @@ -84,7 +86,12 @@ def test_fix_functionalization(model: str, quant_key: QuantKey, # and replaced by fused quantized ops in RMS_QUANT_OPS. rms_ops = [FUSED_OPS[(quant_key, True)], FUSED_OPS[(quant_key, False)] ] if do_fusion else [RMS_OP] - ops = OPS_IN_MODEL + rms_ops + silu_mul_ops = [SILU_MUL_QUANT_OP] if do_fusion and \ + quant_key == kFp8StaticTensorSym else [ + SILU_MUL_OP + ] + + ops = OPS_IN_MODEL + rms_ops + silu_mul_ops for op in ops: find_auto_fn(backend_no_func.graph_post_pass.nodes, op) diff --git a/tests/compile/test_silu_mul_quant_fusion.py b/tests/compile/test_silu_mul_quant_fusion.py new file mode 100644 index 0000000000000..7a6fb8725420c --- /dev/null +++ b/tests/compile/test_silu_mul_quant_fusion.py @@ -0,0 +1,73 @@ +import pytest +import torch + +import vllm.envs as envs +from vllm._custom_ops import scaled_fp8_quant +from vllm.compilation.activation_quant_fusion import ActivationQuantFusionPass +from vllm.compilation.fusion import find_auto_fn, find_auto_fn_maybe +from vllm.compilation.reshapes import RedundantReshapesPass +from vllm.config import CompilationConfig +from vllm.model_executor.layers.activation import SiluAndMul + +from .backend import TestBackend + + +class TestModel(torch.nn.Module): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.silu_and_mul = SiluAndMul() + self.scale = torch.rand(1, dtype=torch.float32) + + def forward(self, x): + y = self.silu_and_mul(x) + x2 = scaled_fp8_quant(y, self.scale) + return x2 + + +@pytest.mark.parametrize("num_tokens", [256]) +@pytest.mark.parametrize("hidden_size", [64]) +@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda", + reason="Only test on CUDA") +def test_fusion_silu_and_mul_quant(num_tokens, hidden_size): + torch.set_default_device("cuda") + torch.set_default_dtype(torch.float16) + + # Reshape pass is needed for the fusion pass to work + config = CompilationConfig.PassConfig(enable_fusion=True, + enable_reshape=True) + reshape_pass = RedundantReshapesPass(config) + fusion_pass = ActivationQuantFusionPass.instance(config) + + backend = TestBackend(reshape_pass, fusion_pass) + model = TestModel() + + # First dimension dynamic + x = torch.rand(num_tokens, hidden_size) + torch._dynamo.mark_dynamic(x, 0) + + result = model(x) + + model2 = torch.compile(model, backend=backend) + result2 = model2(x) + + # Check that it gives the same answer + torch.testing.assert_close(result[0].to(dtype=torch.float16), + result2[0].to(dtype=torch.float16), + atol=1e-3, + rtol=1e-3) + + # Check substitution worked + pre_nodes = backend.graph_pre_pass.nodes + post_nodes = backend.graph_post_pass.nodes + + silu_and_mul_quant = torch.ops._C.silu_and_mul_quant.default + fp8_quant = torch.ops._C.static_scaled_fp8_quant.default + + # In pre-nodes, fp8 quant should be present and fused kernels should not + assert find_auto_fn_maybe(pre_nodes, silu_and_mul_quant) is None + find_auto_fn(pre_nodes, fp8_quant) + + # In post-nodes, fused kernels should be present and fp8 quant should not + find_auto_fn(post_nodes, silu_and_mul_quant) + assert find_auto_fn_maybe(post_nodes, fp8_quant) is None diff --git a/tests/kernels/test_fused_quant_activation.py b/tests/kernels/test_fused_quant_activation.py new file mode 100644 index 0000000000000..fdb86c13adbb7 --- /dev/null +++ b/tests/kernels/test_fused_quant_activation.py @@ -0,0 +1,68 @@ +import pytest +import torch + +import vllm._custom_ops as ops +from tests.kernels.utils import opcheck +from vllm.model_executor.layers.activation import SiluAndMul + +DTYPES = [torch.bfloat16, torch.float16] +QUANT_DTYPES = [torch.float8_e4m3fn] +NUM_TOKENS = [1, 17, 86, 1234, 3045] # Arbitrary values for testing +HIDDEN_SIZES = [16, 48, 128, 1562, 4096] # Arbitrary values for testing +SEEDS = [0] +CUDA_DEVICES = [ + f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) +] + + +def ref_impl(silu_and_mul: SiluAndMul, x: torch.Tensor, + scale: torch.Tensor) -> torch.Tensor: + silu_and_mul_out = silu_and_mul.forward_native(x) + out, scales = ops.scaled_fp8_quant(silu_and_mul_out, scale) + return out + + +def ops_impl(x: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: + out_shape = (x.shape[0], x.shape[1] // 2) + out = torch.empty(out_shape, + dtype=torch.torch.float8_e4m3fn, + device=x.device) + torch.ops._C.silu_and_mul_quant(out, x, scale) + return out + + +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("quant_dtype", QUANT_DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@torch.inference_mode() +def test_silu_and_mul( + num_tokens: int, + hidden_size: int, + dtype: torch.dtype, + quant_dtype: torch.dtype, + seed: int, + device: str, +) -> None: + torch.random.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.set_default_device(device) + + layer = SiluAndMul() + + # Make inputs + scale = (torch.randn((1), device=device, dtype=torch.float32)) + x = torch.randn(num_tokens, hidden_size, dtype=dtype) + + ref_out = ref_impl(layer, x, scale) + ops_out = ops_impl(x, scale) + + assert ref_out.dtype == quant_dtype + assert ops_out.dtype == quant_dtype + assert ref_out.shape == ops_out.shape + assert torch.allclose(ref_out.to(dtype=torch.float32), + ops_out.to(dtype=torch.float32)) + opcheck(torch.ops._C.silu_and_mul_quant, (ops_out, x, scale)) diff --git a/vllm/compilation/activation_quant_fusion.py b/vllm/compilation/activation_quant_fusion.py new file mode 100644 index 0000000000000..9aa50749fc544 --- /dev/null +++ b/vllm/compilation/activation_quant_fusion.py @@ -0,0 +1,104 @@ +from typing import Optional + +import torch +from torch._higher_order_ops.auto_functionalize import auto_functionalized +from torch._inductor.pattern_matcher import (PatternMatcherPass, fwd_only, + register_replacement) + +from vllm.config import CompilationConfig +from vllm.logger import init_logger + +from .vllm_inductor_pass import VllmInductorPass + +logger = init_logger(__name__) + + +def silu_mul_pattern_static(result: torch.Tensor, + result_silu_mul: torch.Tensor, input: torch.Tensor, + scale: torch.Tensor): + at1 = auto_functionalized(torch.ops._C.silu_and_mul.default, + result=result_silu_mul, + input=input) + at2 = auto_functionalized(torch.ops._C.static_scaled_fp8_quant.default, + result=result, + input=at1[1], + scale=scale) + return at2[1] + + +def silu_mul_replacement_static(result: torch.Tensor, + result_silu_mul: torch.Tensor, + input: torch.Tensor, scale: torch.Tensor): + at = auto_functionalized(torch.ops._C.silu_and_mul_quant.default, + result=result, + input=input, + scale=scale) + return at[1] + + +def empty_bf16(*args, **kwargs): + return torch.empty(*args, **kwargs, dtype=torch.bfloat16, device="cuda") + + +def empty_fp8(*args, **kwargs): + fp8 = torch.float8_e4m3fn + return torch.empty(*args, **kwargs, dtype=fp8, device="cuda") + + +def empty_fp32(*args, **kwargs): + return torch.empty(*args, **kwargs, dtype=torch.float32, device="cuda") + + +class ActivationQuantFusionPass(VllmInductorPass): + """ + This pass fuses a pre-defined set of custom ops into fused ops. + It uses the torch pattern matcher to find the patterns and replace them. + + Because patterns can only be registered once, the pass is a singleton. + This will be addressed in a future version of PyTorch: + https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980 + """ + + _instance: 'Optional[ActivationQuantFusionPass]' = None + + @classmethod + def instance(cls, config: CompilationConfig.PassConfig): + """ + Get the singleton instance of the ActivationQuantFusionPass. + If the instance exists, the config is updated but + initialization is not repeated. + """ + if cls._instance is None: + cls._instance = ActivationQuantFusionPass(config) + else: + cls._instance.config = config + return cls._instance + + def __init__(self, config: CompilationConfig.PassConfig): + assert self.__class__._instance is None, \ + "ActivationQuantFusionPass singleton instance already exists" + super().__init__(config) + + self.patterns: PatternMatcherPass = PatternMatcherPass( + pass_name="activation_quant_fusion_pass") + + inputs = [ + empty_fp8(5, 4), # Quant output + empty_bf16(5, 4), # Silu_and_mul output + empty_bf16(5, 4), # Input + empty_fp32(1, 1) # Scale + ] + register_replacement(silu_mul_pattern_static, + silu_mul_replacement_static, inputs, fwd_only, + self.patterns) + + def __call__(self, graph: torch.fx.Graph): + self.begin() + self.dump_graph(graph, "before_act_quant_fusion") + + count = self.patterns.apply(graph) + logger.debug("Replaced %s patterns in ActivationQuantFusionPass", + count) + + self.dump_graph(graph, "after_act_quant_fusion") + self.end_and_log() diff --git a/vllm/compilation/fix_functionalization.py b/vllm/compilation/fix_functionalization.py index e15d7b315c50f..5525be85222d7 100644 --- a/vllm/compilation/fix_functionalization.py +++ b/vllm/compilation/fix_functionalization.py @@ -66,18 +66,25 @@ def __call__(self, graph: torch.fx.Graph): self.defunctionalize(graph, node, mutated_args) elif at_target in [ torch.ops._C.rms_norm.default, - torch.ops._C.rms_norm_static_fp8_quant.default + torch.ops._C.rms_norm_static_fp8_quant.default, ]: mutated_args = {1: 'result'} self.defunctionalize(graph, node, mutated_args) - + # For some reason we need to specify the args for both + # silu_and_mul and silu_and_mul_quant. The kwargs + # pathway gets the wrong answer. elif at_target == torch.ops._C.silu_and_mul.default: - mutated_args = {1: 'out'} - # Because we have an 'out', need to specify args directly + mutated_args = {1: 'result'} + self.defunctionalize(graph, + node, + mutated_args, + args=('result', 'input')) + elif at_target == torch.ops._C.silu_and_mul_quant.default: + mutated_args = {1: 'result'} self.defunctionalize(graph, node, mutated_args, - args=('out', 'input')) + args=('result', 'input', 'scale')) else: continue # skip the count diff --git a/vllm/compilation/pass_manager.py b/vllm/compilation/pass_manager.py index 34f5f355798b2..f9310e59802cc 100644 --- a/vllm/compilation/pass_manager.py +++ b/vllm/compilation/pass_manager.py @@ -5,6 +5,7 @@ from vllm.config import CompilationConfig from vllm.logger import init_logger +from .activation_quant_fusion import ActivationQuantFusionPass from .fix_functionalization import FixFunctionalizationPass from .fusion import FusionPass from .inductor_pass import InductorPass @@ -46,6 +47,7 @@ def configure(self, pass_config: CompilationConfig.PassConfig): if pass_config.enable_fusion: self.passes += [FusionPass.instance(pass_config)] + self.passes += [ActivationQuantFusionPass.instance(pass_config)] self.fix_functionalization = FixFunctionalizationPass(pass_config)