Skip to content

Commit

Permalink
add tests (failing)
Browse files Browse the repository at this point in the history
  • Loading branch information
PythonFZ committed Aug 14, 2024
1 parent 1c8b362 commit 89a2990
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 6 deletions.
18 changes: 18 additions & 0 deletions tests/test_inconsistent_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import znh5md
import numpy.testing as npt


def test_keys_missing(tmp_path, s22, s22_energy_forces):
io = znh5md.IO(tmp_path / "test.h5")


images = s22_energy_forces + s22
io.extend(images)
assert len(io) == len(images)
assert len(list(io)) == len(images)

for a, b in zip(images, io):
assert a == b
b.get_potential_energy()
assert a.get_potential_energy() == b.get_potential_energy()
npt.assert_array_equal(a.get_forces(), b.get_forces())
14 changes: 8 additions & 6 deletions znh5md/format.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,10 +288,12 @@ def _combine_dicts(dicts: List[Dict[str, np.ndarray]]) -> Dict[str, np.ndarray]:
"""Helper function to combine dictionaries containing numpy arrays."""
combined = {}
for key in dicts[0]:
combined[key] = concatenate_varying_shape_arrays(
[
d[key] if isinstance(d[key], np.ndarray) else np.array(d[key])
for d in dicts
]
)
data = []
for d in dicts:
if key in d:
data.append(d[key])
if data:
combined[key] = concatenate_varying_shape_arrays(data)
else:
raise ValueError(f"Key {key} is missing in one of the data objects.")
return combined

1 comment on commit 89a2990

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmark

Write: Varying number of images

Write: Varying number of atoms

Read: Varying number of images

Read: Varying number of atoms

Please sign in to comment.