Skip to content

Commit

Permalink
improve ut with all frames
Browse files Browse the repository at this point in the history
  • Loading branch information
SumGuo-88 committed Jan 10, 2025
1 parent c05ffb1 commit 139f037
Showing 1 changed file with 55 additions and 16 deletions.
71 changes: 55 additions & 16 deletions source/tests/pt/test_make_stat_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,25 @@


def collate_fn(batch):

if isinstance(batch, dict):
batch = [batch]
collated_batch = {}

out = {}
for key in batch[0].keys():
data_list = [d[key] for d in batch]
if isinstance(data_list[0], np.ndarray):
data_np = np.stack(data_list)
collated_batch[key] = torch.from_numpy(data_np)
items = [sample[key] for sample in batch]

if isinstance(items[0], torch.Tensor):
out[key] = torch.stack(items, dim=0)
elif isinstance(items[0], np.ndarray):
out[key] = torch.from_numpy(np.stack(items, axis=0))
else:
collated_batch[key] = torch.tensor(data_list)
return collated_batch
try:
out[key] = torch.tensor(items)
except Exception:
out[key] = items

return out


class TestMakeStatInput(unittest.TestCase):
Expand All @@ -52,19 +60,10 @@ def setUpClass(cls):
]
cls.datasets.add_data_requirement(data_requirements)
cls.datasets = [cls.datasets]
weights_tensor = torch.tensor(
[0.1] * len(cls.datasets), dtype=torch.float64, device="cpu"
)
sampler = torch.utils.data.WeightedRandomSampler(
weights_tensor,
num_samples=len(cls.datasets),
replacement=True,
)
cls.dataloaders = []
for dataset in cls.datasets:
dataloader = DataLoader(
dataset,
sampler=sampler,
batch_size=1,
num_workers=0,
drop_last=False,
Expand Down Expand Up @@ -129,6 +128,7 @@ def test_bias(self):
min_frames_per_element_forstat=1,
enable_element_completion=True,
)

bias_ori, _ = compute_output_stats(lst_ori, ntypes=57)
bias_all, _ = compute_output_stats(lst_all, ntypes=57)
energy_ori = np.array(bias_ori.get("energy").cpu()).flatten()
Expand All @@ -148,3 +148,42 @@ def test_bias(self):
f"Index {i}: energy_ori={e_ori}, energy_all={e_all}, "
f"relative difference {rel_diff:.2%} is too large",
)
def test_with_nomissing(self):
lst_ori = make_stat_input(
datasets=self.datasets,
dataloaders=self.dataloaders,
nbatches=10,
min_frames_per_element_forstat=1,
enable_element_completion=False,
)
for dct in lst_ori:
for key in ["find_box", "find_coord", "find_numb_copy", "find_energy"]:
if key in dct:
val = dct[key]
if val.numel() > 1:
dct[key] = val[0]
lst_new = make_stat_input(
datasets=self.datasets,
dataloaders=self.dataloaders,
nbatches=10,
min_frames_per_element_forstat=1,
enable_element_completion=True,
)
for dct in lst_new:
for key in ["find_box", "find_coord", "find_numb_copy", "find_energy"]:
if key in dct:
val = dct[key]
if val.numel() > 1:
dct[key] = val[0]
bias_ori, _ = compute_output_stats(lst_ori, ntypes=57)
bias_new, _ = compute_output_stats(lst_new, ntypes=57)
energy_ori = np.array(bias_ori.get("energy").cpu()).flatten()
energy_new = np.array(bias_new.get("energy").cpu()).flatten()
self.assertTrue(
np.array_equal(energy_ori, energy_new),
msg=f"energy_ori and energy_new are not exactly the same!\n"
f"energy_ori = {energy_ori}\nenergy_new = {energy_new}"
)

if __name__ == "__main__":
unittest.main()

0 comments on commit 139f037

Please sign in to comment.