forked from lowe-lab-ucl/cellx-predict
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun_training.py
126 lines (107 loc) · 2.82 KB
/
run_training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import argparse
import ast
from pathlib import Path
from cellxpredict.config import Models, config_from_args
from cellxpredict.train import train
# check whether we're running in a container
current_path = Path(__file__).parent.resolve()
container_path = current_path / "container"
if container_path.exists():
DEFAULT_PATH = container_path
else:
DEFAULT_PATH = current_path
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Train TauVAE models")
parser.add_argument(
"--model",
choices=[m.name.lower() for m in Models],
type=str,
required=True,
help="name of the model to train",
)
parser.add_argument(
"--src_dir",
type=Path,
default=DEFAULT_PATH / "data",
help="path to the data directory",
)
parser.add_argument(
"--model_dir",
type=Path,
default=DEFAULT_PATH / "models",
help="path to the model output directory",
)
parser.add_argument(
"--log_dir",
type=Path,
default=DEFAULT_PATH / "logs",
help="path to the TensorBoard log directory",
)
parser.add_argument(
"--input_shape",
type=ast.literal_eval,
default=(64, 64, 2),
help="input shape of the image data (W, H, C)",
)
parser.add_argument(
"--layers",
type=ast.literal_eval,
default=[8, 16, 32, 64],
help="encoder layers list",
)
parser.add_argument(
"--latent_dims",
type=int,
default=32,
help="number of dimensions in latent space embedding",
)
parser.add_argument(
"--input_dtype",
type=str,
default="uint8",
help="input dtype of the image data",
)
parser.add_argument(
"--batch_size",
type=int,
default=256,
help="training mini-batch size",
)
parser.add_argument(
"--epochs",
type=int,
default=50,
help="training epochs",
)
parser.add_argument(
"--max_iterations_fraction",
type=float,
default=0.9,
help="percentage of steps before capacity reaches max value",
)
parser.add_argument(
"--capacity",
type=int,
default=50,
help="network capacity",
)
parser.add_argument(
"--num_outputs",
type=int,
default=3,
help="number of outputs",
)
parser.add_argument(
"--use_probabilistic_encoder",
action="store_true",
help="use a probabilistic encoder while training model",
)
parser.add_argument(
"--noise",
type=float,
default=1.0,
help="amplitude of noise when using a probabilistic encoder",
)
args = parser.parse_args()
config = config_from_args(args)
train(config)