Skip to content

Commit

Permalink
Increase CUDA performance for OnlineDiversityPicker (#12)
Browse files Browse the repository at this point in the history
* chore: update pyproject.toml
* perf: refactor picker
* chore: bump version
  • Loading branch information
vsheg authored Nov 19, 2023
1 parent d7a2fd8 commit 5bcda4a
Show file tree
Hide file tree
Showing 6 changed files with 914 additions and 759 deletions.
283 changes: 94 additions & 189 deletions moll/core/distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import jax
import jax.numpy as jnp
from jax import lax
from loguru import logger

from moll.core.utils import fill_diagonal

Expand All @@ -30,15 +29,12 @@ def tanimoto(a: jnp.ndarray, b: jnp.ndarray) -> float:
bitwise_and = jnp.bitwise_and(a, b).sum().astype(float)

# Check for the case where both vectors are all zeros and return 0.0 in that case
dist = jax.lax.cond(
return jax.lax.cond(
bitwise_or == 0.0,
lambda _: 0.0,
lambda _: 1 - (bitwise_and / bitwise_or),
None,
lambda: 0.0,
lambda: 1 - (bitwise_and / bitwise_or),
)

return dist


@jax.jit
def euclidean(p1, p2):
Expand All @@ -48,25 +44,7 @@ def euclidean(p1, p2):
return jnp.linalg.norm(p1 - p2)


@jax.jit
def is_in_bag(x: jnp.ndarray, X: jnp.ndarray) -> jnp.ndarray:
"""
Checks if a point is already in the bag.
"""
return (X == x).all(axis=1).any()


@partial(jax.jit, static_argnames="dist_fn")
def one_to_many_dists(x: jnp.ndarray, X: jnp.ndarray, dist_fn: Callable) -> jnp.ndarray:
"""
Computes the distance between one point and many points.
"""
dists_from_x_to_X = jax.vmap(dist_fn, in_axes=(None, 0))(x, X)
return dists_from_x_to_X


@partial(jax.jit, static_argnames="dist_fn")
def pairwise_distances(X, dist_fn):
def _pairwise_distances(X, dist_fn: Callable):
"""Compute pairwise distances between points in X using a custom distance function."""

def x_to_X_dists(x):
Expand All @@ -77,43 +55,35 @@ def x_to_X_dists(x):
return dists


def submatrix(X: jnp.ndarray, remove_row: int, remove_col: int) -> jnp.ndarray:
def _matrix_cross_sum(X: jnp.ndarray, i: int, j: int, row_only=False):
"""
Returns a submatrix of `X` removing the specified row and column.
Computes the sum of the elements in the row `i` and the column `j` of the matrix `X`.
"""
return jnp.delete(
jnp.delete(X, remove_row, axis=0, assume_unique_indices=True),
remove_col,
axis=1,
assume_unique_indices=True,

return lax.cond(
row_only,
lambda X: X[i, :].sum(),
lambda X: X[i, :].sum() + X[:, j].sum() - X[i, j],
X,
)


def matrix_cross_sum(X: jnp.ndarray, i: int, j: int, row_only=False):
"""
Computes the sum of the elements in the row `i` and the column `j` of the matrix `X`.
"""
X_row_i = X[i, :]
if row_only:
return X_row_i.sum()
X_col_j = X[:, j]
X_element_ij = X[i, j]
return X_row_i.sum() + X_col_j.sum() - X_element_ij
matrix_cross_sum = jax.jit(_matrix_cross_sum, static_argnames=["row_only"])


def needless_point_idx(
def _needless_point_idx(
X: jnp.ndarray, dist_fn: Callable, potential_fn: Callable
) -> int:
"""
Find a point in `X` removing which would decrease the total potential the most.
"""

dists = pairwise_distances(X, dist_fn)
dists = _pairwise_distances(X, dist_fn)
potentials = jax.vmap(potential_fn)(dists)
potentials = fill_diagonal(potentials, 0) # replace diagonal elements with 0

total_potentials_without_each_point = jax.vmap(
lambda i: matrix_cross_sum(potentials, i, i, row_only=True)
lambda i: _matrix_cross_sum(potentials, i, i, row_only=True)
)(jnp.arange(X.shape[0]))

# Compute the decrease in the total potential
Expand All @@ -125,64 +95,19 @@ def needless_point_idx(
return idx


def _min_dist(x, X, dist_fn, n_valid, threshold=0.0):
# # Initialize the distances with the first distance calculation
dists = jnp.full((X.shape[0],), jnp.inf)

# TODO: test both implementations: GPU, CPU, datasets

def early_stop(dists):
# Loop condition function
def cond_fun(args):
i, _dists, min_dist = args
return (min_dist > threshold) & (i < n_valid)

# Loop body function
def body_fun(args):
i, dists, min_dist = args
disti = dist_fn(x, X[i])
dists = dists.at[i].set(disti)
min_dist = lax.min(min_dist, disti)
return i + 1, dists, min_dist

# Run the while loop
_, dists, min_dist = lax.while_loop(cond_fun, body_fun, (0, dists, jnp.inf))

return dists, min_dist

def no_early_stop(dists):
def body_fun(i, args):
min_dist, dists = args
disti = dist_fn(x, X[i])
min_dist = lax.min(min_dist, disti)
dists = dists.at[i].set(disti)
return min_dist, dists

min_dist, dists = lax.fori_loop(0, n_valid, body_fun, (jnp.inf, dists))
return dists, min_dist
needless_point_idx = jax.jit(
_needless_point_idx, static_argnames=["dist_fn", "potential_fn"]
)

dists, min_dist = lax.cond(
threshold > 0.0,
lambda: early_stop(dists),
lambda: no_early_stop(dists),
)

is_above_threshold = min_dist > threshold
return is_above_threshold, dists, min_dist
def dists(x, X, dist_fn, n_valid, threshold=0.0):
ds = jax.vmap(dist_fn, in_axes=(None, 0))(x, X)
mask = jnp.arange(X.shape[0]) < n_valid
ds = jnp.where(mask, ds, jnp.inf)
return ds.min() > threshold, ds, ds.min()


@partial(
jax.jit,
static_argnames=[
"dist_fn",
"k_neighbors",
"threshold",
"power",
"approx_min",
"n_valid_points",
],
)
def add_point_to_bag(
def _add_point_to_bag(
x: jnp.ndarray,
X: jnp.ndarray,
dist_fn: Callable,
Expand All @@ -197,77 +122,71 @@ def add_point_to_bag(
of the replaced point (or -1 if no point was replaced).
"""

assert k_neighbors > 0

is_full = X.shape[0] == n_valid_points

def below_threshold(X):
return X, False, -1

def above_threshold(X):
def above_and_not_full(X):
changed_item_idx = n_valid_points
X = X.at[changed_item_idx].set(x)
return X, True, changed_item_idx

def above_and_full(X):
# Find closest points in `X` to `x`
# TODO: test approx_min_k vs argpartition
if approx_min:
k_closest_points_dists, k_closest_points_indices = lax.approx_min_k(
dists_from_x_to_X, k=k_neighbors
)
else:
k_closest_points_dists, k_closest_points_indices = lax.top_k(
-dists_from_x_to_X, k_neighbors
)
k_closest_points_dists = -k_closest_points_dists

# Define a neighborhood of `x`
N = jnp.concatenate((jnp.array([x]), X[k_closest_points_indices]))

# Find a point in `N` removing which would decrease the total potential the most
needless_point_local_idx = needless_point_idx(
N, dist_fn, lambda d: jnp.power(d, -power)
)

# If the needless point is not `x`, replace it with `x`
is_accepted = needless_point_local_idx > 0
changed_item_idx = k_closest_points_indices[needless_point_local_idx - 1]

X, changed_item_idx = lax.cond(
is_accepted,
lambda X, idx: (X.at[changed_item_idx].set(x), idx),
lambda X, idx: (X, -1),
X,
changed_item_idx,
)

return X, is_accepted, changed_item_idx

result = lax.cond(
is_full,
lambda: above_and_full(X),
lambda: above_and_not_full(X),
def above_and_not_full(X):
changed_item_idx = n_valid_points
X = X.at[changed_item_idx].set(x)
return X, True, changed_item_idx

def above_and_full(X):
# Find closest points in `X` to `x`
# TODO: test approx_min_k vs argpartition

k_closest_points_dists, k_closest_points_indices = lax.cond(
approx_min,
lambda ds: lax.approx_min_k(ds, k=k_neighbors),
lambda ds: lax.top_k(-ds, k=k_neighbors),
dists_from_x_to_X,
)

k_closest_points_dists = lax.abs(k_closest_points_dists) # for top_k

# Define a neighborhood of `x`
N = jnp.concatenate((jnp.array([x]), X[k_closest_points_indices]))

# Find a point in `N` removing which would decrease the total potential the most
needless_point_local_idx = _needless_point_idx(
N, dist_fn, lambda d: d**-power
)

# If the needless point is not `x`, replace it with `x`
is_accepted = needless_point_local_idx > 0
changed_item_idx = k_closest_points_indices[needless_point_local_idx - 1]

X, changed_item_idx = lax.cond(
is_accepted,
lambda X, idx: (X.at[changed_item_idx].set(x), idx),
lambda X, _: (X, -1),
X,
changed_item_idx,
)

return result
return X, is_accepted, changed_item_idx

is_full = X.shape[0] == n_valid_points

is_above_threshold, dists_from_x_to_X, min_dist = _min_dist(
is_above_threshold, dists_from_x_to_X, min_dist = dists(
x, X, dist_fn, n_valid_points, threshold=threshold
)

X, is_accepted, updated_idx = lax.cond(
is_above_threshold,
above_threshold,
below_threshold,
X,
)
return X, is_accepted, updated_idx
branches = [below_threshold, above_and_not_full, above_and_full]
branch_idx = 0 + is_above_threshold + (is_full & is_above_threshold)

return lax.switch(branch_idx, branches, X)

@jax.jit
def finalize_updates(changes: jnp.ndarray) -> jnp.ndarray:

add_point_to_bag = jax.jit(
_add_point_to_bag,
static_argnames=[
"dist_fn",
"k_neighbors",
],
)


def _finalize_updates(changes: jnp.ndarray) -> jnp.ndarray:
"""
Given an array where each element represents whether a change occurred or
not, -1 means no change, and a positive integer represents ID of changed
Expand All @@ -285,26 +204,15 @@ def finalize_updates(changes: jnp.ndarray) -> jnp.ndarray:
changes_reversed, return_index=True, size=changes.shape[0]
)

def keep_only_unique_change(i, changes_reversed):
return lax.cond(
jnp.isin(i, unique_idxs),
lambda _: changes_reversed,
lambda _: changes_reversed.at[i].set(-1),
changes_reversed,
)
idxs = jnp.arange(changes.shape[0])
mask = jnp.isin(idxs, unique_idxs)[::-1]
return jnp.where(mask, changes, -1)

changes_reversed = lax.fori_loop(
0, changes_reversed.shape[0], keep_only_unique_change, changes_reversed
)

return changes_reversed[::-1]
finalize_updates = jax.jit(_finalize_updates)


@partial(
jax.jit,
static_argnames=["dist_fn", "k_neighbors", "power", "threshold", "n_valid_points"],
)
def add_points_to_bag(
def _add_points_to_bag(
*,
X: jnp.ndarray,
xs: jnp.ndarray,
Expand All @@ -314,25 +222,15 @@ def add_points_to_bag(
power: float,
n_valid_points: int,
) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
assert X.shape[0] >= n_valid_points
assert xs.shape[0] > 0
assert k_neighbors > 0
assert power > 0
assert threshold >= 0
assert n_valid_points >= 0

if X.dtype != xs.dtype:
logger.warning(
f"X and xs have different dtypes: {X.dtype} and {xs.dtype}, casting to {X.dtype}"
)
xs = xs.astype(X.dtype)
assert X.dtype == xs.dtype

# Initialize array to store the information about the changes
changed_item_idxs = -jnp.ones(xs.shape[0], dtype=int) # -1 means not changed

def body_fun(i, args):
X, changed_items_idxs, n_valid_points = args
X_updated, is_accepted, changed_item_idx = add_point_to_bag(
X_updated, is_accepted, changed_item_idx = _add_point_to_bag(
xs[i],
X,
dist_fn,
Expand All @@ -359,7 +257,14 @@ def body_fun(i, args):
)

# Some points might have been accepted and then replaced by another point
changed_item_idxs = finalize_updates(changed_item_idxs)
changed_item_idxs = _finalize_updates(changed_item_idxs)
acceptance_mask = changed_item_idxs >= 0

return changed_item_idxs, X_new, acceptance_mask


add_points_to_bag = jax.jit(
_add_points_to_bag,
static_argnames=["dist_fn", "k_neighbors"],
donate_argnames=["X"],
)
Loading

0 comments on commit 5bcda4a

Please sign in to comment.