Skip to content

Commit

Permalink
Allow probabilities
Browse files Browse the repository at this point in the history
  • Loading branch information
flying-sheep committed Dec 20, 2024
1 parent 91c3f15 commit 179c9f2
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 23 deletions.
28 changes: 21 additions & 7 deletions src/scanpy/get/get.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,8 @@ def _check_mask(
data: AnnData | np.ndarray | CSMatrix | DaskArray,
mask: str | M,
dim: Literal["obs", "var"],
*,
allow_probabilities: bool = False,
) -> M: # Could also be a series, but should be one or the other
"""
Validate mask argument
Expand All @@ -505,33 +507,45 @@ def _check_mask(
data
Annotated data matrix or numpy array.
mask
Mask or probabilities.
Mask (or probabilities if `allow_probabilities=True`).
Either an appropriatley sized array, or name of a column.
dim
The dimension being masked.
allow_probabilities
Whether to allow probabilities as `mask`
"""
if mask is None:
return mask
desc = "mask/probabilities" if allow_probabilities else "mask"

if isinstance(mask, str):
if not isinstance(data, AnnData):
msg = "Cannot refer to mask with string without providing anndata object as argument"
msg = f"Cannot refer to {desc} with string without providing anndata object as argument"
raise ValueError(msg)

annot: pd.DataFrame = getattr(data, dim)
if mask not in annot.columns:
msg = (
f"Did not find `adata.{dim}[{mask!r}]`. "
f"Either add the mask first to `adata.{dim}`"
"or consider using the mask argument with a boolean array."
f"Either add the {desc} first to `adata.{dim}`"
f"or consider using the {desc} argument with an array."
)
raise ValueError(msg)
mask_array = annot[mask].to_numpy()
else:
if len(mask) != data.shape[0 if dim == "obs" else 1]:
raise ValueError("The shape of the mask do not match the data.")
msg = f"The shape of the {desc} do not match the data."
raise ValueError(msg)
mask_array = mask

if not pd.api.types.is_bool_dtype(mask_array.dtype):
raise ValueError("Mask array must be boolean.")
is_bool = pd.api.types.is_bool_dtype(mask_array.dtype)
if not allow_probabilities and not is_bool:
msg = "Mask array must be boolean."
raise ValueError(msg)

Check warning on line 544 in src/scanpy/get/get.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/get/get.py#L543-L544

Added lines #L543 - L544 were not covered by tests
elif allow_probabilities and not (
is_bool or pd.api.types.is_float_dtype(mask_array.dtype)
):
msg = f"{desc} array must be boolean or floating point."
raise ValueError(msg)

Check warning on line 549 in src/scanpy/get/get.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/get/get.py#L548-L549

Added lines #L548 - L549 were not covered by tests

