From 2b416bf22e3c5a3034a0a46f0e5139f837b05e03 Mon Sep 17 00:00:00 2001 From: zhyncs Date: Thu, 23 Jan 2025 10:33:39 -0800 Subject: [PATCH 1/2] feat: use sgl-kernel by default --- python/pyproject.toml | 2 +- python/sglang/srt/layers/activation.py | 14 ++++++++--- python/sglang/srt/layers/layernorm.py | 31 ++++++++++++++++++------- python/sglang/srt/layers/sampler.py | 25 ++++++++++++++------ python/sglang/srt/models/deepseek_v2.py | 14 +++++++++-- python/sglang/srt/models/minicpm3.py | 14 ++++++++--- python/sglang/srt/utils.py | 2 ++ 7 files changed, 77 insertions(+), 25 deletions(-) diff --git a/python/pyproject.toml b/python/pyproject.toml index 80cc0e9dc60..8e514ba8e46 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -27,7 +27,7 @@ runtime_common = [ ] srt = [ "sglang[runtime_common]", "cuda-python", - "sgl-kernel>=0.0.2.post14", "torch", "vllm==0.6.4.post1", + "sgl-kernel>=0.0.2.post16", "torch", "vllm==0.6.4.post1", "flashinfer==0.1.6" ] diff --git a/python/sglang/srt/layers/activation.py b/python/sglang/srt/layers/activation.py index ebb0652c5d2..7d7916cc5c8 100644 --- a/python/sglang/srt/layers/activation.py +++ b/python/sglang/srt/layers/activation.py @@ -20,10 +20,18 @@ import torch.nn as nn import torch.nn.functional as F -from sglang.srt.utils import is_flashinfer_available +from sglang.srt.utils import ( + enable_use_sgl_kernel_first, + is_cuda_available, + is_flashinfer_available, +) -if is_flashinfer_available(): - from flashinfer.activation import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul +if enable_use_sgl_kernel_first: + if is_cuda_available(): + from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul +else: + if is_flashinfer_available(): + from flashinfer.activation import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul from vllm.model_executor.custom_op import CustomOp diff --git a/python/sglang/srt/layers/layernorm.py b/python/sglang/srt/layers/layernorm.py index bd95b9bccce..20f79305371 100644 --- a/python/sglang/srt/layers/layernorm.py +++ b/python/sglang/srt/layers/layernorm.py @@ -19,15 +19,28 @@ import torch import torch.nn as nn -from sglang.srt.utils import is_flashinfer_available - -if is_flashinfer_available(): - from flashinfer.norm import ( - fused_add_rmsnorm, - gemma_fused_add_rmsnorm, - gemma_rmsnorm, - rmsnorm, - ) +from sglang.srt.utils import ( + enable_use_sgl_kernel_first, + is_cuda_available, + is_flashinfer_available, +) + +if enable_use_sgl_kernel_first: + if is_cuda_available(): + from sgl_kernel import ( + fused_add_rmsnorm, + gemma_fused_add_rmsnorm, + gemma_rmsnorm, + rmsnorm, + ) +else: + if is_flashinfer_available(): + from flashinfer.norm import ( + fused_add_rmsnorm, + gemma_fused_add_rmsnorm, + gemma_rmsnorm, + rmsnorm, + ) from vllm.model_executor.custom_op import CustomOp diff --git a/python/sglang/srt/layers/sampler.py b/python/sglang/srt/layers/sampler.py index 3173d533d16..bce66ba8b78 100644 --- a/python/sglang/srt/layers/sampler.py +++ b/python/sglang/srt/layers/sampler.py @@ -12,17 +12,28 @@ from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.utils import ( crash_on_warnings, + enable_use_sgl_kernel_first, get_bool_env_var, + is_cuda_available, is_flashinfer_available, ) -if is_flashinfer_available(): - from flashinfer.sampling import ( - min_p_sampling_from_probs, - top_k_renorm_prob, - top_k_top_p_sampling_from_probs, - top_p_renorm_prob, - ) +if enable_use_sgl_kernel_first: + if is_cuda_available(): + from sgl_kernel import ( + min_p_sampling_from_probs, + top_k_renorm_prob, + top_k_top_p_sampling_from_probs, + top_p_renorm_prob, + ) +else: + if is_flashinfer_available(): + from flashinfer.sampling import ( + min_p_sampling_from_probs, + top_k_renorm_prob, + top_k_top_p_sampling_from_probs, + top_p_renorm_prob, + ) logger = logging.getLogger(__name__) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 17d7fcf8924..651792b484b 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -60,8 +60,18 @@ is_hip_ = is_hip() -if is_flashinfer_available(): - from flashinfer import bmm_fp8 +from sglang.srt.utils import ( + enable_use_sgl_kernel_first, + is_cuda_available, + is_flashinfer_available, +) + +if enable_use_sgl_kernel_first: + if is_cuda_available(): + from sgl_kernel import bmm_fp8 +else: + if is_flashinfer_available(): + from flashinfer import bmm_fp8 class DeepseekV2MLP(nn.Module): diff --git a/python/sglang/srt/models/minicpm3.py b/python/sglang/srt/models/minicpm3.py index 118be8ff6c8..a09a26eaf8a 100644 --- a/python/sglang/srt/models/minicpm3.py +++ b/python/sglang/srt/models/minicpm3.py @@ -40,10 +40,18 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader -from sglang.srt.utils import is_flashinfer_available +from sglang.srt.utils import ( + enable_use_sgl_kernel_first, + is_cuda_available, + is_flashinfer_available, +) -if is_flashinfer_available(): - from flashinfer import bmm_fp8 +if enable_use_sgl_kernel_first: + if is_cuda_available(): + from sgl_kernel import bmm_fp8 +else: + if is_flashinfer_available(): + from flashinfer import bmm_fp8 class MiniCPM3MLP(nn.Module): diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 23dcb43d2d9..b0538193295 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -66,6 +66,8 @@ show_time_cost = False time_infos = {} +enable_use_sgl_kernel_first = bool(int(os.getenv("ENABLE_USE_SGL_KERNEL_FIRST", "1"))) + def is_hip() -> bool: """Return whether it is HIP on the AMD ROCm platform.""" From a1008a1dd5a66f4e8a12be8763a787bac2f5fc80 Mon Sep 17 00:00:00 2001 From: zhyncs Date: Thu, 23 Jan 2025 20:52:43 -0800 Subject: [PATCH 2/2] upd --- .github/workflows/pr-test.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml index c5eeeee3c14..5b9af3ca2e5 100644 --- a/.github/workflows/pr-test.yml +++ b/.github/workflows/pr-test.yml @@ -51,6 +51,7 @@ jobs: if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' runs-on: 1-gpu-runner strategy: + fail-fast: false matrix: range: [0-6, 6-15, 15-22, 22-32, 32-40, 40-100] steps: