From 55661479a316c0ddef22c17f81bb7910201dbb34 Mon Sep 17 00:00:00 2001 From: Mingxin Zheng Date: Mon, 22 Jul 2024 17:18:12 +0800 Subject: [PATCH 01/15] Add checks for monai bundles after download and warn if the version is incompatible Signed-off-by: Mingxin Zheng --- monai/bundle/scripts.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 56146546e8..7d0e13cc50 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -27,6 +27,7 @@ import torch from torch.cuda import is_available +from monai._version import get_versions from monai.apps.mmars.mmars import _get_all_ngc_models from monai.apps.utils import _basename, download_url, extractall, get_logger from monai.bundle.config_item import ConfigComponent @@ -336,6 +337,23 @@ def _process_bundle_dir(bundle_dir: PathLike | None = None) -> Path: return Path(bundle_dir) +def _check_monai_version(bundle_dir: PathLike) -> None: + """Get the `monai_versions` from the metadata.json and compare if it is smaller than the package version""" + metadata_file = Path(bundle_dir) / "configs" / "metadata.json" + if not metadata_file.exists(): + logger.warning(f"metadata file not found in {metadata_file}.") + return + with open(metadata_file, "r") as f: + metadata = json.load(f) + monai_version = metadata.get("monai_version", None) + version_dict = get_versions() + package_version = version_dict.get("version", None) + if package_version and monai_version and package_version < monai_version: + logger.warning( + f"Your MONAI version is {package_version}, but the bundle is built on MONAI version {monai_version}." + ) + + def download( name: str | None = None, version: str | None = None, @@ -501,6 +519,8 @@ def download( f"got source: {source_}." ) + _check_monai_version(bundle_dir_) + @deprecated_arg("net_name", since="1.2", removed="1.5", msg_suffix="please use ``model`` instead.") @deprecated_arg("net_kwargs", since="1.2", removed="1.5", msg_suffix="please use ``model`` instead.") From 3831fe7929c755110b8ffd9fb0beb5fa2551d468 Mon Sep 17 00:00:00 2001 From: Mingxin Zheng Date: Tue, 23 Jul 2024 15:33:28 +0800 Subject: [PATCH 02/15] working ver Signed-off-by: Mingxin Zheng --- monai/bundle/scripts.py | 105 ++++++++++++++++++++++++++++++---------- 1 file changed, 79 insertions(+), 26 deletions(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 7d0e13cc50..ddb3ce19ad 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -28,7 +28,6 @@ from torch.cuda import is_available from monai._version import get_versions -from monai.apps.mmars.mmars import _get_all_ngc_models from monai.apps.utils import _basename, download_url, extractall, get_logger from monai.bundle.config_item import ConfigComponent from monai.bundle.config_parser import ConfigParser @@ -169,17 +168,28 @@ def _get_git_release_url(repo_owner: str, repo_name: str, tag_name: str, filenam return f"https://github.com/{repo_owner}/{repo_name}/releases/download/{tag_name}/{filename}" +def _get_ngc_base_url() -> str: + return "https://api.ngc.nvidia.com/v2/models/nvidia/monaitoolkit" + + def _get_ngc_bundle_url(model_name: str, version: str) -> str: - return f"https://api.ngc.nvidia.com/v2/models/nvidia/monaitoolkit/{model_name.lower()}/versions/{version}/zip" + return f"{_get_ngc_base_url()}/{model_name.lower()}/versions/{version}/zip" + + +def _get_ngc_private_base_url(repo: str) -> str: + return f"https://api.ngc.nvidia.com/v2/{repo}/models" def _get_ngc_private_bundle_url(model_name: str, version: str, repo: str) -> str: - return f"https://api.ngc.nvidia.com/v2/{repo}/models/{model_name.lower()}/versions/{version}/zip" + return f"{_get_ngc_private_base_url(repo)}/{model_name.lower()}/versions/{version}/zip" + + +def _get_monaihosting_base_url() -> str: + return "https://api.ngc.nvidia.com/v2/models/nvidia/monaihosting" def _get_monaihosting_bundle_url(model_name: str, version: str) -> str: - monaihosting_root_path = "https://api.ngc.nvidia.com/v2/models/nvidia/monaihosting" - return f"{monaihosting_root_path}/{model_name.lower()}/versions/{version}/files/{model_name}_v{version}.zip" + return f"{_get_monaihosting_base_url()}/{model_name.lower()}/versions/{version}/files/{model_name}_v{version}.zip" def _download_from_github(repo: str, download_path: Path, filename: str, progress: bool = True) -> None: @@ -268,8 +278,7 @@ def _get_ngc_token(api_key, retry=0): def _get_latest_bundle_version_monaihosting(name): - url = "https://api.ngc.nvidia.com/v2/models/nvidia/monaihosting" - full_url = f"{url}/{name.lower()}" + full_url = f"{_get_monaihosting_base_url()}/{name.lower()}" requests_get, has_requests = optional_import("requests", name="get") if has_requests: resp = requests_get(full_url) @@ -280,18 +289,62 @@ def _get_latest_bundle_version_monaihosting(name): return model_info["model"]["latestVersionIdStr"] -def _get_latest_bundle_version_private_registry(name, repo, headers=None): - url = f"https://api.ngc.nvidia.com/v2/{repo}/models" - full_url = f"{url}/{name.lower()}" - requests_get, has_requests = optional_import("requests", name="get") +def _list_latest_versions(data, max_versions: int = 3): + """ + Extract the latest versions from the data dictionary. + + Args: + data: the data dictionary. + max_versions: the maximum number of versions to return. + + Returns: + versions of the latest models in the reverse order of creation date, e.g. ['1.0.0', '0.9.0', '0.8.0']. + """ + # Check if the data is a dictionary and it has the key 'modelVersions' + if not isinstance(data, dict) or 'modelVersions' not in data: + raise ValueError("The data is not a dictionary or it does not have the key 'modelVersions'.") + + # Extract the list of model versions + model_versions = data['modelVersions'] + + if not isinstance(model_versions, list) or len(model_versions) == 0 or "createdDate" not in model_versions[0] or "versionId" not in model_versions[0]: + raise ValueError("The model versions are not a list or it is empty or it does not have the keys 'createdDate' and 'versionId'.") + + # Sort the versions by the 'createdDate' in descending order + sorted_versions = sorted(model_versions, key=lambda x: x['createdDate'], reverse=True) + return [v["versionId"] for v in sorted_versions[:max_versions]] + + +def _get_latest_bundle_version_ngc(name: str, repo: str = None, headers: dict | None = None) -> dict[str, list[str] | str]: + version_dict = get_versions() + package_version = version_dict.get("version", None) + base_url = _get_ngc_private_base_url(repo) if repo else _get_ngc_base_url() + version_endpoint = base_url + f"/{name.lower()}/versions/" + if has_requests: - headers = {} if headers is None else headers - resp = requests_get(full_url, headers=headers) + version_header = {'Accept-Encoding': 'gzip, deflate'} # Excluding 'zstd' + if headers: + version_header.update(headers) + resp = requests_get(version_endpoint, headers=version_header) resp.raise_for_status() + model_info = json.loads(resp.text) + latest_versions = _list_latest_versions(model_info) + + for version in latest_versions: + file_endpoint = base_url + f"/{name.lower()}/versions/{version}/files/configs/metadata.json" + try: + resp = requests_get(file_endpoint, headers=headers) + metadata = json.loads(resp.text) + # if the package version is not available or the model is compatible with the package version + if not package_version or metadata["monai_version"] <= package_version: + return version + except Exception as e: + raise ValueError(f"Failed to get metadata from {file_endpoint}.") from e + + # if no compatible version is found, return the latest version + return latest_versions[0] else: - raise ValueError("NGC API requires requests package. Please install it.") - model_info = json.loads(resp.text) - return model_info["model"]["latestVersionIdStr"] + raise ValueError("requests package is required, please install it.") def _get_latest_bundle_version( @@ -299,17 +352,13 @@ def _get_latest_bundle_version( ) -> dict[str, list[str] | str] | Any | None: if source == "ngc": name = _add_ngc_prefix(name) - model_dict = _get_all_ngc_models(name) - for v in model_dict.values(): - if v["name"] == name: - return v["latest"] - return None + return _get_latest_bundle_version_ngc(name) elif source == "monaihosting": return _get_latest_bundle_version_monaihosting(name) elif source == "ngc_private": headers = kwargs.pop("headers", {}) name = _add_ngc_prefix(name) - return _get_latest_bundle_version_private_registry(name, repo, headers) + return _get_latest_bundle_version_ngc(name, repo=repo, headers=headers) elif source == "github": repo_owner, repo_name, tag_name = repo.split("/") return get_bundle_versions(name, repo=f"{repo_owner}/{repo_name}", tag=tag_name)["latest_version"] @@ -348,10 +397,14 @@ def _check_monai_version(bundle_dir: PathLike) -> None: monai_version = metadata.get("monai_version", None) version_dict = get_versions() package_version = version_dict.get("version", None) - if package_version and monai_version and package_version < monai_version: - logger.warning( - f"Your MONAI version is {package_version}, but the bundle is built on MONAI version {monai_version}." - ) + if package_version and monai_version: + # treat rc versions as the same as the release version + package_version = re.sub(r"rc\d.*", "", package_version) + monai_version = re.sub(r"rc\d.*", "", monai_version) + if package_version < monai_version: + logger.warning( + f"Your MONAI version is {package_version}, but the bundle is built on MONAI version {monai_version}." + ) def download( From 49d875e1428b3a5299a6061ea6d5d1b15c9611e6 Mon Sep 17 00:00:00 2001 From: Mingxin Zheng Date: Tue, 23 Jul 2024 15:35:02 +0800 Subject: [PATCH 03/15] format Signed-off-by: Mingxin Zheng --- monai/bundle/scripts.py | 29 +++++++++++++++++++---------- 1 file changed, 19 insertions(+), 10 deletions(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index ddb3ce19ad..5e56587b84 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -296,33 +296,42 @@ def _list_latest_versions(data, max_versions: int = 3): Args: data: the data dictionary. max_versions: the maximum number of versions to return. - + Returns: versions of the latest models in the reverse order of creation date, e.g. ['1.0.0', '0.9.0', '0.8.0']. """ # Check if the data is a dictionary and it has the key 'modelVersions' - if not isinstance(data, dict) or 'modelVersions' not in data: + if not isinstance(data, dict) or "modelVersions" not in data: raise ValueError("The data is not a dictionary or it does not have the key 'modelVersions'.") # Extract the list of model versions - model_versions = data['modelVersions'] + model_versions = data["modelVersions"] + + if ( + not isinstance(model_versions, list) + or len(model_versions) == 0 + or "createdDate" not in model_versions[0] + or "versionId" not in model_versions[0] + ): + raise ValueError( + "The model versions are not a list or it is empty or it does not have the keys 'createdDate' and 'versionId'." + ) - if not isinstance(model_versions, list) or len(model_versions) == 0 or "createdDate" not in model_versions[0] or "versionId" not in model_versions[0]: - raise ValueError("The model versions are not a list or it is empty or it does not have the keys 'createdDate' and 'versionId'.") - # Sort the versions by the 'createdDate' in descending order - sorted_versions = sorted(model_versions, key=lambda x: x['createdDate'], reverse=True) + sorted_versions = sorted(model_versions, key=lambda x: x["createdDate"], reverse=True) return [v["versionId"] for v in sorted_versions[:max_versions]] -def _get_latest_bundle_version_ngc(name: str, repo: str = None, headers: dict | None = None) -> dict[str, list[str] | str]: +def _get_latest_bundle_version_ngc( + name: str, repo: str = None, headers: dict | None = None +) -> dict[str, list[str] | str]: version_dict = get_versions() package_version = version_dict.get("version", None) base_url = _get_ngc_private_base_url(repo) if repo else _get_ngc_base_url() version_endpoint = base_url + f"/{name.lower()}/versions/" if has_requests: - version_header = {'Accept-Encoding': 'gzip, deflate'} # Excluding 'zstd' + version_header = {"Accept-Encoding": "gzip, deflate"} # Excluding 'zstd' if headers: version_header.update(headers) resp = requests_get(version_endpoint, headers=version_header) @@ -340,7 +349,7 @@ def _get_latest_bundle_version_ngc(name: str, repo: str = None, headers: dict | return version except Exception as e: raise ValueError(f"Failed to get metadata from {file_endpoint}.") from e - + # if no compatible version is found, return the latest version return latest_versions[0] else: From 4c6f9ce5f1a88b1cce3037f810cc4a61993c1680 Mon Sep 17 00:00:00 2001 From: Mingxin Zheng Date: Tue, 23 Jul 2024 15:41:32 +0800 Subject: [PATCH 04/15] comment Signed-off-by: Mingxin Zheng --- monai/bundle/scripts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 5e56587b84..e529d7bcb2 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -331,7 +331,7 @@ def _get_latest_bundle_version_ngc( version_endpoint = base_url + f"/{name.lower()}/versions/" if has_requests: - version_header = {"Accept-Encoding": "gzip, deflate"} # Excluding 'zstd' + version_header = {"Accept-Encoding": "gzip, deflate"} # Excluding 'zstd' to fit NGC requirements if headers: version_header.update(headers) resp = requests_get(version_endpoint, headers=version_header) From 97a1b0e99dfdcd824a69e431f75be2473e984068 Mon Sep 17 00:00:00 2001 From: Mingxin Zheng Date: Tue, 23 Jul 2024 16:26:52 +0800 Subject: [PATCH 05/15] fixes Signed-off-by: Mingxin Zheng --- monai/bundle/scripts.py | 65 +++++++++++++++++++++-------------------- 1 file changed, 34 insertions(+), 31 deletions(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index e529d7bcb2..be0fda05fb 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -289,6 +289,33 @@ def _get_latest_bundle_version_monaihosting(name): return model_info["model"]["latestVersionIdStr"] +def _examine_monai_version(monai_version: str) -> tuple[bool, str]: + """Examine if the package version is compatible with the MONAI version in the metadata.""" + version_dict = get_versions() + package_version = version_dict.get("version", None) + if package_version is None: + return False, "Package version is not available. Skipping version check." + # treat rc versions as the same as the release version + package_version = re.sub(r"rc\d.*", "", package_version) + monai_version = re.sub(r"rc\d.*", "", monai_version) + if package_version < monai_version: + return False, f"Your MONAI version is {package_version}, but the bundle is built on MONAI version {monai_version}." + return True, "" + + +def _check_monai_version(bundle_dir: PathLike) -> None: + """Get the `monai_versions` from the metadata.json and compare if it is smaller than the package version""" + metadata_file = Path(bundle_dir) / "configs" / "metadata.json" + if not metadata_file.exists(): + logger.warning(f"metadata file not found in {metadata_file}.") + return + with open(metadata_file, "r") as f: + metadata = json.load(f) + is_compatible, msg = _examine_monai_version(metadata["monai_version"]) + if not is_compatible: + logger.warning(msg) + + def _list_latest_versions(data, max_versions: int = 3): """ Extract the latest versions from the data dictionary. @@ -325,8 +352,6 @@ def _list_latest_versions(data, max_versions: int = 3): def _get_latest_bundle_version_ngc( name: str, repo: str = None, headers: dict | None = None ) -> dict[str, list[str] | str]: - version_dict = get_versions() - package_version = version_dict.get("version", None) base_url = _get_ngc_private_base_url(repo) if repo else _get_ngc_base_url() version_endpoint = base_url + f"/{name.lower()}/versions/" @@ -341,14 +366,13 @@ def _get_latest_bundle_version_ngc( for version in latest_versions: file_endpoint = base_url + f"/{name.lower()}/versions/{version}/files/configs/metadata.json" - try: - resp = requests_get(file_endpoint, headers=headers) - metadata = json.loads(resp.text) - # if the package version is not available or the model is compatible with the package version - if not package_version or metadata["monai_version"] <= package_version: - return version - except Exception as e: - raise ValueError(f"Failed to get metadata from {file_endpoint}.") from e + resp = requests_get(file_endpoint, headers=headers) + metadata = json.loads(resp.text) + resp.raise_for_status() + # if the package version is not available or the model is compatible with the package version + is_compatible, _ = _examine_monai_version(metadata["monai_version"]) + if is_compatible: + return version # if no compatible version is found, return the latest version return latest_versions[0] @@ -395,27 +419,6 @@ def _process_bundle_dir(bundle_dir: PathLike | None = None) -> Path: return Path(bundle_dir) -def _check_monai_version(bundle_dir: PathLike) -> None: - """Get the `monai_versions` from the metadata.json and compare if it is smaller than the package version""" - metadata_file = Path(bundle_dir) / "configs" / "metadata.json" - if not metadata_file.exists(): - logger.warning(f"metadata file not found in {metadata_file}.") - return - with open(metadata_file, "r") as f: - metadata = json.load(f) - monai_version = metadata.get("monai_version", None) - version_dict = get_versions() - package_version = version_dict.get("version", None) - if package_version and monai_version: - # treat rc versions as the same as the release version - package_version = re.sub(r"rc\d.*", "", package_version) - monai_version = re.sub(r"rc\d.*", "", monai_version) - if package_version < monai_version: - logger.warning( - f"Your MONAI version is {package_version}, but the bundle is built on MONAI version {monai_version}." - ) - - def download( name: str | None = None, version: str | None = None, From d6edbfac2fb4446581ad97b53582e4b8cd511903 Mon Sep 17 00:00:00 2001 From: Mingxin Zheng Date: Tue, 23 Jul 2024 16:28:31 +0800 Subject: [PATCH 06/15] fix format Signed-off-by: Mingxin Zheng --- monai/bundle/scripts.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index be0fda05fb..7dbe90adee 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -299,7 +299,10 @@ def _examine_monai_version(monai_version: str) -> tuple[bool, str]: package_version = re.sub(r"rc\d.*", "", package_version) monai_version = re.sub(r"rc\d.*", "", monai_version) if package_version < monai_version: - return False, f"Your MONAI version is {package_version}, but the bundle is built on MONAI version {monai_version}." + return ( + False, + f"Your MONAI version is {package_version}, but the bundle is built on MONAI version {monai_version}.", + ) return True, "" @@ -372,7 +375,7 @@ def _get_latest_bundle_version_ngc( # if the package version is not available or the model is compatible with the package version is_compatible, _ = _examine_monai_version(metadata["monai_version"]) if is_compatible: - return version + return version # if no compatible version is found, return the latest version return latest_versions[0] From 2d534715688921b6a727d0d0f6ab148a454656ef Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 23 Jul 2024 08:29:59 +0000 Subject: [PATCH 07/15] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/bundle/scripts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 7dbe90adee..924d18a880 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -312,7 +312,7 @@ def _check_monai_version(bundle_dir: PathLike) -> None: if not metadata_file.exists(): logger.warning(f"metadata file not found in {metadata_file}.") return - with open(metadata_file, "r") as f: + with open(metadata_file) as f: metadata = json.load(f) is_compatible, msg = _examine_monai_version(metadata["monai_version"]) if not is_compatible: From cb1596f15d03f44b6698a562cbf51cf1cc765b50 Mon Sep 17 00:00:00 2001 From: Mingxin Zheng Date: Tue, 23 Jul 2024 16:35:45 +0800 Subject: [PATCH 08/15] fix mypy Signed-off-by: Mingxin Zheng --- monai/bundle/scripts.py | 50 ++++++++++++++++++++--------------------- 1 file changed, 25 insertions(+), 25 deletions(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 924d18a880..32aa3e339f 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -319,7 +319,7 @@ def _check_monai_version(bundle_dir: PathLike) -> None: logger.warning(msg) -def _list_latest_versions(data, max_versions: int = 3): +def _list_latest_versions(data: dict, max_versions: int = 3) -> list[str]: """ Extract the latest versions from the data dictionary. @@ -353,35 +353,35 @@ def _list_latest_versions(data, max_versions: int = 3): def _get_latest_bundle_version_ngc( - name: str, repo: str = None, headers: dict | None = None -) -> dict[str, list[str] | str]: + name: str, repo: str | None = None, headers: dict | None = None +) -> str: base_url = _get_ngc_private_base_url(repo) if repo else _get_ngc_base_url() version_endpoint = base_url + f"/{name.lower()}/versions/" - if has_requests: - version_header = {"Accept-Encoding": "gzip, deflate"} # Excluding 'zstd' to fit NGC requirements - if headers: - version_header.update(headers) - resp = requests_get(version_endpoint, headers=version_header) - resp.raise_for_status() - model_info = json.loads(resp.text) - latest_versions = _list_latest_versions(model_info) - - for version in latest_versions: - file_endpoint = base_url + f"/{name.lower()}/versions/{version}/files/configs/metadata.json" - resp = requests_get(file_endpoint, headers=headers) - metadata = json.loads(resp.text) - resp.raise_for_status() - # if the package version is not available or the model is compatible with the package version - is_compatible, _ = _examine_monai_version(metadata["monai_version"]) - if is_compatible: - return version - - # if no compatible version is found, return the latest version - return latest_versions[0] - else: + if not has_requests: raise ValueError("requests package is required, please install it.") + version_header = {"Accept-Encoding": "gzip, deflate"} # Excluding 'zstd' to fit NGC requirements + if headers: + version_header.update(headers) + resp = requests_get(version_endpoint, headers=version_header) + resp.raise_for_status() + model_info = json.loads(resp.text) + latest_versions = _list_latest_versions(model_info) + + for version in latest_versions: + file_endpoint = base_url + f"/{name.lower()}/versions/{version}/files/configs/metadata.json" + resp = requests_get(file_endpoint, headers=headers) + metadata = json.loads(resp.text) + resp.raise_for_status() + # if the package version is not available or the model is compatible with the package version + is_compatible, _ = _examine_monai_version(metadata["monai_version"]) + if is_compatible: + return version + + # if no compatible version is found, return the latest version + return latest_versions[0] + def _get_latest_bundle_version( source: str, name: str, repo: str, **kwargs: Any From dd3108a9316c9918ed17c31c0987ba7f2dc66e92 Mon Sep 17 00:00:00 2001 From: Mingxin Zheng Date: Tue, 23 Jul 2024 16:43:13 +0800 Subject: [PATCH 09/15] fix Signed-off-by: Mingxin Zheng --- monai/bundle/scripts.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 32aa3e339f..d5aaffec40 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -352,9 +352,7 @@ def _list_latest_versions(data: dict, max_versions: int = 3) -> list[str]: return [v["versionId"] for v in sorted_versions[:max_versions]] -def _get_latest_bundle_version_ngc( - name: str, repo: str | None = None, headers: dict | None = None -) -> str: +def _get_latest_bundle_version_ngc(name: str, repo: str | None = None, headers: dict | None = None) -> str: base_url = _get_ngc_private_base_url(repo) if repo else _get_ngc_base_url() version_endpoint = base_url + f"/{name.lower()}/versions/" From 2d95c264834f8a0e42c8272c830930f5af768104 Mon Sep 17 00:00:00 2001 From: Mingxin Zheng Date: Tue, 23 Jul 2024 19:10:24 +0800 Subject: [PATCH 10/15] add test Signed-off-by: Mingxin Zheng --- monai/bundle/scripts.py | 6 ++--- tests/test_bundle_download.py | 51 +++++++++++++++++++++++++++++++++++ 2 files changed, 54 insertions(+), 3 deletions(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index d5aaffec40..e56949fddf 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -306,9 +306,9 @@ def _examine_monai_version(monai_version: str) -> tuple[bool, str]: return True, "" -def _check_monai_version(bundle_dir: PathLike) -> None: +def _check_monai_version(bundle_dir: PathLike, name: str) -> None: """Get the `monai_versions` from the metadata.json and compare if it is smaller than the package version""" - metadata_file = Path(bundle_dir) / "configs" / "metadata.json" + metadata_file = Path(bundle_dir) / name / "configs" / "metadata.json" if not metadata_file.exists(): logger.warning(f"metadata file not found in {metadata_file}.") return @@ -585,7 +585,7 @@ def download( f"got source: {source_}." ) - _check_monai_version(bundle_dir_) + _check_monai_version(bundle_dir_, name_) @deprecated_arg("net_name", since="1.2", removed="1.5", msg_suffix="please use ``model`` instead.") diff --git a/tests/test_bundle_download.py b/tests/test_bundle_download.py index fe7caf5c17..331d228f1e 100644 --- a/tests/test_bundle_download.py +++ b/tests/test_bundle_download.py @@ -16,6 +16,7 @@ import tempfile import unittest from unittest.case import skipUnless +from unittest.mock import patch import numpy as np import torch @@ -24,6 +25,7 @@ import monai.networks.nets as nets from monai.apps import check_hash from monai.bundle import ConfigParser, create_workflow, load +from monai.bundle.scripts import _examine_monai_version, _list_latest_versions, download from monai.utils import optional_import from tests.utils import ( SkipIfBeforePyTorchVersion, @@ -207,6 +209,55 @@ def test_monaihosting_source_download_bundle(self, bundle_files, bundle_name, ve file_path = os.path.join(tempdir, bundle_name, file) self.assertTrue(os.path.exists(file_path)) + @patch("monai.bundle.scripts.get_versions", return_value={"version": "1.2"}) + def test_examine_monai_version(self, mock_get_versions): + self.assertTrue(_examine_monai_version("1.1")[0]) # Should return True, compatible + self.assertTrue(_examine_monai_version("1.2rc1")[0]) # Should return True, compatible + self.assertFalse(_examine_monai_version("1.3")[0]) # Should return False, not compatible + + @patch("monai.bundle.scripts.get_versions", return_value={"version": "1.2rc1"}) + def test_examine_monai_version_rc(self, mock_get_versions): + self.assertTrue(_examine_monai_version("1.2")[0]) # Should return True, compatible + self.assertFalse(_examine_monai_version("1.3")[0]) # Should return False, not compatible + + def test_list_latest_versions(self): + """Test listing of the latest versions.""" + data = { + "modelVersions": [ + {"createdDate": "2021-01-01", "versionId": "1.0"}, + {"createdDate": "2021-01-02", "versionId": "1.1"}, + {"createdDate": "2021-01-03", "versionId": "1.2"}, + ] + } + self.assertEqual(_list_latest_versions(data), ["1.2", "1.1", "1.0"]) + self.assertEqual(_list_latest_versions(data, max_versions=2), ["1.2", "1.1"]) + data = { + "modelVersions": [ + {"createdDate": "2021-01-01", "versionId": "1.0"}, + {"createdDate": "2021-01-02", "versionId": "1.1"}, + ] + } + self.assertEqual(_list_latest_versions(data), ["1.1", "1.0"]) + + @skip_if_quick + @patch("monai.bundle.scripts.get_versions", return_value={"version": "1.2"}) + def test_download_monaihosting(self, mock_get_versions): + """Test checking MONAI version from a metadata file.""" + with patch("monai.bundle.scripts.logger") as mock_logger: + with tempfile.TemporaryDirectory() as tempdir: + download(name="spleen_ct_segmentation", bundle_dir=tempdir, source="monaihosting") + # Should have a warning message because the latest version is using monai > 1.2 + mock_logger.warning.assert_called_once() + + @skip_if_quick + @patch("monai.bundle.scripts.get_versions", return_value={"version": "1.2"}) + def test_download_ngc(self, mock_get_versions): + """Test checking MONAI version from a metadata file.""" + with patch("monai.bundle.scripts.logger") as mock_logger: + with tempfile.TemporaryDirectory() as tempdir: + download(name="spleen_ct_segmentation", bundle_dir=tempdir, source="ngc") + mock_logger.warning.assert_not_called() + @skip_if_no_cuda class TestLoad(unittest.TestCase): From b6eb67a0418760e6662181f9a4c7dbacb05334f3 Mon Sep 17 00:00:00 2001 From: Mingxin Zheng <18563433+mingxin-zheng@users.noreply.github.com> Date: Wed, 24 Jul 2024 10:31:34 +0800 Subject: [PATCH 11/15] Update monai/bundle/scripts.py Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Signed-off-by: Mingxin Zheng <18563433+mingxin-zheng@users.noreply.github.com> --- monai/bundle/scripts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index e56949fddf..d4d5844a70 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -307,7 +307,7 @@ def _examine_monai_version(monai_version: str) -> tuple[bool, str]: def _check_monai_version(bundle_dir: PathLike, name: str) -> None: - """Get the `monai_versions` from the metadata.json and compare if it is smaller than the package version""" + """Get the `monai_version` from the metadata.json and compare if it is smaller than the installed `monai` package version""" metadata_file = Path(bundle_dir) / name / "configs" / "metadata.json" if not metadata_file.exists(): logger.warning(f"metadata file not found in {metadata_file}.") From 20c84244f806df3f9e9b6d9aaf68e0e9d085489f Mon Sep 17 00:00:00 2001 From: Mingxin Zheng Date: Wed, 24 Jul 2024 10:36:54 +0800 Subject: [PATCH 12/15] fix comment Signed-off-by: Mingxin Zheng --- monai/bundle/scripts.py | 23 +++++++++-------------- 1 file changed, 9 insertions(+), 14 deletions(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index d4d5844a70..8f4309aae5 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -67,6 +67,9 @@ DEFAULT_DOWNLOAD_SOURCE = os.environ.get("BUNDLE_DOWNLOAD_SRC", "monaihosting") PPRINT_CONFIG_N = 5 +MONAI_HOSTING_BASE_URL = "https://api.ngc.nvidia.com/v2/models/nvidia/monaihosting" +NGC_BASE_URL = "https://api.ngc.nvidia.com/v2/models/nvidia/monaitoolkit" + def update_kwargs(args: str | dict | None = None, ignore_none: bool = True, **kwargs: Any) -> dict: """ @@ -168,12 +171,8 @@ def _get_git_release_url(repo_owner: str, repo_name: str, tag_name: str, filenam return f"https://github.com/{repo_owner}/{repo_name}/releases/download/{tag_name}/{filename}" -def _get_ngc_base_url() -> str: - return "https://api.ngc.nvidia.com/v2/models/nvidia/monaitoolkit" - - def _get_ngc_bundle_url(model_name: str, version: str) -> str: - return f"{_get_ngc_base_url()}/{model_name.lower()}/versions/{version}/zip" + return f"{NGC_BASE_URL}/{model_name.lower()}/versions/{version}/zip" def _get_ngc_private_base_url(repo: str) -> str: @@ -184,12 +183,8 @@ def _get_ngc_private_bundle_url(model_name: str, version: str, repo: str) -> str return f"{_get_ngc_private_base_url(repo)}/{model_name.lower()}/versions/{version}/zip" -def _get_monaihosting_base_url() -> str: - return "https://api.ngc.nvidia.com/v2/models/nvidia/monaihosting" - - def _get_monaihosting_bundle_url(model_name: str, version: str) -> str: - return f"{_get_monaihosting_base_url()}/{model_name.lower()}/versions/{version}/files/{model_name}_v{version}.zip" + return f"{MONAI_HOSTING_BASE_URL}/{model_name.lower()}/versions/{version}/files/{model_name}_v{version}.zip" def _download_from_github(repo: str, download_path: Path, filename: str, progress: bool = True) -> None: @@ -278,7 +273,7 @@ def _get_ngc_token(api_key, retry=0): def _get_latest_bundle_version_monaihosting(name): - full_url = f"{_get_monaihosting_base_url()}/{name.lower()}" + full_url = f"{MONAI_HOSTING_BASE_URL}/{name.lower()}" requests_get, has_requests = optional_import("requests", name="get") if has_requests: resp = requests_get(full_url) @@ -292,8 +287,8 @@ def _get_latest_bundle_version_monaihosting(name): def _examine_monai_version(monai_version: str) -> tuple[bool, str]: """Examine if the package version is compatible with the MONAI version in the metadata.""" version_dict = get_versions() - package_version = version_dict.get("version", None) - if package_version is None: + package_version = version_dict.get("version", "0+unknown") + if package_version == "0+unknown": return False, "Package version is not available. Skipping version check." # treat rc versions as the same as the release version package_version = re.sub(r"rc\d.*", "", package_version) @@ -353,7 +348,7 @@ def _list_latest_versions(data: dict, max_versions: int = 3) -> list[str]: def _get_latest_bundle_version_ngc(name: str, repo: str | None = None, headers: dict | None = None) -> str: - base_url = _get_ngc_private_base_url(repo) if repo else _get_ngc_base_url() + base_url = _get_ngc_private_base_url(repo) if repo else NGC_BASE_URL version_endpoint = base_url + f"/{name.lower()}/versions/" if not has_requests: From cbfd184c3968ccb9de89e0911687787d701fe9ec Mon Sep 17 00:00:00 2001 From: Mingxin Zheng Date: Wed, 24 Jul 2024 10:43:07 +0800 Subject: [PATCH 13/15] fix comment 2 Signed-off-by: Mingxin Zheng --- monai/bundle/scripts.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 8f4309aae5..fcce193cf7 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -290,6 +290,8 @@ def _examine_monai_version(monai_version: str) -> tuple[bool, str]: package_version = version_dict.get("version", "0+unknown") if package_version == "0+unknown": return False, "Package version is not available. Skipping version check." + if monai_version == "0+unknown": + return False, "MONAI version is not specified in the bundle. Skipping version check." # treat rc versions as the same as the release version package_version = re.sub(r"rc\d.*", "", package_version) monai_version = re.sub(r"rc\d.*", "", monai_version) @@ -309,7 +311,7 @@ def _check_monai_version(bundle_dir: PathLike, name: str) -> None: return with open(metadata_file) as f: metadata = json.load(f) - is_compatible, msg = _examine_monai_version(metadata["monai_version"]) + is_compatible, msg = _examine_monai_version(metadata.get("monai_version", "0+unknown")) if not is_compatible: logger.warning(msg) From 1ce127dc0fd476a194499fd680b6becd363538eb Mon Sep 17 00:00:00 2001 From: Mingxin Zheng Date: Wed, 24 Jul 2024 14:59:13 +0800 Subject: [PATCH 14/15] fix comment 3 Signed-off-by: Mingxin Zheng --- monai/bundle/scripts.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index fcce193cf7..0689fed969 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -372,6 +372,8 @@ def _get_latest_bundle_version_ngc(name: str, repo: str | None = None, headers: # if the package version is not available or the model is compatible with the package version is_compatible, _ = _examine_monai_version(metadata["monai_version"]) if is_compatible: + if version != latest_versions[0]: + logger.info(f"Latest version is {latest_versions[0]}, but the compatible version is {version}.") return version # if no compatible version is found, return the latest version @@ -552,8 +554,8 @@ def download( version_ = _get_latest_bundle_version(source=source_, name=name_, repo=repo_, headers=headers) if source_ == "github": if version_ is not None: - name_ = "_v".join([name_, version_]) - _download_from_github(repo=repo_, download_path=bundle_dir_, filename=name_, progress=progress_) + name_ver = "_v".join([name_, version_]) + _download_from_github(repo=repo_, download_path=bundle_dir_, filename=name_ver, progress=progress_) elif source_ == "monaihosting": _download_from_monaihosting(download_path=bundle_dir_, filename=name_, version=version_, progress=progress_) elif source_ == "ngc": From 8f309c859b188af9f0f49fa5a3fee7152dec92b0 Mon Sep 17 00:00:00 2001 From: Mingxin Zheng Date: Wed, 24 Jul 2024 15:07:54 +0800 Subject: [PATCH 15/15] fix Signed-off-by: Mingxin Zheng --- monai/bundle/scripts.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 0689fed969..4967b6cf50 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -553,8 +553,7 @@ def download( if version_ is None: version_ = _get_latest_bundle_version(source=source_, name=name_, repo=repo_, headers=headers) if source_ == "github": - if version_ is not None: - name_ver = "_v".join([name_, version_]) + name_ver = "_v".join([name_, version_]) if version_ is not None else name_ _download_from_github(repo=repo_, download_path=bundle_dir_, filename=name_ver, progress=progress_) elif source_ == "monaihosting": _download_from_monaihosting(download_path=bundle_dir_, filename=name_, version=version_, progress=progress_)