-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest.py
96 lines (71 loc) · 2.44 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import warnings
import torch
import random
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader
import time
import sys
import os
try:
sys.path.insert(0, '.')
from model.base import Model
from dataset.base import Dataset
from metric import metrics
from tools import parse, get_logger, print_dct
finally:
pass
warnings.filterwarnings('ignore')
def test():
args = parse()
args_str = print_dct(args)
torch.manual_seed(args['test']['seed'])
random.seed(args['test']['seed'])
bs = args['dataset']['batch_size']
device_str = args['test']['device']
device = torch.device(args['test']['device'])
model = Model(args['model'])
loss_fn = CrossEntropyLoss()
ckpt_path = f"{args['test']['checkpoint']}{args['model']['name'].lower()}/"
checkpoints = [ckpt_path + ckpt for ckpt in os.listdir(ckpt_path)]
log_path = f"{args['test']['log_path']}{args['model']['name'].lower()}/"
os.makedirs(log_path, exist_ok=True)
logger = get_logger(log_path + 'test.log')
model = model.to(device)
logger.info(f'\n{args_str}')
logger.info(f'\n{str(model)}')
logger.info(f'\n{str(loss_fn)}')
_, _, test_dataset = Dataset(args['dataset'])
test_data_loader = DataLoader(
test_dataset,
batch_size=bs,
shuffle=False
)
start = time.time()
for checkpoint in checkpoints:
model.load_state_dict(torch.load(checkpoint, map_location=device))
model.eval()
test_loss = 0.0
with torch.no_grad():
for batch in test_data_loader:
imgs, labels = batch
imgs = imgs.to(device)
labels = labels.to(device)
preds = model(imgs)
loss = loss_fn(preds, labels)
test_loss += loss.item()
test_loss /= len(test_data_loader)
print_str = f'{device_str} '
print_str += f'ckpt: {checkpoint} '
print_str += f'test_loss: {test_loss:.4f} '
test_metrics = metrics(model, test_data_loader, device)
for i, (name, metr) in enumerate(test_metrics):
if i == len(test_metrics) - 1:
print_str += f'test_{name}: {metr:.3f}'
else:
print_str += f'test_{name}: {metr:.3f} '
logger.info(print_str)
end = time.time()
print_str = f'total_time: {(end - start):.3f} sec'
logger.info(print_str)
if __name__ == '__main__':
test()