-
Notifications
You must be signed in to change notification settings - Fork 50
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Lazily import torch/pydantic
in json
module, speedup from monty.json import
by 10x
#713
Merged
shyuep
merged 16 commits into
materialsvirtuallab:master
from
DanielYang59:lazy-import-torch
Oct 21, 2024
Merged
Changes from 6 commits
Commits
Show all changes
16 commits
Select commit
Hold shift + click to select a range
a7752ee
avoid import torch for type check in MontyEncoder
DanielYang59 196dcf9
lazily import torch for MontyDecoder
DanielYang59 f389e24
also lazily import numpy
DanielYang59 2bcc6cd
lazily import pydantic
DanielYang59 acfeb00
lazily load ruamel.yaml
DanielYang59 002f5af
revert lazy import ruamel, no obvious improvement
DanielYang59 f67796c
E1101 should disappear
DanielYang59 69cc7d9
use f-string
DanielYang59 14c1e3e
add type
DanielYang59 3b8d854
eagerly import numpy
DanielYang59 c7f012d
Merge remote-tracking branch 'upstream/master' into lazy-import-torch
DanielYang59 226194b
push a random change to rerun CI, torch install is flaky as always
DanielYang59 bc58f69
add comment about lazy import
DanielYang59 6f9580c
bump setup python version
DanielYang59 4a8377b
pre-commit migrate-config
DanielYang59 f33fd5a
pin 3.12
DanielYang59 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,42 +21,19 @@ | |
from typing import Any | ||
from uuid import UUID, uuid4 | ||
|
||
try: | ||
import numpy as np | ||
except ImportError: | ||
np = None | ||
|
||
try: | ||
import pydantic | ||
except ImportError: | ||
pydantic = None | ||
|
||
try: | ||
from pydantic_core import core_schema | ||
except ImportError: | ||
core_schema = None | ||
from ruamel.yaml import YAML | ||
|
||
try: | ||
import bson | ||
except ImportError: | ||
bson = None | ||
|
||
try: | ||
from ruamel.yaml import YAML | ||
except ImportError: | ||
YAML = None | ||
|
||
try: | ||
import orjson | ||
except ImportError: | ||
orjson = None | ||
|
||
|
||
try: | ||
import torch | ||
except ImportError: | ||
torch = None | ||
|
||
__version__ = "3.0.0" | ||
|
||
|
||
|
@@ -338,8 +315,11 @@ def __get_pydantic_core_schema__(cls, source_type, handler): | |
""" | ||
pydantic v2 core schema definition | ||
""" | ||
if core_schema is None: | ||
raise RuntimeError("Pydantic >= 2.0 is required for validation") | ||
try: | ||
from pydantic_core import core_schema | ||
|
||
except ImportError as exc: | ||
raise RuntimeError("Pydantic >= 2.0 is required for validation") from exc | ||
|
||
s = core_schema.with_info_plain_validator_function(cls.validate_monty_v2) | ||
|
||
|
@@ -586,8 +566,8 @@ def default(self, o) -> dict: | |
if isinstance(o, Path): | ||
return {"@module": "pathlib", "@class": "Path", "string": str(o)} | ||
|
||
if torch is not None and isinstance(o, torch.Tensor): | ||
# Support for Pytorch Tensors. | ||
# Support for Pytorch Tensors | ||
if _check_type(o, "torch.Tensor"): | ||
d: dict[str, Any] = { | ||
"@module": "torch", | ||
"@class": "Tensor", | ||
|
@@ -599,7 +579,9 @@ def default(self, o) -> dict: | |
d["data"] = o.numpy().tolist() | ||
return d | ||
|
||
if np is not None: | ||
try: | ||
import numpy as np | ||
|
||
if isinstance(o, np.ndarray): | ||
if str(o.dtype).startswith("complex"): | ||
return { | ||
|
@@ -616,6 +598,8 @@ def default(self, o) -> dict: | |
} | ||
if isinstance(o, np.generic): | ||
return o.item() | ||
except ImportError: | ||
pass | ||
|
||
if _check_type(o, "pandas.core.frame.DataFrame"): | ||
return { | ||
|
@@ -660,7 +644,7 @@ def default(self, o) -> dict: | |
raise AttributeError(e) | ||
|
||
try: | ||
if pydantic is not None and isinstance(o, pydantic.BaseModel): | ||
if _check_type(o, "pydantic.main.BaseModel"): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. from pydantic import BaseModel
class MyModel(BaseModel):
name: str
model_instance = MyModel(name="monty")
print(type(model_instance).mro()) Gives:
DanielYang59 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
d = o.model_dump() | ||
elif ( | ||
dataclasses is not None | ||
|
@@ -790,11 +774,18 @@ def process_decoded(self, d): | |
return cls_.from_dict(data) | ||
if issubclass(cls_, Enum): | ||
return cls_(d["value"]) | ||
if pydantic is not None and issubclass( | ||
cls_, pydantic.BaseModel | ||
): # pylint: disable=E1101 | ||
d = {k: self.process_decoded(v) for k, v in data.items()} | ||
return cls_(**d) | ||
|
||
try: | ||
import pydantic | ||
|
||
if issubclass(cls_, pydantic.BaseModel): # pylint: disable=E1101 | ||
d = { | ||
k: self.process_decoded(v) for k, v in data.items() | ||
} | ||
return cls_(**d) | ||
except ImportError: | ||
pass | ||
|
||
if ( | ||
dataclasses is not None | ||
and (not issubclass(cls_, MSONable)) | ||
|
@@ -803,26 +794,39 @@ def process_decoded(self, d): | |
d = {k: self.process_decoded(v) for k, v in data.items()} | ||
return cls_(**d) | ||
|
||
elif torch is not None and modname == "torch" and classname == "Tensor": | ||
if "Complex" in d["dtype"]: | ||
return torch.tensor( # pylint: disable=E1101 | ||
[ | ||
np.array(r) + np.array(i) * 1j | ||
for r, i in zip(*d["data"]) | ||
], | ||
).type(d["dtype"]) | ||
return torch.tensor(d["data"]).type(d["dtype"]) # pylint: disable=E1101 | ||
|
||
elif np is not None and modname == "numpy" and classname == "array": | ||
if d["dtype"].startswith("complex"): | ||
return np.array( | ||
[ | ||
np.array(r) + np.array(i) * 1j | ||
for r, i in zip(*d["data"]) | ||
], | ||
dtype=d["dtype"], | ||
) | ||
return np.array(d["data"], dtype=d["dtype"]) | ||
elif modname == "torch" and classname == "Tensor": | ||
try: | ||
import numpy as np | ||
import torch | ||
|
||
if "Complex" in d["dtype"]: | ||
return torch.tensor( # pylint: disable=E1101 | ||
[ | ||
np.array(r) + np.array(i) * 1j | ||
for r, i in zip(*d["data"]) | ||
], | ||
).type(d["dtype"]) | ||
return torch.tensor(d["data"]).type(d["dtype"]) # pylint: disable=E1101 | ||
|
||
except ImportError: | ||
pass | ||
|
||
elif modname == "numpy" and classname == "array": | ||
try: | ||
import numpy as np | ||
|
||
if d["dtype"].startswith("complex"): | ||
return np.array( | ||
[ | ||
np.array(r) + np.array(i) * 1j | ||
for r, i in zip(*d["data"]) | ||
], | ||
dtype=d["dtype"], | ||
) | ||
return np.array(d["data"], dtype=d["dtype"]) | ||
|
||
except ImportError: | ||
pass | ||
|
||
elif modname == "pandas": | ||
import pandas as pd | ||
|
@@ -925,6 +929,7 @@ def jsanitize( | |
or (bson is not None and isinstance(obj, bson.objectid.ObjectId)) | ||
): | ||
return obj | ||
|
||
if isinstance(obj, (list, tuple)): | ||
return [ | ||
jsanitize( | ||
|
@@ -936,22 +941,30 @@ def jsanitize( | |
) | ||
for i in obj | ||
] | ||
if np is not None and isinstance(obj, np.ndarray): | ||
try: | ||
return [ | ||
jsanitize( | ||
i, | ||
strict=strict, | ||
allow_bson=allow_bson, | ||
enum_values=enum_values, | ||
recursive_msonable=recursive_msonable, | ||
) | ||
for i in obj.tolist() | ||
] | ||
except TypeError: | ||
return obj.tolist() | ||
if np is not None and isinstance(obj, np.generic): | ||
return obj.item() | ||
|
||
try: | ||
import numpy as np | ||
|
||
if isinstance(obj, np.ndarray): | ||
try: | ||
return [ | ||
jsanitize( | ||
i, | ||
strict=strict, | ||
allow_bson=allow_bson, | ||
enum_values=enum_values, | ||
recursive_msonable=recursive_msonable, | ||
) | ||
for i in obj.tolist() | ||
] | ||
except TypeError: | ||
return obj.tolist() | ||
|
||
if isinstance(obj, np.generic): | ||
return obj.item() | ||
except ImportError: | ||
pass | ||
|
||
if _check_type( | ||
obj, | ||
( | ||
|
@@ -961,6 +974,7 @@ def jsanitize( | |
), | ||
): | ||
return obj.to_dict() | ||
|
||
if isinstance(obj, dict): | ||
return { | ||
str(k): jsanitize( | ||
|
@@ -972,10 +986,13 @@ def jsanitize( | |
) | ||
for k, v in obj.items() | ||
} | ||
|
||
if isinstance(obj, (int, float)): | ||
return obj | ||
|
||
if obj is None: | ||
return None | ||
|
||
if isinstance(obj, (pathlib.Path, datetime.datetime)): | ||
return str(obj) | ||
|
||
|
@@ -997,7 +1014,7 @@ def jsanitize( | |
if isinstance(obj, str): | ||
return obj | ||
|
||
if pydantic is not None and isinstance(obj, pydantic.BaseModel): # pylint: disable=E1101 | ||
if _check_type(obj, "pydantic.main.BaseModel"): | ||
return jsanitize( | ||
MontyEncoder().default(obj), | ||
strict=strict, | ||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should be the correct type string:
Gives:
[<class 'torch.Tensor'>, <class 'torch._C.TensorBase'>, <class 'object'>]