Skip to content

Commit

Permalink
Multiple changes are part of this commit.
Browse files Browse the repository at this point in the history
1. Changes to `get_one_hot`
Problems are given in:
- #14
- #17
- #13

I discarded the changes in the PRs and and added more comprehensive handling of the input data in the
`SmilesDataset` class and the `get_one_hot` function.

2. Imaginary components
Frechet distance calculation fails to work for some cases because of badly conditioned matrices,
as described here #15.

Could not reproduce the error locally, but could do so on colab.

Fixed it in `calculate_frechet_distance` by checking if the first `covmean` computation  is real add a small value to the diagonal.
This made it work for me and I got the same result as the original implementation run locally.

3. Added some more tests and changed to pytest

4. As described in #16
I changed the data type of the activations to float32 in the `get_predictions` function,
which saves memory for larger datasets.
  • Loading branch information
renzph committed Apr 1, 2024
1 parent 1249378 commit ad2fb7a
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 18 deletions.
14 changes: 12 additions & 2 deletions fcd/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,14 @@
from .fcd import get_fcd, get_predictions, load_ref_model
from .utils import calculate_frechet_distance, canonical_smiles
# ruff: noqa: F401

from fcd.fcd import get_fcd, get_predictions, load_ref_model
from fcd.utils import calculate_frechet_distance, canonical_smiles

__all__ = [
"get_fcd",
"get_predictions",
"load_ref_model",
"calculate_frechet_distance",
"canonical_smiles",
]

__version__ = "1.2"
25 changes: 17 additions & 8 deletions fcd/fcd.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import ast
import os
import pkgutil
import tempfile
Expand All @@ -10,7 +9,7 @@
from torch import nn
from torch.utils.data import DataLoader

from .utils import (
from fcd.utils import (
SmilesDataset,
calculate_frechet_distance,
load_imported_model,
Expand Down Expand Up @@ -94,13 +93,25 @@ def get_fcd(smiles1: List[str], smiles2: List[str], model: nn.Module | None = No
"""Calculate FCD between two sets of Smiles
Args:
smiles1 (List[str]): First set of smiles
smiles2 (List[str]): Second set of smiles
smiles1 (List[str]): First set of SMILES.
smiles2 (List[str]): Second set of SMILES.
model (nn.Module, optional): The model to use. Loads default model if None.
device: The device to use for computation.
Returns:
float: The FCD score
float: The FCD score.
Raises:
ValueError: If the input SMILES lists are empty.
Example:
>>> smiles1 = ['CCO', 'CCN']
>>> smiles2 = ['CCO', 'CCC']
>>> fcd_score = get_fcd(smiles1, smiles2)
"""
if not smiles1 or not smiles2:
raise ValueError("Input SMILES lists cannot be empty.")

if model is None:
model = load_ref_model()

Expand All @@ -114,8 +125,6 @@ def get_fcd(smiles1: List[str], smiles2: List[str], model: nn.Module | None = No
mu2 = np.mean(act2, axis=0)
sigma2 = np.cov(act2.T)

fcd_score = calculate_frechet_distance(
mu1=mu1, mu2=mu2, sigma1=sigma1, sigma2=sigma2
)
fcd_score = calculate_frechet_distance(mu1=mu1, mu2=mu2, sigma1=sigma1, sigma2=sigma2)

return fcd_score
23 changes: 15 additions & 8 deletions fcd/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import re
import warnings
from contextlib import contextmanager
from multiprocessing import Pool
from typing import List
import warnings

import numpy as np
import torch
Expand All @@ -11,7 +11,7 @@
from torch import nn
from torch.utils.data import Dataset

from .torch_layers import IndexTensor, IndexTuple, Reverse, SamePadding1d, Transpose
from fcd.torch_layers import IndexTensor, IndexTuple, Reverse, SamePadding1d, Transpose

# fmt: off
__vocab = ["C","N","O","H","F","Cl","P","B","Br","S","I","Si","#","(",")","+","-","1","2","3","4","5","6","7","8","=","[","]","@","c","n","o","s","X","."]
Expand Down Expand Up @@ -156,7 +156,13 @@ def __len__(self):
return len(self.smiles_list)


def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
def calculate_frechet_distance(
mu1: np.ndarray,
sigma1: np.ndarray,
mu2: np.ndarray,
sigma2: np.ndarray,
eps: float = 1e-6,
) -> float:
"""Numpy implementation of the Frechet Distance.
The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
and X_2 ~ N(mu_2, C_2) is
Expand Down Expand Up @@ -202,7 +208,8 @@ def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
if not np.isfinite(covmean).all() or not is_real:
offset = np.eye(sigma1.shape[0]) * eps
covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))


assert isinstance(covmean, np.ndarray)
# numerical error might give slight imaginary component
if np.iscomplexobj(covmean):
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
Expand All @@ -212,7 +219,7 @@ def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):

tr_covmean = np.trace(covmean)

return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean
return float(diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean)


@contextmanager
Expand All @@ -225,11 +232,11 @@ def todevice(model, device):

def canonical(smi):
try:
return Chem.MolToSmiles(Chem.MolFromSmiles(smi))
except:
return Chem.MolToSmiles(Chem.MolFromSmiles(smi)) # type: ignore
except Exception:
return None


def canonical_smiles(smiles, njobs=32):
def canonical_smiles(smiles, njobs=-1):
with Pool(njobs) as pool:
return pool.map(canonical, smiles)
77 changes: 77 additions & 0 deletions ruff.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Exclude a variety of commonly ignored directories.
exclude = [
".bzr",
".direnv",
".eggs",
".git",
".git-rewrite",
".hg",
".ipynb_checkpoints",
".mypy_cache",
".nox",
".pants.d",
".pyenv",
".pytest_cache",
".pytype",
".ruff_cache",
".svn",
".tox",
".venv",
".vscode",
"__pypackages__",
"_build",
"buck-out",
"build",
"dist",
"node_modules",
"site-packages",
"venv",
]

# Same as Black.
line-length = 120
indent-width = 4

# Assume Python 3.8
target-version = "py38"

[lint]
# Enable Pyflakes (`F`) and a subset of the pycodestyle (`E`) codes by default.
# Unlike Flake8, Ruff doesn't enable pycodestyle warnings (`W`) or
# McCabe complexity (`C901`) by default.
select = ["E4", "E7", "E9", "F"]
ignore = []

# Allow fix for all enabled rules (when `--fix`) is provided.
fixable = ["ALL"]
unfixable = []

# Allow unused variables when underscore-prefixed.
dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"

[format]
# Like Black, use double quotes for strings.
quote-style = "double"

# Like Black, indent with spaces, rather than tabs.
indent-style = "space"

# Like Black, respect magic trailing commas.
skip-magic-trailing-comma = false

# Like Black, automatically detect the appropriate line ending.
line-ending = "auto"

# Enable auto-formatting of code examples in docstrings. Markdown,
# reStructuredText code/literal blocks and doctests are all supported.
#
# This is currently disabled by default, but it is planned for this
# to be opt-out in the future.
docstring-code-format = false

# Set the line length limit used when formatting code snippets in
# docstrings.
#
# This only has an effect when the `docstring-code-format` setting is
# enabled.
docstring-code-line-length = "dynamic"

0 comments on commit ad2fb7a

Please sign in to comment.