Skip to content

Commit

Permalink
Added pretrained weights for ASDAFMNet.
Browse files Browse the repository at this point in the history
  • Loading branch information
NikoOinonen committed Jan 15, 2024
1 parent 6af6d42 commit 114d3f2
Show file tree
Hide file tree
Showing 9 changed files with 68 additions and 14 deletions.
1 change: 1 addition & 0 deletions docs/source/reference/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ Reference
mlspm.data_loading
mlspm.datasets
mlspm.graph
mlspm.image
mlspm.logging
mlspm.losses
mlspm.models
Expand Down
15 changes: 15 additions & 0 deletions docs/source/reference/mlspm.image.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
mlspm.image
===========

.. automodule:: mlspm.image
:members:
:undoc-members:
:show-inheritance:

mlspm.image.models
------------------

.. automodule:: mlspm.image.models
:members:
:undoc-members:
:show-inheritance:
4 changes: 4 additions & 0 deletions docs/source/reference/mlspm.models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,8 @@ mlspm.models

Alias of :class:`mlspm.graph.models.GraphImgNetIce`

.. class:: mlspm.models.ASDAFMNet

Alias of :class:`mlspm.image.models.ASDAFMNet`

.. autofunction:: mlspm.models.download_weights
14 changes: 9 additions & 5 deletions mlspm/_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
"graph-ice-cu111": "https://zenodo.org/records/10054348/files/weights_ice-cu111.pth?download=1",
"graph-ice-au111-monolayer": "https://zenodo.org/records/10054348/files/weights_ice-au111-monolayer.pth?download=1",
"graph-ice-au111-bilayer": "https://zenodo.org/records/10054348/files/weights_ice-au111-bilayer.pth?download=1",
"asdafm-light": "https://zenodo.org/records/10514470/files/weights_asdafm_light.pth?download=1",
"asdafm-heavy": "https://zenodo.org/records/10514470/files/weights_asdafm_heavy.pth?download=1",
}


Expand All @@ -18,18 +20,20 @@ def download_weights(weights_name: str, target_path: Optional[PathLike] = None)
The following weights are available:
- ``'graph-ice-cu111'``: PosNet trained on ice clusters on Cu(111).
- ``'graph-ice-au111-monolayer'``: PosNet trained on monolayer ice clusters on Au(111).
- ``'graph-ice-au111-bilayer'``: PosNet trained on bilayer ice clusters on Au(111).
- ``'graph-ice-cu111'``: PosNet trained on ice clusters on Cu(111). (https://doi.org/10.5281/zenodo.10054348)
- ``'graph-ice-au111-monolayer'``: PosNet trained on monolayer ice clusters on Au(111). (https://doi.org/10.5281/zenodo.10054348)
- ``'graph-ice-au111-bilayer'``: PosNet trained on bilayer ice clusters on Au(111). (https://doi.org/10.5281/zenodo.10054348)
- ``'asdafm-light'``: ASDAFMNet trained on molecules containing the elements H, C, N, O, and F. (https://doi.org/10.5281/zenodo.10514470)
- ``'asdafm-heavy'``: ASDAFMNet trained on molecules additionally containing Si, P, S, Cl, and Br. (https://doi.org/10.5281/zenodo.10514470)
Arguments:
weights_name: Name of weights to download.
target_path: Path where the weights file will be saved. If specified, the parent directory for the file has to exists.
If not specified, a location in cache directory is chosen. If the target file already exists, the download is skipped
If not specified, a location in a cache directory is chosen. If the target file already exists, the download is skipped
Returns:
Path where the weights were saved.
"""
try:
weights_url = WEIGHTS_URLS[weights_name]
Expand Down
6 changes: 3 additions & 3 deletions mlspm/graph/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,9 +654,9 @@ class GraphImgNetIce(GraphImgNet):
Three sets of pretrained weights are available:
- 'cu111'
- 'au111-monolayer'
- 'au111-bilayer'
- ``'cu111'``: trained on images of ice clusters on Cu(111)
- ``'au111-monolayer'``: trained on images of ice clusters on monolayer Au(111)
- ``'au111-bilayer'``: trained on images of ice clusters on bilayer Au(111)
Arguments:
pretrained_weights: Name of pretrained weights. If specified, load pretrained weights. Otherwise, weights are initialized
Expand Down
1 change: 1 addition & 0 deletions mlspm/image/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
#!/usr/bin/env python3
37 changes: 34 additions & 3 deletions mlspm/image/models.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,44 @@
from turtle import forward
from typing import Tuple
from typing import Literal, Optional, Tuple

import torch
from torch import nn

from ..modules import _get_padding, _flatten_z_to_channels
from ..modules import _get_padding
from .._weights import download_weights


def _flatten_z_to_channels(x):
return x.permute(0, 4, 1, 2, 3).reshape(x.size(0), -1, x.size(2), x.size(3))


class ASDAFMNet(nn.Module):
"""
The model used in the paper "Automated structure discovery in atomic force microscopy": https://doi.org/10.1126/sciadv.aay6913.
Two different sets of pretrained weights are available:
- ``'asdafm-light'``: trained on a set of molecules containing the elements H, C, N, O, and F.
- ``'asdafm-heavy'``: trained on a set of molecules additionally containing the elements Si, P, S, Cl, and Br.
Arguments:
n_out: number of output branches.
activation: Activation function used after each layer.
padding_mode: Type of padding in each convolution layer. ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.
last_relu: Whether to use a ReLU layer after the last convolution in each output branch. Either provide a single value
that applies to all output branches or a list of values corresponding to each output branch.
pretrained_weights: Name of pretrained weights. If specified, load pretrained weights. Otherwise, weights are initialized
randomly. If loading the weights, the other arguments should be left in their default values, so that the model
hyperparameters match the training.
"""

def __init__(
self,
n_out: int = 3,
activation: nn.Module = nn.LeakyReLU(0.1),
padding_mode: str = "reflect",
padding_mode: str = "replicate",
last_relu: bool | Tuple[bool, ...] = True,
pretrained_weights: Optional[Literal["asdafm-light", "asdafm-heavy"]] = None,
):
super().__init__()

Expand Down Expand Up @@ -68,6 +94,11 @@ def __init__(
]
)

if pretrained_weights is not None:
weights_path = download_weights(pretrained_weights)
weights = torch.load(weights_path)
self.load_state_dict(weights)

def forward(self, x):
x = self.encoder(x)
x = _flatten_z_to_channels(x)
Expand Down
1 change: 1 addition & 0 deletions mlspm/models.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .graph.models import PosNet, GraphImgNet, GraphImgNetIce
from .image.models import ASDAFMNet
from ._weights import download_weights
3 changes: 0 additions & 3 deletions mlspm/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,6 @@ def _get_padding(kernel_size: int | Tuple[int, ...], nd: int) -> Tuple[int, ...]
padding += [(kernel_size[i] - 1) // 2]
return tuple(padding)

def _flatten_z_to_channels(x):
return x.permute(0,1,4,2,3).reshape(x.size(0), -1, x.size(2), x.size(3))

class _ConvNdBlock(nn.Module):
def __init__(
self,
Expand Down

0 comments on commit 114d3f2

Please sign in to comment.