Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add checks for monai bundles after download and warn if incompatible #7938

Merged
merged 17 commits into from
Jul 24, 2024
Merged
134 changes: 110 additions & 24 deletions monai/bundle/scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import torch
from torch.cuda import is_available

from monai.apps.mmars.mmars import _get_all_ngc_models
from monai._version import get_versions
from monai.apps.utils import _basename, download_url, extractall, get_logger
from monai.bundle.config_item import ConfigComponent
from monai.bundle.config_parser import ConfigParser
Expand Down Expand Up @@ -168,17 +168,28 @@ def _get_git_release_url(repo_owner: str, repo_name: str, tag_name: str, filenam
return f"https://github.com/{repo_owner}/{repo_name}/releases/download/{tag_name}/{filename}"


def _get_ngc_base_url() -> str:
return "https://api.ngc.nvidia.com/v2/models/nvidia/monaitoolkit"
ericspod marked this conversation as resolved.
Show resolved Hide resolved


def _get_ngc_bundle_url(model_name: str, version: str) -> str:
return f"https://api.ngc.nvidia.com/v2/models/nvidia/monaitoolkit/{model_name.lower()}/versions/{version}/zip"
return f"{_get_ngc_base_url()}/{model_name.lower()}/versions/{version}/zip"


def _get_ngc_private_base_url(repo: str) -> str:
return f"https://api.ngc.nvidia.com/v2/{repo}/models"


def _get_ngc_private_bundle_url(model_name: str, version: str, repo: str) -> str:
return f"https://api.ngc.nvidia.com/v2/{repo}/models/{model_name.lower()}/versions/{version}/zip"
return f"{_get_ngc_private_base_url(repo)}/{model_name.lower()}/versions/{version}/zip"


def _get_monaihosting_base_url() -> str:
return "https://api.ngc.nvidia.com/v2/models/nvidia/monaihosting"


def _get_monaihosting_bundle_url(model_name: str, version: str) -> str:
monaihosting_root_path = "https://api.ngc.nvidia.com/v2/models/nvidia/monaihosting"
return f"{monaihosting_root_path}/{model_name.lower()}/versions/{version}/files/{model_name}_v{version}.zip"
return f"{_get_monaihosting_base_url()}/{model_name.lower()}/versions/{version}/files/{model_name}_v{version}.zip"


def _download_from_github(repo: str, download_path: Path, filename: str, progress: bool = True) -> None:
Expand Down Expand Up @@ -267,8 +278,7 @@ def _get_ngc_token(api_key, retry=0):


def _get_latest_bundle_version_monaihosting(name):
url = "https://api.ngc.nvidia.com/v2/models/nvidia/monaihosting"
full_url = f"{url}/{name.lower()}"
full_url = f"{_get_monaihosting_base_url()}/{name.lower()}"
requests_get, has_requests = optional_import("requests", name="get")
if has_requests:
resp = requests_get(full_url)
Expand All @@ -279,36 +289,110 @@ def _get_latest_bundle_version_monaihosting(name):
return model_info["model"]["latestVersionIdStr"]


def _get_latest_bundle_version_private_registry(name, repo, headers=None):
url = f"https://api.ngc.nvidia.com/v2/{repo}/models"
full_url = f"{url}/{name.lower()}"
requests_get, has_requests = optional_import("requests", name="get")
if has_requests:
headers = {} if headers is None else headers
resp = requests_get(full_url, headers=headers)
resp.raise_for_status()
else:
raise ValueError("NGC API requires requests package. Please install it.")
def _examine_monai_version(monai_version: str) -> tuple[bool, str]:
"""Examine if the package version is compatible with the MONAI version in the metadata."""
version_dict = get_versions()
package_version = version_dict.get("version", None)
KumoLiu marked this conversation as resolved.
Show resolved Hide resolved
if package_version is None:
return False, "Package version is not available. Skipping version check."
# treat rc versions as the same as the release version
package_version = re.sub(r"rc\d.*", "", package_version)
monai_version = re.sub(r"rc\d.*", "", monai_version)
if package_version < monai_version:
return (
False,
f"Your MONAI version is {package_version}, but the bundle is built on MONAI version {monai_version}.",
)
return True, ""


def _check_monai_version(bundle_dir: PathLike, name: str) -> None:
"""Get the `monai_versions` from the metadata.json and compare if it is smaller than the package version"""
mingxin-zheng marked this conversation as resolved.
Show resolved Hide resolved
metadata_file = Path(bundle_dir) / name / "configs" / "metadata.json"
if not metadata_file.exists():
logger.warning(f"metadata file not found in {metadata_file}.")
return
with open(metadata_file) as f:
metadata = json.load(f)
is_compatible, msg = _examine_monai_version(metadata["monai_version"])
if not is_compatible:
logger.warning(msg)


def _list_latest_versions(data: dict, max_versions: int = 3) -> list[str]:
"""
Extract the latest versions from the data dictionary.

Args:
data: the data dictionary.
max_versions: the maximum number of versions to return.

Returns:
versions of the latest models in the reverse order of creation date, e.g. ['1.0.0', '0.9.0', '0.8.0'].
"""
# Check if the data is a dictionary and it has the key 'modelVersions'
if not isinstance(data, dict) or "modelVersions" not in data:
raise ValueError("The data is not a dictionary or it does not have the key 'modelVersions'.")

# Extract the list of model versions
model_versions = data["modelVersions"]

if (
not isinstance(model_versions, list)
or len(model_versions) == 0
or "createdDate" not in model_versions[0]
or "versionId" not in model_versions[0]
):
raise ValueError(
"The model versions are not a list or it is empty or it does not have the keys 'createdDate' and 'versionId'."
)

# Sort the versions by the 'createdDate' in descending order
sorted_versions = sorted(model_versions, key=lambda x: x["createdDate"], reverse=True)
return [v["versionId"] for v in sorted_versions[:max_versions]]


def _get_latest_bundle_version_ngc(name: str, repo: str | None = None, headers: dict | None = None) -> str:
base_url = _get_ngc_private_base_url(repo) if repo else _get_ngc_base_url()
version_endpoint = base_url + f"/{name.lower()}/versions/"

if not has_requests:
raise ValueError("requests package is required, please install it.")

version_header = {"Accept-Encoding": "gzip, deflate"} # Excluding 'zstd' to fit NGC requirements
if headers:
version_header.update(headers)
resp = requests_get(version_endpoint, headers=version_header)
resp.raise_for_status()
model_info = json.loads(resp.text)
return model_info["model"]["latestVersionIdStr"]
latest_versions = _list_latest_versions(model_info)

for version in latest_versions:
file_endpoint = base_url + f"/{name.lower()}/versions/{version}/files/configs/metadata.json"
resp = requests_get(file_endpoint, headers=headers)
metadata = json.loads(resp.text)
resp.raise_for_status()
# if the package version is not available or the model is compatible with the package version
is_compatible, _ = _examine_monai_version(metadata["monai_version"])
if is_compatible:
return version

# if no compatible version is found, return the latest version
return latest_versions[0]


def _get_latest_bundle_version(
source: str, name: str, repo: str, **kwargs: Any
) -> dict[str, list[str] | str] | Any | None:
if source == "ngc":
name = _add_ngc_prefix(name)
model_dict = _get_all_ngc_models(name)
for v in model_dict.values():
if v["name"] == name:
return v["latest"]
return None
return _get_latest_bundle_version_ngc(name)
elif source == "monaihosting":
return _get_latest_bundle_version_monaihosting(name)
elif source == "ngc_private":
headers = kwargs.pop("headers", {})
name = _add_ngc_prefix(name)
return _get_latest_bundle_version_private_registry(name, repo, headers)
return _get_latest_bundle_version_ngc(name, repo=repo, headers=headers)
elif source == "github":
repo_owner, repo_name, tag_name = repo.split("/")
return get_bundle_versions(name, repo=f"{repo_owner}/{repo_name}", tag=tag_name)["latest_version"]
Expand Down Expand Up @@ -501,6 +585,8 @@ def download(
f"got source: {source_}."
)

_check_monai_version(bundle_dir_, name_)
KumoLiu marked this conversation as resolved.
Show resolved Hide resolved


@deprecated_arg("net_name", since="1.2", removed="1.5", msg_suffix="please use ``model`` instead.")
@deprecated_arg("net_kwargs", since="1.2", removed="1.5", msg_suffix="please use ``model`` instead.")
Expand Down
51 changes: 51 additions & 0 deletions tests/test_bundle_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import tempfile
import unittest
from unittest.case import skipUnless
from unittest.mock import patch

import numpy as np
import torch
Expand All @@ -24,6 +25,7 @@
import monai.networks.nets as nets
from monai.apps import check_hash
from monai.bundle import ConfigParser, create_workflow, load
from monai.bundle.scripts import _examine_monai_version, _list_latest_versions, download
from monai.utils import optional_import
from tests.utils import (
SkipIfBeforePyTorchVersion,
Expand Down Expand Up @@ -207,6 +209,55 @@ def test_monaihosting_source_download_bundle(self, bundle_files, bundle_name, ve
file_path = os.path.join(tempdir, bundle_name, file)
self.assertTrue(os.path.exists(file_path))

@patch("monai.bundle.scripts.get_versions", return_value={"version": "1.2"})
def test_examine_monai_version(self, mock_get_versions):
self.assertTrue(_examine_monai_version("1.1")[0]) # Should return True, compatible
self.assertTrue(_examine_monai_version("1.2rc1")[0]) # Should return True, compatible
self.assertFalse(_examine_monai_version("1.3")[0]) # Should return False, not compatible

@patch("monai.bundle.scripts.get_versions", return_value={"version": "1.2rc1"})
def test_examine_monai_version_rc(self, mock_get_versions):
self.assertTrue(_examine_monai_version("1.2")[0]) # Should return True, compatible
self.assertFalse(_examine_monai_version("1.3")[0]) # Should return False, not compatible

def test_list_latest_versions(self):
"""Test listing of the latest versions."""
data = {
"modelVersions": [
{"createdDate": "2021-01-01", "versionId": "1.0"},
{"createdDate": "2021-01-02", "versionId": "1.1"},
{"createdDate": "2021-01-03", "versionId": "1.2"},
]
}
self.assertEqual(_list_latest_versions(data), ["1.2", "1.1", "1.0"])
self.assertEqual(_list_latest_versions(data, max_versions=2), ["1.2", "1.1"])
data = {
"modelVersions": [
{"createdDate": "2021-01-01", "versionId": "1.0"},
{"createdDate": "2021-01-02", "versionId": "1.1"},
]
}
self.assertEqual(_list_latest_versions(data), ["1.1", "1.0"])

@skip_if_quick
@patch("monai.bundle.scripts.get_versions", return_value={"version": "1.2"})
def test_download_monaihosting(self, mock_get_versions):
"""Test checking MONAI version from a metadata file."""
with patch("monai.bundle.scripts.logger") as mock_logger:
with tempfile.TemporaryDirectory() as tempdir:
download(name="spleen_ct_segmentation", bundle_dir=tempdir, source="monaihosting")
# Should have a warning message because the latest version is using monai > 1.2
KumoLiu marked this conversation as resolved.
Show resolved Hide resolved
mock_logger.warning.assert_called_once()

@skip_if_quick
@patch("monai.bundle.scripts.get_versions", return_value={"version": "1.2"})
def test_download_ngc(self, mock_get_versions):
"""Test checking MONAI version from a metadata file."""
with patch("monai.bundle.scripts.logger") as mock_logger:
with tempfile.TemporaryDirectory() as tempdir:
download(name="spleen_ct_segmentation", bundle_dir=tempdir, source="ngc")
mock_logger.warning.assert_not_called()
KumoLiu marked this conversation as resolved.
Show resolved Hide resolved


@skip_if_no_cuda
class TestLoad(unittest.TestCase):
Expand Down
Loading