Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow user to set grayscale when replacing videos (mp4/avi only) #787

Merged
merged 16 commits into from
Jun 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 90 additions & 12 deletions sleap/gui/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,14 @@ class which inherits from `AppCommand` (or a more specialized class such as
import attr
import operator
import os
import cv2
import re
import sys
import subprocess

from enum import Enum
from glob import glob
from pathlib import PurePath
from pathlib import PurePath, Path
from typing import Callable, Dict, Iterator, List, Optional, Type, Tuple

import numpy as np
Expand Down Expand Up @@ -1516,30 +1517,107 @@ def do_action(context: CommandContext, params: dict):


class ReplaceVideo(EditCommand):
topics = [UpdateTopic.video]
topics = [UpdateTopic.video, UpdateTopic.frame]

@staticmethod
def do_action(context: CommandContext, params: dict):
new_paths = params["new_video_paths"]
def do_action(context: CommandContext, params: dict) -> bool:

import_list = params["import_list"]

for import_item, video in import_list:
import_params = import_item["params"]

# TODO: Will need to create a new backend if import has different extension.
roomrys marked this conversation as resolved.
Show resolved Hide resolved
if (
Path(video.backend.filename).suffix
!= Path(import_params["filename"]).suffix
):
raise TypeError(
"Importing videos with different extensions is not supported."
)
video.backend.reset(**import_params)

# Remove frames in video past last frame index
last_vid_frame = video.last_frame_idx
lfs: List[LabeledFrame] = list(context.labels.get(video))
if lfs is not None:
lfs = [lf for lf in lfs if lf.frame_idx > last_vid_frame]
context.labels.remove_frames(lfs)

for video, new_path in zip(context.labels.videos, new_paths):
if new_path != video.backend.filename:
video.backend.filename = new_path
video.backend.reset()
# Update seekbar and video length through callbacks
context.state.emit("video")

@staticmethod
def ask(context: CommandContext, params: dict) -> bool:
"""Shows gui for replacing videos in project."""
paths = [video.backend.filename for video in context.labels.videos]

okay = MissingFilesDialog(filenames=paths, replace=True).exec_()
def _get_truncation_message(truncation_messages, path, video):
reader = cv2.VideoCapture(path)
last_vid_frame = int(reader.get(cv2.CAP_PROP_FRAME_COUNT))
lfs: List[LabeledFrame] = list(context.labels.get(video))
if lfs is not None:
lfs.sort(key=lambda lf: lf.frame_idx)
last_lf_frame = lfs[-1].frame_idx
lfs = [lf for lf in lfs if lf.frame_idx > last_vid_frame]
roomrys marked this conversation as resolved.
Show resolved Hide resolved

# Message to warn users that labels will be removed if proceed
if last_lf_frame > last_vid_frame:
message = (
"<p><strong>Warning:</strong> Replacing this video will "
f"remove {len(lfs)} labeled frames.</p>"
f"<p><em>Current video</em>: <b>{Path(video.filename).name}</b>"
f" (last label at frame {last_lf_frame})<br>"
f"<em>Replacement video</em>: <b>{Path(path).name}"
f"</b> ({last_vid_frame} frames)</p>"
)
# Assumes that a project won't import the same video multiple times
truncation_messages[path] = message
roomrys marked this conversation as resolved.
Show resolved Hide resolved

return truncation_messages

# Warn user: newly added labels will be discarded if project is not saved
if not context.state["filename"] or context.state["has_changes"]:
QMessageBox(
text=("You have unsaved changes. Please save before replacing videos.")
).exec_()
return False
Comment on lines +1578 to +1583
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is triggered even when labels are not added/changed (but another "change" was registered). Is there a way to check that just the labels object has changes?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would be amazing -- I've looked into it in the past but it doesn't seem trivial. Search for "has_changes" across the codebase.

In general, it's just a tricky pattern to implement. Here's some relevant literature: https://refactoring.guru/design-patterns/state


# Select the videos we want to swap
old_paths = [video.backend.filename for video in context.labels.videos]
paths = list(old_paths)
okay = MissingFilesDialog(filenames=paths, replace=True).exec_()
if not okay:
return False

params["new_video_paths"] = paths
# Only return an import list for videos we swap
new_paths = [
(path, video_idx)
for video_idx, (path, old_path) in enumerate(zip(paths, old_paths))
if path != old_path
]

return True
new_paths = []
old_videos = dict()
all_videos = context.labels.videos
truncation_messages = dict()
for video_idx, (path, old_path) in enumerate(zip(paths, old_paths)):
if path != old_path:
new_paths.append(path)
old_videos[path] = all_videos[video_idx]
truncation_messages = _get_truncation_message(
truncation_messages, path, video=all_videos[video_idx]
)

import_list = ImportVideos().ask(
filenames=new_paths, messages=truncation_messages
)
# Remove videos that no longer correlate to filenames.
old_videos_to_replace = [
old_videos[imp["params"]["filename"]] for imp in import_list
]
params["import_list"] = zip(import_list, old_videos_to_replace)

return len(import_list) > 0


class RemoveVideo(EditCommand):
Expand Down
105 changes: 60 additions & 45 deletions sleap/gui/dialogs/importvideos.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,11 @@ class ImportVideos:
def __init__(self):
self.result = []

def ask(self, filenames: Optional[List[str]] = None):
def ask(
self,
filenames: Optional[List[str]] = None,
messages: Optional[Dict[str, str]] = None,
):
"""Runs the import UI.

1. Show file selection dialog.
Expand All @@ -59,17 +63,21 @@ def ask(self, filenames: Optional[List[str]] = None):
Returns:
List with dict of the parameters for each file to import.
"""
messages = dict() if messages is None else messages

if filenames is None:
filenames, filter = FileDialog.openMultiple(
None,
"Select videos to import...", # dialogue title
".", # initial path
"Any Video (*.h5 *.hd5v *.mp4 *.avi *.json);;HDF5 (*.h5 *.hd5v);;ImgStore (*.json);;Media Video (*.mp4 *.avi);;Any File (*.*)",
)

if len(filenames) > 0:
importer = ImportParamDialog(filenames)
importer = ImportParamDialog(filenames, messages)
importer.accepted.connect(lambda: importer.get_data(self.result))
importer.exec_()

return self.result

@classmethod
Expand All @@ -91,9 +99,12 @@ class ImportParamDialog(QDialog):
filenames (list): List of files we want to import.
"""

def __init__(self, filenames: List[str], *args, **kwargs):
def __init__(
self, filenames: List[str], messages: Dict[str, str] = None, *args, **kwargs
):
super(ImportParamDialog, self).__init__(*args, **kwargs)

messages = dict() if messages is None else messages
self.import_widgets = []

self.setWindowTitle("Video Import Options")
Expand Down Expand Up @@ -135,7 +146,6 @@ def __init__(self, filenames: List[str], *args, **kwargs):
outer_layout = QVBoxLayout()

scroll_widget = QScrollArea()
# scroll_widget.setWidgetResizable(False)
scroll_widget.setVerticalScrollBarPolicy(Qt.ScrollBarAlwaysOn)
scroll_widget.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff)

Expand All @@ -152,7 +162,10 @@ def __init__(self, filenames: List[str], *args, **kwargs):
this_type = import_type
break
if this_type is not None:
import_item_widget = ImportItemWidget(file_name, this_type)
message = messages[file_name] if file_name in messages else ""
import_item_widget = ImportItemWidget(
file_name, this_type, message=message
)
self.import_widgets.append(import_item_widget)
scroll_layout.addWidget(import_item_widget)
else:
Expand Down Expand Up @@ -197,6 +210,7 @@ def __init__(self, filenames: List[str], *args, **kwargs):
button_layout.addWidget(import_button)

outer_layout.addLayout(button_layout)
self.adjustSize()

self.setLayout(outer_layout)

Expand All @@ -220,14 +234,6 @@ def get_data(self, import_result=None):
import_result.append(import_item.get_data())
return import_result

def boundingRect(self) -> QRectF:
"""Method required by Qt."""
return QRectF()

def paint(self, painter, option, widget=None):
"""Method required by Qt."""
pass

def set_all_grayscale(self):
for import_item in self.import_widgets:
widget_elements = import_item.options_widget.widget_elements
Expand Down Expand Up @@ -269,7 +275,14 @@ class ImportItemWidget(QFrame):
import_type (dict): Data about user-selectable import parameters.
"""

def __init__(self, file_path: str, import_type: dict, *args, **kwargs):
def __init__(
self,
file_path: str,
import_type: Dict[str, Any],
message: str = "",
*args,
**kwargs,
):
super(ImportItemWidget, self).__init__(*args, **kwargs)

self.file_path = file_path
Expand All @@ -287,6 +300,9 @@ def __init__(self, file_path: str, import_type: dict, *args, **kwargs):
self.options_widget = ImportParamWidget(
parent=self, file_path=self.file_path, import_type=self.import_type
)

self.message_widget = MessageWidget(parent=self, message=message)

self.preview_widget = VideoPreviewWidget(parent=self)
self.preview_widget.setFixedSize(200, 200)

Expand All @@ -295,6 +311,7 @@ def __init__(self, file_path: str, import_type: dict, *args, **kwargs):
)

