From c615e4d8a4b922fcac537fed9cfadec6c25a412b Mon Sep 17 00:00:00 2001 From: Fabian Zills <46721498+PythonFZ@users.noreply.github.com> Date: Sun, 6 Oct 2024 19:19:32 +0200 Subject: [PATCH] bugfix for json / list in info data (#145) * bugfix for json / list in info data * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- tests/test_data_types.py | 65 +++++++++++++++++++++++++++++++++++++++- znh5md/format.py | 2 +- znh5md/io.py | 8 ++++- znh5md/utils.py | 2 +- 4 files changed, 73 insertions(+), 4 deletions(-) diff --git a/tests/test_data_types.py b/tests/test_data_types.py index 47c3dd60..c1444241 100644 --- a/tests/test_data_types.py +++ b/tests/test_data_types.py @@ -1,4 +1,6 @@ import ase.build +import ase.io +import numpy.testing as npt import pytest import znh5md @@ -39,9 +41,70 @@ def test_int_info_data(tmp_path): def test_dict_data(tmp_path): - io = znh5md.IO(tmp_path / "test.h5") molecule = ase.build.molecule("H2O") molecule.info["test"] = {"a": 1, "b": 2} + # Write to extxyz format + extxyz_path = tmp_path / "molecule.extxyz" + ase.io.write(extxyz_path, molecule, format="extxyz") + + # Read from extxyz format + molecule = ase.io.read(extxyz_path, format="extxyz") + + io = znh5md.IO(tmp_path / "test.h5") io.append(molecule) + molecule.info["test"] = {"a": 1, "b": 2, "c": 3} + io.append(molecule) + molecule.info["b"] = {"a": 1, "b": 2, "c": 3, "d": 4} + io.append(molecule) + + assert io[0].info["test"] == {"a": 1, "b": 2} + assert io[1].info["test"] == {"a": 1, "b": 2, "c": 3} + assert io[2].info["b"] == {"a": 1, "b": 2, "c": 3, "d": 4} + + +def test_list_data(tmp_path): + molecule = ase.build.molecule("H2O") + molecule.info["test"] = [1, 2] + + # Write to extxyz format + extxyz_path = tmp_path / "molecule.extxyz" + ase.io.write(extxyz_path, molecule, format="extxyz") + + # Read from extxyz format + molecule = ase.io.read(extxyz_path, format="extxyz") + npt.assert_array_equal(molecule.info["test"], [1, 2]) + + io = znh5md.IO(tmp_path / "test.h5") + io.append(molecule) + molecule.info["test"] = [1, 2, 3] + io.append(molecule) + molecule.info["b"] = [1, 2, 3, 4] + io.append(molecule) + + npt.assert_array_equal(io[0].info["test"], [1, 2]) + npt.assert_array_equal(io[1].info["test"], [1, 2, 3]) + npt.assert_array_equal(io[2].info["b"], [1, 2, 3, 4]) + + +def test_multiple_molecules_with_diff_length_dicts(tmp_path): + molecules = [ + ase.build.molecule("H2O"), + ase.build.molecule("CH4"), + ] + + # Assign different length dicts to molecule info + molecules[0].info["test"] = {"a": 1, "b": 2} + molecules[1].info["test"] = {"a": 1} + + extxyz_path = tmp_path / "molecules.extxyz" + ase.io.write(extxyz_path, molecules, format="extxyz") + + read_molecules = ase.io.read(extxyz_path, index=":", format="extxyz") + + io = znh5md.IO(tmp_path / "test.h5") + for mol in read_molecules: + io.append(mol) + assert io[0].info["test"] == {"a": 1, "b": 2} + assert io[1].info["test"] == {"a": 1} diff --git a/znh5md/format.py b/znh5md/format.py index 05cb7e32..73e8c046 100644 --- a/znh5md/format.py +++ b/znh5md/format.py @@ -298,7 +298,7 @@ def _combine_dicts(dicts: List[Dict[str, np.ndarray]]) -> Dict[str, np.ndarray]: data = [] for d in dicts: if key in d: - data.append(d[key]) + data.append(np.array(d[key])) else: dims = dicts[0][key].ndim # Create an array with the appropriate number of dimensions. diff --git a/znh5md/io.py b/znh5md/io.py index 22c17174..3f726efc 100644 --- a/znh5md/io.py +++ b/znh5md/io.py @@ -17,6 +17,7 @@ import znh5md.format as fmt from znh5md import utils +from znh5md.config import STRING_FILL_VALUE __version__ = importlib.metadata.version("znh5md") @@ -204,7 +205,12 @@ def _extract_additional_data(self, f, index, arrays_data, calc_data, info_data): ].attrs.get("ZNH5MD_TYPE") == "json" ): - info_data[key] = [json.loads(x) for x in data] + info_data[key] = [ + json.loads(x) + if x != STRING_FILL_VALUE + else STRING_FILL_VALUE + for x in data + ] else: info_data[key] = data except IndexError: diff --git a/znh5md/utils.py b/znh5md/utils.py index fc17097a..c71a8814 100644 --- a/znh5md/utils.py +++ b/znh5md/utils.py @@ -5,7 +5,7 @@ from znh5md.config import NUMERIC_FILL_VALUE, STRING_FILL_VALUE -NUMPY_STRING_DTYPE = np.dtype("S512") +NUMPY_STRING_DTYPE = np.dtype("S4096") # TODO this has to be tested! def concatenate_varying_shape_arrays(arrays: list[np.ndarray]) -> np.ndarray: