Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add missing functools.wraps decorator to deprecated decorator and handle dataclass properly #699

Merged
merged 18 commits into from
Jul 23, 2024
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions src/monty/dev.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import subprocess
import sys
import warnings
from dataclasses import is_dataclass
from datetime import datetime
from typing import TYPE_CHECKING

Expand Down Expand Up @@ -121,6 +122,7 @@ def craft_message(
return msg

def deprecated_function_decorator(old: Callable) -> Callable:
@functools.wraps(old)
def wrapped(*args, **kwargs):
msg = craft_message(old, replacement, message, _deadline)
warnings.warn(msg, category=category, stacklevel=2)
Expand All @@ -129,14 +131,23 @@ def wrapped(*args, **kwargs):
return wrapped

def deprecated_class_decorator(cls: Type) -> Type:
original_init = cls.__init__
# Modify __post_init__ for dataclass
if is_dataclass(cls) and hasattr(cls, "__post_init__"):
DanielYang59 marked this conversation as resolved.
Show resolved Hide resolved
original_init = cls.__post_init__
else:
original_init = cls.__init__

@functools.wraps(original_init)
def new_init(self, *args, **kwargs):
msg = craft_message(cls, replacement, message, _deadline)
warnings.warn(msg, category=category, stacklevel=2)
original_init(self, *args, **kwargs)

cls.__init__ = new_init
if is_dataclass(cls) and hasattr(cls, "__post_init__"):
cls.__post_init__ = new_init
else:
cls.__init__ = new_init

return cls

# Convert deadline to datetime type
Expand Down
64 changes: 34 additions & 30 deletions src/monty/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from __future__ import annotations

import dataclasses
import datetime
import json
import os
Expand All @@ -17,48 +18,44 @@
from importlib import import_module
from inspect import getfullargspec
from pathlib import Path
from typing import Any, Dict
from typing import Any
from uuid import UUID, uuid4

try:
import numpy as np
except ImportError:
np = None # type: ignore
np = None

try:
import pydantic
except ImportError:
pydantic = None # type: ignore
pydantic = None

try:
from pydantic_core import core_schema
except ImportError:
core_schema = None # type: ignore
core_schema = None

try:
import bson
except ImportError:
bson = None # type: ignore
bson = None

try:
from ruamel.yaml import YAML
except ImportError:
YAML = None # type: ignore
YAML = None

try:
import orjson
except ImportError:
orjson = None # type: ignore
orjson = None

try:
import dataclasses
except ImportError:
dataclasses = None # type: ignore
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's safe to remove guard for dataclasses now.

Added in version 3.7.


try:
import torch
except ImportError:
torch = None # type: ignore
torch = None

__version__ = "3.0.0"

Expand Down Expand Up @@ -174,17 +171,17 @@ def as_dict(self) -> dict:
"""
A JSON serializable dict representation of an object.
"""
d = {
d: dict[str, Any] = {
"@module": self.__class__.__module__,
"@class": self.__class__.__name__,
}

try:
parent_module = self.__class__.__module__.split(".", maxsplit=1)[0]
module_version = import_module(parent_module).__version__ # type: ignore
module_version = import_module(parent_module).__version__
d["@version"] = str(module_version)
except (AttributeError, ImportError):
d["@version"] = None # type: ignore
d["@version"] = None

spec = getfullargspec(self.__class__.__init__)

Expand Down Expand Up @@ -225,21 +222,24 @@ def recursive_as_dict(obj):
)
d[c] = recursive_as_dict(a)
if hasattr(self, "kwargs"):
# type: ignore
d.update(**self.kwargs) # pylint: disable=E1101
d.update(**self.kwargs)
if spec.varargs is not None and getattr(self, spec.varargs, None) is not None:
d.update({spec.varargs: getattr(self, spec.varargs)})
if hasattr(self, "_kwargs"):
d.update(**self._kwargs) # pylint: disable=E1101
d.update(**self._kwargs)
if isinstance(self, Enum):
d.update({"value": self.value}) # pylint: disable=E1101
d.update({"value": self.value})
return d

@classmethod
def from_dict(cls, d):
"""
:param d: Dict representation.
:return: MSONable class.

Args:
d: Dict representation.

Returns:
MSONable class.
"""
decoded = {
k: MontyDecoder().process_decoded(v)
Expand Down Expand Up @@ -547,27 +547,31 @@ class MontyEncoder(json.JSONEncoder):
json.dumps(object, cls=MontyEncoder)
"""

def __init__(self, *args, allow_unserializable_objects=False, **kwargs):
def __init__(
self, *args, allow_unserializable_objects: bool = False, **kwargs
) -> None:
super().__init__(*args, **kwargs)
self._allow_unserializable_objects = allow_unserializable_objects
self._name_object_map: Dict[str, Any] = {}
self._index = 0
self._name_object_map: dict[str, Any] = {}
self._index: int = 0

def _update_name_object_map(self, o):
name = f"{self._index:012}-{str(uuid4())}"
self._index += 1
self._name_object_map[name] = o
return {"@object_reference": name}

def default(self, o) -> dict: # pylint: disable=E0202
def default(self, o) -> dict:
"""
Overriding default method for JSON encoding. This method does two
things: (a) If an object has a to_dict property, return the to_dict
output. (b) If the @module and @class keys are not in the to_dict,
add them to the output automatically. If the object has no to_dict
property, the default Python json encoder default method is called.

Args:
o: Python object.

Return:
Python dict representation.
"""
Expand All @@ -584,13 +588,13 @@ def default(self, o) -> dict: # pylint: disable=E0202

if torch is not None and isinstance(o, torch.Tensor):
# Support for Pytorch Tensors.
d = {
d: dict[str, Any] = {
"@module": "torch",
"@class": "Tensor",
"dtype": o.type(),
}
if "Complex" in o.type():
d["data"] = [o.real.tolist(), o.imag.tolist()] # type: ignore
d["data"] = [o.real.tolist(), o.imag.tolist()]
else:
d["data"] = o.numpy().tolist()
return d
Expand Down Expand Up @@ -651,7 +655,7 @@ def default(self, o) -> dict: # pylint: disable=E0202
and dataclasses.is_dataclass(o)
):
# This handles dataclasses that are not subclasses of MSONAble.
d = dataclasses.asdict(o)
d = dataclasses.asdict(o) # type: ignore[call-overload]
elif hasattr(o, "as_dict"):
d = o.as_dict()
elif isinstance(o, Enum):
Expand All @@ -673,10 +677,10 @@ def default(self, o) -> dict: # pylint: disable=E0202
if "@version" not in d:
try:
parent_module = o.__class__.__module__.split(".")[0]
module_version = import_module(parent_module).__version__ # type: ignore
module_version = import_module(parent_module).__version__
d["@version"] = str(module_version)
except (AttributeError, ImportError):
d["@version"] = None # type: ignore
d["@version"] = None
return d
except AttributeError:
return json.JSONEncoder.default(self, o)
Expand Down
50 changes: 49 additions & 1 deletion tests/test_dev.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import datetime
import unittest
import warnings
from dataclasses import dataclass
from unittest.mock import patch

import pytest
Expand All @@ -19,6 +20,7 @@ def func_replace():

@deprecated(func_replace, "Use func_replace instead")
def func_old():
"""This is the old function."""
pass

with warnings.catch_warnings(record=True) as w:
Expand All @@ -28,6 +30,10 @@ def func_old():
assert issubclass(w[0].category, FutureWarning)
assert "Use func_replace instead" in str(w[0].message)

# Check metadata preservation
assert func_old.__name__ == "func_old"
assert func_old.__doc__ == "This is the old function."

def test_deprecated_str_replacement(self):
@deprecated("func_replace")
def func_old():
Expand Down Expand Up @@ -112,13 +118,55 @@ def method_a(self):

@deprecated(replacement=TestClassNew)
class TestClassOld:
"""A dummy old class for tests."""

class_attrib_old = "OLD_ATTRIB"

def method_b(self):
"""This is method_b."""
pass

with pytest.warns(FutureWarning, match="TestClassOld is deprecated"):
old_class = TestClassOld()

# Check metadata preservation
assert old_class.__doc__ == "A dummy old class for tests."
assert old_class.class_attrib_old == "OLD_ATTRIB"
assert old_class.__module__ == __name__

assert old_class.method_b.__doc__ == "This is method_b."

def test_deprecated_dataclass(self):
@dataclass
class TestClassNew:
"""A dummy class for tests."""

def __post_init__(self):
print("Hello.")

def method_a(self):
pass

@deprecated(replacement=TestClassNew)
@dataclass
class TestClassOld:
"""A dummy old class for tests."""

class_attrib_old = "OLD_ATTRIB"

def __post_init__(self):
print("Hello.")

def method_b(self):
"""This is method_b."""
pass

with pytest.warns(FutureWarning, match="TestClassOld is deprecated"):
TestClassOld()
old_class = TestClassOld()

# Check metadata preservation
assert old_class.__doc__ == "A dummy old class for tests."
assert old_class.class_attrib_old == "OLD_ATTRIB"

def test_deprecated_deadline(self, monkeypatch):
with pytest.raises(DeprecationWarning):
Expand Down
Loading