-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain_singlenode.py
107 lines (90 loc) · 3.97 KB
/
train_singlenode.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import os
import argparse
import torch
from torch import nn
import torchvision
import torchvision.transforms as transforms
import wandb
wandb.init(mode="disabled", project="bird_example", entity=os.environ["USER"], name="bird_example_singlenode")
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--out_path', default='./bird_data/', type=str)
parser.add_argument('--lr', default=1e-3, type=float, help='learning rate')
parser.add_argument('--batch_size', default=16, type=int, help='batch size per GPU')
parser.add_argument('--gpu', default=None, type=int)
parser.add_argument('--start_epoch', default=0, type=int,
help='start epoch number (useful on restarts)')
parser.add_argument('--epochs', default=10, type=int, help='number of total epochs to run')
parser.add_argument('-j', '--workers', default=12, type=int, metavar='N',
help='number of data loading workers (default: 12)')
args = parser.parse_args()
return args
def main(args):
print("Number of GPUS: %d"%torch.cuda.device_count())
wandb.config = {
"learning_rate": args.lr,
"epochs": args.epochs,
"batch_size": args.batch_size
}
model = torchvision.models.resnet152(weights=torchvision.models.ResNet152_Weights.IMAGENET1K_V1)
model.fc = nn.Linear(model.fc.in_features, 400)
for param in model.parameters():
param.requires_grad = False
for param in model.fc.parameters():
param.requires_grad = True
model.cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
transform = transforms.Compose(
[transforms.ToTensor()])
train_dataset = torchvision.datasets.ImageFolder(root=args.out_path+"train/", transform=transform)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=args.batch_size, shuffle=True,
num_workers=args.workers, drop_last=True)
val_dataset = torchvision.datasets.ImageFolder(root=args.out_path+"valid/", transform=transform)
val_loader = torch.utils.data.DataLoader(
val_dataset, batch_size=args.batch_size, shuffle=False,
num_workers=args.workers, drop_last=False)
criterion = torch.nn.CrossEntropyLoss()
### main loop ###
for epoch in range(args.start_epoch, args.epochs):
model.train()
print("Epoch ",epoch)
for i, batch in enumerate(train_loader):
optimizer.zero_grad()
image_batch = batch[0].cuda()
annotation_batch = batch[1].cuda()
output = model(image_batch)
loss = criterion(output, annotation_batch)
loss.backward()
optimizer.step()
validate(val_loader, model, epoch)
def validate(val_loader, model, epoch):
tp_sum = 0
cnt = 0.0
first_batch_images = None
first_batch_predictions = None
first_batch_gt = None
for i,batch in enumerate(val_loader):
image_batch = batch[0].cuda()
annotation_batch = batch[1].cuda()
output = model(image_batch)
labels = torch.argmax(output, dim=1)
tp_sum += torch.sum((labels == annotation_batch).float())
cnt += image_batch.shape[0]
if i==0:
first_batch_images = image_batch
first_batch_predictions = labels
first_batch_gt = annotation_batch
pred_table = wandb.Table(columns=["Image","Prediction","GT"])
for j in range(first_batch_images.shape[0]):
pred_class_name = val_loader.dataset.classes[first_batch_predictions[j].item()]
gt_class_name = val_loader.dataset.classes[first_batch_gt[j].item()]
row = [wandb.Image(first_batch_images[j].detach().cpu().numpy().transpose((1, 2, 0))),
pred_class_name, gt_class_name]
pred_table.add_data(*row)
print("Cls. Acc: ",(tp_sum/cnt).item())
wandb.log({"Val. CA": tp_sum/cnt, "Val. Table":pred_table})
if __name__ == '__main__':
args = parse_args()
with torch.cuda.device(args.gpu):
main(args)