From 4db0c859172b70858c0fdb7b350002ad76df9ed7 Mon Sep 17 00:00:00 2001 From: PauBadiaM Date: Thu, 8 Aug 2024 14:59:52 +0200 Subject: [PATCH] Refactored gsva --- decoupler/method_gsva.py | 152 +++++++++++++++++++++++---------------- 1 file changed, 90 insertions(+), 62 deletions(-) diff --git a/decoupler/method_gsva.py b/decoupler/method_gsva.py index 4d29985..32570ce 100644 --- a/decoupler/method_gsva.py +++ b/decoupler/method_gsva.py @@ -6,7 +6,7 @@ import numpy as np import pandas as pd -from scipy.sparse import csr_matrix +from scipy.sparse import issparse import math from .pre import extract, rename_net, filt_min_n, return_data, break_ties @@ -140,76 +140,99 @@ def density(mat, kcdf): return mat -@nb.njit(nb.types.Tuple((nb.f8[:, :], nb.i8[:, :]))(nb.f8[:, :]), parallel=True, cache=True) -def nb_get_D_I(mat): - n = mat.shape[1] - rev_idx = np.abs(np.arange(n, 0, -1, nb.f8) - n / 2) - Idx = np.zeros(mat.shape, dtype=nb.i8) - for i in nb.prange(mat.shape[0]): - Idx[i] = np.argsort(-mat[i]) - tmp = np.zeros(n, dtype=nb.f8) - tmp[Idx[i]] = rev_idx - mat[i] = tmp - return mat, Idx - - -@nb.njit(nb.f8(nb.f8[:], nb.i8[:], nb.i8, nb.i8[:], nb.i8[:], nb.i8, nb.f8), cache=True) -def ks_sample(D, Idx, n_genes, geneset_mask, fset, n_geneset, dec): - - sum_gset = 0.0 - for i in nb.prange(n_geneset): - sum_gset += D[fset[i]] - - mx_value_sign = 0.0 - cum_sum = 0.0 - mx_pos = 0.0 - mx_neg = 0.0 - - for i in nb.prange(n_genes): - idx = Idx[i] - if geneset_mask[idx] == 1: - cum_sum += D[idx] / sum_gset +@nb.njit(nb.types.Tuple((nb.i8[:, :], nb.i8[:, :]))(nb.f8[:, :]), parallel=True, cache=True) +def order_rankstat(mat): + n_rows, n_cols = mat.shape + ord_mat = np.zeros((n_rows, n_cols), dtype=nb.i8) + rst_mat = np.zeros((n_rows, n_cols), dtype=nb.i8) + for i in range(n_rows): + ord = np.argsort(-mat[i, :]) + 1 + rst = np.zeros(n_cols, dtype=nb.i8) + for j in range(n_cols): + rst[ord[j] - 1] = abs(n_cols - j - (n_cols // 2)) + ord_mat[i, :] = ord + rst_mat[i, :] = rst + return ord_mat, rst_mat + + +@nb.njit(nb.types.UniTuple(nb.f8, 2)(nb.i8[:], nb.i8, nb.i8[:], nb.i8[:], nb.i8), cache=True) +def rnd_walk(gsetidx, k, generanking, rankstat, n): + stepcdfingeneset = np.zeros(n, dtype=np.int32) + stepcdfoutgeneset = np.ones(n, dtype=np.int32) + for i in range(k): + idx = gsetidx[i] - 1 + stepcdfingeneset[idx] = rankstat[generanking[idx] - 1] + stepcdfoutgeneset[idx] = 0 + for i in range(1, n): + stepcdfingeneset[i] += stepcdfingeneset[i-1] + stepcdfoutgeneset[i] += stepcdfoutgeneset[i-1] + walkstatpos = -np.inf + walkstatneg = np.inf + walkstat = np.zeros(n, dtype=np.float64) + for i in range(n): + wlkstat = (stepcdfingeneset[i] / stepcdfingeneset[-1]) - (stepcdfoutgeneset[i] / stepcdfoutgeneset[-1]) + walkstat[i] = wlkstat + if wlkstat > walkstatpos: + walkstatpos = wlkstat + if wlkstat < walkstatneg: + walkstatneg = wlkstat + return walkstatpos, walkstatneg + + +@nb.njit(nb.f8(nb.i8[:], nb.i8[:], nb.i8[:], nb.b1, nb.b1), cache=True) +def score_geneset(gsetidx, generanking, rankstat, maxdiff, absrnk): + n = len(generanking) + k = len(gsetidx) + walkstatpos, walkstatneg = rnd_walk(gsetidx, k, generanking, rankstat, n) + if maxdiff: + if absrnk: + es = walkstatpos - walkstatneg else: - cum_sum -= dec - - if cum_sum > mx_pos: - mx_pos = cum_sum - if cum_sum < mx_neg: - mx_neg = cum_sum - - mx_value_sign = mx_pos + mx_neg - - return mx_value_sign - - -@nb.njit(nb.f8[:](nb.f8[:, :], nb.i8[:, :], nb.i8[:]), parallel=True, cache=True) -def ks_matrix(D, Idx, fset): - n_samples, n_genes = D.shape - n_geneset = fset.shape[0] - - geneset_mask = np.zeros(n_genes, dtype=nb.i8) - geneset_mask[fset] = 1 - - dec = 1.0 / (n_genes - n_geneset) - + es = walkstatpos + walkstatneg + else: + es = walkstatpos if abs(walkstatpos) > abs(walkstatneg) else walkstatneg + return es + + +@nb.njit(nb.i8[:](nb.i8[:], nb.i8[:]), cache=True) +def match(a, b): + max_b = np.max(b) if len(b) > 0 else 0 + index_array = np.full(max_b + 1, -1, dtype=nb.i8) + for idx, value in enumerate(b): + if 0 <= value <= max_b: + index_array[value] = idx + result = np.full(len(a), -1, dtype=nb.i8) + for i in range(len(a)): + if 0 <= a[i] <= max_b: + result[i] = index_array[a[i]] + return result + 1 + + +@nb.njit(nb.f8[:](nb.i8[:, :], nb.i8[:, :], nb.i8[:], nb.b1, nb.b1), parallel=True, cache=True) +def ks_fset(ord, rst, fset, maxdiff, absrnk): + n_samples, n_genes = ord.shape res = np.zeros(n_samples, dtype=nb.f8) for i in nb.prange(n_samples): - res[i] = ks_sample(D[i], Idx[i], n_genes, geneset_mask, fset, n_geneset, dec) - + generanking = ord[i] + rankstat = rst[i] + genesetsrankidx = match(fset, generanking) + res[i] = score_geneset(genesetsrankidx, generanking, rankstat, maxdiff, absrnk) return res -def gsva(mat, net, kcdf=False, verbose=False): +def gsva(mat, net, kcdf=False, maxdiff=True, absrnk=False, verbose=False): + if issparse(mat): + mat = mat.toarray() # Get feature Density mat = density(mat, kcdf=kcdf) - mat, Idx = nb_get_D_I(mat) + ord, rst = order_rankstat(mat) # Run GSVA for each feature set - acts = np.zeros((mat.shape[0], len(net))) + acts = np.zeros((ord.shape[0], len(net))) for j in tqdm(range(len(net)), disable=not verbose): - fset = net.iloc[j] - acts[:, j] = ks_matrix(mat, Idx, fset) + fset = net.iloc[j] + 1 + acts[:, j] = ks_fset(ord, rst, fset, maxdiff, absrnk) return acts @@ -265,6 +288,13 @@ def run_gsva(mat, net, source='source', target='target', kcdf='gaussian', mx_dif # Extract sparse matrix and array of genes m, r, c = extract(mat, use_raw=use_raw, verbose=verbose) + # Remove repeated features + if issparse(m): + m = m.toarray() + msk = ~np.all(m == m[0, :], axis=0) + m = m[:, msk] + c = c[msk] + # Transform net net = rename_net(net, source=source, target=target, weight=None) net = filt_min_n(c, net, min_n=min_n) @@ -281,9 +311,7 @@ def run_gsva(mat, net, source='source', target='target', kcdf='gaussian', mx_dif print('Running gsva on mat with {0} samples and {1} targets for {2} sources.'.format(m.shape[0], len(c), len(net))) # Run GSVA - if isinstance(m, csr_matrix): - m = m.toarray() - estimate = gsva(m, net, kcdf=kcdf, verbose=verbose) + estimate = gsva(m, net, kcdf=kcdf, maxdiff=mx_diff, absrnk=abs_rnk, verbose=verbose) # Transform to df estimate = pd.DataFrame(estimate, index=r, columns=net.index)