Skip to content

Commit

Permalink
Fix multi-line string usage (#244)
Browse files Browse the repository at this point in the history
Summary: use `"""` for multi-line strings instead of tuple syntax which
breaks arg parse.

Test Plan: ```
============================= test session starts
============================== platform linux -- Python 3.10.14,
pytest-8.1.1, pluggy-1.4.0 --
/home/gnadathur/local/a/pytorch-env/bin/python cachedir: .pytest_cache
hypothesis profile 'default' ->
database=DirectoryBasedExampleDatabase(PosixPath('/data/users/gnadathur/a/torchtitan/.hypothesis/examples'))
benchmark: 4.0.0 (defaults: timer=time.perf_counter disable_gc=False
min_rounds=5 min_time=0.000005 max_time=1.0 calibration_precision=10
warmup=False warmup_iterations=100000) rootdir:
/data/users/gnadathur/a/torchtitan
configfile: pyproject.toml
plugins: hypothesis-6.100.1, benchmark-4.0.0, typeguard-4.2.1,
cov-5.0.0, hydra-core-1.3.2 collecting ... collected 6 items

test/test_job_config.py::TestJobConfig::test_command_line_args PASSED [
16%]
test/test_job_config.py::TestJobConfig::test_job_config_file PASSED [
33%]
test/test_job_config.py::TestJobConfig::test_job_file_does_not_exist
PASSED [ 50%]
test/test_job_config.py::TestJobConfig::test_empty_config_file PASSED [
66%]

test/test_job_config.py::TestJobConfig::test_job_config_file_cmd_overrides
PASSED [ 83%]
test/test_job_config.py::TestJobConfig::test_print_help PASSED [100%]

---------- coverage: platform linux, python 3.10.14-final-0 ----------
Coverage XML written to file coverage.xml


============================= slowest 20 durations
=============================
0.00s call     test/test_job_config.py::TestJobConfig::test_print_help
0.00s call
test/test_job_config.py::TestJobConfig::test_job_file_does_not_exist
0.00s call test/test_job_config.py::TestJobConfig::test_job_config_file
0.00s call
test/test_job_config.py::TestJobConfig::test_job_config_file_cmd_overrides
0.00s call
test/test_job_config.py::TestJobConfig::test_empty_config_file
0.00s call
test/test_job_config.py::TestJobConfig::test_command_line_args
0.00s setup
test/test_job_config.py::TestJobConfig::test_command_line_args
0.00s teardown
test/test_job_config.py::TestJobConfig::test_command_line_args
0.00s teardown
test/test_job_config.py::TestJobConfig::test_job_file_does_not_exist
0.00s teardown
test/test_job_config.py::TestJobConfig::test_job_config_file
0.00s setup
test/test_job_config.py::TestJobConfig::test_job_config_file_cmd_overrides
0.00s setup test/test_job_config.py::TestJobConfig::test_job_config_file
0.00s teardown test/test_job_config.py::TestJobConfig::test_print_help
0.00s setup
test/test_job_config.py::TestJobConfig::test_job_file_does_not_exist
0.00s setup
test/test_job_config.py::TestJobConfig::test_empty_config_file
0.00s setup    test/test_job_config.py::TestJobConfig::test_print_help
0.00s teardown
test/test_job_config.py::TestJobConfig::test_job_config_file_cmd_overrides
0.00s teardown
test/test_job_config.py::TestJobConfig::test_empty_config_file
============================== 6 passed in 0.19s
===============================
```
  • Loading branch information
gnadathur authored Apr 16, 2024
1 parent 78b843b commit ce0fff0
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 44 deletions.
5 changes: 5 additions & 0 deletions test/test_job_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,8 @@ def test_job_config_file_cmd_overrides(self):
]
)
assert config.job.dump_folder == "/tmp/test_tt/"

