-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathmain.py
372 lines (309 loc) · 14.2 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
from pathlib import Path
import argparse
import os
import sys
import random
import subprocess
import time
import json
import math
import numpy as np
from PIL import Image, ImageOps, ImageFilter
from torch import nn, optim
import torch
import torchvision
import torchvision.transforms as transforms
from utils import gather_from_all, GaussianBlur, Solarization
parser = argparse.ArgumentParser(description='RotNet Training')
parser.add_argument('--data', type=Path, metavar='DIR',
help='path to dataset')
parser.add_argument('--workers', default=8, type=int, metavar='N',
help='number of data loader workers')
parser.add_argument('--epochs', default=100, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('--batch-size', default=4096, type=int, metavar='N',
help='mini-batch size')
parser.add_argument('--learning-rate', default=4.8, type=float, metavar='LR',
help='base learning rate')
parser.add_argument('--weight-decay', default=1e-6, type=float, metavar='W',
help='weight decay')
parser.add_argument('--print-freq', default=10, type=int, metavar='N',
help='print frequency')
parser.add_argument('--checkpoint-dir', type=Path,
metavar='DIR', help='path to checkpoint directory')
parser.add_argument('--rotation', default=0.4, type=float,
help="coefficient of rotation loss")
parser.add_argument('--scale', default='0.05,0.14', type=str)
def main():
args = parser.parse_args()
args.ngpus_per_node = torch.cuda.device_count()
args.scale = [float(x) for x in args.scale.split(',')]
if 'SLURM_JOB_ID' in os.environ:
cmd = 'scontrol show hostnames ' + os.getenv('SLURM_JOB_NODELIST')
stdout = subprocess.check_output(cmd.split())
host_name = stdout.decode().splitlines()[0]
args.rank = int(os.getenv('SLURM_NODEID')) * args.ngpus_per_node
args.world_size = int(os.getenv('SLURM_NNODES')) * args.ngpus_per_node
args.dist_url = f'tcp://{host_name}:58478'
else:
# single-node distributed training
args.rank = 0
args.dist_url = f'tcp://localhost:{random.randrange(49152, 65535)}'
args.world_size = args.ngpus_per_node
torch.multiprocessing.spawn(main_worker, (args,), args.ngpus_per_node)
def main_worker(gpu, args):
args.rank += gpu
torch.distributed.init_process_group(
backend='nccl', init_method=args.dist_url,
world_size=args.world_size, rank=args.rank)
args.checkpoint_dir = args.checkpoint_dir
if args.rank == 0:
args.checkpoint_dir.mkdir(parents=True, exist_ok=True)
stats_file = open(args.checkpoint_dir / 'stats.txt', 'a', buffering=1)
print(' '.join(sys.argv))
print(' '.join(sys.argv), file=stats_file)
torch.cuda.set_device(gpu)
torch.backends.cudnn.benchmark = True
model = SimCLR(args).cuda(gpu)
model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[gpu])
optimizer = LARS(model.parameters(), lr=0, weight_decay=args.weight_decay,
weight_decay_filter=exclude_bias_and_norm,
lars_adaptation_filter=exclude_bias_and_norm)
# automatically resume from checkpoint if it exists
if (args.checkpoint_dir / 'checkpoint.pth').is_file():
ckpt = torch.load(args.checkpoint_dir / 'checkpoint.pth',
map_location='cpu')
start_epoch = ckpt['epoch']
model.load_state_dict(ckpt['model'])
optimizer.load_state_dict(ckpt['optimizer'])
else:
start_epoch = 0
dataset = torchvision.datasets.ImageFolder(args.data / 'train', Transform(args))
sampler = torch.utils.data.distributed.DistributedSampler(dataset, drop_last=True)
assert args.batch_size % args.world_size == 0
per_device_batch_size = args.batch_size // args.world_size
loader = torch.utils.data.DataLoader(
dataset, batch_size=per_device_batch_size, num_workers=args.workers,
pin_memory=True, sampler=sampler)
start_time = time.time()
scaler = torch.cuda.amp.GradScaler()
for epoch in range(start_epoch, args.epochs):
sampler.set_epoch(epoch)
for step, ((y1, y2, y3), labels) in enumerate(loader, start=epoch * len(loader)):
y1 = y1.cuda(gpu, non_blocking=True)
y2 = y2.cuda(gpu, non_blocking=True)
if args.rotation:
y3 = y3.cuda(gpu, non_blocking=True)
rotated_images, rotated_labels = rotate_images(y3, gpu)
lr = adjust_learning_rate(args, optimizer, loader, step)
optimizer.zero_grad()
with torch.cuda.amp.autocast():
loss, acc = model.forward(y1, y2, labels)
if args.rotation:
logits = model.module.forward_rotation(rotated_images)
rot_loss = torch.nn.functional.cross_entropy(logits, rotated_labels)
loss += args.rotation * rot_loss
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
if step % args.print_freq == 0:
torch.distributed.reduce(acc.div_(args.world_size), 0)
if args.rank == 0:
print(f'epoch={epoch}, step={step}, loss={loss.item()}, acc={acc.item()}')
stats = dict(epoch=epoch, step=step, learning_rate=lr,
loss=loss.item(), acc=acc.item(),
time=int(time.time() - start_time))
print(json.dumps(stats), file=stats_file)
if args.rank == 0:
# save checkpoint
state = dict(epoch=epoch + 1, model=model.state_dict(),
optimizer=optimizer.state_dict())
torch.save(state, args.checkpoint_dir / 'checkpoint.pth')
if args.rank == 0:
# save final model
torch.save(dict(backbone=model.module.backbone.state_dict(),
projector=model.module.projector.state_dict(),
head=model.module.online_head.state_dict()),
args.checkpoint_dir / 'resnet50.pth')
def adjust_learning_rate(args, optimizer, loader, step):
max_steps = args.epochs * len(loader)
warmup_steps = 10 * len(loader)
base_lr = args.learning_rate #* args.batch_size / 256
if step < warmup_steps:
lr = base_lr * step / warmup_steps
else:
step -= warmup_steps
max_steps -= warmup_steps
q = 0.5 * (1 + math.cos(math.pi * step / max_steps))
end_lr = base_lr * 0.001
lr = base_lr * q + end_lr * (1 - q)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
return lr
class SimCLR(nn.Module):
def __init__(self, args):
super().__init__()
self.args = args
self.backbone = torchvision.models.resnet50(zero_init_residual=True)
self.backbone.fc = nn.Identity()
# projector
sizes = [2048, 2048, 2048, 128]
layers = []
for i in range(len(sizes) - 2):
layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=False))
layers.append(nn.BatchNorm1d(sizes[i+1]))
layers.append(nn.ReLU(inplace=True))
layers.append(nn.Linear(sizes[-2], sizes[-1], bias=False))
layers.append(nn.BatchNorm1d(sizes[-1]))
self.projector = nn.Sequential(*layers)
self.online_head = nn.Linear(2048, 1000)
if args.rotation:
self.rotation_projector = nn.Sequential(nn.Linear(2048, 2048),
nn.LayerNorm(2048),
nn.ReLU(inplace=True), # first layer
nn.Linear(2048, 2048),
nn.LayerNorm(2048),
nn.ReLU(inplace=True), # second layer
nn.Linear(2048, 128),
nn.LayerNorm(128),
nn.Linear(128, 4)) # output layer
def forward(self, y1, y2, labels):
r1 = self.backbone(y1)
r2 = self.backbone(y2)
# projoection
z1 = self.projector(r1)
z2 = self.projector(r2)
loss = infoNCE(z1, z2) / 2 + infoNCE(z2, z1) / 2
logits = self.online_head(r1.detach())
cls_loss = torch.nn.functional.cross_entropy(logits, labels)
acc = torch.sum(torch.eq(torch.argmax(logits, dim=1), labels)) / logits.size(0)
loss = loss + cls_loss
return loss, acc
def forward_rotation(self, x):
b = self.backbone(x)
logits = self.rotation_projector(b)
return logits
def infoNCE(nn, p, temperature=0.2):
nn = torch.nn.functional.normalize(nn, dim=1)
p = torch.nn.functional.normalize(p, dim=1)
nn = gather_from_all(nn)
p = gather_from_all(p)
logits = nn @ p.T
logits /= temperature
n = p.shape[0]
labels = torch.arange(0, n, dtype=torch.long).cuda()
loss = torch.nn.functional.cross_entropy(logits, labels)
return loss
class LARS(optim.Optimizer):
def __init__(self, params, lr, weight_decay=0, momentum=0.9, eta=0.001,
weight_decay_filter=None, lars_adaptation_filter=None):
defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum,
eta=eta, weight_decay_filter=weight_decay_filter,
lars_adaptation_filter=lars_adaptation_filter)
super().__init__(params, defaults)
@torch.no_grad()
def step(self):
for g in self.param_groups:
for p in g['params']:
dp = p.grad
if dp is None:
continue
if g['weight_decay_filter'] is None or not g['weight_decay_filter'](p):
dp = dp.add(p, alpha=g['weight_decay'])
if g['lars_adaptation_filter'] is None or not g['lars_adaptation_filter'](p):
param_norm = torch.norm(p)
update_norm = torch.norm(dp)
one = torch.ones_like(param_norm)
q = torch.where(param_norm > 0.,
torch.where(update_norm > 0,
(g['eta'] * param_norm / update_norm), one), one)
dp = dp.mul(q)
param_state = self.state[p]
if 'mu' not in param_state:
param_state['mu'] = torch.zeros_like(p)
mu = param_state['mu']
mu.mul_(g['momentum']).add_(dp)
p.add_(mu, alpha=-g['lr'])
def exclude_bias_and_norm(p):
return p.ndim == 1
class Transform:
def __init__(self, args):
self.transform = transforms.Compose([
transforms.RandomResizedCrop(224, interpolation=Image.BICUBIC),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomApply(
[transforms.ColorJitter(brightness=0.4, contrast=0.4,
saturation=0.2, hue=0.1)],
p=0.8
),
transforms.RandomGrayscale(p=0.2),
GaussianBlur(p=1.0),
Solarization(p=0.0),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
self.transform_prime = transforms.Compose([
transforms.RandomResizedCrop(224, interpolation=Image.BICUBIC),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomApply(
[transforms.ColorJitter(brightness=0.4, contrast=0.4,
saturation=0.2, hue=0.1)],
p=0.8
),
transforms.RandomGrayscale(p=0.2),
GaussianBlur(p=0.1),
Solarization(p=0.2),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
self.transform_rotation = transforms.Compose([
transforms.RandomResizedCrop(96, scale=(args.scale[0], args.scale[1])),
transforms.RandomHorizontalFlip(),
transforms.RandomApply(
[transforms.ColorJitter(brightness=0.4, contrast=0.4,
saturation=0.2, hue=0.1)],
p=0.8
),
transforms.RandomGrayscale(p=0.2),
GaussianBlur(p=0.1),
Solarization(p=0.0),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
def __call__(self, x):
y1 = self.transform(x)
y2 = self.transform_prime(x)
y3 = self.transform_rotation(x)
return y1, y2, y3
# rotation
def rotate_images(images, gpu, single=False):
nimages = images.shape[0]
if single:
y = []
for i in range(nimages):
y.append(random.randint(0, 3))
images[i] = torch.rot90(images[i], y[-1], [1, 2])
y = torch.LongTensor(y).cuda()
return images.cuda(gpu), y
n_rot_images = 4 * nimages
# rotate images all 4 ways at once
rotated_images = torch.zeros([n_rot_images, images.shape[1], images.shape[2], images.shape[3]]).cuda(gpu,
non_blocking=True)
rot_classes = torch.zeros([n_rot_images]).long().cuda(gpu, non_blocking=True)
rotated_images[:nimages] = images
# rotate 90
rotated_images[nimages:2 * nimages] = images.flip(3).transpose(2, 3)
rot_classes[nimages:2 * nimages] = 1
# rotate 180
rotated_images[2 * nimages:3 * nimages] = images.flip(3).flip(2)
rot_classes[2 * nimages:3 * nimages] = 2
# rotate 270
rotated_images[3 * nimages:4 * nimages] = images.transpose(2, 3).flip(3)
rot_classes[3 * nimages:4 * nimages] = 3
return rotated_images, rot_classes
if __name__ == '__main__':
main()