Skip to content

Commit

Permalink
Centralize video extensions (#1244)
Browse files Browse the repository at this point in the history
* Add video extension support lists to io.video module

* Use centralized extension definitions

* Remove some unused code

* Add some coverage for indirect coverage reduction
  • Loading branch information
talmo authored Mar 28, 2023
1 parent 55e4707 commit b305eda
Show file tree
Hide file tree
Showing 7 changed files with 86 additions and 69 deletions.
7 changes: 3 additions & 4 deletions sleap/gui/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,9 @@
import sleap
from sleap.gui.dialogs.metrics import MetricsTableDialog
from sleap.skeleton import Skeleton
from sleap.instance import Instance, LabeledFrame
from sleap.instance import Instance
from sleap.io.dataset import Labels
from sleap.io.video import available_video_exts
from sleap.info.summary import StatisticSeries
from sleap.gui.commands import CommandContext, UpdateTopic
from sleap.gui.widgets.views import CollapsibleWidget
Expand Down Expand Up @@ -275,8 +276,6 @@ def dropEvent(self, event):

exts = [Path(f).suffix for f in filenames]

VIDEO_EXTS = (".mp4", ".avi", ".h5") # TODO: make this list global

if len(exts) == 1 and exts[0].lower() == ".slp":
if self.state["project_loaded"]:
# Merge
Expand All @@ -285,7 +284,7 @@ def dropEvent(self, event):
# Load
self.commands.openProject(filename=filenames[0], first_open=True)

elif all([ext.lower() in VIDEO_EXTS for ext in exts]):
elif all([ext.lower() in available_video_exts() for ext in exts]):
# Import videos
self.commands.showImportVideos(filenames=filenames)

Expand Down
60 changes: 28 additions & 32 deletions sleap/gui/dialogs/importvideos.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,15 @@
)

from sleap.gui.widgets.video import GraphicsView
from sleap.io.video import Video
from sleap.io.video import (
Video,
MediaVideo,
HDF5Video,
NumpyVideo,
ImgStoreVideo,
SingleImageVideo,
available_video_exts,
)
from sleap.gui.dialogs.filedialog import FileDialog

