-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathProcessor.py
123 lines (105 loc) · 4.3 KB
/
Processor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import os
import gc
import torch
import numpy as np
import tqdm
import collections
from model.shift_gcn import Model
from torch.autograd import Variable
from utils import *
import shutil
import inspect
from tqdm import tqdm
import torch.nn as nn
class Processor():
def __init__(self, arg):
super(Processor, self).__init__()
self.arg = arg
self.model = self.load_model()
self.data_loader=self.load_data()
self.loss = nn.CrossEntropyLoss().cuda(self.arg.device)
self.best_acc = 0
def start(self):
wf = './output/wrong.txt'
rf = './output/right.txt'
self.eval(epoch=0, loader_name=['test'], wrong_file=wf, result_file=rf)
def eval(self, epoch, save_score=False, loader_name=['test'], wrong_file=None, result_file=None):
if wrong_file is not None:
f_w = open(wrong_file, 'w')
if result_file is not None:
f_r = open(result_file, 'w')
self.model.eval()
for ln in loader_name:
loss_value = []
score_frag = []
step = 0
process = tqdm(self.data_loader)
for batch_idx, (data, label, index) in enumerate(process):
data = Variable(
data.float().cuda(self.arg.device),
requires_grad=False,
volatile=True)
label = Variable(
label.long().cuda(self.arg.device),
requires_grad=False,
volatile=True)
with torch.no_grad():
output = self.model(data)
loss = self.loss(output, label)
score_frag.append(output.data.cpu().numpy())
loss_value.append(loss.data.cpu().numpy())
_, predict_label = torch.max(output.data, 1)
step += 1
if wrong_file is not None or result_file is not None:
predict = list(predict_label.cpu().numpy())
true = list(label.data.cpu().numpy())
for i, x in enumerate(predict):
if result_file is not None:
f_r.write(str(x) + ',' + str(true[i]) + '\n')
if x != true[i] and wrong_file is not None:
f_w.write(str(index[i]) + ',' + str(x) + ',' + str(true[i]) + '\n')
score = np.concatenate(score_frag)
accuracy = self.data_loader.dataset.top_k(score, 1)
print('Eval Accuracy: ', accuracy)
def load_model(self):
Model = import_class(self.arg.model)
shutil.copy2(inspect.getfile(Model), self.arg.work_dir)
model = Model(**self.arg.model_args).cuda(self.arg.device)
# 加载预训练权重
checkpoint_file = self.arg.weights
if not os.path.exists(checkpoint_file): # from scratch
return None
else: # get chk points
print("=> loading checkpoint '{}'".format(checkpoint_file))
device = torch.device(self.arg.device)
model_dict = model.state_dict()
# import pickle
# with open(checkpoint_file, 'rb') as f:
# obj = f.read()
# weights = {key: weight_dict for key, weight_dict in pickle.loads(obj, encoding='latin1').items()}
checkpoint = torch.load(checkpoint_file)
# checkpoint_epoch = checkpoint['epoch']
checkpoint2 = collections.OrderedDict()
for k, v in checkpoint.items():
try:
name = k
if np.shape(model_dict[name]) == np.shape(v):
checkpoint2[name] = v
except:
continue
checkpoint = checkpoint2
model_dict.update(checkpoint)
model.load_state_dict(model_dict, strict=False)
del checkpoint, checkpoint2, model_dict
gc.collect()
return model
def load_data(self):
Feeder = import_class(self.arg.feeder)
self.data_loader = dict()
data_loader = torch.utils.data.DataLoader(
dataset=Feeder(**self.arg.test_feeder_args),
batch_size=self.arg.test_batch_size,
shuffle=False,
num_workers=0,
drop_last=False)
return data_loader