Skip to content

Commit

Permalink
Merge branch 'main' into ci-tests
Browse files Browse the repository at this point in the history
  • Loading branch information
NikoOinonen committed Jan 15, 2024
2 parents a7e9eb3 + 0eba5b4 commit b41cf8b
Show file tree
Hide file tree
Showing 30 changed files with 2,215 additions and 45 deletions.
32 changes: 32 additions & 0 deletions .github/workflows/ci_pr.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@

name: CI on PR
on: pull_request

jobs:
run-tests:

runs-on: ubuntu-latest

steps:

- name: Checkout repository
uses: actions/checkout@v4

- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: 3.11

- name: Install package
run: |
bash ./build.sh
pip install -e .[dev]
- name: Run pytest
run: pytest --junitxml=pytest.xml --cov-report=term-missing --cov mlspm tests | tee pytest-coverage.txt

- name: Pytest coverage comment
uses: MishaKav/pytest-coverage-comment@main
with:
pytest-coverage-path: ./pytest-coverage.txt
junitxml-path: ./pytest.xml
33 changes: 33 additions & 0 deletions .github/workflows/ci_push.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@

name: CI on push
on: push

jobs:
run-tests:

runs-on: ubuntu-latest
strategy:
matrix:
python-version: ['3.10', '3.11'] # Add 3.12 when pytorch has a wheel

steps:

- name: Checkout repository
uses: actions/checkout@v4

- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}

- name: Install package
run: |
bash ./build.sh
pip install -e .[dev]
- name: Lint with flake8
run: flake8 .

- name: Run pytest
run: pytest

46 changes: 46 additions & 0 deletions .github/workflows/coverage_badge.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@

name: Update coverage badge
on:
push:
branches:
- main

jobs:
update-badge:

runs-on: ubuntu-latest

steps:

- name: Checkout repository
uses: actions/checkout@v4

- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: 3.11

- name: Install package
run: |
bash ./build.sh
pip install -e .[dev]
- name: Run pytest
run: pytest --cov-report=term-missing --cov=mlspm tests | tee pytest-coverage.txt

- name: Pytest coverage comment
id: coverage_comment
uses: MishaKav/pytest-coverage-comment@main
with:
hide-comment: true
pytest-coverage-path: ./pytest-coverage.txt

- name: Create coverage badge
uses: schneegans/[email protected]
with:
auth: ${{ secrets.COVERAGE_GIST_SECRET }}
gistID: 913d30e2a2e333eb407353072948042d
filename: coverage.json
label: Coverage
message: ${{ steps.coverage_comment.outputs.coverage }}
color: ${{ steps.coverage_comment.outputs.color }}
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Machine learning for scanning probe microscopy

