Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(fix) Make bias statistics complete for all elements #4496

Open
wants to merge 106 commits into
base: devel
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 95 commits
Commits
Show all changes
106 commits
Select commit Hold shift + click to select a range
32da243
4424
SumGuo-88 Dec 23, 2024
adf2315
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 23, 2024
4f6f63d
issues4424-2
SumGuo-88 Dec 26, 2024
b9bac38
ll
SumGuo-88 Dec 26, 2024
1db3408
Merge branch 'devel' of https://github.com/SumGuo-88/deepmd-kit into …
SumGuo-88 Dec 26, 2024
543a318
ll
SumGuo-88 Dec 26, 2024
ba72382
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 26, 2024
26f9a17
lll
SumGuo-88 Dec 26, 2024
25a803c
Merge branch 'devel' of https://github.com/SumGuo-88/deepmd-kit into …
SumGuo-88 Dec 26, 2024
dc64307
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 26, 2024
8f962b5
allchange
SumGuo-88 Jan 2, 2025
b88e7fc
Merge branch 'devel' of https://github.com/SumGuo-88/deepmd-kit into …
SumGuo-88 Jan 2, 2025
f57498d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 2, 2025
faeb7c5
test
SumGuo-88 Jan 2, 2025
725f1dd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 2, 2025
ca7fc84
stat
SumGuo-88 Jan 2, 2025
394cf04
Merge branch 'devel' of https://github.com/SumGuo-88/deepmd-kit into …
SumGuo-88 Jan 2, 2025
05128d3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 2, 2025
4828619
check
SumGuo-88 Jan 2, 2025
37ccce4
Merge branch 'devel' of https://github.com/SumGuo-88/deepmd-kit into …
SumGuo-88 Jan 2, 2025
c9406e4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 2, 2025
2224f61
chec坑
SumGuo-88 Jan 2, 2025
ba12c2c
Merge branch 'devel' of https://github.com/SumGuo-88/deepmd-kit into …
SumGuo-88 Jan 2, 2025
9fcee84
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 2, 2025
fe8579e
check3
SumGuo-88 Jan 2, 2025
f004dff
Merge branch 'devel' of https://github.com/SumGuo-88/deepmd-kit into …
SumGuo-88 Jan 2, 2025
11138ff
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 2, 2025
a4a97a3
test
SumGuo-88 Jan 3, 2025
88566fe
Merge branch 'devel' of https://github.com/SumGuo-88/deepmd-kit into …
SumGuo-88 Jan 3, 2025
10e538d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 3, 2025
203dc4e
ttt
SumGuo-88 Jan 3, 2025
bb9fbe1
Merge branch 'devel' of https://github.com/SumGuo-88/deepmd-kit into …
SumGuo-88 Jan 3, 2025
6a65561
t
SumGuo-88 Jan 3, 2025
603aee9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 3, 2025
1c103c4
d
SumGuo-88 Jan 3, 2025
4173040
Merge branch 'devel' of https://github.com/SumGuo-88/deepmd-kit into …
SumGuo-88 Jan 3, 2025
e3a1c9b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 3, 2025
533e95e
ll
SumGuo-88 Jan 3, 2025
38dc18c
Merge branch 'devel' of https://github.com/SumGuo-88/deepmd-kit into …
SumGuo-88 Jan 3, 2025
714c197
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 3, 2025
6713c1a
last
SumGuo-88 Jan 4, 2025
e42e38d
Merge branch 'devel' of https://github.com/SumGuo-88/deepmd-kit into …
SumGuo-88 Jan 4, 2025
1c15cf0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 4, 2025
33c716d
q
SumGuo-88 Jan 4, 2025
6d38b94
Merge branch 'devel' of https://github.com/SumGuo-88/deepmd-kit into …
SumGuo-88 Jan 4, 2025
6bbced8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 4, 2025
b462c97
ll
SumGuo-88 Jan 4, 2025
0d7154c
Merge branch 'devel' of https://github.com/SumGuo-88/deepmd-kit into …
SumGuo-88 Jan 4, 2025
0c7baa0
ll
SumGuo-88 Jan 4, 2025
28d94af
ll
SumGuo-88 Jan 4, 2025
87dcd66
l
SumGuo-88 Jan 4, 2025
5d33060
ll
SumGuo-88 Jan 4, 2025
521e3a6
ll
SumGuo-88 Jan 4, 2025
379d4ad
Merge branch 'devel' into devel
SumGuo-88 Jan 5, 2025
0dabf77
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 5, 2025
a23528c
Update stat.py
SumGuo-88 Jan 5, 2025
0a97b54
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 5, 2025
49744ed
Update deepmd/pt/utils/stat.py
SumGuo-88 Jan 6, 2025
556a684
Simplify logic and remove "not"
SumGuo-88 Jan 6, 2025
aa2633d
Merge branch 'devel' of https://github.com/SumGuo-88/deepmd-kit into …
SumGuo-88 Jan 6, 2025
27999af
check import
SumGuo-88 Jan 6, 2025
83b7f1d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 6, 2025
817d2ec
Add assert to ensure that the new frame contains the required elements
SumGuo-88 Jan 6, 2025
234e461
Merge branch 'devel' of https://github.com/SumGuo-88/deepmd-kit into …
SumGuo-88 Jan 6, 2025
93a748f
check import
SumGuo-88 Jan 6, 2025
4a38f1d
check import
SumGuo-88 Jan 6, 2025
78b2a10
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 6, 2025
6a5d169
check test.py
SumGuo-88 Jan 6, 2025
26205d7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 6, 2025
3ccb4b9
check ut
SumGuo-88 Jan 6, 2025
f669ac5
Merge branch 'devel' of https://github.com/SumGuo-88/deepmd-kit into …
SumGuo-88 Jan 6, 2025
87de0e8
check ut
SumGuo-88 Jan 6, 2025
0939ef1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 6, 2025
7ec779f
Update deepmd/utils/argcheck.py
SumGuo-88 Jan 7, 2025
24d1386
Update deepmd/utils/argcheck.py
SumGuo-88 Jan 7, 2025
708bc78
check msi defalut value
SumGuo-88 Jan 7, 2025
02f3f28
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 7, 2025
c648e9e
Merge branch 'devel' into devel
SumGuo-88 Jan 7, 2025
d36a24a
check ut cuda
SumGuo-88 Jan 7, 2025
b6a483a
Merge branch 'devel' of https://github.com/SumGuo-88/deepmd-kit into …
SumGuo-88 Jan 7, 2025
050dbaf
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 7, 2025
b00f8de
check ut
SumGuo-88 Jan 7, 2025
2f37dfe
Merge branch 'devel' of https://github.com/SumGuo-88/deepmd-kit into …
SumGuo-88 Jan 7, 2025
47fe45b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 7, 2025
98890c2
Merge branch 'devel' into devel
SumGuo-88 Jan 7, 2025
b9bdee5
make truetype for more sys
SumGuo-88 Jan 9, 2025
a30053f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 9, 2025
cfbc88a
Add skip element check function to Chang bias
SumGuo-88 Jan 9, 2025
73a20b0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 9, 2025
0b29b05
make changebias control minframes
SumGuo-88 Jan 10, 2025
85d4da3
Merge branch 'devel' of https://github.com/SumGuo-88/deepmd-kit into …
SumGuo-88 Jan 10, 2025
0400233
check merge
SumGuo-88 Jan 10, 2025
c05ffb1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 10, 2025
139f037
improve ut with all frames
SumGuo-88 Jan 10, 2025
5e826bf
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 10, 2025
edf1d91
check ut
SumGuo-88 Jan 10, 2025
3887013
Merge branch 'devel' of https://github.com/SumGuo-88/deepmd-kit into …
SumGuo-88 Jan 10, 2025
eb9f068
check
SumGuo-88 Jan 10, 2025
10ef768
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 10, 2025
c2dc7ef
check skip logic and def name
SumGuo-88 Jan 10, 2025
8763165
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 10, 2025
d5596bf
improve warning readable
SumGuo-88 Jan 10, 2025
9f389ad
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 10, 2025
0c76ad9
check args
SumGuo-88 Jan 10, 2025
58647f3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 10, 2025
4ce9cfb
check stat.py
SumGuo-88 Jan 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions deepmd/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -737,6 +737,18 @@ def main_parser() -> argparse.ArgumentParser:
default=None,
help="Model branch chosen for changing bias if multi-task model.",
)
parser_change_bias.add_argument(
"--skip-elementcheck",
action="store_false",
SumGuo-88 marked this conversation as resolved.
Show resolved Hide resolved
help="Enable this option to skip element checks if any error occurs while retrieving statistical data.",
)
parser_change_bias.add_argument(
"-mf",
"--min-frames",
default=10,
type=int,
help="The minimum number of frames for each element used for statistics.",
)

