Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add unit test for helper function json._check_type #716

Merged
merged 14 commits into from
Dec 10, 2024
27 changes: 14 additions & 13 deletions src/monty/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from enum import Enum
from hashlib import sha1
from importlib import import_module
from inspect import getfullargspec
from inspect import getfullargspec, isclass
from pathlib import Path
from typing import TYPE_CHECKING
from uuid import UUID, uuid4
Expand Down Expand Up @@ -68,12 +68,12 @@ def _load_redirect(redirect_file) -> dict:
return dict(redirect_dict)


def _check_type(obj, type_str: tuple[str, ...] | str) -> bool:
def _check_type(obj: object, type_str: tuple[str, ...] | str) -> bool:
"""Alternative to isinstance that avoids imports.

Checks whether obj is an instance of the type defined by type_str. This
removes the need to explicitly import type_str. Handles subclasses like
isinstance does. E.g.::
isinstance does. E.g.:
class A:
pass

Expand All @@ -88,21 +88,22 @@ class B(A):
assert isinstance(b, A)
assert not isinstance(a, B)

type_str: str | tuple[str]

Note for future developers: the type_str is not always obvious for an
object. For example, pandas.DataFrame is actually pandas.core.frame.DataFrame.
object. For example, pandas.DataFrame is actually "pandas.core.frame.DataFrame".
To find out the type_str for an object, run type(obj).mro(). This will
list all the types that an object can resolve to in order of generality
(all objects have the builtins.object as the last one).
(all objects have the "builtins.object" as the last one).
"""
type_str = type_str if isinstance(type_str, tuple) else (type_str,)
# I believe this try-except is only necessary for callable types
try:
mro = type(obj).mro()
except TypeError:
# This function is intended as an alternative of "isinstance",
# therefore wouldn't check class
if isclass(obj):
return False
return any(f"{o.__module__}.{o.__name__}" == ts for o in mro for ts in type_str)

type_str = type_str if isinstance(type_str, tuple) else (type_str,)

mro = type(obj).mro()

return any(f"{o.__module__}.{o.__qualname__}" == ts for o in mro for ts in type_str)


class MSONable:
Expand Down
129 changes: 129 additions & 0 deletions tests/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
MontyDecoder,
MontyEncoder,
MSONable,
_check_type,
_load_redirect,
jsanitize,
load,
Expand Down Expand Up @@ -1068,3 +1069,131 @@ def test_enum(self):
assert d_ == {"v": "value_a"}
na2 = EnumAsDict.from_dict(d_)
assert na2 == na1


class TestCheckType:
def test_check_subclass(self):
class A:
pass

class B(A):
pass

a, b = A(), B()

class_name_A = f"{type(a).__module__}.{type(a).__qualname__}"
class_name_B = f"{type(b).__module__}.{type(b).__qualname__}"

# a is an instance of A, but not B
assert _check_type(a, class_name_A)
assert isinstance(a, A)
assert not _check_type(a, class_name_B)
assert not isinstance(a, B)

# b is an instance of both B and A
assert _check_type(b, class_name_B)
assert isinstance(b, B)
assert _check_type(b, class_name_A)
assert isinstance(b, A)

def test_check_class(self):
"""This should not work for classes."""

class A:
pass

class B(A):
pass

class_name_A = f"{A.__module__}.{A.__qualname__}"
class_name_B = f"{B.__module__}.{B.__qualname__}"

# Test class behavior (should return False, like isinstance does)
assert not _check_type(A, class_name_A)
assert not _check_type(B, class_name_B)
assert not _check_type(B, class_name_A)

def test_callable(self):
# Test function
def my_function():
pass

callable_class_name = (
f"{type(my_function).__module__}.{type(my_function).__qualname__}"
)

assert _check_type(my_function, callable_class_name), callable_class_name
assert isinstance(my_function, type(my_function))

# Test callable class
class MyCallableClass:
def __call__(self):
pass

callable_instance = MyCallableClass()
assert callable(callable_instance)

callable_class_instance_name = f"{type(callable_instance).__module__}.{type(callable_instance).__qualname__}"

assert _check_type(
callable_instance, callable_class_instance_name
), callable_class_instance_name
assert isinstance(callable_instance, MyCallableClass)

DanielYang59 marked this conversation as resolved.
Show resolved Hide resolved
def test_numpy(self):
# Test NumPy array
arr = np.array([1, 2, 3])

assert _check_type(arr, "numpy.ndarray")
assert isinstance(arr, np.ndarray)

# Test NumPy generic
scalar = np.float64(3.14)

assert _check_type(scalar, "numpy.generic")
assert isinstance(scalar, np.generic)

@pytest.mark.skipif(pd is None, reason="pandas is not installed")
def test_pandas(self):
# Test pandas DataFrame
df = pd.DataFrame({"a": [1, 2, 3]})

assert _check_type(df, "pandas.core.frame.DataFrame")
assert isinstance(df, pd.DataFrame)

assert _check_type(df, "pandas.core.base.PandasObject")
assert isinstance(df, pd.core.base.PandasObject)

# Test pandas Series
series = pd.Series([1, 2, 3])

assert _check_type(series, "pandas.core.series.Series")
assert isinstance(series, pd.Series)

assert _check_type(series, "pandas.core.base.PandasObject")
assert isinstance(series, pd.core.base.PandasObject)

@pytest.mark.skipif(torch is None, reason="torch is not installed")
def test_torch(self):
tensor = torch.tensor([1, 2, 3])

assert _check_type(tensor, "torch.Tensor")
assert isinstance(tensor, torch.Tensor)

@pytest.mark.skipif(pydantic is None, reason="pydantic is not installed")
def test_pydantic(self):
class MyModel(pydantic.BaseModel):
name: str

model_instance = MyModel(name="Alice")

assert _check_type(model_instance, "pydantic.main.BaseModel")
assert isinstance(model_instance, pydantic.BaseModel)

@pytest.mark.skipif(pint is None, reason="pint is not installed")
def test_pint(self):
ureg = pint.UnitRegistry()
qty = 3 * ureg.meter

assert _check_type(qty, "pint.registry.Quantity")
assert isinstance(qty, pint.Quantity)
Loading