Skip to content

Commit

Permalink
Few extra tests for InstanceGroup
Browse files Browse the repository at this point in the history
  • Loading branch information
roomrys committed Nov 30, 2023
1 parent 4037122 commit 89550a1
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 18 deletions.
32 changes: 21 additions & 11 deletions sleap/io/cameras.py
Original file line number Diff line number Diff line change
Expand Up @@ -782,6 +782,11 @@ def instances(self) -> List["Instance"]:
"""List of `Instance` objects."""
return list(self._instance_by_camcorder.values())

@property
def cameras(self) -> List[Camcorder]:
"""List of `Camcorder` objects."""
return list(self._instance_by_camcorder.keys())

def get_instance(self, cam: Camcorder) -> Optional["Instance"]:
"""Retrieve `Instance` linked to `Camcorder`.
Expand Down Expand Up @@ -825,12 +830,15 @@ def __getitem__(
) -> Union[Camcorder, "Instance"]:
"""Grab a `Camcorder` of `Instance` from the `InstanceGroup`."""

def _raise_key_error():
raise KeyError(f"Key {idx_or_key} not found in {self.__class__.__name__}.")

# Try to find in `self.camera_cluster.cameras`
if isinstance(idx_or_key, int):
try:
return self.instances[idx_or_key]
except IndexError:
pass
_raise_key_error()

# Return a `Instance` if `idx_or_key` is a `Camcorder``
if isinstance(idx_or_key, Camcorder):
Expand All @@ -843,10 +851,7 @@ def __getitem__(
except:
pass

raise KeyError(
f"Key {idx_or_key} not found in {self.__class__.__name__} or "
"associated metadata."
)
_raise_key_error()

def __len__(self):
return len(self.instances)
Expand All @@ -855,23 +860,27 @@ def __repr__(self):
return f"{self.__class__.__name__}(frame_idx={self.frame_idx}, instances={len(self)}, camera_cluster={self.camera_cluster})"

@classmethod
def from_dict(cls, d: dict) -> "InstanceGroup":
def from_dict(cls, d: dict) -> Optional["InstanceGroup"]:
"""Creates an `InstanceGroup` object from a dictionary.
Args:
d: Dictionary with `Camcorder` keys and `Instance` values.
Returns:
`InstanceGroup` object.
`InstanceGroup` object or None if no "real" (determined by `frame_idx` other
than None) instances found.
"""

# Ensure not to mutate the original dictionary
d_copy = d.copy()

frame_idx = None
for cam, instance in d.copy().items():
for cam, instance in d_copy.copy().items():
camera_cluster = cam.camera_cluster

# Remove dummy instances (determined by not having a frame index)
if instance.frame_idx is None:
d.pop(cam)
d_copy.pop(cam)
# Grab the frame index from non-dummy instances
elif frame_idx is None:
frame_idx = instance.frame_idx
Expand All @@ -885,8 +894,9 @@ def from_dict(cls, d: dict) -> "InstanceGroup":
f"does not match instance frame index {instance.frame_idx}."
)

if len(d) == 0:
if len(d_copy) == 0:
logger.warning("Cannot create `InstanceGroup`: No real instances found.")
return None

frame_idx = cast(
int, frame_idx
Expand All @@ -895,5 +905,5 @@ def from_dict(cls, d: dict) -> "InstanceGroup":
return cls(
frame_idx=frame_idx,
camera_cluster=camera_cluster,
instance_by_camcorder=d,
instance_by_camcorder=d_copy,
)
46 changes: 39 additions & 7 deletions tests/io/test_cameras.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,16 +307,48 @@ def test_instance_group(multiview_min_session_labels: Labels):
)
instance_by_camera[cam] = instance

# Add a dummy instance to make sure it gets ignored
instance_by_camera[cam] = dummy_instance

# Test `from_dict`
instance_group = InstanceGroup.from_dict(d=instance_by_camera)
assert isinstance(instance_group, InstanceGroup)
assert instance_group.frame_idx == frame_idx
assert instance_group.camera_cluster == camera_cluster
for cam in session.linked_cameras:
try:
instance = instance_group[cam]
for camera in session.linked_cameras:
if camera == cam:
assert instance_by_camera[camera] == dummy_instance
assert camera not in instance_group.cameras
else:
instance = instance_group[camera]
assert isinstance(instance, Instance)
assert instance_group[cam] == instance_by_camera[cam]
assert instance_group[instance] == cam
except:
assert instance_by_camera[cam] == dummy_instance
assert instance_group[camera] == instance_by_camera[camera]
assert instance_group[instance] == camera

# Test `__repr__`
print(instance_group)

# Test `__len__`
assert len(instance_group) == len(instance_by_camera) - 1

# Test `get_cam`
assert instance_group.get_cam(dummy_instance) is None

# Test `get_instance`
assert instance_group.get_instance(cam) is None

# Test `instances` property
assert len(instance_group.instances) == len(instance_by_camera) - 1

# Test `cameras` property
assert len(instance_group.cameras) == len(instance_by_camera) - 1

# Test `__getitem__` with `int` key
assert isinstance(instance_group[0], Instance)
with pytest.raises(KeyError):
instance_group[len(instance_group)]

# Populate with only dummy instance and test `from_dict`
instance_by_camera = {cam: dummy_instance}
instance_group = InstanceGroup.from_dict(d=instance_by_camera)
assert instance_group is None

0 comments on commit 89550a1

Please sign in to comment.