diff --git a/docs/checkpoint.md b/docs/checkpoint.md index 3f2c8cb24..72e6a021f 100644 --- a/docs/checkpoint.md +++ b/docs/checkpoint.md @@ -49,7 +49,8 @@ export_dtype = "bfloat16" enable_checkpoint = true folder = "checkpoint" interval_type = "steps" -interval = 5 +interval = 10 +load_step = 5 model_weights_only = true export_dtype = "bfloat16" ``` diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index e7bca6f1c..d17e263d5 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -480,7 +480,12 @@ def __init__(self): 0 is the default value. """, ) - + self.parser.add_argument( + "--checkpoint.load_step", + type=int, + default=-1, + help="Load the checkpoint at the specified step. If -1, load the latest checkpoint.", + ) # activation checkpointing configs self.parser.add_argument( "--activation_checkpoint.mode", diff --git a/train.py b/train.py index 9e8b1fa81..53c813f1f 100644 --- a/train.py +++ b/train.py @@ -206,7 +206,7 @@ def loss_fn(pred, labels): logger.info("Created seed checkpoint") return - checkpoint_loaded = checkpoint.load() + checkpoint_loaded = checkpoint.load(step=job_config.checkpoint.load_step) if parallel_dims.pp_enabled and not checkpoint_loaded: # TODO: fix this by allowing each rank to set their own seed