Skip to content

Commit

Permalink
Port to pytorch
Browse files Browse the repository at this point in the history
  • Loading branch information
renzph committed Aug 22, 2023
1 parent 345b2f8 commit c7e7e99
Show file tree
Hide file tree
Showing 12 changed files with 608 additions and 144 deletions.
85 changes: 18 additions & 67 deletions example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,35 +2,20 @@
"cells": [
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"ExecuteTime": {
"end_time": "2020-04-23T09:11:07.345354Z",
"start_time": "2020-04-23T09:11:07.321510Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The autoreload extension is already loaded. To reload it, use:\n",
" %reload_ext autoreload\n"
]
}
],
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2\n",
"import os\n",
"\n",
"from rdkit import RDLogger \n",
"import numpy as np\n",
"import pandas as pd\n",
"import pickle\n",
"from fcd import get_fcd, load_ref_model,canonical_smiles, get_predictions, calculate_frechet_distance\n",
"\n",
"RDLogger.DisableLog('rdApp.*')\n",
"\n",
"np.random.seed(0)\n",
"os.environ[\"CUDA_VISIBLE_DEVICES\"]= '0' #set gpu "
"os.environ[\"CUDA_VISIBLE_DEVICES\"]= '0' #set gpu"
]
},
{
Expand All @@ -46,47 +31,14 @@
},
{
"cell_type": "code",
"execution_count": 19,
"execution_count": 2,
"metadata": {
"ExecuteTime": {
"end_time": "2020-04-23T09:13:50.403933Z",
"start_time": "2020-04-23T09:13:47.310624Z"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"RDKit ERROR: [11:13:48] Explicit valence for atom # 18 N, 4, is greater than permitted\n",
"RDKit ERROR: [11:13:48] Explicit valence for atom # 6 N, 4, is greater than permitted\n",
"RDKit ERROR: [11:13:48] Explicit valence for atom # 5 N, 4, is greater than permitted\n",
"RDKit ERROR: [11:13:48] Explicit valence for atom # 2 N, 5, is greater than permitted\n",
"RDKit ERROR: [11:13:48] Explicit valence for atom # 10 N, 4, is greater than permitted\n",
"RDKit ERROR: [11:13:48] Explicit valence for atom # 22 N, 4, is greater than permitted\n",
"RDKit ERROR: [11:13:48] Explicit valence for atom # 1 N, 4, is greater than permitted\n",
"RDKit ERROR: [11:13:48] Explicit valence for atom # 6 N, 4, is greater than permitted\n",
"RDKit ERROR: [11:13:48] Explicit valence for atom # 13 N, 4, is greater than permitted\n",
"RDKit ERROR: [11:13:48] Explicit valence for atom # 13 N, 4, is greater than permitted\n",
"RDKit ERROR: [11:13:48] Explicit valence for atom # 5 N, 5, is greater than permitted\n",
"RDKit ERROR: [11:13:48] Explicit valence for atom # 10 N, 4, is greater than permitted\n",
"RDKit ERROR: [11:13:48] Explicit valence for atom # 14 N, 4, is greater than permitted\n",
"RDKit ERROR: [11:13:48] Explicit valence for atom # 14 N, 4, is greater than permitted\n",
"RDKit ERROR: [11:13:48] Explicit valence for atom # 9 N, 4, is greater than permitted\n",
"RDKit ERROR: [11:13:49] Explicit valence for atom # 11 N, 4, is greater than permitted\n",
"RDKit ERROR: [11:13:49] Explicit valence for atom # 15 N, 4, is greater than permitted\n",
"RDKit ERROR: [11:13:49] Explicit valence for atom # 23 N, 4, is greater than permitted\n",
"RDKit ERROR: [11:13:49] Explicit valence for atom # 7 N, 5, is greater than permitted\n",
"RDKit ERROR: [11:13:49] Explicit valence for atom # 21 N, 4, is greater than permitted\n",
"RDKit ERROR: [11:13:49] Explicit valence for atom # 14 N, 4, is greater than permitted\n",
"RDKit ERROR: [11:13:49] Explicit valence for atom # 8 N, 4, is greater than permitted\n",
"RDKit ERROR: [11:13:49] Explicit valence for atom # 9 N, 5, is greater than permitted\n",
"RDKit ERROR: [11:13:49] Explicit valence for atom # 6 N, 4, is greater than permitted\n",
"RDKit ERROR: [11:13:50] Explicit valence for atom # 17 N, 5, is greater than permitted\n",
"RDKit ERROR: [11:13:50] Explicit valence for atom # 20 N, 5, is greater than permitted\n"
]
}
],
"outputs": [],
"source": [
"# Load chemnet model\n",
"model = load_ref_model()\n",
Expand All @@ -113,7 +65,7 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": 3,
"metadata": {
"ExecuteTime": {
"end_time": "2020-04-23T09:11:27.207953Z",
Expand All @@ -125,7 +77,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"FCD: 0.3338613001233881\n"
"FCD: 0.333862289051325\n"
]
}
],
Expand All @@ -151,7 +103,7 @@
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": 4,
"metadata": {
"ExecuteTime": {
"end_time": "2020-04-23T09:11:38.873496Z",
Expand All @@ -163,20 +115,20 @@
"name": "stdout",
"output_type": "stream",
"text": [
"FCD: 0.3338613001233881\n"
"FCD: 0.333862289051325\n"
]
}
],
"source": [
"\"\"\"if you don't need to store the activations you can also take a shortcut.\"\"\"\n",
"fcd_score = get_fcd(model, can_sample1, can_sample2)\n",
"fcd_score = get_fcd(can_sample1, can_sample2, model)\n",
"\n",
"print('FCD: ',fcd_score)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"execution_count": 5,
"metadata": {
"ExecuteTime": {
"end_time": "2020-04-23T09:11:49.760022Z",
Expand All @@ -188,14 +140,13 @@
"name": "stdout",
"output_type": "stream",
"text": [
"FCD: 25.63927611890624\n"
"FCD: 25.635578193222216\n"
]
}
],
"source": [
"\"\"\"This is what happens if you do not canonicalize the smiles\"\"\"\n",
"fcd_score = get_fcd(model, can_sample1, sample2)\n",
"\n",
"fcd_score = get_fcd(can_sample1, sample2, model)\n",
"print('FCD: ',fcd_score)"
]
}
Expand All @@ -216,7 +167,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.4"
"version": "3.11.3"
}
},
"nbformat": 4,
Expand Down
Binary file added fcd_torch/ChemNet_v0.13_pretrained.pt
Binary file not shown.
4 changes: 4 additions & 0 deletions fcd_torch/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .fcd import get_fcd, get_predictions, load_ref_model
from .utils import calculate_frechet_distance, canonical_smiles

