Skip to content

Commit

Permalink
add double distance unet
Browse files Browse the repository at this point in the history
  • Loading branch information
Henley13 committed Jun 2, 2021
1 parent 9ebf35e commit 4dfc431
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 3 deletions.
4 changes: 3 additions & 1 deletion bigfish/deep_learning/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from .models_segmentation import check_pretrained_weights
from .models_segmentation import build_compile_3_classes_model
from .models_segmentation import build_compile_distance_model
from .models_segmentation import build_compile_double_distance_model


_utils_models = [
Expand All @@ -60,7 +61,8 @@
"load_pretrained_model",
"check_pretrained_weights",
"build_compile_3_classes_model",
"build_compile_distance_model"]
"build_compile_distance_model",
"build_compile_double_distance_model"]


__all__ = (_utils_models, _models_segmentation)
126 changes: 124 additions & 2 deletions bigfish/deep_learning/models_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from tensorflow.python.keras.layers import Input
from tensorflow.python.keras.layers import Softmax
from tensorflow.python.keras.layers import Concatenate
from tensorflow.python.keras.engine.training import Model

from .utils_models import EncoderDecoder
Expand All @@ -28,7 +29,8 @@ def load_pretrained_model(model_name, channel):
Parameters
----------
model_name : str
Name of the model used ('3-classes' or 'distance_edge').
Name of the model used ('3-classes', 'distance_edge',
'double_distance_edge).
channel : str
Input channel for the model ('nuc' or 'cell').
Expand All @@ -54,6 +56,10 @@ def load_pretrained_model(model_name, channel):
elif model_name == "distance_edge" and channel == "cell":
model = build_compile_distance_model()

# unet distance map to edge for both nucleus and cell segmentation
elif model_name == "double_distance_edge":
model = build_compile_double_distance_model()

else:
raise ValueError("Model name and channel to segment are not "
"consistent: {0} - {1}.".format(model_name, channel))
Expand Down Expand Up @@ -89,7 +95,10 @@ def check_pretrained_weights(model_name, channel):

# get path checkpoint
path_weights_directory = _get_weights_directory()
pretrained_directory = "_".join([channel, model_name])
if model_name == "double_distance_edge":
pretrained_directory = model_name
else:
pretrained_directory = "_".join([channel, model_name])
path_directory = os.path.join(path_weights_directory, pretrained_directory)
download = False

Expand All @@ -115,6 +124,13 @@ def check_pretrained_weights(model_name, channel):
hash_data = ""
url_index = ""
hash_index = ""
elif model_name == "double_distance_edge":
url_checkpoint = "https://github.com/fish-quant/big-fish-examples/releases/download/0.0.2/checkpoint"
hash_checkpoint = "02988027faf3f16b4088ee83c2ade14098e8ffb325c23a576cc639dae48aa936"
url_data = "https://github.com/fish-quant/big-fish-examples/releases/download/0.0.2/checkpoint.data-00000-of-00001"
hash_data = "614c50d25bbdc793c2a54d0e64a31849757e6f969d192aefd1290e43b6fa5146"
url_index = "https://github.com/fish-quant/big-fish-examples/releases/download/0.0.2/checkpoint.index"
hash_index = "bdac339c67c1071e73a856005624e78fc5889927ad16ad1315a60740267966cf"
else:
raise ValueError("Model name and channel to segment are not "
"consistent: {0} - {1}.".format(model_name, channel))
Expand Down Expand Up @@ -339,3 +355,109 @@ def _get_distance_model(inputs):
name="label_distance")(features_core) # (B, H, W, 1)

return output_surface, output_distance


def build_compile_double_distance_model():
"""Build and compile a Unet model to predict foreground and a distance map
from nucleus and cell images.
This model version takes two images as input (for nucleus and cell).
Returns
-------
model_distance : Tensorflow model
Compiled Unet model.
"""
# define inputs
inputs_nuc = Input(
shape=(None, None, 1), dtype="float32", name="nuc")
inputs_cell = Input(
shape=(None, None, 1), dtype="float32", name="cell")
inputs = [inputs_nuc, inputs_cell]

# define model
(output_distance_nuc, output_surface_cell,
output_distance_cell) = _get_double_distance_model(inputs)
outputs = [output_distance_nuc, output_surface_cell, output_distance_cell]
model_distance = Model(
inputs,
outputs,
name="DoubleDistanceModel")

# losses
loss_distance_nuc = tf.keras.losses.MeanAbsoluteError()
loss_surface_cell = tf.keras.losses.BinaryCrossentropy()
loss_distance_cell = tf.keras.losses.MeanAbsoluteError()
losses = [[loss_distance_nuc],
[loss_surface_cell], [loss_distance_cell]]
losses_weight = [[1.0], [1.0], [1.0]]

# metrics
metric_distance_nuc = tf.metrics.MeanAbsoluteError(name="mae")
metric_surface_cell = tf.metrics.BinaryAccuracy(name="accuracy")
metric_distance_cell = tf.metrics.MeanAbsoluteError(name="mae")
metrics = [[metric_distance_nuc],
[metric_surface_cell], [metric_distance_cell]]

# compile model
model_distance.compile(
optimizer='adam',
loss=losses,
loss_weights=losses_weight,
metrics=metrics)

return model_distance


def _get_double_distance_model(inputs):
"""Build Unet architecture that return nucleus and cell distance maps, and
a cell surface prediction.
Parameters
----------
inputs : List[tensorflow.keras.Input object]
List of two input layer with shape (B, H, W, 1) or (H, W, 1).
Returns
-------
output_distance_nuc : tensorflow.keras.layers object
Output layer for the nucleus distance map with shape (B, H, W, 1) or
(H, W, 1).
output_surface_cell : tensorflow.keras.layers object
Output layer for the cell foreground/background prediction with shape
(B, H, W, 1) or (H, W, 1).
output_distance_cell : tensorflow.keras.layers object
Output layer for the cell distance map with shape (B, H, W, 1) or
(H, W, 1).
"""
# compute feature map
inputs_nuc, inputs_cell = inputs
inputs = Concatenate(
axis=-1)([inputs_nuc, inputs_cell]) # (B, H, W, 2)
features_core = EncoderDecoder(
name="encoder_decoder")(inputs) # (B, H, W, 32)

# compute distance output nucleus
output_distance_nuc = SameConv(
filters=1,
kernel_size=(1, 1),
activation="relu",
name="label_distance_nuc")(features_core) # (B, H, W, 1)

# compute surface output cell
output_surface_cell = SameConv(
filters=1,
kernel_size=(1, 1),
activation="sigmoid",
name="label_2_cell")(features_core) # (B, H, W, 1)

# compute distance output cell
output_distance_cell = SameConv(
filters=1,
kernel_size=(1, 1),
activation="relu",
name="label_distance_cell")(features_core) # (B, H, W, 1)

return output_distance_nuc, output_surface_cell, output_distance_cell

0 comments on commit 4dfc431

Please sign in to comment.