From 652d7b80b6f62dd963336d7ad9fc9b399907fb7d Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Tue, 3 Dec 2024 18:54:47 -0800 Subject: [PATCH 1/3] Add checkpoint load step --- torchtitan/config_manager.py | 7 ++++++- train.py | 2 +- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index e7bca6f1c..084dc7a82 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -480,7 +480,12 @@ def __init__(self): 0 is the default value. """, ) - + self.parser.add_argument( + "--checkpoint.load_step", + type=int, + default=0, + help="Load the checkpoint at the specified step. If 0, load the latest checkpoint.", + ) # activation checkpointing configs self.parser.add_argument( "--activation_checkpoint.mode", diff --git a/train.py b/train.py index 9e8b1fa81..53c813f1f 100644 --- a/train.py +++ b/train.py @@ -206,7 +206,7 @@ def loss_fn(pred, labels): logger.info("Created seed checkpoint") return - checkpoint_loaded = checkpoint.load() + checkpoint_loaded = checkpoint.load(step=job_config.checkpoint.load_step) if parallel_dims.pp_enabled and not checkpoint_loaded: # TODO: fix this by allowing each rank to set their own seed From 8417d00be142bae7aee15c4e3e9d1d1035cf0024 Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Wed, 4 Dec 2024 00:45:30 -0800 Subject: [PATCH 2/3] Update config_manager.py --- torchtitan/config_manager.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 084dc7a82..d17e263d5 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -483,8 +483,8 @@ def __init__(self): self.parser.add_argument( "--checkpoint.load_step", type=int, - default=0, - help="Load the checkpoint at the specified step. If 0, load the latest checkpoint.", + default=-1, + help="Load the checkpoint at the specified step. If -1, load the latest checkpoint.", ) # activation checkpointing configs self.parser.add_argument( From 3229fe3b018afa83f31b9872f6006f2ac6cb52df Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Thu, 5 Dec 2024 11:42:07 -0800 Subject: [PATCH 3/3] add docs --- docs/checkpoint.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/checkpoint.md b/docs/checkpoint.md index 3f2c8cb24..72e6a021f 100644 --- a/docs/checkpoint.md +++ b/docs/checkpoint.md @@ -49,7 +49,8 @@ export_dtype = "bfloat16" enable_checkpoint = true folder = "checkpoint" interval_type = "steps" -interval = 5 +interval = 10 +load_step = 5 model_weights_only = true export_dtype = "bfloat16" ```