Skip to content

Commit

Permalink
Merge pull request #37 from anton-bushuiev/main
Browse files Browse the repository at this point in the history
Enhance baselines
  • Loading branch information
anton-bushuiev authored Aug 16, 2024
2 parents 543b3e6 + d9dc2c9 commit 7be2799
Show file tree
Hide file tree
Showing 10 changed files with 2,310 additions and 946 deletions.
3 changes: 2 additions & 1 deletion massspecgym/models/de_novo/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ def __init__(
max_top_k: int = 10,
enforce_connectivity: bool = True,
cache_results: bool = True,
**kwargs
):
"""
Expand All @@ -208,7 +209,7 @@ def __init__(
When set to True, for each unique formula the set of random molecules is cached to avoid
recomputation.
"""
super(RandomDeNovo, self).__init__()
super(RandomDeNovo, self).__init__(**kwargs)
self.formula_known = formula_known
self.count_of_valid_valence_assignments = count_of_valid_valence_assignments
self.estimate_chem_element_stats = estimate_chem_element_stats
Expand Down
2 changes: 1 addition & 1 deletion massspecgym/models/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class FourierFeatures(nn.Module):

def __init__(
self,
strategy,
strategy='dreams',
x_min=1e-4,
x_max=1000,
trainable=False,
Expand Down
69 changes: 59 additions & 10 deletions massspecgym/models/retrieval/deepsets.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,78 @@
import typing as T

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import MLP

from massspecgym.models.base import Stage
from massspecgym.models.retrieval.base import RetrievalMassSpecGymModel
from massspecgym.models.layers import FourierFeatures
from massspecgym.utils import CosSimLoss


class DeepSetsRetrieval(RetrievalMassSpecGymModel):
def __init__(self, **kwargs):
def __init__(
self,
in_channels: int = 2, # m/z and intensity of a peak
hidden_channels: int = 512, # hidden layer size
out_channels: int = 4096, # fingerprint size
num_layers_per_mlp: int = 2,
dropout: float = 0.0,
norm: T.Optional[str] = None,
fourier_features: bool = True,
fourier_features_mz_channels: T.Optional[int] = None,
fourier_features_kwargs: T.Optional[dict] = None,
**kwargs
):
super().__init__(**kwargs)

self.phi = nn.Sequential(
nn.Linear(2, 128),
nn.ReLU(),
nn.Linear(128, 128),
nn.ReLU(),
self.fourier_features = fourier_features
if fourier_features:
if fourier_features_kwargs is None:
fourier_features_kwargs = {}
self.ff = FourierFeatures(**fourier_features_kwargs)

if fourier_features_mz_channels is None:
fourier_features_mz_channels = int(0.8 * hidden_channels)
else:
assert fourier_features_mz_channels < hidden_channels
self.ff_proj_mz = nn.Linear(self.ff.num_features, fourier_features_mz_channels)
self.ff_proj_i = nn.Linear(1, hidden_channels - fourier_features_mz_channels)
in_channels = hidden_channels

self.phi = MLP(
in_channels=in_channels,
hidden_channels=hidden_channels,
out_channels=hidden_channels,
num_layers=num_layers_per_mlp,
dropout=dropout,
norm=norm
)
self.rho = nn.Sequential(
nn.Linear(128, 128), nn.ReLU(), nn.Linear(128, 2048), nn.Sigmoid()

self.rho = MLP(
in_channels=hidden_channels,
hidden_channels=hidden_channels,
out_channels=out_channels,
num_layers=num_layers_per_mlp,
dropout=dropout,
norm=norm
)

self.loss_fn = CosSimLoss()

def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.fourier_features:
x_mz = x[:, :, 0].unsqueeze(-1)
x_mz = self.ff(x_mz)
x_mz = self.ff_proj_mz(x_mz)
x_i = x[:, :, 1].unsqueeze(-1)
x_i = self.ff_proj_i(x_i)
x = torch.cat((x_mz, x_i), dim=-1)
x = self.phi(x)
x = x.sum(dim=-2) # sum over peaks
x = self.rho(x)
x = F.sigmoid(x) # predict proper fingerprint
return x

def step(
Expand All @@ -32,14 +82,13 @@ def step(
x = batch["spec"]
fp_true = batch["mol"]
cands = batch["candidates"]
labels = batch["labels"]
batch_ptr = batch["batch_ptr"]

# Predict fingerprint
fp_pred = self.forward(x)

# Calculate loss
loss = nn.functional.mse_loss(fp_true, fp_pred)
loss = self.loss_fn(fp_true, fp_pred)

# Evaluation performance on fingerprint prediction (optional)
self.evaluate_fingerprint_step(fp_true, fp_pred, stage=stage)
Expand Down
16 changes: 7 additions & 9 deletions massspecgym/models/retrieval/fingerprint_ffn.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,13 @@
import typing as T

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import MLP

from massspecgym.models.base import Stage
from massspecgym.models.retrieval.base import RetrievalMassSpecGymModel


class CosSimLoss(nn.Module):
def __init__(self):
super(CosSimLoss, self).__init__()

def forward(self, inputs, targets):
return 1 - F.cosine_similarity(inputs, targets).mean()
from massspecgym.utils import CosSimLoss


class FingerprintFFNRetrieval(RetrievalMassSpecGymModel):
Expand All @@ -23,6 +18,7 @@ def __init__(
out_channels: int = 4096, # fingerprint size
num_layers: int = 2,
dropout: float = 0.0,
norm: T.Optional[str] = None,
**kwargs
):
super().__init__(**kwargs)
Expand All @@ -32,13 +28,15 @@ def __init__(
hidden_channels=hidden_channels,
out_channels=out_channels,
num_layers=num_layers,
dropout=dropout
dropout=dropout,
norm=norm
)

self.loss_fn = CosSimLoss()

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.ffn(x)
x = F.sigmoid(x) # predict proper fingerprint
return x

def step(
Expand Down
10 changes: 10 additions & 0 deletions massspecgym/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import typing as T
import pulp
import torch
import torch.nn as nn
import torch.nn.functional as F
from itertools import groupby
from pathlib import Path
from myopic_mces.myopic_mces import MCES
Expand Down Expand Up @@ -405,6 +407,14 @@ def unbatch_list(batch_list: list, batch_idx: torch.Tensor) -> list:
]


class CosSimLoss(nn.Module):
def __init__(self):
super(CosSimLoss, self).__init__()

def forward(self, inputs, targets):
return 1 - F.cosine_similarity(inputs, targets).mean()


def parse_sirius_ms(spectra_file: str) -> T.Tuple[dict, T.List[T.Tuple[str, np.ndarray]]]:
"""
Parses spectra from the SIRIUS .ms file.
Expand Down
Loading

0 comments on commit 7be2799

Please sign in to comment.