diff --git a/functorch/compile/__init__.py b/functorch/compile/__init__.py index 96b853cd2e27e..e7548a5ff6b91 100644 --- a/functorch/compile/__init__.py +++ b/functorch/compile/__init__.py @@ -25,7 +25,6 @@ from torch._functorch.partitioners import ( default_partition, draw_graph, - draw_joint_graph, min_cut_rematerialization_partition, ) from torch._functorch.python_key import pythonkey_decompose diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py index ffa71a7e905b5..cfbd96e7368d9 100644 --- a/test/functorch/test_aotdispatch.py +++ b/test/functorch/test_aotdispatch.py @@ -4835,70 +4835,6 @@ def f(a, b, c, d): self.assertEqual(get_num_ins_outs(fw_graph), (4, 2)) self.assertEqual(get_num_ins_outs(bw_graph), (2, 4)) - @unittest.skipIf(not USE_NETWORKX, "networkx not available") - def test_min_cut_partitioner_recomputable_ops(self): - def f(x): - return x * x * x - - recomputable_ops = [] - partition_fn = partial( - min_cut_rematerialization_partition, recomputable_ops=recomputable_ops - ) - - fw_graph, bw_graph = get_fw_bw_graph( - f, [torch.randn(3, requires_grad=True)], partition_fn - ) - # Expected forward graph: - # opcode name target args kwargs - # ------------- --------- --------------- -------------------------- -------- - # placeholder primals_1 primals_1 () {} - # call_function mul aten.mul.Tensor (primals_1, primals_1) {} - # call_function mul_1 aten.mul.Tensor (mul, primals_1) {} - # output output output ([mul_1, primals_1, mul],) {} - self.assertEqual(get_num_ins_outs(fw_graph), (1, 3)) - # Expected backward graph: - # opcode name target args kwargs - # ------------- ---------- --------------- ----------------------- -------- - # placeholder primals_1 primals_1 () {} - # placeholder mul mul () {} - # placeholder tangents_1 tangents_1 () {} - # call_function mul_2 aten.mul.Tensor (tangents_1, mul) {} - # call_function mul_3 aten.mul.Tensor (tangents_1, primals_1) {} - # call_function mul_4 aten.mul.Tensor (mul_3, primals_1) {} - # call_function add aten.add.Tensor (mul_2, mul_4) {} - # call_function add_1 aten.add.Tensor (add, mul_4) {} - # output output output ([add_1],) {} - self.assertEqual(get_num_ins_outs(bw_graph), (3, 1)) - - recomputable_ops = [torch.ops.aten.mul] - partition_fn = partial( - min_cut_rematerialization_partition, recomputable_ops=recomputable_ops - ) - fw_graph, bw_graph = get_fw_bw_graph( - f, [torch.randn(3, requires_grad=True)], partition_fn - ) - # Expected forward graph: - # opcode name target args kwargs - # ------------- --------- --------------- ---------------------- -------- - # placeholder primals_1 primals_1 () {} - # call_function mul aten.mul.Tensor (primals_1, primals_1) {} - # call_function mul_1 aten.mul.Tensor (mul, primals_1) {} - # output output output ([mul_1, primals_1],) {} - self.assertEqual(get_num_ins_outs(fw_graph), (1, 2)) - # Expected backward graph: - # opcode name target args kwargs - # ------------- ---------- --------------- ----------------------- -------- - # placeholder primals_1 primals_1 () {} - # placeholder tangents_1 tangents_1 () {} - # call_function mul aten.mul.Tensor (primals_1, primals_1) {} # RECOMPUTED - # call_function mul_2 aten.mul.Tensor (tangents_1, mul) {} - # call_function mul_3 aten.mul.Tensor (tangents_1, primals_1) {} - # call_function mul_4 aten.mul.Tensor (mul_3, primals_1) {} - # call_function add aten.add.Tensor (mul_2, mul_4) {} - # call_function add_1 aten.add.Tensor (add, mul_4) {} - # output output output ([add_1],) {} - self.assertEqual(get_num_ins_outs(bw_graph), (2, 1)) - def test_contiguous(self): # The test simulates the condition where transpose followed by view # happens in the backward pass. diff --git a/torch/_functorch/compile_utils.py b/torch/_functorch/compile_utils.py index ffa37e59f04df..c9c750835a9f1 100644 --- a/torch/_functorch/compile_utils.py +++ b/torch/_functorch/compile_utils.py @@ -1,6 +1,8 @@ # mypy: ignore-errors +from typing import Callable + import torch import torch.fx as fx from torch.utils import _pytree as pytree @@ -9,7 +11,7 @@ aten = torch.ops.aten -def get_aten_target(node): +def get_aten_target(node: fx.Node) -> Callable: if hasattr(node.target, "overloadpacket"): return node.target.overloadpacket return node.target diff --git a/torch/_functorch/partitioners.py b/torch/_functorch/partitioners.py index 5a43cd5e7bf33..d104247b3f63d 100644 --- a/torch/_functorch/partitioners.py +++ b/torch/_functorch/partitioners.py @@ -1,5 +1,3 @@ -# mypy: ignore-errors - import copy import functools import heapq @@ -9,7 +7,10 @@ import operator import os from collections import defaultdict -from typing import List, Optional, Set, Tuple, TYPE_CHECKING, Union +from dataclasses import dataclass, replace +from typing import Callable, Dict, List, Optional, Set, Tuple, Union + +import sympy import torch import torch._inductor.inductor_prims @@ -28,19 +29,84 @@ from . import config from .compile_utils import fx_graph_cse, get_aten_target -if TYPE_CHECKING: - import sympy - AOT_PARTITIONER_DEBUG = config.debug_partitioner log = logging.getLogger(__name__) +aten = torch.ops.aten +prims = torch.ops.prims + + +@dataclass +class OpTypes: + """Class for keeping track of different operator categories""" + + fusible_ops: Set[Callable] + compute_intensive_ops: Set[Callable] + random_ops: Set[Callable] + view_ops: Set[Callable] + recomputable_ops: Set[Callable] + + def is_fusible(self, node: fx.Node): + return get_aten_target(node) in self.fusible_ops + + def is_compute_intensive(self, node: fx.Node): + return get_aten_target(node) in self.compute_intensive_ops + + def is_random(self, node: fx.Node): + return get_aten_target(node) in self.random_ops + + def is_view(self, node: fx.Node): + return get_aten_target(node) in self.view_ops + + def is_recomputable(self, node: fx.Node): + return get_aten_target(node) in self.recomputable_ops + + +@dataclass +class NodeInfo: + # Be careful about iterating over these explicitly, as their order may not + # be deterministic + inputs: List[fx.Node] + _required_fw_nodes: Set[fx.Node] + required_bw_nodes: Set[fx.Node] + unclaimed_nodes: Set[fx.Node] + fw_order: Dict[fx.Node, int] + + @property + def required_fw_nodes(self) -> List[fx.Node]: + return sorted( + (n for n in self._required_fw_nodes), key=lambda n: self.fw_order[n] + ) + + def is_required_fw(self, n: fx.Node) -> bool: + return n in self._required_fw_nodes -def must_recompute(node): + def is_required_bw(self, n: fx.Node) -> bool: + return n in self.required_bw_nodes + + def is_unclaimed(self, n: fx.Node) -> bool: + return n in self.unclaimed_nodes + + def get_fw_order(self, n: fx.Node) -> int: + assert n in self._required_fw_nodes, f"Node {n} not in fw nodes!" + return self.fw_order[n] + + +@dataclass +class MinCutOptions: + ban_if_used_far_apart: bool + ban_if_long_fusible_chains: bool + ban_if_materialized_backward: bool + ban_if_not_in_allowlist: bool + ban_if_reduction: bool + + +def must_recompute(node: fx.Node) -> bool: return node.meta.get("recompute", False) -def has_recomputable_ops(fx_g): +def has_recomputable_ops(fx_g: fx.GraphModule) -> bool: found = False for node in fx_g.graph.nodes: if must_recompute(node): @@ -48,7 +114,7 @@ def has_recomputable_ops(fx_g): return False -def has_recomputable_rng_ops(fx_g): +def has_recomputable_rng_ops(fx_g: fx.GraphModule) -> bool: for node in fx_g.graph.nodes: if ( must_recompute(node) @@ -59,7 +125,7 @@ def has_recomputable_rng_ops(fx_g): return False -def sym_node_size(node): +def sym_node_size(node: fx.Node) -> int: if isinstance(node.meta["val"], (torch.SymInt, torch.SymBool)): return 1 assert isinstance(node.meta["val"], torch.SymFloat) @@ -74,7 +140,9 @@ def __repr__(self): InvalidNode = InvalidNodeBase() -def _extract_graph_with_inputs_outputs(joint_graph, inputs, outputs): +def _extract_graph_with_inputs_outputs( + joint_graph: fx.Graph, inputs: List[fx.Node], outputs: List[fx.Node] +) -> fx.Graph: """ Given a graph, extracts out a subgraph that takes the specified nodes as inputs and returns the specified outputs. @@ -136,36 +204,38 @@ def _extract_graph_with_inputs_outputs(joint_graph, inputs, outputs): return new_graph -def _is_primal(node): +def _is_primal(node: fx.Node) -> bool: return ( node.op == "placeholder" - and "tangents" not in node.target + and "tangents" not in str(node.target) and not _is_bwd_seed_offset(node) and not _is_fwd_seed_offset(node) ) -def _is_tangent(node): - return node.op == "placeholder" and "tangents" in node.target +def _is_tangent(node: fx.Node) -> bool: + return node.op == "placeholder" and "tangents" in str(node.target) -def _is_bwd_seed_offset(node): +def _is_bwd_seed_offset(node: fx.Node) -> bool: return node.op == "placeholder" and ( - "bwd_seed" in node.target or "bwd_base_offset" in node.target + "bwd_seed" in str(node.target) or "bwd_base_offset" in str(node.target) ) -def _is_fwd_seed_offset(node): +def _is_fwd_seed_offset(node: fx.Node) -> bool: return node.op == "placeholder" and ( - "fwd_seed" in node.target or "fwd_base_offset" in node.target + "fwd_seed" in str(node.target) or "fwd_base_offset" in str(node.target) ) -def _is_backward_state(node): +def _is_backward_state(node: fx.Node) -> bool: return node.op == "placeholder" and isinstance(node.meta.get("val"), BackwardState) -def _extract_fwd_bwd_outputs(joint_module: fx.GraphModule, *, num_fwd_outputs): +def _extract_fwd_bwd_outputs( + joint_module: fx.GraphModule, *, num_fwd_outputs +) -> Tuple[List[fx.Node], List[fx.Node]]: outputs = pytree.arg_tree_leaves( *(node.args for node in joint_module.graph.find_nodes(op="output")) ) @@ -174,7 +244,7 @@ def _extract_fwd_bwd_outputs(joint_module: fx.GraphModule, *, num_fwd_outputs): return fwd_outputs, bwd_outputs -def _remove_by_name(saved_values, name): +def _remove_by_name(saved_values: List[fx.Node], name: str): for saved_value in saved_values: if saved_value.name == name: saved_values.remove(saved_value) @@ -182,8 +252,12 @@ def _remove_by_name(saved_values, name): def _extract_fwd_bwd_modules( - joint_module: fx.GraphModule, saved_values, saved_sym_nodes, *, num_fwd_outputs -): + joint_module: fx.GraphModule, + saved_values: List[fx.Node], + saved_sym_nodes: List[fx.Node], + *, + num_fwd_outputs: int, +) -> Tuple[fx.GraphModule, fx.GraphModule]: fwd_outputs, bwd_outputs = _extract_fwd_bwd_outputs( joint_module, num_fwd_outputs=num_fwd_outputs ) @@ -359,14 +433,10 @@ def default_partition( ) -def _prod(x): - s = 1 - for i in x: - s *= i - return s +INT_INF = int(1e6) -def _tensor_nbytes(numel, dtype): +def _tensor_nbytes(numel: int, dtype) -> int: return numel * dtype.itemsize @@ -374,10 +444,7 @@ def _size_of(node: fx.Node) -> int: if "val" in node.meta: val = node.meta["val"] if isinstance(val, py_sym_types): - if isinstance(val, torch.SymInt): - return 1 - else: - return 999999 + return 1 # NB: The fallback values here are meaningless, maybe we should respect # torch._inductor.config.unbacked_symint_fallback (but this is a # layering violation) @@ -391,28 +458,18 @@ def _size_of(node: fx.Node) -> int: return _tensor_nbytes(hint_int(val.numel(), fallback=4098), val.dtype) raise RuntimeError(f"Unknown metadata type {type(val)}") - - # Only needed since we don't always trace with fake tensors. - if "tensor_meta" in node.meta: - metadata = node.meta["tensor_meta"] - # TODO: What is to_size_hint suppose to be? - numel = _prod(map(to_size_hint, metadata.shape)) # noqa: F821 - dtype = metadata.dtype - else: - return 0 - - return _tensor_nbytes(numel, dtype) + raise RuntimeError("We should always have `val` metadata on the nodes") # Used for some investigative purposes -def _count_ops(graph): +def _count_ops(graph: fx.Graph): from collections import defaultdict - cnt = defaultdict(int) + cnt: Dict[str, int] = defaultdict(int) for node in graph.nodes: if node.op == "call_function": cnt[node.target.__name__] += 1 - print(sorted(cnt.items(), key=operator.itemgetter(1), reverse=True)) + print(sorted(cnt.items(), key=lambda x: x[1], reverse=True)) @functools.lru_cache(None) @@ -433,14 +490,14 @@ def pointwise_ops(): return ops -def sort_depths(args, depth_map): +def sort_depths(args, depth_map: Dict[fx.Node, int]) -> List[Tuple[fx.Node, int]]: arg_depths = { arg: depth_map[arg] for arg in args if isinstance(arg, torch.fx.node.Node) } - return sorted(arg_depths.items(), key=operator.itemgetter(1), reverse=True) + return sorted(arg_depths.items(), key=lambda x: x[1], reverse=True) -def reordering_to_mimic_autograd_engine(gm): +def reordering_to_mimic_autograd_engine(gm: fx.GraphModule) -> fx.GraphModule: """ This pass finds the first bwd node in the graph (by looking at users of tangents) and then reorders the graph by walking from this node to all the @@ -464,7 +521,7 @@ def reordering_to_mimic_autograd_engine(gm): """ new_graph = fx.Graph() - env = {} + env: Dict[fx.Node, fx.Node] = {} # Add new placeholder nodes in the order specified by the inputs for node in gm.graph.find_nodes(op="placeholder"): @@ -517,7 +574,12 @@ def insert_node_in_graph(node): return new_gm -def functionalize_rng_ops(joint_module, fw_module, bw_module, num_sym_nodes): +def functionalize_rng_ops( + joint_module: fx.GraphModule, + fw_module: fx.GraphModule, + bw_module: fx.GraphModule, + num_sym_nodes: int, +) -> Tuple[fx.GraphModule, fx.GraphModule]: # During user-driven activation checkpointing, we have to ensure that a rng # op in fwd yields the same output as the recomputed rng op in the bwd. To # do this, we use functionalize wrappers to wrap the random ops and share @@ -591,11 +653,15 @@ def get_sample_rng_state(device): run_and_save_rng = torch._prims.rng_prims.run_and_save_rng_state run_with_rng_state = torch._prims.rng_prims.run_with_rng_state - + bw_tangent_start_node = None for node in bw_module.graph.find_nodes(op="placeholder"): if "tangent" in node.name: bw_tangent_start_node = node break + if bw_tangent_start_node is None: + raise RuntimeError( + "Couldn't find tangent node in graph inputs. This is unexpected, please file a bug if you see this" + ) fw_rng_state_outputs = [] for base_node, node_pair in recomputable_rng_ops_map.items(): @@ -665,7 +731,7 @@ def get_sample_rng_state(device): return fw_module, bw_module -def cleanup_recompute_tags(joint_module): +def cleanup_recompute_tags(joint_module: fx.GraphModule) -> fx.GraphModule: """ If there are two consecutive checkpointed blocks with no operator in between, we would still want to stash the tensor at the boundary of @@ -683,332 +749,50 @@ def cleanup_recompute_tags(joint_module): return joint_module -def min_cut_rematerialization_partition( - joint_module: fx.GraphModule, - _joint_inputs, - compiler="inductor", - recomputable_ops=None, - *, - num_fwd_outputs, -) -> Tuple[fx.GraphModule, fx.GraphModule]: - """ - Partitions the joint graph such that the backward recomputes the forward. - Recomputing helps in trading off memory bandwidth with computation. - - To create the fwd and bwd graph, we copy the joint graph, manually set the - outputs to just original forward or backward outputs. And then we run the - resulting graphs through dead code elimination. - - .. warning:: - This API is experimental and likely to change. - - Args: - joint_module(fx.GraphModule): The joint forward and backward graph. This - is the result of AOT Autograd tracing. - _joint_inputs: The inputs to the joint graph. This is unused. - compiler: This option determines the default set of recomputable ops. - Currently, there are two options: ``nvfuser`` and ``inductor``. - recomputable_ops: This is an optional set of recomputable ops. If this - is not None, then this set of ops will be used instead of the - default set of ops. - num_fwd_outputs: The number of outputs from the forward graph. - - Returns: - Returns the generated forward and backward Fx graph modules. - """ - try: - import networkx as nx - except ImportError as e: - raise RuntimeError( - "Need networkx installed to perform smart recomputation " "heuristics" - ) from e - - joint_module.graph.eliminate_dead_code() - joint_module.recompile() - - fx_g = joint_module.graph - - # add the CSE pass - if config.cse: - cse_graph = fx_graph_cse(fx_g) - joint_module.graph = cse_graph - joint_graph = joint_module.graph - - graph_has_recomputable_ops = has_recomputable_ops(joint_module) - graph_has_recomputable_rng_ops = has_recomputable_rng_ops(joint_module) - if graph_has_recomputable_ops: - joint_module = cleanup_recompute_tags(joint_module) - - name_to_node = {} - for node in joint_module.graph.nodes: - name_to_node[node.name] = node - - def classify_nodes(joint_module): - required_bw_nodes = set() - for node in joint_module.graph.nodes: - if node.op == "placeholder" and "tangents" in node.target: - required_bw_nodes.add(node) - if node in required_bw_nodes: - required_bw_nodes.update(node.users) - - primal_inputs = list(filter(_is_primal, joint_module.graph.nodes)) - fwd_seed_offset_inputs = list( - filter(_is_fwd_seed_offset, joint_module.graph.nodes) - ) - inputs = primal_inputs + fwd_seed_offset_inputs - fwd_outputs, bwd_outputs = _extract_fwd_bwd_outputs( - joint_module, num_fwd_outputs=num_fwd_outputs - ) - required_bw_nodes.update( - o for o in bwd_outputs if o is not None and o.op != "output" - ) - forward_only_graph = _extract_graph_with_inputs_outputs( - joint_module.graph, inputs, fwd_outputs - ) - required_fw_nodes = { - name_to_node[node.name] - for node in forward_only_graph.nodes - if node.op != "output" - } - unclaimed_nodes = { - node - for node in joint_module.graph.nodes - if node not in required_fw_nodes and node not in required_bw_nodes - } - return ( - fwd_outputs, - required_fw_nodes, - required_bw_nodes, - unclaimed_nodes, - inputs, - ) - - ( - orig_fw_outputs, - required_fw_nodes, - required_bw_nodes, - unclaimed_nodes, - inputs, - ) = classify_nodes(joint_module) - - # networkx blows up on graphs with no required backward nodes - # Since there's nothing to partition anyway, and the default partitioner can "handle" - # this case, send our graph over to the default partitioner. - if len(required_bw_nodes) == 0: - return default_partition( - joint_module, _joint_inputs, num_fwd_outputs=num_fwd_outputs - ) - - def is_fusible(a, b): - # We can perform "memory fusion" into a cat, but cat cannot be a - # producer to a fusion - if get_aten_target(b) == aten.cat: - return True - return get_aten_target(a) in fusible_ops and get_aten_target(b) in fusible_ops - - fw_order = 0 - for node in joint_module.graph.nodes: - if node in required_fw_nodes: - node.fw_order = fw_order - fw_order += 1 - - for node in reversed(joint_module.graph.nodes): - if node not in required_fw_nodes: - node.dist_from_bw = 0 - else: - node.dist_from_bw = int(1e9) - for user in node.users: - node.dist_from_bw = min(node.dist_from_bw, user.dist_from_bw + 1) - - aten = torch.ops.aten - prims = torch.ops.prims - - # compiler == "nvfuser" is the default set of recomputable ops - default_recomputable_ops = [ - aten.add, - aten.sub, - aten.div, - aten.atan2, - aten.mul, - aten.max, - aten.min, - aten.pow, - aten.remainder, - aten.fmod, - aten.__and__, - aten.__or__, - aten.__xor__, - aten.__lshift__, - aten.__rshift__, - aten.eq, - aten.ne, - aten.ge, - aten.gt, - aten.le, - aten.lt, - aten.abs, - aten.bitwise_not, - aten.ceil, - aten.floor, - aten.frac, - aten.neg, - aten.relu, - aten.round, - aten.silu, - aten.trunc, - aten.log, - aten.log10, - aten.log1p, - aten.log2, - aten.lgamma, - aten.exp, - aten.expm1, - aten.erf, - aten.erfc, - aten.cos, - aten.acos, - aten.cosh, - aten.sin, - aten.asin, - aten.sinh, - aten.tan, - aten.atan, - aten.tanh, - aten.atanh, - aten.sqrt, - aten.rsqrt, - aten.reciprocal, - aten.sigmoid, - aten.softplus, - aten.threshold, - aten.threshold_backward, - aten.clamp, - aten.where, - aten.lerp, - aten.addcmul, - aten.gelu, - aten.gelu_backward, - aten.sum, - aten.mean, - aten._grad_sum_to_size, - aten.sum_to_size, - aten.amax, - aten.to, - aten.type_as, - operator.getitem, - aten.squeeze, - aten.unsqueeze, - aten.rsub, - aten._to_copy, - ] # noqa: E501,B950 - view_ops = [aten.squeeze, aten.unsqueeze, aten.alias] - if compiler == "inductor": - default_recomputable_ops += [ - prims.div, - prims.convert_element_type, - aten.clone, - aten._to_copy, - aten.full_like, - prims.var, - prims.sum, - aten.var, - aten.std, - prims.broadcast_in_dim, - aten.select, - aten._unsafe_view, - aten.view, - aten.expand, - aten.slice, - aten.reshape, - aten.broadcast_tensors, - aten.scalar_tensor, - aten.ones, - aten.new_zeros, - aten.lift_fresh_copy, - aten.arange, - aten.triu, - aten.var_mean, - aten.isinf, - aten.any, - aten.full, - aten.as_strided, - aten.zeros, - aten.argmax, - aten.maximum, - prims.iota, - prims._low_memory_max_pool2d_offsets_to_indices, - ] # noqa: E501,B950 - view_ops += [ - aten.view, - aten.slice, - aten.t, - prims.broadcast_in_dim, - aten.expand, - aten.as_strided, - aten.permute, - ] - # Natalia said that we should allow recomputing indexing :) - default_recomputable_ops += [aten.index, aten.gather] - default_recomputable_ops += view_ops - - default_recomputable_ops += pointwise_ops() - - default_recomputable_ops += [ - aten.zeros_like, - ] - - default_recomputable_ops += [method_to_operator(m) for m in magic_methods] - recomputable_ops = ( - set(recomputable_ops) - if recomputable_ops is not None - else set(default_recomputable_ops) - ) - - random_ops = [aten.native_dropout, aten.rand_like, aten.randn_like] - compute_intensive_ops = [ - aten.mm, - aten.convolution, - aten.convolution_backward, - aten.bmm, - aten.addmm, - aten._scaled_dot_product_flash_attention, - aten._scaled_dot_product_efficient_attention, - aten.upsample_bilinear2d, - ] # noqa: E501,B950 +def get_saved_values( + joint_graph: fx.Graph, + node_info: NodeInfo, + min_cut_options: MinCutOptions, + dont_ban=None, +): + if dont_ban is None: + dont_ban = set() + op_types = get_default_op_list() - fusible_ops = recomputable_ops | set(random_ops) if AOT_PARTITIONER_DEBUG: joint_module_ops = { str(node.target._overloadpacket) - for node in joint_module.graph.nodes + for node in joint_graph.nodes if node.op == "call_function" and hasattr(node.target, "_overloadpacket") } - ops_ignored = joint_module_ops - {str(i) for i in recomputable_ops} + ops_ignored = joint_module_ops - {str(i) for i in op_types.recomputable_ops} print("Ops banned from rematerialization: ", ops_ignored) print() - BAN_IF_USED_FAR_APART = config.ban_recompute_used_far_apart - BAN_IF_LONG_FUSIBLE_CHAINS = config.ban_recompute_long_fusible_chains - BAN_IF_MATERIALIZED_BACKWARDS = config.ban_recompute_materialized_backward - BAN_IF_NOT_IN_ALLOWLIST = config.ban_recompute_not_in_allowlist - BAN_IF_REDUCTION = config.ban_recompute_reductions + def is_fusible(a, b): + # We can perform "memory fusion" into a cat, but cat cannot be a + # producer to a fusion + if get_aten_target(b) == aten.cat: + return True + return op_types.is_fusible(a) and op_types.is_fusible(b) - if config.aggressive_recomputation: - BAN_IF_MATERIALIZED_BACKWARDS = False - BAN_IF_USED_FAR_APART = False - BAN_IF_LONG_FUSIBLE_CHAINS = False - BAN_IF_NOT_IN_ALLOWLIST = False + try: + import networkx as nx + except ImportError as e: + raise RuntimeError( + "Need networkx installed to perform smart recomputation " "heuristics" + ) from e def is_materialized_backwards(node): - if get_aten_target(node) in view_ops: + if op_types.is_view(node): return False cur_nodes = {node} while len(cur_nodes) > 0: cur = cur_nodes.pop() for user in cur.users: - if user not in required_fw_nodes and not is_fusible(cur, user): + if not node_info.is_required_fw(user) and not is_fusible(cur, user): return True - if get_aten_target(user) in view_ops: + if op_types.is_view(user): cur_nodes.add(user) return False @@ -1020,17 +804,15 @@ def should_ban_recomputation(node): return False if node.target in [aten.lift_fresh_copy.default, aten.lift_fresh.default]: return False - # NB: "recompute" == 0 means that must save this node. if node.meta.get("recompute", None) == 0: return True - if BAN_IF_NOT_IN_ALLOWLIST: - if get_aten_target(node) not in recomputable_ops: + if min_cut_options.ban_if_not_in_allowlist: + if not op_types.is_recomputable(node): return True else: - ignored_ops = random_ops + compute_intensive_ops - if get_aten_target(node) in ignored_ops: + if op_types.is_random(node) or op_types.is_compute_intensive(node): return True # If a node *must* be materialized in the backwards pass, then we @@ -1038,7 +820,9 @@ def should_ban_recomputation(node): # general, the assumption we make is that recomputing a node in the # backwards pass is "free". However, if a node must be materialized # in the backwards pass, then recomputing it is never free. - if is_materialized_backwards(node) and BAN_IF_MATERIALIZED_BACKWARDS: + if min_cut_options.ban_if_materialized_backward and is_materialized_backwards( + node + ): log.info("materialized backwards: %s %s", node, tuple(node.users)) return True @@ -1046,16 +830,15 @@ def should_ban_recomputation(node): # modification appears to have made this heuristic a lot less critical # for performance. # NB: As of PR #121692, this hack no longer seems necessary. - if not graph_has_recomputable_ops: - if compiler == "inductor" and node.dist_from_bw > config.max_dist_from_bw: - return True + if node.dist_from_bw < 1000 and node.dist_from_bw > config.max_dist_from_bw: + return True # If the output of an op is 4x smaller (arbitrary choice), # then we don't allow recomputation. The idea here is that for # things like reductions, saving the output of the reduction is very # cheap/small, and it makes sure we don't do things like recompute # normalizations in the backwards. - if BAN_IF_REDUCTION: + if min_cut_options.ban_if_reduction: input_tensors_size = sum( _size_of(i) for i in node.args if isinstance(i, fx.Node) ) @@ -1069,9 +852,14 @@ def is_materialized(node): return not all(is_fusible(node, user) for user in node.users) - def get_node_weight(node) -> int: + def get_node_weight(node) -> float: mem_sz = _size_of(node) + if isinstance(node.meta["val"], py_sym_types): + # We never want to save symfloats + if not isinstance(node.meta["val"], torch.SymInt): + return INT_INF + # Heuristic to bias towards nodes closer to the backwards pass # Complete guess about current value mem_sz = int(mem_sz * (1.1 ** max(min(node.dist_from_bw, 100), 1))) @@ -1084,6 +872,11 @@ def get_node_weight(node) -> int: banned_nodes = set() def ban_recomputation_if_allowed(node): + if op_types.is_view(node): + return False + if node in dont_ban: + return False + # breakpoint() # This bans recomputation of the node unless we've been forced not to by # user annotation # NB: "recompute" > 0 means that user annotation has asked us to @@ -1106,8 +899,8 @@ def ban_recomputation_if_allowed(node): if node.op == "output": continue - if node in required_bw_nodes: - if node not in inputs: + if node in node_info.required_bw_nodes: + if node not in node_info.inputs: nx_graph.add_edge(node.name + "_in", "sink", capacity=math.inf) continue # If someone saves a input for backward as-is and backward @@ -1126,7 +919,7 @@ def ban_recomputation_if_allowed(node): # If a node can't be recomputed (too expensive or involves randomness), # we prevent it from being recomputed by adding an inf edge to the source # We only need to ban nodes in the fw pass, as those are the only ones that would be recomputed. - if node in required_fw_nodes and should_ban_recomputation(node): + if node_info.is_required_fw(node) and should_ban_recomputation(node): ban_recomputation_if_allowed(node) # Checks if a node is actually a tuple. Can be simplified to just an isinstance check if we always use faketensors. @@ -1135,12 +928,13 @@ def ban_recomputation_if_allowed(node): ) or ("val" in node.meta and not isinstance(node.meta["val"], torch.Tensor)) if is_sym_node(node): - weight = sym_node_size(node) + weight = float(sym_node_size(node)) elif is_non_tensor_node: - weight = 0 if isinstance(node.meta.get("val"), BackwardState) else math.inf + weight = ( + 0.0 if isinstance(node.meta.get("val"), BackwardState) else math.inf + ) else: weight = get_node_weight(node) - # Creates the weights on the "node" edge nx_graph.add_edge(node.name + "_in", node.name + "_out", capacity=weight) for user in node.users: @@ -1168,35 +962,40 @@ def find_first_unfusible(start_nodes: List[fx.Node], max_range: int) -> int: Finds the first unfusible node in the chain of nodes starting from `start_nodes` and returns its position. """ - sorted_nodes = [] + sorted_nodes: List[Tuple[int, fx.Node, bool]] = [] for n in start_nodes: - heapq.heappush(sorted_nodes, (n.fw_order, n, True)) + heapq.heappush(sorted_nodes, (node_info.get_fw_order(n), n, True)) while len(sorted_nodes) > 0: _, node, node_is_fusible = heapq.heappop(sorted_nodes) if not node_is_fusible: - return node.fw_order + return node_info.get_fw_order(node) for user in node.users: - if user in required_fw_nodes: - if user.fw_order > max_range: + if node_info.is_required_fw(user): + if node_info.get_fw_order(user) > max_range: continue heapq.heappush( - sorted_nodes, (user.fw_order, user, is_fusible(node, user)) + sorted_nodes, + (node_info.get_fw_order(user), user, is_fusible(node, user)), ) return max_range - if BAN_IF_USED_FAR_APART: - for used_node in required_fw_nodes: + if min_cut_options.ban_if_used_far_apart: + for used_node in node_info.required_fw_nodes: orders = [ - user.fw_order for user in used_node.users if user in required_fw_nodes + node_info.get_fw_order(user) + for user in used_node.users + if user in node_info.required_fw_nodes + ] + fw_users = [ + user for user in used_node.users if node_info.is_required_fw(user) ] - fw_users = [user for user in used_node.users if user in required_fw_nodes] if len(orders) > 0: first_unfusible_use = find_first_unfusible(fw_users, max(orders)) for user in tuple(used_node.users): if ( - user in required_fw_nodes - and user.fw_order > first_unfusible_use + user in node_info.required_fw_nodes + and node_info.get_fw_order(user) > first_unfusible_use and is_fusible(used_node, user) ): if user in banned_nodes: @@ -1204,10 +1003,10 @@ def find_first_unfusible(start_nodes: List[fx.Node], max_range: int) -> int: log.info( "used above/below fusible %s:(%s) -> %s -> %s:(%s)", used_node, - used_node.fw_order, + node_info.get_fw_order(used_node), first_unfusible_use, user, - user.fw_order, + node_info.get_fw_order(user), ) ban_recomputation_if_allowed(user) @@ -1222,47 +1021,51 @@ def find_first_unfusible(start_nodes: List[fx.Node], max_range: int) -> int: # Some models it improves perf on are cait_m36_384, mixer_b16_224, poolformer_m36 - if BAN_IF_LONG_FUSIBLE_CHAINS: + if min_cut_options.ban_if_long_fusible_chains: visited = set() for start_node in joint_graph.nodes: - if start_node not in required_fw_nodes: + if start_node not in node_info.required_fw_nodes: continue - fusible = [(start_node.fw_order, start_node)] - start_order = start_node.fw_order + fusible = [(node_info.get_fw_order(start_node), start_node)] + start_order = node_info.get_fw_order(start_node) while len(fusible) > 0: _, cur = heapq.heappop(fusible) if cur in visited: continue visited.add(cur) # 100 is arbitrary choice to try and prevent degenerate cases - if cur.fw_order > start_order + 100 and len(fusible) == 0: + if ( + node_info.get_fw_order(cur) > start_order + 100 + and len(fusible) == 0 + ): log.info( "too long %s %s %s %s", cur, start_node, - cur.fw_order, - start_node.fw_order, + node_info.get_fw_order(cur), + node_info.get_fw_order(start_node), ) ban_recomputation_if_allowed(cur) break for user in cur.users: if ( - user in required_fw_nodes + user in node_info.required_fw_nodes and is_fusible(cur, user) and user not in banned_nodes ): - heapq.heappush(fusible, (user.fw_order, user)) + heapq.heappush(fusible, (node_info.get_fw_order(user), user)) try: cut_value, partition = nx.minimum_cut(nx_graph, "source", "sink") except Exception: print("Failed to compute min-cut on following graph:") print("\n".join(nx.readwrite.edgelist.generate_edgelist(nx_graph))) + visualize_min_cut_graph(nx_graph) raise reachable, non_reachable = partition - cutset = set() + cutset: Set[Tuple[str, str]] = set() for u, nbrs in ((n, nx_graph[n]) for n in reachable): cutset.update((u, v) for v in nbrs if v in non_reachable) @@ -1272,14 +1075,347 @@ def find_first_unfusible(start_nodes: List[fx.Node], max_range: int) -> int: node_name = node_in[:-3] cut_nodes.add(node_name) + name_to_node = get_name_to_node(joint_graph) # To make this stuff deterministic - node_idx = {node: idx for idx, node in enumerate(joint_module.graph.nodes)} + node_idx = {node: idx for idx, node in enumerate(joint_graph.nodes)} saved_values = sorted( (name_to_node[node] for node in cut_nodes), key=lambda x: node_idx[x] ) + return saved_values, banned_nodes + + +def visualize_min_cut_graph(nx_graph): + import networkx as nx + import pydot + + dot_format = nx.nx_pydot.to_pydot(nx_graph).to_string() + dot_graph = pydot.graph_from_dot_data(dot_format)[0] + for edge in dot_graph.get_edges(): + weight = nx_graph[edge.get_source()][edge.get_destination()]["capacity"] + # Set edge label to weight + edge.set_label(str(weight)) + # Color edges with weight 'inf' as red + if weight == float("inf"): + edge.set_color("red") + print("Visualizing the failed graph to min_cut_failed.svg") + dot_graph.write_svg("min_cut_failed.svg") + + +def get_default_op_list() -> OpTypes: + default_recomputable_ops: List[Callable] = [ + aten.add, + aten.sub, + aten.div, + aten.atan2, + aten.mul, + aten.max, + aten.min, + aten.pow, + aten.remainder, + aten.fmod, + aten.__and__, + aten.__or__, + aten.__xor__, + aten.__lshift__, + aten.__rshift__, + aten.eq, + aten.ne, + aten.ge, + aten.gt, + aten.le, + aten.lt, + aten.abs, + aten.bitwise_not, + aten.ceil, + aten.floor, + aten.frac, + aten.neg, + aten.relu, + aten.round, + aten.silu, + aten.trunc, + aten.log, + aten.log10, + aten.log1p, + aten.log2, + aten.lgamma, + aten.exp, + aten.expm1, + aten.erf, + aten.erfc, + aten.cos, + aten.acos, + aten.cosh, + aten.sin, + aten.asin, + aten.sinh, + aten.tan, + aten.atan, + aten.tanh, + aten.atanh, + aten.sqrt, + aten.rsqrt, + aten.reciprocal, + aten.sigmoid, + aten.softplus, + aten.threshold, + aten.threshold_backward, + aten.clamp, + aten.where, + aten.lerp, + aten.addcmul, + aten.gelu, + aten.gelu_backward, + aten.sum, + aten.mean, + aten._grad_sum_to_size, + aten.sum_to_size, + aten.amax, + aten.to, + aten.type_as, + operator.getitem, + aten.squeeze, + aten.unsqueeze, + aten.rsub, + aten._to_copy, + ] # noqa: E501,B950 + recomputable_view_ops = [aten.squeeze, aten.unsqueeze, aten.alias] + recomputable_view_ops += [ + aten.view, + aten.slice, + aten.t, + prims.broadcast_in_dim, + aten.expand, + aten.as_strided, + aten.permute, + ] + view_ops = recomputable_view_ops + default_recomputable_ops += [ + prims.div, + prims.convert_element_type, + aten.clone, + aten._to_copy, + aten.full_like, + prims.var, + prims.sum, + aten.var, + aten.std, + prims.broadcast_in_dim, + aten.select, + aten._unsafe_view, + aten.view, + aten.expand, + aten.slice, + aten.reshape, + aten.broadcast_tensors, + aten.scalar_tensor, + aten.ones, + aten.new_zeros, + aten.lift_fresh_copy, + aten.arange, + aten.triu, + aten.var_mean, + aten.isinf, + aten.any, + aten.full, + aten.as_strided, + aten.zeros, + aten.argmax, + aten.maximum, + prims.iota, + prims._low_memory_max_pool2d_offsets_to_indices, + ] # noqa: E501,B950 + # Natalia said that we should allow recomputing indexing :) + default_recomputable_ops += [aten.index, aten.gather] + default_recomputable_ops += view_ops + + default_recomputable_ops += pointwise_ops() + + default_recomputable_ops += [ + aten.zeros_like, + ] + + default_recomputable_ops += [method_to_operator(m) for m in magic_methods] + recomputable_ops = set(default_recomputable_ops) + + random_ops = [aten.native_dropout, aten.rand_like, aten.randn_like] + compute_intensive_ops = [ + aten.mm, + aten.convolution, + aten.convolution_backward, + aten.bmm, + aten.addmm, + aten._scaled_dot_product_flash_attention, + aten._scaled_dot_product_efficient_attention, + aten.upsample_bilinear2d, + ] # noqa: E501,B950 + + fusible_ops = recomputable_ops | set(random_ops) + return OpTypes( + set(fusible_ops), + set(compute_intensive_ops), + set(random_ops), + set(view_ops), + set(recomputable_ops), + ) + + +def get_name_to_node(graph: fx.Graph): + name_to_node = {} + for node in graph.nodes: + name_to_node[node.name] = node + return name_to_node + + +def choose_saved_values_set( + joint_graph: fx.Graph, node_info: NodeInfo, memory_budget=1 +) -> List[fx.Node]: + min_cut_options = MinCutOptions( + ban_if_used_far_apart=config.ban_recompute_used_far_apart, + ban_if_long_fusible_chains=config.ban_recompute_long_fusible_chains, + ban_if_materialized_backward=config.ban_recompute_materialized_backward, + ban_if_not_in_allowlist=config.ban_recompute_not_in_allowlist, + ban_if_reduction=config.ban_recompute_reductions, + ) + + if config.aggressive_recomputation: + min_cut_options = replace( + min_cut_options, + ban_if_used_far_apart=False, + ban_if_long_fusible_chains=False, + ban_if_materialized_backward=False, + ban_if_not_in_allowlist=False, + ) + + if memory_budget == 0: + return node_info.inputs + + runtime_optimized_saved_values, _ = get_saved_values( + joint_graph, + node_info, + min_cut_options, + ) + return runtime_optimized_saved_values + + +def min_cut_rematerialization_partition( + joint_module: fx.GraphModule, + _joint_inputs, + compiler="inductor", + *, + num_fwd_outputs, +) -> Tuple[fx.GraphModule, fx.GraphModule]: + """ + Partitions the joint graph such that the backward recomputes the forward. + Recomputing helps in trading off memory bandwidth with computation. + + To create the fwd and bwd graph, we copy the joint graph, manually set the + outputs to just original forward or backward outputs. And then we run the + resulting graphs through dead code elimination. + + .. warning:: + This API is experimental and likely to change. + + Args: + joint_module(fx.GraphModule): The joint forward and backward graph. This + is the result of AOT Autograd tracing. + _joint_inputs: The inputs to the joint graph. This is unused. + compiler: This option determines the default set of recomputable ops. + Currently, there are two options: ``nvfuser`` and ``inductor``. + recomputable_ops: This is an optional set of recomputable ops. If this + is not None, then this set of ops will be used instead of the + default set of ops. + num_fwd_outputs: The number of outputs from the forward graph. + + Returns: + Returns the generated forward and backward Fx graph modules. + """ + + joint_module.graph.eliminate_dead_code() + joint_module.recompile() + + fx_g = joint_module.graph + + # add the CSE pass + if config.cse: + cse_graph = fx_graph_cse(fx_g) + joint_module.graph = cse_graph + joint_graph = joint_module.graph + + graph_has_recomputable_ops = has_recomputable_ops(joint_module) + graph_has_recomputable_rng_ops = has_recomputable_rng_ops(joint_module) + if graph_has_recomputable_ops: + joint_module = cleanup_recompute_tags(joint_module) + + def classify_nodes(joint_module): + name_to_node = get_name_to_node(joint_module.graph) + required_bw_nodes = set() + for node in joint_module.graph.nodes: + if node.op == "placeholder" and "tangents" in node.target: + required_bw_nodes.add(node) + if node in required_bw_nodes: + for user in node.users: + required_bw_nodes.add(user) + + primal_inputs = list(filter(_is_primal, joint_module.graph.nodes)) + fwd_seed_offset_inputs = list( + filter(_is_fwd_seed_offset, joint_module.graph.nodes) + ) + inputs = primal_inputs + fwd_seed_offset_inputs + fwd_outputs, bwd_outputs = _extract_fwd_bwd_outputs( + joint_module, num_fwd_outputs=num_fwd_outputs + ) + required_bw_nodes.update( + o for o in bwd_outputs if o is not None and o.op != "output" + ) + forward_only_graph = _extract_graph_with_inputs_outputs( + joint_module.graph, inputs, fwd_outputs + ) + required_fw_nodes: Set[fx.Node] = { + name_to_node[node.name] + for node in forward_only_graph.nodes + if node.op != "output" + } + unclaimed_nodes = { + node + for node in joint_module.graph.nodes + if node not in required_fw_nodes and node not in required_bw_nodes + } + fw_cnt = 0 + fw_order = {} + for node in joint_module.graph.nodes: + if node in required_fw_nodes: + fw_order[node] = fw_cnt + fw_cnt += 1 + return NodeInfo( + inputs, required_fw_nodes, required_bw_nodes, unclaimed_nodes, fw_order + ) + + node_info = classify_nodes(joint_module) + + # networkx blows up on graphs with no required backward nodes + # Since there's nothing to partition anyway, and the default partitioner can "handle" + # this case, send our graph over to the default partitioner. + if len(node_info.required_bw_nodes) == 0: + return default_partition( + joint_module, _joint_inputs, num_fwd_outputs=num_fwd_outputs + ) + + for node in reversed(joint_module.graph.nodes): + if node.op == "output": + node.dist_from_bw = int(1e9) + elif node not in node_info.required_fw_nodes: + node.dist_from_bw = 0 + else: + node.dist_from_bw = int(1e9) + for user in node.users: + node.dist_from_bw = min(node.dist_from_bw, user.dist_from_bw + 1) + + saved_values = choose_saved_values_set(joint_graph, node_info, memory_budget=1) # save_for_backward on tensors and stashes symints in autograd .ctx saved_sym_nodes = list(filter(is_sym_node, saved_values)) saved_values = list(filter(lambda n: not is_sym_node(n), saved_values)) + # NB: saved_sym_nodes will be mutated to reflect the actual saved symbols fw_module, bw_module = _extract_fwd_bwd_modules( joint_module, @@ -1312,7 +1448,7 @@ def find_first_unfusible(start_nodes: List[fx.Node], max_range: int) -> int: } remat_nodes = fw_module_nodes & bw_module_nodes - counts = defaultdict(int) + counts: Dict[str, int] = defaultdict(int) for node in fw_module.graph.nodes: if node.name in remat_nodes and hasattr(node.target, "_overloadpacket"): counts[str(node.target._overloadpacket)] += 1 @@ -1321,7 +1457,7 @@ def find_first_unfusible(start_nodes: List[fx.Node], max_range: int) -> int: ) print( "Count of Ops Rematerialized: ", - sorted(counts.items(), key=operator.itemgetter(1), reverse=True), + sorted(counts.items(), key=lambda x: x[1], reverse=True), ) return fw_module, bw_module @@ -1331,7 +1467,7 @@ def draw_graph( fname: str, figname: str = "fx_graph", clear_meta: bool = True, - prog: Union[str, List[str]] = None, + prog: Optional[Union[str, List[str]]] = None, parse_stack_trace: bool = False, dot_graph_shape: Optional[str] = None, ) -> None: @@ -1357,13 +1493,3 @@ def draw_graph( write_method(fname) else: write_method(fname, prog=prog) - - -def draw_joint_graph( - graph: torch.fx.GraphModule, - joint_inputs, - file_name: str = "full_graph.png", - dot_graph_shape: Optional[str] = None, -): - draw_graph(graph, file_name, dot_graph_shape=dot_graph_shape) - return default_partition(graph, joint_inputs)