__version__ = "1.2"
110 changes: 110 additions & 0 deletions fcd_torch/fcd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import os
import pkgutil
import tempfile
from functools import lru_cache
from typing import List, Optional

import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader

from .utils import (
SmilesDataset,
calculate_frechet_distance,
load_imported_model,
todevice,
)


@lru_cache(maxsize=1)
def load_ref_model(model_path: Optional[str] = None):
"""Loads chemnet model
Args:
model_path (str | None, optional): Path to model file. Defaults to None.
Returns:
Chemnet as torch model
"""

if model_path is None:
chemnet_model_filename = "ChemNet_v0.13_pretrained.pt"
model_bytes = pkgutil.get_data("fcd_torch", chemnet_model_filename)

tmpdir = tempfile.TemporaryDirectory()
model_path = os.path.join(tmpdir.name, chemnet_model_filename)
with open(model_path, "wb") as f:
f.write(model_bytes)

model_config = torch.load(model_path)
model = load_imported_model(model_config)
model.eval()
return model


def get_predictions(
model: nn.Module,
smiles_list: List[str],
batch_size: int = 128,
n_jobs: int = 1,
device: str = "cpu",
) -> np.ndarray:
"""Calculate Chemnet activations
Args:
model (nn.Module): Chemnet model
smiles_list (List[str]): List of smiles to process
batch_size (int, optional): Which batch size to use for inference. Defaults to 128.
n_jobs (int, optional): How many jobs to use for preprocessing. Defaults to 1.
device (str, optional): On which device the chemnet model is run. Defaults to "cpu".
Returns:
np.ndarray: The activation for the input list
"""
if len(smiles_list) == 0:
return np.zeros((0, 512))

dataloader = DataLoader(
SmilesDataset(smiles_list), batch_size=batch_size, num_workers=n_jobs
)
with todevice(model, device), torch.no_grad():
chemnet_activations = []
for batch in dataloader:
chemnet_activations.append(
model(batch.transpose(1, 2).float().to(device))
.to("cpu")
.detach()
.numpy()
)
return np.row_stack(chemnet_activations)


def get_fcd(smiles1: List[str], smiles2: List[str], model: nn.Module = None) -> float:
"""Calculate FCD between two sets of Smiles
Args:
smiles1 (List[str]): First set of smiles
smiles2 (List[str]): Second set of smiles
model (nn.Module, optional): The model to use. Loads default model if None.
Returns:
float: The FCD score
"""
if model is None:
model = load_ref_model()

act1 = get_predictions(model, smiles1)
act2 = get_predictions(model, smiles2)

mu1 = np.mean(act1, axis=0)
sigma1 = np.cov(act1.T)

mu2 = np.mean(act2, axis=0)
sigma2 = np.cov(act2.T)

fcd_score = calculate_frechet_distance(
mu1=mu1, mu2=mu2, sigma1=sigma1, sigma2=sigma2
)

return fcd_score
54 changes: 54 additions & 0 deletions fcd_torch/torch_layers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import torch
from torch import nn


class Reverse(nn.Module):
def forward(self, x):
return torch.flip(x, [1])


class IndexTuple(nn.Module):
def __init__(self, pos):
super().__init__()
self.pos = pos

def forward(self, x):
return x[self.pos]


class IndexTensor(nn.Module):
def __init__(self, pos, dim):
super().__init__()
self.pos = pos
self.dim = dim

def forward(self, x):
return torch.select(x, self.dim, self.pos)


class Transpose(nn.Module):
def forward(self, x):
return x.transpose(1, 2)


# https://stackoverflow.com/questions/52975843/comparing-conv2d-with-padding-between-tensorflow-and-pytorch
class SamePadding1d(nn.Module):
def __init__(self, kernel_size, stride):
super().__init__()
self.kernel_size = kernel_size
self.stride = stride

def forward(self, x):
if x.shape[2] % self.stride == 0:
pad = max(self.kernel_size - self.stride, 0)
else:
pad = max(self.kernel_size - (x.shape[2] % self.stride), 0)

if pad % self.stride == 0:
pad_val = pad // self.stride
padding = (pad_val, pad_val)
else:
pad_val_start = pad // self.stride
pad_val_end = pad - pad_val_start
padding = (pad_val_start, pad_val_end)
return torch.nn.functional.pad(x, padding, "constant", 0)
Loading

0 comments on commit c7e7e99

Please sign in to comment.