Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add triangulated 3d pose attribute #2082

Draft
wants to merge 22 commits into
base: liezl/add-gui-elements-for-sessions
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 2 additions & 12 deletions sleap/gui/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -3772,20 +3772,10 @@ def do_action(cls, context: CommandContext, params: dict):
calib=session.camera_cluster,
excluded_views=frame_group.excluded_views,
) # F x T x N x 3

# Reproject onto all views
pts_reprojected = reproject(
points_3d,
calib=session.camera_cluster,
excluded_views=frame_group.excluded_views,
) # M=include x F=1 x T x N x 2

# Sqeeze back to the original shape
points_reprojected = np.squeeze(pts_reprojected, axis=1) # M=include x TxNx2

# Update or create/insert ("upsert") instance points

frame_group.upsert_points(
points=points_reprojected,
points=points_3d,
instance_groups=instance_groups,
exclude_complete=True,
)
Expand Down
133 changes: 104 additions & 29 deletions sleap/io/cameras.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from sleap.instance import Instance, LabeledFrame, PredictedInstance
from sleap.io.video import Video
from sleap.util import compute_oks, deep_iterable_converter
from sleap_anipose.triangulation import reproject

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -452,6 +453,7 @@ class InstanceGroup:
_dummy_instance: Optional[Instance] = field(default=None)
camera_cluster: Optional[CameraCluster] = field(default=None)
_score: Optional[float] = field(default=None)
_triangulation: Optional[np.ndarray] = field(default=None)

def __attrs_post_init__(self):
"""Initialize `InstanceGroup` object."""
Expand All @@ -460,6 +462,14 @@ def __attrs_post_init__(self):
for cam, instance in self._instance_by_camcorder.items():
self._camcorder_by_instance[instance] = cam

@property
def triangulation(self) -> Optional[np.ndarray]:
return self._triangulation

@triangulation.setter
def triangulation(self, triangulation: np.ndarray):
self._triangulation = triangulation

def _create_dummy_instance(self, instance: Optional[Instance] = None):
"""Create a dummy instance to fill in for missing instances.

Expand Down Expand Up @@ -838,32 +848,114 @@ def get_cam(self, instance: Instance) -> Optional[Camcorder]:
def update_points(
self,
points: np.ndarray,
projection_bounds: np.ndarray,
cams_to_include: Optional[List[Camcorder]] = None,
exclude_complete: bool = True,
excluded_views: Optional[List[str]] = None,
):
"""Update the points in the `Instance` for the specified `Camcorder`s.

