diff --git a/deepmd/main.py b/deepmd/main.py index 097588ca0a..0e024b5011 100644 --- a/deepmd/main.py +++ b/deepmd/main.py @@ -737,6 +737,18 @@ def main_parser() -> argparse.ArgumentParser: default=None, help="Model branch chosen for changing bias if multi-task model.", ) + parser_change_bias.add_argument( + "--skip-elementcheck", + action="store_true", + help="Enable this option to skip element checks if any error occurs while retrieving statistical data.", + ) + parser_change_bias.add_argument( + "-mf", + "--min-frames", + default=10, + type=int, + help="The minimum number of frames for each element used for statistics.", + ) # --version parser.add_argument( diff --git a/deepmd/pt/entrypoints/main.py b/deepmd/pt/entrypoints/main.py index fd4be73e84..5a0fd435f7 100644 --- a/deepmd/pt/entrypoints/main.py +++ b/deepmd/pt/entrypoints/main.py @@ -386,6 +386,8 @@ def change_bias( numb_batch: int = 0, model_branch: Optional[str] = None, output: Optional[str] = None, + skip_elem_check: bool = True, + min_frames: int = 10, ) -> None: if input_file.endswith(".pt"): old_state_dict = torch.load( @@ -472,6 +474,8 @@ def change_bias( data_single.systems, data_single.dataloaders, nbatches, + min_frames_per_element_forstat=min_frames, + enable_element_completion=not skip_elem_check, ) updated_model = training.model_change_out_bias( model_to_change, sampled_data, _bias_adjust_mode=bias_adjust_mode @@ -555,6 +559,8 @@ def main(args: Optional[Union[list[str], argparse.Namespace]] = None) -> None: numb_batch=FLAGS.numb_batch, model_branch=FLAGS.model_branch, output=FLAGS.output, + skip_elem_check=FLAGS.skip_elementcheck, + min_frames=FLAGS.min_frames, ) elif FLAGS.command == "compress": FLAGS.input = str(Path(FLAGS.input).with_suffix(".pth")) diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index f3b4548b05..0903603f3e 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -143,6 +143,12 @@ def __init__( self.max_ckpt_keep = training_params.get("max_ckpt_keep", 5) self.display_in_training = training_params.get("disp_training", True) self.timing_in_training = training_params.get("time_training", True) + self.min_frames_per_element_forstat = training_params.get( + "min_frames_per_element_forstat", 10 + ) + self.enable_element_completion = training_params.get( + "enable_element_completion", True + ) self.change_bias_after_training = training_params.get( "change_bias_after_training", False ) @@ -227,6 +233,8 @@ def get_sample(): _training_data.systems, _training_data.dataloaders, _data_stat_nbatch, + self.min_frames_per_element_forstat, + self.enable_element_completion, ) return sampled diff --git a/deepmd/pt/utils/dataset.py b/deepmd/pt/utils/dataset.py index 3043839308..a59f46be3f 100644 --- a/deepmd/pt/utils/dataset.py +++ b/deepmd/pt/utils/dataset.py @@ -1,10 +1,13 @@ # SPDX-License-Identifier: LGPL-3.0-or-later - +from collections import ( + defaultdict, +) from typing import ( Optional, ) +import numpy as np from torch.utils.data import ( Dataset, ) @@ -17,10 +20,10 @@ class DeepmdDataSetForLoader(Dataset): def __init__(self, system: str, type_map: Optional[list[str]] = None) -> None: - """Construct DeePMD-style dataset containing frames cross different systems. + """Construct DeePMD-style dataset containing frames across different systems. Args: - - systems: Paths to systems. + - system: Path to the system. - type_map: Atom types. """ self.system = system @@ -40,6 +43,48 @@ def __getitem__(self, index): b_data["natoms"] = self._natoms_vec return b_data + def get_frame_index_for_elements(self): + """ + Get the frame index and the number of frames with all the elements in the system. + Map the remapped atom_type_mix back to their element names in type_map, + This function is only used in the mixed type. + + Returns + ------- + element_counts : dict + A dictionary where: + - The key is the element type. + - The value is another dictionary with the following keys: + - "frames": int + The total number of frames in which the element appears. + - "indices": list of int + A list of row indices where the element is found in the dataset. + global_type_name : dict + The key is the element index and the value is the element name. + """ + element_counts = defaultdict(lambda: {"frames": 0, "indices": []}) + set_files = self._data_system.dirs + base_offset = 0 + global_type_name = {} + for set_file in set_files: + element_data = self._data_system._load_type_mix(set_file) + unique_elements = np.unique(element_data) + type_name = self._data_system.build_reidx_to_name_map( + element_data, set_file + ) + for new_idx, elem_name in type_name.items(): + if new_idx not in global_type_name: + global_type_name[new_idx] = elem_name + for elem in unique_elements: + frames_with_elem = np.any(element_data == elem, axis=1) + row_indices = np.where(frames_with_elem)[0] + row_indices_global = np.where(frames_with_elem)[0] + base_offset + element_counts[elem]["frames"] += len(row_indices) + element_counts[elem]["indices"].extend(row_indices_global.tolist()) + base_offset += element_data.shape[0] + element_counts = dict(element_counts) + return element_counts, global_type_name + def add_data_requirement(self, data_requirement: list[DataRequirementItem]) -> None: """Add data requirement for this data system.""" for data_item in data_requirement: diff --git a/deepmd/pt/utils/stat.py b/deepmd/pt/utils/stat.py index 710d392ac3..7657e84d75 100644 --- a/deepmd/pt/utils/stat.py +++ b/deepmd/pt/utils/stat.py @@ -35,51 +35,190 @@ log = logging.getLogger(__name__) -def make_stat_input(datasets, dataloaders, nbatches): +def make_stat_input( + datasets, + dataloaders, + nbatches, + min_frames_per_element_forstat=10, + enable_element_completion=True, +): """Pack data for statistics. + Element checking is only enabled with mixed_type. Args: - - dataset: A list of dataset to analyze. + - datasets: A list of datasets to analyze. + - dataloaders: Corresponding dataloaders for the datasets. - nbatches: Batch count for collecting stats. + - min_frames_per_element_forstat: Minimum frames required for statistics. + - enable_element_completion: Whether to perform missing element completion (default: True). Returns ------- - - a list of dicts, each of which contains data from a system + - A list of dicts, each of which contains data from a system. """ lst = [] log.info(f"Packing data for statistics from {len(datasets)} systems") - for i in range(len(datasets)): - sys_stat = {} - with torch.device("cpu"): - iterator = iter(dataloaders[i]) - numb_batches = min(nbatches, len(dataloaders[i])) - for _ in range(numb_batches): - try: - stat_data = next(iterator) - except StopIteration: - iterator = iter(dataloaders[i]) - stat_data = next(iterator) - for dd in stat_data: - if stat_data[dd] is None: - sys_stat[dd] = None - elif isinstance(stat_data[dd], torch.Tensor): - if dd not in sys_stat: - sys_stat[dd] = [] - sys_stat[dd].append(stat_data[dd]) - elif isinstance(stat_data[dd], np.float32): - sys_stat[dd] = stat_data[dd] - else: - pass + total_element_types = set() + global_element_counts = {} + global_type_name = {} + collect_ele = defaultdict(int) + if datasets[0].mixed_type: + if enable_element_completion: + log.info( + f"Element check enabled. " + f"Verifying if frames with elements meet the set of {min_frames_per_element_forstat}." + ) + else: + log.info( + "Element completion is disabled. Skipping missing element handling." + ) + + def process_batches(dataloader, sys_stat): + """Process batches from a dataloader to collect statistics.""" + iterator = iter(dataloader) + numb_batches = min(nbatches, len(dataloader)) + for _ in range(numb_batches): + try: + stat_data = next(iterator) + except StopIteration: + iterator = iter(dataloader) + stat_data = next(iterator) + for dd in stat_data: + if stat_data[dd] is None: + sys_stat[dd] = None + elif isinstance(stat_data[dd], torch.Tensor): + if dd not in sys_stat: + sys_stat[dd] = [] + sys_stat[dd].append(stat_data[dd]) + elif isinstance(stat_data[dd], np.float32): + sys_stat[dd] = stat_data[dd] + else: + pass + + def process_with_new_frame(sys_indices, newele_counter, miss): + for sys_info in sys_indices: + sys_index = sys_info["sys_index"] + frames = sys_info["frames"] + sys = datasets[sys_index] + for frame in frames: + newele_counter += 1 + if newele_counter <= min_frames_per_element_forstat: + frame_data = sys.__getitem__(frame) + assert miss in frame_data["atype"], ( + "Element check failed. " + "If you are running in 'change-bias' mode, use '--skip-elementcheck' to disable this check. " + "If you encountered this error during model training, set 'enable_element_completion' to False " + "in the 'training' section of your input file." + ) + sys_stat_new = {} + for dd in frame_data: + if dd == "type": + continue + if frame_data[dd] is None: + sys_stat_new[dd] = None + elif isinstance(frame_data[dd], np.ndarray): + if dd not in sys_stat_new: + sys_stat_new[dd] = [] + tensor_data = torch.from_numpy(frame_data[dd]) + tensor_data = tensor_data.unsqueeze(0) + sys_stat_new[dd].append(tensor_data) + elif isinstance(frame_data[dd], np.float32): + sys_stat_new[dd] = frame_data[dd] + else: + pass + finalize_stats(sys_stat_new) + lst.append(sys_stat_new) + else: + break + def finalize_stats(sys_stat): + """Finalize statistics by concatenating tensors.""" for key in sys_stat: if isinstance(sys_stat[key], np.float32): pass - elif sys_stat[key] is None or sys_stat[key][0] is None: + elif sys_stat[key] is None or ( + isinstance(sys_stat[key], list) + and (len(sys_stat[key]) == 0 or sys_stat[key][0] is None) + ): sys_stat[key] = None - elif isinstance(stat_data[dd], torch.Tensor): + elif isinstance(sys_stat[key][0], torch.Tensor): sys_stat[key] = torch.cat(sys_stat[key], dim=0) dict_to_device(sys_stat) + + for sys_index, (dataset, dataloader) in enumerate(zip(datasets, dataloaders)): + sys_stat = {} + with torch.device("cpu"): + process_batches(dataloader, sys_stat) + if datasets[0].mixed_type and enable_element_completion: + element_data = torch.cat(sys_stat["atype"], dim=0) + collect_values = torch.unique(element_data.flatten(), sorted=True) + for elem in collect_values.tolist(): + frames_with_elem = torch.any(element_data == elem, dim=1) + row_indices = torch.where(frames_with_elem)[0] + collect_ele[elem] += len(row_indices) + finalize_stats(sys_stat) lst.append(sys_stat) + + # get frame index + if datasets[0].mixed_type and enable_element_completion: + element_counts, type_name = dataset.get_frame_index_for_elements() + for new_idx, elem_name in type_name.items(): + if new_idx not in global_type_name: + global_type_name[new_idx] = elem_name + for elem, data in element_counts.items(): + indices = data["indices"] + count = data["frames"] + total_element_types.add(elem) + if elem not in global_element_counts: + global_element_counts[elem] = {"count": 0, "indices": []} + if count > min_frames_per_element_forstat: + global_element_counts[elem]["count"] += ( + min_frames_per_element_forstat + ) + indices = indices[:min_frames_per_element_forstat] + global_element_counts[elem]["indices"].append( + {"sys_index": sys_index, "frames": indices} + ) + else: + global_element_counts[elem]["count"] += count + global_element_counts[elem]["indices"].append( + {"sys_index": sys_index, "frames": indices} + ) + else: + if ( + global_element_counts[elem]["count"] + >= min_frames_per_element_forstat + ): + pass + else: + global_element_counts[elem]["count"] += count + global_element_counts[elem]["indices"].append( + {"sys_index": sys_index, "frames": indices} + ) + # Complement + if datasets[0].mixed_type and enable_element_completion: + for elem, data in global_element_counts.items(): + indices_count = data["count"] + element_name = global_type_name.get(elem, f"") + if indices_count < min_frames_per_element_forstat: + log.warning( + f"The number of frames in your datasets with element {element_name} is {indices_count}, " + f"which is less than the set {min_frames_per_element_forstat}" + ) + collect_elements = collect_ele.keys() + missing_elements = total_element_types - collect_elements + collect_miss_element = set() + for ele, count in collect_ele.items(): + if count < min_frames_per_element_forstat: + collect_miss_element.add(ele) + missing_elements.add(ele) + for miss in missing_elements: + sys_indices = global_element_counts[miss].get("indices", []) + if miss in collect_miss_element: + newele_counter = collect_ele.get(miss, 0) + else: + newele_counter = 0 + process_with_new_frame(sys_indices, newele_counter, miss) return lst diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 66a6308fc5..122759541c 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -2787,6 +2787,10 @@ def training_args( "If the gradient norm exceeds this value, it will be clipped to this limit. " "No gradient clipping will occur if set to 0." ) + doc_min_frames_per_element_forstat = "The minimum number of frames per element used for statistics when using the mixed type." + doc_enable_element_completion = ( + "Whether to check elements when using the mixed type" + ) doc_stat_file = ( "The file path for saving the data statistics results. " "If set, the results will be saved and directly loaded during the next training session, " @@ -2893,6 +2897,20 @@ def training_args( optional=True, doc=doc_only_pt_supported + doc_gradient_max_norm, ), + Argument( + "min_frames_per_element_forstat", + int, + default=10, + optional=True, + doc=doc_only_pt_supported + doc_min_frames_per_element_forstat, + ), + Argument( + "enable_element_completion", + bool, + optional=True, + default=True, + doc=doc_only_pt_supported + doc_enable_element_completion, + ), ] variants = [ Variant( diff --git a/deepmd/utils/data.py b/deepmd/utils/data.py index d572efd321..5fe92c617b 100644 --- a/deepmd/utils/data.py +++ b/deepmd/utils/data.py @@ -547,14 +547,6 @@ def _load_set(self, set_name: DPPath): if self.mixed_type: # nframes x natoms atom_type_mix = self._load_type_mix(set_name) - if self.enforce_type_map: - try: - atom_type_mix_ = self.type_idx_map[atom_type_mix].astype(np.int32) - except IndexError as e: - raise IndexError( - f"some types in 'real_atom_types.npy' of set {set_name} are not contained in {self.get_ntypes()} types!" - ) from e - atom_type_mix = atom_type_mix_ real_type = atom_type_mix.reshape([nframes, self.natoms]) data["type"] = real_type natoms = data["type"].shape[1] @@ -704,8 +696,30 @@ def _load_type(self, sys_path: DPPath): def _load_type_mix(self, set_name: DPPath): type_path = set_name / "real_atom_types.npy" real_type = type_path.load_numpy().astype(np.int32).reshape([-1, self.natoms]) + if self.enforce_type_map: + try: + atom_type_mix_ = self.type_idx_map[real_type].astype(np.int32) + except IndexError as e: + raise IndexError( + f"some types in 'real_atom_types.npy' of set {set_name} are not contained in {self.get_ntypes()} types!" + ) from e + real_type = atom_type_mix_ return real_type + def build_reidx_to_name_map(self, typemix, set_name: DPPath): + type_map = self.type_map + type_path = set_name / "real_atom_types.npy" + real_type = type_path.load_numpy().astype(np.int32).reshape([-1, self.natoms]) + type_map_array = np.array(type_map, dtype=object) + reidx_to_name = {} + N, M = real_type.shape + for i in range(N): + for j in range(M): + old_val = int(real_type[i, j]) + re_val = int(typemix[i, j]) + reidx_to_name[re_val] = type_map_array[old_val] + return reidx_to_name + def _make_idx_map(self, atom_type): natoms = atom_type.shape[0] idx = np.arange(natoms, dtype=np.int64) diff --git a/source/tests/pt/mixed_type_data/sys.000000/set.000/box.npy b/source/tests/pt/mixed_type_data/sys.000000/set.000/box.npy new file mode 100644 index 0000000000..9a75fc0591 Binary files /dev/null and b/source/tests/pt/mixed_type_data/sys.000000/set.000/box.npy differ diff --git a/source/tests/pt/mixed_type_data/sys.000000/set.000/coord.npy b/source/tests/pt/mixed_type_data/sys.000000/set.000/coord.npy new file mode 100644 index 0000000000..8d6cb11195 Binary files /dev/null and b/source/tests/pt/mixed_type_data/sys.000000/set.000/coord.npy differ diff --git a/source/tests/pt/mixed_type_data/sys.000000/set.000/energy.npy b/source/tests/pt/mixed_type_data/sys.000000/set.000/energy.npy new file mode 100644 index 0000000000..831de792c4 Binary files /dev/null and b/source/tests/pt/mixed_type_data/sys.000000/set.000/energy.npy differ diff --git a/source/tests/pt/mixed_type_data/sys.000000/set.000/force.npy b/source/tests/pt/mixed_type_data/sys.000000/set.000/force.npy new file mode 100644 index 0000000000..3e11288329 Binary files /dev/null and b/source/tests/pt/mixed_type_data/sys.000000/set.000/force.npy differ diff --git a/source/tests/pt/mixed_type_data/sys.000000/set.000/real_atom_numbs.npy b/source/tests/pt/mixed_type_data/sys.000000/set.000/real_atom_numbs.npy new file mode 100644 index 0000000000..81e8df6b91 Binary files /dev/null and b/source/tests/pt/mixed_type_data/sys.000000/set.000/real_atom_numbs.npy differ diff --git a/source/tests/pt/mixed_type_data/sys.000000/set.000/real_atom_types.npy b/source/tests/pt/mixed_type_data/sys.000000/set.000/real_atom_types.npy new file mode 100644 index 0000000000..1522e39c49 Binary files /dev/null and b/source/tests/pt/mixed_type_data/sys.000000/set.000/real_atom_types.npy differ diff --git a/source/tests/pt/mixed_type_data/sys.000000/set.001/box.npy b/source/tests/pt/mixed_type_data/sys.000000/set.001/box.npy new file mode 100644 index 0000000000..9a75fc0591 Binary files /dev/null and b/source/tests/pt/mixed_type_data/sys.000000/set.001/box.npy differ diff --git a/source/tests/pt/mixed_type_data/sys.000000/set.001/coord.npy b/source/tests/pt/mixed_type_data/sys.000000/set.001/coord.npy new file mode 100644 index 0000000000..8d6cb11195 Binary files /dev/null and b/source/tests/pt/mixed_type_data/sys.000000/set.001/coord.npy differ diff --git a/source/tests/pt/mixed_type_data/sys.000000/set.001/energy.npy b/source/tests/pt/mixed_type_data/sys.000000/set.001/energy.npy new file mode 100644 index 0000000000..831de792c4 Binary files /dev/null and b/source/tests/pt/mixed_type_data/sys.000000/set.001/energy.npy differ diff --git a/source/tests/pt/mixed_type_data/sys.000000/set.001/force.npy b/source/tests/pt/mixed_type_data/sys.000000/set.001/force.npy new file mode 100644 index 0000000000..3e11288329 Binary files /dev/null and b/source/tests/pt/mixed_type_data/sys.000000/set.001/force.npy differ diff --git a/source/tests/pt/mixed_type_data/sys.000000/set.001/real_atom_numbs.npy b/source/tests/pt/mixed_type_data/sys.000000/set.001/real_atom_numbs.npy new file mode 100644 index 0000000000..81e8df6b91 Binary files /dev/null and b/source/tests/pt/mixed_type_data/sys.000000/set.001/real_atom_numbs.npy differ diff --git a/source/tests/pt/mixed_type_data/sys.000000/set.001/real_atom_types.npy b/source/tests/pt/mixed_type_data/sys.000000/set.001/real_atom_types.npy new file mode 100644 index 0000000000..1522e39c49 Binary files /dev/null and b/source/tests/pt/mixed_type_data/sys.000000/set.001/real_atom_types.npy differ diff --git a/source/tests/pt/mixed_type_data/sys.000000/type.raw b/source/tests/pt/mixed_type_data/sys.000000/type.raw new file mode 100644 index 0000000000..a21090781f --- /dev/null +++ b/source/tests/pt/mixed_type_data/sys.000000/type.raw @@ -0,0 +1,7 @@ +0 +0 +0 +0 +0 +0 +0 diff --git a/source/tests/pt/mixed_type_data/sys.000000/type_map.raw b/source/tests/pt/mixed_type_data/sys.000000/type_map.raw new file mode 100644 index 0000000000..0233506281 --- /dev/null +++ b/source/tests/pt/mixed_type_data/sys.000000/type_map.raw @@ -0,0 +1,57 @@ +Ag +Al +As +Au +B +Bi +C +Ca +Cd +Cl +Co +Cr +Cs +Cu +Fe +Ga +Ge +H +Hf +Hg +In +Ir +K +Mg +Mn +Mo +N +Na +Nb +Ni +O +Os +P +Pb +Pd +Pt +Rb +Re +Rh +Ru +S +Sb +Sc +Se +Si +Sn +Sr +Ta +Tc +Te +Ti +Tl +V +W +Y +Zn +Zr diff --git a/source/tests/pt/test_make_stat_input.py b/source/tests/pt/test_make_stat_input.py new file mode 100644 index 0000000000..bdc450512f --- /dev/null +++ b/source/tests/pt/test_make_stat_input.py @@ -0,0 +1,186 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest +from pathlib import ( + Path, +) + +import numpy as np +import torch +from torch.utils.data import ( + DataLoader, +) + +from deepmd.pt.utils.dataset import ( + DeepmdDataSetForLoader, +) +from deepmd.pt.utils.stat import ( + compute_output_stats, + make_stat_input, +) +from deepmd.utils.data import ( + DataRequirementItem, +) + + +def collate_fn(batch): + if isinstance(batch, dict): + batch = [batch] + + out = {} + for key in batch[0].keys(): + items = [sample[key] for sample in batch] + + if isinstance(items[0], torch.Tensor): + out[key] = torch.stack(items, dim=0) + elif isinstance(items[0], np.ndarray): + out[key] = torch.from_numpy(np.stack(items, axis=0)) + else: + try: + out[key] = torch.tensor(items) + except Exception: + out[key] = items + + return out + + +class TestMakeStatInput(unittest.TestCase): + @classmethod + def setUpClass(cls): + with torch.device("cpu"): + system_path = str(Path(__file__).parent / "mixed_type_data/sys.000000") + cls.real_ntypes = 6 + cls.datasets = DeepmdDataSetForLoader(system=system_path) + data_requirements = [ + DataRequirementItem( + "energy", + ndof=1, + atomic=False, + ), + ] + cls.datasets.add_data_requirement(data_requirements) + cls.datasets = [cls.datasets] + cls.dataloaders = [] + for dataset in cls.datasets: + dataloader = DataLoader( + dataset, + batch_size=1, + num_workers=0, + drop_last=False, + collate_fn=collate_fn, + pin_memory=False, + ) + cls.dataloaders.append(dataloader) + + def count_non_zero_elements(self, tensor, threshold=1e-8): + return torch.sum(torch.abs(tensor) > threshold).item() + + def test_make_stat_input(self): + # 3 frames would be count + lst = make_stat_input( + datasets=self.datasets, + dataloaders=self.dataloaders, + nbatches=1, + min_frames_per_element_forstat=1, + enable_element_completion=True, + ) + bias, _ = compute_output_stats(lst, ntypes=57) + energy = bias.get("energy") + non_zero_count = self.count_non_zero_elements(energy) + self.assertEqual( + non_zero_count, + self.real_ntypes, + f"Expected exactly {self.real_ntypes} non-zero elements, but got {non_zero_count}.", + ) + + def test_make_stat_input_nocomplete(self): + # missing element:13,31,37 + # only one frame would be count + + lst = make_stat_input( + datasets=self.datasets, + dataloaders=self.dataloaders, + nbatches=1, + min_frames_per_element_forstat=1, + enable_element_completion=False, + ) + bias, _ = compute_output_stats(lst, ntypes=57) + energy = bias.get("energy") + non_zero_count = self.count_non_zero_elements(energy) + self.assertLess( + non_zero_count, + self.real_ntypes, + f"Expected fewer than {self.real_ntypes} non-zero elements, but got {non_zero_count}.", + ) + + def test_bias(self): + lst_ori = make_stat_input( + datasets=self.datasets, + dataloaders=self.dataloaders, + nbatches=1, + min_frames_per_element_forstat=1, + enable_element_completion=False, + ) + lst_all = make_stat_input( + datasets=self.datasets, + dataloaders=self.dataloaders, + nbatches=1, + min_frames_per_element_forstat=1, + enable_element_completion=True, + ) + + bias_ori, _ = compute_output_stats(lst_ori, ntypes=57) + bias_all, _ = compute_output_stats(lst_all, ntypes=57) + energy_ori = np.array(bias_ori.get("energy").cpu()).flatten() + energy_all = np.array(bias_all.get("energy").cpu()).flatten() + + for i, (e_ori, e_all) in enumerate(zip(energy_ori, energy_all)): + if e_all == 0: + self.assertEqual( + e_ori, 0, f"Index {i}: energy_all=0, but energy_ori={e_ori}" + ) + else: + if e_ori != 0: + diff = abs(e_ori - e_all) + rel_diff = diff / abs(e_ori) + self.assertTrue( + rel_diff < 0.4, + f"Index {i}: energy_ori={e_ori}, energy_all={e_all}, " + f"relative difference {rel_diff:.2%} is too large", + ) + + def test_with_nomissing(self): + lst_ori = make_stat_input( + datasets=self.datasets, + dataloaders=self.dataloaders, + nbatches=10, + min_frames_per_element_forstat=1, + enable_element_completion=False, + ) + for dct in lst_ori: + for key in ["find_box", "find_coord", "find_numb_copy", "find_energy"]: + if key in dct: + val = dct[key] + if val.numel() > 1: + dct[key] = val[0] + lst_new = make_stat_input( + datasets=self.datasets, + dataloaders=self.dataloaders, + nbatches=10, + min_frames_per_element_forstat=1, + enable_element_completion=True, + ) + for dct in lst_new: + for key in ["find_box", "find_coord", "find_numb_copy", "find_energy"]: + if key in dct: + val = dct[key] + if val.numel() > 1: + dct[key] = val[0] + bias_ori, _ = compute_output_stats(lst_ori, ntypes=57) + bias_new, _ = compute_output_stats(lst_new, ntypes=57) + energy_ori = np.array(bias_ori.get("energy").cpu()).flatten() + energy_new = np.array(bias_new.get("energy").cpu()).flatten() + self.assertTrue( + np.array_equal(energy_ori, energy_new), + msg=f"energy_ori and energy_new are not exactly the same!\n" + f"energy_ori = {energy_ori}\nenergy_new = {energy_new}", + )