Skip to content

Commit

Permalink
refactor: handle values corresponding to non-valid vector as nans
Browse files Browse the repository at this point in the history
  • Loading branch information
vsheg committed Jul 22, 2024
1 parent f4a8068 commit 67dda2b
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions moll/pick/_online_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def sim_i(i):
return lax.cond(
i < n_valid,
lambda i: sim_fn(dist_fn(x, X[i])),
lambda _: jnp.inf,
lambda _: jnp.nan,
i,
)

Expand Down Expand Up @@ -150,8 +150,8 @@ def within_sim_limits_and_full(X, sims):
sims = _similarities(x, X, dist_fn, sim_fn, n_valid_vectors)
losses = jax.vmap(loss_fn)(sims)

has_loss_positive_inf = losses.max() == jnp.inf
is_within_sim_limits = (sim_min <= sims.min()) & (sims.max() <= sim_max)
has_loss_positive_inf = jnp.nanmax(losses) == jnp.inf
is_within_sim_limits = (sim_min <= jnp.nanmin(sims)) & (jnp.nanmax(sims) <= sim_max)

branches = [
outside_sim_limits_or_loss_positive_inf,
Expand Down

0 comments on commit 67dda2b

Please sign in to comment.