import h5py
Expand Down Expand Up @@ -67,11 +75,25 @@ def ask(
messages = dict() if messages is None else messages

if filenames is None:

any_video_exts = " ".join(["*." + ext for ext in available_video_exts()])
media_video_exts = " ".join(["*." + ext for ext in MediaVideo.EXTS])
hdf5_video_exts = " ".join(["*." + ext for ext in HDF5Video.EXTS])
numpy_video_exts = " ".join(["*." + ext for ext in NumpyVideo.EXTS])
imgstore_video_exts = " ".join(["*." + ext for ext in ImgStoreVideo.EXTS])
siv_video_exts = " ".join(["*." + ext for ext in SingleImageVideo.EXTS])

filenames, filter = FileDialog.openMultiple(
None,
"Select videos to import...", # dialogue title
".", # initial path
"Any Video (*.h5 *.hd5v *.mp4 *.avi *.json);;HDF5 (*.h5 *.hd5v);;ImgStore (*.json);;Media Video (*.mp4 *.avi);;Any File (*.*)",
f"Any Video ({any_video_exts});;"
f"Media ({media_video_exts});;"
f"HDF5 ({hdf5_video_exts});;"
f"Numpy ({numpy_video_exts});;"
f"ImgStore ({imgstore_video_exts});;"
f"Single image ({siv_video_exts});;"
"Any File (*.*)",
)

if len(filenames) > 0:
Expand Down Expand Up @@ -113,7 +135,7 @@ def __init__(
self.import_types = [
{
"video_type": "hdf5",
"match": "h5,hdf5",
"match": ",".join(HDF5Video.EXTS),
"video_class": Video.from_hdf5,
"params": [
{
Expand All @@ -132,19 +154,19 @@ def __init__(
},
{
"video_type": "mp4",
"match": "mp4,avi",
"match": ",".join(MediaVideo.EXTS),
"video_class": Video.from_media,
"params": [{"name": "grayscale", "type": "check"}],
},
{
"video_type": "imgstore",
"match": "json",
"match": ",".join(ImgStoreVideo.EXTS),
"video_class": Video.from_filename,
"params": [],
},
{
"video_type": "single_image",
"match": "jpg,png,tif,jpeg,tiff",
"match": ",".join(SingleImageVideo.EXTS),
"video_class": Video.from_filename,
"params": [{"name": "grayscale", "type": "check"}],
},
Expand Down Expand Up @@ -635,29 +657,3 @@ def plot(self, idx=0):

# Display image
self.view.setImage(image)


# if __name__ == "__main__":

# app = QApplication([])

# # import_list = ImportVideos().ask()

# filenames = [
# "tests/data/videos/centered_pair_small.mp4",
# "tests/data/videos/small_robot.mp4",
# ]

# messages = {"tests/data/videos/small_robot.mp4": "Testing messages"}

# import_list = []
# importer = ImportParamDialog(filenames, messages=messages)
# importer.accepted.connect(lambda: importer.get_data(import_list))
# importer.exec_()

# for import_item in import_list:
# vid = import_item["video_class"](**import_item["params"])
# print(
# "Imported video data: (%d, %d), %d f, %d c"
# % (vid.width, vid.height, vid.frames, vid.channels)
# )
7 changes: 0 additions & 7 deletions sleap/gui/dialogs/missingfiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,10 +234,3 @@ def headerData(
elif orientation == QtCore.Qt.Vertical:
return section
return None


# if __name__ == "__main__":
# app = QtWidgets.QApplication()
# win = MissingFilesDialog(["m:/centered_pair_small.mp4", "m:/small_robot.mp4"])
# result = win.exec_()
# print(result)
17 changes: 0 additions & 17 deletions sleap/gui/learning/receptivefield.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,20 +227,3 @@ def _set_field_size(self, size: Optional[int] = None, scale: float = 1.0):
self.box.setRect(
scene_center.x(), scene_center.y(), scaled_box_size, scaled_box_size
)


def demo_receptive_field():
app = QtWidgets.QApplication([])

video = Video.from_filename("tests/data/videos/centered_pair_small.mp4")

win = ReceptiveFieldImageWidget()
win.setImage(video.get_frame(0))
win._set_field_size(50)

win.show()
app.exec_()


if __name__ == "__main__":
demo_receptive_field()
33 changes: 29 additions & 4 deletions sleap/io/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ class HDF5Video:
convert_range: Whether we should convert data to [0, 255]-range
"""

EXTS = ("h5", "hdf5", "slp")

filename: str = attr.ib(default=None)
dataset: str = attr.ib(default=None)
input_format: str = attr.ib(default="channels_last")
Expand Down Expand Up @@ -349,6 +351,8 @@ class MediaVideo:
bgr: Whether color channels ordered as (blue, green, red).
"""

EXTS = ("mp4", "avi", "mov", "mj2", "mkv")

filename: str = attr.ib()
grayscale: bool = attr.ib()
bgr: bool = attr.ib(default=True)
Expand Down Expand Up @@ -514,6 +518,8 @@ class NumpyVideo:
* numpy data shape: (frames, height, width, channels)
"""

EXTS = ("npy", "npz")

filename: Union[str, np.ndarray] = attr.ib()

def __attrs_post_init__(self):
Expand Down Expand Up @@ -621,6 +627,8 @@ class ImgStoreVideo:
indices on :class:`LabeledFrame` objects in the dataset.
"""

EXTS = ("json", "yaml")

filename: str = attr.ib(default=None)
index_by_original: bool = attr.ib(default=True)
_store_ = None
Expand Down Expand Up @@ -800,6 +808,8 @@ class SingleImageVideo:
filenames: Files to load as video.
"""

EXTS = ("jpg", "jpeg", "png", "tif", "tiff")

filename: Optional[str] = attr.ib(default=None)
filenames: Optional[List[str]] = attr.ib(factory=list)
height_: Optional[int] = attr.ib(default=None)
Expand Down Expand Up @@ -1251,16 +1261,16 @@ def from_filename(cls, filename: str, *args, **kwargs) -> "Video":
"""
filename = Video.fixup_path(filename)

if filename.lower().endswith(("h5", "hdf5", "slp")):
if filename.lower().endswith(HDF5Video.EXTS):
backend_class = HDF5Video
elif filename.endswith(("npy")):
elif filename.endswith(NumpyVideo.EXTS):
backend_class = NumpyVideo
elif filename.lower().endswith(("mp4", "avi", "mov")):
elif filename.lower().endswith(MediaVideo.EXTS):
backend_class = MediaVideo
kwargs["dataset"] = "" # prevent serialization from breaking
elif os.path.isdir(filename) or "metadata.yaml" in filename:
backend_class = ImgStoreVideo
elif filename.lower().endswith(("jpg", "jpeg", "png", "tif", "tiff")):
elif filename.lower().endswith(SingleImageVideo.EXTS):
backend_class = SingleImageVideo
else:
raise ValueError("Could not detect backend for specified filename.")
Expand Down Expand Up @@ -1600,6 +1610,21 @@ def fixup_path(
return path


def available_video_exts() -> Tuple[str]:
"""Return tuple of supported video extensions.
Returns:
Tuple of supported video extensions.
"""
return (
MediaVideo.EXTS
+ HDF5Video.EXTS
+ NumpyVideo.EXTS
+ SingleImageVideo.EXTS
+ ImgStoreVideo.EXTS
)


def load_video(
filename: str,
grayscale: Optional[bool] = None,
Expand Down
8 changes: 4 additions & 4 deletions sleap/nn/evals.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,10 +173,10 @@ def compute_oks(
Ronch & Perona. "Benchmarking and Error Diagnosis in Multi-Instance Pose
Estimation." ICCV (2017).
"""
if points_gt.ndim != 3 or points_pr.ndim != 3:
raise ValueError(
"Points must be rank-3 with shape (n_instances, n_nodes, n_ed)."
)
if points_gt.ndim == 2:
points_gt = np.expand_dims(points_gt, axis=0)
if points_pr.ndim == 2:
points_pr = np.expand_dims(points_pr, axis=0)

if scale is None:
scale = compute_instance_area(points_gt)
Expand Down
23 changes: 22 additions & 1 deletion tests/nn/test_evals.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,32 @@
import numpy as np
import sleap
from sleap.nn.evals import load_metrics
from sleap.nn.evals import load_metrics, compute_oks


sleap.use_cpu_only()


def test_compute_oks():
inst_gt = np.array([[0, 0], [1, 1], [2, 2]]).astype("float32")
inst_pr = np.array([[0, 0], [1, 1], [2, 2]]).astype("float32")
oks = compute_oks(inst_gt, inst_pr)
np.testing.assert_allclose(oks, 1)

inst_pr = np.array([[0, 0], [1, 1], [np.nan, np.nan]]).astype("float32")
oks = compute_oks(inst_gt, inst_pr)
np.testing.assert_allclose(oks, 2 / 3)

inst_gt = np.array([[0, 0], [1, 1], [np.nan, np.nan]]).astype("float32")
inst_pr = np.array([[0, 0], [1, 1], [2, 2]]).astype("float32")
oks = compute_oks(inst_gt, inst_pr)
np.testing.assert_allclose(oks, 1)

inst_gt = np.array([[0, 0], [1, 1], [np.nan, np.nan]]).astype("float32")
inst_pr = np.array([[0, 0], [1, 1], [np.nan, np.nan]]).astype("float32")
oks = compute_oks(inst_gt, inst_pr)
np.testing.assert_allclose(oks, 1)


def test_load_metrics(min_centered_instance_model_path):
model_path = min_centered_instance_model_path

Expand Down

0 comments on commit b305eda

Please sign in to comment.