Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix str in atoms.arrays #142

Merged
merged 8 commits into from
Sep 18, 2024
Merged
14 changes: 14 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import ase.build
import ase.collections
import numpy as np
import pytest
Expand Down Expand Up @@ -132,3 +133,16 @@ def s22_no_ascii() -> list[ase.Atoms]:
atoms.info["config"] = "βγ"
images.append(atoms)
return images


@pytest.fixture
def frames_with_residuenames() -> list[ase.Atoms]:
water = ase.build.molecule("H2O")
# typical PDB array data
water.arrays["residuenames"] = np.array(["H2O"] * len(water))
water.arrays["atomtypes"] = np.array(["γO", "βH", "βH"])

ethane = ase.build.molecule("C2H6")
ethane.arrays["residuenames"] = np.array(["C2H6"] * len(ethane))
ethane.arrays["atomtypes"] = np.array(["γC", "βH", "βH", "βH", "βH", "βH"])
return [water, ethane]
14 changes: 14 additions & 0 deletions tests/test_string_arrays.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import numpy.testing as npt

import znh5md


def test_info_non_ascii(tmp_path, frames_with_residuenames):
io = znh5md.IO(tmp_path / "test.h5")
# io = znh5md.IO("test.h5")
io.extend(frames_with_residuenames)

for a, b in zip(io, frames_with_residuenames):
for key in b.arrays:
npt.assert_array_equal(a.arrays[key], b.arrays[key])
# test json per atom?
7 changes: 7 additions & 0 deletions tests/test_znh5md.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,10 @@

def test_version():
assert znh5md.__version__ == "0.3.7"


def test_creator(tmp_path):
io = znh5md.IO(tmp_path / "test.h5")
# These are the defaults
assert io.creator == "znh5md"
assert io.creator_version == znh5md.__version__
4 changes: 4 additions & 0 deletions znh5md/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
import numpy as np

NUMERIC_FILL_VALUE = np.nan
STRING_FILL_VALUE = b""
6 changes: 6 additions & 0 deletions znh5md/format.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,12 @@ def extract_atoms_data(atoms: Atoms, use_ase_calc: bool = True) -> ASEData: # n
if use_ase_calc and key in all_properties:
raise ValueError(f"Key {key} is reserved for ASE calculator results.")
if key not in ASE_TO_H5MD:
if isinstance(value, np.ndarray):
# check if dtype is <U1 or <U2
if value.dtype.kind == "U":
value = np.array(
[x.encode() for x in value.tolist()], dtype=NUMPY_STRING_DTYPE
)
particles[key] = value

time: Optional[float] = atoms.info.get(CustomINFOData.h5md_time.name, None)
Expand Down
7 changes: 5 additions & 2 deletions znh5md/io.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import contextlib
import dataclasses
import importlib.metadata
import json
import os
import pathlib
Expand All @@ -17,6 +18,8 @@
import znh5md.format as fmt
from znh5md import utils

__version__ = importlib.metadata.version("znh5md")

# TODO: use pint to convert the units in the h5md file to ase units


Expand All @@ -41,8 +44,8 @@ class IO(MutableSequence):
save_units: bool = True # Export ASE units into the H5MD file
author: str = "N/A"
author_email: str = "N/A"
creator: str = "N/A"
creator_version: str = "N/A"
creator: str = "znh5md"
creator_version: str = __version__
particle_group: Optional[str] = None
compression: Optional[str] = "gzip"
compression_opts: Optional[int] = None
Expand Down
19 changes: 16 additions & 3 deletions znh5md/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import numpy as np
from ase.calculators.singlepoint import SinglePointCalculator

from znh5md.config import NUMERIC_FILL_VALUE, STRING_FILL_VALUE

NUMPY_STRING_DTYPE = np.dtype("S512")


Expand Down Expand Up @@ -32,9 +34,18 @@ def concatenate_varying_shape_arrays(arrays: list[np.ndarray]) -> np.ndarray:
max_n_particles = max(x.shape[0] for x in arrays)
dimensions = arrays[0].shape[1:]

result = np.full((len(arrays), max_n_particles, *dimensions), np.nan)
for i, x in enumerate(arrays):
result[i, : x.shape[0], ...] = x
if arrays[0].dtype == NUMPY_STRING_DTYPE:
result = np.full(
(len(arrays), max_n_particles), STRING_FILL_VALUE, dtype=NUMPY_STRING_DTYPE
)
for i, x in enumerate(arrays):
result[i, : x.shape[0]] = x
else:
result = np.full(
(len(arrays), max_n_particles, *dimensions), NUMERIC_FILL_VALUE
)
for i, x in enumerate(arrays):
result[i, : x.shape[0], ...] = x
return result


Expand All @@ -58,6 +69,8 @@ def remove_nan_rows(array: np.ndarray) -> np.ndarray | None:
1

"""
if isinstance(array, np.ndarray) and array.dtype == object:
return np.array([x.decode() for x in array if x != STRING_FILL_VALUE])
if np.isnan(array).all():
return None
if len(np.shape(array)) == 0:
Expand Down
Loading