Skip to content

Commit

Permalink
[ENH] testing estimators whose package dependencies are changed in `p…
Browse files Browse the repository at this point in the history
…yproject.toml` (sktime#5727)

This PR adds a condition to differential testing, so classes whose
dependencies have been updated in `pyproject.toml` are always tested.

This logic is based on an utility that determines which package
dependencies are changed by a pull request, and adds a condition

The utility could further be useful in:

* hypothetical test environment setup per estimator, such as discussed
in sktime#5719
  • Loading branch information
fkiraly authored Feb 10, 2024
1 parent 02b1feb commit 603f93b
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 17 deletions.
68 changes: 51 additions & 17 deletions sktime/tests/test_switch.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,26 +16,36 @@ def run_test_for_class(cls):
1. whether all required soft dependencies are present.
If not, does not run the test.
If yes, runs the test if and only if
at least one of conditions 2, 3 below are met.
If yes, behaviour depends on ONLY_CHANGED_MODULES setting:
if off (False), always runs the test (return True);
if on (True), runs test if and only if
at least one of conditions 2, 3, 4 below are met.
2. Condition 2:
* if ONLY_CHANGED_MODULES setting is on, condition 2 is met if and only
if the module containing the class/func has changed according to is_class_changed
* if ONLY_CHANGED_MODULES if off, condition 2 is always met.
If the module containing the class/func has changed according to is_class_changed,
or one of the modules containing any parent classes in sktime,
then condition 2 is met.
3. Condition 3:
If the object is an sktime BaseObject, and one of the test classes
If the object is an sktime ``BaseObject``, and one of the test classes
covering the class have changed, then condition 3 is met.
4. Condition 4:
If the object is an sktime ``BaseObject``, and the package requirements
for any of its dependencies have changed in ``pyproject.toml``,
condition 4 is met.
cls can also be a list of classes or functions,
in this case the test is run if and only if:
in this case the test is run if and only if both of the following are True:
* all required soft dependencies are present
* if yes, if any of the estimators in the list should be tested by
criterion 2 or 3 above
* if ``ONLY_CHANGED_MODULES`` is True, additionally,
if any of the estimators in the list should be tested by
at least one of criteria 2-4 above.
If ``ONLY_CHANGED_MODULES`` is False, this condition is always True.
Parameters
----------
Expand All @@ -51,9 +61,11 @@ class for which to determine whether it should be tested
cls = [cls]

from sktime.tests.test_all_estimators import ONLY_CHANGED_MODULES
from sktime.utils.git_diff import is_class_changed
from sktime.utils.git_diff import get_packages_with_changed_specs, is_class_changed
from sktime.utils.validation._dependencies import _check_estimator_deps

PACKAGE_REQ_CHANGED = get_packages_with_changed_specs()

def _required_deps_present(obj):
"""Check if all required soft dependencies are present, return bool."""
if hasattr(obj, "get_class_tag"):
Expand Down Expand Up @@ -81,24 +93,46 @@ def _tests_covering_class_changed(cls):
test_classes = get_test_classes_for_obj(cls)
return any(is_class_changed(x) for x in test_classes)

def _is_impacted_by_pyproject_change(cls):
"""Check if the dep specifcations of cls have changed, return bool."""
from packaging.requirements import Requirement

if not isclass(cls) or not hasattr(cls, "get_class_tags"):
return False

cls_reqs = cls.get_class_tag("python_dependencies", [])
if cls_reqs is None:
cls_reqs = []
if not isinstance(cls_reqs, list):
cls_reqs = [cls_reqs]
package_deps = [Requirement(req).name for req in cls_reqs]

return any(x in PACKAGE_REQ_CHANGED for x in package_deps)

# Condition 1:
# if any of the required soft dependencies are not present, do not run the test
if not all(_required_deps_present(x) for x in cls):
return False
# otherwise, continue

# if ONLY_CHANGED_MODULES is off: always True
# tests are always run if soft dependencies are present
if not ONLY_CHANGED_MODULES:
return True

# Condition 2:
# if ONLY_CHANGED_MODULES is on, run the test if and only if
# any of the modules containing any of the classes in the list have changed
if ONLY_CHANGED_MODULES:
cond2 = any(_is_class_changed_or_sktime_parents(x) for x in cls)
else:
cond2 = True
# or any of the modules containing any parent classes in sktime have changed
cond2 = any(_is_class_changed_or_sktime_parents(x) for x in cls)

# Condition 3:
# if the object is an sktime BaseObject, and one of the test classes
# covering the class have changed, then run the test
cond3 = any(_tests_covering_class_changed(x) for x in cls)

# run the test if and only if at least one of the conditions 2, 3 are met
return cond2 or cond3
# Condition 4:
# the package requirements for any dependency in pyproject.toml have changed
cond4 = any(_is_impacted_by_pyproject_change(x) for x in cls)

# run the test if and only if at least one of the conditions 2-4 are met
return cond2 or cond3 or cond4
80 changes: 80 additions & 0 deletions sktime/utils/git_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,83 @@ class to get module string from, e.g., NaiveForecaster
"""
module_str = get_module_from_class(cls)
return is_module_changed(module_str)


def get_changed_lines(file_path, only_indented=True):
"""Get changed or added lines from a file.
Compares the current branch to the origin-main branch.
Parameters
----------
file_path : str
path to file to get changed lines from
only_indented : bool, default=True
if True, only indented lines are returned, otherwise all lines are returned;
more precisely, only changed/added lines starting with a space are returned
Returns
-------
list of str : changed or added lines on current branch
"""
cmd = f"git diff remotes/origin/main -- {file_path}"

try:
# Run 'git diff' command to get the changes in the specified file
result = subprocess.check_output(cmd, shell=True, text=True)

# if only indented lines are requested, add space to start_chars
start_chars = "+"
if only_indented:
start_chars += " "

# Extract the changed or new lines and return as a list of strings
changed_lines = [
line.strip() for line in result.split("\n") if line.startswith(start_chars)
]
# remove first character ('+') from each line
changed_lines = [line[1:] for line in changed_lines]

return changed_lines

except subprocess.CalledProcessError:
return []


def get_packages_with_changed_specs():
"""Get packages with changed or added specs.
Returns
-------
list of str : names of packages with changed or added specs
"""
from packaging.requirements import Requirement

changed_lines = get_changed_lines("pyproject.toml")

packages = []
for line in changed_lines:
if line.find("'") > line.find('"') and line.find('"') != -1:
sep = '"'
elif line.find("'") == -1:
sep = '"'
else:
sep = "'"

splits = line.split(sep)
if len(splits) < 2:
continue

req = line.split(sep)[1]

# deal with ; python_version >= "3.7" in requirements
if ";" in req:
req = req.split(";")[0]

pkg = Requirement(req).name
packages.append(pkg)

# make unique
packages = list(set(packages))

return packages

0 comments on commit 603f93b

Please sign in to comment.