-
Notifications
You must be signed in to change notification settings - Fork 26
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
12 changed files
with
608 additions
and
144 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from .fcd import get_fcd, get_predictions, load_ref_model | ||
from .utils import calculate_frechet_distance, canonical_smiles | ||
|
||
__version__ = "1.2" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
import os | ||
import pkgutil | ||
import tempfile | ||
from functools import lru_cache | ||
from typing import List, Optional | ||
|
||
import numpy as np | ||
import torch | ||
from torch import nn | ||
from torch.utils.data import DataLoader | ||
|
||
from .utils import ( | ||
SmilesDataset, | ||
calculate_frechet_distance, | ||
load_imported_model, | ||
todevice, | ||
) | ||
|
||
|
||
@lru_cache(maxsize=1) | ||
def load_ref_model(model_path: Optional[str] = None): | ||
"""Loads chemnet model | ||
Args: | ||
model_path (str | None, optional): Path to model file. Defaults to None. | ||
Returns: | ||
Chemnet as torch model | ||
""" | ||
|
||
if model_path is None: | ||
chemnet_model_filename = "ChemNet_v0.13_pretrained.pt" | ||
model_bytes = pkgutil.get_data("fcd_torch", chemnet_model_filename) | ||
|
||
tmpdir = tempfile.TemporaryDirectory() | ||
model_path = os.path.join(tmpdir.name, chemnet_model_filename) | ||
with open(model_path, "wb") as f: | ||
f.write(model_bytes) | ||
|
||
model_config = torch.load(model_path) | ||
model = load_imported_model(model_config) | ||
model.eval() | ||
return model | ||
|
||
|
||
def get_predictions( | ||
model: nn.Module, | ||
smiles_list: List[str], | ||
batch_size: int = 128, | ||
n_jobs: int = 1, | ||
device: str = "cpu", | ||
) -> np.ndarray: | ||
"""Calculate Chemnet activations | ||
Args: | ||
model (nn.Module): Chemnet model | ||
smiles_list (List[str]): List of smiles to process | ||
batch_size (int, optional): Which batch size to use for inference. Defaults to 128. | ||
n_jobs (int, optional): How many jobs to use for preprocessing. Defaults to 1. | ||
device (str, optional): On which device the chemnet model is run. Defaults to "cpu". | ||
Returns: | ||
np.ndarray: The activation for the input list | ||
""" | ||
if len(smiles_list) == 0: | ||
return np.zeros((0, 512)) | ||
|
||
dataloader = DataLoader( | ||
SmilesDataset(smiles_list), batch_size=batch_size, num_workers=n_jobs | ||
) | ||
with todevice(model, device), torch.no_grad(): | ||
chemnet_activations = [] | ||
for batch in dataloader: | ||
chemnet_activations.append( | ||
model(batch.transpose(1, 2).float().to(device)) | ||
.to("cpu") | ||
.detach() | ||
.numpy() | ||
) | ||
return np.row_stack(chemnet_activations) | ||
|
||
|
||
def get_fcd(smiles1: List[str], smiles2: List[str], model: nn.Module = None) -> float: | ||
"""Calculate FCD between two sets of Smiles | ||
Args: | ||
smiles1 (List[str]): First set of smiles | ||
smiles2 (List[str]): Second set of smiles | ||
model (nn.Module, optional): The model to use. Loads default model if None. | ||
Returns: | ||
float: The FCD score | ||
""" | ||
if model is None: | ||
model = load_ref_model() | ||
|
||
act1 = get_predictions(model, smiles1) | ||
act2 = get_predictions(model, smiles2) | ||
|
||
mu1 = np.mean(act1, axis=0) | ||
sigma1 = np.cov(act1.T) | ||
|
||
mu2 = np.mean(act2, axis=0) | ||
sigma2 = np.cov(act2.T) | ||
|
||
fcd_score = calculate_frechet_distance( | ||
mu1=mu1, mu2=mu2, sigma1=sigma1, sigma2=sigma2 | ||
) | ||
|
||
return fcd_score |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
import torch | ||
from torch import nn | ||
|
||
|
||
class Reverse(nn.Module): | ||
def forward(self, x): | ||
return torch.flip(x, [1]) | ||
|
||
|
||
class IndexTuple(nn.Module): | ||
def __init__(self, pos): | ||
super().__init__() | ||
self.pos = pos | ||
|
||
def forward(self, x): | ||
return x[self.pos] | ||
|
||
|
||
class IndexTensor(nn.Module): | ||
def __init__(self, pos, dim): | ||
super().__init__() | ||
self.pos = pos | ||
self.dim = dim | ||
|
||
def forward(self, x): | ||
return torch.select(x, self.dim, self.pos) | ||
|
||
|
||
class Transpose(nn.Module): | ||
def forward(self, x): | ||
return x.transpose(1, 2) | ||
|
||
|
||
# https://stackoverflow.com/questions/52975843/comparing-conv2d-with-padding-between-tensorflow-and-pytorch | ||
class SamePadding1d(nn.Module): | ||
def __init__(self, kernel_size, stride): | ||
super().__init__() | ||
self.kernel_size = kernel_size | ||
self.stride = stride | ||
|
||
def forward(self, x): | ||
if x.shape[2] % self.stride == 0: | ||
pad = max(self.kernel_size - self.stride, 0) | ||
else: | ||
pad = max(self.kernel_size - (x.shape[2] % self.stride), 0) | ||
|
||
if pad % self.stride == 0: | ||
pad_val = pad // self.stride | ||
padding = (pad_val, pad_val) | ||
else: | ||
pad_val_start = pad // self.stride | ||
pad_val_end = pad - pad_val_start | ||
padding = (pad_val_start, pad_val_end) | ||
return torch.nn.functional.pad(x, padding, "constant", 0) |
Oops, something went wrong.