-
Notifications
You must be signed in to change notification settings - Fork 526
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
(fix) Make bias statistics complete for all elements #4496
base: devel
Are you sure you want to change the base?
Conversation
for more information, see https://pre-commit.ci
📝 WalkthroughWalkthroughThe pull request introduces modifications in the DeepMD-kit's PyTorch utility modules. A new public method Changes
Possibly related PRs
Suggested labels
Suggested reviewers
Finishing Touches
Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media? 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (1)
deepmd/pt/utils/stat.py (1)
110-119
: Avoid potential index-out-of-range errors when concatenating Tensors.
The code checks “isinstance(sys_stat[key], list) and isinstance(sys_stat[key][0], torch.Tensor)” without verifying that the list is non-empty. This might raise an IndexError if the list is empty. Consider adding a length check before accessing sys_stat[key][0].Possible safe check:
-if isinstance(sys_stat[key], list) and isinstance(sys_stat[key][0], torch.Tensor): +if isinstance(sys_stat[key], list) and len(sys_stat[key]) > 0 \ + and isinstance(sys_stat[key][0], torch.Tensor):
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
deepmd/pt/utils/dataset.py
(1 hunks)deepmd/pt/utils/stat.py
(1 hunks)
🧰 Additional context used
🪛 Ruff (0.8.2)
deepmd/pt/utils/dataset.py
44-44: SyntaxError: Expected an indented block after function definition
🔇 Additional comments (3)
deepmd/pt/utils/stat.py (3)
91-96
: Validate 'atype' presence and handle empty arrays gracefully.
While updating the collected elements, consider explicitly checking that each 'atype' array is non-empty. If a dataset yields an empty array, it might bypass your missing-elements logic.
✅ Verification successful
Let me gather more information about how atype
is handled in the codebase.
Let me try again with the correct file type flag for ripgrep.
Based on the code analysis, I can now generate the final response:
Empty arrays are handled safely by numpy.unique and set.update
The code uses numpy.unique()
which safely handles empty arrays by returning an empty array, and set.update()
which safely handles empty iterables. The logic for missing elements works correctly even with empty arrays since an empty array's unique values would simply not add any elements to the set.
🏁 Scripts executed
The following scripts were executed for the analysis:
Error: No shell scripts found
Script:
#!/bin/bash
# Search for patterns related to 'atype' handling and array emptiness checks
rg -A 3 "atype" --type python
# Also look for the specific class/function containing this code
ast-grep --pattern $'def $_(self, $_):
$$$
collected_elements = set()
for sys_stat in lst:
if "atype" in sys_stat:
collected_elements.update($$$)
$$$'
Length of output: 248
Script:
#!/bin/bash
# Search for patterns related to 'atype' handling and array emptiness checks
rg -A 3 "atype" -t py
# Look for similar array handling patterns in the codebase
rg -t py "numpy\.unique|\.update.*numpy" -A 2
# Find the specific file we're analyzing
fd "stat.py" --type f
Length of output: 89317
85-90
: Consider using logging and verifying that all datasets are included.
You are only collecting the set of elements from the first dataset and simply printing them. If your intent is to gather possible elements across all datasets, you may want to iterate over all datasets. Also, using "log.info" would be more consistent than a raw "print" in a library setting.
Possible fix for the print statement:
- print('we want', all_elements)
+ log.info(f"We want elements: {all_elements}")
97-109
: Ensure partial coverage logic for missing elements is intended.
Here, frames are appended until the number of records reaches 'nbatches'. Verify that the short-circuiting with “break” actually meets your coverage needs for missing elements. If your goal is to gather as many frames as possible across all missing elements, you might need a more explicit limit or a different approach.
✅ Verification successful
Let me gather more context about this code to better understand the coverage logic for missing elements.
Let me search for more context about the initialization and usage of nbatches
to better understand the coverage requirements.
Let me search for the initialization of nbatches
to understand its purpose in the coverage logic.
Let me search for more context about how nbatches
is used and its relationship with the dataset statistics.
Based on the gathered context, I can now provide a final response about the coverage logic for missing elements.
The current partial coverage logic is intentional and correct.
The code's behavior is appropriate because:
nbatches
is a parameter that defines the desired number of batches for collecting statistics, used consistently throughout themake_stat_input
function- The break condition
if len(lst[i]['atype']) >= nbatches
ensures each system collects up to the specified number of batches, preventing over-sampling - This is consistent with how regular batches are collected earlier in the function where
numb_batches = min(nbatches, len(dataloaders[i]))
is used - The goal is to gather a statistically significant sample (controlled by nbatches) rather than exhaustively processing all frames containing missing elements
The current implementation correctly balances between gathering sufficient data for statistics while maintaining consistent batch sizes across all elements.
🏁 Scripts executed
The following scripts were executed for the analysis:
Error: No shell scripts found
Script:
#!/bin/bash
# Let's first look at the broader context of this code
rg -B 10 -A 5 "for missing_element in missing_elements" deepmd/pt/utils/stat.py
# Also search for any related tests that might help understand the intended behavior
fd "test.*stat.*py" -x rg -l "missing_element"
# Look for the initialization of nbatches and lst variables
rg -B 5 "nbatches.*=|lst.*=" deepmd/pt/utils/stat.py
Length of output: 1724
Script:
#!/bin/bash
# Search for the function definition and its parameters
ast-grep --pattern 'def $FUNC($PARAMS) {
$$$
nbatches
$$$
}' deepmd/pt/utils/stat.py
# Look for any documentation or comments about nbatches
rg -B 3 -A 3 "nbatches" deepmd/pt/utils/stat.py
# Check if there are any tests that verify this behavior
fd "test.*stat.*py" -x cat {}
Length of output: 66294
Script:
#!/bin/bash
# Look for the function definition and documentation
rg -B 10 "def make_stat_input" deepmd/pt/utils/stat.py
# Search for any tests that verify the behavior with missing elements
rg -B 5 -A 5 "missing_element" deepmd/pt/utils/stat.py
Length of output: 1079
Script:
#!/bin/bash
# Look for the full function implementation and its docstring
rg -B 20 -A 50 "def make_stat_input" deepmd/pt/utils/stat.py
# Search for any tests that verify the statistics collection
fd "test.*stat.*py" -x rg -A 10 "make_stat_input"
# Look for the initialization of lst and how frames are collected
rg -B 5 "lst.*=.*\[\]" deepmd/pt/utils/stat.py
Length of output: 4539
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (2)
deepmd/pt/utils/stat.py (2)
89-89
: Remove debug print statement.The print statement
print("we want", all_elements)
appears to be debug code that should be removed or replaced with proper logging.- print("we want", all_elements) + log.debug(f"Required elements for statistics: {all_elements}")
97-111
: Optimize nested loops and add error handling.The nested loops for handling missing elements could be optimized, and error handling should be added for invalid frame indices.
for missing_element in missing_elements: for i, dataset in enumerate(datasets): if hasattr(dataset, "element_to_frames"): - frame_indices = dataset.element_to_frames.get( - missing_element, [] - ) + try: + frame_indices = dataset.element_to_frames.get(missing_element, []) + if not frame_indices: + continue + + # Pre-check if we need more frames + if len(lst[i]["atype"]) >= nbatches: + break + + # Process frames in batch + for frame_idx in frame_indices: + frame_data = dataset[frame_idx] + if any(key not in lst[i] for key in frame_data): + lst[i].update({key: [] for key in frame_data if key not in lst[i]}) + for key in frame_data: + lst[i][key].append(frame_data[key]) + if len(lst[i]["atype"]) >= nbatches: + break + except Exception as e: + log.warning(f"Error processing frames for element {missing_element}: {e}") + continue - for frame_idx in frame_indices: - if len(lst[i]["atype"]) >= nbatches: - break - frame_data = dataset[frame_idx] - for key in frame_data: - if key not in lst[i]: - lst[i][key] = [] - lst[i][key].append(frame_data[key])
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
deepmd/pt/utils/dataset.py
(1 hunks)deepmd/pt/utils/stat.py
(1 hunks)
🧰 Additional context used
🪛 Ruff (0.8.2)
deepmd/pt/utils/dataset.py
44-44: SyntaxError: Expected an indented block after function definition
🔇 Additional comments (2)
deepmd/pt/utils/dataset.py (2)
43-44
:
Fix the indentation error in method definition.
The method definition has incorrect indentation which will cause a SyntaxError. It should be aligned with other class methods.
- def _build_element_to_frames(self):
- """Mapping element types to frame indexes"""
+ def _build_element_to_frames(self):
+ """Mapping element types to frame indexes"""
Likely invalid or redundant comment.
🧰 Tools
🪛 Ruff (0.8.2)
44-44: SyntaxError: Expected an indented block after function definition
43-53
: 🛠️ Refactor suggestion
Make frame limit configurable and enhance documentation.
- The hard-coded limit of 10 frames per element should be configurable.
- The docstring should be more descriptive about the method's purpose and return value.
- def _build_element_to_frames(self):
- """Mapping element types to frame indexes"""
+ def _build_element_to_frames(self, max_frames_per_element: int = 10) -> dict[int, list[int]]:
+ """Build a mapping of element types to their corresponding frame indices.
+
+ Args:
+ max_frames_per_element: Maximum number of frames to store per element type.
+
+ Returns:
+ A dictionary mapping element types (int) to lists of frame indices (list[int])
+ where each element type appears.
+ """
element_to_frames = {element: [] for element in range(self._ntypes)}
for frame_idx in range(len(self)):
frame_data = self._data_system.get_item_torch(frame_idx)
elements = frame_data["atype"]
for element in set(elements):
- if len(element_to_frames[element]) < 10:
+ if len(element_to_frames[element]) < max_frames_per_element:
element_to_frames[element].append(frame_idx)
return element_to_frames
Likely invalid or redundant comment.
🧰 Tools
🪛 Ruff (0.8.2)
44-44: SyntaxError: Expected an indented block after function definition
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 4
🧹 Nitpick comments (4)
source/tests/pt/test_make_stat_input.py (4)
14-23
: Consider usingcollections.defaultdict
forelement_to_frames
.
You can simplify the nested checks for element presence in the dictionary by using adefaultdict(list)
, which would eliminate the need for the explicitif atype not in self.element_to_frames:
condition.-from collections import defaultdict class TestDataset: def __init__(self, samples): self.samples = samples - self.element_to_frames = {} + from collections import defaultdict + self.element_to_frames = defaultdict(list) for idx, sample in enumerate(samples): atypes = sample["atype"] for atype in atypes: - if atype not in self.element_to_frames: - self.element_to_frames[atype] = [] self.element_to_frames[atype].append(idx)
25-28
: Rename the property to better reflect usage.
Using@property
but naming itget_all_atype
can be confusing. Consider a more descriptive name likeall_atypes
, since Python properties typically avoid "get_" prefixes.
53-59
: Remove or use the assignedlst
variable.
The variablelst
is assigned but never used, according to static analysis hints. Consider removing it or using it for additional assertions.def test_make_stat_input(self): nbatches = 1 - lst = make_stat_input(self.datasets, self.dataloaders, nbatches=nbatches) + _ = make_stat_input(self.datasets, self.dataloaders, nbatches=nbatches) all_elements = self.system.get_all_atype unique_elements = {1, 2} self.assertEqual(unique_elements, all_elements, "make_stat_input miss elements")🧰 Tools
🪛 Ruff (0.8.2)
55-55: Local variable
lst
is assigned to but never usedRemove assignment to unused variable
lst
(F841)
61-62
: Optional test runner inclusion.
Having theif __name__ == "__main__": unittest.main()
block is fine. You could remove it if tests are run by a dedicated test runner.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
deepmd/pt/utils/dataset.py
(4 hunks)deepmd/pt/utils/stat.py
(1 hunks)source/tests/pt/test_make_stat_input.py
(1 hunks)
🧰 Additional context used
🪛 Ruff (0.8.2)
source/tests/pt/test_make_stat_input.py
55-55: Local variable lst
is assigned to but never used
Remove assignment to unused variable lst
(F841)
🔇 Additional comments (4)
source/tests/pt/test_make_stat_input.py (1)
40-52
: Test setup looks good.
The dataset creation for testing is straightforward and clear. No issues found.
deepmd/pt/utils/dataset.py (2)
21-24
: Docstring clarity is sufficient.
The docstring effectively describes constructor parameters. No corrections needed.
34-34
: Initialization of element frames is a good approach.
Storing the result of _build_element_to_frames()
in self.element_to_frames
and self.get_all_atype
reduces redundancy.
deepmd/pt/utils/stat.py (1)
86-94
: No immediate issues with collection of atomic types.
Collecting and updating sets is correct.
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🧹 Nitpick comments (4)
deepmd/utils/data.py (1)
667-674
: Consider providing more informative error details when the type map lookup fails.
While raisingIndexError
is appropriate, developers might benefit from including the failing value(s). You could, for instance, collect and display the out-of-range types to provide immediate troubleshooting clues. A custom exception or a more descriptive error message can significantly improve clarity and debuggability.Here's an example of how you might refine the exception:
except IndexError as e: - raise IndexError( - f"some types in 'real_atom_types.npy' of set {set_name} are not contained in {self.get_ntypes()} types!" - ) from e + # Gather all invalid elements + invalid_types = np.unique(real_type[(real_type < 0) | (real_type >= len(self.type_idx_map))]) + raise IndexError( + f"One or more invalid types found in 'real_atom_types.npy' of set {set_name}: {invalid_types}. " + f"Ensure all types are within [0, {self.get_ntypes()-1}]." + ) from edeepmd/utils/argcheck.py (1)
2829-2834
: Ensure user awareness of the new argument.The new argument
min_frames_per_element_forstat
is useful for controlling statistic completeness. It might be helpful to specify the expected range (e.g., must be ≥ 1) and how large values impact memory or performance overhead.source/tests/pt/test_make_stat_input.py (1)
68-68
: Remove or utilize the unused variable.The variable
lst
is assigned with the result ofmake_stat_input(...)
but never used. If no further checks are applied, remove it to keep the code clean.- lst = make_stat_input( + make_stat_input(🧰 Tools
🪛 Ruff (0.8.2)
68-68: Local variable
lst
is assigned to but never usedRemove assignment to unused variable
lst
(F841)
deepmd/pt/utils/stat.py (1)
188-197
: Double-check sets for collected vs. missing elements.This code block re-checks missing elements with:
missing_element = all_element - collect_elementsConfirm that the logic aligns with the earlier
missing_elements
sets in lines 110–111 to avoid confusion or duplication.🧰 Tools
🪛 Ruff (0.8.2)
188-188: SyntaxError: unindent does not match any outer indentation level
189-189: SyntaxError: Unexpected indentation
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
deepmd/pt/train/training.py
(2 hunks)deepmd/pt/utils/dataset.py
(3 hunks)deepmd/pt/utils/stat.py
(3 hunks)deepmd/utils/argcheck.py
(1 hunks)deepmd/utils/data.py
(1 hunks)source/tests/pt/test_make_stat_input.py
(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- deepmd/pt/utils/dataset.py
🧰 Additional context used
🪛 Ruff (0.8.2)
deepmd/pt/utils/stat.py
134-134: SyntaxError: Expected a statement
134-134: SyntaxError: Expected a statement
134-134: SyntaxError: Expected a statement
134-134: SyntaxError: Expected a statement
135-135: SyntaxError: Unexpected indentation
144-144: SyntaxError: unindent does not match any outer indentation level
144-144: SyntaxError: Expected a statement
144-144: SyntaxError: Expected a statement
144-145: SyntaxError: Expected a statement
145-145: SyntaxError: Unexpected indentation
176-176: SyntaxError: Expected a statement
176-176: SyntaxError: Expected a statement
176-176: SyntaxError: Expected a statement
176-176: SyntaxError: Expected a statement
188-188: SyntaxError: unindent does not match any outer indentation level
189-189: SyntaxError: Unexpected indentation
231-231: SyntaxError: Expected a statement
231-231: SyntaxError: Expected a statement
231-231: SyntaxError: Expected a statement
231-231: SyntaxError: Expected a statement
source/tests/pt/test_make_stat_input.py
43-43: Loop control variable idx
not used within loop body
Rename unused idx
to _idx
(B007)
68-68: Local variable lst
is assigned to but never used
Remove assignment to unused variable lst
(F841)
🔇 Additional comments (4)
deepmd/pt/train/training.py (2)
145-147
: Add type check or validation for the statistic threshold.While setting
self.min_frames_per_element_forstat
, consider ensuring it's a strictly positive integer. If a negative or zero value is passed, it may cause runtime issues or meaningless statistics.
232-232
: Parameter usage looks correct.Passing
self.min_frames_per_element_forstat
tomake_stat_input
aligns with the newly introduced functionality. Just be sure to verify that all call sites expect this parameter and properly handle out-of-range values.source/tests/pt/test_make_stat_input.py (1)
41-50
: Optimize loop variable usage.The static analysis hint suggests renaming
_idx
if it were truly unused, but hereidx
is used to assign frames. The accusations of an unused loop variable are false. You can safely ignore that static analysis hint.🧰 Tools
🪛 Ruff (0.8.2)
43-43: Loop control variable
idx
not used within loop bodyRename unused
idx
to_idx
(B007)
deepmd/pt/utils/stat.py (1)
52-110
: Confirm correctness of cumulative element count logic.You aggregate
total_element_counts[elem]["count"]
, but also trackindices
. Ensure you don’t exceed list boundaries when collecting indices for up tomin_frames_per_element_forstat
. If more frames exist, consider whether you need them to fulfill certain statistics.
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (6)
deepmd/pt/entrypoints/main.py (2)
389-389
: Consider renaming the parameter for consistency.The parameter name
elem_check_stat
doesn't match the CLI argument nameskip-elementcheck
. Additionally, their logic is inverted which could be confusing.-def change_bias(..., elem_check_stat: bool = True): +def change_bias(..., enable_element_completion: bool = True):
560-560
: Fix inverted logic between function and CLI argument.The function parameter defaults to
True
but is set from a CLI flag that defaults toFalse
. This inversion makes the code harder to understand.- elem_check_stat=FLAGS.skip_elementcheck, + enable_element_completion=not FLAGS.skip_elementcheck,deepmd/pt/utils/stat.py (4)
61-74
: Remove unused variable.The variable
global_element_counts
is initialized but never used.total_element_types = set() -global_element_counts = {} collect_ele = defaultdict(int)
75-96
: Add error handling for unknown data types.The function silently passes on unknown data types, which could hide potential issues.
else: - pass + log.warning(f"Unexpected data type {type(stat_data[dd])} for key {dd}")
161-194
: Add validation for element counts.The element counting logic needs validation and could be simplified.
if datasets[0].mixed_type and enable_element_completion: + if not isinstance(element_counts, dict): + log.warning(f"Invalid element counts for dataset {sys_index}") + continue for elem, data in element_counts.items(): + if not isinstance(data, dict) or "indices" not in data or "frames" not in data: + log.warning(f"Invalid data format for element {elem}") + continue indices = data["indices"] count = data["frames"]
195-216
: Improve error handling for missing elements.The missing element handling could be improved with better error messages and validation.
if datasets[0].mixed_type and enable_element_completion: + if not total_element_types: + log.warning("No elements found in any dataset") + return lst for elem, data in global_element_counts.items(): indices_count = data["count"] if indices_count < min_frames_per_element_forstat: log.warning( - f"The number of frames in your datasets with element {elem} is {indices_count}, " - f"which is less than the required {min_frames_per_element_forstat}" + f"Insufficient frames for element {elem}: found {indices_count}, " + f"required {min_frames_per_element_forstat}. This may affect model accuracy." )
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
deepmd/main.py
(1 hunks)deepmd/pt/entrypoints/main.py
(3 hunks)deepmd/pt/utils/stat.py
(1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms (20)
- GitHub Check: Build wheels for cp310-manylinux_aarch64
- GitHub Check: Build wheels for cp311-win_amd64
- GitHub Check: Build wheels for cp311-macosx_x86_64
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Test Python (6, 3.12)
- GitHub Check: Test Python (6, 3.9)
- GitHub Check: Test Python (5, 3.12)
- GitHub Check: Test Python (5, 3.9)
- GitHub Check: Test Python (4, 3.12)
- GitHub Check: Test Python (4, 3.9)
- GitHub Check: Test Python (3, 3.12)
- GitHub Check: Test Python (3, 3.9)
- GitHub Check: Test Python (2, 3.12)
- GitHub Check: Test Python (2, 3.9)
- GitHub Check: Test C++ (false)
- GitHub Check: Test Python (1, 3.12)
- GitHub Check: Test C++ (true)
- GitHub Check: Test Python (1, 3.9)
- GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
- GitHub Check: Analyze (c-cpp)
🔇 Additional comments (5)
deepmd/main.py (1)
740-744
: LGTM!The CLI argument is well-defined with a clear help message that explains its purpose.
deepmd/pt/utils/stat.py (4)
38-57
: LGTM!The function signature and docstring are well-documented with clear parameter descriptions.
133-146
: LGTM!The
finalize_stats
helper function is well-structured and handles all edge cases.
97-132
: 🛠️ Refactor suggestionAdd validation for frame data.
The function should validate that the frame actually contains the missing element before processing.
if newele_counter <= min_frames_per_element_forstat: frame_data = sys.__getitem__(frame) + # Validate frame contains the missing element + if "atype" not in frame_data: + log.warning(f"Frame {frame} does not contain type information") + continue + frame_types = frame_data["atype"] + if miss not in frame_types: + log.warning(f"Frame {frame} does not contain element {miss}") + continueLikely invalid or redundant comment.
147-160
: 🛠️ Refactor suggestionAdd input validation.
The initialization section needs validation for input parameters and data consistency.
+ if not datasets: + raise ValueError("No datasets provided") + if len(datasets) != len(dataloaders): + raise ValueError("Number of datasets does not match number of dataloaders") + if min_frames_per_element_forstat < 1: + raise ValueError("min_frames_per_element_forstat must be positive") total_element_types = set() collect_ele = defaultdict(int)Likely invalid or redundant comment.
Adaptation for the change-bias feature has been added. Users can now adjust the minimum number of frames used for each element’s statistics by using the -mf or --min-frames option. Additionally, the --skip-elementcheck option has been introduced, allowing users to skip element checks even when errors occur during change-bias operations. |
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 4
🧹 Nitpick comments (2)
source/tests/pt/test_make_stat_input.py (2)
25-44
: Consider optimizing the collate function.
- Use
key in batch[0]
instead ofkey in batch[0].keys()
for better performance.- The bare
except
clause could mask important errors. Consider catching specific exceptions.Apply this diff to improve the implementation:
- for key in batch[0].keys(): + for key in batch[0]: items = [sample[key] for sample in batch] if isinstance(items[0], torch.Tensor): out[key] = torch.stack(items, dim=0) elif isinstance(items[0], np.ndarray): out[key] = torch.from_numpy(np.stack(items, axis=0)) else: try: out[key] = torch.tensor(items) - except Exception: + except (ValueError, TypeError) as e: + # Log warning about conversion failure out[key] = items🧰 Tools
🪛 Ruff (0.8.2)
30-30: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
151-186
: Improve test with no missing elements.
- Hard-coded ntypes value
- Missing tests with different batch sizes
- Complex dictionary manipulation could be simplified
Consider extracting the dictionary manipulation into a helper method for better readability:
def _process_dict(dct): """Process dictionary by taking first element of multi-element values.""" for key in ["find_box", "find_coord", "find_numb_copy", "find_energy"]: if key in dct: val = dct[key] if val.numel() > 1: dct[key] = val[0] return dct
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
source/tests/pt/test_make_stat_input.py
(1 hunks)
🧰 Additional context used
🪛 Ruff (0.8.2)
source/tests/pt/test_make_stat_input.py
30-30: Use key in dict
instead of key in dict.keys()
Remove .keys()
(SIM118)
⏰ Context from checks skipped due to timeout of 90000ms (20)
- GitHub Check: Build wheels for cp310-manylinux_aarch64
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Test Python (6, 3.12)
- GitHub Check: Test Python (6, 3.9)
- GitHub Check: Test Python (5, 3.12)
- GitHub Check: Test Python (5, 3.9)
- GitHub Check: Build C++ (clang, clang)
- GitHub Check: Test Python (4, 3.12)
- GitHub Check: Test Python (4, 3.9)
- GitHub Check: Test Python (3, 3.12)
- GitHub Check: Test Python (3, 3.9)
- GitHub Check: Test Python (2, 3.12)
- GitHub Check: Test Python (2, 3.9)
- GitHub Check: Test Python (1, 3.12)
- GitHub Check: Test Python (1, 3.9)
- GitHub Check: Test C++ (false)
- GitHub Check: Test C++ (true)
- GitHub Check: Analyze (python)
- GitHub Check: Analyze (c-cpp)
🔇 Additional comments (3)
source/tests/pt/test_make_stat_input.py (3)
1-22
: LGTM! Well-organized imports.The imports are logically grouped and include all necessary dependencies.
74-76
: LGTM! Helper method is well-implemented.The
count_non_zero_elements
helper method is concise and uses an appropriate threshold.
189-190
: LGTM! Standard main block.The main block correctly uses unittest's main function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 6
🧹 Nitpick comments (1)
source/tests/pt/test_make_stat_input.py (1)
30-31
: Consider a minor optimization.Use
key in batch[0]
instead ofkey in batch[0].keys()
to avoid creating an unnecessary view object.- for key in batch[0].keys(): + for key in batch[0]:🧰 Tools
🪛 Ruff (0.8.2)
30-30: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
source/tests/pt/test_make_stat_input.py
(1 hunks)
🧰 Additional context used
🪛 Ruff (0.8.2)
source/tests/pt/test_make_stat_input.py
30-30: Use key in dict
instead of key in dict.keys()
Remove .keys()
(SIM118)
⏰ Context from checks skipped due to timeout of 90000ms (19)
- GitHub Check: Build wheels for cp310-manylinux_aarch64
- GitHub Check: Build wheels for cp311-macosx_arm64
- GitHub Check: Build wheels for cp311-macosx_x86_64
- GitHub Check: Test Python (6, 3.12)
- GitHub Check: Test Python (6, 3.9)
- GitHub Check: Test Python (5, 3.12)
- GitHub Check: Test Python (5, 3.9)
- GitHub Check: Test Python (4, 3.12)
- GitHub Check: Test Python (4, 3.9)
- GitHub Check: Test C++ (false)
- GitHub Check: Test C++ (true)
- GitHub Check: Test Python (3, 3.12)
- GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
- GitHub Check: Test Python (3, 3.9)
- GitHub Check: Test Python (2, 3.12)
- GitHub Check: Test Python (2, 3.9)
- GitHub Check: Test Python (1, 3.12)
- GitHub Check: Test Python (1, 3.9)
- GitHub Check: Analyze (c-cpp)
🔇 Additional comments (1)
source/tests/pt/test_make_stat_input.py (1)
74-76
: LGTM! Well-implemented helper method.The
count_non_zero_elements
method is well-implemented with proper threshold handling for floating-point comparisons.
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🧹 Nitpick comments (5)
deepmd/pt/entrypoints/main.py (2)
389-390
: Align parameter names withmake_stat_input
.The parameter names differ from those in
make_stat_input
:
skip_elem_check
vsenable_element_completion
min_frames
vsmin_frames_per_element_forstat
This inconsistency could cause confusion. Consider using the same names or documenting the mapping.
- skip_elem_check: bool = True, - min_frames: int = 10, + enable_element_completion: bool = False, + min_frames_per_element_forstat: int = 10,
477-478
: Document parameter mapping in the function call.The mapping between function parameters is not immediately clear.
sampled_data = make_stat_input( data_single.systems, data_single.dataloaders, nbatches, - min_frames_per_element_forstat=min_frames, - enable_element_completion=not skip_elem_check, + # Map CLI parameters to their internal names + min_frames_per_element_forstat=min_frames, # from --min-frames + enable_element_completion=not skip_elem_check, # inverse of --skip-elementcheck )deepmd/pt/utils/stat.py (3)
38-58
: Function documentation needs improvement.The docstring could be more detailed about the parameters and their implications.
def make_stat_input( datasets, dataloaders, nbatches, min_frames_per_element_forstat=10, enable_element_completion=True, ): - """Pack data for statistics. - Element checking is only enabled with mixed_type. - - Args: - - datasets: A list of datasets to analyze. - - dataloaders: Corresponding dataloaders for the datasets. - - nbatches: Batch count for collecting stats. - - min_frames_per_element_forstat: Minimum frames required for statistics. - - enable_element_completion: Whether to perform missing element completion (default: True). - - Returns - ------- - - A list of dicts, each of which contains data from a system. - """ + """Pack data for statistics with element completion support. + + This function processes datasets to collect statistics, with optional element completion + for mixed-type datasets. When enabled, it ensures each element type has sufficient + representation in the statistics. + + Args: + datasets: A list of datasets to analyze + dataloaders: Corresponding dataloaders for the datasets + nbatches: Number of batches to process for statistics + min_frames_per_element_forstat: Minimum number of frames required per element type + for reliable statistics. If an element has fewer frames, a warning is logged + enable_element_completion: When True and datasets have mixed types, ensures each + element type meets the minimum frame requirement by collecting additional frames + + Returns: + list[dict]: A list where each dict contains processed data from a system, + including tensors and statistics for each data type + + Note: + Element completion only works with mixed-type datasets. For non-mixed-type + datasets, the element completion parameters are ignored. + """
61-74
: Remove unused variable and improve logging.The code initializes an unused variable and could benefit from more structured logging.
total_element_types = set() - global_element_counts = {} # Unused variable collect_ele = defaultdict(int) if datasets[0].mixed_type: if enable_element_completion: - log.info( - f"Element check enabled. " - f"Verifying if frames with elements meet the set of {min_frames_per_element_forstat}." - ) + log.info("Element completion enabled") + log.info(f"Minimum frames per element: {min_frames_per_element_forstat}") else: - log.info( - "Element completion is disabled. Skipping missing element handling." - ) + log.info("Element completion disabled")
133-146
: Extract tensor processing logic infinalize_stats
.The tensor processing logic could be more robust and maintainable.
def finalize_stats(sys_stat): """Finalize statistics by concatenating tensors.""" + def process_tensor_list(tensor_list): + """Process a list of tensors with proper error handling.""" + try: + if not tensor_list or tensor_list[0] is None: + return None + return torch.cat(tensor_list, dim=0) + except Exception as e: + log.error(f"Failed to concatenate tensors: {e}") + raise + for key in sys_stat: if isinstance(sys_stat[key], np.float32): pass elif sys_stat[key] is None or ( isinstance(sys_stat[key], list) and (len(sys_stat[key]) == 0 or sys_stat[key][0] is None) ): sys_stat[key] = None elif isinstance(sys_stat[key][0], torch.Tensor): - sys_stat[key] = torch.cat(sys_stat[key], dim=0) + sys_stat[key] = process_tensor_list(sys_stat[key]) dict_to_device(sys_stat)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
deepmd/main.py
(1 hunks)deepmd/pt/entrypoints/main.py
(3 hunks)deepmd/pt/utils/dataset.py
(3 hunks)deepmd/pt/utils/stat.py
(1 hunks)deepmd/utils/argcheck.py
(2 hunks)
🚧 Files skipped from review as they are similar to previous changes (3)
- deepmd/pt/utils/dataset.py
- deepmd/main.py
- deepmd/utils/argcheck.py
⏰ Context from checks skipped due to timeout of 90000ms (15)
- GitHub Check: Build wheels for cp310-manylinux_aarch64
- GitHub Check: Build wheels for cp311-win_amd64
- GitHub Check: Build wheels for cp311-macosx_x86_64
- GitHub Check: Test Python (1, 3.9)
- GitHub Check: Test C++ (false)
- GitHub Check: Test C++ (true)
- GitHub Check: Build C++ (clang, clang)
- GitHub Check: Build C++ (rocm, rocm)
- GitHub Check: Build C++ (cuda120, cuda)
- GitHub Check: Build C++ (cuda, cuda)
- GitHub Check: Build C++ (cpu, cpu)
- GitHub Check: Build C library (2.14, >=2.5.0rc0,<2.15, libdeepmd_c_cu11.tar.gz)
- GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
- GitHub Check: Analyze (python)
- GitHub Check: Analyze (c-cpp)
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 5
🧹 Nitpick comments (2)
deepmd/pt/utils/stat.py (2)
38-58
: Improve docstring and add input validation.The docstring should provide more details about the new parameters:
- What are valid values for
min_frames_per_element_forstat
?- What are the implications of enabling/disabling element completion?
Apply this diff to improve the docstring and add input validation:
def make_stat_input( datasets, dataloaders, nbatches, min_frames_per_element_forstat=10, enable_element_completion=True, ): """Pack data for statistics. Element checking is only enabled with mixed_type. Args: - datasets: A list of datasets to analyze. - dataloaders: Corresponding dataloaders for the datasets. - nbatches: Batch count for collecting stats. - - min_frames_per_element_forstat: Minimum frames required for statistics. - - enable_element_completion: Whether to perform missing element completion (default: True). + - min_frames_per_element_forstat: Minimum number of frames required per element for statistics. + Must be a positive integer. Default is 10. + - enable_element_completion: Whether to perform missing element completion. + If True, ensures each element has at least min_frames_per_element_forstat frames. + If False, skips missing element handling. Default is True. Returns ------- - A list of dicts, each of which contains data from a system. """ + if not datasets: + raise ValueError("No datasets provided") + if len(datasets) != len(dataloaders): + raise ValueError("Number of datasets does not match number of dataloaders") + if min_frames_per_element_forstat < 1: + raise ValueError("min_frames_per_element_forstat must be positive")
61-74
: Remove unused variable and improve logging messages.The variable
global_element_counts
is initialized but never used. Also, the logging messages could be more informative.Apply this diff:
total_element_types = set() - global_element_counts = {} global_type_name = {} collect_ele = defaultdict(int) if datasets[0].mixed_type: if enable_element_completion: log.info( - f"Element check enabled. " - f"Verifying if frames with elements meet the set of {min_frames_per_element_forstat}." + f"Element completion enabled. " + f"Ensuring each element has at least {min_frames_per_element_forstat} frames." ) else: log.info( - "Element completion is disabled. Skipping missing element handling." + "Element completion disabled. Elements with insufficient frames will be skipped." )
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
deepmd/pt/utils/dataset.py
(3 hunks)deepmd/pt/utils/stat.py
(1 hunks)deepmd/utils/data.py
(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
- deepmd/pt/utils/dataset.py
- deepmd/utils/data.py
🧰 Additional context used
🪛 Ruff (0.8.2)
deepmd/pt/utils/stat.py
165-165: Undefined name type_name
(F821)
⏰ Context from checks skipped due to timeout of 90000ms (19)
- GitHub Check: Build wheels for cp311-win_amd64
- GitHub Check: Build wheels for cp311-macosx_arm64
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Test Python (6, 3.12)
- GitHub Check: Test Python (6, 3.9)
- GitHub Check: Build C++ (cuda, cuda)
- GitHub Check: Test Python (5, 3.12)
- GitHub Check: Test Python (5, 3.9)
- GitHub Check: Test Python (4, 3.12)
- GitHub Check: Test Python (4, 3.9)
- GitHub Check: Test Python (3, 3.12)
- GitHub Check: Test Python (3, 3.9)
- GitHub Check: Test Python (2, 3.12)
- GitHub Check: Test Python (2, 3.9)
- GitHub Check: Test Python (1, 3.12)
- GitHub Check: Test Python (1, 3.9)
- GitHub Check: Test C++ (false)
- GitHub Check: Test C++ (true)
- GitHub Check: Analyze (c-cpp)
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
deepmd/pt/utils/dataset.py
(3 hunks)deepmd/utils/data.py
(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- deepmd/pt/utils/dataset.py
⏰ Context from checks skipped due to timeout of 90000ms (19)
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Build wheels for cp311-manylinux_x86_64
- GitHub Check: Test Python (6, 3.12)
- GitHub Check: Test Python (6, 3.9)
- GitHub Check: Test Python (5, 3.12)
- GitHub Check: Test Python (5, 3.9)
- GitHub Check: Test Python (4, 3.12)
- GitHub Check: Build C++ (cuda120, cuda)
- GitHub Check: Test Python (4, 3.9)
- GitHub Check: Build C++ (cuda, cuda)
- GitHub Check: Test Python (3, 3.12)
- GitHub Check: Test Python (3, 3.9)
- GitHub Check: Test Python (2, 3.12)
- GitHub Check: Test Python (2, 3.9)
- GitHub Check: Test Python (1, 3.12)
- GitHub Check: Test Python (1, 3.9)
- GitHub Check: Test C++ (false)
- GitHub Check: Test C++ (true)
- GitHub Check: Analyze (c-cpp)
🔇 Additional comments (1)
deepmd/utils/data.py (1)
699-706
: LGTM! Improved error handling for type mapping.The error handling is well-implemented with:
- Clear error message including context (set name and available types)
- Proper exception chaining using
from e
- Appropriate exception type (IndexError)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 6
🧹 Nitpick comments (2)
deepmd/pt/utils/stat.py (2)
38-58
: Enhance docstring with more details.The docstring could be improved with:
- Return type annotation
- More detailed parameter descriptions including valid ranges and constraints
- Documentation of exceptions that might be raised
"""Pack data for statistics. Element checking is only enabled with mixed_type. Args: - - datasets: A list of datasets to analyze. - - dataloaders: Corresponding dataloaders for the datasets. - - nbatches: Batch count for collecting stats. - - min_frames_per_element_forstat: Minimum frames required for statistics. - - enable_element_completion: Whether to perform missing element completion (default: True). + datasets: list + A list of datasets to analyze. Must not be empty. + dataloaders: list + Corresponding dataloaders for the datasets. Must match length of datasets. + nbatches: int + Batch count for collecting stats. Must be positive. + min_frames_per_element_forstat: int, optional + Minimum frames required for statistics per element. Must be positive. + Defaults to 10. + enable_element_completion: bool, optional + Whether to perform missing element completion. Only applies when mixed_type=True. + Defaults to True. Returns ------- - - A list of dicts, each of which contains data from a system. + list[dict] + A list of dictionaries, each containing statistical data from a system. + Each dict contains tensor data for various properties. + Raises + ------ + ValueError + If datasets is empty or if datasets and dataloaders lengths don't match. + If min_frames_per_element_forstat is less than 1. + AssertionError + If element check fails during frame processing. """
199-221
: Improve error handling for missing elements.Add validation for empty element sets and improve warning messages.
if datasets[0].mixed_type and enable_element_completion: + if not total_element_types: + log.warning("No elements found in any dataset") + return lst for elem, data in global_element_counts.items(): indices_count = data["count"] if indices_count < min_frames_per_element_forstat: log.warning( - f"The number of frames in your datasets with element {element_name} is {indices_count}, " - f"which is less than the set {min_frames_per_element_forstat}" + f"Insufficient frames for element {element_name}: found {indices_count}, " + f"required {min_frames_per_element_forstat}. This may affect model accuracy." ) + if indices_count == 0: + log.error(f"No frames found for element {element_name}")
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
deepmd/pt/utils/stat.py
(1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms (10)
- GitHub Check: Build wheels for cp310-manylinux_aarch64
- GitHub Check: Build wheels for cp311-win_amd64
- GitHub Check: Build wheels for cp311-macosx_arm64
- GitHub Check: Build wheels for cp311-macosx_x86_64
- GitHub Check: Build C++ (clang, clang)
- GitHub Check: Analyze (python)
- GitHub Check: Build C library (2.14, >=2.5.0rc0,<2.15, libdeepmd_c_cu11.tar.gz)
- GitHub Check: Analyze (c-cpp)
- GitHub Check: Build C++ (cpu, cpu)
- GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
total_element_types = set() | ||
global_element_counts = {} | ||
global_type_name = {} | ||
collect_ele = defaultdict(int) | ||
if datasets[0].mixed_type: | ||
if enable_element_completion: | ||
log.info( | ||
f"Element check enabled. " | ||
f"Verifying if frames with elements meet the set of {min_frames_per_element_forstat}." | ||
) | ||
else: | ||
log.info( | ||
"Element completion is disabled. Skipping missing element handling." | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Remove unused variable global_element_counts
.
The variable is initialized but never used in the code.
total_element_types = set()
- global_element_counts = {}
global_type_name = {}
collect_ele = defaultdict(int)
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
total_element_types = set() | |
global_element_counts = {} | |
global_type_name = {} | |
collect_ele = defaultdict(int) | |
if datasets[0].mixed_type: | |
if enable_element_completion: | |
log.info( | |
f"Element check enabled. " | |
f"Verifying if frames with elements meet the set of {min_frames_per_element_forstat}." | |
) | |
else: | |
log.info( | |
"Element completion is disabled. Skipping missing element handling." | |
) | |
total_element_types = set() | |
global_type_name = {} | |
collect_ele = defaultdict(int) | |
if datasets[0].mixed_type: | |
if enable_element_completion: | |
log.info( | |
f"Element check enabled. " | |
f"Verifying if frames with elements meet the set of {min_frames_per_element_forstat}." | |
) | |
else: | |
log.info( | |
"Element completion is disabled. Skipping missing element handling." | |
) |
def finalize_stats(sys_stat): | ||
"""Finalize statistics by concatenating tensors.""" | ||
for key in sys_stat: | ||
if isinstance(sys_stat[key], np.float32): | ||
pass | ||
elif sys_stat[key] is None or sys_stat[key][0] is None: | ||
elif sys_stat[key] is None or ( | ||
isinstance(sys_stat[key], list) | ||
and (len(sys_stat[key]) == 0 or sys_stat[key][0] is None) | ||
): | ||
sys_stat[key] = None | ||
elif isinstance(stat_data[dd], torch.Tensor): | ||
elif isinstance(sys_stat[key][0], torch.Tensor): | ||
sys_stat[key] = torch.cat(sys_stat[key], dim=0) | ||
dict_to_device(sys_stat) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Add error handling for tensor operations.
The function lacks error handling for tensor operations which could fail.
def finalize_stats(sys_stat):
"""Finalize statistics by concatenating tensors."""
for key in sys_stat:
- if isinstance(sys_stat[key], np.float32):
- pass
- elif sys_stat[key] is None or (
- isinstance(sys_stat[key], list)
- and (len(sys_stat[key]) == 0 or sys_stat[key][0] is None)
- ):
- sys_stat[key] = None
- elif isinstance(sys_stat[key][0], torch.Tensor):
- sys_stat[key] = torch.cat(sys_stat[key], dim=0)
+ try:
+ if isinstance(sys_stat[key], np.float32):
+ continue
+ elif sys_stat[key] is None or (
+ isinstance(sys_stat[key], list)
+ and (len(sys_stat[key]) == 0 or sys_stat[key][0] is None)
+ ):
+ sys_stat[key] = None
+ elif isinstance(sys_stat[key][0], torch.Tensor):
+ sys_stat[key] = torch.cat(sys_stat[key], dim=0)
+ else:
+ log.warning(f"Unexpected data type for key {key}")
+ except Exception as e:
+ log.error(f"Error finalizing stats for key {key}: {e}")
+ raise
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def finalize_stats(sys_stat): | |
"""Finalize statistics by concatenating tensors.""" | |
for key in sys_stat: | |
if isinstance(sys_stat[key], np.float32): | |
pass | |
elif sys_stat[key] is None or sys_stat[key][0] is None: | |
elif sys_stat[key] is None or ( | |
isinstance(sys_stat[key], list) | |
and (len(sys_stat[key]) == 0 or sys_stat[key][0] is None) | |
): | |
sys_stat[key] = None | |
elif isinstance(stat_data[dd], torch.Tensor): | |
elif isinstance(sys_stat[key][0], torch.Tensor): | |
sys_stat[key] = torch.cat(sys_stat[key], dim=0) | |
dict_to_device(sys_stat) | |
def finalize_stats(sys_stat): | |
"""Finalize statistics by concatenating tensors.""" | |
for key in sys_stat: | |
try: | |
if isinstance(sys_stat[key], np.float32): | |
continue | |
elif sys_stat[key] is None or ( | |
isinstance(sys_stat[key], list) | |
and (len(sys_stat[key]) == 0 or sys_stat[key][0] is None) | |
): | |
sys_stat[key] = None | |
elif isinstance(sys_stat[key][0], torch.Tensor): | |
sys_stat[key] = torch.cat(sys_stat[key], dim=0) | |
else: | |
log.warning(f"Unexpected data type for key {key}") | |
except Exception as e: | |
log.error(f"Error finalizing stats for key {key}: {e}") | |
raise | |
dict_to_device(sys_stat) |
for sys_index, (dataset, dataloader) in enumerate(zip(datasets, dataloaders)): | ||
sys_stat = {} | ||
with torch.device("cpu"): | ||
process_batches(dataloader, sys_stat) | ||
if datasets[0].mixed_type and enable_element_completion: | ||
element_data = torch.cat(sys_stat["atype"], dim=0) | ||
collect_values = torch.unique(element_data.flatten(), sorted=True) | ||
for elem in collect_values.tolist(): | ||
frames_with_elem = torch.any(element_data == elem, dim=1) | ||
row_indices = torch.where(frames_with_elem)[0] | ||
collect_ele[elem] += len(row_indices) | ||
finalize_stats(sys_stat) | ||
lst.append(sys_stat) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Add validation for dataset processing.
The main loop needs validation for empty datasets and error handling for tensor operations.
for sys_index, (dataset, dataloader) in enumerate(zip(datasets, dataloaders)):
+ if not dataset or not dataloader:
+ log.warning(f"Empty dataset or dataloader at index {sys_index}")
+ continue
sys_stat = {}
with torch.device("cpu"):
- process_batches(dataloader, sys_stat)
- if datasets[0].mixed_type and enable_element_completion:
- element_data = torch.cat(sys_stat["atype"], dim=0)
+ try:
+ process_batches(dataloader, sys_stat)
+ if datasets[0].mixed_type and enable_element_completion:
+ if not sys_stat.get("atype"):
+ log.warning(f"No type information found in dataset {sys_index}")
+ continue
+ element_data = torch.cat(sys_stat["atype"], dim=0)
+ except Exception as e:
+ log.error(f"Error processing dataset {sys_index}: {e}")
+ continue
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
for sys_index, (dataset, dataloader) in enumerate(zip(datasets, dataloaders)): | |
sys_stat = {} | |
with torch.device("cpu"): | |
process_batches(dataloader, sys_stat) | |
if datasets[0].mixed_type and enable_element_completion: | |
element_data = torch.cat(sys_stat["atype"], dim=0) | |
collect_values = torch.unique(element_data.flatten(), sorted=True) | |
for elem in collect_values.tolist(): | |
frames_with_elem = torch.any(element_data == elem, dim=1) | |
row_indices = torch.where(frames_with_elem)[0] | |
collect_ele[elem] += len(row_indices) | |
finalize_stats(sys_stat) | |
lst.append(sys_stat) | |
for sys_index, (dataset, dataloader) in enumerate(zip(datasets, dataloaders)): | |
if not dataset or not dataloader: | |
log.warning(f"Empty dataset or dataloader at index {sys_index}") | |
continue | |
sys_stat = {} | |
with torch.device("cpu"): | |
try: | |
process_batches(dataloader, sys_stat) | |
if datasets[0].mixed_type and enable_element_completion: | |
if not sys_stat.get("atype"): | |
log.warning(f"No type information found in dataset {sys_index}") | |
continue | |
element_data = torch.cat(sys_stat["atype"], dim=0) | |
except Exception as e: | |
log.error(f"Error processing dataset {sys_index}: {e}") | |
continue | |
collect_values = torch.unique(element_data.flatten(), sorted=True) | |
for elem in collect_values.tolist(): | |
frames_with_elem = torch.any(element_data == elem, dim=1) | |
row_indices = torch.where(frames_with_elem)[0] | |
collect_ele[elem] += len(row_indices) | |
finalize_stats(sys_stat) | |
lst.append(sys_stat) |
def process_batches(dataloader, sys_stat): | ||
"""Process batches from a dataloader to collect statistics.""" | ||
iterator = iter(dataloader) | ||
numb_batches = min(nbatches, len(dataloader)) | ||
for _ in range(numb_batches): | ||
try: | ||
stat_data = next(iterator) | ||
except StopIteration: | ||
iterator = iter(dataloader) | ||
stat_data = next(iterator) | ||
for dd in stat_data: | ||
if stat_data[dd] is None: | ||
sys_stat[dd] = None | ||
elif isinstance(stat_data[dd], torch.Tensor): | ||
if dd not in sys_stat: | ||
sys_stat[dd] = [] | ||
sys_stat[dd].append(stat_data[dd]) | ||
elif isinstance(stat_data[dd], np.float32): | ||
sys_stat[dd] = stat_data[dd] | ||
else: | ||
pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Improve error handling and data type validation.
The function silently handles unknown data types and lacks error handling for tensor operations.
def process_batches(dataloader, sys_stat):
"""Process batches from a dataloader to collect statistics."""
iterator = iter(dataloader)
numb_batches = min(nbatches, len(dataloader))
for _ in range(numb_batches):
try:
stat_data = next(iterator)
except StopIteration:
iterator = iter(dataloader)
stat_data = next(iterator)
- for dd in stat_data:
+ for dd, value in stat_data.items():
+ try:
+ if value is None:
+ sys_stat[dd] = None
+ elif isinstance(value, torch.Tensor):
+ if dd not in sys_stat:
+ sys_stat[dd] = []
+ sys_stat[dd].append(value)
+ elif isinstance(value, np.float32):
+ sys_stat[dd] = value
+ else:
+ log.warning(f"Unexpected data type {type(value)} for key {dd}")
+ except Exception as e:
+ log.error(f"Error processing key {dd}: {str(e)}")
+ raise
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def process_batches(dataloader, sys_stat): | |
"""Process batches from a dataloader to collect statistics.""" | |
iterator = iter(dataloader) | |
numb_batches = min(nbatches, len(dataloader)) | |
for _ in range(numb_batches): | |
try: | |
stat_data = next(iterator) | |
except StopIteration: | |
iterator = iter(dataloader) | |
stat_data = next(iterator) | |
for dd in stat_data: | |
if stat_data[dd] is None: | |
sys_stat[dd] = None | |
elif isinstance(stat_data[dd], torch.Tensor): | |
if dd not in sys_stat: | |
sys_stat[dd] = [] | |
sys_stat[dd].append(stat_data[dd]) | |
elif isinstance(stat_data[dd], np.float32): | |
sys_stat[dd] = stat_data[dd] | |
else: | |
pass | |
def process_batches(dataloader, sys_stat): | |
"""Process batches from a dataloader to collect statistics.""" | |
iterator = iter(dataloader) | |
numb_batches = min(nbatches, len(dataloader)) | |
for _ in range(numb_batches): | |
try: | |
stat_data = next(iterator) | |
except StopIteration: | |
iterator = iter(dataloader) | |
stat_data = next(iterator) | |
for dd, value in stat_data.items(): | |
try: | |
if value is None: | |
sys_stat[dd] = None | |
elif isinstance(value, torch.Tensor): | |
if dd not in sys_stat: | |
sys_stat[dd] = [] | |
sys_stat[dd].append(value) | |
elif isinstance(value, np.float32): | |
sys_stat[dd] = value | |
else: | |
log.warning(f"Unexpected data type {type(value)} for key {dd}") | |
except Exception as e: | |
log.error(f"Error processing key {dd}: {str(e)}") | |
raise |
# get frame index | ||
if datasets[0].mixed_type and enable_element_completion: | ||
element_counts, type_name = dataset.get_frame_index_for_elements() | ||
for new_idx, elem_name in type_name.items(): | ||
if new_idx not in global_type_name: | ||
global_type_name[new_idx] = elem_name | ||
for elem, data in element_counts.items(): | ||
indices = data["indices"] | ||
count = data["frames"] | ||
total_element_types.add(elem) | ||
if elem not in global_element_counts: | ||
global_element_counts[elem] = {"count": 0, "indices": []} | ||
if count > min_frames_per_element_forstat: | ||
global_element_counts[elem]["count"] += ( | ||
min_frames_per_element_forstat | ||
) | ||
indices = indices[:min_frames_per_element_forstat] | ||
global_element_counts[elem]["indices"].append( | ||
{"sys_index": sys_index, "frames": indices} | ||
) | ||
else: | ||
global_element_counts[elem]["count"] += count | ||
global_element_counts[elem]["indices"].append( | ||
{"sys_index": sys_index, "frames": indices} | ||
) | ||
else: | ||
if ( | ||
global_element_counts[elem]["count"] | ||
>= min_frames_per_element_forstat | ||
): | ||
pass | ||
else: | ||
global_element_counts[elem]["count"] += count | ||
global_element_counts[elem]["indices"].append( | ||
{"sys_index": sys_index, "frames": indices} | ||
) | ||
# Complement |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix undefined variable and improve code structure.
- The variable
type_name
is used but not defined - The nested conditions could be simplified
if datasets[0].mixed_type and enable_element_completion:
- element_counts, type_name = dataset.get_frame_index_for_elements()
+ element_counts, type_name = dataset.get_frame_index_for_elements()
+ if not isinstance(element_counts, dict):
+ log.warning(f"Invalid element counts for dataset {sys_index}")
+ continue
for new_idx, elem_name in type_name.items():
if new_idx not in global_type_name:
global_type_name[new_idx] = elem_name
for elem, data in element_counts.items():
+ if not isinstance(data, dict) or "indices" not in data or "frames" not in data:
+ log.warning(f"Invalid data format for element {elem}")
+ continue
indices = data["indices"]
count = data["frames"]
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
# get frame index | |
if datasets[0].mixed_type and enable_element_completion: | |
element_counts, type_name = dataset.get_frame_index_for_elements() | |
for new_idx, elem_name in type_name.items(): | |
if new_idx not in global_type_name: | |
global_type_name[new_idx] = elem_name | |
for elem, data in element_counts.items(): | |
indices = data["indices"] | |
count = data["frames"] | |
total_element_types.add(elem) | |
if elem not in global_element_counts: | |
global_element_counts[elem] = {"count": 0, "indices": []} | |
if count > min_frames_per_element_forstat: | |
global_element_counts[elem]["count"] += ( | |
min_frames_per_element_forstat | |
) | |
indices = indices[:min_frames_per_element_forstat] | |
global_element_counts[elem]["indices"].append( | |
{"sys_index": sys_index, "frames": indices} | |
) | |
else: | |
global_element_counts[elem]["count"] += count | |
global_element_counts[elem]["indices"].append( | |
{"sys_index": sys_index, "frames": indices} | |
) | |
else: | |
if ( | |
global_element_counts[elem]["count"] | |
>= min_frames_per_element_forstat | |
): | |
pass | |
else: | |
global_element_counts[elem]["count"] += count | |
global_element_counts[elem]["indices"].append( | |
{"sys_index": sys_index, "frames": indices} | |
) | |
# Complement | |
# get frame index | |
if datasets[0].mixed_type and enable_element_completion: | |
element_counts, type_name = dataset.get_frame_index_for_elements() | |
if not isinstance(element_counts, dict): | |
log.warning(f"Invalid element counts for dataset {sys_index}") | |
continue | |
for new_idx, elem_name in type_name.items(): | |
if new_idx not in global_type_name: | |
global_type_name[new_idx] = elem_name | |
for elem, data in element_counts.items(): | |
if not isinstance(data, dict) or "indices" not in data or "frames" not in data: | |
log.warning(f"Invalid data format for element {elem}") | |
continue | |
indices = data["indices"] | |
count = data["frames"] | |
total_element_types.add(elem) | |
if elem not in global_element_counts: | |
global_element_counts[elem] = {"count": 0, "indices": []} | |
if count > min_frames_per_element_forstat: | |
global_element_counts[elem]["count"] += ( | |
min_frames_per_element_forstat | |
) | |
indices = indices[:min_frames_per_element_forstat] | |
global_element_counts[elem]["indices"].append( | |
{"sys_index": sys_index, "frames": indices} | |
) | |
else: | |
global_element_counts[elem]["count"] += count | |
global_element_counts[elem]["indices"].append( | |
{"sys_index": sys_index, "frames": indices} | |
) | |
else: | |
if ( | |
global_element_counts[elem]["count"] | |
>= min_frames_per_element_forstat | |
): | |
pass | |
else: | |
global_element_counts[elem]["count"] += count | |
global_element_counts[elem]["indices"].append( | |
{"sys_index": sys_index, "frames": indices} | |
) | |
# Complement |
def process_with_new_frame(sys_indices, newele_counter, miss): | ||
for sys_info in sys_indices: | ||
sys_index = sys_info["sys_index"] | ||
frames = sys_info["frames"] | ||
sys = datasets[sys_index] | ||
for frame in frames: | ||
newele_counter += 1 | ||
if newele_counter <= min_frames_per_element_forstat: | ||
frame_data = sys.__getitem__(frame) | ||
assert miss in frame_data["atype"], ( | ||
"Element check failed. " | ||
"If you are running in 'change-bias' mode, use '--skip-elementcheck' to disable this check. " | ||
"If you encountered this error during model training, set 'enable_element_completion' to False " | ||
"in the 'training' section of your input file." | ||
) | ||
sys_stat_new = {} | ||
for dd in frame_data: | ||
if dd == "type": | ||
continue | ||
if frame_data[dd] is None: | ||
sys_stat_new[dd] = None | ||
elif isinstance(frame_data[dd], np.ndarray): | ||
if dd not in sys_stat_new: | ||
sys_stat_new[dd] = [] | ||
tensor_data = torch.from_numpy(frame_data[dd]) | ||
tensor_data = tensor_data.unsqueeze(0) | ||
sys_stat_new[dd].append(tensor_data) | ||
elif isinstance(frame_data[dd], np.float32): | ||
sys_stat_new[dd] = frame_data[dd] | ||
else: | ||
pass | ||
finalize_stats(sys_stat_new) | ||
lst.append(sys_stat_new) | ||
else: | ||
break |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Add validation and error handling for frame processing.
The function needs better validation and error handling:
- Add validation for empty sys_indices
- Add error handling for frame data retrieval
- Validate frame data before processing
def process_with_new_frame(sys_indices, newele_counter, miss):
+ if not sys_indices:
+ log.warning(f"No system indices provided for element {miss}")
+ return
for sys_info in sys_indices:
sys_index = sys_info["sys_index"]
frames = sys_info["frames"]
+ if not frames:
+ log.warning(f"No frames found for system {sys_index}")
+ continue
sys = datasets[sys_index]
for frame in frames:
newele_counter += 1
if newele_counter <= min_frames_per_element_forstat:
- frame_data = sys.__getitem__(frame)
+ try:
+ frame_data = sys.__getitem__(frame)
+ if "atype" not in frame_data:
+ log.warning(f"Frame {frame} does not contain type information")
+ continue
+ if miss not in frame_data["atype"]:
+ log.warning(f"Frame {frame} does not contain element {miss}")
+ continue
+ except Exception as e:
+ log.error(f"Failed to get frame {frame} from system {sys_index}: {e}")
+ continue
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def process_with_new_frame(sys_indices, newele_counter, miss): | |
for sys_info in sys_indices: | |
sys_index = sys_info["sys_index"] | |
frames = sys_info["frames"] | |
sys = datasets[sys_index] | |
for frame in frames: | |
newele_counter += 1 | |
if newele_counter <= min_frames_per_element_forstat: | |
frame_data = sys.__getitem__(frame) | |
assert miss in frame_data["atype"], ( | |
"Element check failed. " | |
"If you are running in 'change-bias' mode, use '--skip-elementcheck' to disable this check. " | |
"If you encountered this error during model training, set 'enable_element_completion' to False " | |
"in the 'training' section of your input file." | |
) | |
sys_stat_new = {} | |
for dd in frame_data: | |
if dd == "type": | |
continue | |
if frame_data[dd] is None: | |
sys_stat_new[dd] = None | |
elif isinstance(frame_data[dd], np.ndarray): | |
if dd not in sys_stat_new: | |
sys_stat_new[dd] = [] | |
tensor_data = torch.from_numpy(frame_data[dd]) | |
tensor_data = tensor_data.unsqueeze(0) | |
sys_stat_new[dd].append(tensor_data) | |
elif isinstance(frame_data[dd], np.float32): | |
sys_stat_new[dd] = frame_data[dd] | |
else: | |
pass | |
finalize_stats(sys_stat_new) | |
lst.append(sys_stat_new) | |
else: | |
break | |
def process_with_new_frame(sys_indices, newele_counter, miss): | |
if not sys_indices: | |
log.warning(f"No system indices provided for element {miss}") | |
return | |
for sys_info in sys_indices: | |
sys_index = sys_info["sys_index"] | |
frames = sys_info["frames"] | |
if not frames: | |
log.warning(f"No frames found for system {sys_index}") | |
continue | |
sys = datasets[sys_index] | |
for frame in frames: | |
newele_counter += 1 | |
if newele_counter <= min_frames_per_element_forstat: | |
try: | |
frame_data = sys.__getitem__(frame) | |
if "atype" not in frame_data: | |
log.warning(f"Frame {frame} does not contain type information") | |
continue | |
if miss not in frame_data["atype"]: | |
log.warning(f"Frame {frame} does not contain element {miss}") | |
continue | |
except Exception as e: | |
log.error(f"Failed to get frame {frame} from system {sys_index}: {e}") | |
continue | |
assert miss in frame_data["atype"], ( | |
"Element check failed. " | |
"If you are running in 'change-bias' mode, use '--skip-elementcheck' to disable this check. " | |
"If you encountered this error during model training, set 'enable_element_completion' to False " | |
"in the 'training' section of your input file." | |
) | |
sys_stat_new = {} | |
for dd in frame_data: | |
if dd == "type": | |
continue | |
if frame_data[dd] is None: | |
sys_stat_new[dd] = None | |
elif isinstance(frame_data[dd], np.ndarray): | |
if dd not in sys_stat_new: | |
sys_stat_new[dd] = [] | |
tensor_data = torch.from_numpy(frame_data[dd]) | |
tensor_data = tensor_data.unsqueeze(0) | |
sys_stat_new[dd].append(tensor_data) | |
elif isinstance(frame_data[dd], np.float32): | |
sys_stat_new[dd] = frame_data[dd] | |
else: | |
pass | |
finalize_stats(sys_stat_new) | |
lst.append(sys_stat_new) | |
else: | |
break |
Summary by CodeRabbit
New Features
Bug Fixes
Tests
make_stat_input
function to ensure accurate processing of atomic types.