diff --git a/moll/pick/_online_add.py b/moll/pick/_online_add.py index 3d566f7..a2edf6b 100644 --- a/moll/pick/_online_add.py +++ b/moll/pick/_online_add.py @@ -13,12 +13,13 @@ from ..utils import dist_matrix, fill_diagonal, matrix_cross_sum -@partial(jax.jit, static_argnames=["dist_fn", "sim_fn", "loss_fn"]) +@partial(jax.jit, static_argnames=["dist_fn", "sim_fn", "loss_fn", "reject_inf"]) def _needless_vector_idx( vicinity: Array, dist_fn: Callable[[Array, Array], Array], sim_fn: Callable[[Array], Array], loss_fn: Callable[[Array], Array], + reject_inf: bool = True, ) -> int: """ Find a vector in `X` removing which would decrease the total potential the most. @@ -28,18 +29,33 @@ def _needless_vector_idx( sim_mat = jax.vmap(sim_fn)(dist_mat) potent_mat = jax.vmap(loss_fn)(sim_mat) - # Inherent potential of a vector is 0, it means that the vector by itself does not - # contribute to the total potential. In the future, the penalty for the vector itself - # can be added - potent_mat = fill_diagonal(potent_mat, 0) + is_loss_positive_inf = potent_mat.max() is jnp.inf - # Compute potential decrease when each vector is deleted from the set - deltas = jax.vmap(lambda i: matrix_cross_sum(potent_mat, i, i, row_only=True))( - jnp.arange(vicinity.shape[0]) + def reject(_): + return 0 + + def do_work(potent_mat): + # Inherent potential of a vector is 0, it means that the vector by itself does not + # contribute to the total potential. In the future, the penalty for the vector itself + # can be added + potent_mat = fill_diagonal(potent_mat, 0) + + # Compute potential decrease when each vector is deleted from the set + deltas = jax.vmap(lambda i: matrix_cross_sum(potent_mat, i, i, row_only=True))( + jnp.arange(vicinity.shape[0]) + ) + + # Find vector that decreases total potential the most (deltas are negative) + idx = deltas.argmax() + return idx + + idx = lax.cond( + reject_inf & is_loss_positive_inf, + reject, + do_work, + potent_mat, ) - # Find vector that decreases total potential the most (deltas are negative) - idx = deltas.argmax() return idx @@ -60,6 +76,7 @@ def sim_i(i): @partial(jax.jit, static_argnames=["k_neighbors"]) def _k_neighbors(similarities: Array, k_neighbors: int): + # TODO: test approx_min_k vs argpartition k_neighbors_idxs = lax.approx_min_k(similarities, k=k_neighbors)[1] return k_neighbors_idxs @@ -71,7 +88,9 @@ def _k_neighbors(similarities: Array, k_neighbors: int): "sim_fn", "loss_fn", "k_neighbors", - "min_sim", + "sim_min", + "sim_max", + "maximize", ], donate_argnames=["x", "X"], inline=True, @@ -84,8 +103,11 @@ def _add_vector( loss_fn: Callable[[Array], Array], k_neighbors: int, n_valid_vectors: int, - min_sim: float, + sim_min: float = -jnp.inf, + sim_max: float = +jnp.inf, X_pinned: Array | None = None, + maximize: bool = False, + reject_inf: bool = True, ) -> tuple[Array, int]: """ Adds a vector `x` to a fixed-size set of vectors `X`. @@ -98,17 +120,16 @@ def _add_vector( n_valid_vectors += n_pinned min_changeable_idx = n_pinned - def below_threshold_or_infinite_potential(X, _): + def outside_sim_limits_or_loss_positive_inf(X, _): return X, -1 - def above_threshold_and_not_full(X, _): + def within_sim_limits_and_not_full(X, _): updated_vector_idx = n_valid_vectors X = X.at[updated_vector_idx].set(x) return X, updated_vector_idx - def above_threshold_and_full(X, sims): - # TODO: test approx_min_k vs argpartition - + def within_sim_limits_and_full(X, sims): + # Find `k_neighbors` most similar vectors to `x` k_neighbors_idxs = _k_neighbors(sims, k_neighbors=k_neighbors) # Define a neighborhood of `x` @@ -137,19 +158,19 @@ def above_threshold_and_full(X, sims): is_full = X.shape[0] == n_valid_vectors sims = _similarities(x, X, dist_fn, sim_fn, n_valid_vectors) - is_above_threshold = sims.min() > min_sim + losses = jax.vmap(loss_fn)(sims) + + has_loss_positive_inf = losses.max() == jnp.inf + is_within_sim_limits = (sim_min <= sims.min()) & (sims.max() <= sim_max) branches = [ - below_threshold_or_infinite_potential, - above_threshold_and_not_full, - above_threshold_and_full, + outside_sim_limits_or_loss_positive_inf, + within_sim_limits_and_not_full, + within_sim_limits_and_full, ] - branch_idx = 0 + (is_above_threshold) + (is_full & is_above_threshold) - - # If the potential is infinite, the vector is always rejected - is_potential_infinite = jnp.isinf(loss_fn(sims.min())) - branch_idx *= ~is_potential_infinite + branch_idx = 0 + (is_within_sim_limits) + (is_within_sim_limits & is_full) + branch_idx *= ~(has_loss_positive_inf * reject_inf) X, updated_vector_idx = lax.switch(branch_idx, branches, X, sims) @@ -199,7 +220,8 @@ def update_vectors( sim_fn: Callable, loss_fn: Callable, k_neighbors: int, - min_sim: float, + sim_min: float, + sim_max: float, n_valid: int, X_pinned: Array | None = None, ) -> tuple[Array, Array, Array, int, int]: @@ -220,7 +242,8 @@ def body_fun(carry, x): sim_fn=sim_fn, loss_fn=loss_fn, k_neighbors=k_neighbors, - min_sim=min_sim, + sim_min=sim_min, + sim_max=sim_max, n_valid_vectors=n_valid_new, ) diff --git a/moll/pick/_online_picker.py b/moll/pick/_online_picker.py index d3db54e..b7c2316 100644 --- a/moll/pick/_online_picker.py +++ b/moll/pick/_online_picker.py @@ -44,7 +44,9 @@ def __init__( loss_fn: LossFnCallable | LossFnLiteral = "power", p: float | int = -1, k_neighbors: int | float = 5, # TODO: add heuristic for better default - min_sim: float | None = None, + sim_min: float = -jnp.inf, + sim_max: float = +jnp.inf, + maximize: bool = False, dtype: DTypeLike | None = None, ): """ @@ -69,7 +71,10 @@ def __init__( self.k_neighbors: int = self._init_k_neighbors(k_neighbors, capacity) - self.min_sim: float = min_sim or -jnp.inf + self.maximize: bool = maximize + + self.sim_min: float = sim_min + self.sim_max: float = sim_max # Inferred dtype self.dtype: DTypeLike | None @@ -223,7 +228,8 @@ def partial_fit( sim_fn=self.sim_fn, loss_fn=self.loss_fn, k_neighbors=self.k_neighbors, - min_sim=self.min_sim, + sim_min=self.sim_min, + sim_max=self.sim_max, n_valid=self._n_valid, )