Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hvg seurat v3 numba kernel #3017

Merged
merged 18 commits into from
May 31, 2024
Merged
1 change: 1 addition & 0 deletions docs/release-notes/1.10.2.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,5 @@
```

* `sparse_mean_variance_axis` now uses all cores for the calculations {pr}`3015` {smaller}`S Dicks`
* `pp.highly_variable_genes` with `flavor=seurat_v3` now uses all cores for clipping {pr}`3017` {smaller}`S Dicks`
Intron7 marked this conversation as resolved.
Show resolved Hide resolved
* Speed up {func}`~scanpy.pp.scrublet` {pr}`3044` {smaller}`S Dicks` and {pr}`3056` {smaller}`P Angerer`
40 changes: 32 additions & 8 deletions scanpy/preprocessing/_highly_variable_genes.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from inspect import signature
from typing import TYPE_CHECKING, cast

import numba
import numpy as np
import pandas as pd
import scipy.sparse as sp_sparse
Expand Down Expand Up @@ -98,19 +99,22 @@
estimat_var[not_const] = model.outputs.fitted_values
reg_std = np.sqrt(10**estimat_var)

batch_counts = data_batch.astype(np.float64).copy()
# clip large values as in Seurat
N = data_batch.shape[0]
vmax = np.sqrt(N)
clip_val = reg_std * vmax + mean
if sp_sparse.issparse(batch_counts):
batch_counts = sp_sparse.csr_matrix(batch_counts)
mask = batch_counts.data > clip_val[batch_counts.indices]
batch_counts.data[mask] = clip_val[batch_counts.indices[mask]]

squared_batch_counts_sum = np.array(batch_counts.power(2).sum(axis=0))
batch_counts_sum = np.array(batch_counts.sum(axis=0))
if sp_sparse.issparse(data_batch):
batch_counts = sp_sparse.csr_matrix(data_batch)

squared_batch_counts_sum, batch_counts_sum = _clip_sparse(
batch_counts.indices,
batch_counts.data,
n_cols=batch_counts.shape[1],
clip_val=clip_val,
nnz=batch_counts.nnz,
)
else:
batch_counts = data_batch.astype(np.float64).copy()

Check warning on line 117 in scanpy/preprocessing/_highly_variable_genes.py

View check run for this annotation

Codecov / codecov/patch

scanpy/preprocessing/_highly_variable_genes.py#L117

Added line #L117 was not covered by tests
clip_val_broad = np.broadcast_to(clip_val, batch_counts.shape)
np.putmask(
batch_counts,
Expand Down Expand Up @@ -193,6 +197,26 @@
return df


@numba.njit(cache=True)
def _sum_and_sum_squares_clipped(
indices: NDArray[np.integer],
data: NDArray[np.floating],
*,
n_cols: int,
clip_val: NDArray[np.float64],
nnz: int,
) -> tuple[NDArray[np.float64], NDArray[np.float64]]:
squared_batch_counts_sum = np.zeros(n_cols, dtype=np.float64)
batch_counts_sum = np.zeros(n_cols, dtype=np.float64)
for i in range(nnz):
idx = indices[i]
element = min(np.float64(data[i]), clip_val[idx])
squared_batch_counts_sum[idx] += element**2
batch_counts_sum[idx] += element

return squared_batch_counts_sum, batch_counts_sum


@dataclass
class _Cutoffs:
min_disp: float
Expand Down
Loading