Skip to content

Commit

Permalink
Merge branch 'develop' into arlo/max_instances
Browse files Browse the repository at this point in the history
  • Loading branch information
sheridana authored Dec 7, 2022
2 parents 3b7041e + 5956782 commit 73343b7
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 9 deletions.
20 changes: 20 additions & 0 deletions sleap/config/pipeline_form.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,16 @@ inference:
label: Elapsed Frame Window
type: int
default: 5
- name: tracking.robust
label: 'Robust quantile of similarity scores'
help: 'For a value between 0 and 1 (excluded), use a robust quantile
of the similarity scores to assign a track to an instance.<br />If equal to 1,
use the max similarity score (non-robust).'
type: optional_double
default_disabled: true
none_label: Use max (non-robust)
range: 0,1
default: 0.95
- name: tracking.save_shifted_instances
label: Save shifted instances
type: bool
Expand Down Expand Up @@ -333,6 +343,16 @@ inference:
label: Elapsed Frame Window
type: int
default: 5
- name: tracking.robust
label: 'Robust quantile of similarity scores'
help: 'For a value between 0 and 1 (excluded), use a robust quantile
of the similarity scores to assign a track to an instance.<br />If equal to 1,
use the max similarity score (non-robust).'
type: optional_double
default_disabled: true
none_label: Use max (non-robust)
range: 0,1
default: 0.95
- type: text
text: '<b>Kalman filter-based tracking</b>:<br />
Uses the above tracking options to track instances for an initial
Expand Down
1 change: 1 addition & 0 deletions sleap/gui/learning/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,7 @@ def make_predict_cli_call(
optional_items_as_nones = (
"tracking.target_instance_count",
"tracking.kf_init_frame_count",
"tracking.robust",
)

for key in optional_items_as_nones:
Expand Down
40 changes: 31 additions & 9 deletions sleap/nn/tracker/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,8 +410,23 @@ def from_candidate_instances(
candidate_instances: List[InstanceType],
similarity_function: Callable,
matching_function: Callable,
robust_best_instance: float = 1.0,
):
"""Calculates (and stores) matches for a frame from candidate instance.
Args:
untracked_instances: list of untracked instances in the frame.
candidate_instances: list of instances use as match.
similarity_function: a function that returns the similarity between
two instances (untracked and candidate).
matching_function: function used to find the best match from the
cost matrix. See the classmethod `from_cost_matrix`.
robust_best_instance (float): if the value is between 0 and 1
(excluded), use a robust quantile similarity score for the
track. If the value is 1, use the max similarity (non-robust).
For selecting a robust score, 0.95 is a good value.
"""
cost = np.ndarray((0,))
candidate_tracks = []

Expand All @@ -425,9 +440,8 @@ def from_candidate_instances(
# Compute similarity matrix between untracked instances and best
# candidate for each track.
candidate_tracks = list(candidate_instances_by_track.keys())
matching_similarities = np.full(
(len(untracked_instances), len(candidate_tracks)), np.nan
)
dims = (len(untracked_instances), len(candidate_tracks))
matching_similarities = np.full(dims, np.nan)

for i, untracked_instance in enumerate(untracked_instances):

Expand All @@ -443,19 +457,27 @@ def from_candidate_instances(
for candidate_instance in track_instances
]

# Keep the best scoring instance for this track.
best_ind = np.argmax(track_matching_similarities)

# Use the best similarity score for matching.
best_similarity = track_matching_similarities[best_ind]
if 0 < robust_best_instance < 1:
# Robust, use the similarity score in the q-quantile for matching.
best_similarity = np.quantile(
track_matching_similarities,
robust_best_instance,
)
else:
# Non-robust, use the max similarity score for matching.
best_similarity = np.max(track_matching_similarities)
# Keep the best similarity score for this track.
matching_similarities[i, j] = best_similarity

# Perform matching between untracked instances and candidates.
cost = -matching_similarities
cost[np.isnan(cost)] = np.inf

return cls.from_cost_matrix(
cost, untracked_instances, candidate_tracks, matching_function
cost,
untracked_instances,
candidate_tracks,
matching_function,
)

@classmethod
Expand Down
16 changes: 16 additions & 0 deletions sleap/nn/tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,10 @@ class Tracker(BaseTracker):
after the other tracking has run for all frames.
min_new_track_points: We won't spawn a new track for an instance with
fewer than this many points.
robust_best_instance (float): if the value is between 0 and 1 (excluded),
use a robust quantile similarity score for the track. If the value is 1,
use the max similarity (non-robust). For selecting a robust score,
0.95 is a good value.
"""

track_window: int = 5
Expand All @@ -414,6 +418,7 @@ class Tracker(BaseTracker):
target_instance_count: int = 0
pre_cull_function: Optional[Callable] = None
post_connect_single_breaks: bool = False
robust_best_instance: float = 1.0

min_new_track_points: int = 0

Expand Down Expand Up @@ -510,6 +515,7 @@ def track(
candidate_instances=candidate_instances,
similarity_function=self.similarity_function,
matching_function=self.matching_function,
robust_best_instance=self.robust_best_instance,
)

# Store the most recent match data (for outside inspection).
Expand Down Expand Up @@ -596,6 +602,7 @@ def make_tracker_by_name(
similarity: str = "instance",
match: str = "greedy",
track_window: int = 5,
robust: float = 1.0,
min_new_track_points: int = 0,
min_match_points: int = 0,
# Optical flow options
Expand Down Expand Up @@ -663,6 +670,7 @@ def pre_cull_function(inst_list):

tracker_obj = cls(
track_window=track_window,
robust_best_instance=robust,
min_new_track_points=min_new_track_points,
similarity_function=similarity_function,
matching_function=matching_function,
Expand Down Expand Up @@ -751,6 +759,14 @@ def get_by_name_factory_options(cls):
option["options"] = list(match_policies.keys())
options.append(option)

option = dict(name="robust", default=1)
option["type"] = float
option["help"] = (
"Robust quantile of similarity score for instance matching. "
"If equal to 1, keep the max similarity score (non-robust)."
)
options.append(option)

option = dict(name="track_window", default=5)
option["type"] = int
option["help"] = "How many frames back to look for matches"
Expand Down

0 comments on commit 73343b7

Please sign in to comment.