-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathkfold.py
140 lines (112 loc) · 4.14 KB
/
kfold.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import json
from glob import glob
import datareader
import random
import numpy as np
import itertools
from tqdm import tqdm
import os
import json
from collections import OrderedDict
class KFold():
def __init__(self, path='data/folds.json'):
with open(path, 'r') as json_f:
self.folds = json.load(json_f)
self.n_splits = len(self.folds)
def get_fold_split(self, fold_id):
test_filenames = self.folds[fold_id]
train_filenames = list(set(itertools.chain.from_iterable(self.folds)) - set(test_filenames))
return train_filenames, test_filenames
def get_last_common_epoch(self, folds_dir):
folds = glob(os.path.join(folds_dir, 'fold_*'))
epochs_in_folds = {}
for i, fold in enumerate(folds):
epochs_in_folds[fold] = set()
checkpoints = glob(os.path.join(fold, '*.pth.tar'))
for checkpoint in checkpoints:
filename = os.path.split(checkpoint)[-1]
try:
epoch_num = int(filename.split('=')[1].split('.')[0])
except IndexError:
continue
epochs_in_folds[fold].add(epoch_num)
if len(epochs_in_folds) == 0:
print('No fold dirs found to get common epoch')
return 0
if len(list(epochs_in_folds.values())[0]) == 0:
print('No checkpoints found to get common epoch')
return 0
common_epochs = list(set.intersection(*epochs_in_folds.values()))
common_epochs.sort(reverse=True)
if len(common_epochs) != 0:
return common_epochs[0]
else:
print('Can\'t find common epoch')
return 0
# Logger should be tied to fold
# Logger should store dicts with epoch as key and thresholds and val scores as values
# Logger should be able to set data for current (specified) epoch
# Logger should be able to read log and return values for highest score
class FoldLogger():
def __init__(self, fold_dir):
self.log_filename = 'log.json'
self.log = OrderedDict()
self.log_file_path = os.path.join(fold_dir, self.log_filename)
try:
self.read()
except FileNotFoundError:
# self.write()
pass
def read(self):
with open(self.log_file_path, 'r') as log_f:
self.log = json.load(log_f)
def write(self):
with open(self.log_file_path, 'w') as log_f:
json.dump(self.log, log_f, indent=4)
log_f.flush()
def log_epoch(self, epoch, data):
# epoch: int, training epoch idx
# data: dict of values from validation and/or validation
assert data is not None
try:
logged_data = self.log[str(epoch)]
data = logged_data.update(data)
except KeyError:
pass
self.log[str(epoch)] = data
self.write()
def get_best_epoch(self):
if len(self.log) > 0:
epoch, data = sorted(self.log.items(), key=lambda x: x[1]['score'], reverse=True)[0]
epoch = int(epoch)
return epoch, data
else:
return None, None
if __name__ == '__main__':
n_splits = 5
dst_json = 'data/folds.json'
folds = []
empty_masks_part_per_fold = 0.1
for i in range(n_splits):
folds.append([])
dataset = datareader.SIIMDataset('data/dicom-images-train', 'data/train-rle.csv', ([768], [768]))
rating = []
for image_dict, target_dict in tqdm(dataset):
mask = target_dict['mask'].numpy().astype(np.float32)
area = np.sum(mask) / mask.size
rating.append((image_dict['image_path'], area))
# if len(rating) > 50:
# break
rating.sort(key=lambda x: x[1], reverse=True)
while len(rating) > 0:
for fold in folds:
if len(rating) > 0:
fold.append(rating[0][0])
print('Area', rating[0][1])
rating.remove(rating[0])
else:
break
for i, fold in enumerate(folds):
print('Fold', i, 'size:', len(fold))
with open(dst_json, 'w') as f:
json.dump(folds, f, indent=4)