-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathmain.py
41 lines (34 loc) · 1.24 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import wandb
import torch
from dataloader import get_train_val_loader
from config import get_config
from model import Model
from trainer import Trainer
wandb.init("CTC")
def main(config):
# setup
torch.manual_seed(config.random_seed)
torch.set_num_threads(1)
if config.use_gpu:
torch.cuda.manual_seed(config.random_seed)
# get data-loaders
# create a model
train_dataset, val_dataset, num_leds = get_train_val_loader(config, pin_memory=True)
model = Model(config.num_heads, num_leds, config.num_channels, batch_norm=config.batch_norm, skip=config.skip,
initilization_strategy=config.init_strategy, num_filters=config.num_filters, task=config.task,
noise=config.noise)
if config.use_gpu:
model.cuda()
model.noise_layer.cuda()
[net.cuda() for net in model.nets]
params = list(model.parameters())
for net in model.nets:
params += list(net.parameters())
# setup optimizer
optimizer = torch.optim.Adam(params, lr=config.init_lr)
trainer = Trainer(model, optimizer, train_dataset, val_dataset, config)
wandb.config.update(config)
trainer.train()
if __name__ == "__main__":
config, unparsed = get_config()
main(config)