From 4dfc43196fa5ecb2d02dc2e407427f83a7623372 Mon Sep 17 00:00:00 2001 From: Arthur Imbert Date: Wed, 2 Jun 2021 10:09:29 +0200 Subject: [PATCH] add double distance unet --- bigfish/deep_learning/__init__.py | 4 +- bigfish/deep_learning/models_segmentation.py | 126 ++++++++++++++++++- 2 files changed, 127 insertions(+), 3 deletions(-) diff --git a/bigfish/deep_learning/__init__.py b/bigfish/deep_learning/__init__.py index c26b462..75f8d3e 100644 --- a/bigfish/deep_learning/__init__.py +++ b/bigfish/deep_learning/__init__.py @@ -46,6 +46,7 @@ from .models_segmentation import check_pretrained_weights from .models_segmentation import build_compile_3_classes_model from .models_segmentation import build_compile_distance_model +from .models_segmentation import build_compile_double_distance_model _utils_models = [ @@ -60,7 +61,8 @@ "load_pretrained_model", "check_pretrained_weights", "build_compile_3_classes_model", - "build_compile_distance_model"] + "build_compile_distance_model", + "build_compile_double_distance_model"] __all__ = (_utils_models, _models_segmentation) diff --git a/bigfish/deep_learning/models_segmentation.py b/bigfish/deep_learning/models_segmentation.py index d814ff9..7164300 100644 --- a/bigfish/deep_learning/models_segmentation.py +++ b/bigfish/deep_learning/models_segmentation.py @@ -14,6 +14,7 @@ from tensorflow.python.keras.layers import Input from tensorflow.python.keras.layers import Softmax +from tensorflow.python.keras.layers import Concatenate from tensorflow.python.keras.engine.training import Model from .utils_models import EncoderDecoder @@ -28,7 +29,8 @@ def load_pretrained_model(model_name, channel): Parameters ---------- model_name : str - Name of the model used ('3-classes' or 'distance_edge'). + Name of the model used ('3-classes', 'distance_edge', + 'double_distance_edge). channel : str Input channel for the model ('nuc' or 'cell'). @@ -54,6 +56,10 @@ def load_pretrained_model(model_name, channel): elif model_name == "distance_edge" and channel == "cell": model = build_compile_distance_model() + # unet distance map to edge for both nucleus and cell segmentation + elif model_name == "double_distance_edge": + model = build_compile_double_distance_model() + else: raise ValueError("Model name and channel to segment are not " "consistent: {0} - {1}.".format(model_name, channel)) @@ -89,7 +95,10 @@ def check_pretrained_weights(model_name, channel): # get path checkpoint path_weights_directory = _get_weights_directory() - pretrained_directory = "_".join([channel, model_name]) + if model_name == "double_distance_edge": + pretrained_directory = model_name + else: + pretrained_directory = "_".join([channel, model_name]) path_directory = os.path.join(path_weights_directory, pretrained_directory) download = False @@ -115,6 +124,13 @@ def check_pretrained_weights(model_name, channel): hash_data = "" url_index = "" hash_index = "" + elif model_name == "double_distance_edge": + url_checkpoint = "https://github.com/fish-quant/big-fish-examples/releases/download/0.0.2/checkpoint" + hash_checkpoint = "02988027faf3f16b4088ee83c2ade14098e8ffb325c23a576cc639dae48aa936" + url_data = "https://github.com/fish-quant/big-fish-examples/releases/download/0.0.2/checkpoint.data-00000-of-00001" + hash_data = "614c50d25bbdc793c2a54d0e64a31849757e6f969d192aefd1290e43b6fa5146" + url_index = "https://github.com/fish-quant/big-fish-examples/releases/download/0.0.2/checkpoint.index" + hash_index = "bdac339c67c1071e73a856005624e78fc5889927ad16ad1315a60740267966cf" else: raise ValueError("Model name and channel to segment are not " "consistent: {0} - {1}.".format(model_name, channel)) @@ -339,3 +355,109 @@ def _get_distance_model(inputs): name="label_distance")(features_core) # (B, H, W, 1) return output_surface, output_distance + + +def build_compile_double_distance_model(): + """Build and compile a Unet model to predict foreground and a distance map + from nucleus and cell images. + + This model version takes two images as input (for nucleus and cell). + + Returns + ------- + model_distance : Tensorflow model + Compiled Unet model. + + """ + # define inputs + inputs_nuc = Input( + shape=(None, None, 1), dtype="float32", name="nuc") + inputs_cell = Input( + shape=(None, None, 1), dtype="float32", name="cell") + inputs = [inputs_nuc, inputs_cell] + + # define model + (output_distance_nuc, output_surface_cell, + output_distance_cell) = _get_double_distance_model(inputs) + outputs = [output_distance_nuc, output_surface_cell, output_distance_cell] + model_distance = Model( + inputs, + outputs, + name="DoubleDistanceModel") + + # losses + loss_distance_nuc = tf.keras.losses.MeanAbsoluteError() + loss_surface_cell = tf.keras.losses.BinaryCrossentropy() + loss_distance_cell = tf.keras.losses.MeanAbsoluteError() + losses = [[loss_distance_nuc], + [loss_surface_cell], [loss_distance_cell]] + losses_weight = [[1.0], [1.0], [1.0]] + + # metrics + metric_distance_nuc = tf.metrics.MeanAbsoluteError(name="mae") + metric_surface_cell = tf.metrics.BinaryAccuracy(name="accuracy") + metric_distance_cell = tf.metrics.MeanAbsoluteError(name="mae") + metrics = [[metric_distance_nuc], + [metric_surface_cell], [metric_distance_cell]] + + # compile model + model_distance.compile( + optimizer='adam', + loss=losses, + loss_weights=losses_weight, + metrics=metrics) + + return model_distance + + +def _get_double_distance_model(inputs): + """Build Unet architecture that return nucleus and cell distance maps, and + a cell surface prediction. + + Parameters + ---------- + inputs : List[tensorflow.keras.Input object] + List of two input layer with shape (B, H, W, 1) or (H, W, 1). + + Returns + ------- + output_distance_nuc : tensorflow.keras.layers object + Output layer for the nucleus distance map with shape (B, H, W, 1) or + (H, W, 1). + output_surface_cell : tensorflow.keras.layers object + Output layer for the cell foreground/background prediction with shape + (B, H, W, 1) or (H, W, 1). + output_distance_cell : tensorflow.keras.layers object + Output layer for the cell distance map with shape (B, H, W, 1) or + (H, W, 1). + + """ + # compute feature map + inputs_nuc, inputs_cell = inputs + inputs = Concatenate( + axis=-1)([inputs_nuc, inputs_cell]) # (B, H, W, 2) + features_core = EncoderDecoder( + name="encoder_decoder")(inputs) # (B, H, W, 32) + + # compute distance output nucleus + output_distance_nuc = SameConv( + filters=1, + kernel_size=(1, 1), + activation="relu", + name="label_distance_nuc")(features_core) # (B, H, W, 1) + + # compute surface output cell + output_surface_cell = SameConv( + filters=1, + kernel_size=(1, 1), + activation="sigmoid", + name="label_2_cell")(features_core) # (B, H, W, 1) + + # compute distance output cell + output_distance_cell = SameConv( + filters=1, + kernel_size=(1, 1), + activation="relu", + name="label_distance_cell")(features_core) # (B, H, W, 1) + + return output_distance_nuc, output_surface_cell, output_distance_cell