From 9ae5fcbdd9c4a836e2c1ed74ae45a82747510f0a Mon Sep 17 00:00:00 2001 From: Christopher Soelistyo Date: Mon, 16 Jan 2023 16:41:08 +0000 Subject: [PATCH 1/3] saved models directly in addition to model weights --- cellxpredict/train.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/cellxpredict/train.py b/cellxpredict/train.py index ef6742f..808a311 100644 --- a/cellxpredict/train.py +++ b/cellxpredict/train.py @@ -63,16 +63,24 @@ def train_encoder(config: ConfigBase): callbacks=[tensorboard_callback, montage_callback], ) - # save the model weights + # save the model and model weights config.model_dir.mkdir(parents=True, exist_ok=True) - model_filename = config.model_dir / config.filename("weights") - model.encoder.save_weights(model_filename.with_suffix(".h5")) + model_filename = config.model_dir / config.model + model.encoder.save(model_filename) + + model_weights_filename = config.model_dir / config.filename("weights") + model.encoder.save_weights(model_weights_filename.with_suffix(".h5")) - decoder_filename = config.filename("weights").replace("encoder", "decoder") + decoder_filename = config.model.replace("encoder", "decoder") model_filename = config.model_dir / decoder_filename + model.decoder.save(model_filename) + + decoder_weights_filename = config.filename("weights").replace("encoder", "decoder") + model_weights_filename = config.model_dir / decoder_weights_filename model.decoder.save_weights(model_filename.with_suffix(".h5")) + def train_projector(config: ConfigBase): """Train the projector model.""" from sklearn.decomposition import PCA From 64d346ac191b96ecde3dc781d06081e29237f9ab Mon Sep 17 00:00:00 2001 From: Christopher Soelistyo Date: Mon, 16 Jan 2023 16:43:04 +0000 Subject: [PATCH 2/3] saved models directly in addition to model weights --- cellxpredict/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cellxpredict/train.py b/cellxpredict/train.py index 808a311..1443e98 100644 --- a/cellxpredict/train.py +++ b/cellxpredict/train.py @@ -77,7 +77,7 @@ def train_encoder(config: ConfigBase): decoder_weights_filename = config.filename("weights").replace("encoder", "decoder") model_weights_filename = config.model_dir / decoder_weights_filename - model.decoder.save_weights(model_filename.with_suffix(".h5")) + model.decoder.save_weights(model_weights_filename.with_suffix(".h5")) From 30c4519aca5bdc56077ba42f34e57509af6f68c8 Mon Sep 17 00:00:00 2001 From: Christopher Soelistyo Date: Mon, 16 Jan 2023 16:44:02 +0000 Subject: [PATCH 3/3] saved models directly in addition to model weights --- cellxpredict/train.py | 1 - 1 file changed, 1 deletion(-) diff --git a/cellxpredict/train.py b/cellxpredict/train.py index 1443e98..2bf83f0 100644 --- a/cellxpredict/train.py +++ b/cellxpredict/train.py @@ -80,7 +80,6 @@ def train_encoder(config: ConfigBase): model.decoder.save_weights(model_weights_filename.with_suffix(".h5")) - def train_projector(config: ConfigBase): """Train the projector model.""" from sklearn.decomposition import PCA