Skip to content

Commit

Permalink
Merge pull request #19 from bioinf-jku/dev
Browse files Browse the repository at this point in the history
Fix numerical issues and efficiency issues
  • Loading branch information
renzph committed Apr 1, 2024
2 parents b4bcc22 + 2c45937 commit 66d7ad3
Show file tree
Hide file tree
Showing 10 changed files with 377 additions and 96 deletions.
40 changes: 40 additions & 0 deletions .github/workflows/test_dev.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python

name: Tests (dev)

on:
push:
branches: [ "dev" ]
pull_request:
branches: [ "dev" ]

jobs:
build:

runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]

steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install flake8 pytest
python -m pip install -e .
- name: Lint with flake8
run: |
# stop the build if there are Python syntax errors or undefined names
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
- name: Test with pytest
run: |
pytest
40 changes: 40 additions & 0 deletions .github/workflows/test_master.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python

name: Tests (master)

on:
push:
branches: [ "master"]
pull_request:
branches: [ "master"]

jobs:
build:

runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]

steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install flake8 pytest
python -m pip install -e .
- name: Lint with flake8
run: |
# stop the build if there are Python syntax errors or undefined names
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
- name: Test with pytest
run: |
pytest
14 changes: 13 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,15 +1,27 @@
# Fréchet ChemNet Distance
![PyPI](https://img.shields.io/pypi/v/fcd)
![Tests (master)](https://github.com/bioinf-jku/fcd/actions/workflows/test_master.yml/badge.svg?branch=dev)
![Tests (dev)](https://github.com/bioinf-jku/fcd/actions/workflows/test_dev.yml/badge.svg?branch=dev)
![PyPI - Downloads](https://img.shields.io/pypi/dm/fcd)
![GitHub release (latest by date)](https://img.shields.io/github/v/release/bioinf-jku/fcd)
![GitHub release date](https://img.shields.io/github/release-date/bioinf-jku/fcd)
![GitHub](https://img.shields.io/github/license/bioinf-jku/fcd)


Code for the paper "Fréchet ChemNet Distance: A Metric for Generative Models for Molecules in Drug Discovery"
[JCIM](https://pubs.acs.org/doi/10.1021/acs.jcim.8b00234) /
[ArXiv](https://arxiv.org/abs/1803.09518)


## Installation
You can install the FCD using
You can install FCD using
```
pip install fcd
```
or run the example notebook on Google Colab <a href="https://colab.research.google.com/github/bioinf-jku/FCD/blob/master/example.ipynb">
<img src="https://colab.research.google.com/assets/colab-badge.svg">
</a>.


# Requirements
```
Expand Down
48 changes: 33 additions & 15 deletions example.ipynb
Original file line number Diff line number Diff line change
@@ -1,21 +1,41 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"%%capture\n",
"!pip install fcd"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!mkdir generated_smiles\n",
"!wget https://raw.githubusercontent.com/bioinf-jku/FCD/master/generated_smiles/LSTM_Segler.smi -o generated_smiles/LSTM_Segler.smi"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"from rdkit import RDLogger \n",
"from rdkit import RDLogger\n",
"import numpy as np\n",
"import pandas as pd\n",
"from fcd import get_fcd, load_ref_model,canonical_smiles, get_predictions, calculate_frechet_distance\n",
"from fcd import get_fcd, load_ref_model, canonical_smiles, get_predictions, calculate_frechet_distance\n",
"\n",
"RDLogger.DisableLog('rdApp.*')\n",
"RDLogger.DisableLog(\"rdApp.*\")\n",
"\n",
"np.random.seed(0)\n",
"os.environ[\"CUDA_VISIBLE_DEVICES\"]= '0' #set gpu"
"os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\" # set gpu"
]
},
{
Expand Down Expand Up @@ -44,8 +64,10 @@
"model = load_ref_model()\n",
"\n",
"# Load generated molecules\n",
"gen_mol_file = \"generated_smiles/LSTM_Segler.smi\" #input file which contains one generated SMILES per line\n",
"gen_mol = pd.read_csv(gen_mol_file,header=None)[0] #IMPORTANT: take at least 10000 molecules as FCD can vary with sample size \n",
"gen_mol_file = \"generated_smiles/LSTM_Segler.smi\" # input file which contains one generated SMILES per line\n",
"gen_mol = pd.read_csv(gen_mol_file, header=None)[\n",
" 0\n",
"] # IMPORTANT: take at least 10000 molecules as FCD can vary with sample size\n",
"sample1 = np.random.choice(gen_mol, 10000, replace=False)\n",
"sample2 = np.random.choice(gen_mol, 10000, replace=False)\n",
"\n",
Expand Down Expand Up @@ -82,7 +104,7 @@
}
],
"source": [
"#get CHEBMLNET activations of generated molecules \n",
"# get CHEBMLNET activations of generated molecules\n",
"act1 = get_predictions(model, can_sample1)\n",
"act2 = get_predictions(model, can_sample2)\n",
"\n",
Expand All @@ -92,13 +114,9 @@
"mu2 = np.mean(act2, axis=0)\n",
"sigma2 = np.cov(act2.T)\n",
"\n",
"fcd_score = calculate_frechet_distance(\n",
" mu1=mu1,\n",
" mu2=mu2, \n",
" sigma1=sigma1,\n",
" sigma2=sigma2)\n",
"fcd_score = calculate_frechet_distance(mu1=mu1, mu2=mu2, sigma1=sigma1, sigma2=sigma2)\n",
"\n",
"print('FCD: ',fcd_score)"
"print(\"FCD: \", fcd_score)"
]
},
{
Expand All @@ -123,7 +141,7 @@
"\"\"\"if you don't need to store the activations you can also take a shortcut.\"\"\"\n",
"fcd_score = get_fcd(can_sample1, can_sample2, model)\n",
"\n",
"print('FCD: ',fcd_score)"
"print(\"FCD: \", fcd_score)"
]
},
{
Expand All @@ -147,7 +165,7 @@
"source": [
"\"\"\"This is what happens if you do not canonicalize the smiles\"\"\"\n",
"fcd_score = get_fcd(can_sample1, sample2, model)\n",
"print('FCD: ',fcd_score)"
"print(\"FCD: \", fcd_score)"
]
}
],
Expand Down
16 changes: 13 additions & 3 deletions fcd/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,14 @@
from .fcd import get_fcd, get_predictions, load_ref_model
from .utils import calculate_frechet_distance, canonical_smiles
# ruff: noqa: F401

__version__ = "1.2"
from fcd.fcd import get_fcd, get_predictions, load_ref_model
from fcd.utils import calculate_frechet_distance, canonical_smiles

__all__ = [
"get_fcd",
"get_predictions",
"load_ref_model",
"calculate_frechet_distance",
"canonical_smiles",
]

__version__ = "1.2.1"
47 changes: 29 additions & 18 deletions fcd/fcd.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from torch import nn
from torch.utils.data import DataLoader

from .utils import (
from fcd.utils import (
SmilesDataset,
calculate_frechet_distance,
load_imported_model,
Expand All @@ -31,6 +31,8 @@ def load_ref_model(model_path: Optional[str] = None):
if model_path is None:
chemnet_model_filename = "ChemNet_v0.13_pretrained.pt"
model_bytes = pkgutil.get_data("fcd", chemnet_model_filename)
if model_bytes is None:
raise FileNotFoundError(f"Could not find model file {chemnet_model_filename}")

tmpdir = tempfile.TemporaryDirectory()
model_path = os.path.join(tmpdir.name, chemnet_model_filename)
Expand All @@ -48,7 +50,7 @@ def get_predictions(
smiles_list: List[str],
batch_size: int = 128,
n_jobs: int = 1,
device: str = "cpu",
device: Optional[str] = None,
) -> np.ndarray:
"""Calculate Chemnet activations
Expand All @@ -65,46 +67,55 @@ def get_predictions(
if len(smiles_list) == 0:
return np.zeros((0, 512))

dataloader = DataLoader(
SmilesDataset(smiles_list), batch_size=batch_size, num_workers=n_jobs
)
dataloader = DataLoader(SmilesDataset(smiles_list), batch_size=batch_size, num_workers=n_jobs)

if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"

with todevice(model, device), torch.no_grad():
chemnet_activations = []
for batch in dataloader:
chemnet_activations.append(
model(batch.transpose(1, 2).float().to(device))
.to("cpu")
.detach()
.numpy()
model(batch.transpose(1, 2).float().to(device)).to("cpu").detach().numpy().astype(np.float32)
)
return np.row_stack(chemnet_activations)


def get_fcd(smiles1: List[str], smiles2: List[str], model: nn.Module = None) -> float:
def get_fcd(smiles1: List[str], smiles2: List[str], model: Optional[nn.Module] = None, device=None) -> float:
"""Calculate FCD between two sets of Smiles
Args:
smiles1 (List[str]): First set of smiles
smiles2 (List[str]): Second set of smiles
smiles1 (List[str]): First set of SMILES.
smiles2 (List[str]): Second set of SMILES.
model (nn.Module, optional): The model to use. Loads default model if None.
device: The device to use for computation.
Returns:
float: The FCD score
float: The FCD score.
Raises:
ValueError: If the input SMILES lists are empty.
Example:
>>> smiles1 = ['CCO', 'CCN']
>>> smiles2 = ['CCO', 'CCC']
>>> fcd_score = get_fcd(smiles1, smiles2)
"""
if not smiles1 or not smiles2:
raise ValueError("Input SMILES lists cannot be empty.")

if model is None:
model = load_ref_model()

act1 = get_predictions(model, smiles1)
act2 = get_predictions(model, smiles2)
act1 = get_predictions(model, smiles1, device=device)
act2 = get_predictions(model, smiles2, device=device)

mu1 = np.mean(act1, axis=0)
sigma1 = np.cov(act1.T)

mu2 = np.mean(act2, axis=0)
sigma2 = np.cov(act2.T)

fcd_score = calculate_frechet_distance(
mu1=mu1, mu2=mu2, sigma1=sigma1, sigma2=sigma2
)
fcd_score = calculate_frechet_distance(mu1=mu1, mu2=mu2, sigma1=sigma1, sigma2=sigma2)

return fcd_score
Loading

0 comments on commit 66d7ad3

Please sign in to comment.