From 05dbb84a77c9e2e3db1d4354c3741c86d11a6f5b Mon Sep 17 00:00:00 2001 From: Tom White <tom.e.white@gmail.com> Date: Mon, 20 Jan 2025 14:30:00 +0000 Subject: [PATCH] Fix windowing for cubed for regular chunks case --- .github/workflows/cubed.yml | 5 +++-- sgkit/tests/test_window.py | 2 +- sgkit/window.py | 21 +++++++++++++++------ 3 files changed, 19 insertions(+), 9 deletions(-) diff --git a/.github/workflows/cubed.yml b/.github/workflows/cubed.yml index 5a23813a0..d432640ad 100644 --- a/.github/workflows/cubed.yml +++ b/.github/workflows/cubed.yml @@ -30,12 +30,13 @@ jobs: - name: Test with pytest run: | - pytest -v sgkit/tests/test_{aggregation,association,hwe,pca}.py \ + pytest -v sgkit/tests/test_{aggregation,association,hwe,pca,window}.py \ -k "test_count_call_alleles or \ test_gwas_linear_regression or \ test_hwep or \ test_sample_stats or \ (test_count_variant_alleles and not test_count_variant_alleles__chunked[call_genotype]) or \ (test_variant_stats and not test_variant_stats__chunks[chunks2-False]) or \ - (test_pca__array_backend and tsqr)" or \ + (test_pca__array_backend and tsqr) or \ + (test_window and not 12-5-4-4)" \ --use-cubed diff --git a/sgkit/tests/test_window.py b/sgkit/tests/test_window.py index 3c518f560..65b01168d 100644 --- a/sgkit/tests/test_window.py +++ b/sgkit/tests/test_window.py @@ -1,12 +1,12 @@ import re import allel -import dask.array as da import numpy as np import pandas as pd import pytest import xarray as xr +import sgkit.distarray as da from sgkit import ( simulate_genotype_call_dataset, window_by_interval, diff --git a/sgkit/window.py b/sgkit/window.py index 2f0ba5cae..140930011 100644 --- a/sgkit/window.py +++ b/sgkit/window.py @@ -1,9 +1,10 @@ +import functools from typing import Any, Callable, Hashable, Iterable, Optional, Tuple, Union -import dask.array as da import numpy as np from xarray import Dataset +import sgkit.distarray as da from sgkit import variables from sgkit.model import get_contigs, num_contigs from sgkit.utils import conditional_merge_datasets, create_dataset @@ -510,8 +511,15 @@ def window_statistic( and len(window_stops) == 1 and window_stops == np.array([values.shape[0]]) ): + out = da.map_blocks( + functools.partial(statistic, **kwargs), + values, + dtype=dtype, + chunks=values.chunks[1:], + drop_axis=0, + ) # call expand_dims to add back window dimension (size 1) - return da.expand_dims(statistic(values, **kwargs), axis=0) + return da.expand_dims(out, axis=0) window_lengths = window_stops - window_starts depth = np.max(window_lengths) # type: ignore[no-untyped-call] @@ -536,10 +544,10 @@ def window_statistic( chunk_offsets = _sizes_to_start_offsets(windows_per_chunk) - def blockwise_moving_stat(x: ArrayLike, block_info: Any = None) -> ArrayLike: - if block_info is None or len(block_info) == 0: + def blockwise_moving_stat(x: ArrayLike, block_id: Any = None) -> ArrayLike: + if block_id is None: return np.array([]) - chunk_number = block_info[0]["chunk-location"][0] + chunk_number = block_id[0] chunk_offset_start = chunk_offsets[chunk_number] chunk_offset_stop = chunk_offsets[chunk_number + 1] chunk_window_starts = rel_window_starts[chunk_offset_start:chunk_offset_stop] @@ -559,8 +567,9 @@ def blockwise_moving_stat(x: ArrayLike, block_info: Any = None) -> ArrayLike: depth = {0: depth} # new chunks are same except in first axis new_chunks = tuple([tuple(windows_per_chunk)] + list(desired_chunks[1:])) # type: ignore - return values.map_overlap( + return da.map_overlap( blockwise_moving_stat, + values, dtype=dtype, chunks=new_chunks, depth=depth,