Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(dpmodel): move save_dp_model and load_dp_model to a seperated module #3701

Merged
merged 1 commit into from
Apr 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions deepmd/backend/dpmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def serialize_hook(self) -> Callable[[str], dict]:
Callable[[str], dict]
The serialize hook of the backend.
"""
from deepmd.dpmodel.utils.network import (
from deepmd.dpmodel.utils.serialization import (
load_dp_model,
)

Expand All @@ -115,7 +115,7 @@ def deserialize_hook(self) -> Callable[[str, dict], None]:
Callable[[str, dict], None]
The deserialize hook of the backend.
"""
from deepmd.dpmodel.utils.network import (
from deepmd.dpmodel.utils.serialization import (
save_dp_model,
)

Expand Down
2 changes: 1 addition & 1 deletion deepmd/dpmodel/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from deepmd.dpmodel.utils.batch_size import (
AutoBatchSize,
)
from deepmd.dpmodel.utils.network import (
from deepmd.dpmodel.utils.serialization import (
load_dp_model,
)
from deepmd.infer.deep_dipole import (
Expand Down
10 changes: 5 additions & 5 deletions deepmd/dpmodel/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,9 @@
NativeLayer,
NativeNet,
NetworkCollection,
load_dp_model,
make_embedding_network,
make_fitting_network,
make_multilayer_network,
save_dp_model,
traverse_model_dict,
)
from .nlist import (
build_multiple_neighbor_list,
Expand All @@ -32,6 +29,11 @@
phys2inter,
to_face_distance,
)
from .serialization import (
load_dp_model,
save_dp_model,
traverse_model_dict,
)

__all__ = [
"EnvMat",
Expand All @@ -46,8 +48,6 @@
"load_dp_model",
"save_dp_model",
"traverse_model_dict",
"PRECISION_DICT",
"DEFAULT_PRECISION",
"build_neighbor_list",
"nlist_distinguish_types",
"get_multiple_nlist_key",
Expand Down
107 changes: 0 additions & 107 deletions deepmd/dpmodel/utils/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,6 @@

import copy
import itertools
import json
from datetime import (
datetime,
)
from typing import (
Callable,
ClassVar,
Expand All @@ -19,7 +15,6 @@
Union,
)

import h5py
import numpy as np

from deepmd.utils.version import (
Expand All @@ -38,108 +33,6 @@
)


def traverse_model_dict(model_obj, callback: callable, is_variable: bool = False):
"""Traverse a model dict and call callback on each variable.

Parameters
----------
model_obj : object
The model object to traverse.
callback : callable
The callback function to call on each variable.
is_variable : bool, optional
Whether the current node is a variable.

Returns
-------
object
The model object after traversing.
"""
if isinstance(model_obj, dict):
for kk, vv in model_obj.items():
model_obj[kk] = traverse_model_dict(
vv, callback, is_variable=is_variable or kk == "@variables"
)
elif isinstance(model_obj, list):
for ii, vv in enumerate(model_obj):
model_obj[ii] = traverse_model_dict(vv, callback, is_variable=is_variable)
elif model_obj is None:
return model_obj
elif is_variable:
model_obj = callback(model_obj)
return model_obj


class Counter:
"""A callable counter.

Examples
--------
>>> counter = Counter()
>>> counter()
0
>>> counter()
1
"""

def __init__(self):
self.count = -1

def __call__(self):
self.count += 1
return self.count


# TODO: move save_dp_model and load_dp_model to a seperated module
# should be moved to otherwhere...
def save_dp_model(filename: str, model_dict: dict) -> None:
"""Save a DP model to a file in the native format.

Parameters
----------
filename : str
The filename to save to.
model_dict : dict
The model dict to save.
"""
model_dict = model_dict.copy()
variable_counter = Counter()
with h5py.File(filename, "w") as f:
model_dict = traverse_model_dict(
model_dict,
lambda x: f.create_dataset(
f"variable_{variable_counter():04d}", data=x
).name,
)
save_dict = {
"software": "deepmd-kit",
"version": __version__,
# use UTC+0 time
"time": str(datetime.utcnow()),
**model_dict,
}
f.attrs["json"] = json.dumps(save_dict, separators=(",", ":"))


def load_dp_model(filename: str) -> dict:
"""Load a DP model from a file in the native format.

Parameters
----------
filename : str
The filename to load from.

Returns
-------
dict
The loaded model dict, including meta information.
"""
with h5py.File(filename, "r") as f:
model_dict = json.loads(f.attrs["json"])
model_dict = traverse_model_dict(model_dict, lambda x: f[x][()].copy())
return model_dict


class NativeLayer(NativeOP):
"""Native representation of a layer.

Expand Down
115 changes: 115 additions & 0 deletions deepmd/dpmodel/utils/serialization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import json
from datetime import (
datetime,
)
from typing import (
Callable,
)

import h5py

try:
from deepmd._version import version as __version__
except ImportError:
__version__ = "unknown"

Check warning on line 15 in deepmd/dpmodel/utils/serialization.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/utils/serialization.py#L14-L15

Added lines #L14 - L15 were not covered by tests


def traverse_model_dict(model_obj, callback: Callable, is_variable: bool = False):
"""Traverse a model dict and call callback on each variable.

Parameters
----------
model_obj : object
The model object to traverse.
callback : callable
The callback function to call on each variable.
is_variable : bool, optional
Whether the current node is a variable.

Returns
-------
object
The model object after traversing.
"""
if isinstance(model_obj, dict):
for kk, vv in model_obj.items():
model_obj[kk] = traverse_model_dict(
vv, callback, is_variable=is_variable or kk == "@variables"
)
elif isinstance(model_obj, list):
for ii, vv in enumerate(model_obj):
model_obj[ii] = traverse_model_dict(vv, callback, is_variable=is_variable)
elif model_obj is None:
return model_obj
elif is_variable:
model_obj = callback(model_obj)
return model_obj


class Counter:
"""A callable counter.

Examples
--------
>>> counter = Counter()
>>> counter()
0
>>> counter()
1
"""

def __init__(self):
self.count = -1

def __call__(self):
self.count += 1
return self.count


def save_dp_model(filename: str, model_dict: dict) -> None:
"""Save a DP model to a file in the native format.

Parameters
----------
filename : str
The filename to save to.
model_dict : dict
The model dict to save.
"""
model_dict = model_dict.copy()
variable_counter = Counter()
with h5py.File(filename, "w") as f:
model_dict = traverse_model_dict(
model_dict,
lambda x: f.create_dataset(
f"variable_{variable_counter():04d}", data=x
).name,
)
save_dict = {
"software": "deepmd-kit",
"version": __version__,
# use UTC+0 time
"time": str(datetime.utcnow()),
**model_dict,
}
f.attrs["json"] = json.dumps(save_dict, separators=(",", ":"))


def load_dp_model(filename: str) -> dict:
"""Load a DP model from a file in the native format.

Parameters
----------
filename : str
The filename to load from.

Returns
-------
dict
The loaded model dict, including meta information.
"""
with h5py.File(filename, "r") as f:
model_dict = json.loads(f.attrs["json"])
model_dict = traverse_model_dict(model_dict, lambda x: f[x][()].copy())
return model_dict
Loading