diff --git a/tests/test_inconsistent_data.py b/tests/test_inconsistent_data.py new file mode 100644 index 00000000..c58087b7 --- /dev/null +++ b/tests/test_inconsistent_data.py @@ -0,0 +1,18 @@ +import znh5md +import numpy.testing as npt + + +def test_keys_missing(tmp_path, s22, s22_energy_forces): + io = znh5md.IO(tmp_path / "test.h5") + + + images = s22_energy_forces + s22 + io.extend(images) + assert len(io) == len(images) + assert len(list(io)) == len(images) + + for a, b in zip(images, io): + assert a == b + b.get_potential_energy() + assert a.get_potential_energy() == b.get_potential_energy() + npt.assert_array_equal(a.get_forces(), b.get_forces()) diff --git a/znh5md/format.py b/znh5md/format.py index 677f6ae2..db5a38b8 100644 --- a/znh5md/format.py +++ b/znh5md/format.py @@ -288,10 +288,12 @@ def _combine_dicts(dicts: List[Dict[str, np.ndarray]]) -> Dict[str, np.ndarray]: """Helper function to combine dictionaries containing numpy arrays.""" combined = {} for key in dicts[0]: - combined[key] = concatenate_varying_shape_arrays( - [ - d[key] if isinstance(d[key], np.ndarray) else np.array(d[key]) - for d in dicts - ] - ) + data = [] + for d in dicts: + if key in d: + data.append(d[key]) + if data: + combined[key] = concatenate_varying_shape_arrays(data) + else: + raise ValueError(f"Key {key} is missing in one of the data objects.") return combined