Skip to content

Commit

Permalink
Add logging to fix_functionalization
Browse files Browse the repository at this point in the history
Signed-off-by: luka <[email protected]>
  • Loading branch information
ProExpertProg committed Nov 13, 2024
1 parent b051eb2 commit af751fc
Showing 1 changed file with 13 additions and 1 deletion.
14 changes: 13 additions & 1 deletion vllm/compilation/functionalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@
import torch
from torch._higher_order_ops.auto_functionalize import auto_functionalized

from vllm.compilation.inductor_pass import VllmInductorPass, is_func
from vllm.logger import init_logger

from .inductor_pass import VllmInductorPass, is_func

logger = init_logger(__name__)


class FixFunctionalizationPass(VllmInductorPass):
Expand All @@ -23,6 +27,7 @@ def __call__(self, graph: torch.fx.Graph):
self.dump_graph(graph, "before_fix_functionalization")

self.nodes_to_remove: List[torch.fx.Node] = []
count = 0
for node in graph.nodes:
if not is_func(node, auto_functionalized):
continue # Avoid deep if-elif nesting
Expand Down Expand Up @@ -72,13 +77,20 @@ def __call__(self, graph: torch.fx.Graph):
node,
mutated_args,
args=('out', 'input'))
else:
continue # skip the count

count += 1

self.dump_graph(graph, "before_fix_functionalization_cleanup")

# Remove the nodes all at once
count_removed = len(self.nodes_to_remove)
for node in self.nodes_to_remove:
graph.erase_node(node)

logger.debug("De-functionalized %s nodes, removed %s nodes", count,
count_removed)
self.dump_graph(graph, "after_fix_functionalization")

def _remove(self, node_or_nodes: Union[torch.fx.Node,
Expand Down

0 comments on commit af751fc

Please sign in to comment.