Skip to content

Commit

Permalink
Allow user to set grayscale when replacing videos (mp4/avi only) (#787)
Browse files Browse the repository at this point in the history
  • Loading branch information
roomrys authored Jun 24, 2022
1 parent 5e27751 commit 37f1423
Show file tree
Hide file tree
Showing 4 changed files with 218 additions and 59 deletions.
102 changes: 90 additions & 12 deletions sleap/gui/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,14 @@ class which inherits from `AppCommand` (or a more specialized class such as
import attr
import operator
import os
import cv2
import re
import sys
import subprocess

from enum import Enum
from glob import glob
from pathlib import PurePath
from pathlib import PurePath, Path
from typing import Callable, Dict, Iterator, List, Optional, Type, Tuple

import numpy as np
Expand Down Expand Up @@ -1516,30 +1517,107 @@ def do_action(context: CommandContext, params: dict):


class ReplaceVideo(EditCommand):
topics = [UpdateTopic.video]
topics = [UpdateTopic.video, UpdateTopic.frame]

@staticmethod
def do_action(context: CommandContext, params: dict):
new_paths = params["new_video_paths"]
def do_action(context: CommandContext, params: dict) -> bool:

import_list = params["import_list"]

for import_item, video in import_list:
import_params = import_item["params"]

# TODO: Will need to create a new backend if import has different extension.
if (
Path(video.backend.filename).suffix
!= Path(import_params["filename"]).suffix
):
raise TypeError(
"Importing videos with different extensions is not supported."
)
video.backend.reset(**import_params)

# Remove frames in video past last frame index
last_vid_frame = video.last_frame_idx
lfs: List[LabeledFrame] = list(context.labels.get(video))
if lfs is not None:
lfs = [lf for lf in lfs if lf.frame_idx > last_vid_frame]
context.labels.remove_frames(lfs)

for video, new_path in zip(context.labels.videos, new_paths):
if new_path != video.backend.filename:
video.backend.filename = new_path
video.backend.reset()
# Update seekbar and video length through callbacks
context.state.emit("video")

@staticmethod
def ask(context: CommandContext, params: dict) -> bool:
"""Shows gui for replacing videos in project."""
paths = [video.backend.filename for video in context.labels.videos]

okay = MissingFilesDialog(filenames=paths, replace=True).exec_()
def _get_truncation_message(truncation_messages, path, video):
reader = cv2.VideoCapture(path)
last_vid_frame = int(reader.get(cv2.CAP_PROP_FRAME_COUNT))
lfs: List[LabeledFrame] = list(context.labels.get(video))
if lfs is not None:
lfs.sort(key=lambda lf: lf.frame_idx)
last_lf_frame = lfs[-1].frame_idx
lfs = [lf for lf in lfs if lf.frame_idx > last_vid_frame]

# Message to warn users that labels will be removed if proceed
if last_lf_frame > last_vid_frame:
message = (
"<p><strong>Warning:</strong> Replacing this video will "
f"remove {len(lfs)} labeled frames.</p>"
f"<p><em>Current video</em>: <b>{Path(video.filename).name}</b>"
f" (last label at frame {last_lf_frame})<br>"
f"<em>Replacement video</em>: <b>{Path(path).name}"
f"</b> ({last_vid_frame} frames)</p>"
)
# Assumes that a project won't import the same video multiple times
truncation_messages[path] = message

return truncation_messages

# Warn user: newly added labels will be discarded if project is not saved
if not context.state["filename"] or context.state["has_changes"]:
QMessageBox(
text=("You have unsaved changes. Please save before replacing videos.")
).exec_()
return False

# Select the videos we want to swap
old_paths = [video.backend.filename for video in context.labels.videos]
paths = list(old_paths)
okay = MissingFilesDialog(filenames=paths, replace=True).exec_()
if not okay:
return False

params["new_video_paths"] = paths
# Only return an import list for videos we swap
new_paths = [
(path, video_idx)
for video_idx, (path, old_path) in enumerate(zip(paths, old_paths))
if path != old_path
]

return True
new_paths = []
old_videos = dict()
all_videos = context.labels.videos
truncation_messages = dict()
for video_idx, (path, old_path) in enumerate(zip(paths, old_paths)):
if path != old_path:
new_paths.append(path)
old_videos[path] = all_videos[video_idx]
truncation_messages = _get_truncation_message(
truncation_messages, path, video=all_videos[video_idx]
)

import_list = ImportVideos().ask(
filenames=new_paths, messages=truncation_messages
)
# Remove videos that no longer correlate to filenames.
old_videos_to_replace = [
old_videos[imp["params"]["filename"]] for imp in import_list
]
params["import_list"] = zip(import_list, old_videos_to_replace)

return len(import_list) > 0


class RemoveVideo(EditCommand):
Expand Down
105 changes: 60 additions & 45 deletions sleap/gui/dialogs/importvideos.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,11 @@ class ImportVideos:
def __init__(self):
self.result = []

def ask(self, filenames: Optional[List[str]] = None):
def ask(
self,
filenames: Optional[List[str]] = None,
messages: Optional[Dict[str, str]] = None,
):
"""Runs the import UI.
1. Show file selection dialog.
Expand All @@ -59,17 +63,21 @@ def ask(self, filenames: Optional[List[str]] = None):
Returns:
List with dict of the parameters for each file to import.
"""
messages = dict() if messages is None else messages

if filenames is None:
filenames, filter = FileDialog.openMultiple(
None,
"Select videos to import...", # dialogue title
".", # initial path
"Any Video (*.h5 *.hd5v *.mp4 *.avi *.json);;HDF5 (*.h5 *.hd5v);;ImgStore (*.json);;Media Video (*.mp4 *.avi);;Any File (*.*)",
)

if len(filenames) > 0:
importer = ImportParamDialog(filenames)
importer = ImportParamDialog(filenames, messages)
importer.accepted.connect(lambda: importer.get_data(self.result))
importer.exec_()

return self.result

@classmethod
Expand All @@ -91,9 +99,12 @@ class ImportParamDialog(QDialog):
filenames (list): List of files we want to import.
"""

def __init__(self, filenames: List[str], *args, **kwargs):
def __init__(
self, filenames: List[str], messages: Dict[str, str] = None, *args, **kwargs
):
super(ImportParamDialog, self).__init__(*args, **kwargs)

messages = dict() if messages is None else messages
self.import_widgets = []

self.setWindowTitle("Video Import Options")
Expand Down Expand Up @@ -135,7 +146,6 @@ def __init__(self, filenames: List[str], *args, **kwargs):
outer_layout = QVBoxLayout()

scroll_widget = QScrollArea()
# scroll_widget.setWidgetResizable(False)
scroll_widget.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOn)
scroll_widget.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff)

Expand All @@ -152,7 +162,10 @@ def __init__(self, filenames: List[str], *args, **kwargs):
this_type = import_type
break
if this_type is not None:
import_item_widget = ImportItemWidget(file_name, this_type)
message = messages[file_name] if file_name in messages else ""
import_item_widget = ImportItemWidget(
file_name, this_type, message=message
)
self.import_widgets.append(import_item_widget)
scroll_layout.addWidget(import_item_widget)
else:
Expand Down Expand Up @@ -197,6 +210,7 @@ def __init__(self, filenames: List[str], *args, **kwargs):
button_layout.addWidget(import_button)

outer_layout.addLayout(button_layout)
self.adjustSize()

self.setLayout(outer_layout)

Expand All @@ -220,14 +234,6 @@ def get_data(self, import_result=None):
import_result.append(import_item.get_data())
return import_result

def boundingRect(self) -> QRectF:
"""Method required by Qt."""
return QRectF()

def paint(self, painter, option, widget=None):
"""Method required by Qt."""
pass

def set_all_grayscale(self):
for import_item in self.import_widgets:
widget_elements = import_item.options_widget.widget_elements
Expand Down Expand Up @@ -269,7 +275,14 @@ class ImportItemWidget(QFrame):
import_type (dict): Data about user-selectable import parameters.
"""

def __init__(self, file_path: str, import_type: dict, *args, **kwargs):
def __init__(
self,
file_path: str,
import_type: Dict[str, Any],
message: str = "",
*args,
**kwargs,
):
super(ImportItemWidget, self).__init__(*args, **kwargs)

self.file_path = file_path
Expand All @@ -287,6 +300,9 @@ def __init__(self, file_path: str, import_type: dict, *args, **kwargs):
self.options_widget = ImportParamWidget(
parent=self, file_path=self.file_path, import_type=self.import_type
)

self.message_widget = MessageWidget(parent=self, message=message)

self.preview_widget = VideoPreviewWidget(parent=self)
self.preview_widget.setFixedSize(200, 200)

Expand All @@ -295,6 +311,7 @@ def __init__(self, file_path: str, import_type: dict, *args, **kwargs):
)

inner_layout.addWidget(self.options_widget)
inner_layout.addWidget(self.message_widget)
inner_layout.addWidget(self.preview_widget)
import_item_layout.addLayout(inner_layout)
self.setLayout(import_item_layout)
Expand Down Expand Up @@ -379,7 +396,7 @@ class ImportParamWidget(QWidget):

changed = Signal()

def __init__(self, file_path: str, import_type: dict, *args, **kwargs):
def __init__(self, file_path: str, import_type: Dict[str, Any], *args, **kwargs):
super(ImportParamWidget, self).__init__(*args, **kwargs)

self.file_path = file_path
Expand Down Expand Up @@ -516,13 +533,17 @@ def _find_h5_datasets(self, data_path, data_object) -> list:
)
return options