[![Documentation Status](https://readthedocs.org/projects/ml-spm/badge/?version=latest)](https://ml-spm.readthedocs.io/en/latest/?badge=latest)

![badge](https://img.shields.io/endpoint?url=https://gist.githubusercontent.com/NikoOinonen/913d30e2a2e333eb407353072948042d/raw/coverage.json)
## Installation

Install pre-requisites:
Expand Down
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ dependencies:
- torchvision
- torchaudio
- cuda
- ninja
- matplotlib
- numpy
- scipy
Expand Down
8 changes: 6 additions & 2 deletions mlspm/cli.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse
from typing import Optional


def _bool_type(value):
Expand All @@ -9,10 +10,13 @@ def _bool_type(value):
raise KeyError(f"`{value}` can't be interpreted as a boolean.")


def parse_args() -> dict:
def parse_args(argv: Optional[list[str]] = None) -> dict:
"""
Parse some useful CLI arguments for use in training scripts.
Arguments:
argv: List of argument values. Defaults to ``sys.argv``.
Returns:
A dictionary of the argument values.
"""
Expand Down Expand Up @@ -68,5 +72,5 @@ def parse_args() -> dict:
parser.add_argument(
"--avg_best_epochs", type=int, default=3, help="Number of epochs to average the best validation loss over. Default = 3."
)
args = parser.parse_args()
args = parser.parse_args(argv)
return vars(args)
6 changes: 3 additions & 3 deletions mlspm/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def decode_xyz(key: str, data: Any) -> Tuple[np.ndarray, np.ndarray] | Tuple[Non
sw = get_scan_window_from_comment(comment)
xyz = []
while line := data.readline().decode("utf-8"):
e, x, y, z, _ = line.strip().split()
e, x, y, z = line.strip().split()[:4]
try:
e = int(e)
except ValueError:
Expand Down Expand Up @@ -184,7 +184,7 @@ def get_scan_window_from_comment(comment: str) -> np.ndarray:
return sw


def _rotate_and_stack(src: Iterable[dict], reverse: bool = True) -> Generator[dict, None, None]:
def _rotate_and_stack(src: Iterable[dict], reverse: bool = False) -> Generator[dict, None, None]:
"""
Take a sample in dict format and update it with fields containing an image stack, xyz coordinates and scan window.
Rotate the images to be xy-indexing convention and stack them into a single array.
Expand All @@ -194,7 +194,7 @@ def _rotate_and_stack(src: Iterable[dict], reverse: bool = True) -> Generator[di
Arguments:
src: Iterable of dicts with the fields:
- ``'{000..0xx}.jpg'`` - :class:`PIL.Image.Image` of one slice of the simulation.
- ``'{000..0xx}.{jpg,png}'`` - :class:`PIL.Image.Image` of one slice of the simulation.
- ``'xyz'`` - Tuple(:class:`np.ndarray`, :class:`np.ndarray`) of the xyz data and the scan window.
reverse: Whether the order of the image stack is reversed.
Expand Down
14 changes: 13 additions & 1 deletion mlspm/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
"AFM-ice-Au111-monolayer": "https://zenodo.org/records/10049832/files/AFM-ice-Au111-monolayer.tar.gz?download=1",
"AFM-ice-Au111-bilayer": "https://zenodo.org/records/10049856/files/AFM-ice-Au111-bilayer.tar.gz?download=1",
"AFM-ice-exp": "https://zenodo.org/records/10054847/files/exp_data_ice.tar.gz?download=1",
"AFM-ice-relaxed": "https://zenodo.org/records/10362511/files/relaxed_structures.tar.gz?download=1",
}


Expand All @@ -29,6 +30,16 @@ def _safe_extract(tar, path=".", members=None, *, numeric_owner=False):
raise Exception("Attempted Path Traversal in Tar File")
tar.extractall(path, members, numeric_owner=numeric_owner)

def _common_parent(paths):
path_parts = [list(Path(p).parts) for p in paths]
common_part = Path()
for parts in zip(*path_parts):
p = parts[0]
if all(part == p for part in parts):
common_part /= p
else:
break
return common_part

def download_dataset(name: str, target_dir: PathLike):
"""
Expand All @@ -40,6 +51,7 @@ def download_dataset(name: str, target_dir: PathLike):
- ``'AFM-ice-Au111-monolayer'``: https://doi.org/10.5281/zenodo.10049832
- ``'AFM-ice-Au111-bilayer'``: https://doi.org/10.5281/zenodo.10049856
- ``'AFM-ice-exp'``: https://doi.org/10.5281/zenodo.10054847
- ``'AFM-ice-relaxed'``: https://doi.org/10.5281/zenodo.10362511
Arguments:
name: Name of dataset to download.
Expand All @@ -64,7 +76,7 @@ def download_dataset(name: str, target_dir: PathLike):
with tarfile.open(temp_file, "r") as ft:
print("Reading archive files...")
members = []
base_dir = os.path.commonprefix(ft.getnames())
base_dir = _common_parent(ft.getnames())
for m in ft.getmembers():
if m.isfile():
# relative_to(base_dir) here gets rid of a common parent directory within the archive (if any),
Expand Down
9 changes: 6 additions & 3 deletions mlspm/graph/_molecule_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

from ..utils import elements

from typing import Iterable, Tuple, Self
from typing import Iterable, Tuple
from typing_extensions import Self


class Atom:
Expand Down Expand Up @@ -44,7 +45,7 @@ def __init__(

if q is None:
q = 0
self.q = 0
self.q = q

if classes is not None:
assert class_weights is None, "Cannot have both classes and class_weights not be None."
Expand Down Expand Up @@ -275,8 +276,10 @@ def transform_xy(
"""
Transform atom positions in the xy plane.
Transformations are perfomed in the order: shift, rotate, flip x, flip y
Arguments:
shift: Shift atom positions in xy plane. Performed before rotation and flip.
shift: Shift atom positions in xy plane.
rot_xy: Rotate atoms in xy plane by rot_xy degrees around center point.
flip_x: Mirror atom positions in x direction with respect to the center point.
flip_y: Mirror atom positions in y direction with respect to the center point.
Expand Down
1 change: 1 addition & 0 deletions mlspm/graph/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,7 @@ def save_graphs_to_xyzs(
Arguments:
molecules: Molecule graphs to save.
classes: Chemical elements for atom classification. Either atomic numbers of chemical symbols.
The element for each atom in the graph is the first element in the corresponding class.
outfile_format: Formatting string for saved files. Sample index is available in variable ``ind``.
start_ind: Index where file numbering starts.
verbose: Whether to print output information.
Expand Down
4 changes: 0 additions & 4 deletions mlspm/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,10 +204,6 @@ def __init__(
self._synced_losses = {"train": SyncedLoss(len(self.loss_labels)), "val": SyncedLoss(len(self.loss_labels))}
self._init_log(init_epoch)

def __del__(self):
if self.stream is not sys.stdout:
self.stream.close()

def _init_log(self, init_epoch: Optional[int]):
log_exists = os.path.isfile(self.log_path)
if self.world_size > 1:
Expand Down
8 changes: 7 additions & 1 deletion papers/ice_structure_discovery/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ This folder contains the source code and links to the datasets that were used fo

The subdirectories contain various scripts for training and running predictions with the models:
- `training`: Scripts for training the atom position and graph construction models, and evaluating the trained models.
- `prediction`: Scripts for reproducing the result in Fig. 2 of the paper using the pretrained models.
- `predictions`: Scripts for reproducing the results figures of the paper using the pretrained models.

## Data

Expand All @@ -25,4 +25,10 @@ Training datasets:

Experimental data: https://doi.org/10.5281/zenodo.10054847

Final relaxed geometries: https://doi.org/10.5281/zenodo.10362511

Pretrained weights for the models: https://doi.org/10.5281/zenodo.10054348

Training data and configuration files for training the Nequip neural network potentials:
- Cu(111): https://doi.org/10.5281/zenodo.10371802
- Au(111): https://doi.org/10.5281/zenodo.10371791
6 changes: 4 additions & 2 deletions papers/ice_structure_discovery/predictions/README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
The scripts here can be used to reproduce the results in Fig. 2 of the paper.
The scripts here can be used to reproduce the results figures in the paper.
- `predict_experiments.py`: Runs the prediction for all of the experimental AFM images of ice on Cu(111) and Au(111) using the three models pretrained on the Cu(111), Au(111)-monolayer, and Au(111)-bilayer datasets, and saves them on disk.
- `plot_predictions.py`: Picks the appropriate predictions for each experiment and plots them to a figure.
- `plot_predictions.py`: Picks the appropriate predictions for each experiment and plots them to a figure as in Fig. 2 of the paper.
- `plot_relaxed_structures.py`: Plots the on-surface structures relaxed with a neural network potential and DFT as well as the corresponding simulations and experimental images as in Fig. 3 of the paper.
- `plot_prediction_extra.py`: Plots the prediction and the relaxed structure with corresponding simulations and experimental images for the one extra ice cluster not in the main results figure.
Loading

0 comments on commit b41cf8b

Please sign in to comment.