Skip to content

Commit

Permalink
Remember instance grouping after testing hypotheses
Browse files Browse the repository at this point in the history
  • Loading branch information
roomrys committed Dec 2, 2023
1 parent 2d1e23f commit c66154a
Show file tree
Hide file tree
Showing 3 changed files with 313 additions and 177 deletions.
96 changes: 74 additions & 22 deletions sleap/gui/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class which inherits from `AppCommand` (or a more specialized class such as
from sleap.gui.state import GuiState
from sleap.gui.suggestions import VideoFrameSuggestions
from sleap.instance import Instance, LabeledFrame, Point, PredictedInstance, Track
from sleap.io.cameras import Camcorder, RecordingSession
from sleap.io.cameras import Camcorder, InstanceGroup, RecordingSession
from sleap.io.convert import default_analysis_filename
from sleap.io.dataset import Labels
from sleap.io.format.adaptor import Adaptor
Expand Down Expand Up @@ -3421,13 +3421,14 @@ def do_action(cls, context: CommandContext, params: dict):
# Get params
video = params.get("video", None) or context.state["video"]
session = params.get("session", None) or context.labels.get_session(video)
instances = params["instances"]
instances: Dict[int, Dict[Camcorder, List[Instance]]] = params["instances"]
frame_idx: int = params["frame_idx"]
session = cast(RecordingSession, session) # Could be None if no labels or video

# Get best instance grouping and reprojected coords
instances_and_reprojected_coords = (
TriangulateSession.get_instance_grouping_and_reprojected_coords(
session=session, instance_hypotheses=instances
session=session, frame_idx=frame_idx, instance_hypotheses=instances
)
)

Expand Down Expand Up @@ -3491,6 +3492,7 @@ def verify_views_and_instances(cls, context: CommandContext, params: dict) -> bo
return

cams_to_include = params.get("cams_to_include", None) or session.linked_cameras
selected_cam: Camcorder = session.get_camera(video)

# If not enough `Camcorder`s available/specified, then return
if not TriangulateSession.verify_enough_views(
Expand All @@ -3506,6 +3508,7 @@ def verify_views_and_instances(cls, context: CommandContext, params: dict) -> bo
context=context,
session=session,
frame_idx=frame_idx,
selected_cam=selected_cam,
cams_to_include=cams_to_include,
show_dialog=show_dialog,
)
Expand All @@ -3514,15 +3517,17 @@ def verify_views_and_instances(cls, context: CommandContext, params: dict) -> bo
if not instance:
return False

# Add instances to params dict
# Add instances and frame_idx to params dict
params["instances"] = instances
params["frame_idx"] = frame_idx

return True

@staticmethod
def get_and_verify_enough_instances(
session: RecordingSession,
frame_idx: int,
selected_cam: Camcorder,
context: Optional[CommandContext] = None,
cams_to_include: Optional[List[Camcorder]] = None,
show_dialog: bool = True,
Expand All @@ -3534,6 +3539,7 @@ def get_and_verify_enough_instances(
Args:
session: The `RecordingSession` containing the `Camcorder`s.
frame_idx: Frame index to get instances from (0-indexed).
selected_cam: The `Camcorder` object to determine the current view.
context: The optional command context used to display a dialog.
cams_to_include: List of `Camcorder`s to include. Default is all.
track: `Track` object used to find instances accross views. Default is -1
Expand All @@ -3552,6 +3558,7 @@ def get_and_verify_enough_instances(
] = TriangulateSession.get_products_of_instances(
session=session,
frame_idx=frame_idx,
selected_cam=selected_cam,
cams_to_include=cams_to_include,
)
return instances
Expand Down Expand Up @@ -3759,6 +3766,7 @@ def get_all_views_at_frame(
@staticmethod
def get_instance_grouping_and_reprojected_coords(
session: RecordingSession,
frame_idx: int,
instance_hypotheses: Dict[int, Dict[Camcorder, List[Instance]]],
):
"""Get instance grouping and reprojected coords with lowest reprojection error.
Expand All @@ -3769,6 +3777,7 @@ def get_instance_grouping_and_reprojected_coords(
Args:
session: The `RecordingSession` containing the `Camcorder`s.
frame_idx: Frame index to get views from (0-indexed).
instance_hypotheses: Dict with frame identifier keys (not the frame index)
and values of another inner dict with `Camcorder` keys and
`List[Instance]` values.
Expand Down Expand Up @@ -3803,6 +3812,9 @@ def get_instance_grouping_and_reprojected_coords(
reprojection_error_per_frame=reprojection_error_per_frame,
)

# Assign to `InstanceGroup`s
session.update_instance_group(frame_idx, best_instances)

# Just for type hinting
best_instances = cast(Dict[Camcorder, List[Instance]], best_instances)
instances_and_coords = cast(
Expand Down Expand Up @@ -4019,13 +4031,15 @@ def calculate_error_per_frame(
def get_products_of_instances(
session: RecordingSession,
frame_idx: int,
selected_cam: Camcorder,
cams_to_include: Optional[List[Camcorder]] = None,
) -> Dict[int, Dict[Camcorder, List[Instance]]]:
"""Get all (multi-instance) possible products of instances across views.
Args:
session: The `RecordingSession` containing the `Camcorder`s.
frame_idx: Frame index to get instances from (0-indexed).
selected_cam: The `Camcorder` object to determine the current view.
cams_to_include: List of `Camcorder`s to include. Default is all.
require_multiple_views: If True, then raise and error if one or less views
or instances are found.
Expand Down Expand Up @@ -4062,33 +4076,69 @@ def get_products_of_instances(
skeleton=skeleton,
)

# Get permutations of instances from other views
instances_permutations: Dict[Camcorder, Iterator[Tuple]] = {}
for cam, instances_in_view in instances.items():
# Append a dummy instance to all lists of instances if less than the max length
num_missing = 1
def _fill_in_missing_instances(instances_in_view: List[Instance]):
"""Fill in missing instances with dummy instances up to max number of instances.
Args:
instances_in_view: List of instances in a view.
"""

num_instances = len(instances_in_view)

if num_instances < max_num_instances:
num_missing = max_num_instances - num_instances

# Extend the list first
instances_in_view.extend([dummy_instance] * num_missing)

# Permute instances into all possible orderings w/in a view
instances_permutations[cam] = permutations(instances_in_view)
return instances_in_view

# Get products of instances from other views into all possible groupings
# Ordering of dict_values is preserved in Python 3.7+
products_of_instances: Iterator[Iterator[Tuple]] = product(
*instances_permutations.values()
)
# The existing grouping of instances
instance_group: Optional[
Dict[Camcorder, List[Instance]]
] = session.get_instance_group(frame_idx=frame_idx)

# Reorganize products by cam and add selected instance to each permutation
instances_hypotheses: Dict[int, Dict[Camcorder, List[Instance]]] = {}
for frame_id, prod in enumerate(products_of_instances):
instances_hypotheses[frame_id] = {
cam: [*inst] for cam, inst in zip(instances.keys(), prod)
}
# TODO(LM): This should be skipped if we are doing greedy matching, not if instance_group is None
if instance_group is None:
# Get permutations of instances from other views
instances_permutations: Dict[Camcorder, Iterator[Tuple]] = {}
for cam, instances_in_view in instances.items():
# Append a dummy instance to all lists of instances if less than the max length
instances_in_view = _fill_in_missing_instances(instances_in_view)

# Permute instances into all possible orderings w/in a view
instances_permutations[cam] = permutations(instances_in_view)

# Get products of instances from other views into all possible groupings
# Ordering of dict_values is preserved in Python 3.7+
products_of_instances: Iterator[Iterator[Tuple]] = product(
*instances_permutations.values()
)

# Reorganize products by cam and add selected instance to each permutation
for frame_id, prod in enumerate(products_of_instances):
instances_hypotheses[frame_id] = {
cam: list(inst) for cam, inst in zip(instances.keys(), prod)
}
else:
# Remove instances in selected view from instance grouping, we can't assume
# that all instances in the selected view will be in the instance_group
instance_group.pop(selected_cam, None)

# Get instances in current view
instances_in_view = instances[selected_cam]

# Fill in with dummy instances if less than max number of instances
instances_in_view = _fill_in_missing_instances(instances_in_view)

# Permute instances into all possible orderings w/in a view
instances_permutations = permutations(instances_in_view)

# Create hypotheses
for frame_id, perm in enumerate(instances_permutations):
instances_hypotheses[frame_id] = instance_group.copy()
instances_hypotheses[frame_id].update({selected_cam: list(perm)})

# Expect "max # instances in view" ** "# views" frames (a.k.a. hypotheses)
return instances_hypotheses
Expand Down Expand Up @@ -4327,10 +4377,12 @@ def group_instances_and_coords(
] # len(T) of N x 2

# TODO(LM): I think we will need a reconsumable iterator here.
insts_and_coords_in_frame[cam]: Tuple[Instance, np.ndarray] = zip(
insts_and_coords_in_view: Tuple[Instance, np.ndarray] = zip(
instances_in_frame_ordered,
insts_coords_in_frame,
)
insts_and_coords_in_frame[cam] = insts_and_coords_in_view

insts_and_coords[frame_idx] = insts_and_coords_in_frame

return insts_and_coords
Expand Down
Loading

0 comments on commit c66154a

Please sign in to comment.