Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tracking: robust assignment of the best score to an instance #1062

Merged
merged 6 commits into from
Dec 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions sleap/config/pipeline_form.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,16 @@ inference:
label: Elapsed Frame Window
type: int
default: 5
- name: tracking.robust
label: 'Robust quantile of similarity scores'
help: 'For a value between 0 and 1 (excluded), use a robust quantile
of the similarity scores to assign a track to an instance.<br />If equal to 1,
use the max similarity score (non-robust).'
type: optional_double
default_disabled: true
none_label: Use max (non-robust)
range: 0,1
default: 0.95
- name: tracking.save_shifted_instances
label: Save shifted instances
type: bool
Expand Down Expand Up @@ -333,6 +343,16 @@ inference:
label: Elapsed Frame Window
type: int
default: 5
- name: tracking.robust
label: 'Robust quantile of similarity scores'
help: 'For a value between 0 and 1 (excluded), use a robust quantile
of the similarity scores to assign a track to an instance.<br />If equal to 1,
use the max similarity score (non-robust).'
type: optional_double
default_disabled: true
none_label: Use max (non-robust)
range: 0,1
default: 0.95
- type: text
text: '<b>Kalman filter-based tracking</b>:<br />
Uses the above tracking options to track instances for an initial
Expand Down
1 change: 1 addition & 0 deletions sleap/gui/learning/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,7 @@ def make_predict_cli_call(
optional_items_as_nones = (
"tracking.target_instance_count",
"tracking.kf_init_frame_count",
"tracking.robust",
)

for key in optional_items_as_nones:
Expand Down
40 changes: 31 additions & 9 deletions sleap/nn/tracker/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,8 +410,23 @@ def from_candidate_instances(
candidate_instances: List[InstanceType],
similarity_function: Callable,
matching_function: Callable,
robust_best_instance: float = 1.0,
):
"""Calculates (and stores) matches for a frame from candidate instance.

Args:
untracked_instances: list of untracked instances in the frame.
candidate_instances: list of instances use as match.
similarity_function: a function that returns the similarity between
two instances (untracked and candidate).
matching_function: function used to find the best match from the
cost matrix. See the classmethod `from_cost_matrix`.
robust_best_instance (float): if the value is between 0 and 1
(excluded), use a robust quantile similarity score for the
track. If the value is 1, use the max similarity (non-robust).
For selecting a robust score, 0.95 is a good value.

"""
cost = np.ndarray((0,))
candidate_tracks = []

Expand All @@ -425,9 +440,8 @@ def from_candidate_instances(
# Compute similarity matrix between untracked instances and best
# candidate for each track.
candidate_tracks = list(candidate_instances_by_track.keys())
matching_similarities = np.full(
(len(untracked_instances), len(candidate_tracks)), np.nan
)
dims = (len(untracked_instances), len(candidate_tracks))
matching_similarities = np.full(dims, np.nan)

for i, untracked_instance in enumerate(untracked_instances):

Expand All @@ -443,19 +457,27 @@ def from_candidate_instances(
for candidate_instance in track_instances
]

# Keep the best scoring instance for this track.
best_ind = np.argmax(track_matching_similarities)

# Use the best similarity score for matching.
best_similarity = track_matching_similarities[best_ind]
if 0 < robust_best_instance < 1:
# Robust, use the similarity score in the q-quantile for matching.
best_similarity = np.quantile(
track_matching_similarities,
robust_best_instance,
)
else:
# Non-robust, use the max similarity score for matching.
best_similarity = np.max(track_matching_similarities)
# Keep the best similarity score for this track.
matching_similarities[i, j] = best_similarity

# Perform matching between untracked instances and candidates.
cost = -matching_similarities
cost[np.isnan(cost)] = np.inf

return cls.from_cost_matrix(
cost, untracked_instances, candidate_tracks, matching_function
cost,
untracked_instances,
candidate_tracks,
matching_function,
)

@classmethod
Expand Down
16 changes: 16 additions & 0 deletions sleap/nn/tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,10 @@ class Tracker(BaseTracker):
after the other tracking has run for all frames.
min_new_track_points: We won't spawn a new track for an instance with
fewer than this many points.
robust_best_instance (float): if the value is between 0 and 1 (excluded),
use a robust quantile similarity score for the track. If the value is 1,
use the max similarity (non-robust). For selecting a robust score,
0.95 is a good value.
"""

track_window: int = 5
Expand All @@ -414,6 +418,7 @@ class Tracker(BaseTracker):
target_instance_count: int = 0
pre_cull_function: Optional[Callable] = None
post_connect_single_breaks: bool = False
robust_best_instance: float = 1.0

min_new_track_points: int = 0

Expand Down Expand Up @@ -510,6 +515,7 @@ def track(
candidate_instances=candidate_instances,
similarity_function=self.similarity_function,
matching_function=self.matching_function,
robust_best_instance=self.robust_best_instance,
)

# Store the most recent match data (for outside inspection).
Expand Down Expand Up @@ -596,6 +602,7 @@ def make_tracker_by_name(
similarity: str = "instance",
match: str = "greedy",
track_window: int = 5,
robust: float = 1.0,
min_new_track_points: int = 0,
min_match_points: int = 0,
# Optical flow options
Expand Down Expand Up @@ -663,6 +670,7 @@ def pre_cull_function(inst_list):

tracker_obj = cls(
track_window=track_window,
robust_best_instance=robust,
min_new_track_points=min_new_track_points,
similarity_function=similarity_function,
matching_function=matching_function,
Expand Down Expand Up @@ -751,6 +759,14 @@ def get_by_name_factory_options(cls):
option["options"] = list(match_policies.keys())
options.append(option)

option = dict(name="robust", default=1)
option["type"] = float
option["help"] = (
"Robust quantile of similarity score for instance matching. "
"If equal to 1, keep the max similarity score (non-robust)."
)
options.append(option)

option = dict(name="track_window", default=5)
option["type"] = int
option["help"] = "How many frames back to look for matches"
Expand Down