Skip to content

Commit

Permalink
Merge pull request #21 from vsheg/sklearn-api
Browse files Browse the repository at this point in the history
Add basic `sklearn` API support
  • Loading branch information
vsheg authored Jul 6, 2024
2 parents a117770 + 54c048d commit 15023f3
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 40 deletions.
32 changes: 23 additions & 9 deletions moll/pick/_online_picker.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,17 @@
import jax.numpy as jnp
import numpy as np
from jax import Array
from jax.typing import ArrayLike, DTypeLike
from loguru import logger
from jax.typing import DTypeLike
from numpy.typing import NDArray
from public import public
from sklearn.base import BaseEstimator, TransformerMixin

from ..typing import (
DistanceFnCallable,
DistanceFnLiteral,
Indexable,
PotentialFnCallable,
PotentialFnLiteral,
LossFnCallable,
LossFnLiteral,
SimilarityFnCallable,
SimilarityFnLiteral,
)
Expand All @@ -28,7 +28,7 @@


@public
class OnlineVectorPicker:
class OnlineVectorPicker(BaseEstimator, TransformerMixin):
"""
Greedy algorithm for picking a subset of vectors in an online fashion.
"""
Expand All @@ -39,7 +39,7 @@ def __init__(
*,
dist_fn: DistanceFnCallable | DistanceFnLiteral = "euclidean",
sim_fn: SimilarityFnCallable | SimilarityFnLiteral = "identity",
loss_fn: PotentialFnCallable | PotentialFnLiteral = "power",
loss_fn: LossFnCallable | LossFnLiteral = "power",
p: float | int = -1,
k_neighbors: int | float = 5, # TODO: add heuristic for better default
threshold: float = -jnp.inf,
Expand All @@ -61,7 +61,7 @@ def __init__(
self.p: float | int = p
loss_fn = get_function_from_literal(loss_fn, module="moll.measures._loss")
loss_fn = partial(loss_fn, p=p) if hasarg(loss_fn, "p") else loss_fn
self.loss_fn: PotentialFnCallable = loss_fn
self.loss_fn: LossFnCallable = loss_fn

self.k_neighbors: int = self._init_k_neighbors(k_neighbors, capacity)

Expand Down Expand Up @@ -150,7 +150,7 @@ def _convert_data(data: Iterable, dtype: DTypeLike | None) -> Array:

return jnp.array(data, dtype=dtype)

def update(
def partial_fit(
self, vectors: Iterable, labels: Indexable[Hashable] | None = None
) -> int:
"""
Expand Down Expand Up @@ -219,17 +219,31 @@ def update(

return n_accepted

def fit(self, X, y=None):
"""
Pick a subset of vectors based on their similarity.
"""
self.partial_fit(X, y)
return self

def add(self, vector: Array, label: Hashable | None = None) -> bool:
"""
Add a vector to the picker.
"""
n_accepted = self.update(
n_accepted = self.partial_fit(
vectors=[vector],
labels=[label] if label else None, # type: ignore
)
is_accepted = n_accepted > 0
return is_accepted

def transform(self, X):
return self.vectors

def fit_transform(self, X, y=None):
self.fit(X, y)
return self.transform(X)

def warm(self, vectors: Array, labels: Indexable[Hashable] | None = None):
"""
Initialize the picker with a set of vectors.
Expand Down
69 changes: 42 additions & 27 deletions moll/pick/tests/test_online_picker.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
from ...utils import dists_to_nearest_neighbor, globs, random_grid_points
from .._online_picker import (
DistanceFnLiteral,
LossFnLiteral,
OnlineVectorPicker,
PotentialFnLiteral,
)

RANDOM_SEED = 42
Expand Down Expand Up @@ -170,7 +170,7 @@ def test_add_many_random(picker, centers_and_vectors):
)


def test_update_many_random(picker, centers_and_vectors, n_batches):
def test_partial_fit_many_random(picker, centers_and_vectors, n_batches):
centers, vectors = centers_and_vectors

assert picker.is_empty() is True
Expand All @@ -180,7 +180,7 @@ def test_update_many_random(picker, centers_and_vectors, n_batches):
n_accepted_total = 0

for batch in batches:
n_accepted = picker.update(batch)
n_accepted = picker.partial_fit(batch)
assert n_accepted >= 0
assert n_accepted <= picker.capacity
n_accepted_total += n_accepted
Expand All @@ -200,7 +200,7 @@ def test_update_many_random(picker, centers_and_vectors, n_batches):
)


def test_update_same_vectors(picker_euclidean: OnlineVectorPicker):
def test_partial_fit_same_vectors(picker_euclidean: OnlineVectorPicker):
center1 = jnp.array([0, 0, 0])
center2 = jnp.array([0, 10, 0])
same_centers = [jnp.array([10, 10, 10]) for _ in range(10)]
Expand Down Expand Up @@ -228,6 +228,20 @@ def test_update_same_vectors(picker_euclidean: OnlineVectorPicker):
assert len(picker_euclidean.labels) == n_accepted


def test_fit_and_partial_fit(picker, centers_and_vectors):
centers, vectors = centers_and_vectors

from copy import deepcopy

picker_fit = deepcopy(picker)
picker_fit.fit(vectors)

picker_partial_fit = picker
picker_partial_fit.partial_fit(vectors)

assert (picker_fit.vectors == picker_partial_fit.vectors).all()


@pytest.fixture
def circles(factor=0.1, random_state=42, n_samples=20):
"""
Expand Down Expand Up @@ -257,7 +271,7 @@ def test_labels_add(picker_euclidean: OnlineVectorPicker, circles):
def test_manual_labels_update(picker_euclidean: OnlineVectorPicker, circles):
vectors, labels = circles

_n_accepted = picker_euclidean.update(vectors, labels=labels)
_n_accepted = picker_euclidean.partial_fit(vectors, labels=labels)

assert picker_euclidean.labels
counts = Counter(circle for circle, idx in picker_euclidean.labels)
Expand All @@ -272,7 +286,7 @@ def test_auto_labels_update(picker_euclidean: OnlineVectorPicker, circles):
large_circle_idxs = {idx for tag, idx in _labels if tag == "large"}
small_circle_idxs = {idx for tag, idx in _labels if tag == "small"}

_n_accepted = picker_euclidean.update(vectors)
_n_accepted = picker_euclidean.partial_fit(vectors)

assert picker_euclidean.labels
labels_generated = set(picker_euclidean.labels)
Expand All @@ -297,7 +311,7 @@ def test_fast_init(picker_euclidean: OnlineVectorPicker, circles, init_batch_siz
labels = labels[init_batch_size:]

picker_euclidean.warm(batch_init, labels_init)
picker_euclidean.update(batch, labels)
picker_euclidean.partial_fit(batch, labels)

assert picker_euclidean.labels
counts = Counter(circle for circle, idx in picker_euclidean.labels)
Expand All @@ -306,21 +320,21 @@ def test_fast_init(picker_euclidean: OnlineVectorPicker, circles, init_batch_siz
assert counts["small"] <= 1


# Test picker custom similarity functions
# Test picker custom distance functions

similarity_fns: tuple = get_args(DistanceFnLiteral) + (
lambda x, y: euclidean(x, y) + 10, # similarities must me ordered, shift is ok
lambda x, y: euclidean(x, y) - 10, # similarities must me ordered, negative is ok
dist_fns: tuple = get_args(DistanceFnLiteral) + (
lambda x, y: euclidean(x, y) + 10, # distances must me ordered, shift is ok
lambda x, y: euclidean(x, y) - 10, # distances must me ordered, negative is ok
)


@pytest.fixture(params=similarity_fns)
def picker_similarity_fn(request):
similarity_fn = request.param
@pytest.fixture(params=dist_fns)
def picker_dist_fn(request):
dist_fn = request.param
return OnlineVectorPicker(
capacity=5,
dist_fn=similarity_fn,
loss_fn="exponential", # exp potential is used to treat negative similarities
dist_fn=dist_fn,
loss_fn="exponential", # exp potential is used to treat negative distances
)


Expand All @@ -334,28 +348,29 @@ def integer_vectors(n_vectors=1_000, dim=10, seed: int = RANDOM_SEED):
)


def test_custom_similarity_fn(picker_similarity_fn, integer_vectors):
picker_similarity_fn.update(integer_vectors)
def test_custom_similarity_fn(picker_dist_fn, integer_vectors):
picker_dist_fn.partial_fit(integer_vectors)

assert picker_similarity_fn.n_seen == len(integer_vectors)
assert picker_similarity_fn.n_accepted == 5
assert picker_dist_fn.n_seen == len(integer_vectors)
assert picker_dist_fn.n_accepted == 5

min_dist_orig = dists_to_nearest_neighbor(integer_vectors, euclidean).min()
min_dist_new = dists_to_nearest_neighbor(
picker_similarity_fn.vectors, euclidean
).min()
min_dist_new = dists_to_nearest_neighbor(picker_dist_fn.vectors, euclidean).min()

# Check that the min pairwise distance is increased by at least a factor:
factor = 1.5

assert (min_dist_new > factor * min_dist_orig).all()


# Test custom potential functions
# TODO: Test custom similarity functions


# Test custom loss functions


loss_fns: tuple = get_args(PotentialFnLiteral) + (
lambda d: jnp.exp(d), # potentials must me ordered, negative is ok
loss_fns: tuple = get_args(LossFnLiteral) + (
lambda d: jnp.exp(d) - 100, # losses must me ordered, negative is ok
)


Expand All @@ -372,7 +387,7 @@ def uniform_rectangle(n_vectors=1_000, dim=2, seed: int = RANDOM_SEED):

def test_custom_loss_fn(picker_loss_fn, uniform_rectangle):
picker = picker_loss_fn
picker.update(uniform_rectangle)
picker.partial_fit(uniform_rectangle)

min_dist_orig = dists_to_nearest_neighbor(uniform_rectangle, euclidean).min()
min_dist_new = dists_to_nearest_neighbor(picker_loss_fn.vectors, euclidean).min()
Expand Down
46 changes: 46 additions & 0 deletions moll/pick/tests/test_online_picker_sklearn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from copy import deepcopy

import numpy as np
import pytest
from sklearn.impute import SimpleImputer
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler

from moll.pick._online_picker import OnlineVectorPicker


@pytest.fixture
def X():
return np.random.rand(10, 3)


@pytest.fixture
def y():
return np.arange(10)


def test_pipeline(X, y):
# Fit the picker through a pipeline
pipeline = Pipeline(
[
("scaler", StandardScaler()),
("picker", OnlineVectorPicker(capacity=10)),
# to test it can handle extra steps, nothing to impute here
("imp", SimpleImputer(strategy="mean")),
]
)
pipeline.fit(X, y)
X_transformed_pipeline = pipeline.transform(X)

# Fit the picker directly
picker2 = OnlineVectorPicker(capacity=10)
picker2.fit(X, y)
X_transformed2 = picker2.transform(X)

# Labels are the same
assert picker2.labels == pipeline.named_steps["picker"].labels

# Vectors are the same after scaling
X_transformed2_scaled = (X_transformed2 - np.mean(X, axis=0)) / np.std(X, axis=0)

assert np.allclose(X_transformed_pipeline, X_transformed2_scaled)
8 changes: 4 additions & 4 deletions moll/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
"DistanceFnCallable",
"SimilarityFnLiteral",
"SimilarityFnCallable",
"PotentialFnLiteral",
"PotentialFnCallable",
"LossFnLiteral",
"LossFnCallable",
"Indexable",
"OneOrMany",
]
Expand All @@ -36,13 +36,13 @@
SimilarityFnCallable: TypeAlias = Callable[[Array], ArrayLike]

# TODO: maybe use abbreviations?
PotentialFnLiteral = Literal[
LossFnLiteral = Literal[
"power",
"exponential",
"lennard_jones",
"logarithmic",
]
PotentialFnCallable: TypeAlias = Callable[[float], ArrayLike]
LossFnCallable: TypeAlias = Callable[[float], ArrayLike]


RDKitMol: TypeAlias = Chem.rdchem.Mol
Expand Down

0 comments on commit 15023f3

Please sign in to comment.