Skip to content

Commit

Permalink
support string data (#127)
Browse files Browse the repository at this point in the history
* support string data

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* raise an error to avoid silence string truncation

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* poetry and version update

* fix #125

* additional testing for wrong input types.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
PythonFZ and pre-commit-ci[bot] authored Aug 13, 2024
1 parent 1ad615e commit 51b9b60
Show file tree
Hide file tree
Showing 8 changed files with 135 additions and 41 deletions.
80 changes: 46 additions & 34 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "znh5md"
version = "0.3.4"
version = "0.3.5"
description = "ASE Interface for the H5MD format."
authors = ["zincwarecode <[email protected]>"]
license = "Apache-2.0"
Expand Down
38 changes: 38 additions & 0 deletions tests/test_data_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import ase.build
import pytest

import znh5md


def test_smiles(tmp_path):
io = znh5md.IO(tmp_path / "test.h5")
molecule = ase.build.molecule("H2O")
molecule.info["smiles"] = "O"

io.append(molecule)
assert io[0].info["smiles"] == "O"

molecule = ase.build.molecule("H2O2")
molecule.info["smiles"] = "OO"

io.append(molecule)
assert io[0].info["smiles"] == "O"
assert io[1].info["smiles"] == "OO"


def test_very_long_text_data(tmp_path):
io = znh5md.IO(tmp_path / "test.h5")
molecule = ase.build.molecule("H2O")

molecule.info["test"] = f"{list(range(1_000))}"
with pytest.raises(ValueError, match="String test is too long to be stored."):
io.append(molecule)


def test_int_info_data(tmp_path):
io = znh5md.IO(tmp_path / "test.h5")
molecule = ase.build.molecule("H2O")
molecule.info["test"] = 123

io.append(molecule)
assert io[0].info["test"] == 123
20 changes: 20 additions & 0 deletions tests/test_recreate_h5py_file.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import ase.build
import pytest

import znh5md


def test_extend_wrong_error(tmp_path):
io = znh5md.IO(tmp_path / "test.h5")
molecule = ase.build.molecule("H2O")

with pytest.raises(ValueError, match="images must be a list of ASE Atoms objects"):
io.extend(molecule)


def test_append_wrong_error(tmp_path):
io = znh5md.IO(tmp_path / "test.h5")
molecule = ase.build.molecule("H2O")

with pytest.raises(ValueError, match="atoms must be an ASE Atoms object"):
io.append([molecule])
2 changes: 1 addition & 1 deletion tests/test_znh5md.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@


def test_version():
assert znh5md.__version__ == "0.3.4"
assert znh5md.__version__ == "0.3.5"
9 changes: 7 additions & 2 deletions znh5md/format.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from ase import Atoms
from ase.calculators.calculator import all_properties

from .utils import concatenate_varying_shape_arrays
from .utils import NUMPY_STRING_DTYPE, concatenate_varying_shape_arrays


class ASEKeyMetaData(TypedDict):
Expand Down Expand Up @@ -201,7 +201,12 @@ def extract_atoms_data(atoms: Atoms, use_ase_calc: bool = True) -> ASEData: # n
if use_ase_calc and key in all_properties:
raise ValueError(f"Key {key} is reserved for ASE calculator results.")
if key not in ASE_TO_H5MD and key not in CustomINFOData.__members__:
info_data[key] = value
if isinstance(value, str):
if len(value) > NUMPY_STRING_DTYPE.itemsize:
raise ValueError(f"String {key} is too long to be stored.")
info_data[key] = np.array(value, dtype=NUMPY_STRING_DTYPE)
else:
info_data[key] = value

for key, value in atoms.arrays.items():
if use_ase_calc and key in all_properties:
Expand Down
8 changes: 6 additions & 2 deletions znh5md/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,8 @@ def _extract_additional_data(self, f, index, arrays_data, calc_data, info_data):
)

def extend(self, images: List[ase.Atoms]):
if not isinstance(images, list):
raise ValueError("images must be a list of ASE Atoms objects")
if len(images) == 0:
warnings.warn("No data provided")
return
Expand Down Expand Up @@ -262,7 +264,7 @@ def _create_group(
ds_value = g_grp.create_dataset(
"value",
data=data,
dtype=np.float64,
dtype=utils.get_h5py_dtype(data),
chunks=True
if self.chunk_size is None
else tuple([self.chunk_size] + list(data.shape[1:])),
Expand Down Expand Up @@ -336,7 +338,7 @@ def _create_observables(
ds_value = g_observable.create_dataset(
"value",
data=value,
dtype=np.float64,
dtype=utils.get_h5py_dtype(value),
chunks=True
if self.chunk_size is None
else tuple([self.chunk_size] + list(value.shape[1:])),
Expand Down Expand Up @@ -430,6 +432,8 @@ def _extend_observables(
utils.fill_dataset(g_val["step"], step)

def append(self, atoms: ase.Atoms):
if not isinstance(atoms, ase.Atoms):
raise ValueError("atoms must be an ASE Atoms object")
self.extend([atoms])

def __delitem__(self, index):
Expand Down
17 changes: 16 additions & 1 deletion znh5md/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import ase
import h5py
import numpy as np
from ase.calculators.singlepoint import SinglePointCalculator

NUMPY_STRING_DTYPE = np.dtype("S512")


def concatenate_varying_shape_arrays(arrays: list[np.ndarray]) -> np.ndarray:
"""Concatenate arrays of varying lengths into a numpy array.
Expand Down Expand Up @@ -122,7 +125,12 @@ def build_atoms(args) -> ase.Atoms:
key: remove_nan_rows(value) for key, value in arrays_data.items()
}
if info_data is not None:
info_data = {key: remove_nan_rows(value) for key, value in info_data.items()}
# We update the info_data in place
for key, value in info_data.items():
if isinstance(value, bytes):
info_data[key] = value.decode("utf-8")
else:
info_data[key] = remove_nan_rows(value)

atoms = ase.Atoms(
symbols=atomic_numbers,
Expand Down Expand Up @@ -170,3 +178,10 @@ def build_structures(
)
structures.append(build_atoms(args))
return structures


def get_h5py_dtype(data: np.ndarray):
if data.dtype == NUMPY_STRING_DTYPE:
return h5py.string_dtype(encoding="utf-8")
else:
return data.dtype

0 comments on commit 51b9b60

Please sign in to comment.