Skip to content

Commit

Permalink
Allow passing in Labels to app.main (#1524)
Browse files Browse the repository at this point in the history
* Allow passing in `Labels` to `app.main`

* Load the labels object through command

* Add warning when unable to switch back to CPU mode
  • Loading branch information
roomrys authored Oct 11, 2023
1 parent ed77b49 commit 79f7fba
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 19 deletions.
37 changes: 31 additions & 6 deletions sleap/gui/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@
import platform
import random
import re
import traceback
from logging import getLogger
from pathlib import Path
from typing import Callable, List, Optional, Tuple

Expand Down Expand Up @@ -85,6 +87,9 @@
from sleap.util import parse_uri_path


logger = getLogger(__name__)


class MainWindow(QMainWindow):
"""The SLEAP GUI application.
Expand All @@ -101,6 +106,7 @@ class MainWindow(QMainWindow):
def __init__(
self,
labels_path: Optional[str] = None,
labels: Optional[Labels] = None,
reset: bool = False,
no_usage_data: bool = False,
*args,
Expand All @@ -118,7 +124,7 @@ def __init__(
self.setAcceptDrops(True)

self.state = GuiState()
self.labels = Labels()
self.labels = labels or Labels()

self.commands = CommandContext(
state=self.state, app=self, update_callback=self.on_data_update
Expand Down Expand Up @@ -175,8 +181,10 @@ def __init__(
print("Restoring GUI state...")
self.restoreState(prefs["window state"])

if labels_path:
if labels_path is not None:
self.commands.loadProjectFile(filename=labels_path)
elif labels is not None:
self.commands.loadLabelsObject(labels=labels)
else:
self.state["project_loaded"] = False

Expand Down Expand Up @@ -1594,8 +1602,7 @@ def _show_keyboard_shortcuts_window(self):
ShortcutDialog().exec_()


def main(args: Optional[list] = None):
"""Starts new instance of app."""
def create_parser():

import argparse

Expand Down Expand Up @@ -1635,6 +1642,13 @@ def main(args: Optional[list] = None):
default=False,
)

return parser


def main(args: Optional[list] = None, labels: Optional[Labels] = None):
"""Starts new instance of app."""

parser = create_parser()
args = parser.parse_args(args)

if args.nonnative:
Expand All @@ -1651,12 +1665,23 @@ def main(args: Optional[list] = None):
app.setWindowIcon(QtGui.QIcon(sleap.util.get_package_file("gui/icon.png")))

window = MainWindow(
labels_path=args.labels_path, reset=args.reset, no_usage_data=args.no_usage_data
labels_path=args.labels_path,
labels=labels,
reset=args.reset,
no_usage_data=args.no_usage_data,
)
window.showMaximized()

# Disable GPU in GUI process. This does not affect subprocesses.
sleap.use_cpu_only()
try:
sleap.use_cpu_only()
except RuntimeError: # Visible devices cannot be modified after being initialized
logger.warning(
"Running processes on the GPU. Restarting your GUI should allow switching "
"back to CPU-only mode.\n"
"Received the following error when trying to switch back to CPU-only mode:"
)
traceback.print_exc()

# Print versions.
print()
Expand Down
25 changes: 12 additions & 13 deletions sleap/gui/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class which inherits from `AppCommand` (or a more specialized class such as
from enum import Enum
from glob import glob
from pathlib import Path, PurePath
from typing import Callable, Dict, Iterator, List, Optional, Tuple, Type
from typing import Callable, Dict, Iterator, List, Optional, Tuple, Type, Union

import attr
import cv2
Expand Down Expand Up @@ -260,16 +260,15 @@ def loadLabelsObject(self, labels: Labels, filename: Optional[str] = None):
"""
self.execute(LoadLabelsObject, labels=labels, filename=filename)

def loadProjectFile(self, filename: str):
def loadProjectFile(self, filename: Union[str, Labels]):
"""Loads given labels file into GUI.
Args:
filename: The path to the saved labels dataset. If None,
then don't do anything.
filename: The path to the saved labels dataset or the `Labels` object.
If None, then don't do anything.
Returns:
None
"""
self.execute(LoadProjectFile, filename=filename)

Expand Down Expand Up @@ -647,9 +646,8 @@ def do_action(context: "CommandContext", params: dict):
Returns:
None.
"""
filename = params["filename"]
filename = params.get("filename", None) # If called with just a Labels object
labels: Labels = params["labels"]

context.state["labels"] = labels
Expand All @@ -669,7 +667,9 @@ def do_action(context: "CommandContext", params: dict):
context.state["video"] = labels.videos[0]

context.state["project_loaded"] = True
context.state["has_changes"] = params.get("changed_on_load", False)
context.state["has_changes"] = params.get("changed_on_load", False) or (
filename is None
)

# This is not listed as an edit command since we want a clean changestack
context.app.on_data_update([UpdateTopic.project, UpdateTopic.all])
Expand All @@ -683,17 +683,16 @@ def ask(context: "CommandContext", params: dict):
if len(filename) == 0:
return

gui_video_callback = Labels.make_gui_video_callback(
search_paths=[os.path.dirname(filename)], context=params
)

has_loaded = False
labels = None
if type(filename) == Labels:
if isinstance(filename, Labels):
labels = filename
filename = None
has_loaded = True
else:
gui_video_callback = Labels.make_gui_video_callback(
search_paths=[os.path.dirname(filename)], context=params
)
try:
labels = Labels.load_file(filename, video_search=gui_video_callback)
has_loaded = True
Expand Down

0 comments on commit 79f7fba

Please sign in to comment.