Skip to content

Commit

Permalink
some compile-related improvements
Browse files Browse the repository at this point in the history
ghstack-source-id: 7c4a65c26a8f573222f0a14448ba8258ed893028
Pull Request resolved: #443
  • Loading branch information
tianyu-l committed Jul 10, 2024
1 parent 3fca883 commit 2f23216
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 12 deletions.
9 changes: 9 additions & 0 deletions test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,15 @@ def build_test_list():
"1D compile",
"1d_compile",
),
OverrideDefinitions(
[
[
"--training.compile --model.norm_type=rmsnorm --selective_ac_option=op",
],
],
"1D compile with selective op AC",
"1d_compile_sac_op",
),
OverrideDefinitions(
[
[
Expand Down
16 changes: 4 additions & 12 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,22 +432,14 @@ def apply_compile(model, job_config: JobConfig):
"fused_rmsnorm not yet compatible with torch.compile. Please use layernorm or rmsnorm."
)

# NOTE(anijain): enable the following flag to accelarate compilation
torch._dynamo.config.inline_inbuilt_nn_modules = True

for layer_id, transformer_block in model.layers.named_children():
# turn on per-transformer block compile after AC wrapping and before FSDP
# TODO: dynamic shape have some issues so we turn it off for now.
# TODO: inline inbuilt nn modules does not work yet, enable it to accelarate
# compile time.
# torch._dynamo.config.inline_inbuilt_nn_modules = True
transformer_block = torch.compile(transformer_block, dynamic=False)
transformer_block = torch.compile(transformer_block, fullgraph=True)
model.layers.register_module(layer_id, transformer_block)

ac_config = job_config.activation_checkpoint
if ac_config.mode == "selective" and ac_config.selective_ac_option == "op":
# some temp flags for torch.compile enablement + SAC
torch._dynamo.config._experimental_support_context_fn_in_torch_utils_checkpoint = (
True
)

logger.info("Compiled each TransformerBlock with torch.compile")
return model

Expand Down

0 comments on commit 2f23216

Please sign in to comment.