def boundingRect(self) -> QRectF:
"""Method required by Qt."""
return QRectF()

def paint(self, painter, option, widget=None):
"""Method required by Qt."""
pass
class MessageWidget(QWidget):
"""Widget to show message."""

def __init__(self, message: str = str(), *args, **kwargs):
super().__init__(*args, **kwargs)
self.message = QLabel(message)
self.message.setStyleSheet("color: red")
self.layout = QVBoxLayout()
self.layout.addWidget(self.message)
self.setLayout(self.layout)


class VideoPreviewWidget(QWidget):
Expand Down Expand Up @@ -585,34 +606,28 @@ def plot(self, idx=0):
# Display image
self.view.setImage(image)

def boundingRect(self) -> QRectF:
"""Method required by Qt."""
return QRectF()

def paint(self, painter, option, widget=None):
"""Method required by Qt."""
pass

# if __name__ == "__main__":

if __name__ == "__main__":
# app = QApplication([])

app = QApplication([])
# # import_list = ImportVideos().ask()

# import_list = ImportVideos().ask()
# filenames = [
# "tests/data/videos/centered_pair_small.mp4",
# "tests/data/videos/small_robot.mp4",
# ]

filenames = [
"tests/data/videos/centered_pair_small.mp4",
"tests/data/videos/small_robot.mp4",
]
# messages = {"tests/data/videos/small_robot.mp4": "Testing messages"}

