diff --git a/docs/release-notes/1.10.2.md b/docs/release-notes/1.10.2.md index 9cc5547d63..a030e494d0 100644 --- a/docs/release-notes/1.10.2.md +++ b/docs/release-notes/1.10.2.md @@ -22,4 +22,5 @@ ``` * `sparse_mean_variance_axis` now uses all cores for the calculations {pr}`3015` {smaller}`S Dicks` +* `pp.highly_variable_genes` with `flavor=seurat_v3` now uses a numba kernel {pr}`3017` {smaller}`S Dicks` * Speed up {func}`~scanpy.pp.scrublet` {pr}`3044` {smaller}`S Dicks` and {pr}`3056` {smaller}`P Angerer` diff --git a/scanpy/preprocessing/_highly_variable_genes.py b/scanpy/preprocessing/_highly_variable_genes.py index bc5c2cadf2..1838f3f2ba 100644 --- a/scanpy/preprocessing/_highly_variable_genes.py +++ b/scanpy/preprocessing/_highly_variable_genes.py @@ -5,6 +5,7 @@ from inspect import signature from typing import TYPE_CHECKING, cast +import numba import numpy as np import pandas as pd import scipy.sparse as sp_sparse @@ -98,19 +99,25 @@ def _highly_variable_genes_seurat_v3( estimat_var[not_const] = model.outputs.fitted_values reg_std = np.sqrt(10**estimat_var) - batch_counts = data_batch.astype(np.float64).copy() # clip large values as in Seurat N = data_batch.shape[0] vmax = np.sqrt(N) clip_val = reg_std * vmax + mean - if sp_sparse.issparse(batch_counts): - batch_counts = sp_sparse.csr_matrix(batch_counts) - mask = batch_counts.data > clip_val[batch_counts.indices] - batch_counts.data[mask] = clip_val[batch_counts.indices[mask]] - - squared_batch_counts_sum = np.array(batch_counts.power(2).sum(axis=0)) - batch_counts_sum = np.array(batch_counts.sum(axis=0)) + if sp_sparse.issparse(data_batch): + if sp_sparse.isspmatrix_csr(data_batch): + batch_counts = data_batch + else: + batch_counts = sp_sparse.csr_matrix(data_batch) + + squared_batch_counts_sum, batch_counts_sum = _sum_and_sum_squares_clipped( + batch_counts.indices, + batch_counts.data, + n_cols=batch_counts.shape[1], + clip_val=clip_val, + nnz=batch_counts.nnz, + ) else: + batch_counts = data_batch.astype(np.float64).copy() clip_val_broad = np.broadcast_to(clip_val, batch_counts.shape) np.putmask( batch_counts, @@ -193,6 +200,26 @@ def _highly_variable_genes_seurat_v3( return df +@numba.njit(cache=True) +def _sum_and_sum_squares_clipped( + indices: NDArray[np.integer], + data: NDArray[np.floating], + *, + n_cols: int, + clip_val: NDArray[np.float64], + nnz: int, +) -> tuple[NDArray[np.float64], NDArray[np.float64]]: + squared_batch_counts_sum = np.zeros(n_cols, dtype=np.float64) + batch_counts_sum = np.zeros(n_cols, dtype=np.float64) + for i in range(nnz): + idx = indices[i] + element = min(np.float64(data[i]), clip_val[idx]) + squared_batch_counts_sum[idx] += element**2 + batch_counts_sum[idx] += element + + return squared_batch_counts_sum, batch_counts_sum + + @dataclass class _Cutoffs: min_disp: float