Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resumable Training #1130

Merged
merged 7 commits into from
Feb 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/guides/cli.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ optional arguments:
Path to labels file to use for test. If specified,
overrides the path specified in the training job
config.
--base_checkpoint BASE_CHECKPOINT
Path to base checkpoint (directory containing best_model.h5)
to resume training from.
--tensorboard Enable TensorBoard logging to the run path if not
already specified in the training job config.
--save_viz Enable saving of prediction visualizations to the run
Expand Down
151 changes: 117 additions & 34 deletions sleap/gui/learning/dialog.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from sleap.gui.dialogs.formbuilder import YamlFormWidget
from sleap.gui.learning import runners, scopedkeydict, configs, datagen, receptivefield

from typing import Dict, List, Optional, Text, Optional
from typing import Dict, List, Optional, Text, Optional, cast

from qtpy import QtWidgets, QtCore

Expand Down Expand Up @@ -75,7 +75,7 @@ def __init__(

self.current_pipeline = ""

self.tabs = dict()
self.tabs: Dict[str, TrainingEditorWidget] = dict()
self.shown_tab_names = []

self._cfg_getter = configs.TrainingConfigsGetter.make_from_labels_filename(
Expand Down Expand Up @@ -628,6 +628,7 @@ def run(self):
"""Run with current dialog settings."""

pipeline_form_data = self.pipeline_form_widget.get_form_data()

items_for_inference = self.get_items_for_inference(pipeline_form_data)

config_info_list = self.get_every_head_config_data(pipeline_form_data)
Expand Down Expand Up @@ -900,12 +901,13 @@ def __init__(
self._cfg_list_widget = None
self._receptive_field_widget = None
self._use_trained_model = None
self._resume_training = None
self._require_trained = require_trained
self.head = head

yaml_name = "training_editor_form"

self.form_widgets = dict()
self.form_widgets: Dict[str, YamlFormWidget] = dict()

for key in ("model", "data", "augmentation", "optimization", "outputs"):
self.form_widgets[key] = YamlFormWidget.from_name(
Expand Down Expand Up @@ -960,11 +962,10 @@ def __init__(

# If we have an object which gets a list of config files,
# then we'll show a menu to allow selection from the list.

if self._cfg_getter:
if self._cfg_getter is not None:
self._cfg_list_widget = configs.TrainingConfigFilesWidget(
cfg_getter=self._cfg_getter,
head_name=head,
head_name=cast(str, head), # Expect head to be a string
require_trained=require_trained,
)
self._cfg_list_widget.onConfigSelection.connect(
Expand All @@ -974,20 +975,22 @@ def __init__(

layout.addWidget(self._cfg_list_widget)

# Add option for using trained model from selected config
if self._require_trained:
self._update_use_trained()
else:
self._use_trained_model = QtWidgets.QCheckBox("Use Trained Model")
self._use_trained_model.setEnabled(False)
self._use_trained_model.setVisible(False)

self._use_trained_model.stateChanged.connect(self._update_use_trained)
if self._require_trained:
self._update_use_trained()
elif self._cfg_list_widget is not None:
# Add option for using trained model from selected config file
self._use_trained_model = QtWidgets.QCheckBox("Use Trained Model")
self._use_trained_model.setEnabled(False)
self._use_trained_model.setVisible(False)
self._resume_training = QtWidgets.QCheckBox("Resume Training")
self._resume_training.setEnabled(False)
self._resume_training.setVisible(False)

layout.addWidget(self._use_trained_model)
self._use_trained_model.stateChanged.connect(self._update_use_trained)
self._resume_training.stateChanged.connect(self._update_use_trained)

elif self._require_trained:
self._update_use_trained()
layout.addWidget(self._use_trained_model)
layout.addWidget(self._resume_training)

layout.addWidget(self._layout_widget(col_layout))
self.setLayout(layout)
Expand Down Expand Up @@ -1017,9 +1020,15 @@ def acceptSelectedConfigInfo(self, cfg_info: configs.ConfigFileInfo):
self._load_config(cfg_info)

has_trained_model = cfg_info.has_trained_model
if self._use_trained_model:
if self._use_trained_model is not None:
self._use_trained_model.setChecked(self._require_trained)
self._use_trained_model.setVisible(has_trained_model)
self._use_trained_model.setEnabled(has_trained_model)
# Redundant check (for readability) since this checkbox exists if the above does
if self._resume_training is not None:
self._use_trained_model.setChecked(False)
self._resume_training.setVisible(has_trained_model)
self._resume_training.setEnabled(has_trained_model)

self.update_receptive_field()

Expand Down Expand Up @@ -1060,17 +1069,64 @@ def _load_config(self, cfg_info: configs.ConfigFileInfo):
# self._cfg_list_widget.setUserConfigData(cfg_form_data_dict)

def _update_use_trained(self, check_state=0):
if self._require_trained:
use_trained = True
else:
use_trained = check_state == QtCore.Qt.CheckState.Checked
"""Update config GUI based on _use_trained_model and _resume_training checkboxes.

This function is called when either _use_trained_model or _resume_training checkbox
is checked/unchecked or when _require_trained is changed.

If _require_trained is True, then we'll disable all fields.
If _use_trained_model is checked, then we'll disable all fields.
If _resume_training is checked, then we'll disable only the model field.

Args:
check_state (int, optional): Check state of checkbox. Defaults to 0. Unused.

Returns:
None

Side Effects:
Disables/Enables fields based on checkbox values (and _required_training).
"""

# Check which checkbox changed its value (if any)
sender = self.sender()

if sender is None: # If sender is None, then _required_training is True
pass
# Uncheck _resume_training checkbox if _use_trained_model is unchecked
elif (sender == self._use_trained_model) and (
not self._use_trained_model.isChecked()
):
self._resume_training.setChecked(False)

# Check _use_trained_model checkbox if _resume_training is checked
elif (sender == self._resume_training) and self._resume_training.isChecked():
self._use_trained_model.setChecked(True)

# Update form widgets
use_trained_params = self.use_trained
use_model_params = self.resume_training
for form in self.form_widgets.values():
form.set_enabled(not use_trained)
form.set_enabled(not use_trained_params)

# If user wants to use trained model, then reset form to match config
if use_trained and self._cfg_list_widget:
if use_trained_params or use_model_params:
cfg_info = self._cfg_list_widget.getSelectedConfigInfo()

# If user wants to resume training, then reset only model form to match config
if use_model_params:
self.form_widgets["model"].set_enabled(False)

# Set model form to match config
cfg = cfg_info.config
cfg_dict = cattr.unstructure(cfg)
model_dict = {"model": cfg_dict["model"]}
key_val_dict = scopedkeydict.ScopedKeyDict.from_hierarchical_dict(
model_dict
).key_val_dict
self.set_fields_from_key_val_dict(key_val_dict)

# If user wants to use trained model, then reset entire form to match config
if use_trained_params:
self._load_config(cfg_info)

self._set_head()
Expand Down Expand Up @@ -1100,17 +1156,32 @@ def _set_backbone_from_key_val_dict(self, cfg_key_val_dict):

@property
def use_trained(self) -> bool:
use_trained = False
if self._require_trained:
use_trained = True
elif self._use_trained_model and self._use_trained_model.isChecked():
use_trained = True
return use_trained
if self._require_trained or (
(self._use_trained_model is not None)
and self._use_trained_model.isChecked()
and (not self.resume_training)
):
return True

return False

@property
def resume_training(self) -> bool:
if (self._resume_training is not None) and self._resume_training.isChecked():
return True
return False

@property
def trained_config_info_to_use(self) -> Optional[configs.ConfigFileInfo]:
trained_config_info = self._cfg_list_widget.getSelectedConfigInfo()
if trained_config_info is None:
# If `TrainingEditorWidget` was initialized with a config getter, then
# we expect to have a list of config files
if self._cfg_list_widget is None:
return None

trained_config_info: Optional[
configs.ConfigFileInfo
] = self._cfg_list_widget.getSelectedConfigInfo()
if (trained_config_info is None) or (not trained_config_info.has_trained_model):
return None

if self.use_trained:
Expand All @@ -1123,13 +1194,25 @@ def trained_config_info_to_use(self) -> Optional[configs.ConfigFileInfo]:
trained_config.outputs.run_name_prefix = ""
trained_config.outputs.run_name_suffix = None

if self.resume_training:
# Get the folder path of trained config and set it as the output folder
trained_config_info.config.model.base_checkpoint = str(
Path(cast(str, trained_config_info.path)).parent
)
else:
trained_config_info.config.model.base_checkpoint = None

return trained_config_info

@property
def has_trained_config_selected(self) -> bool:
if self._cfg_list_widget is None:
return False

cfg_info = self._cfg_list_widget.getSelectedConfigInfo()
if cfg_info and cfg_info.has_trained_model:
return True

return False

def get_all_form_data(self) -> dict:
Expand Down
2 changes: 2 additions & 0 deletions sleap/nn/config/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,7 +653,9 @@ class ModelConfig:
Attributes:
backbone: Configurations related to the main network architecture.
heads: Configurations related to the output heads.
base_checkpoint: Path to model folder for loading a checkpoint. Should contain the .h5 file
"""

backbone: BackboneConfig = attr.ib(factory=BackboneConfig)
heads: HeadsConfig = attr.ib(factory=HeadsConfig)
base_checkpoint: Optional[Text] = None
32 changes: 30 additions & 2 deletions sleap/nn/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -745,6 +745,22 @@ def _setup_model(self):
for i, output in enumerate(self.model.keras_model.outputs):
logger.info(f" [{i}] = {output}")

# Resuming training if flagged
if self.config.model.base_checkpoint is not None:
# TODO (AL): Add flexibilty to resume from any checkpoint (e.g.
# latest_model, specific epoch, etc.)

# Grab the 'best_model.h5' file from the previous training run
# and load it into the current model
previous_model_path = os.path.join(
self.config.model.base_checkpoint, "best_model.h5"
)

self.keras_model.load_weights(previous_model_path)
logger.info(f"Loaded previous model weights from {previous_model_path}")
else:
logger.info("Training from scratch")

@property
def keras_model(self) -> tf.keras.Model:
"""Alias for `self.model.keras_model`."""
Expand Down Expand Up @@ -1783,7 +1799,7 @@ def visualize_example(example):
)


def main():
def main(args: Optional[List] = None):
"""Create CLI for training and run."""
import argparse

Expand Down Expand Up @@ -1825,6 +1841,14 @@ def main():
"specified in the training job config."
),
)
parser.add_argument(
"--base_checkpoint",
type=str,
help=(
"Path to base checkpoint (directory containing best_model.h5) to resume "
"training from."
),
)
parser.add_argument(
"--tensorboard",
action="store_true",
Expand Down Expand Up @@ -1883,7 +1907,7 @@ def main():
),
)

args, _ = parser.parse_known_args()
args, _ = parser.parse_known_args(args)

# Find job configuration file.
job_filename = args.training_job_path
Expand Down Expand Up @@ -1916,6 +1940,8 @@ def main():
if len(args.video_paths) == 0:
args.video_paths = None

job_config.model.base_checkpoint = args.base_checkpoint

logger.info("Versions:")
sleap.versions()

Expand Down Expand Up @@ -1980,6 +2006,8 @@ def main():
)
trainer.train()

return trainer


if __name__ == "__main__":
main()
Loading