From 45db7be01d0ebcd4a91a071b4055eee4aac949cf Mon Sep 17 00:00:00 2001 From: David Christoffer Hansen Date: Fri, 29 Sep 2017 09:58:27 +0200 Subject: [PATCH] Converted Team 'grt123' identification and classification algorithms based on the code from https://github.com/lfz/DSB2017 --- .gitattributes | 1 + compose/prediction/Dockerfile-dev | 7 +- prediction/requirements/base.txt | 1 + .../classify/assets/gtr123_model.ckpt | 3 + .../algorithms/classify/src/gtr123_model.py | 315 +++++++++++ .../src/algorithms/classify/trained_model.py | 40 +- .../identify/assets/dsb2017_detector.ckpt | 3 + .../src/algorithms/identify/src/__init__.py | 0 .../algorithms/identify/src/gtr123_model.py | 512 ++++++++++++++++++ .../src/algorithms/identify/trained_model.py | 48 +- prediction/src/preprocess/extract_lungs.py | 270 +++++++++ .../src/preprocess/gtr123_preprocess.py | 61 +++ prediction/src/tests/__init__.py | 1 + .../test_classify_trained_model_predict.py | 14 +- prediction/src/tests/test_endpoints.py | 5 +- .../test_identify_trained_model_predict.py | 49 ++ 16 files changed, 1277 insertions(+), 53 deletions(-) create mode 100644 prediction/src/algorithms/classify/assets/gtr123_model.ckpt create mode 100644 prediction/src/algorithms/classify/src/gtr123_model.py create mode 100644 prediction/src/algorithms/identify/assets/dsb2017_detector.ckpt create mode 100644 prediction/src/algorithms/identify/src/__init__.py create mode 100644 prediction/src/algorithms/identify/src/gtr123_model.py create mode 100644 prediction/src/preprocess/extract_lungs.py create mode 100644 prediction/src/preprocess/gtr123_preprocess.py create mode 100644 prediction/src/tests/test_identify_trained_model_predict.py diff --git a/.gitattributes b/.gitattributes index 56ba7b58..831b4ff9 100644 --- a/.gitattributes +++ b/.gitattributes @@ -9,3 +9,4 @@ test/assets/* filter=lfs diff=lfs merge=lfs -text *.hd5 filter=lfs diff=lfs merge=lfs -text *.mhd filter=lfs diff=lfs merge=lfs -text *.raw filter=lfs diff=lfs merge=lfs -text +*.ckpt filter=lfs diff=lfs merge=lfs -text diff --git a/compose/prediction/Dockerfile-dev b/compose/prediction/Dockerfile-dev index 7f6e4de0..c31be245 100644 --- a/compose/prediction/Dockerfile-dev +++ b/compose/prediction/Dockerfile-dev @@ -1,6 +1,9 @@ -FROM python:3.6 +FROM ubuntu:rolling ENV PYTHONUNBUFFERED 1 - +RUN apt-get update && apt-get install -y tcl tk python3.6 python3.6-tk wget python-opencv +RUN wget https://bootstrap.pypa.io/get-pip.py +RUN python3.6 get-pip.py +RUN ln -s /usr/bin/python3.6 /usr/local/bin/python # Requirements have to be pulled and installed here, otherwise caching won't work COPY ./prediction/requirements /requirements RUN pip install -r /requirements/local.txt diff --git a/prediction/requirements/base.txt b/prediction/requirements/base.txt index 2262807b..4b9c01f1 100644 --- a/prediction/requirements/base.txt +++ b/prediction/requirements/base.txt @@ -10,3 +10,4 @@ opencv-python==3.3.0.10 pandas==0.20.3 scikit-image==0.13.0 SimpleITK==1.0.1 +http://download.pytorch.org/whl/cu80/torch-0.2.0.post3-cp36-cp36m-manylinux1_x86_64.whl diff --git a/prediction/src/algorithms/classify/assets/gtr123_model.ckpt b/prediction/src/algorithms/classify/assets/gtr123_model.ckpt new file mode 100644 index 00000000..516adbb2 --- /dev/null +++ b/prediction/src/algorithms/classify/assets/gtr123_model.ckpt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:eb2c037dd55e2f49da95657f7a7851cbcfe2f2b516848ed03f8c5c820f3e16b4 +size 21621745 diff --git a/prediction/src/algorithms/classify/src/gtr123_model.py b/prediction/src/algorithms/classify/src/gtr123_model.py new file mode 100644 index 00000000..854a1bb2 --- /dev/null +++ b/prediction/src/algorithms/classify/src/gtr123_model.py @@ -0,0 +1,315 @@ +import torch +import numpy as np +from torch.autograd import Variable +from torch import nn +import SimpleITK as sitk + +from src.preprocess.gtr123_preprocess import lum_trans, resample + +"""" +Classification model from team gtr123 +Code adapted from https://github.com/lfz/DSB2017 +""" +config = {} + +config['crop_size'] = [96, 96, 96] +config['scaleLim'] = [0.85, 1.15] +config['radiusLim'] = [6, 100] + +config['stride'] = 4 + +config['detect_th'] = 0.05 +config['conf_th'] = -1 +config['nms_th'] = 0.05 +config['filling_value'] = 160 + +config['startepoch'] = 20 +config['lr_stage'] = np.array([50, 100, 140, 160]) +config['lr'] = [0.01, 0.001, 0.0001, 0.00001] +config['miss_ratio'] = 1 +config['miss_thresh'] = 0.03 +config['anchors'] = [10, 30, 60] + + +class PostRes(nn.Module): + """ """ + + def __init__(self, n_in, n_out, stride=1): + super(PostRes, self).__init__() + self.conv1 = nn.Conv3d(n_in, n_out, kernel_size=3, stride=stride, padding=1) + self.bn1 = nn.BatchNorm3d(n_out) + self.relu = nn.ReLU(inplace=True) + self.conv2 = nn.Conv3d(n_out, n_out, kernel_size=3, padding=1) + self.bn2 = nn.BatchNorm3d(n_out) + + if stride != 1 or n_out != n_in: + self.shortcut = nn.Sequential( + nn.Conv3d(n_in, n_out, kernel_size=1, stride=stride), + nn.BatchNorm3d(n_out)) + else: + self.shortcut = None + + def forward(self, x): + + residual = x + if self.shortcut is not None: + residual = self.shortcut(x) + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + out = self.conv2(out) + out = self.bn2(out) + + out += residual + out = self.relu(out) + return out + + +class Net(nn.Module): + """ """ + + def __init__(self): + super(Net, self).__init__() + # The first few layers consumes the most memory, so use simple + # convolution to save memory. Call these layers preBlock, i.e., before + # the residual blocks of later layers. + self.preBlock = nn.Sequential( + nn.Conv3d(1, 24, kernel_size=3, padding=1), + nn.BatchNorm3d(24), + nn.ReLU(inplace=True), + nn.Conv3d(24, 24, kernel_size=3, padding=1), + nn.BatchNorm3d(24), + nn.ReLU(inplace=True)) + + # 3 poolings, each pooling downsamples the feature map by a factor 2. + # 3 groups of blocks. The first block of each group has one pooling. + num_blocks_forw = [2, 2, 3, 3] + num_blocks_back = [3, 3] + self.featureNum_forw = [24, 32, 64, 64, 64] + self.featureNum_back = [128, 64, 64] + + for i in range(len(num_blocks_forw)): + blocks = [] + + for j in range(num_blocks_forw[i]): + if j == 0: + blocks.append(PostRes(self.featureNum_forw[i], self.featureNum_forw[i + 1])) + else: + blocks.append(PostRes(self.featureNum_forw[i + 1], self.featureNum_forw[i + 1])) + + setattr(self, 'forw' + str(i + 1), nn.Sequential(*blocks)) + + for i in range(len(num_blocks_back)): + blocks = [] + + for j in range(num_blocks_back[i]): + if j == 0: + if i == 0: + addition = 3 + else: + addition = 0 + + blocks.append(PostRes(self.featureNum_back[i + 1] + self.featureNum_forw[i + 2] + addition, + self.featureNum_back[i])) + else: + blocks.append(PostRes(self.featureNum_back[i], self.featureNum_back[i])) + + setattr(self, 'back' + str(i + 2), nn.Sequential(*blocks)) + + self.maxpool1 = nn.MaxPool3d(kernel_size=2, stride=2, return_indices=True) + self.maxpool2 = nn.MaxPool3d(kernel_size=2, stride=2, return_indices=True) + self.maxpool3 = nn.MaxPool3d(kernel_size=2, stride=2, return_indices=True) + self.maxpool4 = nn.MaxPool3d(kernel_size=2, stride=2, return_indices=True) + self.unmaxpool1 = nn.MaxUnpool3d(kernel_size=2, stride=2) + self.unmaxpool2 = nn.MaxUnpool3d(kernel_size=2, stride=2) + + self.path1 = nn.Sequential( + nn.ConvTranspose3d(64, 64, kernel_size=2, stride=2), + nn.BatchNorm3d(64), + nn.ReLU(inplace=True)) + + self.path2 = nn.Sequential( + nn.ConvTranspose3d(64, 64, kernel_size=2, stride=2), + nn.BatchNorm3d(64), + nn.ReLU(inplace=True)) + + self.drop = nn.Dropout3d(p=0.2, inplace=False) + self.output = nn.Sequential(nn.Conv3d(self.featureNum_back[0], 64, kernel_size=1), + nn.ReLU(), + # nn.Dropout3d(p = 0.3), + nn.Conv3d(64, 5 * len(config['anchors']), kernel_size=1)) + + def forward(self, x, coord): + """ + + Args: + x: + coord: + + Returns: + + """ + out = self.preBlock(x) # 16 + out_pool, indices0 = self.maxpool1(out) + out1 = self.forw1(out_pool) # 32 + out1_pool, indices1 = self.maxpool2(out1) + out2 = self.forw2(out1_pool) # 64 + # out2 = self.drop(out2) + out2_pool, indices2 = self.maxpool3(out2) + out3 = self.forw3(out2_pool) # 96 + out3_pool, indices3 = self.maxpool4(out3) + out4 = self.forw4(out3_pool) # 96 + # out4 = self.drop(out4) + + rev3 = self.path1(out4) + comb3 = self.back3(torch.cat((rev3, out3), 1)) # 96+96 + # comb3 = self.drop(comb3) + rev2 = self.path2(comb3) + + feat = self.back2(torch.cat((rev2, out2, coord), 1)) # 64+64 + comb2 = self.drop(feat) + out = self.output(comb2) + size = out.size() + out = out.view(out.size(0), out.size(1), -1) + # out = out.transpose(1, 4).transpose(1, 2).transpose(2, 3).contiguous() + out = out.transpose(1, 2).contiguous().view(size[0], size[2], size[3], size[4], len(config['anchors']), 5) + # out = out.view(-1, 5) + return feat, out + + +class CaseNet(nn.Module): + """The classification Net from the gtr123 team - part of the Winning algorithm for DSB2017""" + + def __init__(self): + super(CaseNet, self).__init__() + self.NoduleNet = Net() + self.fc1 = nn.Linear(128, 64) + self.fc2 = nn.Linear(64, 1) + self.pool = nn.MaxPool3d(kernel_size=2) + self.dropout = nn.Dropout(0.5) + self.baseline = nn.Parameter(torch.Tensor([-30.0]).float()) + self.Relu = nn.ReLU() + + def forward(self, xlist, coordlist): + """ + + Args: + xlist: Image of size n x k x 1x 96 x 96 x 96 + coordlist: Coordinates of size n x k x 3 x 24 x 24 x 24 + + Returns: + + """ + xsize = xlist.size() + corrdsize = coordlist.size() + print(xsize) + # xlist = xlist.view(-1,xsize[2],xsize[3],xsize[4],xsize[5]) + # coordlist = coordlist.view(-1,corrdsize[2],corrdsize[3],corrdsize[4],corrdsize[5]) + + noduleFeat, nodulePred = self.NoduleNet(xlist, coordlist) + nodulePred = nodulePred.contiguous().view(corrdsize[0], corrdsize[1], -1) + + featshape = noduleFeat.size() # nk x 128 x 24 x 24 x24 + centerFeat = self.pool(noduleFeat[:, :, featshape[2] // 2 - 1:featshape[2] // 2 + 1, + featshape[3] // 2 - 1:featshape[3] // 2 + 1, + featshape[4] // 2 - 1:featshape[4] // 2 + 1]) + centerFeat = centerFeat[:, :, 0, 0, 0] + out = self.dropout(centerFeat) + out = self.Relu(self.fc1(out)) + out = torch.sigmoid(self.fc2(out)) + out = out.view(xsize[0], xsize[1]) + base_prob = torch.sigmoid(self.baseline) + casePred = 1 - torch.prod(1 - out, dim=1) * (1 - base_prob.expand(out.size()[0])) + return nodulePred, casePred, out + + +class SimpleCrop(object): + """ """ + + def __init__(self): + self.crop_size = config['crop_size'] + self.scaleLim = config['scaleLim'] + self.radiusLim = config['radiusLim'] + self.stride = config['stride'] + self.filling_value = config['filling_value'] + + def __call__(self, imgs, target): + crop_size = np.array(self.crop_size).astype('int') + + start = (target[:3] - crop_size / 2).astype('int') + pad = [[0, 0]] + + for i in range(3): + if start[i] < 0: + leftpad = -start[i] + start[i] = 0 + else: + leftpad = 0 + if start[i] + crop_size[i] > imgs.shape[i + 1]: + rightpad = start[i] + crop_size[i] - imgs.shape[i + 1] + else: + rightpad = 0 + + pad.append([leftpad, rightpad]) + + imgs = np.pad(imgs, pad, 'constant', constant_values=self.filling_value) + crop = imgs[:, start[0]:start[0] + crop_size[0], start[1]:start[1] + crop_size[1], + start[2]:start[2] + crop_size[2]] + + normstart = np.array(start).astype('float32') / np.array(imgs.shape[1:]) - 0.5 + normsize = np.array(crop_size).astype('float32') / np.array(imgs.shape[1:]) + xx, yy, zz = np.meshgrid(np.linspace(normstart[0], normstart[0] + normsize[0], self.crop_size[0] / self.stride), + np.linspace(normstart[1], normstart[1] + normsize[1], self.crop_size[1] / self.stride), + np.linspace(normstart[2], normstart[2] + normsize[2], self.crop_size[2] / self.stride), + indexing='ij') + coord = np.concatenate([xx[np.newaxis, ...], yy[np.newaxis, ...], zz[np.newaxis, :]], 0).astype('float32') + + return crop, coord + + +def predict(image_itk, nodule_list, model_path="src/algorithms/classify/assets/gtr123_model.ckpt"): + """ + + Args: + image_itk: ITK dicom image + nodule_list: List of nodules + model_path: Path to the torch model (Default value = "src/algorithms/classify/assets/gtr123_model.ckpt") + + Returns: + List of nodules, and probabilities + + """ + if not nodule_list: + return [] + casenet = CaseNet() + + casenet.load_state_dict(torch.load(model_path)) + casenet.eval() + + if torch.cuda.is_available(): + casenet = torch.nn.DataParallel(casenet).cuda() + # else: + # casenet = torch.nn.parallel.DistributedDataParallel(casenet) + + image = sitk.GetArrayFromImage(image_itk) + spacing = np.array(image_itk.GetSpacing())[::-1] + image = lum_trans(image) + image = resample(image, spacing, np.array([1, 1, 1]), order=1)[0] + + crop = SimpleCrop() + + results = [] + for nodule in nodule_list: + print(nodule) + nod_location = np.array([np.float32(nodule[s]) for s in ["z", "y", "x"]]) + nod_location *= spacing + cropped_image, coords = crop(image[np.newaxis], nod_location) + cropped_image = Variable(torch.from_numpy(cropped_image[np.newaxis]).float()) + cropped_image.volatile = True + coords = Variable(torch.from_numpy(coords[np.newaxis]).float()) + coords.volatile = True + _, pred, _ = casenet(cropped_image, coords) + results.append( + {"x": nodule["x"], "y": nodule["y"], "z": nodule["z"], "p_concerning": float(pred.data.cpu().numpy())}) + + return results diff --git a/prediction/src/algorithms/classify/trained_model.py b/prediction/src/algorithms/classify/trained_model.py index d92b5be8..b9610425 100644 --- a/prediction/src/algorithms/classify/trained_model.py +++ b/prediction/src/algorithms/classify/trained_model.py @@ -7,10 +7,11 @@ for if nodules are concerning or not. """ -import numpy as np -import keras.models +from src.algorithms.classify.src import gtr123_model from src.preprocess.load_ct import load_ct, MetaData +import SimpleITK as sitk + def predict(dicom_path, centroids, model_path=None, preprocess_ct=None, preprocess_model_input=None): @@ -43,23 +44,26 @@ def predict(dicom_path, centroids, model_path=None, 'z': int, 'p_concerning': float} """ - if not len(centroids) or model_path is None: - return [] + reader = sitk.ImageSeriesReader() + filenames = reader.GetGDCMSeriesFileNames(dicom_path) + + if not filenames: + raise ValueError("The path doesn't contain neither .mhd nor .dcm files") + + reader.SetFileNames(filenames) + image = reader.Execute() - model = keras.models.load_model(model_path) - ct_array, meta = load_ct(dicom_path) - if preprocess_ct is not None: - meta = MetaData(meta) - ct_array = preprocess_ct(ct_array, meta) - if not isinstance(ct_array, np.ndarray): - raise TypeError('The signature of preprocess_ct must be ' + - 'callable[list[DICOM], ndarray] -> ndarray') + if preprocess_ct: + meta = load_ct(dicom_path)[1] + voxel_data = preprocess_ct(image, MetaData(meta)) + else: + voxel_data = image - patches = preprocess_model_input(ct_array, centroids) - predictions = model.predict(patches) - predictions = predictions.astype(np.float) + if preprocess_model_input: + preprocessed = preprocess_model_input(voxel_data, centroids) + else: + preprocessed = voxel_data - for i, centroid in enumerate(centroids): - centroid['p_concerning'] = predictions[i, 0] + model_path = model_path or "src/algorithms/classify/assets/gtr123_model.ckpt" - return centroids + return gtr123_model.predict(preprocessed, centroids, model_path) diff --git a/prediction/src/algorithms/identify/assets/dsb2017_detector.ckpt b/prediction/src/algorithms/identify/assets/dsb2017_detector.ckpt new file mode 100644 index 00000000..5f8c69f6 --- /dev/null +++ b/prediction/src/algorithms/identify/assets/dsb2017_detector.ckpt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3a46182c1c4883e170d5b576d3a2645309f34817c53c16d4f7f29fbdcf799308 +size 21584388 diff --git a/prediction/src/algorithms/identify/src/__init__.py b/prediction/src/algorithms/identify/src/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/prediction/src/algorithms/identify/src/gtr123_model.py b/prediction/src/algorithms/identify/src/gtr123_model.py new file mode 100644 index 00000000..f124339b --- /dev/null +++ b/prediction/src/algorithms/identify/src/gtr123_model.py @@ -0,0 +1,512 @@ +import torch +from torch import nn +from torch.autograd import Variable +from src.preprocess.extract_lungs import extract_lungs +from src.preprocess.gtr123_preprocess import lum_trans, resample +from scipy.special import expit + +import SimpleITK as sitk +import numpy as np + +"""" +Detector model from team gtr123 +Code adapted from https://github.com/lfz/DSB2017 +""" + +config = {} +config['anchors'] = [10.0, 30.0, 60.] +config['channel'] = 1 +config['crop_size'] = [128, 128, 128] +config['stride'] = 4 +config['max_stride'] = 16 +config['num_neg'] = 800 +config['th_neg'] = 0.02 +config['th_pos_train'] = 0.5 +config['th_pos_val'] = 1 +config['num_hard'] = 2 +config['bound_size'] = 12 +config['reso'] = 1 +config['sizelim'] = 6. # mm +config['sizelim2'] = 30 +config['sizelim3'] = 40 +config['aug_scale'] = True +config['r_rand_crop'] = 0.3 +config['pad_value'] = 170 + +__all__ = ["Net", "lum_trans", "resample", "GetPBB", "SplitComb"] + + +class PostRes(nn.Module): + """ """ + def __init__(self, n_in, n_out, stride=1): + super(PostRes, self).__init__() + self.conv1 = nn.Conv3d(n_in, n_out, kernel_size=3, stride=stride, padding=1) + self.bn1 = nn.BatchNorm3d(n_out) + self.relu = nn.ReLU(inplace=True) + self.conv2 = nn.Conv3d(n_out, n_out, kernel_size=3, padding=1) + self.bn2 = nn.BatchNorm3d(n_out) + + if stride != 1 or n_out != n_in: + self.shortcut = nn.Sequential( + nn.Conv3d(n_in, n_out, kernel_size=1, stride=stride), + nn.BatchNorm3d(n_out)) + else: + self.shortcut = None + + def forward(self, x): + """ + + Args: + x: + + Returns: + + """ + residual = x + if self.shortcut is not None: + residual = self.shortcut(x) + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + out = self.conv2(out) + out = self.bn2(out) + + out += residual + out = self.relu(out) + return out + + +class Net(nn.Module): + """The identification algorithm from Team grt123. Part of the winning algorithm.""" + + def __init__(self): + super(Net, self).__init__() + # The first few layers consumes the most memory, so use simple convolution to save memory. + # Call these layers preBlock, i.e., before the residual blocks of later layers. + self.preBlock = nn.Sequential( + nn.Conv3d(1, 24, kernel_size=3, padding=1), + nn.BatchNorm3d(24), + nn.ReLU(inplace=True), + nn.Conv3d(24, 24, kernel_size=3, padding=1), + nn.BatchNorm3d(24), + nn.ReLU(inplace=True)) + + # 3 poolings, each pooling downsamples the feature map by a factor 2. + # 3 groups of blocks. The first block of each group has one pooling. + num_blocks_forw = [2, 2, 3, 3] + num_blocks_back = [3, 3] + self.featureNum_forw = [24, 32, 64, 64, 64] + self.featureNum_back = [128, 64, 64] + for i in range(len(num_blocks_forw)): + blocks = [] + for j in range(num_blocks_forw[i]): + if j == 0: + blocks.append(PostRes(self.featureNum_forw[i], self.featureNum_forw[i + 1])) + else: + blocks.append(PostRes(self.featureNum_forw[i + 1], self.featureNum_forw[i + 1])) + setattr(self, 'forw' + str(i + 1), nn.Sequential(*blocks)) + + for i in range(len(num_blocks_back)): + blocks = [] + for j in range(num_blocks_back[i]): + if j == 0: + if i == 0: + addition = 3 + else: + addition = 0 + blocks.append(PostRes(self.featureNum_back[i + 1] + self.featureNum_forw[i + 2] + addition, + self.featureNum_back[i])) + else: + blocks.append(PostRes(self.featureNum_back[i], self.featureNum_back[i])) + setattr(self, 'back' + str(i + 2), nn.Sequential(*blocks)) + + self.maxpool1 = nn.MaxPool3d(kernel_size=2, stride=2, return_indices=True) + self.maxpool2 = nn.MaxPool3d(kernel_size=2, stride=2, return_indices=True) + self.maxpool3 = nn.MaxPool3d(kernel_size=2, stride=2, return_indices=True) + self.maxpool4 = nn.MaxPool3d(kernel_size=2, stride=2, return_indices=True) + self.unmaxpool1 = nn.MaxUnpool3d(kernel_size=2, stride=2) + self.unmaxpool2 = nn.MaxUnpool3d(kernel_size=2, stride=2) + + self.path1 = nn.Sequential( + nn.ConvTranspose3d(64, 64, kernel_size=2, stride=2), + nn.BatchNorm3d(64), + nn.ReLU(inplace=True)) + self.path2 = nn.Sequential( + nn.ConvTranspose3d(64, 64, kernel_size=2, stride=2), + nn.BatchNorm3d(64), + nn.ReLU(inplace=True)) + self.drop = nn.Dropout3d(p=0.2, inplace=False) + self.output = nn.Sequential(nn.Conv3d(self.featureNum_back[0], 64, kernel_size=1), + nn.ReLU(), + # nn.Dropout3d(p = 0.3), + nn.Conv3d(64, 5 * len(config['anchors']), kernel_size=1)) + + def forward(self, x, coord): + """ + + Args: + x: + coord: + + Returns: + + """ + out = self.preBlock(x) # 16 + out_pool, indices0 = self.maxpool1(out) + out1 = self.forw1(out_pool) # 32 + out1_pool, indices1 = self.maxpool2(out1) + out2 = self.forw2(out1_pool) # 64 + # out2 = self.drop(out2) + out2_pool, indices2 = self.maxpool3(out2) + out3 = self.forw3(out2_pool) # 96 + out3_pool, indices3 = self.maxpool4(out3) + out4 = self.forw4(out3_pool) # 96 + # out4 = self.drop(out4) + + rev3 = self.path1(out4) + comb3 = self.back3(torch.cat((rev3, out3), 1)) # 96+96 + # comb3 = self.drop(comb3) + rev2 = self.path2(comb3) + feat = self.back2(torch.cat((rev2, out2, coord), 1)) # 64+64 + comb2 = self.drop(feat) + out = self.output(comb2) + size = out.size() + out = out.view(out.size(0), out.size(1), -1) + # out = out.transpose(1, 4).transpose(1, 2).transpose(2, 3).contiguous() + out = out.transpose(1, 2).contiguous().view(size[0], size[2], size[3], size[4], len(config['anchors']), 5) + # out = out.view(-1, 5) + return out + + +class GetPBB(object): + """ """ + def __init__(self, stride=4, anchors=(10.0, 30.0, 60.)): + self.stride = stride + self.anchors = np.asarray(anchors) + + def __call__(self, output, thresh=-3, ismask=False): + stride = self.stride + anchors = self.anchors + output = np.copy(output) + offset = (float(stride) - 1) / 2 + output_size = output.shape + oz = np.arange(offset, offset + stride * (output_size[0] - 1) + 1, stride) + oh = np.arange(offset, offset + stride * (output_size[1] - 1) + 1, stride) + ow = np.arange(offset, offset + stride * (output_size[2] - 1) + 1, stride) + + output[:, :, :, :, 1] = oz.reshape((-1, 1, 1, 1)) + output[:, :, :, :, 1] * anchors.reshape((1, 1, 1, -1)) + output[:, :, :, :, 2] = oh.reshape((1, -1, 1, 1)) + output[:, :, :, :, 2] * anchors.reshape((1, 1, 1, -1)) + output[:, :, :, :, 3] = ow.reshape((1, 1, -1, 1)) + output[:, :, :, :, 3] * anchors.reshape((1, 1, 1, -1)) + output[:, :, :, :, 4] = np.exp(output[:, :, :, :, 4]) * anchors.reshape((1, 1, 1, -1)) + mask = output[..., 0] > thresh + xx, yy, zz, aa = np.where(mask) + + output = output[xx, yy, zz, aa] + if ismask: + return output, [xx, yy, zz, aa] + else: + return output + + +class SplitComb(object): + """ """ + def __init__(self, side_len, max_stride, stride, margin, pad_value): + self.side_len = side_len + self.max_stride = max_stride + self.stride = stride + self.margin = margin + self.pad_value = pad_value + + def split(self, data, side_len=None, max_stride=None, margin=None): + """ + + Args: + data: + side_len: (Default value = None) + max_stride: (Default value = None) + margin: (Default value = None) + + Returns: + + """ + if side_len is None: + side_len = self.side_len + if max_stride is None: + max_stride = self.max_stride + if margin is None: + margin = self.margin + + assert (side_len > margin) + assert (side_len % max_stride == 0) + assert (margin % max_stride == 0) + + splits = [] + _, z, h, w = data.shape + + nz = int(np.ceil(float(z) / side_len)) + nh = int(np.ceil(float(h) / side_len)) + nw = int(np.ceil(float(w) / side_len)) + + nzhw = [nz, nh, nw] + self.nzhw = nzhw + + pad = [[0, 0], + [margin, nz * side_len - z + margin], + [margin, nh * side_len - h + margin], + [margin, nw * side_len - w + margin]] + data = np.pad(data, pad, 'edge') + + for iz in range(nz): + for ih in range(nh): + for iw in range(nw): + sz = iz * side_len + ez = (iz + 1) * side_len + 2 * margin + sh = ih * side_len + eh = (ih + 1) * side_len + 2 * margin + sw = iw * side_len + ew = (iw + 1) * side_len + 2 * margin + + split = data[np.newaxis, :, sz:ez, sh:eh, sw:ew] + splits.append(split) + + splits = np.concatenate(splits, 0) + return splits, nzhw + + def combine(self, output, nzhw=None, side_len=None, stride=None, margin=None): + """ + + Args: + output: + nzhw: (Default value = None) + side_len: (Default value = None) + stride: (Default value = None) + margin: (Default value = None) + + Returns: + + """ + + if side_len is None: + side_len = self.side_len + if stride is None: + stride = self.stride + if margin is None: + margin = self.margin + if nzhw is None: + nz = self.nz + nh = self.nh + nw = self.nw + else: + nz, nh, nw = nzhw + assert (side_len % stride == 0) + assert (margin % stride == 0) + side_len //= stride + margin //= stride + + splits = [] + for i in range(len(output)): + splits.append(output[i]) + + output = -1000000 * np.ones(( + nz * side_len, + nh * side_len, + nw * side_len, + splits[0].shape[3], + splits[0].shape[4]), np.float32) + + idx = 0 + for iz in range(nz): + for ih in range(nh): + for iw in range(nw): + sz = iz * side_len + ez = (iz + 1) * side_len + sh = ih * side_len + eh = (ih + 1) * side_len + sw = iw * side_len + ew = (iw + 1) * side_len + + split = splits[idx][margin:margin + side_len, margin:margin + side_len, margin:margin + side_len] + output[sz:ez, sh:eh, sw:ew] = split + idx += 1 + + return output + + +def split_data(imgs, split_comber, stride=4): + """Image tends to be too big to fit on even very large memory systems. This function splits it up into manageable + chunks. + + Args: + imgs: param split_comber: + stride: return: (Default value = 4) + split_comber: + + Returns: + + """ + nz, nh, nw = imgs.shape[1:] + pz = int(np.ceil(float(nz) / stride)) * stride + ph = int(np.ceil(float(nh) / stride)) * stride + pw = int(np.ceil(float(nw) / stride)) * stride + imgs = np.pad(imgs, [[0, 0], [0, pz - nz], [0, ph - nh], [0, pw - nw]], 'constant', + constant_values=split_comber.pad_value) + + xx, yy, zz = np.meshgrid(np.linspace(-0.5, 0.5, imgs.shape[1] // stride), + np.linspace(-0.5, 0.5, imgs.shape[2] // stride), + np.linspace(-0.5, 0.5, imgs.shape[3] // stride), indexing='ij') + coord = np.concatenate([xx[np.newaxis, ...], yy[np.newaxis, ...], zz[np.newaxis, :]], 0).astype('float32') + imgs, nzhw = split_comber.split(imgs) + coord2, nzhw2 = split_comber.split(coord, + side_len=split_comber.side_len // stride, + max_stride=split_comber.max_stride // stride, + margin=int(split_comber.margin // stride)) + assert np.all(nzhw == nzhw2) + imgs = (imgs.astype(np.float32) - 128) / 128 + return torch.from_numpy(imgs), torch.from_numpy(coord2), np.array(nzhw) + + +def iou(box0, box1): + """ + + Args: + box0: + box1: + + Returns: + Intersection over union + + """ + + r0 = box0[3] / 2 + s0 = box0[:3] - r0 + e0 = box0[:3] + r0 + + r1 = box1[3] / 2 + s1 = box1[:3] - r1 + e1 = box1[:3] + r1 + + overlap = [] + for i in range(len(s0)): + overlap.append(max(0, min(e0[i], e1[i]) - max(s0[i], s1[i]))) + + intersection = overlap[0] * overlap[1] * overlap[2] + union = box0[3] * box0[3] * box0[3] + box1[3] * box1[3] * box1[3] - intersection + return intersection / union + + +def nms(predictions, nms_th=0.05): + """ + + Args: + predictions: Output from the neural network + nms_th: return: (Default value = 0.05) + + Returns: + + """ + if len(predictions) == 0: + return predictions + + predictions = predictions[np.argsort(-predictions[:, 0])] + bboxes = [predictions[0]] + for i in np.arange(1, len(predictions)): + bbox = predictions[i] + flag = 1 + for j in range(len(bboxes)): + if iou(bbox[1:5], bboxes[j][1:5]) >= nms_th: + flag = -1 + break + if flag == 1: + bboxes.append(bbox) + + bboxes = np.asarray(bboxes, np.float32) + return bboxes + + +def filter_lungs(image, spacing=(1, 1, 1), fill_value=170): + """ + + Args: + image: Image in Hu units + spacing: Image spacing (Default value = (1,1,1) + fill_value: Hu value to use (Default value = 170) + + + Returns: + An image volume containing only lungs as well as the boolean mask. + + """ + + mask = extract_lungs(image, spacing) + + extracted = np.array(image) + extracted[np.logical_not(mask)] = fill_value + + return extracted, mask + + +def predict(image_itk, model_path="src/algorithms/identify/assets/dsb2017_detector.ckpt"): + """ + + Args: + image_itk: ITK Image in Hu units + model_path: Path to the file containing the model state + (Default value = "src/algorithms/identify/assets/dsb2017_detector.ckpt") + + Returns: + List of Nodule locations and probabilities + + """ + + spacing = np.array(image_itk.GetSpacing())[::-1] + image = sitk.GetArrayFromImage(image_itk) + masked_image, mask = filter_lungs(image) + # masked_image = image + net = Net() + net.load_state_dict(torch.load(model_path)["state_dict"]) + if torch.cuda.is_available(): + net = torch.nn.DataParallel(net).cuda() + + split_comber = SplitComb(side_len=int(144), margin=32, max_stride=16, stride=4, pad_value=170) + + # We have to use small batches until the next release of PyTorch, as bigger ones will segfault for CPU + # split_comber = SplitComb(side_len=int(32), margin=16, max_stride=16, stride=4, pad_value=170) + # Transform image to the 0-255 range and resample to 1x1x1mm + imgs = lum_trans(masked_image) + imgs = resample(imgs, spacing, np.array([1, 1, 1]), order=1)[0] + imgs = imgs[np.newaxis, ...] + + imgT, coords, nzhw = split_data(imgs, split_comber=split_comber) + results = [] + # Loop over the image chunks + for img, coord in zip(imgT, coords): + var = Variable(img[np.newaxis]) + var.volatile = True + coord = Variable(coord[np.newaxis]) + coord.volatile = True + resvar = net(var, coord) + res = resvar.data.cpu().numpy() + results.append(res) + + results = np.concatenate(results, 0) + results = split_comber.combine(results, nzhw=nzhw) + pbb = GetPBB() + # First index of proposals is the propabillity. Then x, y z, and radius + proposals, _ = pbb(results, ismask=True) + + # proposals = proposals[proposals[:,4] < 40] + proposals = nms(proposals) + # Filter out proposals outside the actual lung + # prop_int = proposals[:, 1:4].astype(np.int32) + # wrong = [imgs[0, x[0], x[1], x[2]] > 180 for x in prop_int] + # proposals = proposals[np.logical_not(wrong)] + + # Do sigmoid to get propabillities + proposals[:, 0] = expit(proposals[:, 0]) + # Remove really weak proposals? + # proposals = proposals[proposals[:,0] > 0.5] + + # Rescale back to image space coordinates + proposals[:, 1:4] /= spacing[np.newaxis] + + return [{"x": int(p[3]), "y": int(p[2]), "z": int(p[1]), "p_nodule": float(p[0])} for p in proposals] diff --git a/prediction/src/algorithms/identify/trained_model.py b/prediction/src/algorithms/identify/trained_model.py index c4c91204..beebdcac 100644 --- a/prediction/src/algorithms/identify/trained_model.py +++ b/prediction/src/algorithms/identify/trained_model.py @@ -7,11 +7,9 @@ for where the centroids of nodules are in the DICOM image. """ -import glob +import SimpleITK as sitk -import dicom -from src.preprocess.errors import EmptyDicomSeriesException -from src.preprocess.lung_segmentation import save_lung_segments +from src.algorithms.identify.src import gtr123_model from . import prediction @@ -40,22 +38,32 @@ def predict(dicom_path): 'z': int, 'p_nodule': float} """ - if dicom_path[-1] != '/': - dicom_path += '/' - dicom_files = glob.glob(dicom_path + "*.dcm") - if not dicom_files: - raise EmptyDicomSeriesException - patient_id = dicom.read_file(dicom_files[0]).SeriesInstanceUID - z, x, y = save_lung_segments(dicom_path, patient_id) - results_df = run_prediction(patient_id) - results_df['coord_x'] *= x - results_df['coord_y'] *= y - results_df['coord_z'] *= z - rescaled_results_df = results_df[['coord_x', 'coord_y', 'coord_z', 'nodule_chance']].copy() - rescaled_results_df.columns = ['x', 'y', 'z', 'p_nodule'] - rescaled_results_df[['x', 'y', 'z']] = rescaled_results_df[['x', 'y', 'z']].astype(int) - rescaled_dict = rescaled_results_df.to_dict(orient='record') - return rescaled_dict + + reader = sitk.ImageSeriesReader() + filenames = reader.GetGDCMSeriesFileNames(dicom_path) + if not filenames: + raise ValueError("The path doesn't contain neither .mhd nor .dcm files") + + reader.SetFileNames(reader.GetGDCMSeriesFileNames(dicom_path)) + image = reader.Execute() + result = gtr123_model.predict(image) + return result + # if dicom_path[-1] != '/': + # dicom_path += '/' + # dicom_files = glob.glob(dicom_path + "*.dcm") + # if not dicom_files: + # raise EmptyDicomSeriesException + # patient_id = dicom.read_file(dicom_files[0]).SeriesInstanceUID + # z, x, y = save_lung_segments(dicom_path, patient_id) + # results_df = run_prediction(patient_id) + # results_df['coord_x'] *= x + # results_df['coord_y'] *= y + # results_df['coord_z'] *= z + # rescaled_results_df = results_df[['coord_x', 'coord_y', 'coord_z', 'nodule_chance']].copy() + # rescaled_results_df.columns = ['x', 'y', 'z', 'p_nodule'] + # rescaled_results_df[['x', 'y', 'z']] = rescaled_results_df[['x', 'y', 'z']].astype(int) + # rescaled_dict = rescaled_results_df.to_dict(orient='record') + # return rescaled_dict def run_prediction(patient_id, magnification=1, ext_name="luna_posnegndsb_v", version=1, holdout=1): diff --git a/prediction/src/preprocess/extract_lungs.py b/prediction/src/preprocess/extract_lungs.py new file mode 100644 index 00000000..5fe94b12 --- /dev/null +++ b/prediction/src/preprocess/extract_lungs.py @@ -0,0 +1,270 @@ +import numpy as np +import scipy.ndimage +from scipy.ndimage.morphology import binary_dilation, generate_binary_structure +from skimage import measure +from skimage.morphology import convex_hull_image + + +def binarize_per_slice(image, spacing, intensity_th=-600, sigma=1, area_th=30, eccen_th=0.99, bg_patch_size=10): + """ + + :param image: + :param spacing: + :param intensity_th: Anything below this threshold is considered air or lung + :param sigma: + :param area_th: + :param eccen_th: + :param bg_patch_size: + :return: + """ + bw = np.zeros(image.shape, dtype=bool) + + # prepare a mask, with all corner values set to nan + image_size = image.shape[1] + grid_axis = np.linspace(-image_size / 2 + 0.5, image_size / 2 - 0.5, image_size) + x, y = np.meshgrid(grid_axis, grid_axis) + d = (x ** 2 + y ** 2) ** 0.5 + nan_mask = (d < image_size / 2).astype(float) + nan_mask[nan_mask == 0] = np.nan + + for i in range(image.shape[0]): + # Check if corner pixels are identical, if so the slice before Gaussian filtering + if len(np.unique(image[i, 0:bg_patch_size, 0:bg_patch_size])) == 1: + current_bw = scipy.ndimage.filters.gaussian_filter(np.multiply(image[i].astype('float32'), nan_mask), sigma, + truncate=2.0) < intensity_th + else: + current_bw = scipy.ndimage.filters.gaussian_filter(image[i].astype('float32'), sigma, + truncate=2.0) < intensity_th + + # select proper components + label = measure.label(current_bw) + properties = measure.regionprops(label) + valid_label = set() + + for prop in properties: + if prop.area * spacing[1] * spacing[2] > area_th and prop.eccentricity < eccen_th: + valid_label.add(prop.label) + + current_bw = np.in1d(label, list(valid_label)).reshape(label.shape) + bw[i] = current_bw + + return bw + + +def _fill_hole(bw): + label = measure.label(~bw) + bg_labels = {label[0, 0, 0], label[0, 0, -1], label[0, -1, 0], label[0, -1, -1], + label[-1, 0, 0], label[-1, 0, -1], label[-1, -1, 0], label[-1, -1, -1]} + bw = ~np.in1d(label, list(bg_labels)).reshape(label.shape) + + return bw + + +def _merge_background_labels(label, cut_num): + mid = int(label.shape[2] / 2) + bg_label = {label[0, 0, 0], label[0, 0, -1], label[0, -1, 0], label[0, -1, -1], + label[-1 - cut_num, 0, 0], label[-1 - cut_num, 0, -1], label[-1 - cut_num, -1, 0], + label[-1 - cut_num, -1, -1], + label[0, 0, mid], label[0, -1, mid], label[-1 - cut_num, 0, mid], label[-1 - cut_num, -1, mid]} + for l in bg_label: + label[label == l] = 0 + return label + + +def _remove_large_objects(label, spacing, area_th, dist_th): + # prepare a distance map for further analysis + x_axis = np.linspace(-label.shape[1] / 2 + 0.5, label.shape[1] / 2 - 0.5, label.shape[1]) * spacing[1] + y_axis = np.linspace(-label.shape[2] / 2 + 0.5, label.shape[2] / 2 - 0.5, label.shape[2]) * spacing[2] + x, y = np.meshgrid(x_axis, y_axis) + d = (x ** 2 + y ** 2) ** 0.5 + vols = measure.regionprops(label) + valid_label = set() + # select components based on their area and distance to center axis on all slices + for vol in vols: + single_vol = label == vol.label + slice_area = np.zeros(label.shape[0]) + min_distance = np.zeros(label.shape[0]) + for i in range(label.shape[0]): + slice_area[i] = np.sum(single_vol[i]) * np.prod(spacing[1:3]) + min_distance[i] = np.min(single_vol[i] * d + (1 - single_vol[i]) * np.max(d)) + + if np.average([min_distance[i] for i in range(label.shape[0]) if slice_area[i] > area_th]) < dist_th: + valid_label.add(vol.label) + return valid_label + + +def all_slice_analysis(bw, spacing, cut_num=0, vol_limit=[0.68, 8.2], area_th=6e3, dist_th=62): + """ + + Args: + bw: Binary volume, created on a per-slice basis + spacing: Image spacing + cut_num: Number of top layers to be removed + vol_limit: + area_th: + dist_th: + + Returns: + + """ + # in some cases, several top layers need to be removed first + if cut_num > 0: + bw0 = np.copy(bw) + bw[-cut_num:] = False + + label = measure.label(bw, connectivity=1) + + # remove components access to corners + label = _merge_background_labels(label, cut_num) + + # select components based on volume + properties = measure.regionprops(label) + for prop in properties: + if prop.area * spacing.prod() < vol_limit[0] * 1e6 or prop.area * spacing.prod() > vol_limit[1] * 1e6: + label[label == prop.label] = 0 + + if len(np.unique(label)) == 1: + return bw, 0 + + valid_label = _remove_large_objects(label, spacing, area_th=area_th, dist_th=dist_th) + + bw = np.in1d(label, list(valid_label)).reshape(label.shape) + + # fill back the parts removed earlier + if cut_num > 0: + # bw1 is bw with removed slices, bw2 is a dilated version of bw, part of their intersection is returned as + # final mask + bw1 = np.copy(bw) + bw1[-cut_num:] = bw0[-cut_num:] + bw2 = np.copy(bw) + bw2 = scipy.ndimage.binary_dilation(bw2, iterations=cut_num) + bw3 = bw1 & bw2 + label = measure.label(bw, connectivity=1) + label3 = measure.label(bw3, connectivity=1) + l_list = set(np.unique(label)) - {0} + valid_l3 = set() + for l in l_list: + indices = np.nonzero(label == l) + l3 = label3[indices[0][0], indices[1][0], indices[2][0]] + if l3 > 0: + valid_l3.add(l3) + bw = np.in1d(label3, list(valid_l3)).reshape(label3.shape) + + return bw, len(valid_label) + + +def _extract_main(bw, cover=0.95): + for i in range(bw.shape[0]): + current_slice = bw[i] + label = measure.label(current_slice) + properties = measure.regionprops(label) + properties.sort(key=lambda x: x.area, reverse=True) + area = [prop.area for prop in properties] + count = 0 + sum = 0 + while sum < np.sum(area) * cover: + sum += area[count] + count += 1 + filter = np.zeros(current_slice.shape, dtype=bool) + for j in range(count): + bb = properties[j].bbox + filter[bb[0]:bb[2], bb[1]:bb[3]] = filter[bb[0]:bb[2], bb[1]:bb[3]] | properties[j].convex_image + bw[i] = bw[i] & filter + + label = measure.label(bw) + properties = measure.regionprops(label) + properties.sort(key=lambda x: x.area, reverse=True) + bw = label == properties[0].label + + return bw + + +def _fill_2d_hole(bw): + for i in range(bw.shape[0]): + current_slice = bw[i] + label = measure.label(current_slice) + properties = measure.regionprops(label) + for prop in properties: + bb = prop.bbox + current_slice[bb[0]:bb[2], bb[1]:bb[3]] = current_slice[bb[0]:bb[2], bb[1]:bb[3]] | prop.filled_image + bw[i] = current_slice + + return bw + + +def two_lung_only(bw, spacing, max_iter=22, max_ratio=4.8): + found_flag = False + iter_count = 0 + bw0 = np.copy(bw) + # Erodes until the two lungs are seperate. + while not found_flag and iter_count < max_iter: + label = measure.label(bw, connectivity=2) + properties = measure.regionprops(label) + properties.sort(key=lambda x: x.area, reverse=True) + if len(properties) > 1 and properties[0].area / properties[1].area < max_ratio: + found_flag = True + bw1 = label == properties[0].label + bw2 = label == properties[1].label + else: + bw = scipy.ndimage.binary_erosion(bw) + iter_count = iter_count + 1 + + if found_flag: + d1 = scipy.ndimage.morphology.distance_transform_edt(np.logical_not(bw1), sampling=spacing) + d2 = scipy.ndimage.morphology.distance_transform_edt(np.logical_not(bw2), sampling=spacing) + bw1 = bw0 & (d1 < d2) + bw2 = bw0 & (d1 > d2) + + bw1 = _extract_main(bw1) + bw2 = _extract_main(bw2) + + else: + bw1 = bw0 + bw2 = np.zeros(bw.shape).astype('bool') + + bw1 = _fill_2d_hole(bw1) + bw2 = _fill_2d_hole(bw2) + bw = bw1 | bw2 + + return bw1, bw2, bw + + +def extract_lungs(image, spacing): + """ + + :param image: Dicom image loaded as numpy array + :param spacing: Pixel spacing + :return: Dicom image numpy + """ + + spacing = np.array(spacing) + + bw = binarize_per_slice(image, spacing) + flag = 0 + cut_num = 0 + cut_step = 2 + bw0 = np.copy(bw) + while flag == 0 and cut_num < bw.shape[0]: + bw = np.copy(bw0) + bw, flag = all_slice_analysis(bw, spacing, cut_num=cut_num, vol_limit=[0.68, 7.5]) + cut_num = cut_num + cut_step + + bw = _fill_hole(bw) + bw1, bw2, bw = two_lung_only(bw, spacing) + return bw + + +def process_mask(mask): + convex_mask = np.copy(mask) + for i_layer in range(convex_mask.shape[0]): + mask1 = np.ascontiguousarray(mask[i_layer]) + if np.sum(mask1) > 0: + mask2 = convex_hull_image(mask1) + if np.sum(mask2) > 2 * np.sum(mask1): + mask2 = mask1 + else: + mask2 = mask1 + convex_mask[i_layer] = mask2 + struct = generate_binary_structure(3, 1) + dilatedMask = binary_dilation(convex_mask, structure=struct, iterations=10) + return dilatedMask diff --git a/prediction/src/preprocess/gtr123_preprocess.py b/prediction/src/preprocess/gtr123_preprocess.py new file mode 100644 index 00000000..274712ec --- /dev/null +++ b/prediction/src/preprocess/gtr123_preprocess.py @@ -0,0 +1,61 @@ +import numpy as np +import warnings +from scipy.ndimage import zoom + +""" +Preprocessing tools used by the gtr123_models +Code adapted from https://github.com/lfz/DSB2017 +""" + + +def lum_trans(img): + """ + + Args: + img: Input image in Hu units + + Returns: Image windowed to [-1200; 600] and scaled to 0-255 + + """ + lungwin = np.array([-1200., 600.]) + newimg = (img - lungwin[0]) / (lungwin[1] - lungwin[0]) + newimg[newimg < 0] = 0 + newimg[newimg > 1] = 1 + return (newimg * 255).astype('uint8') + + +def resample(imgs, spacing, new_spacing, order=2): + """ + + Args: + imgs: + spacing: Input image voxel size + new_spacing: Output image voxel size + order: (Default value = 2) + + Returns: + + """ + if len(imgs.shape) == 3: + new_shape = np.round(imgs.shape * spacing / new_spacing) + true_spacing = spacing * imgs.shape / new_shape + resize_factor = new_shape / imgs.shape + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + imgs = zoom(imgs, resize_factor, mode='nearest', order=order) + + return imgs, true_spacing + elif len(imgs.shape) == 4: + n = imgs.shape[-1] + newimg = [] + + for i in range(n): + slice = imgs[:, :, :, i] + newslice, true_spacing = resample(slice, spacing, new_spacing) + newimg.append(newslice) + + newimg = np.transpose(np.array(newimg), [1, 2, 3, 0]) + return newimg, true_spacing + else: + raise ValueError('wrong shape') diff --git a/prediction/src/tests/__init__.py b/prediction/src/tests/__init__.py index c5499f89..e22693eb 100644 --- a/prediction/src/tests/__init__.py +++ b/prediction/src/tests/__init__.py @@ -1,3 +1,4 @@ +import torch # noqa # pylint: disable=unused-import import sys import os diff --git a/prediction/src/tests/test_classify_trained_model_predict.py b/prediction/src/tests/test_classify_trained_model_predict.py index 672069ee..4c7abf83 100644 --- a/prediction/src/tests/test_classify_trained_model_predict.py +++ b/prediction/src/tests/test_classify_trained_model_predict.py @@ -1,8 +1,6 @@ import pytest from ..algorithms.classify import trained_model -from ..algorithms.classify.src.preprocess_patch import preprocess_LR3DCNN -from ..preprocess import preprocess_ct @pytest.fixture @@ -13,7 +11,7 @@ def dicom_path(): @pytest.fixture def model_path(): - yield '../classify_models/model.h5' + yield 'src/algorithms/classify/assets/gtr123_model.ckpt' def test_classify_predict_model_load(dicom_path, model_path): @@ -21,21 +19,15 @@ def test_classify_predict_model_load(dicom_path, model_path): [], model_path, preprocess_ct=None, - preprocess_model_input=preprocess_LR3DCNN) + preprocess_model_input=None) assert len(predicted) == 0 def test_classify_predict_inference(dicom_path, model_path): - params = preprocess_ct.Params(clip_lower=-1000, - clip_upper=400, - spacing=(.6, .6, .3)) - preprocess = preprocess_ct.PreprocessCT(params) predicted = trained_model.predict(dicom_path, [{'x': 50, 'y': 50, 'z': 21}], - model_path, - preprocess_ct=preprocess, - preprocess_model_input=preprocess_LR3DCNN) + model_path) assert len(predicted) == 1 assert isinstance(predicted[0]['p_concerning'], float) diff --git a/prediction/src/tests/test_endpoints.py b/prediction/src/tests/test_endpoints.py index 5f5ebdf8..b9c89bac 100644 --- a/prediction/src/tests/test_endpoints.py +++ b/prediction/src/tests/test_endpoints.py @@ -4,11 +4,12 @@ Provides unit tests for the API endpoints. """ -import json import os from functools import partial +import json import pytest + from flask import url_for from src.algorithms import classify, identify, segment from src.factory import create_app @@ -169,4 +170,4 @@ def test_other_error(client): content_type='application/json') data = get_data(r) assert r.status_code == 500 - assert "The specified path does not contain dcm-files." in data['error'] + assert "The path doesn't contain neither .mhd nor .dcm files" in data['error'] diff --git a/prediction/src/tests/test_identify_trained_model_predict.py b/prediction/src/tests/test_identify_trained_model_predict.py new file mode 100644 index 00000000..400801a2 --- /dev/null +++ b/prediction/src/tests/test_identify_trained_model_predict.py @@ -0,0 +1,49 @@ +import numpy as np +import pytest + +from ..algorithms.identify import trained_model +from ..tests.test_endpoints import skip_slow_test + + +@pytest.fixture +def dicom_path_001(): + yield '../images/LIDC-IDRI-0001/1.3.6.1.4.1.14519.5.2.1.6279.6001.298806137288633453246975630178/' \ + '1.3.6.1.4.1.14519.5.2.1.6279.6001.179049373636438705059720603192' + + +@pytest.fixture +def dicom_path_003(): + yield "/images/LIDC-IDRI-0003/1.3.6.1.4.1.14519.5.2.1.6279.6001.101370605276577556143013894866/" \ + "1.3.6.1.4.1.14519.5.2.1.6279.6001.170706757615202213033480003264" + + +@pytest.fixture +def nodule_locations_001(): + yield {"x": 317, "y": 367, "z": 7} + + +@pytest.fixture +def nodule_locations_003(): + yield {"x": 369, "y": 347, "z": 6} + + +@pytest.mark.skipif(skip_slow_test, reason='Takes very long') +def test_identify_nodules_001(dicom_path_001, nodule_locations_001): + predicted = trained_model.predict(dicom_path_001) + + first = predicted[0] + + dist = np.sqrt(np.sum([(first[s] - nodule_locations_001[s]) ** 2 for s in ["x", "y", "z"]])) + + assert (dist < 10) + + +@pytest.mark.skipif(skip_slow_test, reason='Takes very long') +def test_identify_nodules_003(dicom_path_003, nodule_locations_003): + predicted = trained_model.predict(dicom_path_003) + + first = predicted[0] + + dist = np.sqrt(np.sum([(first[s] - nodule_locations_003[s]) ** 2 for s in ["x", "y", "z"]])) + + assert (dist < 10)