-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathcal_from_mask.py
130 lines (97 loc) · 4.13 KB
/
cal_from_mask.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import torch
import torch.nn as nn
import torch.utils.data as Data
import torchvision.transforms as transforms
import torch.nn.functional as F
#from utils.metrics import ROCMetric
from utils.data import *
from PIL import Image
import os
import os.path as osp
import scipy.io as scio
import numpy as np
import cv2
from utils.evaluation.roc_cruve import ROCMetric
from utils.evaluation.my_pd_fa import my_PD_FA
from utils.evaluation.TPFNFP import SegmentationMetricTPFNFP
class Dataset_mat(Data.Dataset):
def __init__(self, dataset, base_size=256, thre=0.):
self.base_size = base_size
self.dataset = dataset
if(dataset == 'NUDT-SIRST'):
self.mat_dir = './eval/matData/NUDT-SIRST'
self.mask_dir = './datasets/NUDT-SIRST/test/masks'
elif(dataset == 'IRSTD-1K'):
self.mat_dir = './eval/matData/IRSTD-1k'
self.mask_dir = './datasets/IRSTD-1k/test/masks'
elif(dataset == 'SIRST-aug'):
self.mat_dir = './eval/matData/sirst_aug'
self.mask_dir = './datasets/sirst_aug/test/masks'
else:
raise NotImplementedError
file_mat_names = os.listdir(self.mat_dir)
self.file_names = [s[:-4] for s in file_mat_names]
self.thre = thre
self.mat_transform = transforms.Resize((base_size, base_size), interpolation=Image.BILINEAR)
self.mask_transform = transforms.Resize((base_size, base_size), interpolation=Image.NEAREST)
def __getitem__(self, i):
name = self.file_names[i]
mask_path = osp.join(self.mask_dir, name) + ".png"
mat_path = osp.join(self.mat_dir, name) + ".mat"
#print(mask_path)
rstImg = scio.loadmat(mat_path)['T']
rstImg = np.asarray(rstImg)
rst_seg = np.zeros(rstImg.shape)
rst_seg[rstImg > self.thre] = 1
mask=cv2.imdecode(np.fromfile(mask_path, dtype=np.uint8), -1)
if mask.ndim == 3:
mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
mask = mask /mask.max()
rstImg = cv2.resize(rstImg, dsize=(self.base_size, self.base_size), interpolation = cv2.INTER_LINEAR)
mask = cv2.resize(mask, dsize=(self.base_size, self.base_size), interpolation = cv2.INTER_NEAREST)
return rstImg, mask
def __len__(self):
return len(self.file_names)
def cal_fpr_tpr(dataname, nbins=200, fileName = None):
f = open(fileName, mode = 'a+')
print('Running data: {:s}'.format(dataname))
f.write('Running data: {:s}'.format(dataname) + '\n')
thre = 0.5
baseSize = 256
dataset = Dataset_mat(dataname, base_size=baseSize, thre=thre)
roc = ROCMetric(bins=200)
eval_PD_FA = my_PD_FA()
eval_mIoU_P_R_F = SegmentationMetricTPFNFP(nclass=1)
for i in range(dataset.__len__()):
rstImg, mask = dataset.__getitem__(i)
size = rstImg.shape
roc.update(pred=rstImg, label=mask)
eval_PD_FA.update(rstImg, mask)
eval_mIoU_P_R_F.update(labels=mask, preds=rstImg)
fpr, tpr, auc = roc.get()
pd, fa = eval_PD_FA.get()
miou, prec, recall, fscore = eval_mIoU_P_R_F.get()
print('AUC: %.6f' % (auc))
f.write('AUC: %.6f' % (auc) + '\n')
print('Pd: %.6f, Fa: %.8f' % (pd, fa))
f.write('Pd: %.6f, Fa: %.8f' % (pd, fa) + '\n')
print('mIoU: %.6f, Prec: %.6f, Recall: %.6f, fscore: %.6f' % (miou, prec, recall, fscore))
f.write('mIoU: %.6f, Prec: %.6f, Recall: %.6f, fscore: %.6f' % (miou, prec, recall, fscore) + '\n')
f.write('\n')
save_dict = {'tpr': tpr, 'fpr': fpr, 'Our Pd': pd, 'Our Fa': fa}
matDir = './eval/IndicatorResult/matResult/'
if not os.path.exists(matDir):
os.makedirs(matDir)
matFile = osp.join(matDir, '{:s}.mat'.format(dataname))
scio.savemat(matFile, save_dict)
if __name__ == '__main__':
specific = True
data_list = ['NUDT-SIRST', 'IRSTD-1K', 'SIRST-aug']
fileDir = './eval/IndicatorResult/txtResult/'
fileName = fileDir + 'mat_result.txt'
if not os.path.exists(fileDir):
os.makedirs(fileDir)
f = open(fileName, mode='w+')
f.close()
for data in data_list:
cal_fpr_tpr(dataname=data, nbins=200, fileName = fileName)