Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add InstancesList class to handle backref to LabeledFrame #1807

Merged
merged 13 commits into from
Jun 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
167 changes: 139 additions & 28 deletions sleap/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ class Instance:
from_predicted: Optional["PredictedInstance"] = attr.ib(default=None)
_points: PointArray = attr.ib(default=None)
_nodes: List = attr.ib(default=None)
frame: Union["LabeledFrame", None] = attr.ib(default=None)
frame: Union["LabeledFrame", None] = attr.ib(default=None) # TODO(LM): Make private
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider making frame a private attribute as indicated by the TODO comment.

- frame: Union["LabeledFrame", None] = attr.ib(default=None)  # TODO(LM): Make private
+ _frame: Union["LabeledFrame", None] = attr.ib(default=None)
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
frame: Union["LabeledFrame", None] = attr.ib(default=None) # TODO(LM): Make private
_frame: Union["LabeledFrame", None] = attr.ib(default=None) # TODO(LM): Make private


# The underlying Point array type that this instances point array should be.
_point_array_type = PointArray
Expand Down Expand Up @@ -1214,6 +1214,9 @@ def unstructure_instance(x: Instance):

converter.register_unstructure_hook(Instance, unstructure_instance)
converter.register_unstructure_hook(PredictedInstance, unstructure_instance)
converter.register_unstructure_hook(
InstancesList, lambda x: [converter.unstructure(inst) for inst in x]
)

## STRUCTURE HOOKS

Expand Down Expand Up @@ -1247,6 +1250,7 @@ def structure_instances_list(x, type):
converter.register_structure_hook(
Union[List[Instance], List[PredictedInstance]], structure_instances_list
)
converter.register_structure_hook(InstancesList, structure_instances_list)

# Structure forward reference for PredictedInstance for the Instance.from_predicted
# attribute.
Expand Down Expand Up @@ -1278,6 +1282,127 @@ def structure_point_array(x, t):
return converter


class InstancesList(list):
"""A list of `Instance`s associated with a `LabeledFrame`.

This class should only be used for the `LabeledFrame.instances` attribute.
"""

def __init__(self, *args, labeled_frame: Optional["LabeledFrame"] = None):
super(InstancesList, self).__init__(*args)

# Set the labeled frame for each instance
self.labeled_frame = labeled_frame

@property
def labeled_frame(self) -> "LabeledFrame":
"""Return the `LabeledFrame` associated with this list of instances."""

return self._labeled_frame

@labeled_frame.setter
def labeled_frame(self, labeled_frame: "LabeledFrame"):
"""Set the `LabeledFrame` associated with this list of instances.

This updates the `frame` attribute on each instance.

Args:
labeled_frame: The `LabeledFrame` to associate with this list of instances.
"""

try:
# If the labeled frame is the same as the one we're setting, then skip
if self._labeled_frame == labeled_frame:
return
except AttributeError:
# Only happens on init and updates each instance.frame (even if None)
pass

# Otherwise, update the frame for each instance
self._labeled_frame = labeled_frame
for instance in self:
instance.frame = labeled_frame

def append(self, instance: Union[Instance, PredictedInstance]):
"""Append an `Instance` or `PredictedInstance` to the list, setting the frame.

Args:
item: The `Instance` or `PredictedInstance` to append to the list.
"""

if not isinstance(instance, (Instance, PredictedInstance)):
raise ValueError(
f"InstancesList can only contain Instance or PredictedInstance objects,"
f" but got {type(instance)}."
)
instance.frame = self.labeled_frame
super().append(instance)

def extend(self, instances: List[Union[PredictedInstance, Instance]]):
"""Extend the list with a list of `Instance`s or `PredictedInstance`s.

Args:
instances: A list of `Instance` or `PredictedInstance` objects to add to the
list.

Returns:
None
"""
for instance in instances:
self.append(instance)

def __delitem__(self, index):
"""Remove instance (by index), and set instance.frame to None."""

instance: Instance = self.__getitem__(index)
super().__delitem__(index)

# Modify the instance to remove reference to the frame
instance.frame = None

def insert(self, index: int, instance: Union[Instance, PredictedInstance]) -> None:
super().insert(index, instance)
instance.frame = self.labeled_frame

def __setitem__(self, index, instance: Union[Instance, PredictedInstance]):
"""Set nth instance in frame to the given instance.

Args:
index: The index of instance to replace with new instance.
value: The new instance to associate with frame.

Returns:
None.
"""
super().__setitem__(index, instance)
instance.frame = self.labeled_frame

def pop(self, index: int) -> Union[Instance, PredictedInstance]:
"""Remove and return instance at index, setting instance.frame to None."""

instance = super().pop(index)
instance.frame = None
return instance

def remove(self, instance: Union[Instance, PredictedInstance]) -> None:
"""Remove instance from list, setting instance.frame to None."""
super().remove(instance)
instance.frame = None

def clear(self) -> None:
"""Remove all instances from list, setting instance.frame to None."""
for instance in self:
instance.frame = None
super().clear()

def copy(self) -> list:
"""Return a shallow copy of the list of instances as a list.

Note: This will not return an `InstancesList` object, but a normal list.
"""
return list(self)

Comment on lines +1285 to +1404
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tip

Codebase Verification

The search results indicate that some methods of the InstancesList class, such as append, extend, remove, and copy, have corresponding unit tests. However, there were no matches for other methods like pop, insert, clear, delitem, and setitem.

To ensure comprehensive test coverage, it is recommended to verify and, if necessary, add unit tests for the following methods:

  • pop
  • insert
  • clear
  • delitem
  • setitem

