-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun_for_signifigance.py
55 lines (44 loc) · 1.66 KB
/
run_for_signifigance.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import importlib
import pytorch_lightning as pl
import torch
from omegaconf import OmegaConf
from pytorch_lightning.loggers import WandbLogger
from auto_mixer.datasets.mimic_cxr.mimic_cxr import MIMICCXRDataModule
from auto_mixer.runner import find_architecture
def main():
data = MIMICCXRDataModule(batch_size=64, num_workers=4)
data.setup()
pl.seed_everything(42)
for run in range(1):
fusion_function, best_model = find_architecture(data)
cfg = OmegaConf.load("auto_mixer/cfg/train.yml")
wandb_logger = WandbLogger(project='auto-mixer', name=f'mimic_cxr - {run}')
wandb_logger.experiment.config.update(cfg)
callbacks = build_callbacks(cfg.callbacks)
trainer = pl.Trainer(
callbacks=callbacks,
devices=torch.cuda.device_count(),
log_every_n_steps=cfg.log_interval_steps,
logger=wandb_logger,
max_epochs=cfg.epochs
)
train_dataloader = data.train_dataloader()
val_dataloader = data.val_dataloader()
print("\nStarting full training...\n")
print(best_model)
trainer.fit(best_model, train_dataloader, val_dataloader)
results = trainer.test(best_model, val_dataloader)[0]
print(results)
def build_callbacks(callbacks_cfg):
callbacks = []
for cb in callbacks_cfg:
callbacks.append(build_callback(cb))
return callbacks
def build_callback(cfg):
module_path, class_name = cfg.class_path.rsplit(".", 1)
module = importlib.import_module(module_path)
cls = getattr(module, class_name)
callback = cls(**cfg.init_args)
return callback
if __name__ == '__main__':
main()