diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index a8dd628b9cd6f..87655530cead4 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -145,6 +145,7 @@ def wrap_inductor(graph: fx.GraphModule, example_inputs, additional_inductor_config, compilation_config: CompilationConfig, + vllm_backend: "VllmBackend", graph_index: int = 0, num_graphs: int = 1, runtime_shape: Optional[int] = None, @@ -176,7 +177,7 @@ def wrap_inductor(graph: fx.GraphModule, # see https://github.com/pytorch/pytorch/issues/138980 graph = copy.deepcopy(graph) - cache_data = compilation_config.inductor_hash_cache + cache_data = vllm_backend.inductor_hash_cache if (runtime_shape, graph_index) in cache_data: # we compiled this graph before # so we can directly lookup the compiled graph via hash @@ -196,7 +197,7 @@ def wrap_inductor(graph: fx.GraphModule, hash_str, example_inputs, True, False) assert inductor_compiled_graph is not None, ( "Inductor cache lookup failed. Please remove" - f"the cache file {compilation_config.inductor_hash_cache.cache_file_path} and try again." # noqa + f"the cache file {cache_data.cache_file_path} and try again." # noqa ) # Inductor calling convention (function signature): @@ -354,7 +355,7 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter): def __init__(self, module: torch.fx.GraphModule, compile_submod_names: List[str], vllm_config: VllmConfig, - graph_pool): + graph_pool, vllm_backend: "VllmBackend"): super().__init__(module) from torch._guards import detect_fake_mode self.fake_mode = detect_fake_mode() @@ -362,6 +363,7 @@ def __init__(self, module: torch.fx.GraphModule, self.compilation_config = vllm_config.compilation_config self.graph_pool = graph_pool self.vllm_config = vllm_config + self.vllm_backend = vllm_backend def run(self, *args): fake_args = [ @@ -389,6 +391,7 @@ def call_module(self, target: torch.fx.node.Target, args, self.compilation_config.inductor_compile_config, self.compilation_config, + self.vllm_backend, graph_index=index, num_graphs=len(self.compile_submod_names), runtime_shape=None, @@ -397,7 +400,7 @@ def call_module(self, target: torch.fx.node.Target, self.module.__dict__[target] = PiecewiseBackend( submod, self.vllm_config, self.graph_pool, index, len(self.compile_submod_names), sym_shape_indices, - compiled_graph_for_general_shape) + compiled_graph_for_general_shape, self.vllm_backend) compilation_counter.num_piecewise_capturable_graphs_seen += 1 @@ -430,6 +433,7 @@ class VllmBackend: post_grad_passes: Sequence[Callable] sym_tensor_indices: List[int] input_buffers: List[torch.Tensor] + inductor_hash_cache: InductorHashCache def __init__( self, @@ -472,6 +476,53 @@ def configure_post_pass(self): def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: + if not self.compilation_config.cache_dir: + # no provided cache dir, generate one based on the known factors + # that affects the compilation. if none of the factors change, + # the cache dir will be the same so that we can reuse the compiled + # graph. + + # 1. factors come from the vllm_config (it mainly summarizes how the + # model is created) + vllm_config = self.vllm_config + config_hash = vllm_config.compute_hash() + + # 2. factors come from the code files that are traced by Dynamo ( + # it mainly summarizes how the model is used in forward pass) + forward_code_files = list( + sorted(self.compilation_config.traced_files)) + self.compilation_config.traced_files.clear() + logger.debug( + "Traced files (to be considered for compilation cache):\n%s", + "\n".join(forward_code_files)) + hash_content = [] + for filepath in forward_code_files: + hash_content.append(filepath) + with open(filepath) as f: + hash_content.append(f.read()) + import hashlib + code_hash = hashlib.md5( + "\n".join(hash_content).encode()).hexdigest() + + # combine the two hashes to generate the cache dir + hash_key = hashlib.md5( + f"{config_hash}_{code_hash}".encode()).hexdigest()[:10] + cache_dir = os.path.join( + envs.VLLM_CACHE_ROOT, "torch_compile_cache", hash_key, + f"rank_{vllm_config.parallel_config.rank}") + else: + cache_dir = self.compilation_config.cache_dir + os.makedirs(cache_dir, exist_ok=True) + + disabled = envs.VLLM_DISABLE_COMPILE_CACHE + self.inductor_hash_cache: InductorHashCache = InductorHashCache( + cache_dir, disabled=disabled) + if disabled: + logger.info("vLLM's torch.compile cache is disabled.") + else: + logger.info("Using cache directory: %s for vLLM's torch.compile", + cache_dir) + # when dynamo calls the backend, it means the bytecode # transform and analysis are done compilation_counter.num_graphs_seen += 1 @@ -507,8 +558,8 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: # propagate the split graph to the piecewise backend, # compile submodules with symbolic shapes PiecewiseCompileInterpreter(self.split_gm, submod_names_to_compile, - self.vllm_config, - self.graph_pool).run(*example_inputs) + self.vllm_config, self.graph_pool, + self).run(*example_inputs) self._called = True @@ -577,7 +628,8 @@ class PiecewiseBackend: def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig, graph_pool: Any, piecewise_compile_index: int, total_piecewise_compiles: int, sym_shape_indices: List[int], - compiled_graph_for_general_shape: Callable): + compiled_graph_for_general_shape: Callable, + vllm_backend: VllmBackend): """ The backend for piecewise compilation. It mainly handles the compilation and cudagraph capturing. @@ -597,6 +649,7 @@ def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig, self.graph_pool = graph_pool self.piecewise_compile_index = piecewise_compile_index self.total_piecewise_compiles = total_piecewise_compiles + self.vllm_backend = vllm_backend self.is_first_graph = piecewise_compile_index == 0 self.is_last_graph = ( @@ -634,7 +687,7 @@ def check_for_ending_compilation(self): if self.is_last_graph and not self.to_be_compiled_sizes: # no specific sizes to compile # save the hash of the inductor graph for the next run - self.compilation_config.inductor_hash_cache.save_to_file() + self.vllm_backend.inductor_hash_cache.save_to_file() end_monitoring_torch_compile(self.vllm_config) def __call__(self, *args) -> Any: @@ -662,6 +715,7 @@ def __call__(self, *args) -> Any: args, self.compilation_config.inductor_compile_config, self.compilation_config, + self.vllm_backend, graph_index=self.piecewise_compile_index, num_graphs=self.total_piecewise_compiles, runtime_shape=runtime_shape, diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py index 805a217ee6ca1..10513111ea7f1 100644 --- a/vllm/compilation/decorators.py +++ b/vllm/compilation/decorators.py @@ -1,8 +1,10 @@ import inspect from typing import Callable, Dict, List, Optional, TypeVar, Union, overload +from unittest.mock import patch import torch import torch.nn as nn +from torch._dynamo.symbolic_convert import InliningInstructionTranslator from vllm.compilation.counter import compilation_counter from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher @@ -196,7 +198,31 @@ def __call__(self, *args, **kwargs): # we need to control all the compilation of the model. torch._dynamo.eval_frame.remove_from_cache( self.original_code_object) - return self.compiled_callable(*args, **kwargs) + + # collect all relevant files traced by Dynamo, + # so that the compilation cache can trigger re-compilation + # properly when any of these files change. + + # 1. the file containing the top-level forward function + self.vllm_config.compilation_config.traced_files.add( + self.original_code_object.co_filename) + + # 2. every time Dynamo sees a function call, it will inline + # the function by calling InliningInstructionTranslator.inline_call + # we hijack this function to know all the functions called + # during Dynamo tracing, and their corresponding files + inline_call = InliningInstructionTranslator.inline_call + + def patched_inline_call(parent, func, args, kwargs): + code = func.get_code() + self.vllm_config.compilation_config.traced_files.add( + code.co_filename) + return inline_call(parent, func, args, kwargs) + + with patch.object(InliningInstructionTranslator, 'inline_call', + patched_inline_call): + output = self.compiled_callable(*args, **kwargs) + return output # usually, capturing the model once is enough, and then we can # dispatch to the compiled code directly, without going through diff --git a/vllm/config.py b/vllm/config.py index 535cbe97a311a..6dabeb3861af2 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -3,7 +3,6 @@ import enum import hashlib import json -import os import sys import warnings from contextlib import contextmanager @@ -2778,9 +2777,8 @@ def model_post_init(self, __context: Any) -> None: # keep track of enabled and disabled custom ops enabled_custom_ops: Counter[str] = PrivateAttr disabled_custom_ops: Counter[str] = PrivateAttr + traced_files: Set[str] = PrivateAttr compilation_time: float = PrivateAttr - # should be InductorHashCache, but Pydantic does not support it - inductor_hash_cache: Any = PrivateAttr # Per-model forward context # Mainly used to store attention cls @@ -2818,6 +2816,7 @@ def __repr__(self) -> str: "compilation_time", "bs_to_padded_graph_size", "pass_config", + "traced_files", } return self.model_dump_json(exclude=exclude, exclude_unset=True) @@ -2877,6 +2876,7 @@ def model_post_init(self, __context: Any) -> None: self.enabled_custom_ops = Counter() self.disabled_custom_ops = Counter() + self.traced_files = set() self.static_forward_context = {} self.compilation_time = 0.0 @@ -2899,29 +2899,6 @@ def init_backend(self, vllm_config: "VllmConfig") -> Union[str, Callable]: # merge with the config use_inductor assert self.level == CompilationLevel.PIECEWISE - if not self.cache_dir: - # no provided cache dir, generate one based on the known factors - # that affects the compilation. if none of the factors change, - # the cache dir will be the same so that we can reuse the compiled - # graph. - hash_key = vllm_config.compute_hash() - cache_dir = os.path.join( - envs.VLLM_CACHE_ROOT, "torch_compile_cache", hash_key, - f"rank_{vllm_config.parallel_config.rank}") - os.makedirs(cache_dir, exist_ok=True) - self.cache_dir = cache_dir - - disabled = envs.VLLM_DISABLE_COMPILE_CACHE - from vllm.compilation.backends import InductorHashCache - self.inductor_hash_cache: InductorHashCache = InductorHashCache( - self.cache_dir, disabled=disabled) - if disabled: - logger.info("vLLM's torch.compile cache is disabled.") - else: - logger.info( - "Using cache directory: %s for vLLM's torch.compile", - self.cache_dir) - from vllm.compilation.backends import VllmBackend return VllmBackend(vllm_config) diff --git a/vllm/sequence.py b/vllm/sequence.py index 0157abbd2eed5..5857f656dfc10 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -1108,6 +1108,13 @@ class IntermediateTensors: tensors: Dict[str, torch.Tensor] + def __init__(self, tensors): + # manually define this function, so that + # Dynamo knows `IntermediateTensors()` comes from this file. + # Otherwise, dataclass will generate this function by evaluating + # a string, and we will lose the information about the source file. + self.tensors = tensors + def __getitem__(self, key: Union[str, slice]): if isinstance(key, str): return self.tensors[key]