Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(fix) Make bias statistics complete for all elements #4496

Open
wants to merge 106 commits into
base: devel
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 89 commits
Commits
Show all changes
106 commits
Select commit Hold shift + click to select a range
32da243
4424
SumGuo-88 Dec 23, 2024
adf2315
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 23, 2024
4f6f63d
issues4424-2
SumGuo-88 Dec 26, 2024
b9bac38
ll
SumGuo-88 Dec 26, 2024
1db3408
Merge branch 'devel' of https://github.com/SumGuo-88/deepmd-kit into …
SumGuo-88 Dec 26, 2024
543a318
ll
SumGuo-88 Dec 26, 2024
ba72382
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 26, 2024
26f9a17
lll
SumGuo-88 Dec 26, 2024
25a803c
Merge branch 'devel' of https://github.com/SumGuo-88/deepmd-kit into …
SumGuo-88 Dec 26, 2024
dc64307
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 26, 2024
8f962b5
allchange
SumGuo-88 Jan 2, 2025
b88e7fc
Merge branch 'devel' of https://github.com/SumGuo-88/deepmd-kit into …
SumGuo-88 Jan 2, 2025
f57498d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 2, 2025
faeb7c5
test
SumGuo-88 Jan 2, 2025
725f1dd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 2, 2025
ca7fc84
stat
SumGuo-88 Jan 2, 2025
394cf04
Merge branch 'devel' of https://github.com/SumGuo-88/deepmd-kit into …
SumGuo-88 Jan 2, 2025
05128d3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 2, 2025
4828619
check
SumGuo-88 Jan 2, 2025
37ccce4
Merge branch 'devel' of https://github.com/SumGuo-88/deepmd-kit into …
SumGuo-88 Jan 2, 2025
c9406e4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 2, 2025
2224f61
chec坑
SumGuo-88 Jan 2, 2025
ba12c2c
Merge branch 'devel' of https://github.com/SumGuo-88/deepmd-kit into …
SumGuo-88 Jan 2, 2025
9fcee84
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 2, 2025
fe8579e
check3
SumGuo-88 Jan 2, 2025
f004dff
Merge branch 'devel' of https://github.com/SumGuo-88/deepmd-kit into …
SumGuo-88 Jan 2, 2025
11138ff
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 2, 2025
a4a97a3
test
SumGuo-88 Jan 3, 2025
88566fe
Merge branch 'devel' of https://github.com/SumGuo-88/deepmd-kit into …
SumGuo-88 Jan 3, 2025
10e538d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 3, 2025
203dc4e
ttt
SumGuo-88 Jan 3, 2025
bb9fbe1
Merge branch 'devel' of https://github.com/SumGuo-88/deepmd-kit into …
SumGuo-88 Jan 3, 2025
6a65561
t
SumGuo-88 Jan 3, 2025
603aee9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 3, 2025
1c103c4
d
SumGuo-88 Jan 3, 2025
4173040
Merge branch 'devel' of https://github.com/SumGuo-88/deepmd-kit into …
SumGuo-88 Jan 3, 2025
e3a1c9b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 3, 2025
533e95e
ll
SumGuo-88 Jan 3, 2025
38dc18c
Merge branch 'devel' of https://github.com/SumGuo-88/deepmd-kit into …
SumGuo-88 Jan 3, 2025
714c197
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 3, 2025
6713c1a
last
SumGuo-88 Jan 4, 2025
e42e38d
Merge branch 'devel' of https://github.com/SumGuo-88/deepmd-kit into …
SumGuo-88 Jan 4, 2025
1c15cf0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 4, 2025
33c716d
q
SumGuo-88 Jan 4, 2025
6d38b94
Merge branch 'devel' of https://github.com/SumGuo-88/deepmd-kit into …
SumGuo-88 Jan 4, 2025
6bbced8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 4, 2025
b462c97
ll
SumGuo-88 Jan 4, 2025
0d7154c
Merge branch 'devel' of https://github.com/SumGuo-88/deepmd-kit into …
SumGuo-88 Jan 4, 2025
0c7baa0
ll
SumGuo-88 Jan 4, 2025
28d94af
ll
SumGuo-88 Jan 4, 2025
87dcd66
l
SumGuo-88 Jan 4, 2025
5d33060
ll
SumGuo-88 Jan 4, 2025
521e3a6
ll
SumGuo-88 Jan 4, 2025
379d4ad
Merge branch 'devel' into devel
SumGuo-88 Jan 5, 2025
0dabf77
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 5, 2025
a23528c
Update stat.py
SumGuo-88 Jan 5, 2025
0a97b54
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 5, 2025
49744ed
Update deepmd/pt/utils/stat.py
SumGuo-88 Jan 6, 2025
556a684
Simplify logic and remove "not"
SumGuo-88 Jan 6, 2025
aa2633d
Merge branch 'devel' of https://github.com/SumGuo-88/deepmd-kit into …
SumGuo-88 Jan 6, 2025
27999af
check import
SumGuo-88 Jan 6, 2025
83b7f1d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 6, 2025
817d2ec
Add assert to ensure that the new frame contains the required elements
SumGuo-88 Jan 6, 2025
234e461
Merge branch 'devel' of https://github.com/SumGuo-88/deepmd-kit into …
SumGuo-88 Jan 6, 2025
93a748f
check import
SumGuo-88 Jan 6, 2025
4a38f1d
check import
SumGuo-88 Jan 6, 2025
78b2a10
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 6, 2025
6a5d169
check test.py
SumGuo-88 Jan 6, 2025
26205d7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 6, 2025
3ccb4b9
check ut
SumGuo-88 Jan 6, 2025
f669ac5
Merge branch 'devel' of https://github.com/SumGuo-88/deepmd-kit into …
SumGuo-88 Jan 6, 2025
87de0e8
check ut
SumGuo-88 Jan 6, 2025
0939ef1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 6, 2025
7ec779f
Update deepmd/utils/argcheck.py
SumGuo-88 Jan 7, 2025
24d1386
Update deepmd/utils/argcheck.py
SumGuo-88 Jan 7, 2025
708bc78
check msi defalut value
SumGuo-88 Jan 7, 2025
02f3f28
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 7, 2025
c648e9e
Merge branch 'devel' into devel
SumGuo-88 Jan 7, 2025
d36a24a
check ut cuda
SumGuo-88 Jan 7, 2025
b6a483a
Merge branch 'devel' of https://github.com/SumGuo-88/deepmd-kit into …
SumGuo-88 Jan 7, 2025
050dbaf
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 7, 2025
b00f8de
check ut
SumGuo-88 Jan 7, 2025
2f37dfe
Merge branch 'devel' of https://github.com/SumGuo-88/deepmd-kit into …
SumGuo-88 Jan 7, 2025
47fe45b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 7, 2025
98890c2
Merge branch 'devel' into devel
SumGuo-88 Jan 7, 2025
b9bdee5
make truetype for more sys
SumGuo-88 Jan 9, 2025
a30053f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 9, 2025
cfbc88a
Add skip element check function to Chang bias
SumGuo-88 Jan 9, 2025
73a20b0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 9, 2025
0b29b05
make changebias control minframes
SumGuo-88 Jan 10, 2025
85d4da3
Merge branch 'devel' of https://github.com/SumGuo-88/deepmd-kit into …
SumGuo-88 Jan 10, 2025
0400233
check merge
SumGuo-88 Jan 10, 2025
c05ffb1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 10, 2025
139f037
improve ut with all frames
SumGuo-88 Jan 10, 2025
5e826bf
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 10, 2025
edf1d91
check ut
SumGuo-88 Jan 10, 2025
3887013
Merge branch 'devel' of https://github.com/SumGuo-88/deepmd-kit into …
SumGuo-88 Jan 10, 2025
eb9f068
check
SumGuo-88 Jan 10, 2025
10ef768
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 10, 2025
c2dc7ef
check skip logic and def name
SumGuo-88 Jan 10, 2025
8763165
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 10, 2025
d5596bf
improve warning readable
SumGuo-88 Jan 10, 2025
9f389ad
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 10, 2025
0c76ad9
check args
SumGuo-88 Jan 10, 2025
58647f3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 10, 2025
4ce9cfb
check stat.py
SumGuo-88 Jan 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions deepmd/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -737,6 +737,11 @@ def main_parser() -> argparse.ArgumentParser:
default=None,
help="Model branch chosen for changing bias if multi-task model.",
)
parser_change_bias.add_argument(
"--skip-elementcheck",
action="store_false",
SumGuo-88 marked this conversation as resolved.
Show resolved Hide resolved
help="Enable this option to skip element checks if any error occurs while retrieving statistical data.",
)

