From 89a299026cebdb9c9cc008b953a1489d5831ea44 Mon Sep 17 00:00:00 2001 From: Fabian Zills Date: Wed, 14 Aug 2024 19:38:08 +0200 Subject: [PATCH] add tests (failing) --- tests/test_inconsistent_data.py | 18 ++++++++++++++++++ znh5md/format.py | 14 ++++++++------ 2 files changed, 26 insertions(+), 6 deletions(-) create mode 100644 tests/test_inconsistent_data.py diff --git a/tests/test_inconsistent_data.py b/tests/test_inconsistent_data.py new file mode 100644 index 00000000..c58087b7 --- /dev/null +++ b/tests/test_inconsistent_data.py @@ -0,0 +1,18 @@ +import znh5md +import numpy.testing as npt + + +def test_keys_missing(tmp_path, s22, s22_energy_forces): + io = znh5md.IO(tmp_path / "test.h5") + + + images = s22_energy_forces + s22 + io.extend(images) + assert len(io) == len(images) + assert len(list(io)) == len(images) + + for a, b in zip(images, io): + assert a == b + b.get_potential_energy() + assert a.get_potential_energy() == b.get_potential_energy() + npt.assert_array_equal(a.get_forces(), b.get_forces()) diff --git a/znh5md/format.py b/znh5md/format.py index 677f6ae2..db5a38b8 100644 --- a/znh5md/format.py +++ b/znh5md/format.py @@ -288,10 +288,12 @@ def _combine_dicts(dicts: List[Dict[str, np.ndarray]]) -> Dict[str, np.ndarray]: """Helper function to combine dictionaries containing numpy arrays.""" combined = {} for key in dicts[0]: - combined[key] = concatenate_varying_shape_arrays( - [ - d[key] if isinstance(d[key], np.ndarray) else np.array(d[key]) - for d in dicts - ] - ) + data = [] + for d in dicts: + if key in d: + data.append(d[key]) + if data: + combined[key] = concatenate_varying_shape_arrays(data) + else: + raise ValueError(f"Key {key} is missing in one of the data objects.") return combined