-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathtrain_stage1.py
96 lines (78 loc) · 3.13 KB
/
train_stage1.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import sys
import os
import time
import random
from progress.bar import Bar
from collections import OrderedDict
from util import *
from PIL import Image
from data import get_loader, Test_Dataset
from test import test_model
import torch
from torch.nn import utils
from base.framework_factory import load_framework
torch.set_printoptions(precision=5)
def main():
# Loading model
if len(sys.argv) > 1:
net_name = sys.argv[1]
else:
print('Need model name!')
return
config, model, optim, sche, model_loss, saver = load_framework(net_name)
# Loading datasets
config['vals'] = [config['trset'], ]
train_loader = get_loader(config)
test_sets = OrderedDict()
for set_name in config['vals']:
test_sets[set_name] = Test_Dataset(name=set_name, config=config)
debug = config['debug']
num_epoch = config['epoch']
num_iter = len(train_loader)
ave_batch = config['ave_batch']
trset = config['trset']
batch_idx = 0
model.zero_grad()
for epoch in range(1, num_epoch + 1):
model.train()
torch.cuda.empty_cache()
if epoch == 1:
test_model(model, test_sets, config, 0)
print('Current LR: {:.6f}.'.format(optim.param_groups[1]['lr']))
bar = Bar('{:10}-{:8} | epoch {:2}:'.format(net_name, config['sub'], epoch), max=num_iter)
st = time.time()
loss_count = 0
optim.zero_grad()
sche.step()
#print(optim.state_dict())
for i, pack in enumerate(train_loader, start=1):
cur_it = i + (epoch-1) * num_iter
total_it = epoch * num_iter
images, gts = pack
images, gts = images.float().cuda(), gts.float().cuda()
if config['multi']:
scales = [-1, 0, 1]
#scales = [-2, -1, 0, 1, 2]
input_size = config['size']
input_size += int(np.random.choice(scales, 1) * 64)
#input_size += int(np.random.choice(scales, 1) * 32)
images = F.upsample(images, size=(input_size, input_size), mode='bilinear', align_corners=True)
gts = F.upsample(gts, size=(input_size, input_size), mode='nearest')
Y = model(images, 'train')
#print(torch.max(Y['final']), torch.min(Y['final']))
loss = model_loss(Y, config) / ave_batch
loss_count += loss.data
loss.backward()
batch_idx += 1
if batch_idx == ave_batch:
if config['clip_gradient']:
utils.clip_grad_norm_(model.parameters(), config['clip_gradient'])
optim.step()
optim.zero_grad()
batch_idx = 0
Bar.suffix = '{:4}/{:4} | loss: {:1.5f}, time: {}.'.format(i, num_iter, round(float(loss_count / i), 5), round(time.time() - st, 3))
bar.next()
bar.finish()
test_model(model, test_sets, config, epoch)
if __name__ == "__main__":
main()