Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix hang on model manager install unit tests #5835

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion invokeai/app/services/download/download_default.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,6 @@ def _download_next_item(self) -> None:
job.job_started = get_iso_timestamp()
self._do_download(job)
self._signal_job_complete(job)

except (OSError, HTTPError) as excp:
job.error_type = excp.__class__.__name__ + f"({str(excp)})"
job.error = traceback.format_exc()
Expand Down
6 changes: 6 additions & 0 deletions invokeai/app/services/model_install/model_install_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class InstallStatus(str, Enum):

WAITING = "waiting" # waiting to be dequeued
DOWNLOADING = "downloading" # downloading of model files in process
DOWNLOADS_DONE = "downloads_done" # downloading done, waiting to run
RUNNING = "running" # being processed
COMPLETED = "completed" # finished running
ERROR = "error" # terminated with an error message
Expand Down Expand Up @@ -229,6 +230,11 @@ def downloading(self) -> bool:
"""Return true if job is downloading."""
return self.status == InstallStatus.DOWNLOADING

@property
def downloads_done(self) -> bool:
"""Return true if job's downloads ae done."""
return self.status == InstallStatus.DOWNLOADS_DONE

@property
def running(self) -> bool:
"""Return true if job is running."""
Expand Down
21 changes: 5 additions & 16 deletions invokeai/app/services/model_install/model_install_default.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
ModelRepoVariant,
ModelType,
)
from invokeai.backend.model_manager.hash import FastModelHash
from invokeai.backend.model_manager.metadata import (
AnyModelRepoMetadata,
CivitaiMetadataFetch,
Expand Down Expand Up @@ -153,7 +152,6 @@ def install_path(
config["source"] = model_path.resolve().as_posix()

info: AnyModelConfig = self._probe_model(Path(model_path), config)
old_hash = info.current_hash

if preferred_name := config.get("name"):
preferred_name = Path(preferred_name).with_suffix(model_path.suffix)
Expand All @@ -167,8 +165,6 @@ def install_path(
raise DuplicateModelException(
f"A model named {model_path.name} is already installed at {dest_path.as_posix()}"
) from excp
new_hash = FastModelHash.hash(new_path)
assert new_hash == old_hash, f"{model_path}: Model hash changed during installation, possibly corrupted."

return self._register(
new_path,
Expand Down Expand Up @@ -284,7 +280,7 @@ def sync_to_config(self) -> None:
def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]: # noqa D102
self._cached_model_paths = {Path(x.path) for x in self.record_store.all_models()}
callback = self._scan_install if install else self._scan_register
search = ModelSearch(on_model_found=callback)
search = ModelSearch(on_model_found=callback, config=self._app_config)
self._models_installed.clear()
search.search(scan_dir)
return list(self._models_installed)
Expand Down Expand Up @@ -370,7 +366,7 @@ def _install_next_item(self) -> None:
self._signal_job_errored(job)

elif (
job.waiting or job.downloading
job.waiting or job.downloads_done
): # local jobs will be in waiting state, remote jobs will be downloading state
job.total_bytes = self._stat_size(job.local_path)
job.bytes = job.total_bytes
Expand Down Expand Up @@ -448,7 +444,7 @@ def _scan_models_directory(self) -> None:
installed.update(self.scan_directory(models_dir))
self._logger.info(f"{len(installed)} new models registered; {len(defunct_models)} unregistered")

def _sync_model_path(self, key: str, ignore_hash_change: bool = False) -> AnyModelConfig:
def _sync_model_path(self, key: str) -> AnyModelConfig:
"""
Move model into the location indicated by its basetype, type and name.

Expand All @@ -469,14 +465,7 @@ def _sync_model_path(self, key: str, ignore_hash_change: bool = False) -> AnyMod
new_path = models_dir / model.base.value / model.type.value / model.name
self._logger.info(f"Moving {model.name} to {new_path}.")
new_path = self._move_model(old_path, new_path)
new_hash = FastModelHash.hash(new_path)
model.path = new_path.relative_to(models_dir).as_posix()
if model.current_hash != new_hash:
assert (
ignore_hash_change
), f"{model.name}: Model hash changed during installation, model is possibly corrupted"
model.current_hash = new_hash
self._logger.info(f"Model has new hash {model.current_hash}, but will continue to be identified by {key}")
self.record_store.update_model(key, model)
return model

Expand Down Expand Up @@ -749,8 +738,8 @@ def _download_complete_callback(self, download_job: DownloadJob) -> None:
self._download_cache.pop(download_job.source, None)

# are there any more active jobs left in this task?
if all(x.complete for x in install_job.download_parts):
# now enqueue job for actual installation into the models directory
if install_job.downloading and all(x.complete for x in install_job.download_parts):
install_job.status = InstallStatus.DOWNLOADS_DONE
self._install_queue.put(install_job)

# Let other threads know that the number of downloads has changed
Expand Down
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ dependencies = [
"pre-commit",
"pytest>6.0.0",
"pytest-cov",
"pytest-timeout",
"pytest-datadir",
"requests_testadapter",
"httpx",
Expand Down Expand Up @@ -186,9 +187,10 @@ version = { attr = "invokeai.version.__version__" }

#=== Begin: PyTest and Coverage
[tool.pytest.ini_options]
addopts = "--cov-report term --cov-report html --cov-report xml --strict-markers -m \"not slow\""
addopts = "--cov-report term --cov-report html --cov-report xml --strict-markers --timeout 60 -m \"not slow\""
markers = [
"slow: Marks tests as slow. Disabled by default. To run all tests, use -m \"\". To run only slow tests, use -m \"slow\".",
"timeout: Marks the timeout override."
]
[tool.coverage.run]
branch = true
Expand Down
24 changes: 17 additions & 7 deletions tests/app/routers/test_images.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import os
from pathlib import Path
from typing import Any

import pytest
from fastapi import BackgroundTasks
from fastapi.testclient import TestClient

Expand All @@ -9,7 +11,11 @@
from invokeai.app.services.board_records.board_records_common import BoardRecord
from invokeai.app.services.invoker import Invoker

client = TestClient(app)

@pytest.fixture(autouse=True, scope="module")
def client(invokeai_root_dir: Path) -> TestClient:
os.environ["INVOKEAI_ROOT"] = invokeai_root_dir.as_posix()
return TestClient(app)


class MockApiDependencies(ApiDependencies):
Expand All @@ -19,7 +25,7 @@ def __init__(self, invoker) -> None:
self.invoker = invoker


def test_download_images_from_list(monkeypatch: Any, mock_invoker: Invoker) -> None:
def test_download_images_from_list(monkeypatch: Any, mock_invoker: Invoker, client: TestClient) -> None:
prepare_download_images_test(monkeypatch, mock_invoker)

response = client.post("/api/v1/images/download", json={"image_names": ["test.png"]})
Expand All @@ -28,7 +34,9 @@ def test_download_images_from_list(monkeypatch: Any, mock_invoker: Invoker) -> N
assert json_response["bulk_download_item_name"] == "test.zip"


def test_download_images_from_board_id_empty_image_name_list(monkeypatch: Any, mock_invoker: Invoker) -> None:
def test_download_images_from_board_id_empty_image_name_list(
monkeypatch: Any, mock_invoker: Invoker, client: TestClient
) -> None:
expected_board_name = "test"

def mock_get(*args, **kwargs):
Expand Down Expand Up @@ -56,15 +64,17 @@ def mock_add_task(*args, **kwargs):
monkeypatch.setattr(BackgroundTasks, "add_task", mock_add_task)


def test_download_images_with_empty_image_list_and_no_board_id(monkeypatch: Any, mock_invoker: Invoker) -> None:
def test_download_images_with_empty_image_list_and_no_board_id(
monkeypatch: Any, mock_invoker: Invoker, client: TestClient
) -> None:
prepare_download_images_test(monkeypatch, mock_invoker)

response = client.post("/api/v1/images/download", json={"image_names": []})

assert response.status_code == 400


def test_get_bulk_download_image(tmp_path: Path, monkeypatch: Any, mock_invoker: Invoker) -> None:
def test_get_bulk_download_image(tmp_path: Path, monkeypatch: Any, mock_invoker: Invoker, client: TestClient) -> None:
mock_file: Path = tmp_path / "test.zip"
mock_file.write_text("contents")

Expand All @@ -82,7 +92,7 @@ def mock_add_task(*args, **kwargs):
assert response.content == b"contents"


def test_get_bulk_download_image_not_found(monkeypatch: Any, mock_invoker: Invoker) -> None:
def test_get_bulk_download_image_not_found(monkeypatch: Any, mock_invoker: Invoker, client: TestClient) -> None:
monkeypatch.setattr("invokeai.app.api.routers.images.ApiDependencies", MockApiDependencies(mock_invoker))

def mock_add_task(*args, **kwargs):
Expand All @@ -96,7 +106,7 @@ def mock_add_task(*args, **kwargs):


def test_get_bulk_download_image_image_deleted_after_response(
monkeypatch: Any, mock_invoker: Invoker, tmp_path: Path
monkeypatch: Any, mock_invoker: Invoker, tmp_path: Path, client: TestClient
) -> None:
mock_file: Path = tmp_path / "test.zip"
mock_file.write_text("contents")
Expand Down
4 changes: 4 additions & 0 deletions tests/app/services/download/test_download_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ def broken_callback(job: DownloadJob) -> None:
queue.stop()


@pytest.mark.timeout(timeout=15, method="thread")
def test_cancel(tmp_path: Path, session: Session) -> None:
event_bus = TestEventService()

Expand All @@ -182,6 +183,9 @@ def cancelled_callback(job: DownloadJob) -> None:
nonlocal cancelled
cancelled = True

def handler(signum, frame):
raise TimeoutError("Join took too long to return")

job = queue.download(
source=AnyHttpUrl("http://www.civitai.com/models/12345"),
dest=tmp_path,
Expand Down
2 changes: 2 additions & 0 deletions tests/app/services/model_install/test_model_install.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ def test_delete_register(
store.get_model(key)


@pytest.mark.timeout(timeout=20, method="thread")
def test_simple_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
source = URLModelSource(url=Url("https://www.test.foo/download/test_embedding.safetensors"))

Expand All @@ -221,6 +222,7 @@ def test_simple_download(mm2_installer: ModelInstallServiceBase, mm2_app_config:
assert event_names == ["model_install_downloading", "model_install_running", "model_install_completed"]


@pytest.mark.timeout(timeout=20, method="thread")
def test_huggingface_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
source = URLModelSource(url=Url("https://huggingface.co/stabilityai/sdxl-turbo"))

Expand Down
2 changes: 2 additions & 0 deletions tests/backend/model_manager/model_manager_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import os
import shutil
import time
from pathlib import Path
from typing import Any, Dict, List

Expand Down Expand Up @@ -149,6 +150,7 @@ def mm2_installer(

def stop_installer() -> None:
installer.stop()
time.sleep(0.1) # avoid error message from the logger when it is closed before thread prints final message

request.addfinalizer(stop_installer)
return installer
Expand Down
10 changes: 10 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
# We import the model_installer and torch_device fixtures here so that they can be used by all tests. Flake8 does not
# play well with fixtures (F401 and F811), so this is cleaner than importing in all files that use these fixtures.
import logging
import shutil
from pathlib import Path

import pytest

Expand Down Expand Up @@ -58,3 +60,11 @@ def mock_services() -> InvocationServices:
@pytest.fixture()
def mock_invoker(mock_services: InvocationServices) -> Invoker:
return Invoker(services=mock_services)


@pytest.fixture(scope="module")
def invokeai_root_dir(tmp_path_factory) -> Path:
root_template = Path(__file__).parent.resolve() / "backend/model_manager/data/invokeai_root"
temp_dir: Path = tmp_path_factory.mktemp("data") / "invokeai_root"
shutil.copytree(root_template, temp_dir)
return temp_dir
Loading