-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathvalidation.py
88 lines (76 loc) · 3.22 KB
/
validation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import numpy as np
from torch import nn
import torch
import tqdm
import math
def val_multi(model: nn.Module, criterion, valid_loader, num_classes,batch_size,device_ids):
with torch.no_grad():
model.eval()
losses = []
confusion_matrix = np.zeros(
(num_classes, num_classes), dtype=np.uint32)
dt_size = len(valid_loader.dataset)
tq = tqdm.tqdm(total=math.ceil(dt_size / batch_size))
tq.set_description('Test')
for inputs, targets in valid_loader:
inputs = inputs.cuda(device_ids[0])
targets = targets.long()
targets = targets.cuda(device_ids[0])
outputs = model(inputs)
loss = criterion(outputs, targets)
losses.append(loss.item())
output_classes = outputs.data.cpu().numpy().argmax(axis=1)
target_classes = targets.data.cpu().numpy()
confusion_matrix += calculate_confusion_matrix_from_arrays(
output_classes, target_classes, num_classes)
tq.set_postfix(loss='{0:.3f}'.format(np.mean(losses)))
tq.update(1)
tq.close()
confusion_matrix = confusion_matrix[1:, 1:] # exclude background
valid_loss = np.mean(losses) # type: float
ious = {'iou_{}'.format(cls + 1): iou
for cls, iou in enumerate(calculate_iou(confusion_matrix))}
dices = {'dice_{}'.format(cls + 1): dice
for cls, dice in enumerate(calculate_dice(confusion_matrix))}
average_iou = np.mean(list(ious.values()))
average_dices = np.mean(list(dices.values()))
print('Valid loss: {:.4f}, average IoU: {:.4f}, average Dice: {:.4f}'.format(valid_loss, average_iou, average_dices))
return average_dices, average_iou
def calculate_confusion_matrix_from_arrays(prediction, ground_truth, nr_labels):
replace_indices = np.vstack((
ground_truth.flatten(),
prediction.flatten())
).T
confusion_matrix, _ = np.histogramdd(
replace_indices,
bins=(nr_labels, nr_labels),
range=[(0, nr_labels), (0, nr_labels)]
)
confusion_matrix = confusion_matrix.astype(np.uint32)
return confusion_matrix
def calculate_iou(confusion_matrix):
ious = []
for index in range(confusion_matrix.shape[0]):
true_positives = confusion_matrix[index, index]
false_positives = confusion_matrix[:, index].sum() - true_positives
false_negatives = confusion_matrix[index, :].sum() - true_positives
denom = true_positives + false_positives + false_negatives
if denom == 0:
iou = 0
else:
iou = float(true_positives) / denom
ious.append(iou)
return ious
def calculate_dice(confusion_matrix):
dices = []
for index in range(confusion_matrix.shape[0]):
true_positives = confusion_matrix[index, index]
false_positives = confusion_matrix[:, index].sum() - true_positives
false_negatives = confusion_matrix[index, :].sum() - true_positives
denom = 2 * true_positives + false_positives + false_negatives
if denom == 0:
dice = 0
else:
dice = 2 * float(true_positives) / denom
dices.append(dice)
return dices