Skip to content

Commit

Permalink
use multithreading for linear regression and add tdqm
Browse files Browse the repository at this point in the history
  • Loading branch information
mumichae committed Apr 22, 2024
1 parent 25fb1bf commit 7651163
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 45 deletions.
35 changes: 27 additions & 8 deletions scib/metrics/pcr.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from concurrent.futures import ThreadPoolExecutor

import numpy as np
import pandas as pd
import scanpy as sc
from scipy import sparse
from tqdm import tqdm

from ..utils import check_adata, check_batch

Expand All @@ -16,6 +19,7 @@ def pcr_comparison(
recompute_pca=False,
scale=True,
verbose=False,
n_threads=1,
):
"""Principal component regression score
Expand Down Expand Up @@ -64,6 +68,7 @@ def pcr_comparison(
recompute_pca=recompute_pca,
n_comps=n_comps,
linreg_method=linreg_method,
n_threads=n_threads,
verbose=verbose,
)

Expand All @@ -74,6 +79,7 @@ def pcr_comparison(
recompute_pca=recompute_pca,
n_comps=n_comps,
linreg_method=linreg_method,
n_threads=n_threads,
verbose=verbose,
)

Expand All @@ -98,6 +104,7 @@ def pcr(
recompute_pca=False,
linreg_method="sklearn",
verbose=False,
n_threads=1,
):
"""Principal component regression for anndata object
Expand Down Expand Up @@ -151,6 +158,7 @@ def pcr(
covariate_values,
n_comps=n_comps,
linreg_method=linreg_method,
n_threads=n_threads,
)

# use existing PCA computation
Expand All @@ -162,6 +170,7 @@ def pcr(
covariate_values,
pca_var=adata.uns["pca"]["variance"],
linreg_method=linreg_method,
n_threads=n_threads,
)

# recompute PCA
Expand All @@ -173,6 +182,7 @@ def pcr(
covariate_values,
n_comps=n_comps,
linreg_method=linreg_method,
n_threads=n_threads,
)


Expand All @@ -184,6 +194,7 @@ def pc_regression(
svd_solver="arpack",
linreg_method="sklearn",
verbose=False,
n_threads=1,
):
"""Principal component regression
Expand All @@ -210,7 +221,6 @@ def pc_regression(
:return:
Variance contribution of regression
"""

if isinstance(data, (np.ndarray, sparse.csr_matrix, sparse.csc_matrix)):
matrix = data
else:
Expand Down Expand Up @@ -264,11 +274,17 @@ def pc_regression(
covariate = pd.get_dummies(covariate).to_numpy()

# fit linear model for n_comps PCs
r2 = []
for i in range(n_comps):
pc = X_pca[:, [i]]
r2_score = linreg_method(X=covariate, y=pc)
r2.append(np.maximum(0, r2_score))
if verbose:
print(f"Use {n_threads} threads for regression...")
if n_threads == 1:
r2 = []
for i in tqdm(range(n_comps), total=n_comps):
r2_score = linreg_method(X=covariate, y=X_pca[:, [i]])
r2.append(np.maximum(0, r2_score))
else:
with ThreadPoolExecutor(max_workers=n_threads) as executor:
run_r2 = executor.map(linreg_method, [covariate] * n_comps, X_pca.T)
r2 = list(tqdm(run_r2, total=n_comps))

Var = pca_var / sum(pca_var) * 100
R2Var = sum(r2 * Var) / 100
Expand All @@ -281,10 +297,13 @@ def linreg_sklearn(X, y):

lm = LinearRegression()
lm.fit(X, y)
return lm.score(X, y)
r2_score = lm.score(X, y)
np.maximum(0, r2_score)
return r2_score


def linreg_np(X, y):
coefficients, residuals, _, _ = np.linalg.lstsq(X, y, rcond=None)
tss = np.sum((y - y.mean()) ** 2)
return 1 - (residuals[0] / tss)
r2_score = 1 - (residuals[0] / tss)
return np.maximum(0, r2_score)
74 changes: 37 additions & 37 deletions tests/metrics/test_pcr_metrics.py
Original file line number Diff line number Diff line change
@@ -1,57 +1,57 @@
import pytest
from scipy.sparse import csr_matrix

import scib
from tests.common import LOGGER, add_embed, assert_near_exact


def test_pc_regression(adata):
score = scib.me.pc_regression(adata.X, adata.obs["batch"])
LOGGER.info(score)
assert_near_exact(score, 0, diff=1e-4)


def test_pc_regression_sparse(adata):
@pytest.mark.parametrize("sparse", [False, True])
def test_pc_regression(adata, sparse):
if sparse:
adata.X = csr_matrix(adata.X)
score = scib.me.pc_regression(
csr_matrix(adata.X),
adata.X,
covariate=adata.obs["batch"],
n_comps=adata.n_vars,
)
LOGGER.info(score)
assert_near_exact(score, 0, diff=1e-4)


def test_pcr_sklearn(adata_pca):
score = scib.me.pcr(
adata_pca, covariate="celltype", linreg_method="sklearn", verbose=True
assert_near_exact(score, 0, diff=1e-3)


@pytest.mark.parametrize("linreg_method", ["numpy", "sklearn"])
def test_pcr_timing(adata_pca, linreg_method):
import timeit

import anndata as ad
import scanpy as sc

# scale up anndata
adata = ad.concat([adata_pca] * 100)
print(f"compute PCA on {adata.n_obs} cells...")
sc.pp.pca(adata)

timing = timeit.timeit(
lambda: scib.me.pcr(
adata,
covariate="celltype",
linreg_method=linreg_method,
verbose=False,
n_threads=10,
),
number=10,
)
LOGGER.info(score)
assert_near_exact(score, 0.3371261556141021, diff=1e-3)
LOGGER.info(f"timeit: {timing}")


def test_pcr_numpy(adata_pca):
# test pcr value
score = scib.me.pcr(
adata_pca, covariate="celltype", linreg_method="numpy", verbose=True
)
LOGGER.info(score)
assert_near_exact(score, 0.3371261556141021, diff=1e-3)


def test_pcr_implementations(adata_pca):
score_sklearn = scib.me.pcr(
adata_pca,
covariate="celltype",
linreg_method="sklearn",
linreg_method=linreg_method,
verbose=True,
n_threads=1,
)
LOGGER.info(f"sklearn score: {score_sklearn}")

score_numpy = scib.me.pcr(
adata_pca,
covariate="celltype",
linreg_method="numpy",
)
LOGGER.info(f"numpy score: {score_numpy}")

assert_near_exact(score_sklearn, score_numpy, diff=1e-3)
LOGGER.info(score)
assert_near_exact(score, 0.33401529220865844, diff=1e-3)


def test_pcr_comparison_batch(adata):
Expand Down

0 comments on commit 7651163

Please sign in to comment.