return mask_array
10 changes: 7 additions & 3 deletions src/scanpy/preprocessing/_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -885,6 +885,7 @@ def sample(
Rows correspond to cells and columns to genes.
fraction
Sample to this `fraction` of the number of observations or variables.
(All of them, even if there are `0`s/`False`s in `p`.)
This can be larger than 1.0, if `replace=True`.
See `axis` and `replace`.
n
Expand All @@ -899,8 +900,9 @@ def sample(
axis
Sample `obs`\\ ervations (axis 0) or `var`\\ iables (axis 1).
p
Drawing probabilities or mask.
Either an appropriatley sized array, or name of a column.
Drawing probabilities (floats) or mask (bools).
Either an `axis`-sized array, or the name of a column.
If `p` is an array of probabilities, it must sum to 1.
Returns
-------
Expand All @@ -917,7 +919,9 @@ def sample(
msg = "Inplace sampling (`copy=False`) is not implemented for backed objects."
raise NotImplementedError(msg)
axis, axis_name = _resolve_axis(axis)
p = _check_mask(data, p, dim=axis_name)
p = _check_mask(data, p, dim=axis_name, allow_probabilities=True)
if p is not None and p.dtype == bool:
p = p.astype(np.float64) / p.sum()
old_n = data.shape[axis]
match (fraction, n):
case (None, None):
Expand Down
53 changes: 40 additions & 13 deletions tests/test_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
from collections.abc import Callable
from typing import Any, Literal

from numpy.typing import NDArray

CSMatrix = sp.csc_matrix | sp.csr_matrix


Expand Down Expand Up @@ -144,31 +146,55 @@ def test_normalize_per_cell():
assert adata.X.sum(axis=1).tolist() == adata_sparse.X.sum(axis=1).A1.tolist()


def _random_probs(n: int, frac_zero: float) -> NDArray[np.float64]:
"""
Generate a random probability distribution of `n` values between 0 and 1.
"""
probs = np.random.randint(0, 10000, n).astype(np.float64)
probs[probs < np.quantile(probs, frac_zero)] = 0
probs /= probs.sum()
np.testing.assert_almost_equal(probs.sum(), 1)
return probs


@pytest.mark.parametrize("array_type", ARRAY_TYPES)
@pytest.mark.parametrize("which", ["copy", "inplace", "array"])
@pytest.mark.parametrize(
("axis", "fraction", "n", "replace", "expected"),
("axis", "f_or_n", "replace"),
[
pytest.param(0, 40, False, id="obs-40-no_replace"),
pytest.param(0, 0.1, False, id="obs-0.1-no_replace"),
pytest.param(0, 201, True, id="obs-201-replace"),
pytest.param(0, 1, True, id="obs-1-replace"),
pytest.param(1, 10, False, id="var-10-no_replace"),
pytest.param(1, 11, True, id="var-11-replace"),
pytest.param(1, 2.0, True, id="var-2.0-replace"),
],
)
@pytest.mark.parametrize(
"ps",
[
pytest.param(0, None, 40, False, 40, id="obs-40-no_replace"),
pytest.param(0, 0.1, None, False, 20, id="obs-0.1-no_replace"),
pytest.param(0, None, 201, True, 201, id="obs-201-replace"),
pytest.param(0, None, 1, True, 1, id="obs-1-replace"),
pytest.param(1, None, 10, False, 10, id="var-10-no_replace"),
pytest.param(1, None, 11, True, 11, id="var-11-replace"),
pytest.param(1, 2.0, None, True, 20, id="var-2.0-replace"),
dict(obs=None, var=None),
dict(obs=np.tile([True, False], 100), var=np.tile([True, False], 5)),
dict(obs=_random_probs(200, 0.3), var=_random_probs(10, 0.7)),
],
ids=["all", "mask", "p"],
)
def test_sample(
*,
request: pytest.FixtureRequest,
array_type: Callable[[np.ndarray], np.ndarray | CSMatrix],
which: Literal["copy", "inplace", "array"],
axis: Literal[0, 1],
fraction: float | None,
n: int | None,
f_or_n: float | int, # noqa: PYI041
replace: bool,
expected: int,
ps: dict[Literal["obs", "var"], NDArray[np.bool_] | None],
):
adata = AnnData(array_type(np.ones((200, 10))))
p = ps["obs" if axis == 0 else "var"]
expected = int(adata.shape[axis] * f_or_n) if isinstance(f_or_n, float) else f_or_n
if p is not None and not replace and expected > (n_possible := (p != 0).sum()):
request.applymarker(pytest.xfail(f"Can’t draw {expected} out of {n_possible}"))

# ignoring this warning declaratively is a pain so do it here
if find_spec("dask"):
Expand All @@ -182,12 +208,13 @@ def test_sample(
)
rv = sc.pp.sample(
adata.X if which == "array" else adata,
fraction,
n=n,
f_or_n if isinstance(f_or_n, float) else None,
n=f_or_n if isinstance(f_or_n, int) else None,
replace=replace,
axis=axis,
# `copy` only effects AnnData inputs
copy=dict(copy=True, inplace=False, array=False)[which],
p=p,
)

match which:
Expand Down

0 comments on commit 179c9f2

Please sign in to comment.