Skip to content

Commit

Permalink
Added data generation utils
Browse files Browse the repository at this point in the history
  • Loading branch information
NikoOinonen committed Jan 24, 2024
1 parent 114d3f2 commit 230af32
Show file tree
Hide file tree
Showing 3 changed files with 157 additions and 80 deletions.
82 changes: 2 additions & 80 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,6 @@ share/python-wheels/
*.egg
MANIFEST

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt
Expand All @@ -52,26 +46,6 @@ coverage.xml
.pytest_cache/
cover/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
.pybuilder/
target/
Expand All @@ -83,43 +57,9 @@ target/
profile_default/
ipython_config.py

# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version

# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock

# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock

# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml

# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
Expand All @@ -133,29 +73,11 @@ venv.bak/
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/

# pytype static type analyzer
.pytype/

# Cython debug symbols
cython_debug/
# VS Code
.vscode

# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
109 changes: 109 additions & 0 deletions mlspm/data_generation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@

import io
import os
import tarfile
import time
from os import PathLike
from pathlib import Path
from typing import List, Optional

import numpy as np
from PIL import Image


class TarWriter:
'''
Write samples of AFM images, molecules and descriptors to tar files. Use as a context manager and add samples with
:meth:`add_sample`.
Each tar file has a maximum number of samples, and whenever that maximum is reached, a new tar file is created.
The generated tar files are named as ``{base_name}_{n}.tar`` and saved into the specified folder. The current tar file
handle is always available in the attribute :attr:`ft`, and is automatically closed when the context ends.
Arguments:
base_path: Path to directory where tar files are saved.
base_name: Base name for output tar files. The number of the tar file is appended to the name.
max_count: Maximum number of samples per tar file.
png_compress_level: Compression level 1-9 for saved png images. Larger value for smaller file size but slower
write speed.
'''

def __init__(self, base_path: PathLike='./', base_name: str='', max_count: int=100, png_compress_level=4):
self.base_path = Path(base_path)
self.base_name = base_name
self.max_count = max_count
self.png_compress_level = png_compress_level

def __enter__(self):
self.sample_count = 0
self.total_count = 0
self.tar_count = 0
self.ft = self._get_tar_file()
return self

def __exit__(self, exc_type, exc_value, exc_traceback):
self.ft.close()

def _get_tar_file(self):
file_path = self.base_path / f'{self.base_name}_{self.tar_count}.tar'
if os.path.exists(file_path):
raise RuntimeError(f'Tar file already exists at `{file_path}`')
return tarfile.open(file_path, 'w', format=tarfile.GNU_FORMAT)

def add_sample(self, X: List[np.ndarray], xyzs: np.ndarray, Y: Optional[np.ndarray]=None, comment_str: str=''):
"""
Add a sample to the current tar file.
Arguments:
X: AFM images. Each list item corresponds to an AFM tip and is an array of shape (nx, ny, nz).
xyzs: Atom coordinates and elements. Each row is one atom and is of the form [x, y, z, element].
Y: Image descriptors. Each list item is one descriptor and is an array of shape (nx, ny).
comment_str: Comment line (second line) to add to the xyz file.
"""

if self.sample_count >= self.max_count:
self.tar_count += 1
self.sample_count = 0
self.ft.close()
self.ft = self._get_tar_file()

# Write AFM images
for i, x in enumerate(X):
for j in range(x.shape[-1]):
xj = x[:, :, j]
xj = ((xj - xj.min()) / np.ptp(xj) * (2**8 - 1)).astype(np.uint8) # Convert range to 0-255 integers
img_bytes = io.BytesIO()
Image.fromarray(xj.T[::-1], mode='L').save(img_bytes, 'png', compress_level=self.png_compress_level)
img_bytes.seek(0) # Return stream to start so that addfile can read it correctly
self.ft.addfile(get_tarinfo(f'{self.total_count}.{j:02d}.{i}.png', img_bytes), img_bytes)
img_bytes.close()

# Write xyz file
xyz_bytes = io.BytesIO()
xyz_bytes.write(bytearray(f'{len(xyzs)}\n{comment_str}\n', 'utf-8'))
for xyz in xyzs:
xyz_bytes.write(bytearray(f'{int(xyz[-1])}\t', 'utf-8'))
for i in range(len(xyz)-1):
xyz_bytes.write(bytearray(f'{xyz[i]:10.8f}\t', 'utf-8'))
xyz_bytes.write(bytearray('\n', 'utf-8'))
xyz_bytes.seek(0) # Return stream to start so that addfile can read it correctly
self.ft.addfile(get_tarinfo(f'{self.total_count}.xyz', xyz_bytes), xyz_bytes)
xyz_bytes.close()

# Write image descriptors (if any)
if Y is not None:
for i, y in enumerate(Y):
img_bytes = io.BytesIO()
np.save(img_bytes, y.astype(np.float32))
img_bytes.seek(0) # Return stream to start so that addfile can read it correctly
self.ft.addfile(get_tarinfo(f'{self.total_count}.desc_{i}.npy', img_bytes), img_bytes)
img_bytes.close()

self.sample_count += 1
self.total_count += 1

def get_tarinfo(fname: str, file_bytes: io.BytesIO):
info = tarfile.TarInfo(fname)
info.size = file_bytes.getbuffer().nbytes
info.mtime = time.time()
return info
46 changes: 46 additions & 0 deletions tests/test_data_generation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@

from pathlib import Path
from shutil import rmtree
import tarfile
import numpy as np
import pytest


def test_tar_writer():

from mlspm.data_generation import TarWriter

base_path = Path('./test_writer')
base_name = 'test'

base_path.mkdir(exist_ok=True)

with TarWriter(base_path, base_name, max_count=10) as tar_writer:
for _ in range(20):
X = [np.random.rand(128, 128, 10), np.random.rand(128, 128, 10)]
Y = [np.random.rand(128, 128), np.random.rand(128, 128)]
xyzs = np.concatenate([np.random.rand(10, 3), np.random.randint(1, 10, (10, 1))], axis=1)
tar_writer.add_sample(X, xyzs, Y, comment_str='test comment')

assert (base_path / 'test_0.tar').exists()
assert (base_path / 'test_1.tar').exists()

with tarfile.open(base_path / 'test_0.tar') as ft:
names = [m.name for m in ft.getmembers()]
assert len(names) == 10 * (2 * 10 + 1 + 2)
assert "0.00.0.png" in names
assert "0.00.1.png" in names
assert "0.09.1.png" in names
assert "0.00.2.png" not in names
assert "0.10.0.png" not in names
assert "0.xyz" in names
assert "0.desc_0.npy" in names

with pytest.raises(RuntimeError):
# Cannot overwrite an existing file
with TarWriter(base_path, base_name, max_count=10) as tar_writer:
pass

rmtree(base_path)

test_tar_writer()

0 comments on commit 230af32

Please sign in to comment.