Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Increase CUDA performance for OnlineDiversityPicker #12

Merged
merged 3 commits into from
Nov 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
283 changes: 94 additions & 189 deletions moll/core/distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import jax
import jax.numpy as jnp
from jax import lax
from loguru import logger

from moll.core.utils import fill_diagonal

Expand All @@ -30,15 +29,12 @@ def tanimoto(a: jnp.ndarray, b: jnp.ndarray) -> float:
bitwise_and = jnp.bitwise_and(a, b).sum().astype(float)

# Check for the case where both vectors are all zeros and return 0.0 in that case
dist = jax.lax.cond(
return jax.lax.cond(
bitwise_or == 0.0,
lambda _: 0.0,
lambda _: 1 - (bitwise_and / bitwise_or),
None,
lambda: 0.0,
lambda: 1 - (bitwise_and / bitwise_or),
)

return dist


@jax.jit
def euclidean(p1, p2):
Expand All @@ -48,25 +44,7 @@ def euclidean(p1, p2):
return jnp.linalg.norm(p1 - p2)


@jax.jit
def is_in_bag(x: jnp.ndarray, X: jnp.ndarray) -> jnp.ndarray:
"""
Checks if a point is already in the bag.
"""
return (X == x).all(axis=1).any()


@partial(jax.jit, static_argnames="dist_fn")
def one_to_many_dists(x: jnp.ndarray, X: jnp.ndarray, dist_fn: Callable) -> jnp.ndarray:
"""
Computes the distance between one point and many points.
"""
dists_from_x_to_X = jax.vmap(dist_fn, in_axes=(None, 0))(x, X)
return dists_from_x_to_X


@partial(jax.jit, static_argnames="dist_fn")
def pairwise_distances(X, dist_fn):
def _pairwise_distances(X, dist_fn: Callable):
"""Compute pairwise distances between points in X using a custom distance function."""

def x_to_X_dists(x):
Expand All @@ -77,43 +55,35 @@ def x_to_X_dists(x):
return dists


def submatrix(X: jnp.ndarray, remove_row: int, remove_col: int) -> jnp.ndarray:
def _matrix_cross_sum(X: jnp.ndarray, i: int, j: int, row_only=False):
"""
Returns a submatrix of `X` removing the specified row and column.
Computes the sum of the elements in the row `i` and the column `j` of the matrix `X`.
"""
return jnp.delete(
jnp.delete(X, remove_row, axis=0, assume_unique_indices=True),
remove_col,
axis=1,
assume_unique_indices=True,

return lax.cond(
row_only,
lambda X: X[i, :].sum(),
lambda X: X[i, :].sum() + X[:, j].sum() - X[i, j],
X,
)


def matrix_cross_sum(X: jnp.ndarray, i: int, j: int, row_only=False):
"""
Computes the sum of the elements in the row `i` and the column `j` of the matrix `X`.
"""
X_row_i = X[i, :]
if row_only:
return X_row_i.sum()
X_col_j = X[:, j]
X_element_ij = X[i, j]
return X_row_i.sum() + X_col_j.sum() - X_element_ij
matrix_cross_sum = jax.jit(_matrix_cross_sum, static_argnames=["row_only"])


def needless_point_idx(
def _needless_point_idx(
X: jnp.ndarray, dist_fn: Callable, potential_fn: Callable
) -> int:
"""
Find a point in `X` removing which would decrease the total potential the most.
"""

dists = pairwise_distances(X, dist_fn)
dists = _pairwise_distances(X, dist_fn)
potentials = jax.vmap(potential_fn)(dists)
potentials = fill_diagonal(potentials, 0) # replace diagonal elements with 0

total_potentials_without_each_point = jax.vmap(
lambda i: matrix_cross_sum(potentials, i, i, row_only=True)
lambda i: _matrix_cross_sum(potentials, i, i, row_only=True)
)(jnp.arange(X.shape[0]))

# Compute the decrease in the total potential
Expand All @@ -125,64 +95,19 @@ def needless_point_idx(
return idx


def _min_dist(x, X, dist_fn, n_valid, threshold=0.0):
# # Initialize the distances with the first distance calculation
dists = jnp.full((X.shape[0],), jnp.inf)

# TODO: test both implementations: GPU, CPU, datasets

def early_stop(dists):
# Loop condition function
def cond_fun(args):
i, _dists, min_dist = args
return (min_dist > threshold) & (i < n_valid)

# Loop body function
def body_fun(args):
i, dists, min_dist = args
disti = dist_fn(x, X[i])
dists = dists.at[i].set(disti)
min_dist = lax.min(min_dist, disti)
return i + 1, dists, min_dist

# Run the while loop
_, dists, min_dist = lax.while_loop(cond_fun, body_fun, (0, dists, jnp.inf))

return dists, min_dist

def no_early_stop(dists):
def body_fun(i, args):
min_dist, dists = args
disti = dist_fn(x, X[i])
min_dist = lax.min(min_dist, disti)
dists = dists.at[i].set(disti)
return min_dist, dists

min_dist, dists = lax.fori_loop(0, n_valid, body_fun, (jnp.inf, dists))
return dists, min_dist
needless_point_idx = jax.jit(
_needless_point_idx, static_argnames=["dist_fn", "potential_fn"]
)

dists, min_dist = lax.cond(
threshold > 0.0,
lambda: early_stop(dists),
lambda: no_early_stop(dists),
)

is_above_threshold = min_dist > threshold
return is_above_threshold, dists, min_dist
def dists(x, X, dist_fn, n_valid, threshold=0.0):
ds = jax.vmap(dist_fn, in_axes=(None, 0))(x, X)
mask = jnp.arange(X.shape[0]) < n_valid
ds = jnp.where(mask, ds, jnp.inf)
return ds.min() > threshold, ds, ds.min()


@partial(
jax.jit,
static_argnames=[
"dist_fn",
"k_neighbors",
"threshold",
"power",
"approx_min",
"n_valid_points",
],
)
def add_point_to_bag(
def _add_point_to_bag(
x: jnp.ndarray,
X: jnp.ndarray,
dist_fn: Callable,
Expand All @@ -197,77 +122,71 @@ def add_point_to_bag(
of the replaced point (or -1 if no point was replaced).
"""

assert k_neighbors > 0

is_full = X.shape[0] == n_valid_points

def below_threshold(X):
return X, False, -1

def above_threshold(X):
def above_and_not_full(X):
changed_item_idx = n_valid_points
X = X.at[changed_item_idx].set(x)
return X, True, changed_item_idx

def above_and_full(X):
# Find closest points in `X` to `x`
# TODO: test approx_min_k vs argpartition
if approx_min:
k_closest_points_dists, k_closest_points_indices = lax.approx_min_k(
dists_from_x_to_X, k=k_neighbors
)
else:
k_closest_points_dists, k_closest_points_indices = lax.top_k(
-dists_from_x_to_X, k_neighbors
)
k_closest_points_dists = -k_closest_points_dists

# Define a neighborhood of `x`
N = jnp.concatenate((jnp.array([x]), X[k_closest_points_indices]))

# Find a point in `N` removing which would decrease the total potential the most
needless_point_local_idx = needless_point_idx(
N, dist_fn, lambda d: jnp.power(d, -power)
)

# If the needless point is not `x`, replace it with `x`
is_accepted = needless_point_local_idx > 0
changed_item_idx = k_closest_points_indices[needless_point_local_idx - 1]

X, changed_item_idx = lax.cond(
is_accepted,
lambda X, idx: (X.at[changed_item_idx].set(x), idx),
lambda X, idx: (X, -1),
X,
changed_item_idx,
)

return X, is_accepted, changed_item_idx

result = lax.cond(
is_full,
lambda: above_and_full(X),
lambda: above_and_not_full(X),
def above_and_not_full(X):
changed_item_idx = n_valid_points
X = X.at[changed_item_idx].set(x)
return X, True, changed_item_idx

def above_and_full(X):
# Find closest points in `X` to `x`
# TODO: test approx_min_k vs argpartition

k_closest_points_dists, k_closest_points_indices = lax.cond(
approx_min,
lambda ds: lax.approx_min_k(ds, k=k_neighbors),
lambda ds: lax.top_k(-ds, k=k_neighbors),
dists_from_x_to_X,
)

k_closest_points_dists = lax.abs(k_closest_points_dists) # for top_k

# Define a neighborhood of `x`
N = jnp.concatenate((jnp.array([x]), X[k_closest_points_indices]))

# Find a point in `N` removing which would decrease the total potential the most
needless_point_local_idx = _needless_point_idx(
N, dist_fn, lambda d: d**-power
)

# If the needless point is not `x`, replace it with `x`
is_accepted = needless_point_local_idx > 0
changed_item_idx = k_closest_points_indices[needless_point_local_idx - 1]

X, changed_item_idx = lax.cond(
is_accepted,
lambda X, idx: (X.at[changed_item_idx].set(x), idx),
lambda X, _: (X, -1),
X,
changed_item_idx,
)

return result
return X, is_accepted, changed_item_idx

is_full = X.shape[0] == n_valid_points

is_above_threshold, dists_from_x_to_X, min_dist = _min_dist(
is_above_threshold, dists_from_x_to_X, min_dist = dists(
x, X, dist_fn, n_valid_points, threshold=threshold
)

X, is_accepted, updated_idx = lax.cond(
is_above_threshold,
above_threshold,
below_threshold,
X,
)
return X, is_accepted, updated_idx
branches = [below_threshold, above_and_not_full, above_and_full]
branch_idx = 0 + is_above_threshold + (is_full & is_above_threshold)

return lax.switch(branch_idx, branches, X)

@jax.jit
def finalize_updates(changes: jnp.ndarray) -> jnp.ndarray:

add_point_to_bag = jax.jit(
_add_point_to_bag,
static_argnames=[
"dist_fn",
"k_neighbors",
],
)


def _finalize_updates(changes: jnp.ndarray) -> jnp.ndarray:
"""
Given an array where each element represents whether a change occurred or
not, -1 means no change, and a positive integer represents ID of changed
Expand All @@ -285,26 +204,15 @@ def finalize_updates(changes: jnp.ndarray) -> jnp.ndarray:
changes_reversed, return_index=True, size=changes.shape[0]
)

def keep_only_unique_change(i, changes_reversed):
return lax.cond(
jnp.isin(i, unique_idxs),
lambda _: changes_reversed,
lambda _: changes_reversed.at[i].set(-1),
changes_reversed,
)
idxs = jnp.arange(changes.shape[0])
mask = jnp.isin(idxs, unique_idxs)[::-1]
return jnp.where(mask, changes, -1)

changes_reversed = lax.fori_loop(
0, changes_reversed.shape[0], keep_only_unique_change, changes_reversed
)

return changes_reversed[::-1]
finalize_updates = jax.jit(_finalize_updates)


@partial(
jax.jit,
static_argnames=["dist_fn", "k_neighbors", "power", "threshold", "n_valid_points"],
)
def add_points_to_bag(
def _add_points_to_bag(
*,
X: jnp.ndarray,
xs: jnp.ndarray,
Expand All @@ -314,25 +222,15 @@ def add_points_to_bag(
power: float,
n_valid_points: int,
) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
assert X.shape[0] >= n_valid_points
assert xs.shape[0] > 0
assert k_neighbors > 0
assert power > 0
assert threshold >= 0
assert n_valid_points >= 0

if X.dtype != xs.dtype:
logger.warning(
f"X and xs have different dtypes: {X.dtype} and {xs.dtype}, casting to {X.dtype}"
)
xs = xs.astype(X.dtype)
assert X.dtype == xs.dtype

# Initialize array to store the information about the changes
changed_item_idxs = -jnp.ones(xs.shape[0], dtype=int) # -1 means not changed

def body_fun(i, args):
X, changed_items_idxs, n_valid_points = args
X_updated, is_accepted, changed_item_idx = add_point_to_bag(
X_updated, is_accepted, changed_item_idx = _add_point_to_bag(
xs[i],
X,
dist_fn,
Expand All @@ -359,7 +257,14 @@ def body_fun(i, args):
)

# Some points might have been accepted and then replaced by another point
changed_item_idxs = finalize_updates(changed_item_idxs)
changed_item_idxs = _finalize_updates(changed_item_idxs)
acceptance_mask = changed_item_idxs >= 0

return changed_item_idxs, X_new, acceptance_mask


add_points_to_bag = jax.jit(
_add_points_to_bag,
static_argnames=["dist_fn", "k_neighbors"],
donate_argnames=["X"],
)
Loading