-
Notifications
You must be signed in to change notification settings - Fork 112
/
Copy pathtest.py
64 lines (53 loc) · 2.46 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
from torch.utils.data import DataLoader
import torch
from tqdm import tqdm
import config
from utils import logger,common
from dataset.dataset_lits_test import Test_Datasets,to_one_hot_3d
import SimpleITK as sitk
import os
import numpy as np
from models import ResUNet
from utils.metrics import DiceAverage
from collections import OrderedDict
def predict_one_img(model, img_dataset, args):
dataloader = DataLoader(dataset=img_dataset, batch_size=1, num_workers=0, shuffle=False)
model.eval()
test_dice = DiceAverage(args.n_labels)
target = to_one_hot_3d(img_dataset.label, args.n_labels)
with torch.no_grad():
for data in tqdm(dataloader,total=len(dataloader)):
data = data.to(device)
output = model(data)
# output = nn.functional.interpolate(output, scale_factor=(1//args.slice_down_scale,1//args.xy_down_scale,1//args.xy_down_scale), mode='trilinear', align_corners=False) # 空间分辨率恢复到原始size
img_dataset.update_result(output.detach().cpu())
pred = img_dataset.recompone_result()
pred = torch.argmax(pred,dim=1)
pred_img = common.to_one_hot_3d(pred,args.n_labels)
test_dice.update(pred_img, target)
test_dice = OrderedDict({'Dice_liver': test_dice.avg[1]})
if args.n_labels==3: test_dice.update({'Dice_tumor': test_dice.avg[2]})
pred = np.asarray(pred.numpy(),dtype='uint8')
if args.postprocess:
pass # TO DO
pred = sitk.GetImageFromArray(np.squeeze(pred,axis=0))
return test_dice, pred
if __name__ == '__main__':
args = config.args
save_path = os.path.join('./experiments', args.save)
device = torch.device('cpu' if args.cpu else 'cuda')
# model info
model = ResUNet(in_channel=1, out_channel=args.n_labels,training=False).to(device)
model = torch.nn.DataParallel(model, device_ids=args.gpu_id) # multi-GPU
ckpt = torch.load('{}/best_model.pth'.format(save_path))
model.load_state_dict(ckpt['net'])
test_log = logger.Test_Logger(save_path,"test_log")
# data info
result_save_path = '{}/result'.format(save_path)
if not os.path.exists(result_save_path):
os.mkdir(result_save_path)
datasets = Test_Datasets(args.test_data_path,args=args)
for img_dataset,file_idx in datasets:
test_dice,pred_img = predict_one_img(model, img_dataset, args)
test_log.update(file_idx, test_dice)
sitk.WriteImage(pred_img, os.path.join(result_save_path, 'result-'+file_idx+'.gz'))