diff --git a/moll/measures/_loss.py b/moll/measures/_loss.py index 64fe5bc..3de4149 100644 --- a/moll/measures/_loss.py +++ b/moll/measures/_loss.py @@ -23,7 +23,12 @@ def power(diff: ArrayLike, p: float = 2.0) -> Array: Array(1000., dtype=...) """ diff = jnp.asarray(diff) - return jnp.where(diff > 0, jnp.power(diff, p), jnp.inf) + + return jnp.where( + jnp.isclose(diff, 0), + jnp.inf, + jnp.power(diff, p), + ) @public @@ -60,12 +65,15 @@ def lennard_jones(diff: ArrayLike, p: float = 1.0) -> Array: """ diff = jnp.asarray(diff) sigma: float = p * 0.5 ** (1 / 6) - return jnp.where( + + acc = jnp.where( diff > 0, jnp.power(sigma / diff, 12.0) - jnp.power(sigma / diff, 6.0), jnp.inf, ) + return jnp.where(jnp.isnan(diff), jnp.nan, acc) + @public @partial(jax.jit, inline=True) diff --git a/moll/pick/_online_add.py b/moll/pick/_online_add.py index 3d566f7..498b80a 100644 --- a/moll/pick/_online_add.py +++ b/moll/pick/_online_add.py @@ -13,12 +13,13 @@ from ..utils import dist_matrix, fill_diagonal, matrix_cross_sum -@partial(jax.jit, static_argnames=["dist_fn", "sim_fn", "loss_fn"]) +@partial(jax.jit, static_argnames=["dist_fn", "sim_fn", "loss_fn", "maximize"]) def _needless_vector_idx( vicinity: Array, dist_fn: Callable[[Array, Array], Array], sim_fn: Callable[[Array], Array], loss_fn: Callable[[Array], Array], + maximize: bool = False, ) -> int: """ Find a vector in `X` removing which would decrease the total potential the most. @@ -26,20 +27,25 @@ def _needless_vector_idx( # Calculate matrix of pairwise distances and corresponding matrix of potentials dist_mat = dist_matrix(vicinity, dist_fn) sim_mat = jax.vmap(sim_fn)(dist_mat) - potent_mat = jax.vmap(loss_fn)(sim_mat) + loss_mat = jax.vmap(loss_fn)(sim_mat) # Inherent potential of a vector is 0, it means that the vector by itself does not # contribute to the total potential. In the future, the penalty for the vector itself # can be added - potent_mat = fill_diagonal(potent_mat, 0) + loss_mat = fill_diagonal(loss_mat, 0) # Compute potential decrease when each vector is deleted from the set - deltas = jax.vmap(lambda i: matrix_cross_sum(potent_mat, i, i, row_only=True))( + deltas = jax.vmap(lambda i: matrix_cross_sum(loss_mat, i, i, row_only=True))( jnp.arange(vicinity.shape[0]) ) + # If the goal is to maximize the potential, return the vector with the largest delta + if maximize: + deltas = -deltas + # Find vector that decreases total potential the most (deltas are negative) idx = deltas.argmax() + return idx @@ -51,7 +57,7 @@ def sim_i(i): return lax.cond( i < n_valid, lambda i: sim_fn(dist_fn(x, X[i])), - lambda _: jnp.inf, + lambda _: jnp.nan, i, ) @@ -60,6 +66,7 @@ def sim_i(i): @partial(jax.jit, static_argnames=["k_neighbors"]) def _k_neighbors(similarities: Array, k_neighbors: int): + # TODO: test approx_min_k vs argpartition k_neighbors_idxs = lax.approx_min_k(similarities, k=k_neighbors)[1] return k_neighbors_idxs @@ -71,7 +78,9 @@ def _k_neighbors(similarities: Array, k_neighbors: int): "sim_fn", "loss_fn", "k_neighbors", - "min_sim", + "sim_min", + "sim_max", + "maximize", ], donate_argnames=["x", "X"], inline=True, @@ -84,8 +93,11 @@ def _add_vector( loss_fn: Callable[[Array], Array], k_neighbors: int, n_valid_vectors: int, - min_sim: float, + sim_min: float = -jnp.inf, + sim_max: float = +jnp.inf, X_pinned: Array | None = None, + maximize: bool = False, + reject_inf: bool = True, ) -> tuple[Array, int]: """ Adds a vector `x` to a fixed-size set of vectors `X`. @@ -98,17 +110,16 @@ def _add_vector( n_valid_vectors += n_pinned min_changeable_idx = n_pinned - def below_threshold_or_infinite_potential(X, _): + def outside_sim_limits_or_loss_positive_inf(X, _): return X, -1 - def above_threshold_and_not_full(X, _): + def within_sim_limits_and_not_full(X, _): updated_vector_idx = n_valid_vectors X = X.at[updated_vector_idx].set(x) return X, updated_vector_idx - def above_threshold_and_full(X, sims): - # TODO: test approx_min_k vs argpartition - + def within_sim_limits_and_full(X, sims): + # Find `k_neighbors` most similar vectors to `x` k_neighbors_idxs = _k_neighbors(sims, k_neighbors=k_neighbors) # Define a neighborhood of `x` @@ -116,7 +127,7 @@ def above_threshold_and_full(X, sims): # Vector in the vicinity removing which decreases the total potential the most: needless_vector_vicinity_idx = _needless_vector_idx( - vicinity, dist_fn, sim_fn, loss_fn + vicinity, dist_fn, sim_fn, loss_fn, maximize=maximize ) # If the needless vector is not `x`, replace it with `x` @@ -137,19 +148,19 @@ def above_threshold_and_full(X, sims): is_full = X.shape[0] == n_valid_vectors sims = _similarities(x, X, dist_fn, sim_fn, n_valid_vectors) - is_above_threshold = sims.min() > min_sim + losses = jax.vmap(loss_fn)(sims) + + has_loss_positive_inf = jnp.nanmax(losses) == jnp.inf + is_within_sim_limits = (sim_min <= jnp.nanmin(sims)) & (jnp.nanmax(sims) <= sim_max) branches = [ - below_threshold_or_infinite_potential, - above_threshold_and_not_full, - above_threshold_and_full, + outside_sim_limits_or_loss_positive_inf, + within_sim_limits_and_not_full, + within_sim_limits_and_full, ] - branch_idx = 0 + (is_above_threshold) + (is_full & is_above_threshold) - - # If the potential is infinite, the vector is always rejected - is_potential_infinite = jnp.isinf(loss_fn(sims.min())) - branch_idx *= ~is_potential_infinite + branch_idx = 0 + (is_within_sim_limits) + (is_within_sim_limits & is_full) + branch_idx *= ~(has_loss_positive_inf * reject_inf) X, updated_vector_idx = lax.switch(branch_idx, branches, X, sims) @@ -188,7 +199,7 @@ def _finalize_updates(changes: Array) -> Array: @partial( jax.jit, - static_argnames=["dist_fn", "sim_fn", "loss_fn", "k_neighbors"], + static_argnames=["dist_fn", "sim_fn", "loss_fn", "k_neighbors", "maximize"], donate_argnames=["X", "X_pinned", "xs"], ) def update_vectors( @@ -199,9 +210,11 @@ def update_vectors( sim_fn: Callable, loss_fn: Callable, k_neighbors: int, - min_sim: float, + sim_min: float, + sim_max: float, n_valid: int, X_pinned: Array | None = None, + maximize: bool = False, ) -> tuple[Array, Array, Array, int, int]: assert xs.shape[0] > 0 # assert X.dtype == xs.dtype # TODO: fix dtype @@ -220,8 +233,10 @@ def body_fun(carry, x): sim_fn=sim_fn, loss_fn=loss_fn, k_neighbors=k_neighbors, - min_sim=min_sim, + sim_min=sim_min, + sim_max=sim_max, n_valid_vectors=n_valid_new, + maximize=maximize, ) n_valid_new = lax.cond( diff --git a/moll/pick/_online_picker.py b/moll/pick/_online_picker.py index d3db54e..aaa4f05 100644 --- a/moll/pick/_online_picker.py +++ b/moll/pick/_online_picker.py @@ -44,7 +44,9 @@ def __init__( loss_fn: LossFnCallable | LossFnLiteral = "power", p: float | int = -1, k_neighbors: int | float = 5, # TODO: add heuristic for better default - min_sim: float | None = None, + sim_min: float = -jnp.inf, + sim_max: float = +jnp.inf, + maximize: bool = False, dtype: DTypeLike | None = None, ): """ @@ -69,7 +71,10 @@ def __init__( self.k_neighbors: int = self._init_k_neighbors(k_neighbors, capacity) - self.min_sim: float = min_sim or -jnp.inf + self.maximize: bool = maximize + + self.sim_min: float = sim_min + self.sim_max: float = sim_max # Inferred dtype self.dtype: DTypeLike | None @@ -223,8 +228,10 @@ def partial_fit( sim_fn=self.sim_fn, loss_fn=self.loss_fn, k_neighbors=self.k_neighbors, - min_sim=self.min_sim, + sim_min=self.sim_min, + sim_max=self.sim_max, n_valid=self._n_valid, + maximize=self.maximize, ) # Update vectors data diff --git a/moll/pick/tests/test_online_add.py b/moll/pick/tests/test_online_add.py index c955027..2d74778 100644 --- a/moll/pick/tests/test_online_add.py +++ b/moll/pick/tests/test_online_add.py @@ -61,6 +61,34 @@ def test_find_needless_vector(array, expected, dist_fn): assert idx == expected +# Test needless vector search with maximization + + +@pytest.mark.parametrize( + "array, idx_expected", + [ + ([0, 0.1, 1], 2), + ([0, 1.1, 1], 0), + ([0, 0, 0, 1], 3), + ([(0.1, 0.1), (0, 0), (1, 1)], 2), + ], +) +@pytest.mark.parametrize( + "dist_fn", + [ + euclidean, + lambda x, y: euclidean(x, y) - 10, # negative distance is ok + ], +) +def test_find_needless_vector_with_maximize(array, idx_expected, dist_fn): + array = jnp.array(array) + # exp potential is used to treat negative distances + idx = _needless_vector_idx( + array, dist_fn, sim_fn=lambda d: d, loss_fn=lambda s: jnp.exp(-s), maximize=True + ) + assert idx == idx_expected + + # Test add vectors @@ -97,7 +125,7 @@ def test_add_vector(X, dist_fn): loss_fn=lambda s: jnp.exp(-s), k_neighbors=5, n_valid_vectors=5, - min_sim=-jnp.inf, + sim_min=-jnp.inf, ) assert updated_idx >= 0 assert updated_idx == 4 @@ -134,7 +162,36 @@ def test_add_vector_with_pinned(X, X_pinned, x, updated_idx): loss_fn=lambda s: s**-1, k_neighbors=5, n_valid_vectors=5, - min_sim=0.0, + ) + assert upd_idx == updated_idx + + if updated_idx == -1: + assert (X_copy == X_updated).all() + else: + assert (X_updated == X_copy.at[updated_idx].set(x)).all() + + +@pytest.mark.parametrize( + "x, updated_idx", + [ + ([1.1, 1.1], 4), + ([0.1, 0.1], 4), + ([4.1, 4.1], 0), + ], +) +def test_add_vector_with_maximize(X, x, updated_idx): + X_copy = X.copy() + x = jnp.array(x) + + X_updated, upd_idx = _add_vector( + x=x, + X=X, + dist_fn=euclidean, + sim_fn=lambda d: d, + loss_fn=lambda s: s**-1, + k_neighbors=5, + n_valid_vectors=len(X), + maximize=True, ) assert upd_idx == updated_idx @@ -189,7 +246,8 @@ def test_update_vectors(X, X_pinned, xs, acc_mask): loss_fn=lambda s: s**-1, k_neighbors=5, n_valid=5, - min_sim=0.0, + sim_min=0.0, + sim_max=jnp.inf, ) assert X_copy.shape == X_updated.shape diff --git a/moll/pick/tests/test_online_picker.py b/moll/pick/tests/test_online_picker.py index 2281e3c..22ce20e 100644 --- a/moll/pick/tests/test_online_picker.py +++ b/moll/pick/tests/test_online_picker.py @@ -9,7 +9,7 @@ from sklearn import datasets from ...measures import euclidean -from ...utils import dists_to_nearest_neighbor, globs, random_grid_points +from ...utils import dist_matrix, dists_to_others, globs, random_grid_points from .._online_picker import ( DistanceFnLiteral, LossFnLiteral, @@ -354,8 +354,8 @@ def test_custom_similarity_fn(picker_dist_fn, integer_vectors): assert picker_dist_fn.n_seen == len(integer_vectors) assert picker_dist_fn.n_accepted == 5 - min_dist_orig = dists_to_nearest_neighbor(integer_vectors, euclidean).min() - min_dist_new = dists_to_nearest_neighbor(picker_dist_fn.vectors, euclidean).min() + min_dist_orig = dists_to_others(integer_vectors, euclidean).min() + min_dist_new = dists_to_others(picker_dist_fn.vectors, euclidean).min() # Check that the min pairwise distance is increased by at least a factor: factor = 1.5 @@ -389,8 +389,8 @@ def test_custom_loss_fn(picker_loss_fn, uniform_rectangle): picker = picker_loss_fn picker.partial_fit(uniform_rectangle) - min_dist_orig = dists_to_nearest_neighbor(uniform_rectangle, euclidean).min() - min_dist_new = dists_to_nearest_neighbor(picker_loss_fn.vectors, euclidean).min() + min_dist_orig = dists_to_others(uniform_rectangle).min() + min_dist_new = dists_to_others(picker_loss_fn.vectors).min() # Check that the min pairwise distance is increased by at least a factor: factor = 1.5 @@ -427,3 +427,25 @@ def test_picker_with_pinned_vectors(picker_with_pinned_vectors, vectors): values, indices = jax.lax.top_k(-picker.vectors[:, 0], k=3) assert jnp.allclose(picker.vectors, vectors[indices]) + + +# Test with maximize argument + + +@pytest.fixture +def picker_with_loss_maximization(): + return OnlineVectorPicker(capacity=3, k_neighbors=3, maximize=True) + + +def test_picker_with_loss_maximization( + picker_with_loss_maximization, uniform_rectangle +): + picker = picker_with_loss_maximization + vectors = uniform_rectangle + picker.fit(vectors) + + dist_median_mean_before = dists_to_others(vectors).mean() + dist_median_min_after = dists_to_others(picker.vectors).mean() + + # Test that the mean median distance is decreased by at least a factor of 1.5 + assert dist_median_min_after < dist_median_mean_before / 1.5 diff --git a/moll/utils/_utils.py b/moll/utils/_utils.py index f0b16b4..ba27f7e 100644 --- a/moll/utils/_utils.py +++ b/moll/utils/_utils.py @@ -13,6 +13,7 @@ from numpy.typing import NDArray from public import public +from ..measures._distance import euclidean from ._decorators import listify Seed: TypeAlias = int | Array | None @@ -218,12 +219,23 @@ def matrix_cross_sum(X: Array, i: int, j: int, row_only=False, crossover=True): @public -@partial(jax.jit, static_argnames="dist_fn") -def dists_to_nearest_neighbor(vectors, dist_fn): - """Compute pairwise distances between vectors.""" +@partial(jax.jit, static_argnames=["dist_fn", "reduce_fn"]) +def dists_to_others(vectors, dist_fn=euclidean, reduce_fn=jnp.nanmin): + """ + Compute the distance of each vector to all other vectors and reduce to a single value. + + Examples: + >>> vectors = jnp.array([[0, 0], [1, 0], [0, 1]]) + + >>> dists_to_others(vectors).tolist() + [1.0, 1.0, 1.0] + + >>> dists_to_others(vectors, reduce_fn=jnp.nanmax).tolist() + [1.0, 1.41..., 1.41...] + """ dists_ = dist_matrix(vectors, dist_fn) - dists_ = fill_diagonal(dists_, jnp.inf) - return jnp.min(dists_, axis=0) + dists_ = fill_diagonal(dists_, jnp.nan) + return reduce_fn(dists_, axis=0) @public