From 04ec6ba2ac7a4e4beee8be9dc15bc1922544ca82 Mon Sep 17 00:00:00 2001 From: Liangsheng Yin Date: Thu, 25 Jul 2024 03:04:21 -0700 Subject: [PATCH] Fix dockerfile and triton cache manager (#720) --- docker/Dockerfile | 8 ------- python/sglang/srt/server.py | 6 +++++ python/sglang/srt/utils.py | 45 ++++++++++++++++++++++++++++++++++++- 3 files changed, 50 insertions(+), 9 deletions(-) diff --git a/docker/Dockerfile b/docker/Dockerfile index f519d48eccd..abaf645c094 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -23,18 +23,10 @@ RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \ RUN apt-get update -y \ && apt-get install -y python3-pip git curl sudo -# Workaround for https://github.com/openai/triton/issues/2507 and -# https://github.com/pytorch/pytorch/issues/107960 -- hopefully -# this won't be needed for future versions of this docker image -# or future versions of triton. -RUN ldconfig /usr/local/cuda-$(echo $CUDA_VERSION | cut -d. -f1,2)/compat/ - WORKDIR /sgl-workspace RUN pip3 --no-cache-dir install --upgrade pip \ && pip3 --no-cache-dir install "sglang[all]" \ - && pip3 --no-cache-dir uninstall -y triton triton-nightly \ - && pip3 --no-cache-dir install --no-deps --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly \ && pip3 --no-cache-dir install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.3/ ENV DEBIAN_FRONTEND=interactive diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index b0fca27f675..fbbdd06cbdd 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -52,6 +52,7 @@ allocate_init_ports, assert_pkg_version, enable_show_time_cost, + maybe_set_triton_cache_manager, set_ulimit, ) from sglang.utils import get_exception_traceback @@ -201,6 +202,11 @@ def launch_server( "reinstall the latest version by following the instructions " "at https://docs.flashinfer.ai/installation.html.", ) + + if server_args.tp_size // server_args.dp_size > 1: + # FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency. + maybe_set_triton_cache_manager() + if server_args.chat_template: # TODO: replace this with huggingface transformers template load_chat_template_for_openai_api(server_args.chat_template) diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 6ada031d237..a7a9f26d4cf 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -18,10 +18,15 @@ import requests import torch import torch.distributed as dist -import triton from fastapi.responses import JSONResponse from packaging import version as pkg_version from starlette.middleware.base import BaseHTTPMiddleware +from triton.runtime.cache import ( + FileCacheManager, + default_cache_dir, + default_dump_dir, + default_override_dir, +) logger = logging.getLogger(__name__) @@ -460,6 +465,44 @@ def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: setattr(GroupCoordinator, "all_gather", all_gather) +def maybe_set_triton_cache_manager() -> None: + """Set environment variable to tell Triton to use a + custom cache manager""" + cache_manger = os.environ.get("TRITON_CACHE_MANAGER", None) + if cache_manger is None: + manager = "sglang.srt.utils:CustomCacheManager" + logger.info("Setting Triton cache manager to: %s", manager) + os.environ["TRITON_CACHE_MANAGER"] = manager + + +class CustomCacheManager(FileCacheManager): + # Adapted from: https://github.com/tdoublep/vllm/blob/3307522289fdfefe323b6c00d0db696651989a2f/vllm/triton_utils/custom_cache_manager.py + def __init__(self, key, override=False, dump=False): + + self.key = key + self.lock_path = None + if dump: + self.cache_dir = default_dump_dir() + self.cache_dir = os.path.join(self.cache_dir, self.key) + self.lock_path = os.path.join(self.cache_dir, "lock") + os.makedirs(self.cache_dir, exist_ok=True) + elif override: + self.cache_dir = default_override_dir() + self.cache_dir = os.path.join(self.cache_dir, self.key) + else: + # create cache directory if it doesn't exist + self.cache_dir = ( + os.getenv("TRITON_CACHE_DIR", "").strip() or default_cache_dir() + ) + if self.cache_dir: + self.cache_dir = f"{self.cache_dir}_{os.getpid()}" + self.cache_dir = os.path.join(self.cache_dir, self.key) + self.lock_path = os.path.join(self.cache_dir, "lock") + os.makedirs(self.cache_dir, exist_ok=True) + else: + raise RuntimeError("Could not create or locate cache dir") + + API_KEY_HEADER_NAME = "X-API-Key"