Skip to content

Commit

Permalink
Fix aggregate for more than two groups (#2965)
Browse files Browse the repository at this point in the history
* Fix aggregate for more than two groups

* release note
  • Loading branch information
ivirshup authored Mar 27, 2024
1 parent 65f567e commit 3ceb740
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 1 deletion.
13 changes: 13 additions & 0 deletions docs/release-notes/1.10.1.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
### 1.10.1 {small}`the future`


```{rubric} Docs
```

```{rubric} Bug fixes
```

* Fix `aggregate` when aggregating by more than two groups {pr}`2965` {smaller}`I Virshup`

```{rubric} Performance
```
3 changes: 3 additions & 0 deletions docs/release-notes/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@

## Version 1.10

```{include} /release-notes/1.10.1.md
```

```{include} /release-notes/1.10.0.md
```

Expand Down
2 changes: 1 addition & 1 deletion scanpy/get/_aggregated.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ def _combine_categories(

# Calculating result codes
factors = np.ones(len(cols) + 1, dtype=np.int32) # First factor needs to be 1
np.cumsum(n_categories[::-1], out=factors[1:])
np.cumprod(n_categories[::-1], out=factors[1:])
factors = factors[:-1][::-1]

code_array = np.zeros((len(cols), df.shape[0]), dtype=np.int32)
Expand Down
16 changes: 16 additions & 0 deletions scanpy/tests/test_aggregated.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,3 +460,19 @@ def test_dispatch_not_implemented():
adata = sc.datasets.blobs()
with pytest.raises(NotImplementedError):
sc.get.aggregate(adata.X, adata.obs["blobs"], "sum")


def test_factors():
from itertools import product

obs = pd.DataFrame(
product(range(5), range(5), range(5), range(5)), columns=list("abcd")
)
obs.index = [f"cell_{i:04d}" for i in range(obs.shape[0])]
adata = ad.AnnData(
X=np.arange(obs.shape[0]).reshape(-1, 1),
obs=obs,
)

res = sc.get.aggregate(adata, by=["a", "b", "c", "d"], func="sum")
np.testing.assert_equal(res.layers["sum"], adata.X)

0 comments on commit 3ceb740

Please sign in to comment.