From 05dbb84a77c9e2e3db1d4354c3741c86d11a6f5b Mon Sep 17 00:00:00 2001
From: Tom White <tom.e.white@gmail.com>
Date: Mon, 20 Jan 2025 14:30:00 +0000
Subject: [PATCH] Fix windowing for cubed for regular chunks case

---
 .github/workflows/cubed.yml |  5 +++--
 sgkit/tests/test_window.py  |  2 +-
 sgkit/window.py             | 21 +++++++++++++++------
 3 files changed, 19 insertions(+), 9 deletions(-)

diff --git a/.github/workflows/cubed.yml b/.github/workflows/cubed.yml
index 5a23813a0..d432640ad 100644
--- a/.github/workflows/cubed.yml
+++ b/.github/workflows/cubed.yml
@@ -30,12 +30,13 @@ jobs:
 
     - name: Test with pytest
       run: |
-        pytest -v sgkit/tests/test_{aggregation,association,hwe,pca}.py \
+        pytest -v sgkit/tests/test_{aggregation,association,hwe,pca,window}.py \
           -k "test_count_call_alleles or \
               test_gwas_linear_regression or \
               test_hwep or \
               test_sample_stats or \
               (test_count_variant_alleles and not test_count_variant_alleles__chunked[call_genotype]) or \
               (test_variant_stats and not test_variant_stats__chunks[chunks2-False]) or \
-              (test_pca__array_backend and tsqr)" or \
+              (test_pca__array_backend and tsqr) or \
+              (test_window and not 12-5-4-4)" \
           --use-cubed
diff --git a/sgkit/tests/test_window.py b/sgkit/tests/test_window.py
index 3c518f560..65b01168d 100644
--- a/sgkit/tests/test_window.py
+++ b/sgkit/tests/test_window.py
@@ -1,12 +1,12 @@
 import re
 
 import allel
-import dask.array as da
 import numpy as np
 import pandas as pd
 import pytest
 import xarray as xr
 
+import sgkit.distarray as da
 from sgkit import (
     simulate_genotype_call_dataset,
     window_by_interval,
diff --git a/sgkit/window.py b/sgkit/window.py
index 2f0ba5cae..140930011 100644
--- a/sgkit/window.py
+++ b/sgkit/window.py
@@ -1,9 +1,10 @@
+import functools
 from typing import Any, Callable, Hashable, Iterable, Optional, Tuple, Union
 
-import dask.array as da
 import numpy as np
 from xarray import Dataset
 
+import sgkit.distarray as da
 from sgkit import variables
 from sgkit.model import get_contigs, num_contigs
 from sgkit.utils import conditional_merge_datasets, create_dataset
@@ -510,8 +511,15 @@ def window_statistic(
         and len(window_stops) == 1
         and window_stops == np.array([values.shape[0]])
     ):
+        out = da.map_blocks(
+            functools.partial(statistic, **kwargs),
+            values,
+            dtype=dtype,
+            chunks=values.chunks[1:],
+            drop_axis=0,
+        )
         # call expand_dims to add back window dimension (size 1)
-        return da.expand_dims(statistic(values, **kwargs), axis=0)
+        return da.expand_dims(out, axis=0)
 
     window_lengths = window_stops - window_starts
     depth = np.max(window_lengths)  # type: ignore[no-untyped-call]
@@ -536,10 +544,10 @@ def window_statistic(
 
     chunk_offsets = _sizes_to_start_offsets(windows_per_chunk)
 
-    def blockwise_moving_stat(x: ArrayLike, block_info: Any = None) -> ArrayLike:
-        if block_info is None or len(block_info) == 0:
+    def blockwise_moving_stat(x: ArrayLike, block_id: Any = None) -> ArrayLike:
+        if block_id is None:
             return np.array([])
-        chunk_number = block_info[0]["chunk-location"][0]
+        chunk_number = block_id[0]
         chunk_offset_start = chunk_offsets[chunk_number]
         chunk_offset_stop = chunk_offsets[chunk_number + 1]
         chunk_window_starts = rel_window_starts[chunk_offset_start:chunk_offset_stop]
@@ -559,8 +567,9 @@ def blockwise_moving_stat(x: ArrayLike, block_info: Any = None) -> ArrayLike:
         depth = {0: depth}
         # new chunks are same except in first axis
         new_chunks = tuple([tuple(windows_per_chunk)] + list(desired_chunks[1:]))  # type: ignore
-    return values.map_overlap(
+    return da.map_overlap(
         blockwise_moving_stat,
+        values,
         dtype=dtype,
         chunks=new_chunks,
         depth=depth,