From 8e9488544961e5863e1ee490d64b65f328f367b4 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 19 Nov 2024 10:27:47 -0800 Subject: [PATCH 1/5] add time profiling for memory profiling Signed-off-by: youkaichao --- vllm/worker/worker.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index d3ca6d9d0b17e..3a65621261947 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -1,6 +1,7 @@ """A GPU worker class.""" import gc import os +import time from typing import Dict, List, Optional, Set, Tuple, Type, Union import torch @@ -189,6 +190,7 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: torch.cuda.reset_peak_memory_stats() free_memory_pre_profile, total_gpu_memory = torch.cuda.mem_get_info() + start_time = time.time() # Execute a forward pass with dummy inputs to profile the memory usage # of the model. @@ -229,6 +231,9 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: num_gpu_blocks = max(num_gpu_blocks, 0) num_cpu_blocks = max(num_cpu_blocks, 0) + end_time = time.time() + logger.info("Memory profiling took %.2f seconds", + end_time - start_time) logger.info( "Memory profiling results: total_gpu_memory=%.2fGiB" " initial_memory_usage=%.2fGiB peak_torch_memory=%.2fGiB" From f6b95aa7e3601e5b2326a5969886c5198cf92ef1 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 19 Nov 2024 10:46:47 -0800 Subject: [PATCH 2/5] splitting ops Signed-off-by: youkaichao --- tests/compile/piecewise/test_simple.py | 2 +- tests/compile/piecewise/test_toy_llama.py | 4 ++-- vllm/compilation/backends.py | 2 +- vllm/config.py | 10 ++++++---- 4 files changed, 10 insertions(+), 8 deletions(-) diff --git a/tests/compile/piecewise/test_simple.py b/tests/compile/piecewise/test_simple.py index 0e40e3b4ebc96..0db12d6b6a43c 100644 --- a/tests/compile/piecewise/test_simple.py +++ b/tests/compile/piecewise/test_simple.py @@ -79,7 +79,7 @@ def test_simple_piecewise_compile(): vllm_config = VllmConfig(compilation_config=CompilationConfig( level=CompilationLevel.PIECEWISE, use_cudagraph=True, - non_cudagraph_ops=["silly.attention"], + splitting_ops=["silly.attention"], cudagraph_copy_inputs=True, )) with set_current_vllm_config(vllm_config): diff --git a/tests/compile/piecewise/test_toy_llama.py b/tests/compile/piecewise/test_toy_llama.py index 356d119a40334..cfe661b8871e0 100644 --- a/tests/compile/piecewise/test_toy_llama.py +++ b/tests/compile/piecewise/test_toy_llama.py @@ -258,7 +258,7 @@ def run_model(llama_config, use_cudagraph=True, ) if split_attn: - compilation_config.non_cudagraph_ops = ["silly.attention"] + compilation_config.splitting_ops = ["silly.attention"] else: compilation_config = CompilationConfig( level=CompilationLevel.NO_COMPILATION, ) @@ -378,7 +378,7 @@ def benchmark(): compilation_config = CompilationConfig( level=CompilationLevel.PIECEWISE, use_cudagraph=True, - non_cudagraph_ops=["silly.attention"], + splitting_ops=["silly.attention"], ) else: compilation_config = CompilationConfig( diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 0cf1e3a95fcba..416cffd326489 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -447,7 +447,7 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: self.add_passes_to_config() self.split_gm, self.piecewise_graphs = split_graph( - graph, self.compilation_configs.non_cudagraph_ops) + graph, self.compilation_configs.splitting_ops) from torch._dynamo.utils import lazy_format_graph_code logger.debug("%s", lazy_format_graph_code("before split", self.graph)) diff --git a/vllm/config.py b/vllm/config.py index e69cbd3eb402a..d9d4fad5df65d 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2089,6 +2089,7 @@ class CompilationConfig(BaseModel): - 'none,+op1,+op2' to enable only op1 and op2 By default, all custom ops are enabled when running without Inductor and disabled when running with Inductor (compile_level >= Inductor). + - splitting_ops: a list of ops to split the full graph into subgraphs, used in piecewise compilation. - CudaGraph capture: - use_cudagraph: whether to use cudagraph inside compilation. - False: cudagraph inside compilation is not used. @@ -2149,6 +2150,11 @@ class CompilationConfig(BaseModel): level: int = 0 backend: str = "" custom_ops: List[str] = Field(default_factory=list) + splitting_ops: List[str] = Field(default_factory=lambda: [ + "vllm.unified_flash_attention", + "vllm.unified_flash_infer", + "vllm.unified_v1_flash_attention", + ]) use_inductor: bool = True inductor_specialize_for_cudagraph_no_more_than: Optional[int] = None @@ -2157,7 +2163,6 @@ class CompilationConfig(BaseModel): inductor_passes: Dict[str, str] = Field(default_factory=dict) use_cudagraph: bool = False - non_cudagraph_ops: List[str] = Field(default_factory=list) cudagraph_num_of_warmups: int = 0 cudagraph_capture_sizes: Optional[List[int]] = None cudagraph_copy_inputs: bool = False @@ -2348,9 +2353,6 @@ def __post_init__(self): # and avoid any potential issues with the inductor. self.compilation_config.custom_ops = ["none"] self.compilation_config.use_cudagraph = True - self.compilation_config.non_cudagraph_ops = [ - "vllm.unified_v1_flash_attention" - ] self.compilation_config.use_inductor = True self.compilation_config.enable_fusion = False From a545b677e75033ac97ee107439811da43dfd8be6 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 19 Nov 2024 11:13:20 -0800 Subject: [PATCH 3/5] merge logging Signed-off-by: youkaichao --- vllm/worker/worker.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 3a65621261947..745a59362a58a 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -232,14 +232,14 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: num_cpu_blocks = max(num_cpu_blocks, 0) end_time = time.time() - logger.info("Memory profiling took %.2f seconds", - end_time - start_time) logger.info( - "Memory profiling results: total_gpu_memory=%.2fGiB" + "Memory profiling results: duration=%.2f seconds," + "total_gpu_memory=%.2fGiB" " initial_memory_usage=%.2fGiB peak_torch_memory=%.2fGiB" " memory_usage_post_profile=%.2fGiB" " non_torch_memory=%.2fGiB kv_cache_size=%.2fGiB" - " gpu_memory_utilization=%.2f", total_gpu_memory / (1024**3), + " gpu_memory_utilization=%.2f", end_time - start_time, + total_gpu_memory / (1024**3), (total_gpu_memory - free_memory_pre_profile) / (1024**3), (peak_memory - non_torch_allocations) / (1024**3), total_allocated_bytes / (1024**3), From 1a43f59d1bba630ae47068a83642d116cf1b43f7 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 19 Nov 2024 11:16:28 -0800 Subject: [PATCH 4/5] fix format Signed-off-by: youkaichao --- vllm/worker/worker.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 745a59362a58a..80fd7bc3b67cc 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -233,12 +233,15 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: end_time = time.time() logger.info( - "Memory profiling results: duration=%.2f seconds," - "total_gpu_memory=%.2fGiB" - " initial_memory_usage=%.2fGiB peak_torch_memory=%.2fGiB" - " memory_usage_post_profile=%.2fGiB" - " non_torch_memory=%.2fGiB kv_cache_size=%.2fGiB" - " gpu_memory_utilization=%.2f", end_time - start_time, + "Memory profiling results: " + "duration=%.2f seconds, " + "total_gpu_memory=%.2fGiB, " + "initial_memory_usage=%.2fGiB, " + "peak_torch_memory=%.2fGiB, " + "memory_usage_post_profile=%.2fGiB, " + "non_torch_memory=%.2fGiB, " + "kv_cache_size=%.2fGiB, " + "gpu_memory_utilization=%.2f.", end_time - start_time, total_gpu_memory / (1024**3), (total_gpu_memory - free_memory_pre_profile) / (1024**3), (peak_memory - non_torch_allocations) / (1024**3), From af3887d178a294f20d349698898abd86a5aa957a Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 19 Nov 2024 11:25:17 -0800 Subject: [PATCH 5/5] fix comments Signed-off-by: youkaichao --- vllm/config.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index d9d4fad5df65d..3d0c616868225 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2094,9 +2094,10 @@ class CompilationConfig(BaseModel): - use_cudagraph: whether to use cudagraph inside compilation. - False: cudagraph inside compilation is not used. - True: cudagraph inside compilation is used. It requires - that all input buffers have fixed addresses. - Note that this is orthogonal to the cudagraph capture out - side of compilation. + that all input buffers have fixed addresses, and all + splitting ops write their outputs to input buffers. + Note that this is orthogonal to the cudagraph capture logic + outside of compilation. TODO: move outside cudagraph logic into compilation. torch.compile will handle cudagraph capture logic in the future. - cudagraph_capture_sizes: sizes to capture cudagraph.