inner_layout.addWidget(self.options_widget)
inner_layout.addWidget(self.message_widget)
inner_layout.addWidget(self.preview_widget)
import_item_layout.addLayout(inner_layout)
self.setLayout(import_item_layout)
Expand Down Expand Up @@ -379,7 +396,7 @@ class ImportParamWidget(QWidget):

changed = Signal()

def __init__(self, file_path: str, import_type: dict, *args, **kwargs):
def __init__(self, file_path: str, import_type: Dict[str, Any], *args, **kwargs):
super(ImportParamWidget, self).__init__(*args, **kwargs)

self.file_path = file_path
Expand Down Expand Up @@ -516,13 +533,17 @@ def _find_h5_datasets(self, data_path, data_object) -> list:
)
return options

def boundingRect(self) -> QRectF:
"""Method required by Qt."""
return QRectF()

def paint(self, painter, option, widget=None):
"""Method required by Qt."""
pass
class MessageWidget(QWidget):
"""Widget to show message."""

def __init__(self, message: str = str(), *args, **kwargs):
super().__init__(*args, **kwargs)
self.message = QLabel(message)
self.message.setStyleSheet("color: red")
self.layout = QVBoxLayout()
self.layout.addWidget(self.message)
self.setLayout(self.layout)


class VideoPreviewWidget(QWidget):
Expand Down Expand Up @@ -585,34 +606,28 @@ def plot(self, idx=0):
# Display image
self.view.setImage(image)