def test_print_help(self):
config = JobConfig()
parser = config.parser
parser.print_help()
88 changes: 44 additions & 44 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class JobConfig:
def __init__(self):
# main parser
self.parser = argparse.ArgumentParser(description="torchtitan arg parser.")

self.parser.add_argument(
"--job.config_file",
type=str,
Expand Down Expand Up @@ -154,10 +155,9 @@ def __init__(self):
self.parser.add_argument(
"--training.dataset_path",
type=str,
help=(
"Path to the dataset in the file system. If provided, data will be"
"loaded from this path instead of downloaded.",
),
help="""
Path to the dataset in the file system. If provided, data will be
loaded from this path instead of downloaded.""",
)
self.parser.add_argument(
"--training.batch_size", type=int, default=8, help="batch size"
Expand Down Expand Up @@ -212,6 +212,22 @@ def __init__(self):
action="store_true",
help="Whether to compile the model.",
)
self.parser.add_argument(
"--training.fp8_linear",
type=str,
default="",
choices=[
"dynamic",
"",
], # TODO: add "delayed" option back in when supported
help="Type of fp8 linear quantization to apply to the model",
)
self.parser.add_argument(
"--training.gc_freq",
type=int,
default=50,
help="Python garbage control scheduling interval, in steps",
)

# checkpoint configs
self.parser.add_argument(
Expand All @@ -223,65 +239,49 @@ def __init__(self):
"--checkpoint.folder",
type=str,
default="checkpoint",
help=(
"The folder to store the checkpoints. "
"When enable_checkpoint is set to true, checkpoints will be in {--job.dump_folder}/{--checkpoint.folder}."
),
help="""
The folder to store the checkpoints.
When enable_checkpoint is set to true, checkpoints will be in {--job.dump_folder}/{--checkpoint.folder}.
""",
)
self.parser.add_argument(
"--checkpoint.interval_type",
type=str,
default="steps",
help=(
"The checkpointing interval unit of measurement. "
"The default value is steps."
),
help="""
The checkpointing interval unit of measurement.
The default value is steps.
""",
)
self.parser.add_argument(
"--checkpoint.interval",
type=int,
default=500,
help=(
"Checkpointing interval. The unit of measurement is in seconds or "
"steps depending on --checkpoint.interval_type."
),
help="""
Checkpointing interval. The unit of measurement is in seconds or
steps depending on --checkpoint.interval_type.
""",
)
self.parser.add_argument(
"--checkpoint.model_weights_only",
action="store_true",
help=(
"When model_weights_only=True, only model weights will be saved at the end of training. "
"With this, checkpoints can be loaded using `torch.load(..., weights_only=True)` after conversion. "
"When model_weights_only=False, the full checkpoint will be saved. "
"A full checkpoint includes model, optimizer and train_state, which can be used to resume training. "
"The default value is false."
),
help="""
When model_weights_only=True, only model weights will be saved at the end of training.
With this, checkpoints can be loaded using `torch.load(..., weights_only=True)` after conversion.
When model_weights_only=False, the full checkpoint will be saved.
A full checkpoint includes model, optimizer and train_state, which can be used to resume training.
The default value is false.
""",
)
self.parser.add_argument(
"--checkpoint.export_dtype",
type=str,
default="float32",
help=(
"Converts to the specified precision when training completes and model_weights_only=true. "
"Currently supports float32, float16, and bfloat16. "
"The default value is float32."
),
)
self.parser.add_argument(
"--training.fp8_linear",
type=str,
default="",
choices=[
"dynamic",
"",
], # TODO: add "delayed" option back in when supported
help="Type of fp8 linear quantization to apply to the model",
)
self.parser.add_argument(
"--training.gc_freq",
type=int,
default=50,
help="Python garbage control scheduling interval, in steps",
help="""
Converts to the specified precision when training completes and model_weights_only=true.
Currently supports float32, float16, and bfloat16.
The default value is float32.
""",
)

# activation checkpointing
Expand Down

0 comments on commit ce0fff0

Please sign in to comment.