From 603f93beb2055c3c19e626cc7c22ee8e7a4418c6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franz=20Kir=C3=A1ly?= Date: Sat, 10 Feb 2024 12:20:01 +0100 Subject: [PATCH] [ENH] testing estimators whose package dependencies are changed in `pyproject.toml` (#5727) This PR adds a condition to differential testing, so classes whose dependencies have been updated in `pyproject.toml` are always tested. This logic is based on an utility that determines which package dependencies are changed by a pull request, and adds a condition The utility could further be useful in: * hypothetical test environment setup per estimator, such as discussed in https://github.com/sktime/sktime/issues/5719 --- sktime/tests/test_switch.py | 68 +++++++++++++++++++++++-------- sktime/utils/git_diff.py | 80 +++++++++++++++++++++++++++++++++++++ 2 files changed, 131 insertions(+), 17 deletions(-) diff --git a/sktime/tests/test_switch.py b/sktime/tests/test_switch.py index 443be8569ca..742cc457d03 100644 --- a/sktime/tests/test_switch.py +++ b/sktime/tests/test_switch.py @@ -16,26 +16,36 @@ def run_test_for_class(cls): 1. whether all required soft dependencies are present. If not, does not run the test. - If yes, runs the test if and only if - at least one of conditions 2, 3 below are met. + If yes, behaviour depends on ONLY_CHANGED_MODULES setting: + if off (False), always runs the test (return True); + if on (True), runs test if and only if + at least one of conditions 2, 3, 4 below are met. 2. Condition 2: - * if ONLY_CHANGED_MODULES setting is on, condition 2 is met if and only - if the module containing the class/func has changed according to is_class_changed - * if ONLY_CHANGED_MODULES if off, condition 2 is always met. + If the module containing the class/func has changed according to is_class_changed, + or one of the modules containing any parent classes in sktime, + then condition 2 is met. 3. Condition 3: - If the object is an sktime BaseObject, and one of the test classes + If the object is an sktime ``BaseObject``, and one of the test classes covering the class have changed, then condition 3 is met. + 4. Condition 4: + + If the object is an sktime ``BaseObject``, and the package requirements + for any of its dependencies have changed in ``pyproject.toml``, + condition 4 is met. + cls can also be a list of classes or functions, - in this case the test is run if and only if: + in this case the test is run if and only if both of the following are True: * all required soft dependencies are present - * if yes, if any of the estimators in the list should be tested by - criterion 2 or 3 above + * if ``ONLY_CHANGED_MODULES`` is True, additionally, + if any of the estimators in the list should be tested by + at least one of criteria 2-4 above. + If ``ONLY_CHANGED_MODULES`` is False, this condition is always True. Parameters ---------- @@ -51,9 +61,11 @@ class for which to determine whether it should be tested cls = [cls] from sktime.tests.test_all_estimators import ONLY_CHANGED_MODULES - from sktime.utils.git_diff import is_class_changed + from sktime.utils.git_diff import get_packages_with_changed_specs, is_class_changed from sktime.utils.validation._dependencies import _check_estimator_deps + PACKAGE_REQ_CHANGED = get_packages_with_changed_specs() + def _required_deps_present(obj): """Check if all required soft dependencies are present, return bool.""" if hasattr(obj, "get_class_tag"): @@ -81,24 +93,46 @@ def _tests_covering_class_changed(cls): test_classes = get_test_classes_for_obj(cls) return any(is_class_changed(x) for x in test_classes) + def _is_impacted_by_pyproject_change(cls): + """Check if the dep specifcations of cls have changed, return bool.""" + from packaging.requirements import Requirement + + if not isclass(cls) or not hasattr(cls, "get_class_tags"): + return False + + cls_reqs = cls.get_class_tag("python_dependencies", []) + if cls_reqs is None: + cls_reqs = [] + if not isinstance(cls_reqs, list): + cls_reqs = [cls_reqs] + package_deps = [Requirement(req).name for req in cls_reqs] + + return any(x in PACKAGE_REQ_CHANGED for x in package_deps) + # Condition 1: # if any of the required soft dependencies are not present, do not run the test if not all(_required_deps_present(x) for x in cls): return False # otherwise, continue + # if ONLY_CHANGED_MODULES is off: always True + # tests are always run if soft dependencies are present + if not ONLY_CHANGED_MODULES: + return True + # Condition 2: - # if ONLY_CHANGED_MODULES is on, run the test if and only if # any of the modules containing any of the classes in the list have changed - if ONLY_CHANGED_MODULES: - cond2 = any(_is_class_changed_or_sktime_parents(x) for x in cls) - else: - cond2 = True + # or any of the modules containing any parent classes in sktime have changed + cond2 = any(_is_class_changed_or_sktime_parents(x) for x in cls) # Condition 3: # if the object is an sktime BaseObject, and one of the test classes # covering the class have changed, then run the test cond3 = any(_tests_covering_class_changed(x) for x in cls) - # run the test if and only if at least one of the conditions 2, 3 are met - return cond2 or cond3 + # Condition 4: + # the package requirements for any dependency in pyproject.toml have changed + cond4 = any(_is_impacted_by_pyproject_change(x) for x in cls) + + # run the test if and only if at least one of the conditions 2-4 are met + return cond2 or cond3 or cond4 diff --git a/sktime/utils/git_diff.py b/sktime/utils/git_diff.py index a88a7e9ac6a..fb0b29a70fb 100644 --- a/sktime/utils/git_diff.py +++ b/sktime/utils/git_diff.py @@ -78,3 +78,83 @@ class to get module string from, e.g., NaiveForecaster """ module_str = get_module_from_class(cls) return is_module_changed(module_str) + + +def get_changed_lines(file_path, only_indented=True): + """Get changed or added lines from a file. + + Compares the current branch to the origin-main branch. + + Parameters + ---------- + file_path : str + path to file to get changed lines from + only_indented : bool, default=True + if True, only indented lines are returned, otherwise all lines are returned; + more precisely, only changed/added lines starting with a space are returned + + Returns + ------- + list of str : changed or added lines on current branch + """ + cmd = f"git diff remotes/origin/main -- {file_path}" + + try: + # Run 'git diff' command to get the changes in the specified file + result = subprocess.check_output(cmd, shell=True, text=True) + + # if only indented lines are requested, add space to start_chars + start_chars = "+" + if only_indented: + start_chars += " " + + # Extract the changed or new lines and return as a list of strings + changed_lines = [ + line.strip() for line in result.split("\n") if line.startswith(start_chars) + ] + # remove first character ('+') from each line + changed_lines = [line[1:] for line in changed_lines] + + return changed_lines + + except subprocess.CalledProcessError: + return [] + + +def get_packages_with_changed_specs(): + """Get packages with changed or added specs. + + Returns + ------- + list of str : names of packages with changed or added specs + """ + from packaging.requirements import Requirement + + changed_lines = get_changed_lines("pyproject.toml") + + packages = [] + for line in changed_lines: + if line.find("'") > line.find('"') and line.find('"') != -1: + sep = '"' + elif line.find("'") == -1: + sep = '"' + else: + sep = "'" + + splits = line.split(sep) + if len(splits) < 2: + continue + + req = line.split(sep)[1] + + # deal with ; python_version >= "3.7" in requirements + if ";" in req: + req = req.split(";")[0] + + pkg = Requirement(req).name + packages.append(pkg) + + # make unique + packages = list(set(packages)) + + return packages