-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_process.py
241 lines (197 loc) · 8.6 KB
/
test_process.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
"""
this process to convert dicom to png image to load the model
"""
import numpy as np
import pandas as pd
import glob
import cv2
import shutil
import ctypes
from pathlib import Path
from tqdm import tqdm
import pydicom
from pydicom.filebase import DicomBytesIO
import multiprocessing as mp
from joblib import Parallel, delayed
from sklearn.model_selection import StratifiedGroupKFold
from utils import load_config
import torch
import torch.nn.functional as F
from nvidia.dali import pipeline_def, types
from nvidia.dali.types import DALIDataType
from nvidia.dali.backend import TensorGPU, TensorListGPU
import nvidia.dali.fn as fn
import nvidia.dali.types as types
to_torch_type = {
types.DALIDataType.FLOAT: torch.float32,
types.DALIDataType.FLOAT64: torch.float64,
types.DALIDataType.FLOAT16: torch.float16,
types.DALIDataType.UINT8: torch.uint8,
types.DALIDataType.INT8: torch.int8,
types.DALIDataType.UINT16: torch.int16,
types.DALIDataType.INT16: torch.int16,
types.DALIDataType.INT32: torch.int32,
types.DALIDataType.INT64: torch.int64
}
def feed_ndarray(dali_tensor, arr, cuda_stream=None):
"""
Copy contents of DALI tensor to PyTorch's Tensor.
Parameters
----------
`dali_tensor` : nvidia.dali.backend.TensorCPU or nvidia.dali.backend.TensorGPU
Tensor from which to copy
`arr` : torch.Tensor
Destination of the copy
`cuda_stream` : torch.cuda.Stream, cudaStream_t or any value that can be cast to cudaStream_t.
CUDA stream to be used for the copy
(if not provided, an internal user stream will be selected)
In most cases, using pytorch's current stream is expected (for example,
if we are copying to a tensor allocated with torch.zeros(...))
"""
dali_type = to_torch_type[dali_tensor.dtype]
assert dali_type == arr.dtype, ("The element type of DALI Tensor/TensorList"
" doesn't match the element type of the target PyTorch Tensor: "
"{} vs {}".format(dali_type, arr.dtype))
assert dali_tensor.shape() == list(arr.size()), \
("Shapes do not match: DALI tensor has size {0}, but PyTorch Tensor has size {1}".
format(dali_tensor.shape(), list(arr.size())))
cuda_stream = types._raw_cuda_stream(cuda_stream)
# turn raw int to a c void pointer
c_type_pointer = ctypes.c_void_p(arr.data_ptr())
if isinstance(dali_tensor, (TensorGPU, TensorListGPU)):
stream = None if cuda_stream is None else ctypes.c_void_p(cuda_stream)
dali_tensor.copy_to_external(c_type_pointer, stream, non_blocking=True)
else:
dali_tensor.copy_to_external(c_type_pointer)
return arr
def process(f, save_folder=""):
patient = f.split('/')[-2]
dicom_id = f.split('/')[-1][:-4]
dicom = dicomsdl.open(f)
img = dicom.pixelData()
img = torch.from_numpy(img)
img = process_dicom(img, dicom)
img = F.interpolate(img.view(1, 1, img.size(0), img.size(1)), (SAVE_SIZE, SAVE_SIZE), mode="bilinear")[0, 0]
img = (img * 255).clip(0,255).to(torch.uint8).cpu().numpy()
out_file_name = SAVE_FOLDER + f"{patient}_{dicom_id}.png"
cv2.imwrite(out_file_name, img)
return out_file_name
def convert_dicom_to_jpg(file, save_folder=""):
patient = file.split('/')[-2]
image = file.split('/')[-1][:-4]
dcmfile = pydicom.dcmread(file)
if dcmfile.file_meta.TransferSyntaxUID == '1.2.840.10008.1.2.4.90':
with open(file, 'rb') as fp:
raw = DicomBytesIO(fp.read())
ds = pydicom.dcmread(raw)
offset = ds.PixelData.find(b"\x00\x00\x00\x0C") #<---- the jpeg2000 header info we're looking for
hackedbitstream = bytearray()
hackedbitstream.extend(ds.PixelData[offset:])
with open(save_folder + f"{patient}_{image}.jpg", "wb") as binary_file:
binary_file.write(hackedbitstream)
if dcmfile.file_meta.TransferSyntaxUID == '1.2.840.10008.1.2.4.70':
with open(file, 'rb') as fp:
raw = DicomBytesIO(fp.read())
ds = pydicom.dcmread(raw)
offset = ds.PixelData.find(b"\xff\xd8\xff\xe0") #<---- the jpeg lossless header info we're looking for
hackedbitstream = bytearray()
hackedbitstream.extend(ds.PixelData[offset:])
with open(save_folder + f"{patient}_{image}.jpg", "wb") as binary_file:
binary_file.write(hackedbitstream)
@pipeline_def
def jpg_decode_pipeline(jpgfiles):
jpegs, _ = fn.readers.file(files=jpgfiles)
images = fn.experimental.decoders.image(jpegs, device='mixed', output_type=types.ANY_DATA, dtype=DALIDataType.UINT16)
return images
def parse_window_element(elem):
if type(elem)==list:
return float(elem[0])
if type(elem)==str:
return float(elem)
if type(elem)==float:
return elem
if type(elem)==pydicom.dataelem.DataElement:
try:
return float(elem[0])
except:
return float(elem.value)
return None
def linear_window(data, center, width):
lower, upper = center - width // 2, center + width // 2
data = torch.clamp(data, min=lower, max=upper)
return data
def process_dicom(img, dicom):
try:
invert = getattr(dicom, "PhotometricInterpretation", None) == "MONOCHROME1"
except:
invert = False
center = parse_window_element(dicom["WindowCenter"])
width = parse_window_element(dicom["WindowWidth"])
if (center is not None) & (width is not None):
img = linear_window(img, center, width)
img = (img - img.min()) / (img.max() - img.min())
if invert:
img = 1 - img
return img
COMP_FOLDER = 'data/'
DATA_FOLDER = 'data/test_images/'
N_CORES = mp.cpu_count()
MIXED_PRECISION = False
DEVICE = 'cuda:1' if torch.cuda.is_available() else 'cpu'
RAM_CHECK = True
DEBUG = True
train_df = pd.read_csv('data/test.csv')
# train_df['cancer'] = 0
RAM_CHECK = True
DEBUG = False
patient_filter = list(sorted((set(train_df.patient_id.unique()))))
train_df = train_df[train_df.patient_id.isin(patient_filter)]
cfg = load_config('config/default.yaml')
# split = StratifiedGroupKFold(cfg.folds)
# for k, (_, test_idx) in enumerate(split.split(train_df, train_df.cancer, groups=train_df.patient_id)):
# train_df.loc[test_idx, 'split'] = k
# train_df.split = train_df.split.astype(int)
print(f'Len df : {len(train_df)}')
train_df['fns'] = train_df['patient_id'].astype(str) + '/' + train_df['image_id'].astype(str) + '.dcm'
print(train_df.head())
y_pred = train_df['cancer'].values
print(type(y_pred))
SAVE_SIZE = int(cfg.image_size * 1.125)
SAVE_FOLDER = 'data/gen_test/'
Path(SAVE_FOLDER).mkdir(parents=True, exist_ok=True)
N_CHUNKS = len(train_df['fns']) // 2000 if len(train_df['fns']) > 2000 else 1
CHUNKS = [(len(train_df['fns']) / N_CHUNKS * k, len(train_df['fns']) / N_CHUNKS * (k+1)) for k in range(N_CHUNKS)]
CHUNKS = np.array(CHUNKS).astype(int)
JPG_FOLDER = 'data/jpg/'
for ttt, chunk in enumerate(CHUNKS):
print(f'chunk {ttt} of {len(CHUNKS)} chunks')
Path(JPG_FOLDER).mkdir(parents=True, exist_ok=True)
_ = Parallel(n_jobs=2)(delayed(convert_dicom_to_jpg)(f'{DATA_FOLDER}/{img}', save_folder=JPG_FOLDER) for img in train_df['fns'].tolist()[chunk[0]: chunk[1]]
)
jpgfiles = glob.glob(JPG_FOLDER + '*.jpg')
pipe = jpg_decode_pipeline(jpgfiles, batch_size=1, num_threads=2, device_id=1)
pipe.build()
for i, f in enumerate(tqdm(jpgfiles)):
patient, dicom_id = f.split('/')[-1][:-4].split('_')
dicom = pydicom.dcmread(DATA_FOLDER + f"/{patient}/{dicom_id}.dcm")
try:
out = pipe.run()
# Dali -> Torch
img = out[0][0]
img_torch = torch.empty(img.shape(), dtype=torch.int16, device=DEVICE)
feed_ndarray(img, img_torch, cuda_stream=torch.cuda.current_stream(device=1))
img = img_torch.float()
# apply dicom preprocessing
img = process_dicom(img, dicom)
# resize the torch image
img = F.interpolate(img.view(1, 1, img.size(0), img.size(1)), (SAVE_SIZE, SAVE_SIZE), mode='bilinear')[0, 0]
img = (img * 255).clip(0, 255).to(torch.uint8).cpu().numpy()
out_file_name = SAVE_FOLDER + f"{patient}_{dicom_id}.png"
cv2.imwrite(out_file_name, img)
except Exception as e:
print(i, e)
pipe = jpg_decode_pipeline(jpgfiles[i+1:], batch_size=1, num_threads=2, device_id=1)
pipe.build()
continue
shutil.rmtree(JPG_FOLDER)
print(f'DALI Raw image load complete')