def boundingRect(self) -> QRectF:
"""Method required by Qt."""
return QRectF()

def paint(self, painter, option, widget=None):
"""Method required by Qt."""
pass

# if __name__ == "__main__":

if __name__ == "__main__":
# app = QApplication([])

app = QApplication([])
# # import_list = ImportVideos().ask()

# import_list = ImportVideos().ask()
# filenames = [
# "tests/data/videos/centered_pair_small.mp4",
# "tests/data/videos/small_robot.mp4",
# ]

filenames = [
"tests/data/videos/centered_pair_small.mp4",
"tests/data/videos/small_robot.mp4",
]
# messages = {"tests/data/videos/small_robot.mp4": "Testing messages"}

import_list = []
importer = ImportParamDialog(filenames)
importer.accepted.connect(lambda: importer.get_data(import_list))
importer.exec_()
# import_list = []
# importer = ImportParamDialog(filenames, messages=messages)
# importer.accepted.connect(lambda: importer.get_data(import_list))
# importer.exec_()

for import_item in import_list:
vid = import_item["video_class"](**import_item["params"])
print(
"Imported video data: (%d, %d), %d f, %d c"
% (vid.width, vid.height, vid.frames, vid.channels)
)
# for import_item in import_list:
# vid = import_item["video_class"](**import_item["params"])
# print(
# "Imported video data: (%d, %d), %d f, %d c"
# % (vid.width, vid.height, vid.frames, vid.channels)
# )
15 changes: 14 additions & 1 deletion sleap/gui/dialogs/missingfiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import os

from pathlib import Path, PurePath
from typing import Callable, List

from PySide2 import QtWidgets, QtCore, QtGui
Expand Down Expand Up @@ -47,6 +48,7 @@ def __init__(

self.filenames = filenames
self.missing = missing
self.replace = replace

missing_count = sum(missing)

Expand Down Expand Up @@ -88,11 +90,22 @@ def locateFile(self, idx: int):

caption = f"Please locate {old_filename}..."
filters = [f"Missing file type (*{old_ext})", "Any File (*.*)"]
filters = [filters[0]] if self.replace else filters
new_filename, _ = FileDialog.open(
None, dir=None, caption=caption, filter=";;".join(filters)
)

if new_filename:
path_new_filename = Path(new_filename)
paths = [str(PurePath(fn)) for fn in self.filenames]
if str(path_new_filename) in paths:
# Do not allow same video to be imported more than once.
QtWidgets.QMessageBox(
text=(
f"The file <b>{path_new_filename.name}</b> cannot be added to the "
"project multiple times."
)
).exec_()
elif new_filename:
# Try using this change to find other missing files
self.setFilename(idx, new_filename)

Expand Down
Loading