-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathtest.py
47 lines (38 loc) · 1.62 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
# config
import sys
import warnings
import time
import numpy as np
import torch
from config import config
from eval.Estimator import Estimator
from net.networks import *
from options.test_options import TestOptions
#from Dataset.DatasetConstructor import TrainDatasetConstructor,EvalDatasetConstructor
from Dataset.DatasetConstructor_fast import TrainDatasetConstructor,EvalDatasetConstructor
opt = TestOptions().parse()
opt.nThreads = 1 # test code only supports nThreads = 1
opt.batch_size = 1 # test code only supports batchSize = 1
opt.is_flip = 0 # no flip
test_model_name = 'output/HRNet_relu_aspp/JSTL_large_4/MAE_60.62_MSE_324.36_mae_59.85_8.99_91.11_73.17_mse_96.74_15.22_160.36_509.46_Ep_380'
# Mainly get settings for specific datasets
setting = config(opt)
# Data loaders
eval_dataset = EvalDatasetConstructor(
setting.eval_num,
setting.eval_img_path,
setting.eval_gt_map_path,
mode=setting.mode,
dataset_name=setting.dataset_name,
device=setting.device)
eval_loader = torch.utils.data.DataLoader(dataset=eval_dataset, batch_size=1)
# model construct
net = define_net(opt.net_name)
net = init_net(net, gpu_ids=opt.gpu_ids)
net.module.load_state_dict(torch.load(test_model_name, map_location=str(setting.device)))
criterion = torch.nn.MSELoss(reduction='sum').to(setting.device)
estimator = Estimator(setting, eval_loader, criterion=criterion)
validate_MAE, validate_RMSE, validate_loss, time_cost = estimator.evaluate(net, True)
sys.stdout.write('loss = {}, eval_mae = {}, eval_rmse = {}, time cost eval = {}s\n'
.format(validate_loss, validate_MAE, validate_RMSE, time_cost))
sys.stdout.flush()