Skip to content

Commit

Permalink
unify definitions
Browse files Browse the repository at this point in the history
  • Loading branch information
flying-sheep committed Jan 10, 2025
1 parent c3cdd25 commit db3aa89
Show file tree
Hide file tree
Showing 10 changed files with 46 additions and 45 deletions.
6 changes: 3 additions & 3 deletions src/scanpy/_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,14 +629,14 @@ def axis_mul_or_truediv(
@axis_mul_or_truediv.register(sparse.csr_matrix)
@axis_mul_or_truediv.register(sparse.csc_matrix)
def _(
X: sparse.csr_matrix | sparse.csc_matrix,
X: _CSMatrix,
scaling_array,
axis: Literal[0, 1],
op: Callable[[Any, Any], Any],
*,
allow_divide_by_zero: bool = True,
out: sparse.csr_matrix | sparse.csc_matrix | None = None,
) -> sparse.csr_matrix | sparse.csc_matrix:
out: _CSMatrix | None = None,
) -> _CSMatrix:
check_op(op)
if out is not None and X.data is not out.data:
raise ValueError(
Expand Down
6 changes: 3 additions & 3 deletions src/scanpy/preprocessing/_deprecated/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,18 @@

from ..._compat import _LegacyRandom

CSMatrix = csr_matrix | csc_matrix
_CSMatrix = csr_matrix | csc_matrix


@old_positionals("n_obs", "random_state", "copy")
def subsample(
data: AnnData | np.ndarray | CSMatrix,
data: AnnData | np.ndarray | _CSMatrix,
fraction: float | None = None,
*,
n_obs: int | None = None,
random_state: _LegacyRandom = 0,
copy: bool = False,
) -> AnnData | tuple[np.ndarray | CSMatrix, NDArray[np.int64]] | None:
) -> AnnData | tuple[np.ndarray | _CSMatrix, NDArray[np.int64]] | None:
"""\
Subsample to a fraction of the number of observations.
Expand Down
4 changes: 2 additions & 2 deletions src/scanpy/preprocessing/_pca/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@

from ..._compat import _LegacyRandom

CSMatrix = sparse.csr_matrix | sparse.csc_matrix
_CSMatrix = sparse.csr_matrix | sparse.csc_matrix


def _pca_compat_sparse(
x: CSMatrix,
x: _CSMatrix,
n_pcs: int,
*,
solver: Literal["arpack", "lobpcg"],
Expand Down
8 changes: 4 additions & 4 deletions src/scanpy/preprocessing/_pca/_dask_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from ..._compat import DaskArray

CSMatrix = sparse.csr_matrix | sparse.csc_matrix
_CSMatrix = sparse.csr_matrix | sparse.csc_matrix


@dataclass
Expand Down Expand Up @@ -120,7 +120,7 @@ def transform(self, x: DaskArray) -> DaskArray:
import dask.array as da

def transform_block(
x_part: CSMatrix,
x_part: _CSMatrix,
mean_: NDArray[np.floating],
components_: NDArray[np.floating],
):
Expand Down Expand Up @@ -191,8 +191,8 @@ def _cov_sparse_dask(
else:
dtype = np.dtype(dtype)

def gram_block(x_part: CSMatrix):
gram_matrix: CSMatrix = x_part.T @ x_part
def gram_block(x_part: _CSMatrix):
gram_matrix: _CSMatrix = x_part.T @ x_part
return gram_matrix.toarray()[None, ...] # need new axis for summing

gram_matrix_dask: DaskArray = da.map_blocks(
Expand Down
14 changes: 8 additions & 6 deletions src/scanpy/preprocessing/_qc.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@

from anndata import AnnData

_CSMatrix = csr_matrix | csc_matrix


def _choose_mtx_rep(adata, *, use_raw: bool = False, layer: str | None = None):
is_layer = layer is not None
Expand Down Expand Up @@ -102,9 +104,9 @@ def describe_obs(
# Handle whether X is passed
if X is None:
X = _choose_mtx_rep(adata, use_raw=use_raw, layer=layer)
if isinstance(X, spmatrix) and not isinstance(X, csr_matrix | csc_matrix):
if isinstance(X, spmatrix) and not isinstance(X, _CSMatrix):
X = csr_matrix(X) # COO not subscriptable
if isinstance(X, csr_matrix | csc_matrix):
if isinstance(X, _CSMatrix):
X.eliminate_zeros()
obs_metrics = pd.DataFrame(index=adata.obs_names)
obs_metrics[f"n_{var_type}_by_{expr_type}"] = materialize_as_ndarray(
Expand Down Expand Up @@ -189,9 +191,9 @@ def describe_var(
# Handle whether X is passed
if X is None:
X = _choose_mtx_rep(adata, use_raw=use_raw, layer=layer)
if isinstance(X, spmatrix) and not isinstance(X, csr_matrix | csc_matrix):
if isinstance(X, spmatrix) and not isinstance(X, _CSMatrix):
X = csr_matrix(X) # COO not subscriptable
if isinstance(X, csr_matrix | csc_matrix):
if isinstance(X, _CSMatrix):
X.eliminate_zeros()
var_metrics = pd.DataFrame(index=adata.var_names)
var_metrics[f"n_cells_by_{expr_type}"], var_metrics[f"mean_{expr_type}"] = (
Expand Down Expand Up @@ -298,9 +300,9 @@ def calculate_qc_metrics(
)
# Pass X so I only have to do it once
X = _choose_mtx_rep(adata, use_raw=use_raw, layer=layer)
if isinstance(X, spmatrix) and not isinstance(X, csr_matrix | csc_matrix):
if isinstance(X, spmatrix) and not isinstance(X, _CSMatrix):
X = csr_matrix(X) # COO not subscriptable
if isinstance(X, csr_matrix | csc_matrix):
if isinstance(X, _CSMatrix):
X.eliminate_zeros()

# Convert qc_vars to list if str
Expand Down
10 changes: 5 additions & 5 deletions src/scanpy/preprocessing/_scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from numpy.typing import NDArray
from scipy import sparse as sp

CSMatrix = sp.csr_matrix | sp.csc_matrix
_CSMatrix = sp.csr_matrix | sp.csc_matrix


@njit
Expand Down Expand Up @@ -66,7 +66,7 @@ def clip_array(
return X


def clip_set(x: CSMatrix, *, max_value: float, zero_center: bool = True) -> CSMatrix:
def clip_set(x: _CSMatrix, *, max_value: float, zero_center: bool = True) -> _CSMatrix:
x = x.copy()
x[x > max_value] = max_value
if zero_center:
Expand All @@ -78,15 +78,15 @@ def clip_set(x: CSMatrix, *, max_value: float, zero_center: bool = True) -> CSMa
@old_positionals("zero_center", "max_value", "copy", "layer", "obsm")
@singledispatch
def scale(
data: AnnData | csr_matrix | csc_matrix | np.ndarray | DaskArray,
data: AnnData | _CSMatrix | np.ndarray | DaskArray,
*,
zero_center: bool = True,
max_value: float | None = None,
copy: bool = False,
layer: str | None = None,
obsm: str | None = None,
mask_obs: NDArray[np.bool_] | str | None = None,
) -> AnnData | csr_matrix | csc_matrix | np.ndarray | DaskArray | None:
) -> AnnData | _CSMatrix | np.ndarray | DaskArray | None:
"""\
Scale data to unit variance and zero mean.
Expand Down Expand Up @@ -233,7 +233,7 @@ def scale_array(
@scale.register(csr_matrix)
@scale.register(csc_matrix)
def scale_sparse(
X: csr_matrix | csc_matrix,
X: _CSMatrix,
*,
zero_center: bool = True,
max_value: float | None = None,
Expand Down
19 changes: 8 additions & 11 deletions src/scanpy/preprocessing/_scrublet/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@
__all__ = ["Scrublet"]


_CSMatrix = sparse.csr_matrix | sparse.csc_matrix


@dataclass(kw_only=True)
class Scrublet:
"""\
Expand Down Expand Up @@ -65,9 +68,7 @@ class Scrublet:

# init fields

counts_obs: InitVar[sparse.csr_matrix | sparse.csc_matrix | NDArray[np.integer]] = (
field(kw_only=False)
)
counts_obs: InitVar[_CSMatrix | NDArray[np.integer]] = field(kw_only=False)
total_counts_obs: InitVar[NDArray[np.integer] | None] = None
sim_doublet_ratio: float = 2.0
n_neighbors: InitVar[int | None] = None
Expand All @@ -82,15 +83,11 @@ class Scrublet:

_counts_obs: sparse.csc_matrix = field(init=False, repr=False)
_total_counts_obs: NDArray[np.integer] = field(init=False, repr=False)
_counts_obs_norm: sparse.csr_matrix | sparse.csc_matrix = field(
init=False, repr=False
)
_counts_obs_norm: _CSMatrix = field(init=False, repr=False)

_counts_sim: sparse.csr_matrix | sparse.csc_matrix = field(init=False, repr=False)
_counts_sim: _CSMatrix = field(init=False, repr=False)
_total_counts_sim: NDArray[np.integer] = field(init=False, repr=False)
_counts_sim_norm: sparse.csr_matrix | sparse.csc_matrix | None = field(
default=None, init=False, repr=False
)
_counts_sim_norm: _CSMatrix | None = field(default=None, init=False, repr=False)

# Fields set by methods

Expand Down Expand Up @@ -171,7 +168,7 @@ class Scrublet:

def __post_init__(
self,
counts_obs: sparse.csr_matrix | sparse.csc_matrix | NDArray[np.integer],
counts_obs: _CSMatrix | NDArray[np.integer],
total_counts_obs: NDArray[np.integer] | None,
n_neighbors: int | None,
random_state: _LegacyRandom,
Expand Down
14 changes: 8 additions & 6 deletions src/scanpy/preprocessing/_scrublet/sparse_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@

from .._compat import _LegacyRandom

_CSMatrix = sparse.csr_matrix | sparse.csc_matrix


def sparse_multiply(
E: sparse.csr_matrix | sparse.csc_matrix | NDArray[np.float64],
E: _CSMatrix | NDArray[np.float64],
a: float | NDArray[np.float64],
) -> sparse.csr_matrix | sparse.csc_matrix:
) -> _CSMatrix:
"""multiply each row of E by a scalar"""

nrow = E.shape[0]
Expand All @@ -30,11 +32,11 @@ def sparse_multiply(


def sparse_zscore(
E: sparse.csr_matrix | sparse.csc_matrix,
E: _CSMatrix,
*,
gene_mean: NDArray[np.float64] | None = None,
gene_stdev: NDArray[np.float64] | None = None,
) -> sparse.csr_matrix | sparse.csc_matrix:
) -> _CSMatrix:
"""z-score normalize each column of E"""
if gene_mean is None or gene_stdev is None:
gene_means, gene_stdevs = _get_mean_var(E, axis=0)
Expand All @@ -43,12 +45,12 @@ def sparse_zscore(


def subsample_counts(
E: sparse.csr_matrix | sparse.csc_matrix,
E: _CSMatrix,
*,
rate: float,
original_totals,
random_seed: _LegacyRandom = 0,
) -> tuple[sparse.csr_matrix | sparse.csc_matrix, NDArray[np.int64]]:
) -> tuple[_CSMatrix, NDArray[np.int64]]:
if rate < 1:
random_seed = _get_legacy_random(random_seed)
E.data = random_seed.binomial(np.round(E.data).astype(int), rate)
Expand Down
3 changes: 2 additions & 1 deletion src/scanpy/tools/_rank_genes_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from scipy import sparse

_CorrMethod = Literal["benjamini-hochberg", "bonferroni"]
_CSMatrix = sparse.csr_matrix | sparse.csc_matrix

# Used with get_literal_vals
_Method = Literal["logreg", "t-test", "wilcoxon", "t-test_overestim_var"]
Expand All @@ -45,7 +46,7 @@ def _select_top_n(scores: NDArray, n_top: int):


def _ranks(
X: np.ndarray | sparse.csr_matrix | sparse.csc_matrix,
X: np.ndarray | _CSMatrix,
mask_obs: NDArray[np.bool_] | None = None,
mask_obs_rest: NDArray[np.bool_] | None = None,
) -> Generator[tuple[pd.DataFrame, int, int], None, None]:
Expand Down
7 changes: 3 additions & 4 deletions src/scanpy/tools/_score_genes.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,11 @@
_StrIdx = pd.Index[str]
except TypeError: # Sphinx
_StrIdx = pd.Index
_GetSubset = Callable[[_StrIdx], np.ndarray | csr_matrix | csc_matrix]
_CSMatrix = csr_matrix | csc_matrix
_GetSubset = Callable[[_StrIdx], np.ndarray | _CSMatrix]


def _sparse_nanmean(
X: csr_matrix | csc_matrix, axis: Literal[0, 1]
) -> NDArray[np.float64]:
def _sparse_nanmean(X: _CSMatrix, axis: Literal[0, 1]) -> NDArray[np.float64]:
"""
np.nanmean equivalent for sparse matrices
"""
Expand Down

0 comments on commit db3aa89

Please sign in to comment.