import_list = []
importer = ImportParamDialog(filenames)
importer.accepted.connect(lambda: importer.get_data(import_list))
importer.exec_()
# import_list = []
# importer = ImportParamDialog(filenames, messages=messages)
# importer.accepted.connect(lambda: importer.get_data(import_list))
# importer.exec_()

for import_item in import_list:
vid = import_item["video_class"](**import_item["params"])
print(
"Imported video data: (%d, %d), %d f, %d c"
% (vid.width, vid.height, vid.frames, vid.channels)
)
# for import_item in import_list:
# vid = import_item["video_class"](**import_item["params"])
# print(
# "Imported video data: (%d, %d), %d f, %d c"
# % (vid.width, vid.height, vid.frames, vid.channels)
# )
15 changes: 14 additions & 1 deletion sleap/gui/dialogs/missingfiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import os

from pathlib import Path, PurePath
from typing import Callable, List

from PySide2 import QtWidgets, QtCore, QtGui
Expand Down Expand Up @@ -47,6 +48,7 @@ def __init__(

self.filenames = filenames
self.missing = missing
self.replace = replace

missing_count = sum(missing)

Expand Down Expand Up @@ -88,11 +90,22 @@ def locateFile(self, idx: int):

caption = f"Please locate {old_filename}..."
filters = [f"Missing file type (*{old_ext})", "Any File (*.*)"]
filters = [filters[0]] if self.replace else filters
new_filename, _ = FileDialog.open(
None, dir=None, caption=caption, filter=";;".join(filters)
)

if new_filename:
path_new_filename = Path(new_filename)
paths = [str(PurePath(fn)) for fn in self.filenames]
if str(path_new_filename) in paths:
# Do not allow same video to be imported more than once.
QtWidgets.QMessageBox(
text=(
f"The file <b>{path_new_filename.name}</b> cannot be added to the "
"project multiple times."
)
).exec_()
elif new_filename:
# Try using this change to find other missing files
self.setFilename(idx, new_filename)

Expand Down
Loading

0 comments on commit 37f1423

Please sign in to comment.