Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tracking backwards in time #14

Merged
merged 5 commits into from
Nov 12, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 72 additions & 73 deletions sleap_mot/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
import attrs
import cv2
import numpy as np
from copy import deepcopy
from collections import deque
import threading

import sleap_io as sio
from sleap_mot.candidates.fixed_window import FixedWindowCandidates
Expand Down Expand Up @@ -216,24 +216,72 @@

can_load_images = labels.video.exists()

context_frames = deque(maxlen=self.candidate.window_size)
prev_frame_untracked = True
def initialize_and_track(bout, untracked_frames, start, end):
self.initialize_tracker(bout[-5:] if len(bout) > 5 else bout)
for untracked_lf in untracked_frames[start:end]:
img = untracked_lf.image if can_load_images else None
try:
if untracked_lf.frame_idx % 1000 == 0:
print(f"Tracking frame: {untracked_lf.frame_idx}", flush=True)
untracked_lf.instances = self.track_frame(
untracked_lf.instances, untracked_lf.frame_idx, image=img
)
except Exception as e:
print(

Check warning on line 230 in sleap_mot/tracker.py

View check run for this annotation

Codecov / codecov/patch

sleap_mot/tracker.py#L229-L230

Added lines #L229 - L230 were not covered by tests
f"Error tracking frame {untracked_lf.frame_idx}: {e}",
flush=True,
)

tracked_frames, untracked_frames = [], []
untracked_frames_grouped = []
tracked_frames_grouped = []
prev_frame_tracked = True

for lf in labels:
if lf.frame_idx % 100 == 0:
logging.debug(f"Processing frame: {lf.frame_idx}")
if lf.instances and all(inst.track is not None for inst in lf.instances):
if not prev_frame_tracked:
tracked_frames_grouped.append(tracked_frames)
tracked_frames = []

Check warning on line 244 in sleap_mot/tracker.py

View check run for this annotation

Codecov / codecov/patch

sleap_mot/tracker.py#L242-L244

Added lines #L242 - L244 were not covered by tests

tracked_frames.append(lf)
prev_frame_tracked = True

Check warning on line 247 in sleap_mot/tracker.py

View check run for this annotation

Codecov / codecov/patch

sleap_mot/tracker.py#L246-L247

Added lines #L246 - L247 were not covered by tests

else:
if prev_frame_tracked and untracked_frames:
untracked_frames_grouped.append(untracked_frames)
untracked_frames = []

Check warning on line 252 in sleap_mot/tracker.py

View check run for this annotation

Codecov / codecov/patch

sleap_mot/tracker.py#L251-L252

Added lines #L251 - L252 were not covered by tests

untracked_frames.append(lf)
prev_frame_tracked = False

if all(inst.track is not None for inst in lf.instances):
if prev_frame_untracked:
context_frames.clear()
prev_frame_untracked = False
context_frames.append(lf)
tracked_frames_grouped.append(tracked_frames)
untracked_frames_grouped.append(untracked_frames)

for i in range(len(untracked_frames_grouped)):
first_bout = (
tracked_frames_grouped[i] if i < len(tracked_frames_grouped) else []
)
second_bout = (
tracked_frames_grouped[i + 1]
if i + 1 < len(tracked_frames_grouped)
else []
)
untracked_bout = untracked_frames_grouped[i]

if first_bout and second_bout:
half_idx = len(untracked_bout) // 2

Check warning on line 272 in sleap_mot/tracker.py

View check run for this annotation

Codecov / codecov/patch

sleap_mot/tracker.py#L272

Added line #L272 was not covered by tests
elif first_bout:
half_idx = len(untracked_bout)

Check warning on line 274 in sleap_mot/tracker.py

View check run for this annotation

Codecov / codecov/patch

sleap_mot/tracker.py#L274

Added line #L274 was not covered by tests
else:
if not prev_frame_untracked:
self.initialize_tracker(context_frames)
prev_frame_untracked = True
img = lf.image if can_load_images else None
lf.instances = self.track_frame(lf.instances, lf.frame_idx, image=img)
half_idx = 0

if first_bout:
initialize_and_track(first_bout, untracked_bout, 0, half_idx)

Check warning on line 279 in sleap_mot/tracker.py

View check run for this annotation

Codecov / codecov/patch

sleap_mot/tracker.py#L279

Added line #L279 was not covered by tests
if second_bout:
initialize_and_track(second_bout, untracked_bout, half_idx, None)

Check warning on line 281 in sleap_mot/tracker.py

View check run for this annotation

Codecov / codecov/patch

sleap_mot/tracker.py#L281

Added line #L281 was not covered by tests
if not first_bout and not second_bout:
initialize_and_track([], untracked_bout, 0, None)

labels.update()
return labels

Expand All @@ -247,69 +295,20 @@
# get features for the untracked instances.
current_instances = self.get_features(untracked_instances, frame_idx, image)

if self.is_local_queue:
has_tracks = any(
inst.src_instance.track is not None for inst in current_instances
)
else:
has_tracks = any(
inst.track is not None for inst in current_instances.src_instances
)

if has_tracks:
if not self.candidate.current_tracks:
current_tracked_instances = self.candidate.add_new_tracks(
current_instances, maintain_track_ids=True
)
candidates_list = self.generate_candidates()

else:
if self.is_local_queue:
cost_matrix = np.ones(
(len(current_instances), len(self.candidate.current_tracks))
)
for i, inst in enumerate(current_instances):
if inst.src_instance.track is not None:
track_name = int(inst.src_instance.track.name.split("_")[1])
cost_matrix[i][
track_name
] = 0 # No cost for keeping the same track

else:
cost_matrix = np.ones(
(
len(current_instances.src_instances),
len(self.candidate.current_tracks),
)
)
for i, inst in enumerate(current_instances.src_instances):
if inst.track is not None:
track_name = int(inst.track.name.split("_")[1])
if candidates_list:
candidates_feature_dict = self.update_candidates(candidates_list, image)

cost_matrix[i][
track_name
] = 0 # No cost for keeping the same track
scores = self.get_scores(current_instances, candidates_feature_dict)
cost_matrix = self.scores_to_cost_matrix(scores)

current_tracked_instances = self.assign_tracks(
current_instances, cost_matrix
)
current_tracked_instances = self.assign_tracks(
current_instances, cost_matrix
)

else:
candidates_list = self.generate_candidates()

if candidates_list:
candidates_feature_dict = self.update_candidates(candidates_list, image)

scores = self.get_scores(current_instances, candidates_feature_dict)
cost_matrix = self.scores_to_cost_matrix(scores)

current_tracked_instances = self.assign_tracks(
current_instances, cost_matrix
)

else:
current_tracked_instances = self.candidate.add_new_tracks(
current_instances
)
current_tracked_instances = self.candidate.add_new_tracks(current_instances)

# Convert the `current_instances` back to `List[sio.PredictedInstance]` objects.
if self.is_local_queue:
Expand Down
Loading