diff --git a/sleap/io/dataset.py b/sleap/io/dataset.py index 76ee755f8..beee6dece 100644 --- a/sleap/io/dataset.py +++ b/sleap/io/dataset.py @@ -58,6 +58,7 @@ import cattr import h5py as h5 import numpy as np +import datetime from sklearn.model_selection import train_test_split try: @@ -2028,6 +2029,26 @@ def export(self, filename: str): SleapAnalysisAdaptor.write(filename, self) + def export_nwb( + self, + filename: str, + overwrite: bool = False, + session_description: str = "Processed SLEAP pose data", + identifier: Optional[str] = None, + session_start_time: Optional[datetime.datetime] = None, + ): + from sleap.io.format.ndx_pose import NDXPoseAdaptor + + NDXPoseAdaptor.write( + NDXPoseAdaptor, + filename=filename, + labels=self, + overwrite=overwrite, + session_description=session_description, + identifier=identifier, + session_start_time=session_start_time, + ) + @classmethod def load_json(cls, filename: str, *args, **kwargs) -> "Labels": from .format import read diff --git a/sleap/io/format/ndx_pose.py b/sleap/io/format/ndx_pose.py index d182d229d..928a97cf5 100644 --- a/sleap/io/format/ndx_pose.py +++ b/sleap/io/format/ndx_pose.py @@ -4,9 +4,10 @@ import datetime import re import numpy as np +import uuid -from pathlib import PurePath -from typing import List +from pathlib import Path, PurePath +from typing import List, Optional from pynwb import NWBFile, NWBHDF5IO, ProcessingModule from ndx_pose import PoseEstimationSeries, PoseEstimation @@ -68,7 +69,9 @@ def read(self, file: FileHandle) -> Labels: nwb_file = read_nwbfile.processing # Get list of videos - video_keys: List[str] = list(nwb_file.keys()) + video_keys: List[str] = [ + key for key in nwb_file.keys() if "SLEAP_VIDEO" in key + ] video_tracks = dict() # Get track keys @@ -164,7 +167,15 @@ def read(self, file: FileHandle) -> Labels: labels = Labels(lfs) return labels - def write(self, filename: str, labels: Labels): + def write( + self, + filename: str, + labels: Labels, + overwrite: bool = False, + session_description: str = "Processed SLEAP pose data", + identifier: Optional[str] = None, + session_start_time: Optional[datetime.datetime] = None, + ): """Write all `PredictedInstance` objects in a `Labels` object to an NWB file. Use `Labels.numpy` to create a `pynwb.NWBFile` with a separate @@ -198,6 +209,18 @@ def write(self, filename: str, labels: Labels): Args: filename: Output path for the NWB format file. labels: The `Labels` object to covert to a NWB format file. + overwrite: Boolean that overwrites existing NWB file if True. If False, data + will be appended to existing NWB file. + session_description: Description for entire project. Stored under + NWBFile "session_description" key. If appending data to a preexisting + file, then the session_description will not be used. + identifier: Unique identifier for project. If no identifier is + specified, then will generate a GUID. If appending data to a + preexisting file, then the identifier will not be used. + session_start_time: THe datetime associated with the project. If no + session_start_time is given, then the current datetime will be used. If + appending data to a preexisting file, then the session_start_time will + not be used. Returns: A `pynwb.NWBFile` with a separate `pynwb.ProcessingModule` for each @@ -205,90 +228,123 @@ def write(self, filename: str, labels: Labels): """ - skeleton = labels.skeleton - # Check that this project contains predicted instances if len(labels.predicted_instances) == 0: raise TypeError( "Only predicted instances are written to the NWB format. " - "This project has no predicted instances" + "This project has no predicted instances." ) - print(f"\nCreating NWB file...") - nwb_file = NWBFile( - session_description="session_description", - identifier="identifier", - session_start_time=datetime.datetime.now(datetime.timezone.utc), - ) + # Set optional kwargs if not specified by user + if session_start_time is None: + session_start_time = datetime.datetime.now(datetime.timezone.utc) + identifier = str(uuid.uuid4()) if identifier is None else identifier - for video_idx, video in enumerate(labels.videos): - # Create new processing module for each video - video_fn = PurePath(video.backend.filename) - nwb_processing_module = nwb_file.create_processing_module( - name=f"{video_idx:03}_{video_fn.stem}", - description=f"Processed SLEAP pose data for {video_fn.name} with " - f"{skeleton.name} skeleton.", - ) - - # Get tracks for each video - video_lfs = labels.get(video) - untracked = all( - [inst.track is None for lf in video_lfs for inst in lf.instances] - ) - tracks_numpy = labels.numpy( - video=video, - all_frames=True, - untracked=untracked, - return_confidence=True, - ) - n_frames, n_tracks, n_nodes, _ = tracks_numpy.shape - timestamps = np.arange(n_frames) - for track_idx in list(range(n_tracks)): - pose_estimation_series: List[PoseEstimationSeries] = [] - - for node_idx, node in enumerate(skeleton.nodes): - - # Create instance of PoseEstimationSeries for each node - data = tracks_numpy[:, track_idx, node_idx, :2] - confidence = tracks_numpy[:, track_idx, node_idx, 2] - - pose_estimation_series.append( - PoseEstimationSeries( - name=f"{node.name}", - description=f"Sequential trajectory of {node.name}.", - data=data, - unit="pixels", - reference_frame="No reference.", - timestamps=timestamps, - confidence=confidence, - confidence_definition="Point-wise confidence scores.", - ) + try: + io = None + if Path(filename).exists() and not overwrite: + # Append to file if it exists and we do not want to overwrite + print(f"\nOpening existing NWB file...") + io = NWBHDF5IO(filename, mode="a", load_namespaces=True) + nwb_file = io.read() + else: + # If file does not exist or we want to overwrite, create new file + if not overwrite: + print(f"\nCould not find the file specified: {filename}") + print(f"\nCreating NWB file...") + nwb_file = NWBFile( + session_description=session_description, + identifier=identifier, + session_start_time=session_start_time, + ) + io = NWBHDF5IO(filename, mode="w") + + skeleton = labels.skeleton + + for video_idx, video in enumerate(labels.videos): + # Create new processing module for each video + video_fn = PurePath(video.backend.filename) + try: + name = f"SLEAP_VIDEO_{video_idx:03}_{video_fn.stem}" + nwb_processing_module = nwb_file.create_processing_module( + name=name, + description=f"{session_description} for {video_fn.name} with " + f"{skeleton.name} skeleton.", ) + except ValueError: + # Cannot overwrite or delete processing modules + print( + f"Processing module for {video_fn.name} already exists... " + f"Skipping: {name}" + ) + continue - # Combine each node's PoseEstimationSeries to create a PoseEstimation - name_prefix = "untracked" if untracked else "track" - pose_estimation = PoseEstimation( - name=f"{name_prefix}{track_idx:03}", - pose_estimation_series=pose_estimation_series, - description=( - f"Estimated positions of {skeleton.name} in video {video_fn} " - f"using SLEAP." - ), - original_videos=[f"{video_fn}"], - labeled_videos=[f"{video_fn}"], - dimensions=np.array([[video.backend.height, video.backend.width]]), - scorer=str(labels.provenance), - source_software="SLEAP", - source_software_version=f"{sleap.__version__}", - nodes=skeleton.node_names, - edges=skeleton.edge_inds, + # Get tracks for each video + video_lfs = labels.get(video) + untracked = all( + [inst.track is None for lf in video_lfs for inst in lf.instances] + ) + tracks_numpy = labels.numpy( + video=video, + all_frames=True, + untracked=untracked, + return_confidence=True, ) + n_frames, n_tracks, n_nodes, _ = tracks_numpy.shape + timestamps = np.arange(n_frames) + for track_idx in list(range(n_tracks)): + pose_estimation_series: List[PoseEstimationSeries] = [] + + for node_idx, node in enumerate(skeleton.nodes): + + # Create instance of PoseEstimationSeries for each node + data = tracks_numpy[:, track_idx, node_idx, :2] + confidence = tracks_numpy[:, track_idx, node_idx, 2] + + pose_estimation_series.append( + PoseEstimationSeries( + name=f"{node.name}", + description=f"Sequential trajectory of {node.name}.", + data=data, + unit="pixels", + reference_frame="No reference.", + timestamps=timestamps, + confidence=confidence, + confidence_definition="Point-wise confidence scores.", + ) + ) + + # Combine each node's PoseEstimationSeries to create a PoseEstimation + name_prefix = "untracked" if untracked else "track" + pose_estimation = PoseEstimation( + name=f"{name_prefix}{track_idx:03}", + pose_estimation_series=pose_estimation_series, + description=( + f"Estimated positions of {skeleton.name} in video {video_fn} " + f"using SLEAP." + ), + original_videos=[f"{video_fn}"], + labeled_videos=[f"{video_fn}"], + dimensions=np.array( + [[video.backend.height, video.backend.width]] + ), + scorer=str(labels.provenance), + source_software="SLEAP", + source_software_version=f"{sleap.__version__}", + nodes=skeleton.node_names, + edges=skeleton.edge_inds, + ) - # Create a processing module for each - nwb_processing_module.add(pose_estimation) + # Create a processing module for each + nwb_processing_module.add(pose_estimation) - path = filename - with NWBHDF5IO(path, mode="w") as io: io.write(nwb_file) + except Exception as e: + raise e + + finally: + if io is not None: + io.close() + print(f"Finished writing NWB file to {filename}\n") diff --git a/tests/io/test_dataset.py b/tests/io/test_dataset.py index bc1bab613..11fd32a53 100644 --- a/tests/io/test_dataset.py +++ b/tests/io/test_dataset.py @@ -1,7 +1,7 @@ import os import pytest import numpy as np -from pathlib import Path +from pathlib import Path, PurePath import sleap from sleap.skeleton import Skeleton @@ -9,7 +9,10 @@ from sleap.io.video import Video, MediaVideo from sleap.io.dataset import Labels, load_file from sleap.io.legacy import load_labels_json_old +from sleap.io.format.ndx_pose import NDXPoseAdaptor +from sleap.io.format import filehandle from sleap.gui.suggestions import VideoFrameSuggestions, SuggestionFrame +from tests.io.test_formats import assert_read_labels_match TEST_H5_DATASET = "tests/data/hdf5_format_v1/training.scale=0.50,sigma=10.h5" @@ -1493,3 +1496,14 @@ def test_remove_untracked_instances(min_tracks_2node_labels): # Test function with remove_empty_frames=True labels.remove_untracked_instances(remove_empty_frames=True) assert all([len(lf.instances) > 0 for lf in labels.labeled_frames]) + + +def test_export_nwb(centered_pair_predictions: Labels, tmpdir): + filename = str(PurePath(tmpdir, "ndx_pose_test.nwb")) + + # Export to NWB file + centered_pair_predictions.export_nwb(filename) + + # Read from NWB file + read_labels = NDXPoseAdaptor.read(NDXPoseAdaptor, filehandle.FileHandle(filename)) + assert_read_labels_match(centered_pair_predictions, read_labels) diff --git a/tests/io/test_formats.py b/tests/io/test_formats.py index 27ced1bce..ec94bc4c3 100644 --- a/tests/io/test_formats.py +++ b/tests/io/test_formats.py @@ -320,26 +320,7 @@ def test_tracking_scores(tmpdir, centered_pair_predictions_slp_path): assert hasattr(instance, "tracking_score") -def test_nwb(centered_pair_predictions: Labels, small_robot_mp4_vid: Video, tmpdir): - """Test that `Labels` can be written to and recreated from an NWB file.""" - labels = centered_pair_predictions - filename = str(PurePath(tmpdir, "ndx_pose_test.nwb")) - - # Add another video with an untracked PredictedInstance - labels.videos.append(small_robot_mp4_vid) - pred_instance: PredictedInstance = PredictedInstance.from_instance( - labels[0].instances[0], score=5 - ) - pred_instance.track = None - lf = LabeledFrame(video=small_robot_mp4_vid, frame_idx=6, instances=[pred_instance]) - labels.append(lf) - - # Write to NWB file - NDXPoseAdaptor.write(NDXPoseAdaptor, filename, labels) - - # Read from NWB file - read_labels = NDXPoseAdaptor.read(NDXPoseAdaptor, filehandle.FileHandle(filename)) - +def assert_read_labels_match(labels, read_labels): # Labeled Frames assert len(read_labels.labeled_frames) == len(labels.labeled_frames) @@ -380,6 +361,38 @@ def test_nwb(centered_pair_predictions: Labels, small_robot_mp4_vid: Video, tmpd assert read_labels.skeleton.edge_inds == labels.skeleton.edge_inds assert len(read_labels.tracks) == len(labels.tracks) + +def test_nwb( + centered_pair_predictions: Labels, + small_robot_mp4_vid: Video, + tmpdir, +): + """Test that `Labels` can be written to and recreated from an NWB file.""" + + labels = centered_pair_predictions + filename = str(PurePath(tmpdir, "ndx_pose_test.nwb")) + + # Add another video with an untracked PredictedInstance + labels.videos.append(small_robot_mp4_vid) + pred_instance: PredictedInstance = PredictedInstance.from_instance( + labels[0].instances[0], score=5 + ) + pred_instance.track = None + lf = LabeledFrame(video=small_robot_mp4_vid, frame_idx=6, instances=[pred_instance]) + labels.append(lf) + + # Write to NWB file + NDXPoseAdaptor.write(NDXPoseAdaptor, filename, labels) + + # Read from NWB file + read_labels = NDXPoseAdaptor.read(NDXPoseAdaptor, filehandle.FileHandle(filename)) + assert_read_labels_match(labels, read_labels) + + # Append to NWB File (no changes expected) + NDXPoseAdaptor.write(NDXPoseAdaptor, filename, labels) + read_labels = NDXPoseAdaptor.read(NDXPoseAdaptor, filehandle.FileHandle(filename)) + assert_read_labels_match(labels, read_labels) + # Project with no predicted instances labels.instances = [] with pytest.raises(TypeError):