Skip to content

Commit

Permalink
bugfix for json / list in info data (#145)
Browse files Browse the repository at this point in the history
* bugfix for json / list in info data

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
PythonFZ and pre-commit-ci[bot] authored Oct 6, 2024
1 parent d4d1bfb commit c615e4d
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 4 deletions.
65 changes: 64 additions & 1 deletion tests/test_data_types.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import ase.build
import ase.io
import numpy.testing as npt
import pytest

import znh5md
Expand Down Expand Up @@ -39,9 +41,70 @@ def test_int_info_data(tmp_path):


def test_dict_data(tmp_path):
io = znh5md.IO(tmp_path / "test.h5")
molecule = ase.build.molecule("H2O")
molecule.info["test"] = {"a": 1, "b": 2}

# Write to extxyz format
extxyz_path = tmp_path / "molecule.extxyz"
ase.io.write(extxyz_path, molecule, format="extxyz")

# Read from extxyz format
molecule = ase.io.read(extxyz_path, format="extxyz")

io = znh5md.IO(tmp_path / "test.h5")
io.append(molecule)
molecule.info["test"] = {"a": 1, "b": 2, "c": 3}
io.append(molecule)
molecule.info["b"] = {"a": 1, "b": 2, "c": 3, "d": 4}
io.append(molecule)

assert io[0].info["test"] == {"a": 1, "b": 2}
assert io[1].info["test"] == {"a": 1, "b": 2, "c": 3}
assert io[2].info["b"] == {"a": 1, "b": 2, "c": 3, "d": 4}


def test_list_data(tmp_path):
molecule = ase.build.molecule("H2O")
molecule.info["test"] = [1, 2]

# Write to extxyz format
extxyz_path = tmp_path / "molecule.extxyz"
ase.io.write(extxyz_path, molecule, format="extxyz")

# Read from extxyz format
molecule = ase.io.read(extxyz_path, format="extxyz")
npt.assert_array_equal(molecule.info["test"], [1, 2])

io = znh5md.IO(tmp_path / "test.h5")
io.append(molecule)
molecule.info["test"] = [1, 2, 3]
io.append(molecule)
molecule.info["b"] = [1, 2, 3, 4]
io.append(molecule)

npt.assert_array_equal(io[0].info["test"], [1, 2])
npt.assert_array_equal(io[1].info["test"], [1, 2, 3])
npt.assert_array_equal(io[2].info["b"], [1, 2, 3, 4])


def test_multiple_molecules_with_diff_length_dicts(tmp_path):
molecules = [
ase.build.molecule("H2O"),
ase.build.molecule("CH4"),
]

# Assign different length dicts to molecule info
molecules[0].info["test"] = {"a": 1, "b": 2}
molecules[1].info["test"] = {"a": 1}

extxyz_path = tmp_path / "molecules.extxyz"
ase.io.write(extxyz_path, molecules, format="extxyz")

read_molecules = ase.io.read(extxyz_path, index=":", format="extxyz")

io = znh5md.IO(tmp_path / "test.h5")
for mol in read_molecules:
io.append(mol)

assert io[0].info["test"] == {"a": 1, "b": 2}
assert io[1].info["test"] == {"a": 1}
2 changes: 1 addition & 1 deletion znh5md/format.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ def _combine_dicts(dicts: List[Dict[str, np.ndarray]]) -> Dict[str, np.ndarray]:
data = []
for d in dicts:
if key in d:
data.append(d[key])
data.append(np.array(d[key]))
else:
dims = dicts[0][key].ndim
# Create an array with the appropriate number of dimensions.
Expand Down
8 changes: 7 additions & 1 deletion znh5md/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import znh5md.format as fmt
from znh5md import utils
from znh5md.config import STRING_FILL_VALUE

__version__ = importlib.metadata.version("znh5md")

Expand Down Expand Up @@ -204,7 +205,12 @@ def _extract_additional_data(self, f, index, arrays_data, calc_data, info_data):
].attrs.get("ZNH5MD_TYPE")
== "json"
):
info_data[key] = [json.loads(x) for x in data]
info_data[key] = [
json.loads(x)
if x != STRING_FILL_VALUE
else STRING_FILL_VALUE
for x in data
]
else:
info_data[key] = data
except IndexError:
Expand Down
2 changes: 1 addition & 1 deletion znh5md/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from znh5md.config import NUMERIC_FILL_VALUE, STRING_FILL_VALUE

NUMPY_STRING_DTYPE = np.dtype("S512")
NUMPY_STRING_DTYPE = np.dtype("S4096") # TODO this has to be tested!


def concatenate_varying_shape_arrays(arrays: list[np.ndarray]) -> np.ndarray:
Expand Down

0 comments on commit c615e4d

Please sign in to comment.