-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcli.py
76 lines (70 loc) · 4.35 KB
/
cli.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import argparse
from typing import Optional
def _bool_type(value):
if value.lower() == "true":
return True
elif value.lower() == "false":
return False
raise KeyError(f"`{value}` can't be interpreted as a boolean.")
def parse_args(argv: Optional[list[str]] = None) -> dict:
"""
Parse some useful CLI arguments for use in training scripts.
Arguments:
argv: List of argument values. Defaults to ``sys.argv``.
Returns:
A dictionary of the argument values.
"""
parser = argparse.ArgumentParser("Fit ML model")
parser.add_argument("--run_dir", type=str, default="./", help="Path to directory where output files are saved. Default = ./")
parser.add_argument("--dataset", type=str, help="Name of dataset to use.")
parser.add_argument("--data_dir", type=str, help="Path to directory where the data is stored.")
parser.add_argument("--urls_train", type=str, help="Webdataset URLs for the training set shards.")
parser.add_argument("--urls_val", type=str, help="Webdataset URLs for the validation set shards.")
parser.add_argument("--urls_test", type=str, help="Webdataset URLs for the test set shards.")
parser.add_argument("--epochs", type=int, help="Number of training epochs.")
parser.add_argument("--batch_size", type=int, help="Number of samples per batch per parallel rank.")
parser.add_argument("--lr", type=float, help="Learning rate.")
parser.add_argument(
"--classes",
type=lambda s: [int(n) for n in s.split(",")],
nargs="+",
help="Lists of classes for categorizing chemical elements. Elements in each list are separated by commas, and lists are separated by spaces.",
)
parser.add_argument("--class_colors", type=str, nargs="+", help="Colors for each class of chemical elements.")
parser.add_argument(
"--z_lims", type=float, nargs=2, help="Lower and upper limit for position distribution grid in the z direction."
)
parser.add_argument(
"--box_res", type=float, nargs=3, help="Voxel resolution in the position distribution grid in each direction in Ångströms."
)
parser.add_argument("--edge_cutoff", type=float, help="Edge cutoff distance in Ångströms.")
parser.add_argument("--afm_cutoff", type=float, help="AFM region cutoff around each atom position in Ångströms.")
parser.add_argument("--zmin", type=float, help="Lowest atom z-coordinate to include in Ångströms (top atom is at 0).")
parser.add_argument("--peak_std", type=float, help="Position distribution grid peak standard deviation in Ångströms.")
parser.add_argument(
"--loss_weights", type=float, nargs="+", default=[1.0], help="Weights for each loss component. Default = [1.0]"
)
parser.add_argument(
"--loss_labels", type=str, nargs="+", default=["MSE"], help='Labels for each loss component. Default = ["MSE"]'
)
parser.add_argument(
"--load_weights", type=str, default="", help="Path to a saved model state to load in the beginning of the training."
)
parser.add_argument("--random_seed", type=int, default=0, help="Set random seed for reproducibility. Default = 0.")
parser.add_argument("--train", type=_bool_type, default=True, help="Enable training (true of false). Default = true.")
parser.add_argument("--test", type=_bool_type, default=True, help="Enable testing (true of false). Default = true.")
parser.add_argument("--predict", type=_bool_type, default=True, help="Enable prediction (true of false). Default = true.")
parser.add_argument(
"--num_workers", type=int, default=1, help="Number of parallel workers for data loading per parallel rank. Default = 1."
)
parser.add_argument(
"--print_interval", type=int, default=10, help="Number of batches between printing training loss etc. Default = 10."
)
parser.add_argument("--comm_backend", type=str, default="nccl", help="Parallel communications backend. Default = nccl.")
parser.add_argument("--timings", action="store_true", help="Enable printing timings during training.")
parser.add_argument("--pred_batches", type=int, default=3, help="Number of prediction batches. Default = 3.")
parser.add_argument(
"--avg_best_epochs", type=int, default=3, help="Number of epochs to average the best validation loss over. Default = 3."
)
args = parser.parse_args(argv)
return vars(args)