Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add missing functools.wraps decorator to deprecated decorator and handle dataclass properly #699

Merged
merged 18 commits into from
Jul 23, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions src/monty/dev.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import subprocess
import sys
import warnings
from dataclasses import is_dataclass
from datetime import datetime
from typing import TYPE_CHECKING

Expand Down Expand Up @@ -121,6 +122,7 @@ def craft_message(
return msg

def deprecated_function_decorator(old: Callable) -> Callable:
@functools.wraps(old)
def wrapped(*args, **kwargs):
msg = craft_message(old, replacement, message, _deadline)
warnings.warn(msg, category=category, stacklevel=2)
Expand All @@ -129,14 +131,23 @@ def wrapped(*args, **kwargs):
return wrapped

def deprecated_class_decorator(cls: Type) -> Type:
original_init = cls.__init__
# Modify __post_init__ for dataclass
if is_dataclass(cls) and hasattr(cls, "__post_init__"):
DanielYang59 marked this conversation as resolved.
Show resolved Hide resolved
original_init = cls.__post_init__
else:
original_init = cls.__init__

@functools.wraps(original_init)
def new_init(self, *args, **kwargs):
msg = craft_message(cls, replacement, message, _deadline)
warnings.warn(msg, category=category, stacklevel=2)
original_init(self, *args, **kwargs)

cls.__init__ = new_init
if is_dataclass(cls) and hasattr(cls, "__post_init__"):
cls.__post_init__ = new_init
else:
cls.__init__ = new_init

return cls

# Convert deadline to datetime type
Expand Down
45 changes: 26 additions & 19 deletions src/monty/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from importlib import import_module
from inspect import getfullargspec
from pathlib import Path
from typing import Any, Dict
from typing import Any
from uuid import UUID, uuid4

try:
Expand Down Expand Up @@ -174,17 +174,17 @@ def as_dict(self) -> dict:
"""
A JSON serializable dict representation of an object.
"""
d = {
d: dict[str, Any] = {
"@module": self.__class__.__module__,
"@class": self.__class__.__name__,
}

try:
parent_module = self.__class__.__module__.split(".", maxsplit=1)[0]
module_version = import_module(parent_module).__version__ # type: ignore
module_version = import_module(parent_module).__version__
d["@version"] = str(module_version)
except (AttributeError, ImportError):
d["@version"] = None # type: ignore
d["@version"] = None

spec = getfullargspec(self.__class__.__init__)

Expand Down Expand Up @@ -225,21 +225,24 @@ def recursive_as_dict(obj):
)
d[c] = recursive_as_dict(a)
if hasattr(self, "kwargs"):
# type: ignore
d.update(**self.kwargs) # pylint: disable=E1101
d.update(**self.kwargs)
if spec.varargs is not None and getattr(self, spec.varargs, None) is not None:
d.update({spec.varargs: getattr(self, spec.varargs)})
if hasattr(self, "_kwargs"):
d.update(**self._kwargs) # pylint: disable=E1101
d.update(**self._kwargs)
if isinstance(self, Enum):
d.update({"value": self.value}) # pylint: disable=E1101
d.update({"value": self.value})
return d

@classmethod
def from_dict(cls, d):
"""
:param d: Dict representation.
:return: MSONable class.

Args:
d: Dict representation.

Returns:
MSONable class.
"""
decoded = {
k: MontyDecoder().process_decoded(v)
Expand Down Expand Up @@ -547,27 +550,31 @@ class MontyEncoder(json.JSONEncoder):
json.dumps(object, cls=MontyEncoder)
"""

def __init__(self, *args, allow_unserializable_objects=False, **kwargs):
def __init__(
self, *args, allow_unserializable_objects: bool = False, **kwargs
) -> None:
super().__init__(*args, **kwargs)
self._allow_unserializable_objects = allow_unserializable_objects
self._name_object_map: Dict[str, Any] = {}
self._index = 0
self._name_object_map: dict[str, Any] = {}
self._index: int = 0

def _update_name_object_map(self, o):
name = f"{self._index:012}-{str(uuid4())}"
self._index += 1
self._name_object_map[name] = o
return {"@object_reference": name}

def default(self, o) -> dict: # pylint: disable=E0202
def default(self, o) -> dict:
"""
Overriding default method for JSON encoding. This method does two
things: (a) If an object has a to_dict property, return the to_dict
output. (b) If the @module and @class keys are not in the to_dict,
add them to the output automatically. If the object has no to_dict
property, the default Python json encoder default method is called.

Args:
o: Python object.

Return:
Python dict representation.
"""
Expand All @@ -584,13 +591,13 @@ def default(self, o) -> dict: # pylint: disable=E0202

if torch is not None and isinstance(o, torch.Tensor):
# Support for Pytorch Tensors.
d = {
d: dict[str, Any] = {
"@module": "torch",
"@class": "Tensor",
"dtype": o.type(),
}
if "Complex" in o.type():
d["data"] = [o.real.tolist(), o.imag.tolist()] # type: ignore
d["data"] = [o.real.tolist(), o.imag.tolist()]
else:
d["data"] = o.numpy().tolist()
return d
Expand Down Expand Up @@ -651,7 +658,7 @@ def default(self, o) -> dict: # pylint: disable=E0202
and dataclasses.is_dataclass(o)
):
# This handles dataclasses that are not subclasses of MSONAble.
d = dataclasses.asdict(o)
d = dataclasses.asdict(o) # type: ignore[call-overload]
elif hasattr(o, "as_dict"):
d = o.as_dict()
elif isinstance(o, Enum):
Expand All @@ -673,10 +680,10 @@ def default(self, o) -> dict: # pylint: disable=E0202
if "@version" not in d:
try:
parent_module = o.__class__.__module__.split(".")[0]
module_version = import_module(parent_module).__version__ # type: ignore
module_version = import_module(parent_module).__version__
d["@version"] = str(module_version)
except (AttributeError, ImportError):
d["@version"] = None # type: ignore
d["@version"] = None
return d
except AttributeError:
return json.JSONEncoder.default(self, o)
Expand Down
19 changes: 17 additions & 2 deletions tests/test_dev.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def func_replace():

@deprecated(func_replace, "Use func_replace instead")
def func_old():
"""This is the old function."""
pass

with warnings.catch_warnings(record=True) as w:
Expand All @@ -28,6 +29,10 @@ def func_old():
assert issubclass(w[0].category, FutureWarning)
assert "Use func_replace instead" in str(w[0].message)

# Check metadata preservation
assert func_old.__name__ == "func_old"
assert func_old.__doc__ == "This is the old function."

def test_deprecated_str_replacement(self):
@deprecated("func_replace")
def func_old():
Expand Down Expand Up @@ -112,13 +117,23 @@ def method_a(self):

@deprecated(replacement=TestClassNew)
class TestClassOld:
"""A dummy class for tests."""
"""A dummy old class for tests."""

class_attrib_old = "OLD_ATTRIB"

def method_b(self):
"""This is method_b."""
pass

with pytest.warns(FutureWarning, match="TestClassOld is deprecated"):
TestClassOld()
old_class = TestClassOld()

# Check metadata preservation
assert TestClassOld.__doc__ == "A dummy old class for tests."
assert old_class.class_attrib_old == "OLD_ATTRIB"
assert TestClassOld.__module__ == __name__

assert TestClassOld.method_b.__doc__ == "This is method_b."

def test_deprecated_deadline(self, monkeypatch):
with pytest.raises(DeprecationWarning):
Expand Down
Loading