Skip to content

Commit

Permalink
Supervised Identity Prediction (#460) (#679)
Browse files Browse the repository at this point in the history
* squash merge from roomrys/sleap-1 (#460)

* Add test dataset with tracks

* Add track indices to labels provider

* Add identity class map generator

* Update docstring

* Add class map model trainer, head and config

* Add inference

* Docs and tests for identity module
- Slightly modified matching to greedy-like behavior
* Fix inference
- Add imports to evals inference
- Move the common Predictor.predict() method to base class
- Fix docstrings for new inference classes
- Add test model and integration test

* Generate tracks from config metadata if not provided

* Force typecasting in identity functions

* Force boolean masking op

* Clean up inference module
- Move common Predictor methods to base class
- Switch to `model.predict_on_batch()` for massive performance increase
  with `predictor.predict()`.
- Enable prediction directly on arrays (slow)

* Enable Qt5Agg backend only when necessary during training

* Top-down supervised identity prediction (#476)

* Add sizematcher to new training pipelines

* Fix topdown ID visualization during training

* Add LabeledFrame.tracked_instances property for filtering
- Greedy checking in has_* properties

* Add Labels.copy() method for creating deep copies
- Works by serializing and deserializing to JSON (inefficient, but
  guaranteed to work since we have lots of coverage on I/O)

* Extract labels with tracked instances
- Add copy kwarg to extract to return deep copies
- Remove user and/or untracked instances in with_user_labels_only().
  Previously this functionality was blocked since we couldn't remove the
  instances from labeled frames without affecting the source labels.
- Add remove_untracked_instances() utility for filtering out instances
  from the labels.

* Add track filtering in LabelsReader provider
- This is slightly redundant with
  Labels.with_user_labels_only(..., with_track_only=True) but serves as
  an extra guarantee that we don't train on instances without tracks
  accidentally, regardless of how the data is preprocessed. Can still
  emit "empty" frames if no instances have tracks set, however.

* Add track filtering in DataReaders during training
- Auto-enabled when training from ID models
- Filters out instances without tracks BEFORE train/val splitting
- Split is now done on copy of labels
- Fix DataReaders arg typing
- Tests for DataReaders

* Add crop size detection to topdown ID models
- Add training integration test for topdown ID

* add removal of untracked instances for labeled instances (#460)

* Add removal of untracked instances for labeled instances
- previously used `LabeledFrame.tracked_instances()` which only returns predicted instances with tracking
- created `LabeledFrame.remove_untracked` which returns both user labeled and predicted instances with tracking
* Formatting
- using black v20.8b1

* add tests for Labels and LabeledFrames (#460)

* Add tests
- test `Labels.remove_untracked_instances()` for both cases of `remove_empty_frames: bool`
-test `LabeledFrames.remove_untracked()` for both user-labeled and predicted frames

* formatting (#460)

* add newline (no indent) at end of files which had failed Lint test

* clean-up comments and unneeded parenthesis (#460)

* Last merge fixes

* Lint
  • Loading branch information
roomrys authored Mar 19, 2022
1 parent 1c57287 commit a2738ec
Show file tree
Hide file tree
Showing 34 changed files with 4,373 additions and 85 deletions.
4 changes: 4 additions & 0 deletions sleap/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -1434,6 +1434,10 @@ def tracked_instances(self) -> List[PredictedInstance]:
if type(inst) == PredictedInstance and inst.track is not None
]

def remove_untracked(self):
"""Removes any instances without a track assignment."""
self.instances = [inst for inst in self.instances if inst.track is not None]

@property
def has_user_instances(self) -> bool:
"""Return whether the frame contains any user instances."""
Expand Down
45 changes: 42 additions & 3 deletions sleap/io/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -960,9 +960,36 @@ def user_labeled_frame_inds(self) -> List[int]:
"""Return a list of indices of frames with user labeled instances."""
return [i for i, lf in enumerate(self.labeled_frames) if lf.has_user_instances]

def with_user_labels_only(self) -> "Labels":
"""Return a new `Labels` object with only user labels."""
return self.extract(self.user_labeled_frame_inds)
def with_user_labels_only(
self,
user_instances_only: bool = True,
with_track_only: bool = False,
copy: bool = True,
) -> "Labels":
"""Return a new `Labels` containing only user labels.
This is useful as a preprocessing step to train on only user-labeled data.
Args:
user_instances_only: If `True` (the default), predicted instances will be
removed from frames that also have user instances.
with_track_only: If `True`, remove instances without a track.
copy: If `True` (the default), create a new copy of all of the extracted
labeled frames and associated labels. If `False`, a shallow copy with
references to the original labeled frames and other objects will be
returned. Warning: If returning a shallow copy, predicted and untracked
instances will be removed from the original labels as well!
Returns:
A new `Labels` with only the specified subset of frames and instances.
"""
new_labels = self.extract(self.user_labeled_frame_inds, copy=copy)
if user_instances_only:
new_labels.remove_predictions()
if with_track_only:
new_labels.remove_untracked_instances()
new_labels.remove_empty_frames()
return new_labels

def get_labeled_frame_count(self, video: Optional[Video] = None, filter: Text = ""):
return self._cache.get_frame_count(video, filter)
Expand Down Expand Up @@ -1617,6 +1644,18 @@ def remove_predictions(self, new_labels: Optional["Labels"] = None):
# Keep only labeled frames with no conflicting predictions.
self.labeled_frames = keep_lfs

def remove_untracked_instances(self, remove_empty_frames: bool = True):
"""Remove instances that do not have a track assignment.
Args:
remove_empty_frames: If `True` (the default), removes frames that do not
contain any instances after removing untracked ones.
"""
for lf in self.labeled_frames:
lf.remove_untracked()
if remove_empty_frames:
self.remove_empty_frames()

@classmethod
def complex_merge_between(
cls, base_labels: "Labels", new_labels: "Labels", unify: bool = True
Expand Down
4 changes: 4 additions & 0 deletions sleap/nn/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@
MultiInstanceConfmapsHeadConfig,
PartAffinityFieldsHeadConfig,
MultiInstanceConfig,
ClassMapsHeadConfig,
MultiClassBottomUpConfig,
ClassVectorsHeadConfig,
MultiClassTopDownConfig,
HeadsConfig,
LEAPConfig,
UNetConfig,
Expand Down
152 changes: 151 additions & 1 deletion sleap/nn/config/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ class SingleInstanceConfmapsHeadConfig:
results in confidence maps that are 0.5x the size of the input. Increasing
this value can considerably speed up model performance and decrease memory
requirements, at the cost of decreased spatial resolution.
loss_weight: Scalar float used to weigh the loss term for this head during
training. Increase this to encourage the optimization to focus on improving
this specific output in multi-head models.
offset_refinement: If `True`, model will also output an offset refinement map
used to achieve subpixel localization of peaks during inference. This can
improve the localization accuracy of the model at the cost of additional
Expand All @@ -40,6 +43,7 @@ class SingleInstanceConfmapsHeadConfig:
part_names: Optional[List[Text]] = None
sigma: float = 5.0
output_stride: int = 1
loss_weight: float = 1.0
offset_refinement: bool = False


Expand Down Expand Up @@ -72,6 +76,9 @@ class CentroidsHeadConfig:
results in confidence maps that are 0.5x the size of the input. Increasing
this value can considerably speed up model performance and decrease memory
requirements, at the cost of decreased spatial resolution.
loss_weight: Scalar float used to weigh the loss term for this head during
training. Increase this to encourage the optimization to focus on improving
this specific output in multi-head models.
offset_refinement: If `True`, model will also output an offset refinement map
used to achieve subpixel localization of peaks during inference. This can
improve the localization accuracy of the model at the cost of additional
Expand All @@ -84,6 +91,7 @@ class CentroidsHeadConfig:
anchor_part: Optional[Text] = None
sigma: float = 5.0
output_stride: int = 1
loss_weight: float = 1.0
offset_refinement: bool = False


Expand Down Expand Up @@ -129,6 +137,9 @@ class CenteredInstanceConfmapsHeadConfig:
results in confidence maps that are 0.5x the size of the input. Increasing
this value can considerably speed up model performance and decrease memory
requirements, at the cost of decreased spatial resolution.
loss_weight: Scalar float used to weigh the loss term for this head during
training. Increase this to encourage the optimization to focus on improving
this specific output in multi-head models.
offset_refinement: If `True`, model will also output an offset refinement map
used to achieve subpixel localization of peaks during inference. This can
improve the localization accuracy of the model at the cost of additional
Expand All @@ -142,6 +153,7 @@ class CenteredInstanceConfmapsHeadConfig:
part_names: Optional[List[Text]] = None
sigma: float = 5.0
output_stride: int = 1
loss_weight: float = 1.0
offset_refinement: bool = False


Expand Down Expand Up @@ -274,6 +286,125 @@ class MultiInstanceConfig:
pafs: PartAffinityFieldsHeadConfig = attr.ib(factory=PartAffinityFieldsHeadConfig)


@attr.s(auto_attribs=True)
class ClassMapsHeadConfig:
"""Configurations for class map heads.
These heads are used in bottom-up multi-instance models that classify detected
points using a fixed set of learned classes (e.g., animal identities).
Class maps are an image-space representation of the probability of that each class
occupies a given pixel. This is similar to semantic segmentation, however only the
pixels in the neighborhood of the landmarks have a class assignment.
Attributes:
classes: List of string names of the classes that this head will predict.
sigma: Spread of the Gaussian distribution that determines the neighborhood
that the class maps will be nonzero around each landmark.
output_stride: The stride of the output class maps relative to the input image.
This is the reciprocal of the resolution, e.g., an output stride of 2
results in maps that are 0.5x the size of the input. This should be the same
size as the confidence maps they are associated with.
loss_weight: Scalar float used to weigh the loss term for this head during
training. Increase this to encourage the optimization to focus on improving
this specific output in multi-head models.
"""

classes: Optional[List[Text]] = None
sigma: float = 5.0
output_stride: int = 1
loss_weight: float = 1.0


@attr.s(auto_attribs=True)
class MultiClassBottomUpConfig:
"""Configuration for multi-instance confidence map and class map models.
This configuration specifies a multi-head model that outputs both multi-instance
confidence maps and class maps, which together enable multi-instance pose tracking
in a bottom-up fashion, i.e., no instance cropping, centroids or PAFs are required.
The limitation with this approach is that the classes, e.g., animal identities, must
be labeled in the training data and cannot be generalized beyond those classes. This
is still useful for applications in which the animals are uniquely identifiable and
tracking their identities at inference time is critical, e.g., for closed loop
experiments.
Attributes:
confmaps: Part confidence map configuration (see the description in
`MultiInstanceConfmapsHeadConfig`).
class_maps: Class map configuration (see the description in
`ClassMapsHeadConfig`).
"""

confmaps: MultiInstanceConfmapsHeadConfig = attr.ib(
factory=MultiInstanceConfmapsHeadConfig
)
class_maps: ClassMapsHeadConfig = attr.ib(factory=ClassMapsHeadConfig)


@attr.s(auto_attribs=True)
class ClassVectorsHeadConfig:
"""Configurations for class vectors heads.
These heads are used in top-down multi-instance models that classify detected
points using a fixed set of learned classes (e.g., animal identities).
Class vectors represent the probability that the image is associated with each of
the specified classes. This is similar to a standard classification task.
Attributes:
classes: List of string names of the classes that this head will predict.
num_fc_layers: Number of fully-connected layers before the classification output
layer. These can help in transforming general image features into
classification-specific features.
num_fc_units: Number of units (dimensions) in the fully-connected layers before
classification. Increasing this can improve the representational capacity in
the pre-classification layers.
output_stride: The stride of the output class maps relative to the input image.
This is the reciprocal of the resolution, e.g., an output stride of 2
results in maps that are 0.5x the size of the input. This should be the same
size as the confidence maps they are associated with.
loss_weight: Scalar float used to weigh the loss term for this head during
training. Increase this to encourage the optimization to focus on improving
this specific output in multi-head models.
"""

classes: Optional[List[Text]] = None
num_fc_layers: int = 1
num_fc_units: int = 64
global_pool: bool = True
output_stride: int = 1
loss_weight: float = 1.0


@attr.s(auto_attribs=True)
class MultiClassTopDownConfig:
"""Configuration for centered-instance confidence map and class map models.
This configuration specifies a multi-head model that outputs both centered-instance
confidence maps and class vectors, which together enable multi-instance pose
tracking in a top-down fashion, i.e., instance-centered crops followed by pose
estimation and classification.
The limitation with this approach is that the classes, e.g., animal identities, must
be labeled in the training data and cannot be generalized beyond those classes. This
is still useful for applications in which the animals are uniquely identifiable and
tracking their identities at inference time is critical, e.g., for closed loop
experiments.
Attributes:
confmaps: Part confidence map configuration (see the description in
`CenteredInstanceConfmapsHeadConfig`).
class_vectors: Class map configuration (see the description in
`ClassVectorsHeadConfig`).
"""

confmaps: CenteredInstanceConfmapsHeadConfig = attr.ib(
factory=CenteredInstanceConfmapsHeadConfig
)
class_vectors: ClassVectorsHeadConfig = attr.ib(factory=ClassVectorsHeadConfig)


@oneof
@attr.s(auto_attribs=True)
class HeadsConfig:
Expand All @@ -286,12 +417,16 @@ class HeadsConfig:
centroid: An instance of `CentroidsHeadConfig`.
centered_instance: An instance of `CenteredInstanceConfmapsHeadConfig`.
multi_instance: An instance of `MultiInstanceConfig`.
multi_class_bottomup: An instance of `MultiClassBottomUpConfig`.
multi_class_topdown: An instance of `MultiClassTopDownConfig`.
"""

single_instance: Optional[SingleInstanceConfmapsHeadConfig] = None
centroid: Optional[CentroidsHeadConfig] = None
centered_instance: Optional[CenteredInstanceConfmapsHeadConfig] = None
multi_instance: Optional[MultiInstanceConfig] = None
multi_class_bottomup: Optional[MultiClassBottomUpConfig] = None
multi_class_topdown: Optional[MultiClassTopDownConfig] = None


@attr.s(auto_attribs=True)
Expand Down Expand Up @@ -459,20 +594,35 @@ class PretrainedEncoderConfig:
"""Configuration for UNet backbone with pretrained encoder.
Attributes:
encoder: Name of the network architecture to use as the encoder.
encoder: Name of the network architecture to use as the encoder. Valid encoder
names are:
- `"vgg16", "vgg19",`
- `"resnet18", "resnet34", "resnet50", "resnet101", "resnet152"`
- `"resnext50", "resnext101"`
- `"inceptionv3", "inceptionresnetv2"`
- `"densenet121", "densenet169", "densenet201"`
- `"seresnet18", "seresnet34", "seresnet50", "seresnet101", "seresnet152",`
`"seresnext50", "seresnext101", "senet154"`
- `"mobilenet", "mobilenetv2"`
- `"efficientnetb0", "efficientnetb1", "efficientnetb2", "efficientnetb3",`
`"efficientnetb4", "efficientnetb5", "efficientnetb6", "efficientnetb7"`
Defaults to `"efficientnetb0"`.
pretrained: If `True`, use initialized with weights pretrained on ImageNet.
decoder_filters: Base number of filters for the upsampling blocks in the
decoder.
decoder_filters_rate: Factor to scale the number of filters by at each
consecutive upsampling block in the decoder.
output_stride: Stride of the final output.
decoder_batchnorm: If `True` (the default), use batch normalization in the
decoder layers.
"""

encoder: Text = attr.ib(default="efficientnetb0")
pretrained: bool = True
decoder_filters: int = 256
decoder_filters_rate: float = 1.0
output_stride: int = 2
decoder_batchnorm: bool = True


@oneof
Expand Down
1 change: 1 addition & 0 deletions sleap/nn/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from sleap.nn.data import confidence_maps
from sleap.nn.data import instance_centroids
from sleap.nn.data import instance_cropping
from sleap.nn.data import identity
from sleap.nn.data import normalization
from sleap.nn.data import pipelines
from sleap.nn.data import providers
Expand Down
Loading

0 comments on commit a2738ec

Please sign in to comment.