From a210e0afa32dd9853ad10f08ebf30fb7361697e1 Mon Sep 17 00:00:00 2001 From: Louis Dupont Date: Sun, 27 Aug 2023 17:44:43 +0300 Subject: [PATCH] add version and tests --- src/super_gradients/common/deprecate.py | 20 ++++++--- tests/unit_tests/test_deprecate.py | 60 +++++++++++++++++++++++-- 2 files changed, 71 insertions(+), 9 deletions(-) diff --git a/src/super_gradients/common/deprecate.py b/src/super_gradients/common/deprecate.py index 2318db8eae..be9f7ee971 100644 --- a/src/super_gradients/common/deprecate.py +++ b/src/super_gradients/common/deprecate.py @@ -1,16 +1,19 @@ import warnings from functools import wraps from typing import Optional +from pkg_resources import parse_version +import super_gradients -def deprecate_call(deprecated_in_v: str, removed_in_v: str, target: Optional[callable] = None, reason: str = ""): + +def deprecate_call(deprecated_in_v: str, remove_in_v: str, target: Optional[callable] = None, reason: str = ""): """ Decorator to mark a callable as deprecated. It provides a clear and actionable warning message informing the user about the version in which the function was deprecated, the version in which it will be removed, and guidance on how to replace it. :param deprecated_in_v: Version number when the function was deprecated. - :param removed_in_v: Version number when the function will be removed. + :param remove_in_v: Version number when the function will be removed. :param target: (Optional) The new function that should be used as a replacement. If provided, it will guide the user to the updated function. :param reason: (Optional) Additional information or reason for the deprecation. @@ -18,12 +21,12 @@ def deprecate_call(deprecated_in_v: str, removed_in_v: str, target: Optional[cal If a direct replacement function exists: >> from new.module.path import new_get_local_rank - >> @deprecate_call(deprecated_in_v='3.2.0', removed_in_v='4.0.0', target=new_get_local_rank, reason="Replaced for optimization") + >> @deprecate_call(deprecated_in_v='3.2.0', remove_in_v='4.0.0', target=new_get_local_rank, reason="Replaced for optimization") >> def get_local_rank(): >> return new_get_local_rank() If there's no direct replacement: - >> @deprecate_call(deprecated_in_v='3.2.0', removed_in_v='4.0.0', reason="Function is no longer needed due to XYZ reason") + >> @deprecate_call(deprecated_in_v='3.2.0', remove_in_v='4.0.0', reason="Function is no longer needed due to XYZ reason") >> def some_old_function(): >> # ... function logic ... @@ -38,12 +41,19 @@ def deprecate_call(deprecated_in_v: str, removed_in_v: str, target: Optional[cal """ def decorator(old_func: callable) -> callable: + + if parse_version(super_gradients.__version__) >= parse_version(remove_in_v): + raise ValueError( + f"`super_gradients.__version__={super_gradients.__version__}` >= `remove_in_v={remove_in_v}`. " + f"Please remove {old_func.__module__}.{old_func.__name__} from your code base." + ) + @wraps(old_func) def wrapper(*args, **kwargs): if not wrapper._warned: message = ( f"Function `{old_func.__module__}.{old_func.__name__}` is deprecated since version `{deprecated_in_v}` " - f"and will be removed in version `{removed_in_v}`.\n" + f"and will be removed in version `{remove_in_v}`.\n" ) if reason: message += f"Reason: {reason}.\n" diff --git a/tests/unit_tests/test_deprecate.py b/tests/unit_tests/test_deprecate.py index 5a34bd1fbd..f517d8c4b1 100644 --- a/tests/unit_tests/test_deprecate.py +++ b/tests/unit_tests/test_deprecate.py @@ -1,5 +1,7 @@ -import unittest import warnings +import unittest +from unittest.mock import patch + from super_gradients.common.deprecate import deprecate_call @@ -11,11 +13,11 @@ def setUp(self): def new_func(): return self.new_function_message - @deprecate_call(deprecated_in_v="3.2.0", removed_in_v="4.0.0", target=new_func, reason="Replaced for optimization") + @deprecate_call(deprecated_in_v="3.2.0", remove_in_v="10.0.0", target=new_func, reason="Replaced for optimization") def fully_configured_deprecated_func(): return new_func() - @deprecate_call(deprecated_in_v="3.2.0", removed_in_v="4.0.0") + @deprecate_call(deprecated_in_v="3.2.0", remove_in_v="10.0.0") def basic_deprecated_func(): return new_func() @@ -23,6 +25,18 @@ def basic_deprecated_func(): self.fully_configured_deprecated_func = fully_configured_deprecated_func self.basic_deprecated_func = basic_deprecated_func + class NewClass: + def __init__(self): + pass + + @deprecate_call(deprecated_in_v="3.2.0", remove_in_v="10.0.0", target=NewClass, reason="Replaced for optimization") + class DeprecatedClass: + def __init__(self): + pass + + self.NewClass = NewClass + self.DeprecatedClass = DeprecatedClass + def test_emits_warning(self): """Ensure that the deprecated function emits a warning when called.""" with warnings.catch_warnings(record=True) as w: @@ -42,7 +56,7 @@ def test_displays_removed_version(self): with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") self.fully_configured_deprecated_func() - self.assertTrue(any("4.0.0" in str(warning.message) for warning in w)) + self.assertTrue(any("10.0.0" in str(warning.message) for warning in w)) def test_guidance_on_replacement(self): """Ensure that if a replacement target is provided, guidance on using the new function is included in the warning.""" @@ -74,6 +88,44 @@ def test_basic_deprecation_emits_warning(self): self.basic_deprecated_func() self.assertEqual(len(w), 1) + def test_class_deprecation_warning(self): + """Ensure that creating an instance of a deprecated class emits a warning.""" + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + _ = self.DeprecatedClass() # Instantiate the deprecated class + self.assertEqual(len(w), 1) + + def test_class_deprecation_message_content(self): + """Ensure that the emitted warning for a deprecated class contains relevant information including target class.""" + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + _ = self.DeprecatedClass() + self.assertTrue(any("3.2.0" in str(warning.message) for warning in w)) + self.assertTrue(any("10.0.0" in str(warning.message) for warning in w)) + self.assertTrue(any("DeprecatedClass" in str(warning.message) for warning in w)) + self.assertTrue(any("Replaced for optimization" in str(warning.message) for warning in w)) + self.assertTrue(any("NewClass" in str(warning.message) for warning in w)) + + def test_raise_error_when_library_version_equals_removal_version(self): + """Ensure that an error is raised when the library's version equals the function's removal version.""" + with patch("super_gradients.__version__", "10.1.0"): # Mocking the version to be equal to removal version + with self.assertRaises(ValueError): + + @deprecate_call(deprecated_in_v="3.2.0", remove_in_v="10.1.0", target=self.new_func) + def deprecated_func_version_equal(): + return + + def test_no_error_when_library_version_below_removal_version(self): + """Ensure that no error is raised when the library's version is below the function's removal version.""" + with patch("super_gradients.__version__", "10.1.0"): # Mocking the version to be below removal version + + @deprecate_call(deprecated_in_v="3.2.0", remove_in_v="10.2.0", target=self.new_func) + def deprecated_func_version_below(): + return + + # Actually call the function to check no exception is raised + deprecated_func_version_below() + if __name__ == "__main__": unittest.main()