Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[7/N] torch.compile, reduce compilation time #10460

Merged
merged 5 commits into from
Nov 20, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/compile/piecewise/test_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def test_simple_piecewise_compile():
vllm_config = VllmConfig(compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE,
use_cudagraph=True,
non_cudagraph_ops=["silly.attention"],
splitting_ops=["silly.attention"],
cudagraph_copy_inputs=True,
))
with set_current_vllm_config(vllm_config):
Expand Down
4 changes: 2 additions & 2 deletions tests/compile/piecewise/test_toy_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def run_model(llama_config,
use_cudagraph=True,
)
if split_attn:
compilation_config.non_cudagraph_ops = ["silly.attention"]
compilation_config.splitting_ops = ["silly.attention"]
else:
compilation_config = CompilationConfig(
level=CompilationLevel.NO_COMPILATION, )
Expand Down Expand Up @@ -378,7 +378,7 @@ def benchmark():
compilation_config = CompilationConfig(
level=CompilationLevel.PIECEWISE,
use_cudagraph=True,
non_cudagraph_ops=["silly.attention"],
splitting_ops=["silly.attention"],
)
else:
compilation_config = CompilationConfig(
Expand Down
2 changes: 1 addition & 1 deletion vllm/compilation/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,7 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
self.add_passes_to_config()

self.split_gm, self.piecewise_graphs = split_graph(
graph, self.compilation_configs.non_cudagraph_ops)
graph, self.compilation_configs.splitting_ops)

from torch._dynamo.utils import lazy_format_graph_code
logger.debug("%s", lazy_format_graph_code("before split", self.graph))
Expand Down
10 changes: 6 additions & 4 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2089,6 +2089,7 @@ class CompilationConfig(BaseModel):
- 'none,+op1,+op2' to enable only op1 and op2
By default, all custom ops are enabled when running without Inductor
and disabled when running with Inductor (compile_level >= Inductor).
- splitting_ops: a list of ops to split the full graph into subgraphs, used in piecewise compilation.
- CudaGraph capture:
- use_cudagraph: whether to use cudagraph inside compilation.
- False: cudagraph inside compilation is not used.
Expand Down Expand Up @@ -2149,6 +2150,11 @@ class CompilationConfig(BaseModel):
level: int = 0
backend: str = ""
custom_ops: List[str] = Field(default_factory=list)
splitting_ops: List[str] = Field(default_factory=lambda: [
"vllm.unified_flash_attention",
"vllm.unified_flash_infer",
"vllm.unified_v1_flash_attention",
])
mgoin marked this conversation as resolved.
Show resolved Hide resolved
Comment on lines +2154 to +2158
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What are these default values for?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

they are the default splitting ops we know inductor cannot do any optimization, and they are good candidates to split the model.


use_inductor: bool = True
inductor_specialize_for_cudagraph_no_more_than: Optional[int] = None
Expand All @@ -2157,7 +2163,6 @@ class CompilationConfig(BaseModel):
inductor_passes: Dict[str, str] = Field(default_factory=dict)

use_cudagraph: bool = False
non_cudagraph_ops: List[str] = Field(default_factory=list)
cudagraph_num_of_warmups: int = 0
cudagraph_capture_sizes: Optional[List[int]] = None
cudagraph_copy_inputs: bool = False
Expand Down Expand Up @@ -2348,9 +2353,6 @@ def __post_init__(self):
# and avoid any potential issues with the inductor.
self.compilation_config.custom_ops = ["none"]
self.compilation_config.use_cudagraph = True
self.compilation_config.non_cudagraph_ops = [
"vllm.unified_v1_flash_attention"
]
self.compilation_config.use_inductor = True
self.compilation_config.enable_fusion = False

Expand Down
5 changes: 5 additions & 0 deletions vllm/worker/worker.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""A GPU worker class."""
import gc
import os
import time
from typing import Dict, List, Optional, Set, Tuple, Type, Union

import torch
Expand Down Expand Up @@ -189,6 +190,7 @@ def determine_num_available_blocks(self) -> Tuple[int, int]:
torch.cuda.reset_peak_memory_stats()

free_memory_pre_profile, total_gpu_memory = torch.cuda.mem_get_info()
start_time = time.time()

# Execute a forward pass with dummy inputs to profile the memory usage
# of the model.
Expand Down Expand Up @@ -229,6 +231,9 @@ def determine_num_available_blocks(self) -> Tuple[int, int]:
num_gpu_blocks = max(num_gpu_blocks, 0)
num_cpu_blocks = max(num_cpu_blocks, 0)

end_time = time.time()
logger.info("Memory profiling took %.2f seconds",
end_time - start_time)
logger.info(
"Memory profiling results: total_gpu_memory=%.2fGiB"
mgoin marked this conversation as resolved.
Show resolved Hide resolved
" initial_memory_usage=%.2fGiB peak_torch_memory=%.2fGiB"
Expand Down