Skip to content

Commit

Permalink
Add maximize argument (#23)
Browse files Browse the repository at this point in the history
  • Loading branch information
vsheg authored Jul 22, 2024
2 parents a74c46a + c3cd1ff commit 40871e8
Show file tree
Hide file tree
Showing 6 changed files with 165 additions and 43 deletions.
12 changes: 10 additions & 2 deletions moll/measures/_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,12 @@ def power(diff: ArrayLike, p: float = 2.0) -> Array:
Array(1000., dtype=...)
"""
diff = jnp.asarray(diff)
return jnp.where(diff > 0, jnp.power(diff, p), jnp.inf)

return jnp.where(
jnp.isclose(diff, 0),
jnp.inf,
jnp.power(diff, p),
)


@public
Expand Down Expand Up @@ -60,12 +65,15 @@ def lennard_jones(diff: ArrayLike, p: float = 1.0) -> Array:
"""
diff = jnp.asarray(diff)
sigma: float = p * 0.5 ** (1 / 6)
return jnp.where(

acc = jnp.where(
diff > 0,
jnp.power(sigma / diff, 12.0) - jnp.power(sigma / diff, 6.0),
jnp.inf,
)

return jnp.where(jnp.isnan(diff), jnp.nan, acc)


@public
@partial(jax.jit, inline=True)
Expand Down
65 changes: 40 additions & 25 deletions moll/pick/_online_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,33 +13,39 @@
from ..utils import dist_matrix, fill_diagonal, matrix_cross_sum


@partial(jax.jit, static_argnames=["dist_fn", "sim_fn", "loss_fn"])
@partial(jax.jit, static_argnames=["dist_fn", "sim_fn", "loss_fn", "maximize"])
def _needless_vector_idx(
vicinity: Array,
dist_fn: Callable[[Array, Array], Array],
sim_fn: Callable[[Array], Array],
loss_fn: Callable[[Array], Array],
maximize: bool = False,
) -> int:
"""
Find a vector in `X` removing which would decrease the total potential the most.
"""
# Calculate matrix of pairwise distances and corresponding matrix of potentials
dist_mat = dist_matrix(vicinity, dist_fn)
sim_mat = jax.vmap(sim_fn)(dist_mat)
potent_mat = jax.vmap(loss_fn)(sim_mat)
loss_mat = jax.vmap(loss_fn)(sim_mat)

# Inherent potential of a vector is 0, it means that the vector by itself does not
# contribute to the total potential. In the future, the penalty for the vector itself
# can be added
potent_mat = fill_diagonal(potent_mat, 0)
loss_mat = fill_diagonal(loss_mat, 0)

# Compute potential decrease when each vector is deleted from the set
deltas = jax.vmap(lambda i: matrix_cross_sum(potent_mat, i, i, row_only=True))(
deltas = jax.vmap(lambda i: matrix_cross_sum(loss_mat, i, i, row_only=True))(
jnp.arange(vicinity.shape[0])
)

# If the goal is to maximize the potential, return the vector with the largest delta
if maximize:
deltas = -deltas

# Find vector that decreases total potential the most (deltas are negative)
idx = deltas.argmax()

return idx


Expand All @@ -51,7 +57,7 @@ def sim_i(i):
return lax.cond(
i < n_valid,
lambda i: sim_fn(dist_fn(x, X[i])),
lambda _: jnp.inf,
lambda _: jnp.nan,
i,
)

Expand All @@ -60,6 +66,7 @@ def sim_i(i):

@partial(jax.jit, static_argnames=["k_neighbors"])
def _k_neighbors(similarities: Array, k_neighbors: int):
# TODO: test approx_min_k vs argpartition
k_neighbors_idxs = lax.approx_min_k(similarities, k=k_neighbors)[1]
return k_neighbors_idxs

Expand All @@ -71,7 +78,9 @@ def _k_neighbors(similarities: Array, k_neighbors: int):
"sim_fn",
"loss_fn",
"k_neighbors",
"min_sim",
"sim_min",
"sim_max",
"maximize",
],
donate_argnames=["x", "X"],
inline=True,
Expand All @@ -84,8 +93,11 @@ def _add_vector(
loss_fn: Callable[[Array], Array],
k_neighbors: int,
n_valid_vectors: int,
min_sim: float,
sim_min: float = -jnp.inf,
sim_max: float = +jnp.inf,
X_pinned: Array | None = None,
maximize: bool = False,
reject_inf: bool = True,
) -> tuple[Array, int]:
"""
Adds a vector `x` to a fixed-size set of vectors `X`.
Expand All @@ -98,25 +110,24 @@ def _add_vector(
n_valid_vectors += n_pinned
min_changeable_idx = n_pinned

def below_threshold_or_infinite_potential(X, _):
def outside_sim_limits_or_loss_positive_inf(X, _):
return X, -1

def above_threshold_and_not_full(X, _):
def within_sim_limits_and_not_full(X, _):
updated_vector_idx = n_valid_vectors
X = X.at[updated_vector_idx].set(x)
return X, updated_vector_idx

def above_threshold_and_full(X, sims):
# TODO: test approx_min_k vs argpartition

def within_sim_limits_and_full(X, sims):
# Find `k_neighbors` most similar vectors to `x`
k_neighbors_idxs = _k_neighbors(sims, k_neighbors=k_neighbors)

# Define a neighborhood of `x`
vicinity = lax.concatenate((jnp.array([x]), X[k_neighbors_idxs]), 0)

# Vector in the vicinity removing which decreases the total potential the most:
needless_vector_vicinity_idx = _needless_vector_idx(
vicinity, dist_fn, sim_fn, loss_fn
vicinity, dist_fn, sim_fn, loss_fn, maximize=maximize
)

# If the needless vector is not `x`, replace it with `x`
Expand All @@ -137,19 +148,19 @@ def above_threshold_and_full(X, sims):
is_full = X.shape[0] == n_valid_vectors

sims = _similarities(x, X, dist_fn, sim_fn, n_valid_vectors)
is_above_threshold = sims.min() > min_sim
losses = jax.vmap(loss_fn)(sims)

has_loss_positive_inf = jnp.nanmax(losses) == jnp.inf
is_within_sim_limits = (sim_min <= jnp.nanmin(sims)) & (jnp.nanmax(sims) <= sim_max)

branches = [
below_threshold_or_infinite_potential,
above_threshold_and_not_full,
above_threshold_and_full,
outside_sim_limits_or_loss_positive_inf,
within_sim_limits_and_not_full,
within_sim_limits_and_full,
]

branch_idx = 0 + (is_above_threshold) + (is_full & is_above_threshold)

# If the potential is infinite, the vector is always rejected
is_potential_infinite = jnp.isinf(loss_fn(sims.min()))
branch_idx *= ~is_potential_infinite
branch_idx = 0 + (is_within_sim_limits) + (is_within_sim_limits & is_full)
branch_idx *= ~(has_loss_positive_inf * reject_inf)

X, updated_vector_idx = lax.switch(branch_idx, branches, X, sims)

Expand Down Expand Up @@ -188,7 +199,7 @@ def _finalize_updates(changes: Array) -> Array:

@partial(
jax.jit,
static_argnames=["dist_fn", "sim_fn", "loss_fn", "k_neighbors"],
static_argnames=["dist_fn", "sim_fn", "loss_fn", "k_neighbors", "maximize"],
donate_argnames=["X", "X_pinned", "xs"],
)
def update_vectors(
Expand All @@ -199,9 +210,11 @@ def update_vectors(
sim_fn: Callable,
loss_fn: Callable,
k_neighbors: int,
min_sim: float,
sim_min: float,
sim_max: float,
n_valid: int,
X_pinned: Array | None = None,
maximize: bool = False,
) -> tuple[Array, Array, Array, int, int]:
assert xs.shape[0] > 0
# assert X.dtype == xs.dtype # TODO: fix dtype
Expand All @@ -220,8 +233,10 @@ def body_fun(carry, x):
sim_fn=sim_fn,
loss_fn=loss_fn,
k_neighbors=k_neighbors,
min_sim=min_sim,
sim_min=sim_min,
sim_max=sim_max,
n_valid_vectors=n_valid_new,
maximize=maximize,
)

n_valid_new = lax.cond(
Expand Down
13 changes: 10 additions & 3 deletions moll/pick/_online_picker.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ def __init__(
loss_fn: LossFnCallable | LossFnLiteral = "power",
p: float | int = -1,
k_neighbors: int | float = 5, # TODO: add heuristic for better default
min_sim: float | None = None,
sim_min: float = -jnp.inf,
sim_max: float = +jnp.inf,
maximize: bool = False,
dtype: DTypeLike | None = None,
):
"""
Expand All @@ -69,7 +71,10 @@ def __init__(

self.k_neighbors: int = self._init_k_neighbors(k_neighbors, capacity)

self.min_sim: float = min_sim or -jnp.inf
self.maximize: bool = maximize

self.sim_min: float = sim_min
self.sim_max: float = sim_max

# Inferred dtype
self.dtype: DTypeLike | None
Expand Down Expand Up @@ -223,8 +228,10 @@ def partial_fit(
sim_fn=self.sim_fn,
loss_fn=self.loss_fn,
k_neighbors=self.k_neighbors,
min_sim=self.min_sim,
sim_min=self.sim_min,
sim_max=self.sim_max,
n_valid=self._n_valid,
maximize=self.maximize,
)

# Update vectors data
Expand Down
64 changes: 61 additions & 3 deletions moll/pick/tests/test_online_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,34 @@ def test_find_needless_vector(array, expected, dist_fn):
assert idx == expected


# Test needless vector search with maximization


@pytest.mark.parametrize(
"array, idx_expected",
[
([0, 0.1, 1], 2),
([0, 1.1, 1], 0),
([0, 0, 0, 1], 3),
([(0.1, 0.1), (0, 0), (1, 1)], 2),
],
)
@pytest.mark.parametrize(
"dist_fn",
[
euclidean,
lambda x, y: euclidean(x, y) - 10, # negative distance is ok
],
)
def test_find_needless_vector_with_maximize(array, idx_expected, dist_fn):
array = jnp.array(array)
# exp potential is used to treat negative distances
idx = _needless_vector_idx(
array, dist_fn, sim_fn=lambda d: d, loss_fn=lambda s: jnp.exp(-s), maximize=True
)
assert idx == idx_expected


# Test add vectors


Expand Down Expand Up @@ -97,7 +125,7 @@ def test_add_vector(X, dist_fn):
loss_fn=lambda s: jnp.exp(-s),
k_neighbors=5,
n_valid_vectors=5,
min_sim=-jnp.inf,
sim_min=-jnp.inf,
)
assert updated_idx >= 0
assert updated_idx == 4
Expand Down Expand Up @@ -134,7 +162,36 @@ def test_add_vector_with_pinned(X, X_pinned, x, updated_idx):
loss_fn=lambda s: s**-1,
k_neighbors=5,
n_valid_vectors=5,
min_sim=0.0,
)
assert upd_idx == updated_idx

if updated_idx == -1:
assert (X_copy == X_updated).all()
else:
assert (X_updated == X_copy.at[updated_idx].set(x)).all()


@pytest.mark.parametrize(
"x, updated_idx",
[
([1.1, 1.1], 4),
([0.1, 0.1], 4),
([4.1, 4.1], 0),
],
)
def test_add_vector_with_maximize(X, x, updated_idx):
X_copy = X.copy()
x = jnp.array(x)

X_updated, upd_idx = _add_vector(
x=x,
X=X,
dist_fn=euclidean,
sim_fn=lambda d: d,
loss_fn=lambda s: s**-1,
k_neighbors=5,
n_valid_vectors=len(X),
maximize=True,
)
assert upd_idx == updated_idx

Expand Down Expand Up @@ -189,7 +246,8 @@ def test_update_vectors(X, X_pinned, xs, acc_mask):
loss_fn=lambda s: s**-1,
k_neighbors=5,
n_valid=5,
min_sim=0.0,
sim_min=0.0,
sim_max=jnp.inf,
)