# --version
parser.add_argument(
Expand Down
6 changes: 6 additions & 0 deletions deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,8 @@ def change_bias(
numb_batch: int = 0,
model_branch: Optional[str] = None,
output: Optional[str] = None,
elem_check_stat: bool = True,
min_frames: int = 10,
) -> None:
if input_file.endswith(".pt"):
old_state_dict = torch.load(
Expand Down Expand Up @@ -472,6 +474,8 @@ def change_bias(
data_single.systems,
data_single.dataloaders,
nbatches,
min_frames_per_element_forstat=min_frames,
enable_element_completion=elem_check_stat,
)
updated_model = training.model_change_out_bias(
model_to_change, sampled_data, _bias_adjust_mode=bias_adjust_mode
Expand Down Expand Up @@ -555,6 +559,8 @@ def main(args: Optional[Union[list[str], argparse.Namespace]] = None) -> None:
numb_batch=FLAGS.numb_batch,
model_branch=FLAGS.model_branch,
output=FLAGS.output,
elem_check_stat=FLAGS.skip_elementcheck,
SumGuo-88 marked this conversation as resolved.
Show resolved Hide resolved
min_frames=FLAGS.min_frames,
)
elif FLAGS.command == "compress":
FLAGS.input = str(Path(FLAGS.input).with_suffix(".pth"))
Expand Down
8 changes: 8 additions & 0 deletions deepmd/pt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,12 @@ def __init__(
self.max_ckpt_keep = training_params.get("max_ckpt_keep", 5)
self.display_in_training = training_params.get("disp_training", True)
self.timing_in_training = training_params.get("time_training", True)
self.min_frames_per_element_forstat = training_params.get(
"min_frames_per_element_forstat", 10
)
self.enable_element_completion = training_params.get(
"enable_element_completion", True
)
self.change_bias_after_training = training_params.get(
"change_bias_after_training", False
)
Expand Down Expand Up @@ -227,6 +233,8 @@ def get_sample():
_training_data.systems,
_training_data.dataloaders,
_data_stat_nbatch,
self.min_frames_per_element_forstat,
self.enable_element_completion,
)
return sampled

Expand Down
41 changes: 38 additions & 3 deletions deepmd/pt/utils/dataset.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
# SPDX-License-Identifier: LGPL-3.0-or-later


from collections import (
defaultdict,
)
from typing import (
Optional,
)

import numpy as np
from torch.utils.data import (
Dataset,
)
Expand All @@ -17,10 +20,10 @@

class DeepmdDataSetForLoader(Dataset):
def __init__(self, system: str, type_map: Optional[list[str]] = None) -> None:
"""Construct DeePMD-style dataset containing frames cross different systems.
"""Construct DeePMD-style dataset containing frames across different systems.

Args:
- systems: Paths to systems.
- system: Path to the system.
- type_map: Atom types.
"""
self.system = system
Expand All @@ -40,6 +43,38 @@ def __getitem__(self, index):
b_data["natoms"] = self._natoms_vec
return b_data

def get_frame_index(self):
SumGuo-88 marked this conversation as resolved.
Show resolved Hide resolved
"""
Get the frame index and the number of frames with all the elements in the system.
This function is only used in the mixed type.

Returns
-------
element_counts : dict
A dictionary where:
- The key is the element type.
- The value is another dictionary with the following keys:
- "frames": int
The total number of frames in which the element appears.
- "indices": list of int
A list of row indices where the element is found in the dataset.
"""
element_counts = defaultdict(lambda: {"frames": 0, "indices": []})
set_files = self._data_system.dirs
base_offset = 0
for set_file in set_files:
element_data = self._data_system._load_type_mix(set_file)
unique_elements = np.unique(element_data)
for elem in unique_elements:
frames_with_elem = np.any(element_data == elem, axis=1)
row_indices = np.where(frames_with_elem)[0]
row_indices_global = np.where(frames_with_elem)[0] + base_offset
element_counts[elem]["frames"] += len(row_indices)
element_counts[elem]["indices"].extend(row_indices_global.tolist())
base_offset += element_data.shape[0]
element_counts = dict(element_counts)
return element_counts

def add_data_requirement(self, data_requirement: list[DataRequirementItem]) -> None:
"""Add data requirement for this data system."""
for data_item in data_requirement:
Expand Down
188 changes: 161 additions & 27 deletions deepmd/pt/utils/stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,51 +35,185 @@
log = logging.getLogger(__name__)


def make_stat_input(datasets, dataloaders, nbatches):
def make_stat_input(
datasets,
dataloaders,
nbatches,
min_frames_per_element_forstat=10,
enable_element_completion=True,
):
"""Pack data for statistics.
Element checking is only enabled with mixed_type.

Args:
- dataset: A list of dataset to analyze.
- datasets: A list of datasets to analyze.
- dataloaders: Corresponding dataloaders for the datasets.
- nbatches: Batch count for collecting stats.
- min_frames_per_element_forstat: Minimum frames required for statistics.
- enable_element_completion: Whether to perform missing element completion (default: True).

Returns
-------
- a list of dicts, each of which contains data from a system
- A list of dicts, each of which contains data from a system.
"""
lst = []
log.info(f"Packing data for statistics from {len(datasets)} systems")
for i in range(len(datasets)):
sys_stat = {}
with torch.device("cpu"):
iterator = iter(dataloaders[i])
numb_batches = min(nbatches, len(dataloaders[i]))
for _ in range(numb_batches):
try:
stat_data = next(iterator)
except StopIteration:
iterator = iter(dataloaders[i])
stat_data = next(iterator)
for dd in stat_data:
if stat_data[dd] is None:
sys_stat[dd] = None
elif isinstance(stat_data[dd], torch.Tensor):
if dd not in sys_stat:
sys_stat[dd] = []
sys_stat[dd].append(stat_data[dd])
elif isinstance(stat_data[dd], np.float32):
sys_stat[dd] = stat_data[dd]
else:
pass
total_element_types = set()
global_element_counts = {}
collect_ele = defaultdict(int)
SumGuo-88 marked this conversation as resolved.
Show resolved Hide resolved
if datasets[0].mixed_type:
if enable_element_completion:
log.info(
f"Element check enabled. "
f"Verifying if frames with elements meet the set of {min_frames_per_element_forstat}."
)
else:
log.info(
"Element completion is disabled. Skipping missing element handling."
)

SumGuo-88 marked this conversation as resolved.
Show resolved Hide resolved
def process_batches(dataloader, sys_stat):
"""Process batches from a dataloader to collect statistics."""
iterator = iter(dataloader)
numb_batches = min(nbatches, len(dataloader))
for _ in range(numb_batches):
try:
stat_data = next(iterator)
except StopIteration:
iterator = iter(dataloader)
stat_data = next(iterator)
for dd in stat_data:
if stat_data[dd] is None:
sys_stat[dd] = None
elif isinstance(stat_data[dd], torch.Tensor):
if dd not in sys_stat:
sys_stat[dd] = []
sys_stat[dd].append(stat_data[dd])
elif isinstance(stat_data[dd], np.float32):
sys_stat[dd] = stat_data[dd]
else:
pass
SumGuo-88 marked this conversation as resolved.
Show resolved Hide resolved
SumGuo-88 marked this conversation as resolved.
Show resolved Hide resolved

SumGuo-88 marked this conversation as resolved.
Show resolved Hide resolved
def process_with_new_frame(sys_indices, newele_counter, miss):
for sys_info in sys_indices:
sys_index = sys_info["sys_index"]
frames = sys_info["frames"]
sys = datasets[sys_index]
for frame in frames:
newele_counter += 1
if newele_counter <= min_frames_per_element_forstat:
frame_data = sys.__getitem__(frame)
SumGuo-88 marked this conversation as resolved.
Show resolved Hide resolved
assert miss in frame_data["atype"], (
"Element check failed. "
"If you are running in 'change-bias' mode, use '--skip-elementcheck' to disable this check. "
"If you encountered this error during model training, set 'enable_element_completion' to False "
"in the 'training' section of your input file."
)
sys_stat_new = {}
for dd in frame_data:
if dd == "type":
continue
if frame_data[dd] is None:
sys_stat_new[dd] = None
elif isinstance(frame_data[dd], np.ndarray):
if dd not in sys_stat_new:
sys_stat_new[dd] = []
tensor_data = torch.from_numpy(frame_data[dd])
tensor_data = tensor_data.unsqueeze(0)
sys_stat_new[dd].append(tensor_data)
elif isinstance(frame_data[dd], np.float32):
sys_stat_new[dd] = frame_data[dd]
else:
pass
finalize_stats(sys_stat_new)
lst.append(sys_stat_new)
else:
break
SumGuo-88 marked this conversation as resolved.
Show resolved Hide resolved

SumGuo-88 marked this conversation as resolved.
Show resolved Hide resolved
def finalize_stats(sys_stat):
"""Finalize statistics by concatenating tensors."""
for key in sys_stat:
if isinstance(sys_stat[key], np.float32):
pass
elif sys_stat[key] is None or sys_stat[key][0] is None:
elif sys_stat[key] is None or (
isinstance(sys_stat[key], list)
and (len(sys_stat[key]) == 0 or sys_stat[key][0] is None)
):
sys_stat[key] = None
elif isinstance(stat_data[dd], torch.Tensor):
elif isinstance(sys_stat[key][0], torch.Tensor):
sys_stat[key] = torch.cat(sys_stat[key], dim=0)
dict_to_device(sys_stat)
SumGuo-88 marked this conversation as resolved.
Show resolved Hide resolved

SumGuo-88 marked this conversation as resolved.
Show resolved Hide resolved
for sys_index, (dataset, dataloader) in enumerate(zip(datasets, dataloaders)):
sys_stat = {}
with torch.device("cpu"):
process_batches(dataloader, sys_stat)
if datasets[0].mixed_type and enable_element_completion:
element_data = torch.cat(sys_stat["atype"], dim=0)
collect_values = torch.unique(element_data.flatten(), sorted=True)
for elem in collect_values.tolist():
frames_with_elem = torch.any(element_data == elem, dim=1)
row_indices = torch.where(frames_with_elem)[0]
collect_ele[elem] += len(row_indices)
finalize_stats(sys_stat)
lst.append(sys_stat)

SumGuo-88 marked this conversation as resolved.
Show resolved Hide resolved
# get frame index
if datasets[0].mixed_type and enable_element_completion:
element_counts = dataset.get_frame_index()
for elem, data in element_counts.items():
indices = data["indices"]
count = data["frames"]
total_element_types.add(elem)
if elem not in global_element_counts:
global_element_counts[elem] = {"count": 0, "indices": []}
if count > min_frames_per_element_forstat:
global_element_counts[elem]["count"] += (
min_frames_per_element_forstat
)
indices = indices[:min_frames_per_element_forstat]
global_element_counts[elem]["indices"].append(
{"sys_index": sys_index, "frames": indices}
)
else:
global_element_counts[elem]["count"] += count
global_element_counts[elem]["indices"].append(
{"sys_index": sys_index, "frames": indices}
)
else:
if (
global_element_counts[elem]["count"]
>= min_frames_per_element_forstat
):
pass
else:
global_element_counts[elem]["count"] += count
global_element_counts[elem]["indices"].append(
{"sys_index": sys_index, "frames": indices}
)
# Complement
SumGuo-88 marked this conversation as resolved.
Show resolved Hide resolved
if datasets[0].mixed_type and enable_element_completion:
for elem, data in global_element_counts.items():
indices_count = data["count"]
if indices_count < min_frames_per_element_forstat:
log.warning(
f"The number of frames in your datasets with element {elem} is {indices_count}, "
f"which is less than the required {min_frames_per_element_forstat}"
)
collect_elements = collect_ele.keys()
missing_elements = total_element_types - collect_elements
SumGuo-88 marked this conversation as resolved.
Show resolved Hide resolved
collect_miss_element = set()
for ele, count in collect_ele.items():
if count < min_frames_per_element_forstat:
collect_miss_element.add(ele)
missing_elements.add(ele)
for miss in missing_elements:
sys_indices = global_element_counts[miss].get("indices", [])
if miss in collect_miss_element:
newele_counter = collect_ele.get(miss, 0)
else:
newele_counter = 0
process_with_new_frame(sys_indices, newele_counter, miss)
SumGuo-88 marked this conversation as resolved.
Show resolved Hide resolved
return lst


Expand Down
14 changes: 14 additions & 0 deletions deepmd/utils/argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -2893,6 +2893,20 @@ def training_args(
optional=True,
doc=doc_only_pt_supported + doc_gradient_max_norm,
),
Argument(
"min_frames_per_element_forstat",
int,
default=10,
optional=True,
doc="The minimum number of frames per element used for statistics when using the mixed type.",
),
Argument(
"enable_element_completion",
bool,
optional=True,
default=True,
doc="Whether to check elements when using the mixed type",
),
SumGuo-88 marked this conversation as resolved.
Show resolved Hide resolved
]
variants = [
Variant(
Expand Down
16 changes: 8 additions & 8 deletions deepmd/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,14 +547,6 @@ def _load_set(self, set_name: DPPath):
if self.mixed_type:
# nframes x natoms
atom_type_mix = self._load_type_mix(set_name)
if self.enforce_type_map:
try:
atom_type_mix_ = self.type_idx_map[atom_type_mix].astype(np.int32)
except IndexError as e:
raise IndexError(
f"some types in 'real_atom_types.npy' of set {set_name} are not contained in {self.get_ntypes()} types!"
) from e
atom_type_mix = atom_type_mix_
real_type = atom_type_mix.reshape([nframes, self.natoms])
data["type"] = real_type
natoms = data["type"].shape[1]
Expand Down Expand Up @@ -704,6 +696,14 @@ def _load_type(self, sys_path: DPPath):
def _load_type_mix(self, set_name: DPPath):
type_path = set_name / "real_atom_types.npy"
real_type = type_path.load_numpy().astype(np.int32).reshape([-1, self.natoms])
if self.enforce_type_map:
try:
atom_type_mix_ = self.type_idx_map[real_type].astype(np.int32)
except IndexError as e:
raise IndexError(
f"some types in 'real_atom_types.npy' of set {set_name} are not contained in {self.get_ntypes()} types!"
) from e
real_type = atom_type_mix_
return real_type

def _make_idx_map(self, atom_type):
Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Loading
Loading