-
Notifications
You must be signed in to change notification settings - Fork 77
/
Copy pathdataset_volume.py
executable file
·269 lines (227 loc) · 11.3 KB
/
dataset_volume.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
from __future__ import print_function, division
from typing import Optional, List
import numpy as np
import random
import torch
import torch.utils.data
from ..augmentation import Compose
from ..utils import count_volume, crop_volume, relabel, seg_to_targets, seg_to_weights
TARGET_OPT_TYPE = List[str]
WEIGHT_OPT_TYPE = List[List[str]]
AUGMENTOR_TYPE = Optional[Compose]
class VolumeDataset(torch.utils.data.Dataset):
"""
Dataset class for volumetric image datasets. At training time, subvolumes are randomly sampled from all the large
input volumes with (optional) rejection sampling to increase the frequency of foreground regions in a batch. At inference
time, subvolumes are yielded in a sliding-window manner with overlap to counter border artifacts.
Args:
volume (list): list of image volumes.
label (list, optional): list of label volumes. Default: None
valid_mask (list, optional): list of valid masks. Default: None
valid_ratio (float): volume ratio threshold for valid samples. Default: 0.5
sample_volume_size (tuple, int): model input size.
sample_label_size (tuple, int): model output size.
sample_stride (tuple, int): stride size for sampling.
augmentor (connectomics.data.augmentation.composition.Compose, optional): data augmentor for training. Default: None
target_opt (list): list the model targets generated from segmentation labels.
weight_opt (list): list of options for generating pixel-wise weight masks.
mode (str): ``'train'``, ``'val'`` or ``'test'``. Default: ``'train'``
do_2d (bool): load 2d samples from 3d volumes. Default: False
iter_num (int): total number of training iterations (-1 for inference). Default: -1
reject_size_thres (int): threshold to decide if a sampled volumes contains foreground objects. Default: 0
reject_p (float): probability of rejecting non-foreground volumes. Default: 0.95
Note:
For relatively small volumes, the total number of possible subvolumes can be smaller than the total number
of samples required in training (the product of total iterations and mini-natch size), which raises *StopIteration*.
Therefore the dataset length is also decided by the training settings.
"""
background: int = 0 # background label index
def __init__(self,
volume: list,
label: Optional[list] = None,
valid_mask: Optional[list] = None,
valid_ratio: float = 0.5,
sample_volume_size: tuple = (8, 64, 64),
sample_label_size: tuple = (8, 64, 64),
sample_stride: tuple = (1, 1, 1),
augmentor: AUGMENTOR_TYPE = None,
target_opt: TARGET_OPT_TYPE = ['1'],
weight_opt: WEIGHT_OPT_TYPE = [['1']],
mode: str = 'train',
do_2d: bool = False,
iter_num: int = -1,
reject_size_thres: int = 0,
reject_p: float = 0.95):
assert mode in ['train', 'test']
self.mode = mode
self.do_2d = do_2d
if self.do_2d:
assert (sample_volume_size[0]==1) * (sample_label_size[0]==1)
# data format
self.volume = volume
self.label = label
self.augmentor = augmentor
# target and weight options
self.target_opt = target_opt
self.weight_opt = weight_opt
# rejection samping
self.reject_size_thres = reject_size_thres
self.reject_p = reject_p
# dataset: channels, depths, rows, cols
self.volume_size = [np.array(x.shape) for x in self.volume] # volume size, could be multi-volume input
self.sample_volume_size = np.array(sample_volume_size).astype(int) # model input size
if self.label is not None:
self.sample_label_size = np.array(sample_label_size).astype(int) # model label size
assert np.array_equal(self.augmentor.sample_size, self.sample_label_size)
self.label_vol_ratio = self.sample_label_size / self.sample_volume_size
# compute number of samples for each dataset (multi-volume input)
self.sample_stride = np.array(sample_stride).astype(int)
self.sample_size = [count_volume(self.volume_size[x], self.sample_volume_size, self.sample_stride)
for x in range(len(self.volume_size))]
# total number of possible inputs for each volume
self.sample_num = np.array([np.prod(x) for x in self.sample_size])
self.sample_num_a = np.sum(self.sample_num)
self.sample_num_c = np.cumsum([0] + list(self.sample_num))
# handle partially labeled volume
self.valid_mask = valid_mask
self.valid_ratio = valid_ratio
if mode=='test': # for test
self.sample_size_test = [np.array([np.prod(x[1:3]), x[2]]) for x in self.sample_size]
# For relatively small volumes, the total number of samples can be generated is smaller
# than the number of samples required for training (i.e., iteration * batch size). Thus
# we let the __len__() of the dataset return the larger value among the two during training.
if iter_num < 0: # inference mode
self.iter_num = self.sample_num_a
else: # training mode
self.iter_num = max(iter_num, self.sample_num_a)
print('Total number of samples to be generated: ', self.iter_num)
def __len__(self):
# total number of possible samples
return self.iter_num
def __getitem__(self, index):
# orig input: keep uint/int format to save cpu memory
# output sample: need np.float32
vol_size = self.sample_volume_size
if self.mode in ['train','val']:
# train/val mode
sample = self._rejection_sampling(vol_size)
pos, out_volume, out_label, out_valid = sample
if self.do_2d:
out_volume = np.squeeze(out_volume)
out_label = np.squeeze(out_label)
if out_valid is not None:
out_valid = np.squeeze(out_valid)
out_volume = np.expand_dims(out_volume, 0)
# output list
out_target = seg_to_targets(out_label, self.target_opt)
out_weight = seg_to_weights(out_target, self.weight_opt, out_valid)
return pos, out_volume, out_target, out_weight
elif self.mode == 'test':
# test mode
pos = self._get_pos_test(index)
out_volume = (crop_volume(self.volume[pos[0]], vol_size, pos[1:])/255.0).astype(np.float32)
if self.do_2d:
out_volume = np.squeeze(out_volume)
return pos, np.expand_dims(out_volume, 0)
#######################################################
# Position Calculator
#######################################################
def _index_to_dataset(self, index):
return np.argmax(index < self.sample_num_c) - 1 # which dataset
def _index_to_location(self, index, sz):
# index -> z,y,x
# sz: [y*x, x]
pos = [0, 0, 0]
pos[0] = np.floor(index/sz[0])
pz_r = index % sz[0]
pos[1] = int(np.floor(pz_r/sz[1]))
pos[2] = pz_r % sz[1]
return pos
def _get_pos_test(self, index):
pos = [0, 0, 0, 0]
did = self._index_to_dataset(index)
pos[0] = did
index2 = index - self.sample_num_c[did]
pos[1:] = self._index_to_location(index2, self.sample_size_test[did])
# if out-of-bound, tuck in
for i in range(1, 4):
if pos[i] != self.sample_size[pos[0]][i-1]-1:
pos[i] = int(pos[i] * self.sample_stride[i-1])
else:
pos[i] = int(self.volume_size[pos[0]][i-1]-self.sample_volume_size[i-1])
return pos
def _get_pos_train(self, vol_size):
# random: multithread
# np.random: same seed
pos = [0, 0, 0, 0]
# pick a dataset
did = self._index_to_dataset(random.randint(0,self.sample_num_a-1))
pos[0] = did
# pick a position
tmp_size = count_volume(self.volume_size[did], vol_size, self.sample_stride)
tmp_pos = [random.randint(0,tmp_size[x]-1) * self.sample_stride[x] for x in range(len(tmp_size))]
pos[1:] = tmp_pos
return pos
#######################################################
# Volume Sampler
#######################################################
def _rejection_sampling(self, vol_size):
"""Rejection sampling to filter out samples without required number
of foreground pixels or valid ratio.
"""
while True:
sample = self._random_sampling(vol_size)
pos, out_volume, out_label, out_valid = sample
if self.augmentor is not None:
if out_valid is not None:
assert 'valid_mask' in self.augmentor.additional_targets.keys(), \
"Need to specify the 'valid_mask' option in additional_targets " \
"of the data augmentor when training with partial annotation."
data = {'image': out_volume,
'label': out_label,
'valid_mask': out_valid}
augmented = self.augmentor(data)
out_volume, out_label = augmented['image'], augmented['label']
out_valid = augmented['valid_mask']
if self._is_valid(out_valid) and self._is_fg(out_label):
break
return pos, out_volume, out_label, out_valid
def _random_sampling(self, vol_size):
"""Randomly sample a subvolume from all the volumes.
"""
pos = self._get_pos_train(vol_size)
out_volume = (crop_volume(self.volume[pos[0]], vol_size, pos[1:])/255.0).astype(np.float32)
# position in the label and valid mask
pos_l = np.round(pos[1:]*self.label_vol_ratio)
out_label = crop_volume(self.label[pos[0]], self.sample_label_size, pos_l)
# For warping: cv2.remap requires input to be float32.
# Make labels index smaller. Otherwise uint32 and float32 are not
# the same for some values.
out_label = relabel(out_label.copy()).astype(np.float32)
if self.valid_mask is not None:
out_valid = crop_volume(self.label[pos[0]],
self.sample_label_size, pos_l)
out_valid = (out_valid!=0).astype(np.float32)
else:
out_valid = None
return pos, out_volume, out_label, out_valid
def _is_valid(self, out_valid):
"""Decide whether the sampled region is valid or not using
the corresponding valid mask.
"""
if self.valid_mask is None:
return True
ratio = float(out_valid.sum()) / np.prod(np.array(out_valid.shape))
return ratio > self.valid_ratio
def _is_fg(self, out_label):
"""Decide whether the sample belongs to a foreground decided
by the rejection sampling criterion.
"""
size_thres = self.reject_size_thres
p = self.reject_p
if self.reject_size_thres > 0:
temp = out_label.copy().astype(int)
temp = (temp!=self.background).astype(int).sum()
if temp < size_thres and random.random() < p:
return False
return True