Skip to content

Commit

Permalink
feat: implement better support for sim_min and sim_max arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
vsheg committed Jul 16, 2024
1 parent a74c46a commit 1f183bc
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 31 deletions.
79 changes: 51 additions & 28 deletions moll/pick/_online_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@
from ..utils import dist_matrix, fill_diagonal, matrix_cross_sum


@partial(jax.jit, static_argnames=["dist_fn", "sim_fn", "loss_fn"])
@partial(jax.jit, static_argnames=["dist_fn", "sim_fn", "loss_fn", "reject_inf"])
def _needless_vector_idx(
vicinity: Array,
dist_fn: Callable[[Array, Array], Array],
sim_fn: Callable[[Array], Array],
loss_fn: Callable[[Array], Array],
reject_inf: bool = True,
) -> int:
"""
Find a vector in `X` removing which would decrease the total potential the most.
Expand All @@ -28,18 +29,33 @@ def _needless_vector_idx(
sim_mat = jax.vmap(sim_fn)(dist_mat)
potent_mat = jax.vmap(loss_fn)(sim_mat)

# Inherent potential of a vector is 0, it means that the vector by itself does not
# contribute to the total potential. In the future, the penalty for the vector itself
# can be added
potent_mat = fill_diagonal(potent_mat, 0)
is_loss_positive_inf = potent_mat.max() is jnp.inf

# Compute potential decrease when each vector is deleted from the set
deltas = jax.vmap(lambda i: matrix_cross_sum(potent_mat, i, i, row_only=True))(
jnp.arange(vicinity.shape[0])
def reject(_):
return 0

def do_work(potent_mat):
# Inherent potential of a vector is 0, it means that the vector by itself does not
# contribute to the total potential. In the future, the penalty for the vector itself
# can be added
potent_mat = fill_diagonal(potent_mat, 0)

# Compute potential decrease when each vector is deleted from the set
deltas = jax.vmap(lambda i: matrix_cross_sum(potent_mat, i, i, row_only=True))(
jnp.arange(vicinity.shape[0])
)

# Find vector that decreases total potential the most (deltas are negative)
idx = deltas.argmax()
return idx

idx = lax.cond(
reject_inf & is_loss_positive_inf,
reject,
do_work,
potent_mat,
)

# Find vector that decreases total potential the most (deltas are negative)
idx = deltas.argmax()
return idx


Expand All @@ -60,6 +76,7 @@ def sim_i(i):

@partial(jax.jit, static_argnames=["k_neighbors"])
def _k_neighbors(similarities: Array, k_neighbors: int):
# TODO: test approx_min_k vs argpartition
k_neighbors_idxs = lax.approx_min_k(similarities, k=k_neighbors)[1]
return k_neighbors_idxs

Expand All @@ -71,7 +88,9 @@ def _k_neighbors(similarities: Array, k_neighbors: int):
"sim_fn",
"loss_fn",
"k_neighbors",
"min_sim",
"sim_min",
"sim_max",
"maximize",
],
donate_argnames=["x", "X"],
inline=True,
Expand All @@ -84,8 +103,11 @@ def _add_vector(
loss_fn: Callable[[Array], Array],
k_neighbors: int,
n_valid_vectors: int,
min_sim: float,
sim_min: float = -jnp.inf,
sim_max: float = +jnp.inf,
X_pinned: Array | None = None,
maximize: bool = False,
reject_inf: bool = True,
) -> tuple[Array, int]:
"""
Adds a vector `x` to a fixed-size set of vectors `X`.
Expand All @@ -98,17 +120,16 @@ def _add_vector(
n_valid_vectors += n_pinned
min_changeable_idx = n_pinned

def below_threshold_or_infinite_potential(X, _):
def outside_sim_limits_or_loss_positive_inf(X, _):
return X, -1

def above_threshold_and_not_full(X, _):
def within_sim_limits_and_not_full(X, _):
updated_vector_idx = n_valid_vectors
X = X.at[updated_vector_idx].set(x)
return X, updated_vector_idx

def above_threshold_and_full(X, sims):
# TODO: test approx_min_k vs argpartition

def within_sim_limits_and_full(X, sims):
# Find `k_neighbors` most similar vectors to `x`
k_neighbors_idxs = _k_neighbors(sims, k_neighbors=k_neighbors)

# Define a neighborhood of `x`
Expand Down Expand Up @@ -137,19 +158,19 @@ def above_threshold_and_full(X, sims):
is_full = X.shape[0] == n_valid_vectors

sims = _similarities(x, X, dist_fn, sim_fn, n_valid_vectors)
is_above_threshold = sims.min() > min_sim
losses = jax.vmap(loss_fn)(sims)

has_loss_positive_inf = losses.max() == jnp.inf
is_within_sim_limits = (sim_min <= sims.min()) & (sims.max() <= sim_max)

branches = [
below_threshold_or_infinite_potential,
above_threshold_and_not_full,
above_threshold_and_full,
outside_sim_limits_or_loss_positive_inf,
within_sim_limits_and_not_full,
within_sim_limits_and_full,
]

branch_idx = 0 + (is_above_threshold) + (is_full & is_above_threshold)

# If the potential is infinite, the vector is always rejected
is_potential_infinite = jnp.isinf(loss_fn(sims.min()))
branch_idx *= ~is_potential_infinite
branch_idx = 0 + (is_within_sim_limits) + (is_within_sim_limits & is_full)
branch_idx *= ~(has_loss_positive_inf * reject_inf)

X, updated_vector_idx = lax.switch(branch_idx, branches, X, sims)

Expand Down Expand Up @@ -199,7 +220,8 @@ def update_vectors(
sim_fn: Callable,
loss_fn: Callable,
k_neighbors: int,
min_sim: float,
sim_min: float,
sim_max: float,
n_valid: int,
X_pinned: Array | None = None,
) -> tuple[Array, Array, Array, int, int]:
Expand All @@ -220,7 +242,8 @@ def body_fun(carry, x):
sim_fn=sim_fn,
loss_fn=loss_fn,
k_neighbors=k_neighbors,
min_sim=min_sim,
sim_min=sim_min,
sim_max=sim_max,
n_valid_vectors=n_valid_new,
)

Expand Down
12 changes: 9 additions & 3 deletions moll/pick/_online_picker.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ def __init__(
loss_fn: LossFnCallable | LossFnLiteral = "power",
p: float | int = -1,
k_neighbors: int | float = 5, # TODO: add heuristic for better default
min_sim: float | None = None,
sim_min: float = -jnp.inf,
sim_max: float = +jnp.inf,
maximize: bool = False,
dtype: DTypeLike | None = None,
):
"""
Expand All @@ -69,7 +71,10 @@ def __init__(

self.k_neighbors: int = self._init_k_neighbors(k_neighbors, capacity)

self.min_sim: float = min_sim or -jnp.inf
self.maximize: bool = maximize

self.sim_min: float = sim_min
self.sim_max: float = sim_max

# Inferred dtype
self.dtype: DTypeLike | None
Expand Down Expand Up @@ -223,7 +228,8 @@ def partial_fit(
sim_fn=self.sim_fn,
loss_fn=self.loss_fn,
k_neighbors=self.k_neighbors,
min_sim=self.min_sim,
sim_min=self.sim_min,
sim_max=self.sim_max,
n_valid=self._n_valid,
)

Expand Down

0 comments on commit 1f183bc

Please sign in to comment.