Skip to content

Commit

Permalink
add and debug deep learning models
Browse files Browse the repository at this point in the history
  • Loading branch information
Henley13 committed Jun 29, 2021
1 parent efc51ff commit 37fb620
Showing 1 changed file with 94 additions and 68 deletions.
162 changes: 94 additions & 68 deletions bigfish/deep_learning/models_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
"""

import os
import warnings
from zipfile import ZipFile

import bigfish.stack as stack

Expand All @@ -23,65 +25,60 @@

# ### Pre-trained models ###

def load_pretrained_model(model_name, channel):
def load_pretrained_model(channel, model_name):
"""Build and compile a model, then load its pretrained weights.
Parameters
----------
model_name : str
Name of the model used ('3-classes', 'distance_edge',
'double_distance_edge).
channel : str
Input channel for the model ('nuc' or 'cell').
Input channel for the model ('nuc', 'cell' or 'double').
model_name : str
Name of the model used ('3-classes' or 'double_distance_edge).
Returns
-------
model : tensorflow.keras.model object
Pretrained Unet model.
"""
# TODO fix warning partial restoration with distance model

# check parameters
stack.check_parameter(model_name=str,
channel=str)
stack.check_parameter(channel=str,
model_name=str)

# unet 3-classes for nucleus segmentation
if model_name == "3_classes" and channel == "nuc":
model = build_compile_3_classes_model()

# unet 3-classes for cell segmentation
elif model_name == "3_classes" and channel == "cell":
model = build_compile_3_classes_model()

# unet distance map to edge for cell segmentation
elif model_name == "distance_edge" and channel == "cell":
model = build_compile_distance_model()

# unet distance map to edge for both nucleus and cell segmentation
elif model_name == "double_distance_edge":
elif model_name == "distance_edge" and channel == "double":
model = build_compile_double_distance_model()

else:
raise ValueError("Model name and channel to segment are not "
"consistent: {0} - {1}.".format(model_name, channel))
"consistent: {0} - {1}.".format(channel, model_name))

# load weights
path_pretrained_directory = check_pretrained_weights(model_name, channel)
path_pretrained_directory = check_pretrained_weights(channel, model_name)
path_checkpoint = os.path.join(
path_pretrained_directory, "checkpoint")
model.load_weights(path_checkpoint)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
model.load_weights(path_checkpoint)

return model


def check_pretrained_weights(model_name, channel):
def check_pretrained_weights(channel, model_name):
"""Check pretrained weights exist and download them if necessary.
Parameters
----------
channel : str
Input channel for the model ('nuc', 'cell' or 'double').
model_name : str
Name of the model used ('3-classes' or 'distance_edge').
channel : str
Input channel for the model ('nuc' or 'cell').
Returns
-------
Expand All @@ -90,59 +87,44 @@ def check_pretrained_weights(model_name, channel):
"""
# check parameters
stack.check_parameter(model_name=str,
channel=str)
stack.check_parameter(channel=str,
model_name=str)

# get path checkpoint
path_weights_directory = _get_weights_directory()
if model_name == "double_distance_edge":
pretrained_directory = model_name
else:
pretrained_directory = "_".join([channel, model_name])
pretrained_directory = "_".join([channel, model_name])
path_directory = os.path.join(path_weights_directory, pretrained_directory)
download = False

# get url and hash
if model_name == "3_classes" and channel == "nuc":
url_checkpoint = ""
hash_checkpoint = ""
url_data = ""
hash_data = ""
url_index = ""
hash_index = ""
elif model_name == "3_classes" and channel == "cell":
url_checkpoint = ""
hash_checkpoint = ""
url_data = ""
hash_data = ""
url_index = ""
hash_index = ""
elif model_name == "distance_edge" and channel == "cell":
url_checkpoint = ""
hash_checkpoint = ""
url_data = ""
hash_data = ""
url_index = ""
hash_index = ""
elif model_name == "double_distance_edge":
url_checkpoint = "https://github.com/fish-quant/big-fish-examples/releases/download/0.0.2/checkpoint"
if channel == "nuc" and model_name == "3_classes":
url_zip_file = "https://github.com/fish-quant/big-fish-examples/releases/download/weights_1/nuc_3_classes.zip"
hash_checkpoint = "02988027faf3f16b4088ee83c2ade14098e8ffb325c23a576cc639dae48aa936"
hash_data = "284d6b8cadb3eddd691f1407bfdd1a7e6fa085c57a2feac07446e84aa3b1baf8"
hash_index = "c924f0f2be179340dc5c75e21833fd831d5a4efdfb48382edbbcff748a162bc4"
hash_log = "b6d08b9fbacd3430d89eef2cd5c3246eb8c93e7b970ba17a76c8a9d72f58c8e7"
elif channel == "double" and model_name == "distance_edge":
url_zip_file = "https://github.com/fish-quant/big-fish-examples/releases/download/weights_0/double_distance_edge.zip"
hash_checkpoint = "02988027faf3f16b4088ee83c2ade14098e8ffb325c23a576cc639dae48aa936"
url_data = "https://github.com/fish-quant/big-fish-examples/releases/download/0.0.2/checkpoint.data-00000-of-00001"
hash_data = "614c50d25bbdc793c2a54d0e64a31849757e6f969d192aefd1290e43b6fa5146"
url_index = "https://github.com/fish-quant/big-fish-examples/releases/download/0.0.2/checkpoint.index"
hash_index = "bdac339c67c1071e73a856005624e78fc5889927ad16ad1315a60740267966cf"
hash_data = "528aecbc6418df8d0d73fb23b3518ef779ae5135e15ca3e2e67518726f998d4d"
hash_index = "cbebdd86868e46507733fccca589f940a438f4fde6f5d935e61f4112047db63e"
hash_log = "63bc7e629d464bbf6cd3afd466f4d89a3be10edb2633de7ab825384a61326a09"
else:
raise ValueError("Model name and channel to segment are not "
"consistent: {0} - {1}.".format(model_name, channel))
"consistent: {0} - {1}.".format(channel, model_name))

# case where pretrained directory exists
if os.path.isdir(path_directory):

# paths
path_checkpoint = os.path.join(path_directory, "checkpoint")
path_data = os.path.join(path_directory,
"checkpoint.data-00000-of-00001")
path_index = os.path.join(path_directory, "checkpoint.index")
path_checkpoint = os.path.join(
path_directory, "checkpoint")
path_data = os.path.join(
path_directory, "checkpoint.data-00000-of-00001")
path_index = os.path.join(
path_directory, "checkpoint.index")
path_log = os.path.join(
path_directory, "log")

# checkpoint available and not corrupted
if os.path.exists(path_checkpoint):
Expand Down Expand Up @@ -174,23 +156,67 @@ def check_pretrained_weights(model_name, channel):
else:
download = True

# log available and not corrupted
if os.path.exists(path_log):
try:
stack.check_hash(path_log, hash_log)
except IOError:
print("{0} seems corrupted.".format(path_log))
download = True
else:
download = True

# case where pretrained directory does not exist
else:
os.mkdir(path_directory)
download = True

# download checkpoint files
if download:
print("downloading checkpoint files...")
path = stack.load_and_save_url(
url_checkpoint, path_directory, "checkpoint")
print("downloading model weights...")

# download zipfile
stack.load_and_save_url(
remote_url=url_zip_file,
directory=path_weights_directory)

# unzip
path_zipfile = os.path.join(
path_weights_directory, "{0}.zip".format(pretrained_directory))
with ZipFile(path_zipfile, 'r') as zip:
zip.extract(
member="{0}/checkpoint".format(pretrained_directory),
path=path_weights_directory)
zip.extract(
member="{0}/checkpoint.data-00000-of-00001"
.format(pretrained_directory),
path=path_weights_directory)
zip.extract(
member="{0}/checkpoint.index".format(pretrained_directory),
path=path_weights_directory)
zip.extract(
member="{0}/log".format(pretrained_directory),
path=path_weights_directory)

# check files consistency
path = os.path.join(
path_weights_directory,
"{0}/checkpoint".format(pretrained_directory))
stack.check_hash(path, hash_checkpoint)
path = stack.load_and_save_url(
url_data, path_directory, "checkpoint.data-00000-of-00001")
path = os.path.join(
path_weights_directory,
"{0}/checkpoint.data-00000-of-00001".format(pretrained_directory))
stack.check_hash(path, hash_data)
path = stack.load_and_save_url(
url_index, path_directory, "checkpoint.index")
path = os.path.join(
path_weights_directory,
"{0}/checkpoint.index".format(pretrained_directory))
stack.check_hash(path, hash_index)
path = os.path.join(
path_weights_directory,
"{0}/log".format(pretrained_directory))
stack.check_hash(path, hash_log)

# remove zipfile
os.remove(path_zipfile)

return path_directory

Expand Down

0 comments on commit 37fb620

Please sign in to comment.