Skip to content

Commit

Permalink
Add _Axis.getattr_from and _Axis.getitem_from. (#183)
Browse files Browse the repository at this point in the history
This aims to eliminate the annoyance (and potential error) of writing

    thing = whatever.obs if axis is _Axis.OBS else whatever.var

by letting you say

    thing = axis.getattr_from(whatever)

instead.
  • Loading branch information
thetorpedodog authored Nov 30, 2023
1 parent 7a922b0 commit c6f6fd8
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 19 deletions.
87 changes: 68 additions & 19 deletions python-spec/src/somacore/query/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
Tuple,
TypeVar,
Union,
overload,
)

import anndata
Expand All @@ -19,7 +20,7 @@
import pandas as pd
import pyarrow as pa
from scipy import sparse
from typing_extensions import Literal, Protocol, Self, TypedDict, assert_never
from typing_extensions import Literal, Protocol, Self, TypedDict

from .. import data
from .. import measurement
Expand Down Expand Up @@ -368,14 +369,10 @@ def _read_axis_dataframe(
) -> pa.Table:
"""Reads the specified axis. Will cache join IDs if not present."""
column_names = axis_column_names.get(axis.value)
if axis is _Axis.OBS:
axis_df = self._obs_df
axis_query = self._matrix_axis_query.obs
elif axis is _Axis.VAR:
axis_df = self._var_df
axis_query = self._matrix_axis_query.var
else:
assert_never(axis) # must be obs or var

axis_df = axis.getattr_from(self, pre="_", suf="_df")
assert isinstance(axis_df, data.DataFrame)
axis_query = axis.getattr_from(self._matrix_axis_query)

# If we can cache join IDs, prepare to add them to the cache.
joinids_cached = self._joinids._is_cached(axis)
Expand Down Expand Up @@ -420,19 +417,24 @@ def _axisp_inner(
axis: "_Axis",
layer: str,
) -> data.SparseRead:
key = axis.value + "p"

if key not in self._ms:
raise ValueError(f"Measurement does not contain {key} data")
p_name = f"{axis.value}p"
try:
axisp = axis.getitem_from(self._ms, suf="p")
except KeyError as ke:
raise ValueError(f"Measurement does not contain {p_name} data") from ke

axisp = self._ms.obsp if axis is _Axis.OBS else self._ms.varp
if not (layer and layer in axisp):
raise ValueError(f"Must specify '{key}' layer")
if not isinstance(axisp[layer], data.SparseNDArray):
raise TypeError(f"Unexpected SOMA type stored in '{key}' layer")
try:
ap_layer = axisp[layer]
except KeyError as ke:
raise ValueError(f"layer {layer!r} is not available in {p_name}") from ke
if not isinstance(ap_layer, data.SparseNDArray):
raise TypeError(
f"Unexpected SOMA type {type(ap_layer).__name__}"
f" stored in {p_name} layer {layer!r}"
)

joinids = getattr(self._joinids, axis.value)
return axisp[layer].read((joinids, joinids))
return ap_layer.read((joinids, joinids))

@property
def _obs_df(self) -> data.DataFrame:
Expand Down Expand Up @@ -493,6 +495,30 @@ class _Axis(enum.Enum):
def value(self) -> Literal["obs", "var"]:
return super().value

@overload
def getattr_from(self, __source: "_HasObsVar[_T]") -> "_T":
...

@overload
def getattr_from(
self, __source: Any, *, pre: Literal[""], suf: Literal[""]
) -> object:
...

@overload
def getattr_from(self, __source: Any, *, pre: str = ..., suf: str = ...) -> object:
...

def getattr_from(self, __source: Any, *, pre: str = "", suf: str = "") -> object:
"""Equivalent to ``something.<pre><obs/var><suf>``."""
return getattr(__source, pre + self.value + suf)

def getitem_from(
self, __source: Mapping[str, "_T"], *, pre: str = "", suf: str = ""
) -> "_T":
"""Equivalent to ``something[pre + "obs"/"var" + suf]``."""
return __source[pre + self.value + suf]


@attrs.define(frozen=True)
class _MatrixAxisQuery:
Expand Down Expand Up @@ -605,6 +631,14 @@ def _to_numpy(it: _Numpyable) -> np.ndarray:
return it.to_numpy()


#
# Type shenanigans
#

_T = TypeVar("_T")
_T_co = TypeVar("_T_co", covariant=True)


class _Experimentish(Protocol):
"""The API we need from an Experiment."""

Expand All @@ -615,3 +649,18 @@ def ms(self) -> Mapping[str, measurement.Measurement]:
@property
def obs(self) -> data.DataFrame:
...


class _HasObsVar(Protocol[_T_co]):
"""Something which has an ``obs`` and ``var`` field.
Used to give nicer type inference in :meth:`_Axis.getattr_from`.
"""

@property
def obs(self) -> _T_co:
...

@property
def var(self) -> _T_co:
...
23 changes: 23 additions & 0 deletions python-spec/testing/test_query_axis.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from typing import Any, Tuple

import attrs
import numpy as np
import pytest
from pytest import mark

import somacore
from somacore import options
from somacore.query import query


@mark.parametrize(
Expand Down Expand Up @@ -49,3 +51,24 @@ def test_canonicalization_nparray() -> None:
def test_canonicalization_bad(coords) -> None:
with pytest.raises(TypeError):
somacore.AxisQuery(coords=coords)


@attrs.define(frozen=True)
class IHaveObsVarStuff:
obs: int
var: int
the_obs_suf: str
the_var_suf: str


def test_axis_helpers() -> None:
thing = IHaveObsVarStuff(obs=1, var=2, the_obs_suf="observe", the_var_suf="vary")
assert 1 == query._Axis.OBS.getattr_from(thing)
assert 2 == query._Axis.VAR.getattr_from(thing)
assert "observe" == query._Axis.OBS.getattr_from(thing, pre="the_", suf="_suf")
assert "vary" == query._Axis.VAR.getattr_from(thing, pre="the_", suf="_suf")
ovdict = {"obs": "erve", "var": "y", "i_obscure": "hide", "i_varcure": "???"}
assert "erve" == query._Axis.OBS.getitem_from(ovdict)
assert "y" == query._Axis.VAR.getitem_from(ovdict)
assert "hide" == query._Axis.OBS.getitem_from(ovdict, pre="i_", suf="cure")
assert "???" == query._Axis.VAR.getitem_from(ovdict, pre="i_", suf="cure")

0 comments on commit c6f6fd8

Please sign in to comment.