-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathbase_classifier_train.py
149 lines (113 loc) · 5.39 KB
/
base_classifier_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
# -*- coding: utf-8 -*-
from __future__ import division
"""
Trains a ResNeXt Model on Cifar10 and Cifar 100. Implementation as defined in:
Xie, S., Girshick, R., Dollár, P., Tu, Z., & He, K. (2016).
Aggregated residual transformations for deep neural networks.
arXiv preprint arXiv:1611.05431.
"""
import numpy as np
import hydra
from omegaconf import DictConfig
import logging
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
# from torchvision.models import resnet18, resnet34, resnet50
from models import resnet18, resnet34, resnet50
from utils import cal_parameters, get_dataset, AverageMeter
logger = logging.getLogger(__name__)
def get_model_for_tiny_imagenet(name='resnet18', n_classes=200):
classifier = eval('torchvision.models.' + name)(pretrained=True)
classifier.avgpool = nn.AdaptiveAvgPool2d(1)
classifier.fc = nn.Linear(classifier.fc.in_features, n_classes)
return classifier
def get_model(name='resnet18', n_classes=10):
""" get proper model from torchvision models. """
model_list = ['resnet18', 'resnet34', 'resnet50']
assert name in model_list, '{} not available, choose from {}'.format(name, model_list)
classifier = eval(name)(n_classes=n_classes)
return classifier
def run_epoch(classifier, data_loader, args, optimizer=None):
"""
Run one epoch.
:param classifier: torch.nn.Module representing the classifier.
:param data_loader: dataloader
:param args:
:param optimizer: if None, then inference; if optimizer given, training and optimizing.
:return: mean of loss, mean of accuracy of this epoch.
"""
if optimizer:
classifier.train()
else:
classifier.eval()
loss_meter = AverageMeter('loss')
acc_meter = AverageMeter('Acc')
for batch_idx, (x, y) in enumerate(data_loader):
x, y = x.to(args.device), y.to(args.device)
output = classifier(x)
loss = F.cross_entropy(output, y)
if optimizer:
optimizer.zero_grad()
loss.backward()
optimizer.step()
loss_meter.update(loss.item(), x.size(0))
acc = (output.argmax(dim=1) == y).float().mean().item()
acc_meter.update(acc, x.size(0))
return loss_meter.avg, acc_meter.avg
def train(classifier, train_loader, test_loader, args):
optimizer = torch.optim.SGD(classifier.parameters(), lr=args.learning_rate, momentum=args.momentum,
weight_decay=args.decay, nesterov=True)
best_train_loss = np.inf
for epoch in range(1, args.epochs + 1):
if epoch in args.schedule:
args.learning_rate *= args.gamma
for param_group in optimizer.param_groups:
param_group['lr'] = args.learning_rate
train_loss, train_acc = run_epoch(classifier, train_loader, args, optimizer=optimizer)
logger.info('Epoch: {}, training loss: {:.4f}, acc: {:.4f}.'.format(epoch, train_loss, train_acc))
test_loss, test_acc = run_epoch(classifier, test_loader, args)
logger.info("Test loss: {:.4f}, acc: {:.4f}".format(test_loss, test_acc))
if train_loss < best_train_loss:
best_train_loss = train_loss
save_name = '{}.pth'.format(args.classifier_name)
# # if use cuda and n_gpu > 1
# if next(classifier.parameters()).is_cuda and args.n_gpu > 1:
# state = classifier.module.state_dict()
# else:
# state = classifier.state_dict()
state = classifier.state_dict()
torch.save(state, save_name)
logger.info("==> New optimal training loss & saving checkpoint ...")
@hydra.main(config_path='configs/base_config.yaml')
def run(args: DictConfig) -> None:
cuda_available = torch.cuda.is_available()
torch.manual_seed(args.seed)
device = "cuda" if cuda_available and args.device == 'cuda' else "cpu"
n_classes = args.get(args.dataset).n_classes
if args.dataset == 'tiny_imagenet':
args.epochs = 20
args.learning_rate = 0.001
classifier = get_model_for_tiny_imagenet(args.classifier_name, n_classes).to(device)
args.data_dir = 'tiny_imagenet'
else:
classifier = get_model(name=args.classifier_name, n_classes=n_classes).to(device)
# if device == 'cuda' and args.n_gpu > 1:
# classifier = torch.nn.DataParallel(classifier, device_ids=list(range(args.n_gpu)))
logger.info('Base classifier name: {}, # parameters: {}'.format(args.classifier_name, cal_parameters(classifier)))
data_dir = hydra.utils.to_absolute_path(args.data_dir)
train_data = get_dataset(data_name=args.dataset, data_dir=data_dir, train=True, crop_flip=True)
test_data = get_dataset(data_name=args.dataset, data_dir=data_dir, train=False, crop_flip=False)
train_loader = DataLoader(dataset=train_data, batch_size=args.n_batch_train, shuffle=True)
test_loader = DataLoader(dataset=test_data, batch_size=args.n_batch_test, shuffle=False)
if args.inference:
save_name = '{}.pth'.format(args.classifier_name)
classifier.load_state_dict(torch.load(save_name, map_location=lambda storage, loc: storage))
loss, acc = run_epoch(classifier, test_loader, args)
logger.info('Inference loss: {:.4f}, acc: {:.4f}'.format(loss, acc))
else:
train(classifier, train_loader, test_loader, args)
if __name__ == '__main__':
run()