Skip to content

Commit

Permalink
Add Labels.skeleton convenience attribute (#54)
Browse files Browse the repository at this point in the history
  • Loading branch information
talmo authored Jul 20, 2023
1 parent 8b444b1 commit ae61fb9
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 1 deletion.
13 changes: 13 additions & 0 deletions sleap_io/model/labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,19 @@ def video(self) -> Video:
"in the labels. Use Labels.videos instead."
)

@property
def skeleton(self) -> Skeleton:
"""Return the skeleton if there is only a single skeleton in the labels."""
if len(self.skeletons) == 0:
raise ValueError("There are no skeletons in the labels.")
elif len(self.skeletons) == 1:
return self.skeletons[0]
else:
raise ValueError(
"Labels.skeleton can only be used when there is only a single skeleton "
"saved in the labels. Use Labels.skeletons instead."
)

def find(
self,
video: Video,
Expand Down
2 changes: 1 addition & 1 deletion sleap_io/model/skeleton.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def _update_node_map(self, attr, nodes):
self._node_name_map = {node.name: node for node in nodes}
self._node_ind_map = {node: i for i, node in enumerate(nodes)}

nodes: list[Node] = field(on_setattr=_update_node_map)
nodes: list[Node] = field(factory=list, on_setattr=_update_node_map)
edges: list[Edge] = field(factory=list)
symmetries: list[Symmetry] = field(factory=list)
name: Optional[str] = None
Expand Down
15 changes: 15 additions & 0 deletions tests/model/test_labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,3 +101,18 @@ def test_labels_video():
labels.videos.append(Video(filename="test2"))
with pytest.raises(ValueError):
labels.video


def test_labels_skeleton():
labels = Labels()

with pytest.raises(ValueError):
labels.skeleton

skel = Skeleton(["A"])
labels.skeletons.append(skel)
assert labels.skeleton == skel

labels.skeletons.append(Skeleton(["B"]))
with pytest.raises(ValueError):
labels.skeleton

0 comments on commit ae61fb9

Please sign in to comment.