Skip to content

Commit

Permalink
Fix dockerfile and triton cache manager (#720)
Browse files Browse the repository at this point in the history
  • Loading branch information
hnyls2002 authored Jul 25, 2024
1 parent d63f13c commit 04ec6ba
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 9 deletions.
8 changes: 0 additions & 8 deletions docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,10 @@ RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \
RUN apt-get update -y \
&& apt-get install -y python3-pip git curl sudo

# Workaround for https://github.com/openai/triton/issues/2507 and
# https://github.com/pytorch/pytorch/issues/107960 -- hopefully
# this won't be needed for future versions of this docker image
# or future versions of triton.
RUN ldconfig /usr/local/cuda-$(echo $CUDA_VERSION | cut -d. -f1,2)/compat/

WORKDIR /sgl-workspace

RUN pip3 --no-cache-dir install --upgrade pip \
&& pip3 --no-cache-dir install "sglang[all]" \
&& pip3 --no-cache-dir uninstall -y triton triton-nightly \
&& pip3 --no-cache-dir install --no-deps --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly \
&& pip3 --no-cache-dir install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.3/

ENV DEBIAN_FRONTEND=interactive
6 changes: 6 additions & 0 deletions python/sglang/srt/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
allocate_init_ports,
assert_pkg_version,
enable_show_time_cost,
maybe_set_triton_cache_manager,
set_ulimit,
)
from sglang.utils import get_exception_traceback
Expand Down Expand Up @@ -201,6 +202,11 @@ def launch_server(
"reinstall the latest version by following the instructions "
"at https://docs.flashinfer.ai/installation.html.",
)

if server_args.tp_size // server_args.dp_size > 1:
# FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency.
maybe_set_triton_cache_manager()

if server_args.chat_template:
# TODO: replace this with huggingface transformers template
load_chat_template_for_openai_api(server_args.chat_template)
Expand Down
45 changes: 44 additions & 1 deletion python/sglang/srt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,15 @@
import requests
import torch
import torch.distributed as dist
import triton
from fastapi.responses import JSONResponse
from packaging import version as pkg_version
from starlette.middleware.base import BaseHTTPMiddleware
from triton.runtime.cache import (
FileCacheManager,
default_cache_dir,
default_dump_dir,
default_override_dir,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -460,6 +465,44 @@ def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
setattr(GroupCoordinator, "all_gather", all_gather)


def maybe_set_triton_cache_manager() -> None:
"""Set environment variable to tell Triton to use a
custom cache manager"""
cache_manger = os.environ.get("TRITON_CACHE_MANAGER", None)
if cache_manger is None:
manager = "sglang.srt.utils:CustomCacheManager"
logger.info("Setting Triton cache manager to: %s", manager)
os.environ["TRITON_CACHE_MANAGER"] = manager


class CustomCacheManager(FileCacheManager):
# Adapted from: https://github.com/tdoublep/vllm/blob/3307522289fdfefe323b6c00d0db696651989a2f/vllm/triton_utils/custom_cache_manager.py
def __init__(self, key, override=False, dump=False):

self.key = key
self.lock_path = None
if dump:
self.cache_dir = default_dump_dir()
self.cache_dir = os.path.join(self.cache_dir, self.key)
self.lock_path = os.path.join(self.cache_dir, "lock")
os.makedirs(self.cache_dir, exist_ok=True)
elif override:
self.cache_dir = default_override_dir()
self.cache_dir = os.path.join(self.cache_dir, self.key)
else:
# create cache directory if it doesn't exist
self.cache_dir = (
os.getenv("TRITON_CACHE_DIR", "").strip() or default_cache_dir()
)
if self.cache_dir:
self.cache_dir = f"{self.cache_dir}_{os.getpid()}"
self.cache_dir = os.path.join(self.cache_dir, self.key)
self.lock_path = os.path.join(self.cache_dir, "lock")
os.makedirs(self.cache_dir, exist_ok=True)
else:
raise RuntimeError("Could not create or locate cache dir")


API_KEY_HEADER_NAME = "X-API-Key"


Expand Down

0 comments on commit 04ec6ba

Please sign in to comment.