Skip to content

Commit

Permalink
feat: Add support for maximizing loss in _add_vector
Browse files Browse the repository at this point in the history
  • Loading branch information
vsheg committed Jul 16, 2024
1 parent 9c03735 commit 31c455b
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 3 deletions.
2 changes: 1 addition & 1 deletion moll/pick/_online_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def within_sim_limits_and_full(X, sims):

# Vector in the vicinity removing which decreases the total potential the most:
needless_vector_vicinity_idx = _needless_vector_idx(
vicinity, dist_fn, sim_fn, loss_fn
vicinity, dist_fn, sim_fn, loss_fn, maximize=maximize
)

# If the needless vector is not `x`, replace it with `x`
Expand Down
34 changes: 32 additions & 2 deletions moll/pick/tests/test_online_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,36 @@ def test_add_vector_with_pinned(X, X_pinned, x, updated_idx):
loss_fn=lambda s: s**-1,
k_neighbors=5,
n_valid_vectors=5,
min_sim=0.0,
)
assert upd_idx == updated_idx

if updated_idx == -1:
assert (X_copy == X_updated).all()
else:
assert (X_updated == X_copy.at[updated_idx].set(x)).all()


@pytest.mark.parametrize(
"x, updated_idx",
[
([1.1, 1.1], 4),
([0.1, 0.1], 4),
([4.1, 4.1], 0),
],
)
def test_add_vector_with_maximize(X, x, updated_idx):
X_copy = X.copy()
x = jnp.array(x)

X_updated, upd_idx = _add_vector(
x=x,
X=X,
dist_fn=euclidean,
sim_fn=lambda d: d,
loss_fn=lambda s: s**-1,
k_neighbors=5,
n_valid_vectors=len(X),
maximize=True,
)
assert upd_idx == updated_idx

Expand Down Expand Up @@ -217,7 +246,8 @@ def test_update_vectors(X, X_pinned, xs, acc_mask):
loss_fn=lambda s: s**-1,
k_neighbors=5,
n_valid=5,
min_sim=0.0,
sim_min=0.0,
sim_max=jnp.inf,
)

assert X_copy.shape == X_updated.shape
Expand Down

0 comments on commit 31c455b

Please sign in to comment.