# --version
parser.add_argument(
Expand Down
3 changes: 3 additions & 0 deletions deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,7 @@ def change_bias(
numb_batch: int = 0,
model_branch: Optional[str] = None,
output: Optional[str] = None,
elem_check_stat: bool = True,
) -> None:
if input_file.endswith(".pt"):
old_state_dict = torch.load(
Expand Down Expand Up @@ -472,6 +473,7 @@ def change_bias(
data_single.systems,
data_single.dataloaders,
nbatches,
enable_element_completion=elem_check_stat,
)
updated_model = training.model_change_out_bias(
model_to_change, sampled_data, _bias_adjust_mode=bias_adjust_mode
Expand Down Expand Up @@ -555,6 +557,7 @@ def main(args: Optional[Union[list[str], argparse.Namespace]] = None) -> None:
numb_batch=FLAGS.numb_batch,
model_branch=FLAGS.model_branch,
output=FLAGS.output,
elem_check_stat=FLAGS.skip_elementcheck,
SumGuo-88 marked this conversation as resolved.
Show resolved Hide resolved
)
elif FLAGS.command == "compress":
FLAGS.input = str(Path(FLAGS.input).with_suffix(".pth"))
Expand Down
8 changes: 8 additions & 0 deletions deepmd/pt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,12 @@ def __init__(
self.max_ckpt_keep = training_params.get("max_ckpt_keep", 5)
self.display_in_training = training_params.get("disp_training", True)
self.timing_in_training = training_params.get("time_training", True)
self.min_frames_per_element_forstat = training_params.get(
"min_frames_per_element_forstat", 10
)
self.enable_element_completion = training_params.get(
"enable_element_completion", True
)
self.change_bias_after_training = training_params.get(
"change_bias_after_training", False
)
Expand Down Expand Up @@ -227,6 +233,8 @@ def get_sample():
_training_data.systems,
_training_data.dataloaders,
_data_stat_nbatch,
self.min_frames_per_element_forstat,
self.enable_element_completion,
)
return sampled

Expand Down
41 changes: 38 additions & 3 deletions deepmd/pt/utils/dataset.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
# SPDX-License-Identifier: LGPL-3.0-or-later


from collections import (
defaultdict,
)
from typing import (
Optional,
)

import numpy as np
from torch.utils.data import (
Dataset,
)
Expand All @@ -17,10 +20,10 @@

class DeepmdDataSetForLoader(Dataset):
def __init__(self, system: str, type_map: Optional[list[str]] = None) -> None:
"""Construct DeePMD-style dataset containing frames cross different systems.
"""Construct DeePMD-style dataset containing frames across different systems.

Args:
- systems: Paths to systems.
- system: Path to the system.
- type_map: Atom types.
"""
self.system = system
Expand All @@ -40,6 +43,38 @@ def __getitem__(self, index):
b_data["natoms"] = self._natoms_vec
return b_data

def get_frame_index(self):
SumGuo-88 marked this conversation as resolved.
Show resolved Hide resolved
"""
Get the frame index and the number of frames with all the elements in the system.
This function is only used in the mixed type.

Returns
-------
element_counts : dict
A dictionary where:
- The key is the element type.
- The value is another dictionary with the following keys:
- "frames": int
The total number of frames in which the element appears.
- "indices": list of int
A list of row indices where the element is found in the dataset.
"""
element_counts = defaultdict(lambda: {"frames": 0, "indices": []})
set_files = self._data_system.dirs
base_offset = 0
for set_file in set_files:
element_data = self._data_system._load_type_mix(set_file)
unique_elements = np.unique(element_data)
for elem in unique_elements:
frames_with_elem = np.any(element_data == elem, axis=1)
row_indices = np.where(frames_with_elem)[0]
row_indices_global = np.where(frames_with_elem)[0] + base_offset
element_counts[elem]["frames"] += len(row_indices)
element_counts[elem]["indices"].extend(row_indices_global.tolist())
base_offset += element_data.shape[0]
element_counts = dict(element_counts)
return element_counts

def add_data_requirement(self, data_requirement: list[DataRequirementItem]) -> None:
"""Add data requirement for this data system."""
for data_item in data_requirement:
Expand Down
188 changes: 161 additions & 27 deletions deepmd/pt/utils/stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,51 +35,185 @@
log = logging.getLogger(__name__)


def make_stat_input(datasets, dataloaders, nbatches):
def make_stat_input(
datasets,
dataloaders,
nbatches,
min_frames_per_element_forstat=10,
enable_element_completion=True,
):
"""Pack data for statistics.
Element checking is only enabled with mixed_type.

Args:
- dataset: A list of dataset to analyze.
- datasets: A list of datasets to analyze.
- dataloaders: Corresponding dataloaders for the datasets.
- nbatches: Batch count for collecting stats.
- min_frames_per_element_forstat: Minimum frames required for statistics.
- enable_element_completion: Whether to perform missing element completion (default: True).

Returns
-------
- a list of dicts, each of which contains data from a system
- A list of dicts, each of which contains data from a system.
"""
lst = []
log.info(f"Packing data for statistics from {len(datasets)} systems")
for i in range(len(datasets)):
sys_stat = {}
with torch.device("cpu"):
iterator = iter(dataloaders[i])
numb_batches = min(nbatches, len(dataloaders[i]))
for _ in range(numb_batches):
try:
stat_data = next(iterator)
except StopIteration:
iterator = iter(dataloaders[i])
stat_data = next(iterator)
for dd in stat_data:
if stat_data[dd] is None:
sys_stat[dd] = None
elif isinstance(stat_data[dd], torch.Tensor):
if dd not in sys_stat:
sys_stat[dd] = []
sys_stat[dd].append(stat_data[dd])
elif isinstance(stat_data[dd], np.float32):
sys_stat[dd] = stat_data[dd]
else:
pass
total_element_types = set()
global_element_counts = {}
collect_ele = defaultdict(int)
SumGuo-88 marked this conversation as resolved.
Show resolved Hide resolved
if datasets[0].mixed_type:
if enable_element_completion:
log.info(
f"Element check enabled. "
f"Verifying if frames with elements meet the set of {min_frames_per_element_forstat}."
)
else:
log.info(
"Element completion is disabled. Skipping missing element handling."
)

SumGuo-88 marked this conversation as resolved.
Show resolved Hide resolved
def process_batches(dataloader, sys_stat):
"""Process batches from a dataloader to collect statistics."""
iterator = iter(dataloader)
numb_batches = min(nbatches, len(dataloader))
for _ in range(numb_batches):
try:
stat_data = next(iterator)
except StopIteration:
iterator = iter(dataloader)
stat_data = next(iterator)

Check warning on line 84 in deepmd/pt/utils/stat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/stat.py#L82-L84

Added lines #L82 - L84 were not covered by tests
for dd in stat_data:
if stat_data[dd] is None:
sys_stat[dd] = None
elif isinstance(stat_data[dd], torch.Tensor):
if dd not in sys_stat:
sys_stat[dd] = []
sys_stat[dd].append(stat_data[dd])
elif isinstance(stat_data[dd], np.float32):
sys_stat[dd] = stat_data[dd]
else:
pass
SumGuo-88 marked this conversation as resolved.
Show resolved Hide resolved
Comment on lines +76 to +96
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Improve error handling and data type validation.

The function silently handles unknown data types and lacks error handling for tensor operations.

    def process_batches(dataloader, sys_stat):
        """Process batches from a dataloader to collect statistics."""
        iterator = iter(dataloader)
        numb_batches = min(nbatches, len(dataloader))
        for _ in range(numb_batches):
            try:
                stat_data = next(iterator)
            except StopIteration:
                iterator = iter(dataloader)
                stat_data = next(iterator)
-            for dd in stat_data:
+            for dd, value in stat_data.items():
+                try:
+                    if value is None:
+                        sys_stat[dd] = None
+                    elif isinstance(value, torch.Tensor):
+                        if dd not in sys_stat:
+                            sys_stat[dd] = []
+                        sys_stat[dd].append(value)
+                    elif isinstance(value, np.float32):
+                        sys_stat[dd] = value
+                    else:
+                        log.warning(f"Unexpected data type {type(value)} for key {dd}")
+                except Exception as e:
+                    log.error(f"Error processing key {dd}: {str(e)}")
+                    raise
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def process_batches(dataloader, sys_stat):
"""Process batches from a dataloader to collect statistics."""
iterator = iter(dataloader)
numb_batches = min(nbatches, len(dataloader))
for _ in range(numb_batches):
try:
stat_data = next(iterator)
except StopIteration:
iterator = iter(dataloader)
stat_data = next(iterator)
for dd in stat_data:
if stat_data[dd] is None:
sys_stat[dd] = None
elif isinstance(stat_data[dd], torch.Tensor):
if dd not in sys_stat:
sys_stat[dd] = []
sys_stat[dd].append(stat_data[dd])
elif isinstance(stat_data[dd], np.float32):
sys_stat[dd] = stat_data[dd]
else:
pass
def process_batches(dataloader, sys_stat):
"""Process batches from a dataloader to collect statistics."""
iterator = iter(dataloader)
numb_batches = min(nbatches, len(dataloader))
for _ in range(numb_batches):
try:
stat_data = next(iterator)
except StopIteration:
iterator = iter(dataloader)
stat_data = next(iterator)
for dd, value in stat_data.items():
try:
if value is None:
sys_stat[dd] = None
elif isinstance(value, torch.Tensor):
if dd not in sys_stat:
sys_stat[dd] = []
sys_stat[dd].append(value)
elif isinstance(value, np.float32):
sys_stat[dd] = value
else:
log.warning(f"Unexpected data type {type(value)} for key {dd}")
except Exception as e:
log.error(f"Error processing key {dd}: {str(e)}")
raise


SumGuo-88 marked this conversation as resolved.
Show resolved Hide resolved
def process_with_new_frame(sys_indices, newele_counter, miss):
for sys_info in sys_indices:
sys_index = sys_info["sys_index"]
frames = sys_info["frames"]
sys = datasets[sys_index]
for frame in frames:
newele_counter += 1
if newele_counter <= min_frames_per_element_forstat:
frame_data = sys.__getitem__(frame)
SumGuo-88 marked this conversation as resolved.
Show resolved Hide resolved
assert miss in frame_data["atype"], (
"Element check failed. "
"If you are running in 'change-bias' mode, use '--skip-elementcheck' to disable this check. "
"If you encountered this error during model training, set 'enable_element_completion' to False "
"in the 'training' section of your input file."
)
sys_stat_new = {}
for dd in frame_data:
if dd == "type":
continue
if frame_data[dd] is None:
sys_stat_new[dd] = None
elif isinstance(frame_data[dd], np.ndarray):
if dd not in sys_stat_new:
sys_stat_new[dd] = []
tensor_data = torch.from_numpy(frame_data[dd])
tensor_data = tensor_data.unsqueeze(0)
sys_stat_new[dd].append(tensor_data)
elif isinstance(frame_data[dd], np.float32):
sys_stat_new[dd] = frame_data[dd]
else:
pass
finalize_stats(sys_stat_new)
lst.append(sys_stat_new)
else:
break

Check warning on line 131 in deepmd/pt/utils/stat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/stat.py#L131

Added line #L131 was not covered by tests
Comment on lines +98 to +132
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Add validation and error handling for frame processing.

The function needs better validation and error handling:

  1. Add validation for empty sys_indices
  2. Add error handling for frame data retrieval
  3. Validate frame data before processing
    def process_with_new_frame(sys_indices, newele_counter, miss):
+        if not sys_indices:
+            log.warning(f"No system indices provided for element {miss}")
+            return
        for sys_info in sys_indices:
            sys_index = sys_info["sys_index"]
            frames = sys_info["frames"]
+            if not frames:
+                log.warning(f"No frames found for system {sys_index}")
+                continue
            sys = datasets[sys_index]
            for frame in frames:
                newele_counter += 1
                if newele_counter <= min_frames_per_element_forstat:
-                    frame_data = sys.__getitem__(frame)
+                    try:
+                        frame_data = sys.__getitem__(frame)
+                        if "atype" not in frame_data:
+                            log.warning(f"Frame {frame} does not contain type information")
+                            continue
+                        if miss not in frame_data["atype"]:
+                            log.warning(f"Frame {frame} does not contain element {miss}")
+                            continue
+                    except Exception as e:
+                        log.error(f"Failed to get frame {frame} from system {sys_index}: {e}")
+                        continue
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def process_with_new_frame(sys_indices, newele_counter, miss):
for sys_info in sys_indices:
sys_index = sys_info["sys_index"]
frames = sys_info["frames"]
sys = datasets[sys_index]
for frame in frames:
newele_counter += 1
if newele_counter <= min_frames_per_element_forstat:
frame_data = sys.__getitem__(frame)
assert miss in frame_data["atype"], (
"Element check failed. "
"If you are running in 'change-bias' mode, use '--skip-elementcheck' to disable this check. "
"If you encountered this error during model training, set 'enable_element_completion' to False "
"in the 'training' section of your input file."
)
sys_stat_new = {}
for dd in frame_data:
if dd == "type":
continue
if frame_data[dd] is None:
sys_stat_new[dd] = None
elif isinstance(frame_data[dd], np.ndarray):
if dd not in sys_stat_new:
sys_stat_new[dd] = []
tensor_data = torch.from_numpy(frame_data[dd])
tensor_data = tensor_data.unsqueeze(0)
sys_stat_new[dd].append(tensor_data)
elif isinstance(frame_data[dd], np.float32):
sys_stat_new[dd] = frame_data[dd]
else:
pass
finalize_stats(sys_stat_new)
lst.append(sys_stat_new)
else:
break
def process_with_new_frame(sys_indices, newele_counter, miss):
if not sys_indices:
log.warning(f"No system indices provided for element {miss}")
return
for sys_info in sys_indices:
sys_index = sys_info["sys_index"]
frames = sys_info["frames"]
if not frames:
log.warning(f"No frames found for system {sys_index}")
continue
sys = datasets[sys_index]
for frame in frames:
newele_counter += 1
if newele_counter <= min_frames_per_element_forstat:
try:
frame_data = sys.__getitem__(frame)
if "atype" not in frame_data:
log.warning(f"Frame {frame} does not contain type information")
continue
if miss not in frame_data["atype"]:
log.warning(f"Frame {frame} does not contain element {miss}")
continue
except Exception as e:
log.error(f"Failed to get frame {frame} from system {sys_index}: {e}")
continue
assert miss in frame_data["atype"], (
"Element check failed. "
"If you are running in 'change-bias' mode, use '--skip-elementcheck' to disable this check. "
"If you encountered this error during model training, set 'enable_element_completion' to False "
"in the 'training' section of your input file."
)
sys_stat_new = {}
for dd in frame_data:
if dd == "type":
continue
if frame_data[dd] is None:
sys_stat_new[dd] = None
elif isinstance(frame_data[dd], np.ndarray):
if dd not in sys_stat_new:
sys_stat_new[dd] = []
tensor_data = torch.from_numpy(frame_data[dd])
tensor_data = tensor_data.unsqueeze(0)
sys_stat_new[dd].append(tensor_data)
elif isinstance(frame_data[dd], np.float32):
sys_stat_new[dd] = frame_data[dd]
else:
pass
finalize_stats(sys_stat_new)
lst.append(sys_stat_new)
else:
break


SumGuo-88 marked this conversation as resolved.
Show resolved Hide resolved
def finalize_stats(sys_stat):
"""Finalize statistics by concatenating tensors."""
for key in sys_stat:
if isinstance(sys_stat[key], np.float32):
pass
elif sys_stat[key] is None or sys_stat[key][0] is None:
elif sys_stat[key] is None or (
isinstance(sys_stat[key], list)
and (len(sys_stat[key]) == 0 or sys_stat[key][0] is None)
):
sys_stat[key] = None
elif isinstance(stat_data[dd], torch.Tensor):
elif isinstance(sys_stat[key][0], torch.Tensor):
sys_stat[key] = torch.cat(sys_stat[key], dim=0)
dict_to_device(sys_stat)
SumGuo-88 marked this conversation as resolved.
Show resolved Hide resolved

Comment on lines +134 to +147
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Add error handling for tensor operations.

The function lacks error handling for tensor operations which could fail.

    def finalize_stats(sys_stat):
        """Finalize statistics by concatenating tensors."""
        for key in sys_stat:
-            if isinstance(sys_stat[key], np.float32):
-                pass
-            elif sys_stat[key] is None or (
-                isinstance(sys_stat[key], list)
-                and (len(sys_stat[key]) == 0 or sys_stat[key][0] is None)
-            ):
-                sys_stat[key] = None
-            elif isinstance(sys_stat[key][0], torch.Tensor):
-                sys_stat[key] = torch.cat(sys_stat[key], dim=0)
+            try:
+                if isinstance(sys_stat[key], np.float32):
+                    continue
+                elif sys_stat[key] is None or (
+                    isinstance(sys_stat[key], list)
+                    and (len(sys_stat[key]) == 0 or sys_stat[key][0] is None)
+                ):
+                    sys_stat[key] = None
+                elif isinstance(sys_stat[key][0], torch.Tensor):
+                    sys_stat[key] = torch.cat(sys_stat[key], dim=0)
+                else:
+                    log.warning(f"Unexpected data type for key {key}")
+            except Exception as e:
+                log.error(f"Error finalizing stats for key {key}: {e}")
+                raise
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def finalize_stats(sys_stat):
"""Finalize statistics by concatenating tensors."""
for key in sys_stat:
if isinstance(sys_stat[key], np.float32):
pass
elif sys_stat[key] is None or sys_stat[key][0] is None:
elif sys_stat[key] is None or (
isinstance(sys_stat[key], list)
and (len(sys_stat[key]) == 0 or sys_stat[key][0] is None)
):
sys_stat[key] = None
elif isinstance(stat_data[dd], torch.Tensor):
elif isinstance(sys_stat[key][0], torch.Tensor):
sys_stat[key] = torch.cat(sys_stat[key], dim=0)
dict_to_device(sys_stat)
def finalize_stats(sys_stat):
"""Finalize statistics by concatenating tensors."""
for key in sys_stat:
try:
if isinstance(sys_stat[key], np.float32):
continue
elif sys_stat[key] is None or (
isinstance(sys_stat[key], list)
and (len(sys_stat[key]) == 0 or sys_stat[key][0] is None)
):
sys_stat[key] = None
elif isinstance(sys_stat[key][0], torch.Tensor):
sys_stat[key] = torch.cat(sys_stat[key], dim=0)
else:
log.warning(f"Unexpected data type for key {key}")
except Exception as e:
log.error(f"Error finalizing stats for key {key}: {e}")
raise
dict_to_device(sys_stat)

for sys_index, (dataset, dataloader) in enumerate(zip(datasets, dataloaders)):
sys_stat = {}
with torch.device("cpu"):
process_batches(dataloader, sys_stat)
if datasets[0].mixed_type and enable_element_completion:
element_data = torch.cat(sys_stat["atype"], dim=0)
collect_values = torch.unique(element_data.flatten(), sorted=True)
for elem in collect_values.tolist():
frames_with_elem = torch.any(element_data == elem, dim=1)
row_indices = torch.where(frames_with_elem)[0]
collect_ele[elem] += len(row_indices)
finalize_stats(sys_stat)
lst.append(sys_stat)

Comment on lines +148 to +161
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Add validation for dataset processing.

The main loop needs validation for empty datasets and error handling for tensor operations.

    for sys_index, (dataset, dataloader) in enumerate(zip(datasets, dataloaders)):
+        if not dataset or not dataloader:
+            log.warning(f"Empty dataset or dataloader at index {sys_index}")
+            continue
        sys_stat = {}
        with torch.device("cpu"):
-            process_batches(dataloader, sys_stat)
-            if datasets[0].mixed_type and enable_element_completion:
-                element_data = torch.cat(sys_stat["atype"], dim=0)
+            try:
+                process_batches(dataloader, sys_stat)
+                if datasets[0].mixed_type and enable_element_completion:
+                    if not sys_stat.get("atype"):
+                        log.warning(f"No type information found in dataset {sys_index}")
+                        continue
+                    element_data = torch.cat(sys_stat["atype"], dim=0)
+            except Exception as e:
+                log.error(f"Error processing dataset {sys_index}: {e}")
+                continue
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
for sys_index, (dataset, dataloader) in enumerate(zip(datasets, dataloaders)):
sys_stat = {}
with torch.device("cpu"):
process_batches(dataloader, sys_stat)
if datasets[0].mixed_type and enable_element_completion:
element_data = torch.cat(sys_stat["atype"], dim=0)
collect_values = torch.unique(element_data.flatten(), sorted=True)
for elem in collect_values.tolist():
frames_with_elem = torch.any(element_data == elem, dim=1)
row_indices = torch.where(frames_with_elem)[0]
collect_ele[elem] += len(row_indices)
finalize_stats(sys_stat)
lst.append(sys_stat)
for sys_index, (dataset, dataloader) in enumerate(zip(datasets, dataloaders)):
if not dataset or not dataloader:
log.warning(f"Empty dataset or dataloader at index {sys_index}")
continue
sys_stat = {}
with torch.device("cpu"):
try:
process_batches(dataloader, sys_stat)
if datasets[0].mixed_type and enable_element_completion:
if not sys_stat.get("atype"):
log.warning(f"No type information found in dataset {sys_index}")
continue
element_data = torch.cat(sys_stat["atype"], dim=0)
except Exception as e:
log.error(f"Error processing dataset {sys_index}: {e}")
continue
collect_values = torch.unique(element_data.flatten(), sorted=True)
for elem in collect_values.tolist():
frames_with_elem = torch.any(element_data == elem, dim=1)
row_indices = torch.where(frames_with_elem)[0]
collect_ele[elem] += len(row_indices)
finalize_stats(sys_stat)
lst.append(sys_stat)

# get frame index
if datasets[0].mixed_type and enable_element_completion:
element_counts = dataset.get_frame_index()
for elem, data in element_counts.items():
indices = data["indices"]
count = data["frames"]
total_element_types.add(elem)
if elem not in global_element_counts:
global_element_counts[elem] = {"count": 0, "indices": []}
if count > min_frames_per_element_forstat:
global_element_counts[elem]["count"] += (
min_frames_per_element_forstat
)
indices = indices[:min_frames_per_element_forstat]
global_element_counts[elem]["indices"].append(
{"sys_index": sys_index, "frames": indices}
)
else:
global_element_counts[elem]["count"] += count
global_element_counts[elem]["indices"].append(
{"sys_index": sys_index, "frames": indices}
)
else:
if (

Check warning on line 184 in deepmd/pt/utils/stat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/stat.py#L184

Added line #L184 was not covered by tests
global_element_counts[elem]["count"]
>= min_frames_per_element_forstat
):
pass

Check warning on line 188 in deepmd/pt/utils/stat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/stat.py#L188

Added line #L188 was not covered by tests
else:
global_element_counts[elem]["count"] += count
global_element_counts[elem]["indices"].append(

Check warning on line 191 in deepmd/pt/utils/stat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/stat.py#L190-L191

Added lines #L190 - L191 were not covered by tests
{"sys_index": sys_index, "frames": indices}
)
# Complement
SumGuo-88 marked this conversation as resolved.
Show resolved Hide resolved
if datasets[0].mixed_type and enable_element_completion:
for elem, data in global_element_counts.items():
indices_count = data["count"]
if indices_count < min_frames_per_element_forstat:
log.warning(
f"The number of frames in your datasets with element {elem} is {indices_count}, "
f"which is less than the required {min_frames_per_element_forstat}"
)
collect_elements = collect_ele.keys()
missing_elements = total_element_types - collect_elements
SumGuo-88 marked this conversation as resolved.
Show resolved Hide resolved
collect_miss_element = set()
for ele, count in collect_ele.items():
if count < min_frames_per_element_forstat:
collect_miss_element.add(ele)
missing_elements.add(ele)
for miss in missing_elements:
sys_indices = global_element_counts[miss].get("indices", [])
if miss in collect_miss_element:
newele_counter = collect_ele.get(miss, 0)
else:
newele_counter = 0
process_with_new_frame(sys_indices, newele_counter, miss)
SumGuo-88 marked this conversation as resolved.
Show resolved Hide resolved
return lst


Expand Down
14 changes: 14 additions & 0 deletions deepmd/utils/argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -2893,6 +2893,20 @@ def training_args(
optional=True,
doc=doc_only_pt_supported + doc_gradient_max_norm,
),
Argument(
"min_frames_per_element_forstat",
int,
default=10,
optional=True,
doc="The minimum number of frames per element used for statistics when using the mixed type.",
),
Argument(
"enable_element_completion",
bool,
optional=True,
default=True,
doc="Whether to check elements when using the mixed type",
),
SumGuo-88 marked this conversation as resolved.
Show resolved Hide resolved
]
variants = [
Variant(
Expand Down
16 changes: 8 additions & 8 deletions deepmd/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,14 +547,6 @@
if self.mixed_type:
# nframes x natoms
atom_type_mix = self._load_type_mix(set_name)
if self.enforce_type_map:
try:
atom_type_mix_ = self.type_idx_map[atom_type_mix].astype(np.int32)
except IndexError as e:
raise IndexError(
f"some types in 'real_atom_types.npy' of set {set_name} are not contained in {self.get_ntypes()} types!"
) from e
atom_type_mix = atom_type_mix_
real_type = atom_type_mix.reshape([nframes, self.natoms])
data["type"] = real_type
natoms = data["type"].shape[1]
Expand Down Expand Up @@ -704,6 +696,14 @@
def _load_type_mix(self, set_name: DPPath):
type_path = set_name / "real_atom_types.npy"
real_type = type_path.load_numpy().astype(np.int32).reshape([-1, self.natoms])
if self.enforce_type_map:
try:
atom_type_mix_ = self.type_idx_map[real_type].astype(np.int32)
except IndexError as e:
raise IndexError(

Check warning on line 703 in deepmd/utils/data.py

View check run for this annotation

Codecov / codecov/patch

deepmd/utils/data.py#L702-L703

Added lines #L702 - L703 were not covered by tests
f"some types in 'real_atom_types.npy' of set {set_name} are not contained in {self.get_ntypes()} types!"
) from e
real_type = atom_type_mix_
return real_type

def _make_idx_map(self, atom_type):
Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
7 changes: 7 additions & 0 deletions source/tests/pt/mixed_type_data/sys.000000/type.raw
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
0
0
0
0
0
0
0
Loading
Loading