diff --git a/.gitignore b/.gitignore index 94fd6dc..d18fa2a 100644 --- a/.gitignore +++ b/.gitignore @@ -27,12 +27,6 @@ share/python-wheels/ *.egg MANIFEST -# PyInstaller -# Usually these files are written by a python script from a template -# before PyInstaller builds the exe, so as to inject date/other infos into it. -*.manifest -*.spec - # Installer logs pip-log.txt pip-delete-this-directory.txt @@ -52,26 +46,6 @@ coverage.xml .pytest_cache/ cover/ -# Translations -*.mo -*.pot - -# Django stuff: -*.log -local_settings.py -db.sqlite3 -db.sqlite3-journal - -# Flask stuff: -instance/ -.webassets-cache - -# Scrapy stuff: -.scrapy - -# Sphinx documentation -docs/_build/ - # PyBuilder .pybuilder/ target/ @@ -83,43 +57,9 @@ target/ profile_default/ ipython_config.py -# pyenv -# For a library or package, you might want to ignore these files since the code is -# intended to run in multiple environments; otherwise, check them in: -# .python-version - -# pipenv -# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. -# However, in case of collaboration, if having platform-specific dependencies or dependencies -# having no cross-platform support, pipenv may install dependencies that don't work, or not -# install all needed dependencies. -#Pipfile.lock - -# poetry -# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. -# This is especially recommended for binary packages to ensure reproducibility, and is more -# commonly ignored for libraries. -# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control -#poetry.lock - -# pdm -# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. -#pdm.lock -# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it -# in version control. -# https://pdm.fming.dev/#use-with-ide -.pdm.toml - # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm __pypackages__/ -# Celery stuff -celerybeat-schedule -celerybeat.pid - -# SageMath parsed files -*.sage.py - # Environments .env .venv @@ -133,29 +73,11 @@ venv.bak/ .spyderproject .spyproject -# Rope project settings -.ropeproject - -# mkdocs documentation -/site - # mypy .mypy_cache/ .dmypy.json dmypy.json -# Pyre type checker -.pyre/ - -# pytype static type analyzer -.pytype/ - -# Cython debug symbols -cython_debug/ +# VS Code +.vscode -# PyCharm -# JetBrains specific template is maintained in a separate JetBrains.gitignore that can -# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore -# and can be added to the global gitignore or merged into this file. For a more nuclear -# option (not recommended) you can uncomment the following to ignore the entire idea folder. -#.idea/ diff --git a/mlspm/data_generation.py b/mlspm/data_generation.py new file mode 100644 index 0000000..58734fd --- /dev/null +++ b/mlspm/data_generation.py @@ -0,0 +1,109 @@ + +import io +import os +import tarfile +import time +from os import PathLike +from pathlib import Path +from typing import List, Optional + +import numpy as np +from PIL import Image + + +class TarWriter: + ''' + Write samples of AFM images, molecules and descriptors to tar files. Use as a context manager and add samples with + :meth:`add_sample`. + + Each tar file has a maximum number of samples, and whenever that maximum is reached, a new tar file is created. + The generated tar files are named as ``{base_name}_{n}.tar`` and saved into the specified folder. The current tar file + handle is always available in the attribute :attr:`ft`, and is automatically closed when the context ends. + + Arguments: + base_path: Path to directory where tar files are saved. + base_name: Base name for output tar files. The number of the tar file is appended to the name. + max_count: Maximum number of samples per tar file. + png_compress_level: Compression level 1-9 for saved png images. Larger value for smaller file size but slower + write speed. + ''' + + def __init__(self, base_path: PathLike='./', base_name: str='', max_count: int=100, png_compress_level=4): + self.base_path = Path(base_path) + self.base_name = base_name + self.max_count = max_count + self.png_compress_level = png_compress_level + + def __enter__(self): + self.sample_count = 0 + self.total_count = 0 + self.tar_count = 0 + self.ft = self._get_tar_file() + return self + + def __exit__(self, exc_type, exc_value, exc_traceback): + self.ft.close() + + def _get_tar_file(self): + file_path = self.base_path / f'{self.base_name}_{self.tar_count}.tar' + if os.path.exists(file_path): + raise RuntimeError(f'Tar file already exists at `{file_path}`') + return tarfile.open(file_path, 'w', format=tarfile.GNU_FORMAT) + + def add_sample(self, X: List[np.ndarray], xyzs: np.ndarray, Y: Optional[np.ndarray]=None, comment_str: str=''): + """ + Add a sample to the current tar file. + + Arguments: + X: AFM images. Each list item corresponds to an AFM tip and is an array of shape (nx, ny, nz). + xyzs: Atom coordinates and elements. Each row is one atom and is of the form [x, y, z, element]. + Y: Image descriptors. Each list item is one descriptor and is an array of shape (nx, ny). + comment_str: Comment line (second line) to add to the xyz file. + """ + + if self.sample_count >= self.max_count: + self.tar_count += 1 + self.sample_count = 0 + self.ft.close() + self.ft = self._get_tar_file() + + # Write AFM images + for i, x in enumerate(X): + for j in range(x.shape[-1]): + xj = x[:, :, j] + xj = ((xj - xj.min()) / np.ptp(xj) * (2**8 - 1)).astype(np.uint8) # Convert range to 0-255 integers + img_bytes = io.BytesIO() + Image.fromarray(xj.T[::-1], mode='L').save(img_bytes, 'png', compress_level=self.png_compress_level) + img_bytes.seek(0) # Return stream to start so that addfile can read it correctly + self.ft.addfile(get_tarinfo(f'{self.total_count}.{j:02d}.{i}.png', img_bytes), img_bytes) + img_bytes.close() + + # Write xyz file + xyz_bytes = io.BytesIO() + xyz_bytes.write(bytearray(f'{len(xyzs)}\n{comment_str}\n', 'utf-8')) + for xyz in xyzs: + xyz_bytes.write(bytearray(f'{int(xyz[-1])}\t', 'utf-8')) + for i in range(len(xyz)-1): + xyz_bytes.write(bytearray(f'{xyz[i]:10.8f}\t', 'utf-8')) + xyz_bytes.write(bytearray('\n', 'utf-8')) + xyz_bytes.seek(0) # Return stream to start so that addfile can read it correctly + self.ft.addfile(get_tarinfo(f'{self.total_count}.xyz', xyz_bytes), xyz_bytes) + xyz_bytes.close() + + # Write image descriptors (if any) + if Y is not None: + for i, y in enumerate(Y): + img_bytes = io.BytesIO() + np.save(img_bytes, y.astype(np.float32)) + img_bytes.seek(0) # Return stream to start so that addfile can read it correctly + self.ft.addfile(get_tarinfo(f'{self.total_count}.desc_{i}.npy', img_bytes), img_bytes) + img_bytes.close() + + self.sample_count += 1 + self.total_count += 1 + +def get_tarinfo(fname: str, file_bytes: io.BytesIO): + info = tarfile.TarInfo(fname) + info.size = file_bytes.getbuffer().nbytes + info.mtime = time.time() + return info \ No newline at end of file diff --git a/tests/test_data_generation.py b/tests/test_data_generation.py new file mode 100644 index 0000000..ff77e6a --- /dev/null +++ b/tests/test_data_generation.py @@ -0,0 +1,46 @@ + +from pathlib import Path +from shutil import rmtree +import tarfile +import numpy as np +import pytest + + +def test_tar_writer(): + + from mlspm.data_generation import TarWriter + + base_path = Path('./test_writer') + base_name = 'test' + + base_path.mkdir(exist_ok=True) + + with TarWriter(base_path, base_name, max_count=10) as tar_writer: + for _ in range(20): + X = [np.random.rand(128, 128, 10), np.random.rand(128, 128, 10)] + Y = [np.random.rand(128, 128), np.random.rand(128, 128)] + xyzs = np.concatenate([np.random.rand(10, 3), np.random.randint(1, 10, (10, 1))], axis=1) + tar_writer.add_sample(X, xyzs, Y, comment_str='test comment') + + assert (base_path / 'test_0.tar').exists() + assert (base_path / 'test_1.tar').exists() + + with tarfile.open(base_path / 'test_0.tar') as ft: + names = [m.name for m in ft.getmembers()] + assert len(names) == 10 * (2 * 10 + 1 + 2) + assert "0.00.0.png" in names + assert "0.00.1.png" in names + assert "0.09.1.png" in names + assert "0.00.2.png" not in names + assert "0.10.0.png" not in names + assert "0.xyz" in names + assert "0.desc_0.npy" in names + + with pytest.raises(RuntimeError): + # Cannot overwrite an existing file + with TarWriter(base_path, base_name, max_count=10) as tar_writer: + pass + + rmtree(base_path) + +test_tar_writer() \ No newline at end of file