From 114d3f2e7cebce09c0ac9782db68cb0c005ae904 Mon Sep 17 00:00:00 2001 From: NikoOinonen Date: Mon, 15 Jan 2024 19:18:07 +0200 Subject: [PATCH] Added pretrained weights for ASDAFMNet. --- docs/source/reference/index.rst | 1 + docs/source/reference/mlspm.image.rst | 15 +++++++++++ docs/source/reference/mlspm.models.rst | 4 +++ mlspm/_weights.py | 14 ++++++---- mlspm/graph/models.py | 6 ++--- mlspm/image/__init__.py | 1 + mlspm/image/models.py | 37 +++++++++++++++++++++++--- mlspm/models.py | 1 + mlspm/modules.py | 3 --- 9 files changed, 68 insertions(+), 14 deletions(-) create mode 100644 docs/source/reference/mlspm.image.rst create mode 100755 mlspm/image/__init__.py diff --git a/docs/source/reference/index.rst b/docs/source/reference/index.rst index 769a30c..9ac8eb4 100644 --- a/docs/source/reference/index.rst +++ b/docs/source/reference/index.rst @@ -7,6 +7,7 @@ Reference mlspm.data_loading mlspm.datasets mlspm.graph + mlspm.image mlspm.logging mlspm.losses mlspm.models diff --git a/docs/source/reference/mlspm.image.rst b/docs/source/reference/mlspm.image.rst new file mode 100644 index 0000000..3921b3e --- /dev/null +++ b/docs/source/reference/mlspm.image.rst @@ -0,0 +1,15 @@ +mlspm.image +=========== + +.. automodule:: mlspm.image + :members: + :undoc-members: + :show-inheritance: + +mlspm.image.models +------------------ + +.. automodule:: mlspm.image.models + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/reference/mlspm.models.rst b/docs/source/reference/mlspm.models.rst index 8c1249d..de0705f 100644 --- a/docs/source/reference/mlspm.models.rst +++ b/docs/source/reference/mlspm.models.rst @@ -13,4 +13,8 @@ mlspm.models Alias of :class:`mlspm.graph.models.GraphImgNetIce` +.. class:: mlspm.models.ASDAFMNet + + Alias of :class:`mlspm.image.models.ASDAFMNet` + .. autofunction:: mlspm.models.download_weights diff --git a/mlspm/_weights.py b/mlspm/_weights.py index 5baa12d..2a21828 100644 --- a/mlspm/_weights.py +++ b/mlspm/_weights.py @@ -9,6 +9,8 @@ "graph-ice-cu111": "https://zenodo.org/records/10054348/files/weights_ice-cu111.pth?download=1", "graph-ice-au111-monolayer": "https://zenodo.org/records/10054348/files/weights_ice-au111-monolayer.pth?download=1", "graph-ice-au111-bilayer": "https://zenodo.org/records/10054348/files/weights_ice-au111-bilayer.pth?download=1", + "asdafm-light": "https://zenodo.org/records/10514470/files/weights_asdafm_light.pth?download=1", + "asdafm-heavy": "https://zenodo.org/records/10514470/files/weights_asdafm_heavy.pth?download=1", } @@ -18,18 +20,20 @@ def download_weights(weights_name: str, target_path: Optional[PathLike] = None) The following weights are available: - - ``'graph-ice-cu111'``: PosNet trained on ice clusters on Cu(111). - - ``'graph-ice-au111-monolayer'``: PosNet trained on monolayer ice clusters on Au(111). - - ``'graph-ice-au111-bilayer'``: PosNet trained on bilayer ice clusters on Au(111). + - ``'graph-ice-cu111'``: PosNet trained on ice clusters on Cu(111). (https://doi.org/10.5281/zenodo.10054348) + - ``'graph-ice-au111-monolayer'``: PosNet trained on monolayer ice clusters on Au(111). (https://doi.org/10.5281/zenodo.10054348) + - ``'graph-ice-au111-bilayer'``: PosNet trained on bilayer ice clusters on Au(111). (https://doi.org/10.5281/zenodo.10054348) + - ``'asdafm-light'``: ASDAFMNet trained on molecules containing the elements H, C, N, O, and F. (https://doi.org/10.5281/zenodo.10514470) + - ``'asdafm-heavy'``: ASDAFMNet trained on molecules additionally containing Si, P, S, Cl, and Br. (https://doi.org/10.5281/zenodo.10514470) + Arguments: weights_name: Name of weights to download. target_path: Path where the weights file will be saved. If specified, the parent directory for the file has to exists. - If not specified, a location in cache directory is chosen. If the target file already exists, the download is skipped + If not specified, a location in a cache directory is chosen. If the target file already exists, the download is skipped Returns: Path where the weights were saved. - """ try: weights_url = WEIGHTS_URLS[weights_name] diff --git a/mlspm/graph/models.py b/mlspm/graph/models.py index e4855db..b960079 100644 --- a/mlspm/graph/models.py +++ b/mlspm/graph/models.py @@ -654,9 +654,9 @@ class GraphImgNetIce(GraphImgNet): Three sets of pretrained weights are available: - - 'cu111' - - 'au111-monolayer' - - 'au111-bilayer' + - ``'cu111'``: trained on images of ice clusters on Cu(111) + - ``'au111-monolayer'``: trained on images of ice clusters on monolayer Au(111) + - ``'au111-bilayer'``: trained on images of ice clusters on bilayer Au(111) Arguments: pretrained_weights: Name of pretrained weights. If specified, load pretrained weights. Otherwise, weights are initialized diff --git a/mlspm/image/__init__.py b/mlspm/image/__init__.py new file mode 100755 index 0000000..e5a0d9b --- /dev/null +++ b/mlspm/image/__init__.py @@ -0,0 +1 @@ +#!/usr/bin/env python3 diff --git a/mlspm/image/models.py b/mlspm/image/models.py index d097be3..b5d3bd1 100644 --- a/mlspm/image/models.py +++ b/mlspm/image/models.py @@ -1,18 +1,44 @@ from turtle import forward -from typing import Tuple +from typing import Literal, Optional, Tuple +import torch from torch import nn -from ..modules import _get_padding, _flatten_z_to_channels +from ..modules import _get_padding +from .._weights import download_weights + + +def _flatten_z_to_channels(x): + return x.permute(0, 4, 1, 2, 3).reshape(x.size(0), -1, x.size(2), x.size(3)) class ASDAFMNet(nn.Module): + """ + The model used in the paper "Automated structure discovery in atomic force microscopy": https://doi.org/10.1126/sciadv.aay6913. + + Two different sets of pretrained weights are available: + + - ``'asdafm-light'``: trained on a set of molecules containing the elements H, C, N, O, and F. + - ``'asdafm-heavy'``: trained on a set of molecules additionally containing the elements Si, P, S, Cl, and Br. + + Arguments: + n_out: number of output branches. + activation: Activation function used after each layer. + padding_mode: Type of padding in each convolution layer. ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. + last_relu: Whether to use a ReLU layer after the last convolution in each output branch. Either provide a single value + that applies to all output branches or a list of values corresponding to each output branch. + pretrained_weights: Name of pretrained weights. If specified, load pretrained weights. Otherwise, weights are initialized + randomly. If loading the weights, the other arguments should be left in their default values, so that the model + hyperparameters match the training. + """ + def __init__( self, n_out: int = 3, activation: nn.Module = nn.LeakyReLU(0.1), - padding_mode: str = "reflect", + padding_mode: str = "replicate", last_relu: bool | Tuple[bool, ...] = True, + pretrained_weights: Optional[Literal["asdafm-light", "asdafm-heavy"]] = None, ): super().__init__() @@ -68,6 +94,11 @@ def __init__( ] ) + if pretrained_weights is not None: + weights_path = download_weights(pretrained_weights) + weights = torch.load(weights_path) + self.load_state_dict(weights) + def forward(self, x): x = self.encoder(x) x = _flatten_z_to_channels(x) diff --git a/mlspm/models.py b/mlspm/models.py index 71e0051..b0c6b40 100644 --- a/mlspm/models.py +++ b/mlspm/models.py @@ -1,2 +1,3 @@ from .graph.models import PosNet, GraphImgNet, GraphImgNetIce +from .image.models import ASDAFMNet from ._weights import download_weights diff --git a/mlspm/modules.py b/mlspm/modules.py index 8f2df0f..4f4c7e2 100644 --- a/mlspm/modules.py +++ b/mlspm/modules.py @@ -13,9 +13,6 @@ def _get_padding(kernel_size: int | Tuple[int, ...], nd: int) -> Tuple[int, ...] padding += [(kernel_size[i] - 1) // 2] return tuple(padding) -def _flatten_z_to_channels(x): - return x.permute(0,1,4,2,3).reshape(x.size(0), -1, x.size(2), x.size(3)) - class _ConvNdBlock(nn.Module): def __init__( self,