Skip to content
This repository has been archived by the owner on Feb 22, 2020. It is now read-only.

Commit

Permalink
Converted Team 'grt123' identification and classification algorithms …
Browse files Browse the repository at this point in the history
…based on the code from https://github.com/lfz/DSB2017
  • Loading branch information
dchansen authored and reubano committed Sep 29, 2017
1 parent f069a19 commit 45db7be
Show file tree
Hide file tree
Showing 16 changed files with 1,277 additions and 53 deletions.
1 change: 1 addition & 0 deletions .gitattributes
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ test/assets/* filter=lfs diff=lfs merge=lfs -text
*.hd5 filter=lfs diff=lfs merge=lfs -text
*.mhd filter=lfs diff=lfs merge=lfs -text
*.raw filter=lfs diff=lfs merge=lfs -text
*.ckpt filter=lfs diff=lfs merge=lfs -text
7 changes: 5 additions & 2 deletions compose/prediction/Dockerfile-dev
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
FROM python:3.6
FROM ubuntu:rolling
ENV PYTHONUNBUFFERED 1

RUN apt-get update && apt-get install -y tcl tk python3.6 python3.6-tk wget python-opencv
RUN wget https://bootstrap.pypa.io/get-pip.py
RUN python3.6 get-pip.py
RUN ln -s /usr/bin/python3.6 /usr/local/bin/python
# Requirements have to be pulled and installed here, otherwise caching won't work
COPY ./prediction/requirements /requirements
RUN pip install -r /requirements/local.txt
Expand Down
1 change: 1 addition & 0 deletions prediction/requirements/base.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ opencv-python==3.3.0.10
pandas==0.20.3
scikit-image==0.13.0
SimpleITK==1.0.1
http://download.pytorch.org/whl/cu80/torch-0.2.0.post3-cp36-cp36m-manylinux1_x86_64.whl
3 changes: 3 additions & 0 deletions prediction/src/algorithms/classify/assets/gtr123_model.ckpt
Git LFS file not shown
315 changes: 315 additions & 0 deletions prediction/src/algorithms/classify/src/gtr123_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,315 @@
import torch
import numpy as np
from torch.autograd import Variable
from torch import nn
import SimpleITK as sitk

from src.preprocess.gtr123_preprocess import lum_trans, resample

""""
Classification model from team gtr123
Code adapted from https://github.com/lfz/DSB2017
"""
config = {}

config['crop_size'] = [96, 96, 96]
config['scaleLim'] = [0.85, 1.15]
config['radiusLim'] = [6, 100]

config['stride'] = 4

config['detect_th'] = 0.05
config['conf_th'] = -1
config['nms_th'] = 0.05
config['filling_value'] = 160

config['startepoch'] = 20
config['lr_stage'] = np.array([50, 100, 140, 160])
config['lr'] = [0.01, 0.001, 0.0001, 0.00001]
config['miss_ratio'] = 1
config['miss_thresh'] = 0.03
config['anchors'] = [10, 30, 60]


class PostRes(nn.Module):
""" """

def __init__(self, n_in, n_out, stride=1):
super(PostRes, self).__init__()
self.conv1 = nn.Conv3d(n_in, n_out, kernel_size=3, stride=stride, padding=1)
self.bn1 = nn.BatchNorm3d(n_out)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv3d(n_out, n_out, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm3d(n_out)

if stride != 1 or n_out != n_in:
self.shortcut = nn.Sequential(
nn.Conv3d(n_in, n_out, kernel_size=1, stride=stride),
nn.BatchNorm3d(n_out))
else:
self.shortcut = None

def forward(self, x):

residual = x
if self.shortcut is not None:
residual = self.shortcut(x)
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)

out += residual
out = self.relu(out)
return out


class Net(nn.Module):
""" """

def __init__(self):
super(Net, self).__init__()
# The first few layers consumes the most memory, so use simple
# convolution to save memory. Call these layers preBlock, i.e., before
# the residual blocks of later layers.
self.preBlock = nn.Sequential(
nn.Conv3d(1, 24, kernel_size=3, padding=1),
nn.BatchNorm3d(24),
nn.ReLU(inplace=True),
nn.Conv3d(24, 24, kernel_size=3, padding=1),
nn.BatchNorm3d(24),
nn.ReLU(inplace=True))

# 3 poolings, each pooling downsamples the feature map by a factor 2.
# 3 groups of blocks. The first block of each group has one pooling.
num_blocks_forw = [2, 2, 3, 3]
num_blocks_back = [3, 3]
self.featureNum_forw = [24, 32, 64, 64, 64]
self.featureNum_back = [128, 64, 64]

for i in range(len(num_blocks_forw)):
blocks = []

for j in range(num_blocks_forw[i]):
if j == 0:
blocks.append(PostRes(self.featureNum_forw[i], self.featureNum_forw[i + 1]))
else:
blocks.append(PostRes(self.featureNum_forw[i + 1], self.featureNum_forw[i + 1]))

setattr(self, 'forw' + str(i + 1), nn.Sequential(*blocks))

for i in range(len(num_blocks_back)):
blocks = []

for j in range(num_blocks_back[i]):
if j == 0:
if i == 0:
addition = 3
else:
addition = 0

blocks.append(PostRes(self.featureNum_back[i + 1] + self.featureNum_forw[i + 2] + addition,
self.featureNum_back[i]))
else:
blocks.append(PostRes(self.featureNum_back[i], self.featureNum_back[i]))

setattr(self, 'back' + str(i + 2), nn.Sequential(*blocks))

self.maxpool1 = nn.MaxPool3d(kernel_size=2, stride=2, return_indices=True)
self.maxpool2 = nn.MaxPool3d(kernel_size=2, stride=2, return_indices=True)
self.maxpool3 = nn.MaxPool3d(kernel_size=2, stride=2, return_indices=True)
self.maxpool4 = nn.MaxPool3d(kernel_size=2, stride=2, return_indices=True)
self.unmaxpool1 = nn.MaxUnpool3d(kernel_size=2, stride=2)
self.unmaxpool2 = nn.MaxUnpool3d(kernel_size=2, stride=2)

self.path1 = nn.Sequential(
nn.ConvTranspose3d(64, 64, kernel_size=2, stride=2),
nn.BatchNorm3d(64),
nn.ReLU(inplace=True))

self.path2 = nn.Sequential(
nn.ConvTranspose3d(64, 64, kernel_size=2, stride=2),
nn.BatchNorm3d(64),
nn.ReLU(inplace=True))

self.drop = nn.Dropout3d(p=0.2, inplace=False)
self.output = nn.Sequential(nn.Conv3d(self.featureNum_back[0], 64, kernel_size=1),
nn.ReLU(),
# nn.Dropout3d(p = 0.3),
nn.Conv3d(64, 5 * len(config['anchors']), kernel_size=1))

def forward(self, x, coord):
"""
Args:
x:
coord:
Returns:
"""
out = self.preBlock(x) # 16
out_pool, indices0 = self.maxpool1(out)
out1 = self.forw1(out_pool) # 32
out1_pool, indices1 = self.maxpool2(out1)
out2 = self.forw2(out1_pool) # 64
# out2 = self.drop(out2)
out2_pool, indices2 = self.maxpool3(out2)
out3 = self.forw3(out2_pool) # 96
out3_pool, indices3 = self.maxpool4(out3)
out4 = self.forw4(out3_pool) # 96
# out4 = self.drop(out4)

rev3 = self.path1(out4)
comb3 = self.back3(torch.cat((rev3, out3), 1)) # 96+96
# comb3 = self.drop(comb3)
rev2 = self.path2(comb3)

feat = self.back2(torch.cat((rev2, out2, coord), 1)) # 64+64
comb2 = self.drop(feat)
out = self.output(comb2)
size = out.size()
out = out.view(out.size(0), out.size(1), -1)
# out = out.transpose(1, 4).transpose(1, 2).transpose(2, 3).contiguous()
out = out.transpose(1, 2).contiguous().view(size[0], size[2], size[3], size[4], len(config['anchors']), 5)
# out = out.view(-1, 5)
return feat, out


class CaseNet(nn.Module):
"""The classification Net from the gtr123 team - part of the Winning algorithm for DSB2017"""

def __init__(self):
super(CaseNet, self).__init__()
self.NoduleNet = Net()
self.fc1 = nn.Linear(128, 64)
self.fc2 = nn.Linear(64, 1)
self.pool = nn.MaxPool3d(kernel_size=2)
self.dropout = nn.Dropout(0.5)
self.baseline = nn.Parameter(torch.Tensor([-30.0]).float())
self.Relu = nn.ReLU()

def forward(self, xlist, coordlist):
"""
Args:
xlist: Image of size n x k x 1x 96 x 96 x 96
coordlist: Coordinates of size n x k x 3 x 24 x 24 x 24
Returns:
"""
xsize = xlist.size()
corrdsize = coordlist.size()
print(xsize)
# xlist = xlist.view(-1,xsize[2],xsize[3],xsize[4],xsize[5])
# coordlist = coordlist.view(-1,corrdsize[2],corrdsize[3],corrdsize[4],corrdsize[5])

noduleFeat, nodulePred = self.NoduleNet(xlist, coordlist)
nodulePred = nodulePred.contiguous().view(corrdsize[0], corrdsize[1], -1)

featshape = noduleFeat.size() # nk x 128 x 24 x 24 x24
centerFeat = self.pool(noduleFeat[:, :, featshape[2] // 2 - 1:featshape[2] // 2 + 1,
featshape[3] // 2 - 1:featshape[3] // 2 + 1,
featshape[4] // 2 - 1:featshape[4] // 2 + 1])
centerFeat = centerFeat[:, :, 0, 0, 0]
out = self.dropout(centerFeat)
out = self.Relu(self.fc1(out))
out = torch.sigmoid(self.fc2(out))
out = out.view(xsize[0], xsize[1])
base_prob = torch.sigmoid(self.baseline)
casePred = 1 - torch.prod(1 - out, dim=1) * (1 - base_prob.expand(out.size()[0]))
return nodulePred, casePred, out


class SimpleCrop(object):
""" """

def __init__(self):
self.crop_size = config['crop_size']
self.scaleLim = config['scaleLim']
self.radiusLim = config['radiusLim']
self.stride = config['stride']
self.filling_value = config['filling_value']

def __call__(self, imgs, target):
crop_size = np.array(self.crop_size).astype('int')

start = (target[:3] - crop_size / 2).astype('int')
pad = [[0, 0]]

for i in range(3):
if start[i] < 0:
leftpad = -start[i]
start[i] = 0
else:
leftpad = 0
if start[i] + crop_size[i] > imgs.shape[i + 1]:
rightpad = start[i] + crop_size[i] - imgs.shape[i + 1]
else:
rightpad = 0

pad.append([leftpad, rightpad])

imgs = np.pad(imgs, pad, 'constant', constant_values=self.filling_value)
crop = imgs[:, start[0]:start[0] + crop_size[0], start[1]:start[1] + crop_size[1],
start[2]:start[2] + crop_size[2]]

normstart = np.array(start).astype('float32') / np.array(imgs.shape[1:]) - 0.5
normsize = np.array(crop_size).astype('float32') / np.array(imgs.shape[1:])
xx, yy, zz = np.meshgrid(np.linspace(normstart[0], normstart[0] + normsize[0], self.crop_size[0] / self.stride),
np.linspace(normstart[1], normstart[1] + normsize[1], self.crop_size[1] / self.stride),
np.linspace(normstart[2], normstart[2] + normsize[2], self.crop_size[2] / self.stride),
indexing='ij')
coord = np.concatenate([xx[np.newaxis, ...], yy[np.newaxis, ...], zz[np.newaxis, :]], 0).astype('float32')

return crop, coord


def predict(image_itk, nodule_list, model_path="src/algorithms/classify/assets/gtr123_model.ckpt"):
"""
Args:
image_itk: ITK dicom image
nodule_list: List of nodules
model_path: Path to the torch model (Default value = "src/algorithms/classify/assets/gtr123_model.ckpt")
Returns:
List of nodules, and probabilities
"""
if not nodule_list:
return []
casenet = CaseNet()

casenet.load_state_dict(torch.load(model_path))
casenet.eval()

if torch.cuda.is_available():
casenet = torch.nn.DataParallel(casenet).cuda()
# else:
# casenet = torch.nn.parallel.DistributedDataParallel(casenet)

image = sitk.GetArrayFromImage(image_itk)
spacing = np.array(image_itk.GetSpacing())[::-1]
image = lum_trans(image)
image = resample(image, spacing, np.array([1, 1, 1]), order=1)[0]

crop = SimpleCrop()

results = []
for nodule in nodule_list:
print(nodule)
nod_location = np.array([np.float32(nodule[s]) for s in ["z", "y", "x"]])
nod_location *= spacing
cropped_image, coords = crop(image[np.newaxis], nod_location)
cropped_image = Variable(torch.from_numpy(cropped_image[np.newaxis]).float())
cropped_image.volatile = True
coords = Variable(torch.from_numpy(coords[np.newaxis]).float())
coords.volatile = True
_, pred, _ = casenet(cropped_image, coords)
results.append(
{"x": nodule["x"], "y": nodule["y"], "z": nodule["z"], "p_concerning": float(pred.data.cpu().numpy())})

return results
40 changes: 22 additions & 18 deletions prediction/src/algorithms/classify/trained_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@
for if nodules are concerning or not.
"""

import numpy as np
import keras.models
from src.algorithms.classify.src import gtr123_model
from src.preprocess.load_ct import load_ct, MetaData

import SimpleITK as sitk


def predict(dicom_path, centroids, model_path=None,
preprocess_ct=None, preprocess_model_input=None):
Expand Down Expand Up @@ -43,23 +44,26 @@ def predict(dicom_path, centroids, model_path=None,
'z': int,
'p_concerning': float}
"""
if not len(centroids) or model_path is None:
return []
reader = sitk.ImageSeriesReader()
filenames = reader.GetGDCMSeriesFileNames(dicom_path)

if not filenames:
raise ValueError("The path doesn't contain neither .mhd nor .dcm files")

reader.SetFileNames(filenames)
image = reader.Execute()

model = keras.models.load_model(model_path)
ct_array, meta = load_ct(dicom_path)
if preprocess_ct is not None:
meta = MetaData(meta)
ct_array = preprocess_ct(ct_array, meta)
if not isinstance(ct_array, np.ndarray):
raise TypeError('The signature of preprocess_ct must be ' +
'callable[list[DICOM], ndarray] -> ndarray')
if preprocess_ct:
meta = load_ct(dicom_path)[1]
voxel_data = preprocess_ct(image, MetaData(meta))
else:
voxel_data = image

patches = preprocess_model_input(ct_array, centroids)
predictions = model.predict(patches)
predictions = predictions.astype(np.float)
if preprocess_model_input:
preprocessed = preprocess_model_input(voxel_data, centroids)
else:
preprocessed = voxel_data

for i, centroid in enumerate(centroids):
centroid['p_concerning'] = predictions[i, 0]
model_path = model_path or "src/algorithms/classify/assets/gtr123_model.ckpt"

return centroids
return gtr123_model.predict(preprocessed, centroids, model_path)
Loading

0 comments on commit 45db7be

Please sign in to comment.