From 0f45f621b1c5c6dfbbd41ed08dedea6684032c69 Mon Sep 17 00:00:00 2001 From: Tianyu Liu Date: Tue, 6 Aug 2024 16:28:04 -0700 Subject: [PATCH 1/2] address TODOs as 2D recompiles is fixed [ghstack-poisoned] --- torchtitan/parallelisms/parallelize_llama.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index 03540552..06cdbcc5 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -185,9 +185,6 @@ def apply_tp( if enable_async_tp: from torch.distributed._symmetric_memory import enable_symm_mem_for_group - # TODO: remove cache_size_limit adjustment after 2D compile is fixed - torch._dynamo.config.cache_size_limit = 10000 - torch._inductor.config._micro_pipeline_tp = True enable_symm_mem_for_group(tp_mesh.get_group().group_name) @@ -282,10 +279,9 @@ def apply_ac(model: nn.Module, ac_config): def apply_compile(model: nn.Module): """Apply torch.compile to each transformer block.""" - # the following flag can be used to to accelarate per-TransformerBlock compilation - # TODO(bdhirsh): turning it off because it's currently not working with 2D + # the following flag is used to accelarate per-TransformerBlock compilation # TODO(anijain): remove it after it's enabled in pytorch by default - # torch._dynamo.config.inline_inbuilt_nn_modules = True + torch._dynamo.config.inline_inbuilt_nn_modules = True for layer_id, transformer_block in model.layers.named_children(): transformer_block = torch.compile(transformer_block, fullgraph=True) From eada944d1981e0627ba08e28ce2806f6fc803b32 Mon Sep 17 00:00:00 2001 From: Tianyu Liu Date: Wed, 7 Aug 2024 10:30:12 -0700 Subject: [PATCH 2/2] Update on "address TODOs as 2D recompiles is fixed" This PR - enables `inline_inbuilt_nn_modules` for block-level compilation - removes `cache_size_limit` adjustment in async TP [ghstack-poisoned] --- torchtitan/parallelisms/parallelize_llama.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index 06cdbcc5..a300c644 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -277,17 +277,15 @@ def apply_ac(model: nn.Module, ac_config): def apply_compile(model: nn.Module): - """Apply torch.compile to each transformer block.""" - - # the following flag is used to accelarate per-TransformerBlock compilation - # TODO(anijain): remove it after it's enabled in pytorch by default - torch._dynamo.config.inline_inbuilt_nn_modules = True - + """ + Apply torch.compile to each TransformerBlock, which makes compilation efficient due to + repeated structure. Alternatively one can compile the whole model (after applying DP). + """ for layer_id, transformer_block in model.layers.named_children(): transformer_block = torch.compile(transformer_block, fullgraph=True) model.layers.register_module(layer_id, transformer_block) - logger.info("Compiled each TransformerBlock with torch.compile") + logger.info("Compiling each TransformerBlock with torch.compile") return model