Skip to content

Commit

Permalink
Last merge fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
roomrys committed Mar 19, 2022
1 parent 27120ca commit ddcc6cb
Show file tree
Hide file tree
Showing 6 changed files with 26 additions and 16 deletions.
6 changes: 6 additions & 0 deletions sleap/nn/data/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -1022,6 +1022,12 @@ def make_training_pipeline(self, data_provider: Provider) -> Pipeline:
provider=data_provider,
)
aug_config = self.optimization_config.augmentation_config
if aug_config.random_flip:
pipeline += RandomFlipper.from_skeleton(
self.data_config.labels.skeletons[0],
horizontal=aug_config.flip_horizontal,
)

pipeline += ImgaugAugmenter.from_config(aug_config)
if aug_config.random_crop:
pipeline += RandomCropper(
Expand Down
13 changes: 5 additions & 8 deletions sleap/nn/evals.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
TopDownPredictor,
BottomUpPredictor,
BottomUpMultiClassPredictor,
TopDownMultiClassPredictor,
SingleInstancePredictor,
)

Expand Down Expand Up @@ -688,20 +689,16 @@ def evaluate_model(
confmap_model=model,
)
elif isinstance(head_config, MultiInstanceConfig):
predictor = sleap.nn.inference.BottomUpPredictor(
bottomup_config=cfg, bottomup_model=model
)
predictor = BottomUpPredictor(bottomup_config=cfg, bottomup_model=model)
elif isinstance(head_config, SingleInstanceConfmapsHeadConfig):
predictor = sleap.nn.inference.SingleInstancePredictor(
confmap_config=cfg, confmap_model=model
)
predictor = SingleInstancePredictor(confmap_config=cfg, confmap_model=model)
elif isinstance(head_config, MultiClassBottomUpConfig):
predictor = sleap.nn.inference.BottomUpMultiClassPredictor(
predictor = BottomUpMultiClassPredictor(
config=cfg,
model=model,
)
elif isinstance(head_config, MultiClassTopDownConfig):
predictor = sleap.nn.inference.TopDownMultiClassPredictor(
predictor = TopDownMultiClassPredictor(
centroid_config=None,
centroid_model=None,
confmap_config=cfg,
Expand Down
12 changes: 5 additions & 7 deletions sleap/nn/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -3512,16 +3512,17 @@ class TopDownMultiClassPredictor(Predictor):
def _initialize_inference_model(self):
"""Initialize the inference model from the trained models and configuration."""
use_gt_centroid = self.centroid_config is None
use_gt_confmap = self.confmap_config is None # TODO
use_gt_confmap = self.confmap_config is None
if use_gt_confmap:
raise ValueError(
"Both a centroid and a confidence map model must be provided to "
"initialize a TopDownMultiClassPredictor.")

if use_gt_centroid:
centroid_crop_layer = CentroidCropGroundTruth(
crop_size=self.confmap_config.data.instance_cropping.crop_size
)
else:
# if use_gt_confmap:
# crop_size = 1
# else:
crop_size = self.confmap_config.data.instance_cropping.crop_size
centroid_crop_layer = CentroidCrop(
keras_model=self.centroid_model.keras_model,
Expand All @@ -3535,9 +3536,6 @@ def _initialize_inference_model(self):
return_confmaps=False,
)

# if use_gt_confmap:
# instance_peaks_layer = FindInstancePeaksGroundTruth()
# else:
cfg = self.confmap_config
instance_peaks_layer = TopDownMultiClassFindPeaks(
keras_model=self.confmap_model.keras_model,
Expand Down
1 change: 1 addition & 0 deletions sleap/nn/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,6 +548,7 @@ def sanitize_scope_name(name: Text) -> Text:
CentroidConfmapsPipeline,
TopdownConfmapsPipeline,
BottomUpPipeline,
TopDownMultiClassPipeline,
BottomUpMultiClassPipeline,
SingleInstanceConfmapsPipeline,
)
Expand Down
2 changes: 1 addition & 1 deletion sleap/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
"""


__version__ = "1.2.0a6"
__version__ = "1.2.0"


def versions():
Expand Down
8 changes: 8 additions & 0 deletions tests/nn/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -698,6 +698,8 @@ def test_load_model(
min_centroid_model_path,
min_centered_instance_model_path,
min_bottomup_model_path,
min_topdown_multiclass_model_path,
min_bottomup_multiclass_model_path,
):
predictor = load_model(min_single_instance_robot_model_path)
assert isinstance(predictor, SingleInstancePredictor)
Expand All @@ -707,3 +709,9 @@ def test_load_model(

predictor = load_model(min_bottomup_model_path)
assert isinstance(predictor, BottomUpPredictor)

predictor = load_model([min_centroid_model_path, min_topdown_multiclass_model_path])
assert isinstance(predictor, TopDownMultiClassPredictor)

predictor = load_model(min_bottomup_multiclass_model_path)
assert isinstance(predictor, BottomUpMultiClassPredictor)

0 comments on commit ddcc6cb

Please sign in to comment.