diff --git a/sleap/nn/training.py b/sleap/nn/training.py index 43be5cf31..7d0d25a56 100644 --- a/sleap/nn/training.py +++ b/sleap/nn/training.py @@ -1799,8 +1799,8 @@ def visualize_example(example): ) -def main(args: Optional[List] = None): - """Create CLI for training and run.""" +def create_trainer_using_cli(args: Optional[List] = None): + """Create CLI for training.""" import argparse parser = argparse.ArgumentParser() @@ -2004,10 +2004,15 @@ def main(args: Optional[List] = None): test_labels=args.test_labels, video_search_paths=args.video_paths, ) - trainer.train() return trainer +def main(args: Optional[List] = None): + """Create CLI for training and run.""" + trainer = create_trainer_using_cli(args=args) + trainer.train() + + if __name__ == "__main__": main() diff --git a/tests/nn/test_training.py b/tests/nn/test_training.py index b3eda8676..b95d177ff 100644 --- a/tests/nn/test_training.py +++ b/tests/nn/test_training.py @@ -22,7 +22,7 @@ TopdownConfmapsModelTrainer, TopDownMultiClassModelTrainer, Trainer, - main as sleap_train, + create_trainer_using_cli as sleap_train, ) sleap.use_cpu_only()