Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix dockerfile and triton cache manager #720

Merged
merged 3 commits into from
Jul 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 0 additions & 8 deletions docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,10 @@ RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \
RUN apt-get update -y \
&& apt-get install -y python3-pip git curl sudo

# Workaround for https://github.com/openai/triton/issues/2507 and
# https://github.com/pytorch/pytorch/issues/107960 -- hopefully
# this won't be needed for future versions of this docker image
# or future versions of triton.
RUN ldconfig /usr/local/cuda-$(echo $CUDA_VERSION | cut -d. -f1,2)/compat/

WORKDIR /sgl-workspace

RUN pip3 --no-cache-dir install --upgrade pip \
&& pip3 --no-cache-dir install "sglang[all]" \
&& pip3 --no-cache-dir uninstall -y triton triton-nightly \
zhyncs marked this conversation as resolved.
Show resolved Hide resolved
&& pip3 --no-cache-dir install --no-deps --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly \
&& pip3 --no-cache-dir install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.3/

ENV DEBIAN_FRONTEND=interactive
6 changes: 6 additions & 0 deletions python/sglang/srt/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
allocate_init_ports,
assert_pkg_version,
enable_show_time_cost,
maybe_set_triton_cache_manager,
set_ulimit,
)
from sglang.utils import get_exception_traceback
Expand Down Expand Up @@ -201,6 +202,11 @@ def launch_server(
"reinstall the latest version by following the instructions "
"at https://docs.flashinfer.ai/installation.html.",
)

if server_args.tp_size // server_args.dp_size > 1:
# FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency.
maybe_set_triton_cache_manager()

if server_args.chat_template:
# TODO: replace this with huggingface transformers template
load_chat_template_for_openai_api(server_args.chat_template)
Expand Down
45 changes: 44 additions & 1 deletion python/sglang/srt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,15 @@
import requests
import torch
import torch.distributed as dist
import triton
from fastapi.responses import JSONResponse
from packaging import version as pkg_version
from starlette.middleware.base import BaseHTTPMiddleware
from triton.runtime.cache import (
FileCacheManager,
default_cache_dir,
default_dump_dir,
default_override_dir,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -460,6 +465,44 @@ def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
setattr(GroupCoordinator, "all_gather", all_gather)


def maybe_set_triton_cache_manager() -> None:
"""Set environment variable to tell Triton to use a
custom cache manager"""
cache_manger = os.environ.get("TRITON_CACHE_MANAGER", None)
if cache_manger is None:
manager = "sglang.srt.utils:CustomCacheManager"
logger.info("Setting Triton cache manager to: %s", manager)
os.environ["TRITON_CACHE_MANAGER"] = manager


class CustomCacheManager(FileCacheManager):
# Adapted from: https://github.com/tdoublep/vllm/blob/3307522289fdfefe323b6c00d0db696651989a2f/vllm/triton_utils/custom_cache_manager.py
def __init__(self, key, override=False, dump=False):

self.key = key
self.lock_path = None
if dump:
self.cache_dir = default_dump_dir()
self.cache_dir = os.path.join(self.cache_dir, self.key)
self.lock_path = os.path.join(self.cache_dir, "lock")
os.makedirs(self.cache_dir, exist_ok=True)
elif override:
self.cache_dir = default_override_dir()
self.cache_dir = os.path.join(self.cache_dir, self.key)
else:
# create cache directory if it doesn't exist
self.cache_dir = (
os.getenv("TRITON_CACHE_DIR", "").strip() or default_cache_dir()
)
if self.cache_dir:
self.cache_dir = f"{self.cache_dir}_{os.getpid()}"
self.cache_dir = os.path.join(self.cache_dir, self.key)
self.lock_path = os.path.join(self.cache_dir, "lock")
os.makedirs(self.cache_dir, exist_ok=True)
else:
raise RuntimeError("Could not create or locate cache dir")


API_KEY_HEADER_NAME = "X-API-Key"


Expand Down
Loading