This will help ensure that the InstancesList class is thoroughly tested for all its functionalities.

Analysis chain

The InstancesList class is a significant addition. Ensure that the logic for managing instance frames upon list operations is thoroughly tested, especially methods that modify the list (e.g., append, extend, pop, etc.).

Scripts executed

The following scripts were executed for the analysis:


Script:

#!/bin/bash
# Description: Verify all methods of InstancesList are covered by unit tests.

# Test: Search for test cases covering InstancesList methods. Expect: Test cases for each method.
rg --type python $'def test_.*InstancesList'

Length of output: 76



Script:

#!/bin/bash
# Description: Verify all methods of InstancesList are covered by unit tests.

# Step 1: Find all Python test files
test_files=$(fd --type f --extension py 'test_')

# Step 2: Search for test cases covering InstancesList methods in the found test files
for file in $test_files; do
    echo "Searching in $file"
    rg 'def test_.*InstancesList' "$file"
done

Length of output: 18401



Script:

#!/bin/bash
# Description: Verify specific methods of InstancesList are covered by unit tests.

# List of methods to verify
methods=("append" "extend" "pop" "remove" "insert" "clear" "copy" "delitem" "setitem")

# Step 1: Find all Python test files
test_files=$(fd --type f --extension py 'test_')

# Step 2: Search for test cases covering specific InstancesList methods in the found test files
for method in "${methods[@]}"; do
    for file in $test_files; do
        echo "Searching for method $method in $file"
        rg "def test_.*$method" "$file"
    done
done

Length of output: 166380

Tools
Ruff

1292-1292: Use super() instead of super(__class__, self) (UP008)

Remove __super__ parameters


@attr.s(auto_attribs=True, eq=False, repr=False, str=False)
class LabeledFrame:
"""Holds labeled data for a single frame of a video.
Expand All @@ -1290,9 +1415,7 @@ class LabeledFrame:

video: Video = attr.ib()
frame_idx: int = attr.ib(converter=int)
_instances: Union[List[Instance], List[PredictedInstance]] = attr.ib(
default=attr.Factory(list)
)
_instances: InstancesList = attr.ib(default=attr.Factory(InstancesList))

def __attrs_post_init__(self):
"""Called by attrs.
Expand All @@ -1302,8 +1425,7 @@ def __attrs_post_init__(self):
"""

# Make sure all instances have a reference to this frame
for instance in self.instances:
instance.frame = self
self.instances = self._instances

def __len__(self) -> int:
"""Return number of instances associated with frame."""
Expand All @@ -1319,13 +1441,8 @@ def index(self, value: Instance) -> int:

def __delitem__(self, index):
"""Remove instance (by index) from frame."""
value = self.instances.__getitem__(index)

self.instances.__delitem__(index)

# Modify the instance to remove reference to this frame
value.frame = None

def __repr__(self) -> str:
"""Return a readable representation of the LabeledFrame."""
return (
Expand All @@ -1348,9 +1465,6 @@ def insert(self, index: int, value: Instance):
"""
self.instances.insert(index, value)

# Modify the instance to have a reference back to this frame
value.frame = self

def __setitem__(self, index, value: Instance):
"""Set nth instance in frame to the given instance.

Expand All @@ -1363,9 +1477,6 @@ def __setitem__(self, index, value: Instance):
"""
self.instances.__setitem__(index, value)

# Modify the instance to have a reference back to this frame
value.frame = self

def find(
self, track: Optional[Union[Track, int]] = -1, user: bool = False
) -> List[Instance]:
Expand Down Expand Up @@ -1393,7 +1504,7 @@ def instances(self) -> List[Instance]:
return self._instances

@instances.setter
def instances(self, instances: List[Instance]):
def instances(self, instances: Union[InstancesList, List[Instance]]):
"""Set the list of instances associated with this frame.

Updates the `frame` attribute on each instance to the
Expand All @@ -1408,9 +1519,11 @@ def instances(self, instances: List[Instance]):
None
"""

# Make sure to set the frame for each instance to this LabeledFrame
for instance in instances:
instance.frame = self
# Make sure to set the LabeledFrame for each instance to this frame
if isinstance(instances, InstancesList):
instances.labeled_frame = self
else:
instances = InstancesList(instances, labeled_frame=self)

self._instances = instances

Expand Down Expand Up @@ -1685,22 +1798,20 @@ def complex_frame_merge(
* list of conflicting instances from base
* list of conflicting instances from new
"""
merged_instances = []
redundant_instances = []
extra_base_instances = copy(base_frame.instances)
extra_new_instances = []
merged_instances: List[Instance] = [] # Only used for informing user
redundant_instances: List[Instance] = []
extra_base_instances: List[Instance] = list(base_frame.instances)
extra_new_instances: List[Instance] = []

for new_inst in new_frame:
redundant = False
for base_inst in base_frame.instances:
if new_inst.matches(base_inst):
base_inst.frame = None
extra_base_instances.remove(base_inst)
redundant_instances.append(base_inst)
redundant = True
continue
if not redundant:
new_inst.frame = None
extra_new_instances.append(new_inst)

conflict = False
Expand Down Expand Up @@ -1732,7 +1843,7 @@ def complex_frame_merge(
else:
# No conflict, so include all instances in base
base_frame.instances.extend(extra_new_instances)
merged_instances = copy(extra_new_instances)
merged_instances: List[Instance] = copy(extra_new_instances)
extra_base_instances = []
extra_new_instances = []

Expand Down
Loading