Args:
points: Numpy array of shape (M, N, 2) where M is the number of views, N is
the number of Nodes, and 2 is for x, y.
points: Numpy array of shape (N, 3) N is
the number of Nodes, and 3 is for x, y, z.
projections_bounds: Numpy array of shape (M, 2) where M is the number of views
cams_to_include: List of `Camcorder`s to include in the update. The order of
the `Camcorder`s in the list should match the order of the views in the
`points` array. If None, then all `Camcorder`s in the `CameraCluster`
are included. Default is None.
exclude_complete: If True, then do not update points that are marked as
complete. Default is True.
excluded_views: List of `Camcorder` names to exclude from the update.
"""
# Ensure we are working with a float array
points = points.astype(float)

# Check that the correct shape was passed in
points = points.squeeze() # N x 3
n_nodes, n_coords = points.shape
if n_coords != 3:
raise ValueError(f"Expected 3 coordinates in `points`, got {n_coords}.")

# If no `Camcorder`s specified, then update `Instance`s for all `CameraCluster`
if cams_to_include is None:
cams_to_include = self.camera_cluster.cameras

if excluded_views is None:
excluded_views = ()

if len(cams_to_include) + len(excluded_views) != len(self.camera_cluster):
raise ValueError(
f"The number of `Camcorder`s to include {len(cams_to_include)} plus the number of `Camcorder`s "
f"to exclude {len(excluded_views)} does not match the number of `Camcorder`s in the "
f"`CameraCluster` {len(self.camera_cluster)}."
)

# Reproject 3D points into 2D points for each camera view
pts_reprojected = reproject(
np.expand_dims(points, axis=(0, 1)), # M=include x N x 3
calib=self.camera_cluster,
excluded_views=excluded_views,
) # M=include x F=1 x T x N x 2

# Squeeze back to the original shape
points_reprojected = np.squeeze(pts_reprojected, axis=(1, 2)) # M=include x Nx2

# Update the points for each `Instance` in the `InstanceGroup` using 2d points
self.update_points_from_2d(
points_reprojected=points_reprojected,
projection_bounds=projection_bounds,
cams_to_include=cams_to_include,
exclude_complete=exclude_complete,
)

def update_points_from_2d(
self,
points_reprojected: np.ndarray,
projection_bounds: np.ndarray,
cams_to_include: Optional[List[Camcorder]] = None,
exclude_complete: bool = True,
):

# Check that correct shape was passed in
n_views, n_nodes, _ = points.shape
assert n_views == len(cams_to_include), (
f"Number of views in `points` ({n_views}) does not match the number of "
f"Camcorders in `cams_to_include` ({len(cams_to_include)})."
points_shape = points_reprojected.shape
try:
n_views, n_nodes, n_coords = points_reprojected.shape
if n_views != len(cams_to_include):
raise ValueError(
f"Number of views in `points` ({n_views}) does not match the number"
f" of Camcorders in `cams_to_include` ({len(cams_to_include)})."
)
if n_coords != 2:
raise ValueError(f"Expected 2 coordinates in `points`, got {n_coords}.")
except ValueError as e:
raise ValueError(
f"Expected `points_reprojected` to be of shape (M, N, 2), got "
f"{points_shape}.\n\n{e}"
)

# Ensure we are working with a float array
points_reprojected = points_reprojected.astype(np.float64)

# If no `Camcorder`s specified, then update `Instance`s for all `CameraCluster`
if cams_to_include is None:
cams_to_include = self.camera_cluster.cameras

# Get projection bounds (based on video height/width)
bounds = projection_bounds # TODO: make sure projection bounds are the shape they need to be in update points
bounds_expanded_x = bounds[:, None, 0]
bounds_expanded_y = bounds[:, None, 1]

# Create masks for out-of-bounds x and y coordinates
out_of_bounds_x = (points_reprojected[..., 0] < 0) | (
points_reprojected[..., 0] > bounds_expanded_x
)
out_of_bounds_y = (points_reprojected[..., 1] < 0) | (
points_reprojected[..., 1] > bounds_expanded_y
)

# Replace out-of-bounds x and y coordinates with nan
points_reprojected[out_of_bounds_x, 0] = np.nan
points_reprojected[out_of_bounds_y, 1] = np.nan

# Calculate OKS scores for the points
gt_points = self.numpy(
Expand All @@ -883,13 +975,14 @@ def update_points(
if not isinstance(instance, PredictedInstance):
instance_oks = compute_oks(
gt_points[cam_idx, :, :],
points[cam_idx, :, :],
points_reprojected[cam_idx, :, :],
)
oks_scores[cam_idx] = instance_oks

# Update the points for the instance
instance.update_points(
points=points[cam_idx, :, :], exclude_complete=exclude_complete
points=points_reprojected[cam_idx, :, :],
exclude_complete=exclude_complete,
)

# Update the score for the InstanceGroup to be the average OKS score
Expand Down Expand Up @@ -2289,31 +2382,11 @@ def upsert_points(
complete. Default is True.
"""

# Check that the correct shape was passed in
n_views, n_instances, n_nodes, n_coords = points.shape
assert n_views == len(
self.cams_to_include
), f"Expected {len(self.cams_to_include)} views, got {n_views}."
assert n_instances == len(
instance_groups
), f"Expected {len(instance_groups)} instances, got {n_instances}."
assert n_coords == 2, f"Expected 2 coordinates, got {n_coords}."

# Ensure we are working with a float array
points = points.astype(float)

# Get projection bounds (based on video height/width)
bounds = self.session.projection_bounds
bounds_expanded_x = bounds[:, None, None, 0]
bounds_expanded_y = bounds[:, None, None, 1]

# Create masks for out-of-bounds x and y coordinates
out_of_bounds_x = (points[..., 0] < 0) | (points[..., 0] > bounds_expanded_x)
out_of_bounds_y = (points[..., 1] < 0) | (points[..., 1] > bounds_expanded_y)

# Replace out-of-bounds x and y coordinates with nan
points[out_of_bounds_x, 0] = np.nan
points[out_of_bounds_y, 1] = np.nan

# Update points for each `InstanceGroup`
for ig_idx, instance_group in enumerate(instance_groups):
Expand All @@ -2326,11 +2399,13 @@ def upsert_points(
self.create_and_add_missing_instances(instance_group=instance_group)

# Update points for each `Instance` in `InstanceGroup`
instance_points = points[:, ig_idx, :, :] # M x N x 2
instance_points = points[ig_idx, :, :] # N x 3
instance_group.update_points(
points=instance_points,
cams_to_include=self.cams_to_include,
excluded_views=self.excluded_views,
exclude_complete=exclude_complete,
projection_bounds=bounds,
)

def _raise_if_instance_not_in_instance_group(self, instance: Instance):
Expand Down
Loading
Loading