-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathimagenet_predict.py
92 lines (73 loc) · 2.85 KB
/
imagenet_predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
from __future__ import print_function
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import os
import argparse
import resnet_imgnt as resnet
from prog_bar import progress_bar
import utils_color as utils
parser = argparse.ArgumentParser(description='ImageNet Evaluation')
parser.add_argument('--keep', required=True, type=int, help='pixels to keep')
parser.add_argument('--model', required=True, help='checkpoint to predict')
parser.add_argument('--alpha', default=0.05, type=float, help='Predict to 1-alpha probability')
parser.add_argument('--seed', default=0, type=int, help='random seed')
parser.add_argument('--samples', default=10000, type=int, help='number of samples')
parser.add_argument('--valpath', default='imagenet-val/val', type=str, help='Path to ImageNet validation set')
args = parser.parse_args()
checkpoint_dir = 'checkpoints'
acc_dir = 'accuracies'
if not os.path.exists('./accuracies'):
os.makedirs('./accuracies')
device = 'cuda' if torch.cuda.is_available() else 'cpu'
test_indices = torch.load('imagenet_indices.pth')
# Model
print('==> Building model..')
valdir = args.valpath
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
testloader = torch.utils.data.DataLoader(
datasets.ImageFolder(valdir, transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
])),
batch_size=1, shuffle=False,
num_workers=2, pin_memory=True,sampler=torch.utils.data.sampler.SubsetRandomSampler(test_indices))
# Model
print('==> Building model..')
net = resnet.resnet50()
net = net.to(device)
if device == 'cuda':
net = torch.nn.DataParallel(net)
cudnn.benchmark = True
assert os.path.isdir(checkpoint_dir), 'Error: no checkpoint directory found!'
resume_file = '{}/{}'.format(checkpoint_dir, args.model)
assert os.path.isfile(resume_file)
checkpoint = torch.load(resume_file)
net.load_state_dict(checkpoint['net'])
net.eval()
tot = 0
correct = 0
abstain = 0
for batch_idx, (inputs, targets) in enumerate(testloader):
inputs, targets = inputs.to(device), targets.to(device)
with torch.no_grad():
#breakpoint()
predicted = utils.predict(inputs, net, args.keep, args.samples, args.alpha,sub_batch=1000)
correct += (predicted == targets.cpu()).sum()
abstain += (predicted == -1).sum()
tot += predicted.shape[0]
progress_bar(batch_idx, len(testloader), 'Acc: %.3f%% (%d/%d)' % (100.*correct/tot, correct, tot))
out = {
'total': tot,
'correct': correct,
'abstain': abstain
}
torch.save(out, acc_dir +'/'+args.model+'_alpha_'+str(args.alpha)+'_samples_'+str(args.samples)+'.pth')