-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathseries_finder.py
306 lines (239 loc) · 8.96 KB
/
series_finder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
from matplotlib import pyplot as plt
from argparse import ArgumentParser
import itertools
import sys
import traceback
import attr
import pdb
import toolz
from l3finder.ingest import find_subjects, separate_series, \
construct_series_for_subjects_without_sagittals
from l3finder.exclude import filter_axial_series, filter_sagittal_series, \
load_series_to_skip_pickle_file, remove_series_to_skip
from l3finder.output import output_l3_images_to_h5, output_images
from l3finder.predict import make_predictions_for_sagittal_mips
from l3finder.preprocess import create_sagittal_mip, preprocess_images, \
group_mips_by_dimension, create_sagittal_mips_from_series
from util.reify import reify
from util.investigate import load_subject_ids_to_investigate
def pcs_debugger(type, value, tb):
traceback.print_exception(type, value, tb)
pdb.pm()
sys.excepthook = pcs_debugger
@attr.s
class SagittalMIP:
source_preprocessed_image = attr.ib()
@property
def subject_id(self):
return self.source_preprocessed_image.source_series.subject.id_
@property
def preprocessed_image(self):
return self.source_preprocessed_image
@property
def series(self):
return self.source_preprocessed_image.source_series
# series = attr.ib()
# preprocessed_image = attr.ib()
# @property
# def subject_id(self):
# return self.series.subject.id_
def parse_args():
parser = ArgumentParser()
parser.add_argument(
'--dicom_dir',
required=True,
help='Root directory containing dicoms in format output by Tim\'s '
'script. That is subject_1/accession_xyz/series{sagittal & '
'axial}. The accession directory should contain both a sagittal '
'preprocessed_image series and an axial preprocessed_image series. '
)
parser.add_argument(
'--new_tim_dicom_dir_structure',
action='store_true',
help='Gets subjects in the format used for the 10000 pt dataset versus that used for the 380 pt dataset'
)
parser.add_argument(
'--model_path',
required=True,
help='Path to .h5 model trained using '
'https://github.com/fk128/ct-slice-detection Unet model. '
)
parser.add_argument(
'--output_directory',
required=True,
help='Path to directory where output files will be saved. Will be '
'created if it does not exist '
)
parser.add_argument(
'--show_plots',
action='store_true',
help='Path to directory where output files will be saved'
)
parser.add_argument(
'--overwrite',
action='store_true',
help='Overwrite files within target folder'
)
parser.add_argument(
'--save_plots',
action='store_true',
help='If true, will save side-by-side plot of predicted L3 and the '
'axial slice at that level '
)
parser.add_argument(
"--series_to_skip_pickle_file",
help="Pickle file of series dumped from identify_broken_dicoms.py"
)
parser.add_argument(
'--cache_intermediate_results',
action='store_true',
help='If true, will cache the results of some of the longer running '
'computations in the passed --cache_dir. Note: there is no check to make'
'sure that you actually passed a dir'
)
parser.add_argument(
"--cache_dir",
help="Directory to store cached files in. If none given, no caching"
)
return parser.parse_args()
def main():
args = parse_args()
l3_images = series_finder(config=vars(args))
print("Outputting images")
output_images(
l3_images,
args=dict(
output_directory=args.output_directory,
should_plot=args.show_plots,
should_overwrite=args.overwrite,
should_save_plots=args.save_plots
)
)
return l3_images
def series_finder(config):
print("Finding subjects")
subjects = list(
find_subjects(
config["dicom_dir"],
new_tim_dir_structure=config["new_tim_dicom_dir_structure"]
)
)
print('Subjects found: ', len(subjects))
# Filter by valid patients:
subjects = [subject for subject in subjects if subject.id_ in config["valid_ids"]]
print('Valid Subjects found: ', len(subjects))
# Find series images
print("Finding series")
series = list(flatten(s.find_series() for s in subjects))
exclusions = []
# Separate series
print("Separating series")
sagittal_series, axial_series, excluded_series = separate_series(series)
# Print number of series for each patient:
# print("SHORTENING for development")
# sagittal_series = sagittal_series[:20]
# axial_series = axial_series[:20]
# investigate = set(["Z1243452", "Z1238033"])
# investigate = set(load_subject_ids_to_investigate('/opt/smi/areas_differ_by_gt_10_pct.txt'))
# sagittal_series = [s for s in sagittal_series if s.subject.id_ in investigate]
# axial_series = [s for s in axial_series if s.subject.id_ in investigate]
# Filter series images
print("Filtering out unwanted series")
axial_series, ax_exclusions = filter_axial_series(axial_series)
exclusions.extend(ax_exclusions)
# Reconstruct missing sagittals
constructed_sagittals = construct_series_for_subjects_without_sagittals(
subjects, sagittal_series, axial_series
)
print(
"Series separated\n",
len(sagittal_series), "sagittal series.",
len(axial_series), "axial series.",
len(excluded_series), "excluded series.",
len(constructed_sagittals), "constructed series.",
)
if config.get("series_to_skip_pickle_file", False):
print("Removing unwanted series")
series_to_skip = load_series_to_skip_pickle_file(
config["series_to_skip_pickle_file"])
sagittal_series = remove_series_to_skip(series_to_skip, sagittal_series)
axial_series = remove_series_to_skip(series_to_skip, axial_series)
# print("Just using constructed sagittals for now...")
# sagittal_series = constructed_sagittals
sagittal_series.extend(constructed_sagittals)
sagittal_series, sag_exclusions = filter_sagittal_series(sagittal_series)
exclusions.extend(sag_exclusions)
# FIlter for non-correlated axial and sagittal images
exclusions = filter_sma_and_l3_images(sma_images)
# Print number of series for each patient:
# Move/write series for predictions in output dir:
def flatten(sequence):
"""Converts array of arrays into just an array of items"""
return itertools.chain(*sequence)
def build_l3_images(axial_series, prediction_results):
"""
Pairs axial series with L3 location predictions based on
subject ID.
"""
axials_with_prediction_results = toolz.join(
leftkey=lambda ax: ax.subject.id_,
leftseq=axial_series,
rightkey=lambda pred_res: pred_res.input_mip.subject_id,
rightseq=prediction_results
)
l3_images = [
L3Image(
sagittal_series=result.input_mip.series,
axial_series=ax,
prediction_result=result)
for ax, result in axials_with_prediction_results
]
return l3_images
@attr.s
class L3Image:
axial_series = attr.ib()
sagittal_series = attr.ib()
prediction_result = attr.ib()
@property
def pixel_data(self):
return self.axial_series.image_at_pos_in_px(
self.prediction_result.prediction.predicted_y_in_px,
sagittal_z_pos_pair=self.sagittal_series.z_range_pair
)
@property
def height_of_sagittal_image(self):
return self.sagittal_series.resolution[0]
@property
def number_of_axial_dicoms(self):
return self.axial_series.number_of_dicoms
@property
def subject_id(self):
return self.axial_series.subject.id_
@reify
def prediction_index(self):
index, metadata = self.axial_series.image_index_at_pos(
self.prediction_result.prediction.predicted_y_in_px,
sagittal_z_pos_pair=self.sagittal_series.z_range_pair
)
return index, metadata
def as_csv_row(self):
prediction = self.prediction_result.prediction
prediction_index, l3_axial_slice_metadata = self.prediction_index
row = [
self.axial_series.id_,
prediction.predicted_y_in_px,
prediction.probability,
self.sagittal_series.series_path,
self.axial_series.series_path,
]
row.extend(l3_axial_slice_metadata.as_csv_row())
return row
@property
def predicted_y_in_px(self):
return self.prediction_result.prediction.predicted_y_in_px
def free_pixel_data(self):
"""Frees the memory used in the underlying ImageSeries objects"""
self.axial_series.free_pixel_data()
self.sagittal_series.free_pixel_data()
if __name__ == "__main__":
l3_images = main()