diff --git a/test/test_job_config.py b/test/test_job_config.py index d2ba02093..e4ef04ba0 100644 --- a/test/test_job_config.py +++ b/test/test_job_config.py @@ -43,3 +43,8 @@ def test_job_config_file_cmd_overrides(self): ] ) assert config.job.dump_folder == "/tmp/test_tt/" + + def test_print_help(self): + config = JobConfig() + parser = config.parser + parser.print_help() diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 6d00634df..0e6555080 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -42,6 +42,7 @@ class JobConfig: def __init__(self): # main parser self.parser = argparse.ArgumentParser(description="torchtitan arg parser.") + self.parser.add_argument( "--job.config_file", type=str, @@ -154,10 +155,9 @@ def __init__(self): self.parser.add_argument( "--training.dataset_path", type=str, - help=( - "Path to the dataset in the file system. If provided, data will be" - "loaded from this path instead of downloaded.", - ), + help=""" + Path to the dataset in the file system. If provided, data will be + loaded from this path instead of downloaded.""", ) self.parser.add_argument( "--training.batch_size", type=int, default=8, help="batch size" @@ -212,6 +212,22 @@ def __init__(self): action="store_true", help="Whether to compile the model.", ) + self.parser.add_argument( + "--training.fp8_linear", + type=str, + default="", + choices=[ + "dynamic", + "", + ], # TODO: add "delayed" option back in when supported + help="Type of fp8 linear quantization to apply to the model", + ) + self.parser.add_argument( + "--training.gc_freq", + type=int, + default=50, + help="Python garbage control scheduling interval, in steps", + ) # checkpoint configs self.parser.add_argument( @@ -223,65 +239,49 @@ def __init__(self): "--checkpoint.folder", type=str, default="checkpoint", - help=( - "The folder to store the checkpoints. " - "When enable_checkpoint is set to true, checkpoints will be in {--job.dump_folder}/{--checkpoint.folder}." - ), + help=""" + The folder to store the checkpoints. + When enable_checkpoint is set to true, checkpoints will be in {--job.dump_folder}/{--checkpoint.folder}. + """, ) self.parser.add_argument( "--checkpoint.interval_type", type=str, default="steps", - help=( - "The checkpointing interval unit of measurement. " - "The default value is steps." - ), + help=""" + The checkpointing interval unit of measurement. + The default value is steps. + """, ) self.parser.add_argument( "--checkpoint.interval", type=int, default=500, - help=( - "Checkpointing interval. The unit of measurement is in seconds or " - "steps depending on --checkpoint.interval_type." - ), + help=""" + Checkpointing interval. The unit of measurement is in seconds or + steps depending on --checkpoint.interval_type. + """, ) self.parser.add_argument( "--checkpoint.model_weights_only", action="store_true", - help=( - "When model_weights_only=True, only model weights will be saved at the end of training. " - "With this, checkpoints can be loaded using `torch.load(..., weights_only=True)` after conversion. " - "When model_weights_only=False, the full checkpoint will be saved. " - "A full checkpoint includes model, optimizer and train_state, which can be used to resume training. " - "The default value is false." - ), + help=""" + When model_weights_only=True, only model weights will be saved at the end of training. + With this, checkpoints can be loaded using `torch.load(..., weights_only=True)` after conversion. + When model_weights_only=False, the full checkpoint will be saved. + A full checkpoint includes model, optimizer and train_state, which can be used to resume training. + The default value is false. + """, ) self.parser.add_argument( "--checkpoint.export_dtype", type=str, default="float32", - help=( - "Converts to the specified precision when training completes and model_weights_only=true. " - "Currently supports float32, float16, and bfloat16. " - "The default value is float32." - ), - ) - self.parser.add_argument( - "--training.fp8_linear", - type=str, - default="", - choices=[ - "dynamic", - "", - ], # TODO: add "delayed" option back in when supported - help="Type of fp8 linear quantization to apply to the model", - ) - self.parser.add_argument( - "--training.gc_freq", - type=int, - default=50, - help="Python garbage control scheduling interval, in steps", + help=""" + Converts to the specified precision when training completes and model_weights_only=true. + Currently supports float32, float16, and bfloat16. + The default value is float32. + """, ) # activation checkpointing