Skip to content

Commit

Permalink
Allow disabling flashinfer sampling kernel (#778)
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy authored Jul 28, 2024
1 parent 30db99b commit 752e643
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 26 deletions.
7 changes: 5 additions & 2 deletions python/sglang/srt/layers/radix_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,11 @@
from sglang.global_config import global_config
from sglang.srt.layers.extend_attention import extend_attention_fwd
from sglang.srt.layers.token_attention import token_attention_fwd
from sglang.srt.managers.controller.model_runner import ForwardMode, InputMetadata
from sglang.srt.server import global_server_args_dict
from sglang.srt.managers.controller.model_runner import (
ForwardMode,
InputMetadata,
global_server_args_dict,
)


class RadixAttention(nn.Module):
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/layers/token_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import triton
import triton.language as tl

from sglang.srt.server import global_server_args_dict
from sglang.srt.managers.controller.infer_batch import global_server_args_dict

if global_server_args_dict.get("attention_reduce_in_fp32", False):
REDUCE_TRITON_TYPE = tl.float32
Expand Down
9 changes: 8 additions & 1 deletion python/sglang/srt/managers/controller/infer_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,13 @@

INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5

# Put some global args for easy access
global_server_args_dict = {
"disable_flashinfer": False,
"disable_flashinfer_sampling": False,
"attention_reduce_in_fp32": False,
}


class ForwardMode(IntEnum):
# Prefill a new sequence. This is deprecated now. "EXTEND" covers this case.
Expand Down Expand Up @@ -687,7 +694,7 @@ def sample(self, logits: torch.Tensor):
# TODO(lmzheng): apply penalty
probs = torch.softmax(logits, dim=-1)

if True:
if not global_server_args_dict["disable_flashinfer_sampling"]:
max_top_k_round, batch_size = 32, probs.shape[0]
uniform_samples = torch.rand(
(max_top_k_round, batch_size), device=probs.device
Expand Down
16 changes: 14 additions & 2 deletions python/sglang/srt/managers/controller/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,12 @@
from vllm.model_executor.models import ModelRegistry

from sglang.global_config import global_config
from sglang.srt.managers.controller.infer_batch import Batch, ForwardMode, InputMetadata
from sglang.srt.managers.controller.infer_batch import (
Batch,
ForwardMode,
InputMetadata,
global_server_args_dict,
)
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import (
Expand Down Expand Up @@ -60,7 +65,13 @@ def __init__(
self.nccl_port = nccl_port
self.server_args = server_args
self.is_multimodal_model = is_multimodal_model(self.model_config)
monkey_patch_vllm_dummy_weight_loader()
global_server_args_dict.update(
{
"disable_flashinfer": server_args.disable_flashinfer,
"disable_flashinfer_sampling": server_args.disable_flashinfer_sampling,
"attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
}
)

# Init torch distributed
torch.cuda.set_device(self.gpu_id)
Expand Down Expand Up @@ -108,6 +119,7 @@ def load_model(self):
f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
)

monkey_patch_vllm_dummy_weight_loader()
device_config = DeviceConfig()
load_config = LoadConfig(load_format=self.server_args.load_format)
vllm_model_config = VllmModelConfig(
Expand Down
13 changes: 0 additions & 13 deletions python/sglang/srt/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,6 @@
app = FastAPI()
tokenizer_manager = None

# Put some args for easily access
global_server_args_dict = {}


@app.get("/health")
async def health() -> Response:
Expand Down Expand Up @@ -150,14 +147,6 @@ def available_models():
return ModelList(data=model_cards)


def _set_global_server_args(server_args: ServerArgs):
global global_server_args_dict
global_server_args_dict = {
"disable_flashinfer": server_args.disable_flashinfer,
"attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
}


def _set_torch_compile_config():
# The following configurations are for torch compile optimizations
import torch._dynamo.config
Expand Down Expand Up @@ -213,8 +202,6 @@ def launch_server(
if server_args.enable_torch_compile:
_set_torch_compile_config()

_set_global_server_args(server_args)

# Allocate ports
server_args.port, server_args.additional_ports = allocate_init_ports(
server_args.port,
Expand Down
20 changes: 13 additions & 7 deletions python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,14 @@ class ServerArgs:

# Optimization/debug options
disable_flashinfer: bool = False
disable_flashinfer_sampling: bool = False
disable_radix_cache: bool = False
disable_regex_jump_forward: bool = False
disable_cuda_graph: bool = False
disable_disk_cache: bool = False
enable_torch_compile: bool = False
attention_reduce_in_fp32: bool = False
enable_p2p_check: bool = False
attention_reduce_in_fp32: bool = False
efficient_weight_load: bool = False

# Distributed args
Expand Down Expand Up @@ -303,7 +304,12 @@ def add_cli_args(parser: argparse.ArgumentParser):
parser.add_argument(
"--disable-flashinfer",
action="store_true",
help="Disable flashinfer inference kernels.",
help="Disable flashinfer attention kernels.",
)
parser.add_argument(
"--disable-flashinfer-sampling",
action="store_true",
help="Disable flashinfer sampling kernels.",
)
parser.add_argument(
"--disable-radix-cache",
Expand Down Expand Up @@ -331,15 +337,15 @@ def add_cli_args(parser: argparse.ArgumentParser):
help="Optimize the model with torch.compile, experimental feature.",
)
parser.add_argument(
"--attention-reduce-in-fp32",
"--enable-p2p-check",
action="store_true",
help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16."
"This only affects Triton attention kernels",
help="Enable P2P check for GPU access, otherwise the p2p access is allowed by default.",
)
parser.add_argument(
"--enable-p2p-check",
"--attention-reduce-in-fp32",
action="store_true",
help="Enable P2P check for GPU access, otherwise the p2p access is allowed by default.",
help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16."
"This only affects Triton attention kernels",
)
parser.add_argument(
"--efficient-weight-load",
Expand Down

0 comments on commit 752e643

Please sign in to comment.