diff --git a/src/super_gradients/__init__.py b/src/super_gradients/__init__.py index f2de3ba001..329f8d8586 100755 --- a/src/super_gradients/__init__.py +++ b/src/super_gradients/__init__.py @@ -1,3 +1,5 @@ +__version__ = "3.2.0" + from super_gradients.common import init_trainer, is_distributed, object_names from super_gradients.training import losses, utils, datasets_utils, DataAugmentation, Trainer, KDTrainer, QATTrainer from super_gradients.common.registry.registry import ARCHITECTURES @@ -23,6 +25,4 @@ "AutoTrainBatchSizeSelectionCallback", ] -__version__ = "3.2.0" - env_sanity_check() diff --git a/src/super_gradients/common/deprecate.py b/src/super_gradients/common/deprecate.py new file mode 100644 index 0000000000..516ec85ff4 --- /dev/null +++ b/src/super_gradients/common/deprecate.py @@ -0,0 +1,78 @@ +import warnings +from functools import wraps +from typing import Optional +from pkg_resources import parse_version + + +def deprecated(deprecated_since: str, removed_from: str, target: Optional[callable] = None, reason: str = ""): + """ + Decorator to mark a callable as deprecated. Works on functions and classes. + It provides a clear and actionable warning message informing + the user about the version in which the function was deprecated, the version in which it will be removed, + and guidance on how to replace it. + + :param deprecated_since: Version number when the function was deprecated. + :param removed_from: Version number when the function will be removed. + :param target: (Optional) The new function that should be used as a replacement. If provided, it will guide the user to the updated function. + :param reason: (Optional) Additional information or reason for the deprecation. + + Example usage: + If a direct replacement function exists: + >> from new.module.path import new_get_local_rank + + >> @deprecated(deprecated_since='3.2.0', removed_from='4.0.0', target=new_get_local_rank, reason="Replaced for optimization") + >> def get_local_rank(): + >> return new_get_local_rank() + + If there's no direct replacement: + >> @deprecated(deprecated_since='3.2.0', removed_from='4.0.0', reason="Function is no longer needed due to XYZ reason") + >> def some_old_function(): + >> # ... function logic ... + + When calling a deprecated function: + >> from some_module import get_local_rank + >> get_local_rank() + DeprecationWarning: Function `some_module.get_local_rank` is deprecated. Deprecated since version `3.2.0` + and will be removed in version `4.0.0`. Reason: `Replaced for optimization`. + Please update your code: + [-] from `some_module` import `get_local_rank` + [+] from `new.module.path` import `new_get_local_rank`. + """ + + def decorator(old_func: callable) -> callable: + @wraps(old_func) + def wrapper(*args, **kwargs): + if not wrapper._warned: + import super_gradients + + is_still_supported = parse_version(super_gradients.__version__) < parse_version(removed_from) + status_msg = "is deprecated" if is_still_supported else "was deprecated and has been removed" + message = ( + f"Callable `{old_func.__module__}.{old_func.__name__}` {status_msg} since version `{deprecated_since}` " + f"and will be removed in version `{removed_from}`.\n" + ) + if reason: + message += f"Reason: {reason}.\n" + + if target is not None: + message += ( + f"Please update your code:\n" + f" [-] from `{old_func.__module__}` import `{old_func.__name__}`\n" + f" [+] from `{target.__module__}` import `{target.__name__}`" + ) + + if is_still_supported: + warnings.simplefilter("once", DeprecationWarning) # Required, otherwise the warning may never be displayed. + warnings.warn(message, DeprecationWarning, stacklevel=2) + wrapper._warned = True + else: + raise ImportError(message) + + return old_func(*args, **kwargs) + + # Each decorated object will have its own _warned state + # This state ensures that the warning will appear only once, to avoid polluting the console in case the function is called too often. + wrapper._warned = False + return wrapper + + return decorator diff --git a/src/super_gradients/training/models/__init__.py b/src/super_gradients/training/models/__init__.py index 64becb156b..e8e29a9954 100755 --- a/src/super_gradients/training/models/__init__.py +++ b/src/super_gradients/training/models/__init__.py @@ -1,4 +1,4 @@ -import warnings +from super_gradients.common.deprecate import deprecated from .sg_module import SgModule from .classification_models.base_classifer import BaseClassifier @@ -135,51 +135,28 @@ from super_gradients.training.utils import make_divisible as _make_divisible_current_version, HpmStruct as CurrVersionHpmStruct -def make_deprecated(func, reason): - def inner(*args, **kwargs): - with warnings.catch_warnings(): - warnings.simplefilter("once", DeprecationWarning) - warnings.warn(reason, category=DeprecationWarning, stacklevel=2) - warnings.warn(reason, DeprecationWarning) - return func(*args, **kwargs) +@deprecated(deprecated_since="3.1.0", removed_from="3.4.0", target=_make_divisible_current_version) +def make_divisible(x: int, divisor: int, ceil: bool = True) -> int: + """ + Returns x evenly divisible by divisor. + If ceil=True it will return the closest larger number to the original x, and ceil=False the closest smaller number. + """ + return _make_divisible_current_version(x=x, divisor=divisor, ceil=ceil) - return inner +@deprecated(deprecated_since="3.1.0", removed_from="3.4.0", target=BasicResNetBlock, reason="This block was renamed to BasicResNetBlock for better clarity.") +class BasicBlock(BasicResNetBlock): + ... -make_divisible = make_deprecated( - func=_make_divisible_current_version, - reason="You're importing `make_divisible` from `super_gradients.training.models`. This is deprecated since SuperGradients 3.1.0.\n" - "Please update your code to import it as follows:\n" - "[-] from super_gradients.training.models import make_divisible\n" - "[+] from super_gradients.training.utils import make_divisible\n", -) +@deprecated(deprecated_since="3.1.0", removed_from="3.4.0", target=NewBottleneck, reason="This block was renamed to BasicResNetBlock for better clarity.") +class Bottleneck(NewBottleneck): + ... -BasicBlock = make_deprecated( - func=BasicResNetBlock, - reason="You're importing `BasicBlock` class from `super_gradients.training.models`. This is deprecated since SuperGradients 3.1.0.\n" - "This block was renamed to BasicResNetBlock for better clarity.\n" - "Please update your code to import it as follows:\n" - "[-] from super_gradients.training.models import BasicBlock\n" - "[+] from super_gradients.training.models import BasicResNetBlock\n", -) -Bottleneck = make_deprecated( - func=NewBottleneck, - reason="You're importing `Bottleneck` class from `super_gradients.training.models`. This is deprecated since SuperGradients 3.1.0.\n" - "This block was renamed to BasicResNetBlock for better clarity.\n" - "Please update your code to import it as follows:\n" - "[-] from super_gradients.training.models import Bottleneck\n" - "[+] from super_gradients.training.models.classification_models.resnet import Bottleneck\n", -) - -HpmStruct = make_deprecated( - func=CurrVersionHpmStruct, - reason="You're importing `HpmStruct` class from `super_gradients.training.models`. This is deprecated since SuperGradients 3.1.0.\n" - "Please update your code to import it as follows:\n" - "[-] from super_gradients.training.models import HpmStruct\n" - "[+] from super_gradients.training.utils import HpmStruct\n", -) +@deprecated(deprecated_since="3.1.0", removed_from="3.4.0", target=CurrVersionHpmStruct) +class HpmStruct(CurrVersionHpmStruct): + ... __all__ = [ diff --git a/tests/deci_core_unit_test_suite_runner.py b/tests/deci_core_unit_test_suite_runner.py index 03b55e1503..d89d5bc1d5 100644 --- a/tests/deci_core_unit_test_suite_runner.py +++ b/tests/deci_core_unit_test_suite_runner.py @@ -23,6 +23,7 @@ TestTransforms, TestPostPredictionCallback, TestModelPredict, + TestDeprecationDecorator, ) from tests.end_to_end_tests import TestTrainer from tests.unit_tests.detection_utils_test import TestDetectionUtils @@ -153,6 +154,7 @@ def _add_modules_to_unit_tests_suite(self): self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestModelPredict)) self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestDetectionModelExport)) self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(SlidingWindowTest)) + self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestDeprecationDecorator)) def _add_modules_to_end_to_end_tests_suite(self): """ diff --git a/tests/unit_tests/__init__.py b/tests/unit_tests/__init__.py index b818fcaa25..d83d7a4b14 100644 --- a/tests/unit_tests/__init__.py +++ b/tests/unit_tests/__init__.py @@ -25,6 +25,7 @@ from tests.unit_tests.transforms_test import TestTransforms from tests.unit_tests.post_prediction_callback_test import TestPostPredictionCallback from tests.unit_tests.test_predict import TestModelPredict +from tests.unit_tests.test_deprecate import TestDeprecationDecorator __all__ = [ "CrashTipTest", @@ -53,4 +54,5 @@ "TestTransforms", "TestPostPredictionCallback", "TestModelPredict", + "TestDeprecationDecorator", ] diff --git a/tests/unit_tests/test_deprecate.py b/tests/unit_tests/test_deprecate.py new file mode 100644 index 0000000000..37845d6059 --- /dev/null +++ b/tests/unit_tests/test_deprecate.py @@ -0,0 +1,132 @@ +import warnings +import unittest +from unittest.mock import patch + +from super_gradients.common.deprecate import deprecated + + +class TestDeprecationDecorator(unittest.TestCase): + def setUp(self): + """Prepare required functions before each test.""" + self.new_function_message = "This is the new function!" + + def new_func(): + return self.new_function_message + + @deprecated(deprecated_since="3.2.0", removed_from="10.0.0", target=new_func, reason="Replaced for optimization") + def fully_configured_deprecated_func(): + return new_func() + + @deprecated(deprecated_since="3.2.0", removed_from="10.0.0") + def basic_deprecated_func(): + return new_func() + + self.new_func = new_func + self.fully_configured_deprecated_func = fully_configured_deprecated_func + self.basic_deprecated_func = basic_deprecated_func + + class NewClass: + def __init__(self): + pass + + @deprecated(deprecated_since="3.2.0", removed_from="10.0.0", target=NewClass, reason="Replaced for optimization") + class DeprecatedClass: + def __init__(self): + pass + + self.NewClass = NewClass + self.DeprecatedClass = DeprecatedClass + + def test_emits_warning(self): + """Ensure that the deprecated function emits a warning when called.""" + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + self.fully_configured_deprecated_func() + self.assertEqual(len(w), 1) + + def test_displays_deprecated_version(self): + """Ensure that the warning contains the version in which the function was deprecated.""" + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + self.fully_configured_deprecated_func() + self.assertTrue(any("3.2.0" in str(warning.message) for warning in w)) + + def test_displays_removed_version(self): + """Ensure that the warning contains the version in which the function will be removed.""" + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + self.fully_configured_deprecated_func() + self.assertTrue(any("10.0.0" in str(warning.message) for warning in w)) + + def test_guidance_on_replacement(self): + """Ensure that if a replacement target is provided, guidance on using the new function is included in the warning.""" + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + self.fully_configured_deprecated_func() + self.assertTrue(any("new_func" in str(warning.message) for warning in w)) + + def test_displays_reason(self): + """Ensure that if provided, the reason for deprecation is included in the warning.""" + reason_str = "Replaced for optimization" + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + self.fully_configured_deprecated_func() + self.assertTrue(any(reason_str in str(warning.message) for warning in w)) + + def test_triggered_only_once(self): + """Ensure that the deprecation warning is triggered only once even if the deprecated function is called multiple times.""" + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + for _ in range(10): + self.fully_configured_deprecated_func() + self.assertEqual(len(w), 1, "Only one warning should be emitted") + + def test_basic_deprecation_emits_warning(self): + """Ensure that a function with minimal deprecation configuration emits a warning.""" + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + self.basic_deprecated_func() + self.assertEqual(len(w), 1) + + def test_class_deprecation_warning(self): + """Ensure that creating an instance of a deprecated class emits a warning.""" + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + _ = self.DeprecatedClass() # Instantiate the deprecated class + self.assertEqual(len(w), 1) + + def test_class_deprecation_message_content(self): + """Ensure that the emitted warning for a deprecated class contains relevant information including target class.""" + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + _ = self.DeprecatedClass() + self.assertTrue(any("3.2.0" in str(warning.message) for warning in w)) + self.assertTrue(any("10.0.0" in str(warning.message) for warning in w)) + self.assertTrue(any("DeprecatedClass" in str(warning.message) for warning in w)) + self.assertTrue(any("Replaced for optimization" in str(warning.message) for warning in w)) + self.assertTrue(any("NewClass" in str(warning.message) for warning in w)) + + def test_raise_error_when_library_version_equals_removal_version(self): + """Ensure that an error is raised when the library's version equals the function's removal version.""" + with patch("super_gradients.__version__", "10.1.0"): # Mocking the version to be equal to removal version + with self.assertRaises(ImportError): + + @deprecated(deprecated_since="3.2.0", removed_from="10.1.0", target=self.new_func) + def deprecated_func_version_equal(): + return + + deprecated_func_version_equal() + + def test_no_error_when_library_version_below_removal_version(self): + """Ensure that no error is raised when the library's version is below the function's removal version.""" + with patch("super_gradients.__version__", "10.1.0"): # Mocking the version to be below removal version + + @deprecated(deprecated_since="3.2.0", removed_from="10.2.0", target=self.new_func) + def deprecated_func_version_below(): + return + + deprecated_func_version_below() + + +if __name__ == "__main__": + unittest.main()