-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest.py
72 lines (59 loc) · 1.88 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
"""Script for multi-gpu training."""
import logging
import os
import random
import sys
import numpy as np
import torch
import torch.multiprocessing as mp
import torch.nn as nn
import torch.utils.data
from utils.opt import cfg, logger, opt
from utils.metrics import NullWriter
from trainer import train, validate
from model.fcoct import FCOCT
from utils.dataset import Hc
from model.criterion import SummaryLoss
from tensorboardX import SummaryWriter
def _init_fn(worker_id):
np.random.seed(opt.seed)
random.seed(opt.seed)
def setup_seed(seed):
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
def main():
if opt.seed is not None:
setup_seed(opt.seed)
if opt.log:
cfg_file_name = os.path.basename(opt.cfg)
filehandler = logging.FileHandler(
'./exp/{}-{}/training.log'.format(opt.exp_id, cfg_file_name))
streamhandler = logging.StreamHandler()
logger.setLevel(logging.INFO)
logger.addHandler(filehandler)
logger.addHandler(streamhandler)
else:
null_writer = NullWriter()
sys.stdout = null_writer
logger.info('******************************')
logger.info(opt)
logger.info('******************************')
logger.info(cfg)
logger.info('******************************')
m = FCOCT(opt, cfg)
print(f'Loading model from {opt.checkpoint}...')
m.load_state_dict(torch.load(
opt.checkpoint, map_location='cpu'), strict=False)
m.cuda()
with torch.no_grad():
err, std = validate(m, opt, cfg, test=True)
if opt.log:
logger.info(f'##### gt results: {err} std: {std} #####')
if __name__ == "__main__":
main()