From 00b194a02212d80ed0ea071f3b83db8965462708 Mon Sep 17 00:00:00 2001 From: Andrey Talman Date: Thu, 16 Nov 2023 09:37:25 -0500 Subject: [PATCH] Release v1 (#1595) * test * test --- test/smoke_test/smoke_test.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/test/smoke_test/smoke_test.py b/test/smoke_test/smoke_test.py index 3d1b6af64..5e0d0ace7 100644 --- a/test/smoke_test/smoke_test.py +++ b/test/smoke_test/smoke_test.py @@ -52,6 +52,20 @@ def forward(self, x): output = self.fc1(x) return output +def load_json_from_basedir(filename: str): + try: + if os.path.exists(BASE_DIR / filename): + with open(BASE_DIR / filename) as fptr: + return json.load(fptr) + else: + return None + except FileNotFoundError as exc: + raise ImportError(f"File {filename} not found error: {exc.strerror}") from exc + except json.JSONDecodeError as exc: + raise ImportError(f"Invalid JSON {filename}") from exc + +def read_release_matrix(): + return load_json_from_basedir("release_matrix.json") def check_version(package: str) -> None: # only makes sense to check nightly package where dates are known @@ -62,6 +76,16 @@ def check_version(package: str) -> None: raise RuntimeError( f"Torch version mismatch, expected {stable_version} for channel {channel}. But its {torch.__version__}" ) + release_version = read_release_matrix() + if package == "all": + for module in MODULES: + imported_module = importlib.import_module(module["name"]) + module_version = imported_module.__version__ + if not module_version.startswith(release_version[module["name"]]): + raise RuntimeError( + f"{module['name']} version mismatch, expected {release_version[module['name']]} for channel {channel}. But its {module_version}" + ) + else: print(f"Skip version check for channel {channel} as stable version is None")