Skip to content

Commit

Permalink
Parse skeleton symmetries from SLP files (#53)
Browse files Browse the repository at this point in the history
* Add Symmetry parsing to slp

* Update centered pair fixture to include symmetries
  • Loading branch information
talmo authored Jul 20, 2023
1 parent a341f14 commit 8b444b1
Show file tree
Hide file tree
Showing 6 changed files with 57 additions and 17 deletions.
31 changes: 26 additions & 5 deletions sleap_io/io/slp.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
Video,
Skeleton,
Edge,
Symmetry,
Node,
Track,
Point,
Expand Down Expand Up @@ -123,7 +124,7 @@ def read_skeletons(labels_path: str) -> list[Skeleton]:
skeleton_objects = []
for skel in metadata["skeletons"]:
# Parse out the cattr-based serialization stuff from the skeleton links.
edge_inds = []
edge_inds, symmetry_inds = [], []
for link in skel["links"]:
if "py/reduce" in link["type"]:
edge_type = link["type"]["py/reduce"][1]["py/tuple"][0]
Expand All @@ -133,20 +134,40 @@ def read_skeletons(labels_path: str) -> list[Skeleton]:
if edge_type == 1: # 1 -> real edge, 2 -> symmetry edge
edge_inds.append((link["source"], link["target"]))

elif edge_type == 2:
symmetry_inds.append((link["source"], link["target"]))

# Re-index correctly.
skeleton_node_inds = [node["id"] for node in skel["nodes"]]
node_names = [node_names[i] for i in skeleton_node_inds]

# Create nodes.
nodes = []
for name in node_names:
nodes.append(Node(name=name))

# Create edges.
edge_inds = [
(skeleton_node_inds.index(s), skeleton_node_inds.index(d))
for s, d in edge_inds
]
nodes = []
for name in node_names:
nodes.append(Node(name=name))
edges = []
for edge in edge_inds:
edges.append(Edge(source=nodes[edge[0]], destination=nodes[edge[1]]))
skel = Skeleton(nodes=nodes, edges=edges, name=skel["graph"]["name"])

# Create symmetries.
symmetry_inds = [
(skeleton_node_inds.index(s), skeleton_node_inds.index(d))
for s, d in symmetry_inds
]
symmetries = []
for symmetry in symmetry_inds:
symmetries.append(Symmetry([nodes[symmetry[0]], nodes[symmetry[1]]]))

# Create the full skeleton.
skel = Skeleton(
nodes=nodes, edges=edges, symmetries=symmetries, name=skel["graph"]["name"]
)
skeleton_objects.append(skel)
return skeleton_objects

Expand Down
Binary file modified tests/data/slp/centered_pair_predictions.slp
Binary file not shown.
4 changes: 2 additions & 2 deletions tests/fixtures/labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@


@pytest.fixture
def labels_predictions(slp_predictions):
def labels_predictions(centered_pair):
"""Labels object containing predicted instances, multiple tracks and a single video."""
return sleap_io.load_slp(slp_predictions)
return sleap_io.load_slp(centered_pair)
15 changes: 11 additions & 4 deletions tests/fixtures/slp.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

@pytest.fixture
def slp_typical():
"""Typical SLP file including `PredictedInstance`, `Instance`, `Track` and `Skeleton` objects."""
"""Typical SLP file including."""
return "tests/data/slp/typical.slp"


Expand All @@ -27,14 +27,21 @@ def slp_minimal_pkg():


@pytest.fixture
def slp_predictions():
"""A more complex example containing predicted instances from multiple tracks and a single video"""
def centered_pair():
"""Example with predicted instances from multiple tracks and a single video.
This project:
- Has 1 grayscale video with 1100 frames, cropped to 384x384 with 2 flies
- Has a 24 node skeleton with edges and symmetries
- Has 0 user instances and 2274 predicted instances
- Has 2 correct tracks and 25 extraneous tracks
"""
return "tests/data/slp/centered_pair_predictions.slp"


@pytest.fixture
def slp_predictions_with_provenance():
"""The slp file generated with the collab tutorial and sleap version 1.27"""
"""The slp file generated with the colab tutorial and sleap version 1.2.7."""
return "tests/data/slp/predictions_1.2.7_provenance_and_tracking.slp"


Expand Down
12 changes: 6 additions & 6 deletions tests/io/test_nwb.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,8 @@ def test_default_metadata_overwriting(nwbfile, slp_predictions_with_provenance):
assert pose_estimation_series.rate == expected_sampling_rate


def test_complex_case_append(nwbfile, slp_predictions):
labels = load_slp(slp_predictions)
def test_complex_case_append(nwbfile, centered_pair):
labels = load_slp(centered_pair)
nwbfile = append_nwb_data(labels, nwbfile)

# Test matching number of processing modules
Expand Down Expand Up @@ -167,8 +167,8 @@ def test_complex_case_append(nwbfile, slp_predictions):
assert node_name in pose_estimation_container.pose_estimation_series


def test_complex_case_append_with_timestamps_metadata(nwbfile, slp_predictions):
labels = load_slp(slp_predictions)
def test_complex_case_append_with_timestamps_metadata(nwbfile, centered_pair):
labels = load_slp(centered_pair)

number_of_frames = 1100 # extracted using ffmpeg probe
video_sample_rate = 15.0 # 15 Hz extracted using ffmpeg probe for the video stream
Expand Down Expand Up @@ -236,8 +236,8 @@ def test_typical_case_write(slp_typical, tmp_path):
assert len(nwbfile.processing) == number_of_videos


def test_get_timestamps(nwbfile, slp_predictions):
labels = load_slp(slp_predictions)
def test_get_timestamps(nwbfile, centered_pair):
labels = load_slp(centered_pair)
nwbfile = append_nwb_data(labels, nwbfile)
processing = nwbfile.processing["SLEAP_VIDEO_000_centered_pair_low_quality"]
assert True
Expand Down
12 changes: 12 additions & 0 deletions tests/io/test_slp.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,18 @@ def test_read_instances_from_predicted(slp_real_data):
assert len(lf.unused_predictions) == 0


def test_read_skeleton(centered_pair):
skeletons = read_skeletons(centered_pair)
assert len(skeletons) == 1
skeleton = skeletons[0]
assert type(skeleton) == Skeleton
assert len(skeleton.nodes) == 24
assert len(skeleton.edges) == 23
assert len(skeleton.symmetries) == 20
assert Node("wingR") in skeleton.symmetries[0].nodes
assert Node("wingL") in skeleton.symmetries[0].nodes


def test_read_videos_pkg(slp_minimal_pkg):
videos = read_videos(slp_minimal_pkg)
assert len(videos) == 1
Expand Down

0 comments on commit 8b444b1

Please sign in to comment.