Skip to content

Commit

Permalink
Test full Dask support for log1p, normalize_per_cell, `filter_cel…
Browse files Browse the repository at this point in the history
…ls`/`filter_genes` (#2814)
  • Loading branch information
flying-sheep authored Jan 16, 2024
1 parent 4f4b1c3 commit e00932b
Show file tree
Hide file tree
Showing 4 changed files with 120 additions and 41 deletions.
8 changes: 8 additions & 0 deletions scanpy/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,14 @@ class DaskArray:
pass


try:
from zappy.base import ZappyArray
except ImportError:

class ZappyArray:
pass


__all__ = ["cache", "DaskArray", "fullname", "pkg_metadata", "pkg_version"]


Expand Down
51 changes: 41 additions & 10 deletions scanpy/preprocessing/_distributed.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,49 @@
from __future__ import annotations

from typing import TYPE_CHECKING, overload

import numpy as np

# install dask if available
try:
import dask.array as da
except ImportError:
da = None
from scanpy._compat import DaskArray, ZappyArray

if TYPE_CHECKING:
from numpy.typing import ArrayLike


@overload
def materialize_as_ndarray(a: ArrayLike) -> np.ndarray:
...


@overload
def materialize_as_ndarray(a: tuple[ArrayLike]) -> tuple[np.ndarray]:
...


@overload
def materialize_as_ndarray(
a: tuple[ArrayLike, ArrayLike],
) -> tuple[np.ndarray, np.ndarray]:
...

def materialize_as_ndarray(a):

@overload
def materialize_as_ndarray(
a: tuple[ArrayLike, ArrayLike, ArrayLike],
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
...


def materialize_as_ndarray(
a: ArrayLike | tuple[ArrayLike | ZappyArray | DaskArray, ...],
) -> tuple[np.ndarray] | np.ndarray:
"""Convert distributed arrays to ndarrays."""
if type(a) in (list, tuple):
if da is not None and any(isinstance(arr, da.Array) for arr in a):
return da.compute(*a, sync=True)
if not isinstance(a, tuple):
return np.asarray(a)

if not any(isinstance(arr, DaskArray) for arr in a):
return tuple(np.asarray(arr) for arr in a)
return np.asarray(a)

import dask.array as da

return da.compute(*a, sync=True)
12 changes: 6 additions & 6 deletions scanpy/preprocessing/_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def filter_cells(
if max_number is not None:
cell_subset = number_per_cell <= max_number

s = np.sum(~cell_subset)
s = materialize_as_ndarray(np.sum(~cell_subset))
if s > 0:
msg = f"filtered out {s} cells that have "
if min_genes is not None or min_counts is not None:
Expand Down Expand Up @@ -354,7 +354,7 @@ def log1p(


@log1p.register(spmatrix)
def log1p_sparse(X, *, base: Number | None = None, copy: bool = False):
def log1p_sparse(X: spmatrix, *, base: Number | None = None, copy: bool = False):
X = check_array(
X, accept_sparse=("csr", "csc"), dtype=(np.float64, np.float32), copy=copy
)
Expand All @@ -363,7 +363,7 @@ def log1p_sparse(X, *, base: Number | None = None, copy: bool = False):


@log1p.register(np.ndarray)
def log1p_array(X, *, base: Number | None = None, copy: bool = False):
def log1p_array(X: np.ndarray, *, base: Number | None = None, copy: bool = False):
# Can force arrays to be np.ndarrays, but would be useful to not
# X = check_array(X, dtype=(np.float64, np.float32), ensure_2d=False, copy=copy)
if copy:
Expand All @@ -381,7 +381,7 @@ def log1p_array(X, *, base: Number | None = None, copy: bool = False):

@log1p.register(AnnData)
def log1p_anndata(
adata,
adata: AnnData,
*,
base: Number | None = None,
copy: bool = False,
Expand Down Expand Up @@ -564,7 +564,7 @@ def normalize_per_cell( # noqa: PLR0917
else:
raise ValueError('use_rep should be "after", "X" or None')
for layer in layers:
subset, counts = filter_cells(adata.layers[layer], min_counts=min_counts)
_subset, counts = filter_cells(adata.layers[layer], min_counts=min_counts)
temp = normalize_per_cell(adata.layers[layer], after, counts, copy=True)
adata.layers[layer] = temp

Expand All @@ -589,7 +589,7 @@ def normalize_per_cell( # noqa: PLR0917
counts_per_cell += counts_per_cell == 0
counts_per_cell /= counts_per_cell_after
if not issparse(X):
X /= materialize_as_ndarray(counts_per_cell[:, np.newaxis])
X /= counts_per_cell[:, np.newaxis]
else:
sparsefuncs.inplace_row_scale(X, 1 / counts_per_cell)
return X if copy else None
Expand Down
90 changes: 65 additions & 25 deletions scanpy/tests/test_preprocessing_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@

from pathlib import Path

import anndata as ad
import numpy.testing as npt
import pytest
from anndata import AnnData, OldFormatWarning, read_zarr

from scanpy._compat import DaskArray, ZappyArray
from scanpy.preprocessing import (
filter_cells,
filter_genes,
Expand All @@ -17,15 +18,19 @@
from scanpy.testing._pytest.marks import needs

HERE = Path(__file__).parent / Path("_data/")
input_file = str(Path(HERE, "10x-10k-subset.zarr"))
input_file = Path(HERE, "10x-10k-subset.zarr")

DIST_TYPES = (DaskArray, ZappyArray)


pytestmark = [needs.zarr]


@pytest.fixture()
def adata():
a = ad.read_zarr(input_file) # regular anndata
def adata() -> AnnData:
with pytest.warns(OldFormatWarning):
a = read_zarr(input_file) # regular anndata
a.var_names_make_unique()
a.X = a.X[:] # convert to numpy array
return a

Expand All @@ -36,78 +41,111 @@ def adata():
pytest.param("dask", marks=[needs.dask]),
]
)
def adata_dist(request):
def adata_dist(request: pytest.FixtureRequest) -> AnnData:
# regular anndata except for X, which we replace on the next line
a = ad.read_zarr(input_file)
with pytest.warns(OldFormatWarning):
a = read_zarr(input_file)
a.var_names_make_unique()
a.uns["dist-mode"] = request.param
input_file_X = f"{input_file}/X"
if request.param == "direct":
import zappy.direct

a.X = zappy.direct.from_zarr(input_file_X)
yield a
elif request.param == "dask":
import dask.array as da
return a

assert request.param == "dask"
import dask.array as da

a.X = da.from_zarr(input_file_X)
yield a
a.X = da.from_zarr(input_file_X)
return a


def test_log1p(adata, adata_dist):
def test_log1p(adata: AnnData, adata_dist: AnnData):
log1p(adata_dist)
assert isinstance(adata_dist.X, DIST_TYPES)
result = materialize_as_ndarray(adata_dist.X)
log1p(adata)
assert result.shape == adata.shape
assert result.shape == (adata.n_obs, adata.n_vars)
npt.assert_allclose(result, adata.X)


def test_normalize_per_cell(adata, adata_dist):
if adata_dist.uns["dist-mode"] == "dask":
pytest.xfail("TODO: Test broken for dask")
def test_normalize_per_cell(
request: pytest.FixtureRequest, adata: AnnData, adata_dist: AnnData
):
if isinstance(adata_dist.X, DaskArray):
request.node.add_marker(
pytest.mark.xfail(
reason="normalize_per_cell deprecated and broken for Dask"
)
)
normalize_per_cell(adata_dist)
assert isinstance(adata_dist.X, DIST_TYPES)
result = materialize_as_ndarray(adata_dist.X)
normalize_per_cell(adata)
assert result.shape == adata.shape
assert result.shape == (adata.n_obs, adata.n_vars)
npt.assert_allclose(result, adata.X)


def test_normalize_total(adata, adata_dist):
def test_normalize_total(adata: AnnData, adata_dist: AnnData):
normalize_total(adata_dist)
assert isinstance(adata_dist.X, DIST_TYPES)
result = materialize_as_ndarray(adata_dist.X)
normalize_total(adata)
assert result.shape == adata.shape
assert result.shape == (adata.n_obs, adata.n_vars)
npt.assert_allclose(result, adata.X)


def test_filter_cells(adata, adata_dist):
def test_filter_cells_array(adata: AnnData, adata_dist: AnnData):
cell_subset_dist, number_per_cell_dist = filter_cells(adata_dist.X, min_genes=3)
assert isinstance(cell_subset_dist, DIST_TYPES)
assert isinstance(number_per_cell_dist, DIST_TYPES)

cell_subset, number_per_cell = filter_cells(adata.X, min_genes=3)
npt.assert_allclose(materialize_as_ndarray(cell_subset_dist), cell_subset)
npt.assert_allclose(materialize_as_ndarray(number_per_cell_dist), number_per_cell)


def test_filter_cells(adata: AnnData, adata_dist: AnnData):
filter_cells(adata_dist, min_genes=3)
assert isinstance(adata_dist.X, DIST_TYPES)
result = materialize_as_ndarray(adata_dist.X)
filter_cells(adata, min_genes=3)

assert result.shape == adata.shape
assert result.shape == (adata.n_obs, adata.n_vars)
npt.assert_array_equal(adata_dist.obs["n_genes"], adata.obs["n_genes"])
npt.assert_allclose(result, adata.X)


def test_filter_genes(adata, adata_dist):
def test_filter_genes_array(adata: AnnData, adata_dist: AnnData):
gene_subset_dist, number_per_gene_dist = filter_genes(adata_dist.X, min_cells=2)
assert isinstance(gene_subset_dist, DIST_TYPES)
assert isinstance(number_per_gene_dist, DIST_TYPES)

gene_subset, number_per_gene = filter_genes(adata.X, min_cells=2)
npt.assert_allclose(materialize_as_ndarray(gene_subset_dist), gene_subset)
npt.assert_allclose(materialize_as_ndarray(number_per_gene_dist), number_per_gene)


def test_filter_genes(adata: AnnData, adata_dist: AnnData):
filter_genes(adata_dist, min_cells=2)
assert isinstance(adata_dist.X, DIST_TYPES)
result = materialize_as_ndarray(adata_dist.X)
filter_genes(adata, min_cells=2)
assert result.shape == adata.shape
assert result.shape == (adata.n_obs, adata.n_vars)
npt.assert_allclose(result, adata.X)


def test_write_zarr(adata, adata_dist):
def test_write_zarr(adata: AnnData, adata_dist: AnnData):
import zarr

log1p(adata_dist)
assert isinstance(adata_dist.X, DIST_TYPES)
temp_store = zarr.TempStore()
chunks = adata_dist.X.chunks
if isinstance(chunks[0], tuple):
chunks = (chunks[0][0],) + chunks[1]

# write metadata using regular anndata
adata.write_zarr(temp_store, chunks)
if adata_dist.uns["dist-mode"] == "dask":
Expand All @@ -116,7 +154,9 @@ def test_write_zarr(adata, adata_dist):
adata_dist.X.to_zarr(temp_store.dir_path("X"), chunks)
else:
assert False, "add branch for new dist-mode"

# read back as zarr directly and check it is the same as adata.X
adata_log1p = ad.read_zarr(temp_store)
with pytest.warns(OldFormatWarning, match="without encoding metadata"):
adata_log1p = read_zarr(temp_store)
log1p(adata)
npt.assert_allclose(adata_log1p.X, adata.X)

0 comments on commit e00932b

Please sign in to comment.