From ddcc6cbc8f1376914ddc1e5ef82dfebc3104bcd1 Mon Sep 17 00:00:00 2001 From: roomrys <38435167+roomrys@users.noreply.github.com> Date: Fri, 18 Mar 2022 18:20:02 -0700 Subject: [PATCH] Last merge fixes --- sleap/nn/data/pipelines.py | 6 ++++++ sleap/nn/evals.py | 13 +++++-------- sleap/nn/inference.py | 12 +++++------- sleap/nn/training.py | 1 + sleap/version.py | 2 +- tests/nn/test_inference.py | 8 ++++++++ 6 files changed, 26 insertions(+), 16 deletions(-) diff --git a/sleap/nn/data/pipelines.py b/sleap/nn/data/pipelines.py index 3c874095b..b0892f8a1 100644 --- a/sleap/nn/data/pipelines.py +++ b/sleap/nn/data/pipelines.py @@ -1022,6 +1022,12 @@ def make_training_pipeline(self, data_provider: Provider) -> Pipeline: provider=data_provider, ) aug_config = self.optimization_config.augmentation_config + if aug_config.random_flip: + pipeline += RandomFlipper.from_skeleton( + self.data_config.labels.skeletons[0], + horizontal=aug_config.flip_horizontal, + ) + pipeline += ImgaugAugmenter.from_config(aug_config) if aug_config.random_crop: pipeline += RandomCropper( diff --git a/sleap/nn/evals.py b/sleap/nn/evals.py index 7e71236b3..670c339c7 100644 --- a/sleap/nn/evals.py +++ b/sleap/nn/evals.py @@ -42,6 +42,7 @@ TopDownPredictor, BottomUpPredictor, BottomUpMultiClassPredictor, + TopDownMultiClassPredictor, SingleInstancePredictor, ) @@ -688,20 +689,16 @@ def evaluate_model( confmap_model=model, ) elif isinstance(head_config, MultiInstanceConfig): - predictor = sleap.nn.inference.BottomUpPredictor( - bottomup_config=cfg, bottomup_model=model - ) + predictor = BottomUpPredictor(bottomup_config=cfg, bottomup_model=model) elif isinstance(head_config, SingleInstanceConfmapsHeadConfig): - predictor = sleap.nn.inference.SingleInstancePredictor( - confmap_config=cfg, confmap_model=model - ) + predictor = SingleInstancePredictor(confmap_config=cfg, confmap_model=model) elif isinstance(head_config, MultiClassBottomUpConfig): - predictor = sleap.nn.inference.BottomUpMultiClassPredictor( + predictor = BottomUpMultiClassPredictor( config=cfg, model=model, ) elif isinstance(head_config, MultiClassTopDownConfig): - predictor = sleap.nn.inference.TopDownMultiClassPredictor( + predictor = TopDownMultiClassPredictor( centroid_config=None, centroid_model=None, confmap_config=cfg, diff --git a/sleap/nn/inference.py b/sleap/nn/inference.py index 4518d323a..351e9ffa0 100644 --- a/sleap/nn/inference.py +++ b/sleap/nn/inference.py @@ -3512,16 +3512,17 @@ class TopDownMultiClassPredictor(Predictor): def _initialize_inference_model(self): """Initialize the inference model from the trained models and configuration.""" use_gt_centroid = self.centroid_config is None - use_gt_confmap = self.confmap_config is None # TODO + use_gt_confmap = self.confmap_config is None + if use_gt_confmap: + raise ValueError( + "Both a centroid and a confidence map model must be provided to " + "initialize a TopDownMultiClassPredictor.") if use_gt_centroid: centroid_crop_layer = CentroidCropGroundTruth( crop_size=self.confmap_config.data.instance_cropping.crop_size ) else: - # if use_gt_confmap: - # crop_size = 1 - # else: crop_size = self.confmap_config.data.instance_cropping.crop_size centroid_crop_layer = CentroidCrop( keras_model=self.centroid_model.keras_model, @@ -3535,9 +3536,6 @@ def _initialize_inference_model(self): return_confmaps=False, ) - # if use_gt_confmap: - # instance_peaks_layer = FindInstancePeaksGroundTruth() - # else: cfg = self.confmap_config instance_peaks_layer = TopDownMultiClassFindPeaks( keras_model=self.confmap_model.keras_model, diff --git a/sleap/nn/training.py b/sleap/nn/training.py index 391c5d68c..06b86a841 100644 --- a/sleap/nn/training.py +++ b/sleap/nn/training.py @@ -548,6 +548,7 @@ def sanitize_scope_name(name: Text) -> Text: CentroidConfmapsPipeline, TopdownConfmapsPipeline, BottomUpPipeline, + TopDownMultiClassPipeline, BottomUpMultiClassPipeline, SingleInstanceConfmapsPipeline, ) diff --git a/sleap/version.py b/sleap/version.py index ff43cccd9..0463d12df 100644 --- a/sleap/version.py +++ b/sleap/version.py @@ -12,7 +12,7 @@ """ -__version__ = "1.2.0a6" +__version__ = "1.2.0" def versions(): diff --git a/tests/nn/test_inference.py b/tests/nn/test_inference.py index bf1fe9599..36db3243e 100644 --- a/tests/nn/test_inference.py +++ b/tests/nn/test_inference.py @@ -698,6 +698,8 @@ def test_load_model( min_centroid_model_path, min_centered_instance_model_path, min_bottomup_model_path, + min_topdown_multiclass_model_path, + min_bottomup_multiclass_model_path, ): predictor = load_model(min_single_instance_robot_model_path) assert isinstance(predictor, SingleInstancePredictor) @@ -707,3 +709,9 @@ def test_load_model( predictor = load_model(min_bottomup_model_path) assert isinstance(predictor, BottomUpPredictor) + + predictor = load_model([min_centroid_model_path, min_topdown_multiclass_model_path]) + assert isinstance(predictor, TopDownMultiClassPredictor) + + predictor = load_model(min_bottomup_multiclass_model_path) + assert isinstance(predictor, BottomUpMultiClassPredictor)