assert X_copy.shape == X_updated.shape
Expand Down
32 changes: 27 additions & 5 deletions moll/pick/tests/test_online_picker.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from sklearn import datasets

from ...measures import euclidean
from ...utils import dists_to_nearest_neighbor, globs, random_grid_points
from ...utils import dist_matrix, dists_to_others, globs, random_grid_points
from .._online_picker import (
DistanceFnLiteral,
LossFnLiteral,
Expand Down Expand Up @@ -354,8 +354,8 @@ def test_custom_similarity_fn(picker_dist_fn, integer_vectors):
assert picker_dist_fn.n_seen == len(integer_vectors)
assert picker_dist_fn.n_accepted == 5

min_dist_orig = dists_to_nearest_neighbor(integer_vectors, euclidean).min()
min_dist_new = dists_to_nearest_neighbor(picker_dist_fn.vectors, euclidean).min()
min_dist_orig = dists_to_others(integer_vectors, euclidean).min()
min_dist_new = dists_to_others(picker_dist_fn.vectors, euclidean).min()

# Check that the min pairwise distance is increased by at least a factor:
factor = 1.5
Expand Down Expand Up @@ -389,8 +389,8 @@ def test_custom_loss_fn(picker_loss_fn, uniform_rectangle):
picker = picker_loss_fn
picker.partial_fit(uniform_rectangle)

min_dist_orig = dists_to_nearest_neighbor(uniform_rectangle, euclidean).min()
min_dist_new = dists_to_nearest_neighbor(picker_loss_fn.vectors, euclidean).min()
min_dist_orig = dists_to_others(uniform_rectangle).min()
min_dist_new = dists_to_others(picker_loss_fn.vectors).min()

# Check that the min pairwise distance is increased by at least a factor:
factor = 1.5
Expand Down Expand Up @@ -427,3 +427,25 @@ def test_picker_with_pinned_vectors(picker_with_pinned_vectors, vectors):

values, indices = jax.lax.top_k(-picker.vectors[:, 0], k=3)
assert jnp.allclose(picker.vectors, vectors[indices])


# Test with maximize argument


@pytest.fixture
def picker_with_loss_maximization():
return OnlineVectorPicker(capacity=3, k_neighbors=3, maximize=True)


def test_picker_with_loss_maximization(
picker_with_loss_maximization, uniform_rectangle
):
picker = picker_with_loss_maximization
vectors = uniform_rectangle
picker.fit(vectors)

dist_median_mean_before = dists_to_others(vectors).mean()
dist_median_min_after = dists_to_others(picker.vectors).mean()

# Test that the mean median distance is decreased by at least a factor of 1.5
assert dist_median_min_after < dist_median_mean_before / 1.5
Loading

0 comments on commit 40871e8

Please sign in to comment.