diff --git a/tests/conftest.py b/tests/conftest.py index 5f4a2bad..5e651fde 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3 +1,4 @@ +import ase.build import ase.collections import numpy as np import pytest @@ -132,3 +133,16 @@ def s22_no_ascii() -> list[ase.Atoms]: atoms.info["config"] = "βγ" images.append(atoms) return images + + +@pytest.fixture +def frames_with_residuenames() -> list[ase.Atoms]: + water = ase.build.molecule("H2O") + # typical PDB array data + water.arrays["residuenames"] = np.array(["H2O"] * len(water)) + water.arrays["atomtypes"] = np.array(["γO", "βH", "βH"]) + + ethane = ase.build.molecule("C2H6") + ethane.arrays["residuenames"] = np.array(["C2H6"] * len(ethane)) + ethane.arrays["atomtypes"] = np.array(["γC", "βH", "βH", "βH", "βH", "βH"]) + return [water, ethane] diff --git a/tests/test_string_arrays.py b/tests/test_string_arrays.py new file mode 100644 index 00000000..0bb15268 --- /dev/null +++ b/tests/test_string_arrays.py @@ -0,0 +1,14 @@ +import numpy.testing as npt + +import znh5md + + +def test_info_non_ascii(tmp_path, frames_with_residuenames): + io = znh5md.IO(tmp_path / "test.h5") + # io = znh5md.IO("test.h5") + io.extend(frames_with_residuenames) + + for a, b in zip(io, frames_with_residuenames): + for key in b.arrays: + npt.assert_array_equal(a.arrays[key], b.arrays[key]) + # test json per atom? diff --git a/tests/test_znh5md.py b/tests/test_znh5md.py index 9b86103a..3874083b 100644 --- a/tests/test_znh5md.py +++ b/tests/test_znh5md.py @@ -3,3 +3,10 @@ def test_version(): assert znh5md.__version__ == "0.3.7" + + +def test_creator(tmp_path): + io = znh5md.IO(tmp_path / "test.h5") + # These are the defaults + assert io.creator == "znh5md" + assert io.creator_version == znh5md.__version__ diff --git a/znh5md/config.py b/znh5md/config.py new file mode 100644 index 00000000..cffb14e9 --- /dev/null +++ b/znh5md/config.py @@ -0,0 +1,4 @@ +import numpy as np + +NUMERIC_FILL_VALUE = np.nan +STRING_FILL_VALUE = b"" diff --git a/znh5md/format.py b/znh5md/format.py index 6b7990fb..05cb7e32 100644 --- a/znh5md/format.py +++ b/znh5md/format.py @@ -218,6 +218,12 @@ def extract_atoms_data(atoms: Atoms, use_ase_calc: bool = True) -> ASEData: # n if use_ase_calc and key in all_properties: raise ValueError(f"Key {key} is reserved for ASE calculator results.") if key not in ASE_TO_H5MD: + if isinstance(value, np.ndarray): + # check if dtype is np.ndarray: max_n_particles = max(x.shape[0] for x in arrays) dimensions = arrays[0].shape[1:] - result = np.full((len(arrays), max_n_particles, *dimensions), np.nan) - for i, x in enumerate(arrays): - result[i, : x.shape[0], ...] = x + if arrays[0].dtype == NUMPY_STRING_DTYPE: + result = np.full( + (len(arrays), max_n_particles), STRING_FILL_VALUE, dtype=NUMPY_STRING_DTYPE + ) + for i, x in enumerate(arrays): + result[i, : x.shape[0]] = x + else: + result = np.full( + (len(arrays), max_n_particles, *dimensions), NUMERIC_FILL_VALUE + ) + for i, x in enumerate(arrays): + result[i, : x.shape[0], ...] = x return result @@ -58,6 +69,8 @@ def remove_nan_rows(array: np.ndarray) -> np.ndarray | None: 1 """ + if isinstance(array, np.ndarray) and array.dtype == object: + return np.array([x.decode() for x in array if x != STRING_FILL_VALUE]) if np.isnan(array).all(): return None if len(np.shape(array)) == 0: