Skip to content

Commit

Permalink
Split trainer cli function into two functions (#1197)
Browse files Browse the repository at this point in the history
  • Loading branch information
roomrys authored Mar 2, 2023
1 parent c636d96 commit 7abfac7
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 4 deletions.
11 changes: 8 additions & 3 deletions sleap/nn/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -1799,8 +1799,8 @@ def visualize_example(example):
)


def main(args: Optional[List] = None):
"""Create CLI for training and run."""
def create_trainer_using_cli(args: Optional[List] = None):
"""Create CLI for training."""
import argparse

parser = argparse.ArgumentParser()
Expand Down Expand Up @@ -2004,10 +2004,15 @@ def main(args: Optional[List] = None):
test_labels=args.test_labels,
video_search_paths=args.video_paths,
)
trainer.train()

return trainer


def main(args: Optional[List] = None):
"""Create CLI for training and run."""
trainer = create_trainer_using_cli(args=args)
trainer.train()


if __name__ == "__main__":
main()
2 changes: 1 addition & 1 deletion tests/nn/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
TopdownConfmapsModelTrainer,
TopDownMultiClassModelTrainer,
Trainer,
main as sleap_train,
create_trainer_using_cli as sleap_train,
)

sleap.use_cpu_only()
Expand Down

0 comments on commit 7abfac7

Please sign in to comment.