diff --git a/source/tests/pt/test_make_stat_input.py b/source/tests/pt/test_make_stat_input.py index 40a765f647..c3fcc5ebc1 100644 --- a/source/tests/pt/test_make_stat_input.py +++ b/source/tests/pt/test_make_stat_input.py @@ -23,17 +23,25 @@ def collate_fn(batch): + if isinstance(batch, dict): batch = [batch] - collated_batch = {} + + out = {} for key in batch[0].keys(): - data_list = [d[key] for d in batch] - if isinstance(data_list[0], np.ndarray): - data_np = np.stack(data_list) - collated_batch[key] = torch.from_numpy(data_np) + items = [sample[key] for sample in batch] + + if isinstance(items[0], torch.Tensor): + out[key] = torch.stack(items, dim=0) + elif isinstance(items[0], np.ndarray): + out[key] = torch.from_numpy(np.stack(items, axis=0)) else: - collated_batch[key] = torch.tensor(data_list) - return collated_batch + try: + out[key] = torch.tensor(items) + except Exception: + out[key] = items + + return out class TestMakeStatInput(unittest.TestCase): @@ -52,19 +60,10 @@ def setUpClass(cls): ] cls.datasets.add_data_requirement(data_requirements) cls.datasets = [cls.datasets] - weights_tensor = torch.tensor( - [0.1] * len(cls.datasets), dtype=torch.float64, device="cpu" - ) - sampler = torch.utils.data.WeightedRandomSampler( - weights_tensor, - num_samples=len(cls.datasets), - replacement=True, - ) cls.dataloaders = [] for dataset in cls.datasets: dataloader = DataLoader( dataset, - sampler=sampler, batch_size=1, num_workers=0, drop_last=False, @@ -129,6 +128,7 @@ def test_bias(self): min_frames_per_element_forstat=1, enable_element_completion=True, ) + bias_ori, _ = compute_output_stats(lst_ori, ntypes=57) bias_all, _ = compute_output_stats(lst_all, ntypes=57) energy_ori = np.array(bias_ori.get("energy").cpu()).flatten() @@ -148,3 +148,42 @@ def test_bias(self): f"Index {i}: energy_ori={e_ori}, energy_all={e_all}, " f"relative difference {rel_diff:.2%} is too large", ) + def test_with_nomissing(self): + lst_ori = make_stat_input( + datasets=self.datasets, + dataloaders=self.dataloaders, + nbatches=10, + min_frames_per_element_forstat=1, + enable_element_completion=False, + ) + for dct in lst_ori: + for key in ["find_box", "find_coord", "find_numb_copy", "find_energy"]: + if key in dct: + val = dct[key] + if val.numel() > 1: + dct[key] = val[0] + lst_new = make_stat_input( + datasets=self.datasets, + dataloaders=self.dataloaders, + nbatches=10, + min_frames_per_element_forstat=1, + enable_element_completion=True, + ) + for dct in lst_new: + for key in ["find_box", "find_coord", "find_numb_copy", "find_energy"]: + if key in dct: + val = dct[key] + if val.numel() > 1: + dct[key] = val[0] + bias_ori, _ = compute_output_stats(lst_ori, ntypes=57) + bias_new, _ = compute_output_stats(lst_new, ntypes=57) + energy_ori = np.array(bias_ori.get("energy").cpu()).flatten() + energy_new = np.array(bias_new.get("energy").cpu()).flatten() + self.assertTrue( + np.array_equal(energy_ori, energy_new), + msg=f"energy_ori and energy_new are not exactly the same!\n" + f"energy_ori = {energy_ori}\nenergy_new = {energy_new}" + ) + +if __name__ == "__main__": + unittest.main()