forked from g1910/HyperNetworks
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
122 lines (87 loc) · 3.41 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import torch
import torchvision
import torchvision.transforms as transforms
from torch.autograd import Variable
import torch.nn as nn
import argparse
import torch.optim as optim
from primary_net import PrimaryNetwork
########### Data Loader ###############
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
trainset = torchvision.datasets.CIFAR10(root='../data', train=True,
download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128,
shuffle=True, num_workers=4)
testset = torchvision.datasets.CIFAR10(root='../data', train=False,
download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=128,
shuffle=False, num_workers=4)
classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
#############################
parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training')
parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint')
args = parser.parse_args()
############
net = PrimaryNetwork()
best_accuracy = 0.
if args.resume:
ckpt = torch.load('./hypernetworks_cifar_paper.pth')
net.load_state_dict(ckpt['net'])
best_accuracy = ckpt['acc']
net.cuda()
learning_rate = 0.002
weight_decay = 0.0005
milestones = [168000, 336000, 400000, 450000, 550000, 600000]
max_iter = 1000000
optimizer = optim.Adam(net.parameters(), lr=learning_rate, weight_decay=weight_decay)
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizer, milestones=milestones, gamma=0.5)
criterion = nn.CrossEntropyLoss()
total_iter = 0
epochs = 0
print_freq = 50
while total_iter < max_iter:
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
inputs, labels = data
inputs, labels = Variable(inputs.cuda()), Variable(labels.cuda())
optimizer.zero_grad()
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
lr_scheduler.step()
running_loss += loss.data[0]
if i % print_freq == (print_freq-1):
print("[Epoch %d, Total Iterations %6d] Loss: %.4f" % (epochs + 1, total_iter + 1, running_loss/print_freq))
running_loss = 0.0
total_iter += 1
epochs += 1
correct = 0.
total = 0.
for tdata in testloader:
timages, tlabels = tdata
toutputs = net(Variable(timages.cuda()))
_, predicted = torch.max(toutputs.cpu().data, 1)
total += tlabels.size(0)
correct += (predicted == tlabels).sum()
accuracy = (100. * correct) / total
print('After epoch %d, accuracy: %.4f %%' % (epochs, accuracy))
if accuracy > best_accuracy:
print('Saving model...')
state = {
'net': net.state_dict(),
'acc': accuracy
}
torch.save(state, './hypernetworks_cifar_paper.pth')
best_accuracy = accuracy
print('Finished Training')