From 51b9b60bce9c28d6f83d04550ee58efbe2c8dacd Mon Sep 17 00:00:00 2001 From: Fabian Zills <46721498+PythonFZ@users.noreply.github.com> Date: Tue, 13 Aug 2024 14:55:07 +0200 Subject: [PATCH] support string data (#127) * support string data * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * raise an error to avoid silence string truncation * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * poetry and version update * fix #125 * additional testing for wrong input types. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- poetry.lock | 80 ++++++++++++++++++-------------- pyproject.toml | 2 +- tests/test_data_types.py | 38 +++++++++++++++ tests/test_recreate_h5py_file.py | 20 ++++++++ tests/test_znh5md.py | 2 +- znh5md/format.py | 9 +++- znh5md/io.py | 8 +++- znh5md/utils.py | 17 ++++++- 8 files changed, 135 insertions(+), 41 deletions(-) create mode 100644 tests/test_data_types.py create mode 100644 tests/test_recreate_h5py_file.py diff --git a/poetry.lock b/poetry.lock index 5a167fe0..279d6a57 100644 --- a/poetry.lock +++ b/poetry.lock @@ -612,40 +612,51 @@ testing = ["coverage", "pytest", "pytest-cov", "pytest-regressions"] [[package]] name = "matplotlib" -version = "3.9.1.post1" +version = "3.9.2" description = "Python plotting package" optional = false python-versions = ">=3.9" files = [ - {file = "matplotlib-3.9.1.post1-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:3779ad3e8b72df22b8a622c5796bbcfabfa0069b835412e3c1dec8ee3de92d0c"}, - {file = "matplotlib-3.9.1.post1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:ec400340f8628e8e2260d679078d4e9b478699f386e5cc8094e80a1cb0039c7c"}, - {file = "matplotlib-3.9.1.post1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:82c18791b8862ea095081f745b81f896b011c5a5091678fb33204fef641476af"}, - {file = "matplotlib-3.9.1.post1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:621a628389c09a6b9f609a238af8e66acecece1cfa12febc5fe4195114ba7446"}, - {file = "matplotlib-3.9.1.post1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:9a54734ca761ebb27cd4f0b6c2ede696ab6861052d7d7e7b8f7a6782665115f5"}, - {file = "matplotlib-3.9.1.post1-cp310-cp310-win_amd64.whl", hash = "sha256:0721f93db92311bb514e446842e2b21c004541dcca0281afa495053e017c5458"}, - {file = "matplotlib-3.9.1.post1-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:b08b46058fe2a31ecb81ef6aa3611f41d871f6a8280e9057cb4016cb3d8e894a"}, - {file = "matplotlib-3.9.1.post1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:22b344e84fcc574f561b5731f89a7625db8ef80cdbb0026a8ea855a33e3429d1"}, - {file = "matplotlib-3.9.1.post1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4b49fee26d64aefa9f061b575f0f7b5fc4663e51f87375c7239efa3d30d908fa"}, - {file = "matplotlib-3.9.1.post1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:89eb7e89e2b57856533c5c98f018aa3254fa3789fcd86d5f80077b9034a54c9a"}, - {file = "matplotlib-3.9.1.post1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:c06e742bade41fda6176d4c9c78c9ea016e176cd338e62a1686384cb1eb8de41"}, - {file = "matplotlib-3.9.1.post1-cp311-cp311-win_amd64.whl", hash = "sha256:c44edab5b849e0fc1f1c9d6e13eaa35ef65925f7be45be891d9784709ad95561"}, - {file = "matplotlib-3.9.1.post1-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:bf28b09986aee06393e808e661c3466be9c21eff443c9bc881bce04bfbb0c500"}, - {file = "matplotlib-3.9.1.post1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:92aeb8c439d4831510d8b9d5e39f31c16c7f37873879767c26b147cef61e54cd"}, - {file = "matplotlib-3.9.1.post1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f15798b0691b45c80d3320358a88ce5a9d6f518b28575b3ea3ed31b4bd95d009"}, - {file = "matplotlib-3.9.1.post1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d59fc6096da7b9c1df275f9afc3fef5cbf634c21df9e5f844cba3dd8deb1847d"}, - {file = "matplotlib-3.9.1.post1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:ab986817a32a70ce22302438691e7df4c6ee4a844d47289db9d583d873491e0b"}, - {file = "matplotlib-3.9.1.post1-cp312-cp312-win_amd64.whl", hash = "sha256:0d78e7d2d86c4472da105d39aba9b754ed3dfeaeaa4ac7206b82706e0a5362fa"}, - {file = "matplotlib-3.9.1.post1-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:bd07eba6431b4dc9253cce6374a28c415e1d3a7dc9f8aba028ea7592f06fe172"}, - {file = "matplotlib-3.9.1.post1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:ca230cc4482010d646827bd2c6d140c98c361e769ae7d954ebf6fff2a226f5b1"}, - {file = "matplotlib-3.9.1.post1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ace27c0fdeded399cbc43f22ffa76e0f0752358f5b33106ec7197534df08725a"}, - {file = "matplotlib-3.9.1.post1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9a4f3aeb7ba14c497dc6f021a076c48c2e5fbdf3da1e7264a5d649683e284a2f"}, - {file = "matplotlib-3.9.1.post1-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:23f96fbd4ff4cfa9b8a6b685a65e7eb3c2ced724a8d965995ec5c9c2b1f7daf5"}, - {file = "matplotlib-3.9.1.post1-cp39-cp39-win_amd64.whl", hash = "sha256:2808b95452b4ffa14bfb7c7edffc5350743c31bda495f0d63d10fdd9bc69e895"}, - {file = "matplotlib-3.9.1.post1-pp39-pypy39_pp73-macosx_10_15_x86_64.whl", hash = "sha256:ffc91239f73b4179dec256b01299d46d0ffa9d27d98494bc1476a651b7821cbe"}, - {file = "matplotlib-3.9.1.post1-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:f965ebca9fd4feaaca45937c4849d92b70653057497181100fcd1e18161e5f29"}, - {file = "matplotlib-3.9.1.post1-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:801ee9323fd7b2da0d405aebbf98d1da77ea430bbbbbec6834c0b3af15e5db44"}, - {file = "matplotlib-3.9.1.post1-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:50113e9b43ceb285739f35d43db36aa752fb8154325b35d134ff6e177452f9ec"}, - {file = "matplotlib-3.9.1.post1.tar.gz", hash = "sha256:c91e585c65092c975a44dc9d4239ba8c594ba3c193d7c478b6d178c4ef61f406"}, + {file = "matplotlib-3.9.2-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:9d78bbc0cbc891ad55b4f39a48c22182e9bdaea7fc0e5dbd364f49f729ca1bbb"}, + {file = "matplotlib-3.9.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c375cc72229614632c87355366bdf2570c2dac01ac66b8ad048d2dabadf2d0d4"}, + {file = "matplotlib-3.9.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1d94ff717eb2bd0b58fe66380bd8b14ac35f48a98e7c6765117fe67fb7684e64"}, + {file = "matplotlib-3.9.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ab68d50c06938ef28681073327795c5db99bb4666214d2d5f880ed11aeaded66"}, + {file = "matplotlib-3.9.2-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:65aacf95b62272d568044531e41de26285d54aec8cb859031f511f84bd8b495a"}, + {file = "matplotlib-3.9.2-cp310-cp310-win_amd64.whl", hash = "sha256:3fd595f34aa8a55b7fc8bf9ebea8aa665a84c82d275190a61118d33fbc82ccae"}, + {file = "matplotlib-3.9.2-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:d8dd059447824eec055e829258ab092b56bb0579fc3164fa09c64f3acd478772"}, + {file = "matplotlib-3.9.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:c797dac8bb9c7a3fd3382b16fe8f215b4cf0f22adccea36f1545a6d7be310b41"}, + {file = "matplotlib-3.9.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d719465db13267bcef19ea8954a971db03b9f48b4647e3860e4bc8e6ed86610f"}, + {file = "matplotlib-3.9.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8912ef7c2362f7193b5819d17dae8629b34a95c58603d781329712ada83f9447"}, + {file = "matplotlib-3.9.2-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:7741f26a58a240f43bee74965c4882b6c93df3e7eb3de160126d8c8f53a6ae6e"}, + {file = "matplotlib-3.9.2-cp311-cp311-win_amd64.whl", hash = "sha256:ae82a14dab96fbfad7965403c643cafe6515e386de723e498cf3eeb1e0b70cc7"}, + {file = "matplotlib-3.9.2-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:ac43031375a65c3196bee99f6001e7fa5bdfb00ddf43379d3c0609bdca042df9"}, + {file = "matplotlib-3.9.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:be0fc24a5e4531ae4d8e858a1a548c1fe33b176bb13eff7f9d0d38ce5112a27d"}, + {file = "matplotlib-3.9.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bf81de2926c2db243c9b2cbc3917619a0fc85796c6ba4e58f541df814bbf83c7"}, + {file = "matplotlib-3.9.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f6ee45bc4245533111ced13f1f2cace1e7f89d1c793390392a80c139d6cf0e6c"}, + {file = "matplotlib-3.9.2-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:306c8dfc73239f0e72ac50e5a9cf19cc4e8e331dd0c54f5e69ca8758550f1e1e"}, + {file = "matplotlib-3.9.2-cp312-cp312-win_amd64.whl", hash = "sha256:5413401594cfaff0052f9d8b1aafc6d305b4bd7c4331dccd18f561ff7e1d3bd3"}, + {file = "matplotlib-3.9.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:18128cc08f0d3cfff10b76baa2f296fc28c4607368a8402de61bb3f2eb33c7d9"}, + {file = "matplotlib-3.9.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:4876d7d40219e8ae8bb70f9263bcbe5714415acfdf781086601211335e24f8aa"}, + {file = "matplotlib-3.9.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6d9f07a80deab4bb0b82858a9e9ad53d1382fd122be8cde11080f4e7dfedb38b"}, + {file = "matplotlib-3.9.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f7c0410f181a531ec4e93bbc27692f2c71a15c2da16766f5ba9761e7ae518413"}, + {file = "matplotlib-3.9.2-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:909645cce2dc28b735674ce0931a4ac94e12f5b13f6bb0b5a5e65e7cea2c192b"}, + {file = "matplotlib-3.9.2-cp313-cp313-win_amd64.whl", hash = "sha256:f32c7410c7f246838a77d6d1eff0c0f87f3cb0e7c4247aebea71a6d5a68cab49"}, + {file = "matplotlib-3.9.2-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:37e51dd1c2db16ede9cfd7b5cabdfc818b2c6397c83f8b10e0e797501c963a03"}, + {file = "matplotlib-3.9.2-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:b82c5045cebcecd8496a4d694d43f9cc84aeeb49fe2133e036b207abe73f4d30"}, + {file = "matplotlib-3.9.2-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f053c40f94bc51bc03832a41b4f153d83f2062d88c72b5e79997072594e97e51"}, + {file = "matplotlib-3.9.2-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dbe196377a8248972f5cede786d4c5508ed5f5ca4a1e09b44bda889958b33f8c"}, + {file = "matplotlib-3.9.2-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:5816b1e1fe8c192cbc013f8f3e3368ac56fbecf02fb41b8f8559303f24c5015e"}, + {file = "matplotlib-3.9.2-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:cef2a73d06601437be399908cf13aee74e86932a5ccc6ccdf173408ebc5f6bb2"}, + {file = "matplotlib-3.9.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:e0830e188029c14e891fadd99702fd90d317df294c3298aad682739c5533721a"}, + {file = "matplotlib-3.9.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:03ba9c1299c920964e8d3857ba27173b4dbb51ca4bab47ffc2c2ba0eb5e2cbc5"}, + {file = "matplotlib-3.9.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1cd93b91ab47a3616b4d3c42b52f8363b88ca021e340804c6ab2536344fad9ca"}, + {file = "matplotlib-3.9.2-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:6d1ce5ed2aefcdce11904fc5bbea7d9c21fff3d5f543841edf3dea84451a09ea"}, + {file = "matplotlib-3.9.2-cp39-cp39-win_amd64.whl", hash = "sha256:b2696efdc08648536efd4e1601b5fd491fd47f4db97a5fbfd175549a7365c1b2"}, + {file = "matplotlib-3.9.2-pp39-pypy39_pp73-macosx_10_15_x86_64.whl", hash = "sha256:d52a3b618cb1cbb769ce2ee1dcdb333c3ab6e823944e9a2d36e37253815f9556"}, + {file = "matplotlib-3.9.2-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:039082812cacd6c6bec8e17a9c1e6baca230d4116d522e81e1f63a74d01d2e21"}, + {file = "matplotlib-3.9.2-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6758baae2ed64f2331d4fd19be38b7b4eae3ecec210049a26b6a4f3ae1c85dcc"}, + {file = "matplotlib-3.9.2-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:050598c2b29e0b9832cde72bcf97627bf00262adbc4a54e2b856426bb2ef0697"}, + {file = "matplotlib-3.9.2.tar.gz", hash = "sha256:96ab43906269ca64a6366934106fa01534454a69e471b7bf3d79083981aaab92"}, ] [package.dependencies] @@ -1175,13 +1186,13 @@ files = [ [[package]] name = "pygal" -version = "3.0.4" +version = "3.0.5" description = "A Python svg graph plotting library" optional = false python-versions = ">=3.8" files = [ - {file = "pygal-3.0.4-py2.py3-none-any.whl", hash = "sha256:e931caf08b4be0e6ec119a4c0e20dbed2d77829c641b7dea0ed21fe6ec81f2ea"}, - {file = "pygal-3.0.4.tar.gz", hash = "sha256:6c5da33f1041e8b30cbc980f8a34910d9edc584b833240298f6a25df65425289"}, + {file = "pygal-3.0.5-py3-none-any.whl", hash = "sha256:a3268a5667b470c8fbbb0eca7e987561a7321caeba589d40e4c1bc16dbe71393"}, + {file = "pygal-3.0.5.tar.gz", hash = "sha256:c0a0f34e5bc1c01975c2bfb8342ad521e293ad42e525699dd00c4d7a52c14b71"}, ] [package.dependencies] @@ -1190,8 +1201,9 @@ importlib-metadata = "*" [package.extras] docs = ["pygal-sphinx-directives", "sphinx", "sphinx-rtd-theme"] lxml = ["lxml"] +moulinrouge = ["flask", "pygal-maps-ch", "pygal-maps-fr", "pygal-maps-world"] png = ["cairosvg"] -test = ["cairosvg", "coveralls", "flake8", "flask", "lxml", "pygal-maps-ch", "pygal-maps-fr", "pygal-maps-world", "pyquery", "pytest", "pytest-cov", "pytest-isort", "pytest-runner"] +test = ["cairosvg", "coveralls", "lxml", "pyquery", "pytest", "pytest-cov", "ruff (>=0.5.6)"] [[package]] name = "pygaljs" diff --git a/pyproject.toml b/pyproject.toml index 7f09ed0a..53311a48 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "znh5md" -version = "0.3.4" +version = "0.3.5" description = "ASE Interface for the H5MD format." authors = ["zincwarecode "] license = "Apache-2.0" diff --git a/tests/test_data_types.py b/tests/test_data_types.py new file mode 100644 index 00000000..693f1849 --- /dev/null +++ b/tests/test_data_types.py @@ -0,0 +1,38 @@ +import ase.build +import pytest + +import znh5md + + +def test_smiles(tmp_path): + io = znh5md.IO(tmp_path / "test.h5") + molecule = ase.build.molecule("H2O") + molecule.info["smiles"] = "O" + + io.append(molecule) + assert io[0].info["smiles"] == "O" + + molecule = ase.build.molecule("H2O2") + molecule.info["smiles"] = "OO" + + io.append(molecule) + assert io[0].info["smiles"] == "O" + assert io[1].info["smiles"] == "OO" + + +def test_very_long_text_data(tmp_path): + io = znh5md.IO(tmp_path / "test.h5") + molecule = ase.build.molecule("H2O") + + molecule.info["test"] = f"{list(range(1_000))}" + with pytest.raises(ValueError, match="String test is too long to be stored."): + io.append(molecule) + + +def test_int_info_data(tmp_path): + io = znh5md.IO(tmp_path / "test.h5") + molecule = ase.build.molecule("H2O") + molecule.info["test"] = 123 + + io.append(molecule) + assert io[0].info["test"] == 123 diff --git a/tests/test_recreate_h5py_file.py b/tests/test_recreate_h5py_file.py new file mode 100644 index 00000000..6cd34a54 --- /dev/null +++ b/tests/test_recreate_h5py_file.py @@ -0,0 +1,20 @@ +import ase.build +import pytest + +import znh5md + + +def test_extend_wrong_error(tmp_path): + io = znh5md.IO(tmp_path / "test.h5") + molecule = ase.build.molecule("H2O") + + with pytest.raises(ValueError, match="images must be a list of ASE Atoms objects"): + io.extend(molecule) + + +def test_append_wrong_error(tmp_path): + io = znh5md.IO(tmp_path / "test.h5") + molecule = ase.build.molecule("H2O") + + with pytest.raises(ValueError, match="atoms must be an ASE Atoms object"): + io.append([molecule]) diff --git a/tests/test_znh5md.py b/tests/test_znh5md.py index 212b47a5..e9d0ab63 100644 --- a/tests/test_znh5md.py +++ b/tests/test_znh5md.py @@ -2,4 +2,4 @@ def test_version(): - assert znh5md.__version__ == "0.3.4" + assert znh5md.__version__ == "0.3.5" diff --git a/znh5md/format.py b/znh5md/format.py index 5304f25d..dbd02738 100644 --- a/znh5md/format.py +++ b/znh5md/format.py @@ -7,7 +7,7 @@ from ase import Atoms from ase.calculators.calculator import all_properties -from .utils import concatenate_varying_shape_arrays +from .utils import NUMPY_STRING_DTYPE, concatenate_varying_shape_arrays class ASEKeyMetaData(TypedDict): @@ -201,7 +201,12 @@ def extract_atoms_data(atoms: Atoms, use_ase_calc: bool = True) -> ASEData: # n if use_ase_calc and key in all_properties: raise ValueError(f"Key {key} is reserved for ASE calculator results.") if key not in ASE_TO_H5MD and key not in CustomINFOData.__members__: - info_data[key] = value + if isinstance(value, str): + if len(value) > NUMPY_STRING_DTYPE.itemsize: + raise ValueError(f"String {key} is too long to be stored.") + info_data[key] = np.array(value, dtype=NUMPY_STRING_DTYPE) + else: + info_data[key] = value for key, value in atoms.arrays.items(): if use_ase_calc and key in all_properties: diff --git a/znh5md/io.py b/znh5md/io.py index 2782f007..afb26f52 100644 --- a/znh5md/io.py +++ b/znh5md/io.py @@ -180,6 +180,8 @@ def _extract_additional_data(self, f, index, arrays_data, calc_data, info_data): ) def extend(self, images: List[ase.Atoms]): + if not isinstance(images, list): + raise ValueError("images must be a list of ASE Atoms objects") if len(images) == 0: warnings.warn("No data provided") return @@ -262,7 +264,7 @@ def _create_group( ds_value = g_grp.create_dataset( "value", data=data, - dtype=np.float64, + dtype=utils.get_h5py_dtype(data), chunks=True if self.chunk_size is None else tuple([self.chunk_size] + list(data.shape[1:])), @@ -336,7 +338,7 @@ def _create_observables( ds_value = g_observable.create_dataset( "value", data=value, - dtype=np.float64, + dtype=utils.get_h5py_dtype(value), chunks=True if self.chunk_size is None else tuple([self.chunk_size] + list(value.shape[1:])), @@ -430,6 +432,8 @@ def _extend_observables( utils.fill_dataset(g_val["step"], step) def append(self, atoms: ase.Atoms): + if not isinstance(atoms, ase.Atoms): + raise ValueError("atoms must be an ASE Atoms object") self.extend([atoms]) def __delitem__(self, index): diff --git a/znh5md/utils.py b/znh5md/utils.py index 3e0da0f7..ffdcffa8 100644 --- a/znh5md/utils.py +++ b/znh5md/utils.py @@ -1,7 +1,10 @@ import ase +import h5py import numpy as np from ase.calculators.singlepoint import SinglePointCalculator +NUMPY_STRING_DTYPE = np.dtype("S512") + def concatenate_varying_shape_arrays(arrays: list[np.ndarray]) -> np.ndarray: """Concatenate arrays of varying lengths into a numpy array. @@ -122,7 +125,12 @@ def build_atoms(args) -> ase.Atoms: key: remove_nan_rows(value) for key, value in arrays_data.items() } if info_data is not None: - info_data = {key: remove_nan_rows(value) for key, value in info_data.items()} + # We update the info_data in place + for key, value in info_data.items(): + if isinstance(value, bytes): + info_data[key] = value.decode("utf-8") + else: + info_data[key] = remove_nan_rows(value) atoms = ase.Atoms( symbols=atomic_numbers, @@ -170,3 +178,10 @@ def build_structures( ) structures.append(build_atoms(args)) return structures + + +def get_h5py_dtype(data: np.ndarray): + if data.dtype == NUMPY_STRING_DTYPE: + return h5py.string_dtype(encoding="utf-8") + else: + return data.dtype