From 091831c5b6df4e7d8f99fe82fa69ee0f665ffe52 Mon Sep 17 00:00:00 2001 From: BRAUN REMI Date: Thu, 20 Feb 2025 17:16:26 +0100 Subject: [PATCH] ENH: Add the ability to return a `dask.delayed` object in `rasters.write` --- CHANGES.md | 1 + ci/script_utils.py | 2 +- ci/test_rasters.py | 35 ++++++++++++++++++++++++++++++++++- sertit/rasters.py | 28 +++++++++++++++++++++------- 4 files changed, 57 insertions(+), 9 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index 026ff46..3248eef 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -3,6 +3,7 @@ ## 1.46.0 (2025-mm-dd) - **ENH: Add two functions for converting degrees to and from meters: `rasters.from_deg_to_meters` and `rasters.from_meters_to_deg`** +- **ENH: Add the ability to return a `dask.delayed` object in `rasters.write`** - FIX: Don't take nodata value into account in `ci.assert_raster_almost_equal_magnitude` ## 1.45.2 (2025-02-17) diff --git a/ci/script_utils.py b/ci/script_utils.py index 10ba6da..f1356bd 100644 --- a/ci/script_utils.py +++ b/ci/script_utils.py @@ -66,7 +66,7 @@ def dask_env_wrapper(*_args, **_kwargs): print("Using DASK multithreaded.") function(*_args, **_kwargs) - with dask.get_or_create_dask_client(): + with dask.get_or_create_dask_client(processes=True): print("Using DASK with local cluster") function(*_args, **_kwargs) except ImportError: diff --git a/ci/test_rasters.py b/ci/test_rasters.py index 4baaf0f..60505a1 100644 --- a/ci/test_rasters.py +++ b/ci/test_rasters.py @@ -200,7 +200,7 @@ def test_read_with_window(tmp_path, raster_path, mask_path, mask): @s3_env @dask_env -@pytest.mark.timeout(2, func_only=True) +@pytest.mark.timeout(4, func_only=True) def test_very_big_file(tmp_path, mask_path): """Test read with big files function (should be fast, if not there is something going on...)""" dem_path = unistra.get_geodatastore() / "GLOBAL" / "EUDEM_v2" / "eudem_wgs84.tif" @@ -238,6 +238,39 @@ def test_write_basic(tmp_path, raster_path, xda, xds, xda_dask, ds_dtype): ci.assert_raster_equal(raster_path, xda_dask_out) +@s3_env +@dask_env +def test_write_dask(tmp_path, raster_path, xda, xds, xda_dask, ds_dtype): + """Test write (basic) function""" + # DataArray + xda_out = os.path.join(tmp_path, "test_xda.tif") + delayed_1 = rasters.write(xda, xda_out, dtype=ds_dtype, compute=False) + # assert not os.path.isfile(xda_out) + + # Dataset + xds_out = os.path.join(tmp_path, "test_xds.tif") + delayed_2 = rasters.write(xds, xds_out, dtype=ds_dtype, compute=False) + # assert os.path.isfile(xds_out) + + # With dask + xda_dask_out = os.path.join(tmp_path, "test_xda_dask.tif") + delayed_3 = rasters.write(xda_dask, xda_dask_out, dtype=ds_dtype, compute=False) + # assert not os.path.isfile(xda_dask_out) + + # Compute + import dask + + dask.compute(delayed_1, delayed_2, delayed_3) + assert os.path.isfile(xda_out) + assert os.path.isfile(xds_out) + assert os.path.isfile(xda_dask_out) + + # Tests + ci.assert_raster_equal(raster_path, xda_out) + ci.assert_raster_equal(raster_path, xds_out) + ci.assert_raster_equal(raster_path, xda_dask_out) + + @s3_env @dask_env def test_mask(tmp_path, xda, xds, xda_dask, mask): diff --git a/sertit/rasters.py b/sertit/rasters.py index 66db53e..2d47e7e 100644 --- a/sertit/rasters.py +++ b/sertit/rasters.py @@ -1060,8 +1060,9 @@ def read( return xda -def __save_cog_with_dask(xds: AnyXrDataStructure, nodata, dtype, output_path, kwargs): - from dask import optimize +def __save_cog_with_dask( + xds: AnyXrDataStructure, nodata, dtype, output_path, compute, kwargs +): from odc.geo import cog, xr # noqa delayed = cog.save_cog_with_dask( @@ -1070,8 +1071,10 @@ def __save_cog_with_dask(xds: AnyXrDataStructure, nodata, dtype, output_path, kw **kwargs, ) - (delayed,) = optimize(delayed) - delayed.compute(optimize_graph=True) + if compute: + delayed.compute(optimize_graph=True) + + return delayed @any_raster_to_xr_ds @@ -1080,6 +1083,7 @@ def write( output_path: AnyPathStrType = None, tags: dict = None, write_cogs_with_dask: bool = True, + compute: bool = True, **kwargs, ) -> None: """ @@ -1110,6 +1114,9 @@ def write( tags (dict): Tags that will be written in your file write_cogs_with_dask (bool): If odc-geo and imagecodecs are installed, write your COGs with Dask. Otherwise, the array will be loaded into memory before writing it on disk (and can cause MemoryErrors). + compute (bool): If True (default) and data is a dask array, then compute and save the data immediately. + If False, return a dask Delayed object. Call ".compute()" on the Delayed object to compute the result later. + Call ``dask.compute(delayed1, delayed2)`` to save multiple delayed files at once. **kwargs: Overloading metadata, ie :code:`nodata=255` or :code:`dtype=np.uint8` Examples: @@ -1122,6 +1129,7 @@ def write( >>> # Rewrite it >>> write(xds, raster_out) """ + delayed = None if output_path is None: logs.deprecation_warning( "'path' is deprecated in 'rasters.write'. Use 'output_path' instead." @@ -1221,10 +1229,14 @@ def write( # Write cog on disk try: - __save_cog_with_dask(xds, nodata, dtype, output_path, da_kwargs) + delayed = __save_cog_with_dask( + xds, nodata, dtype, output_path, compute, da_kwargs + ) except Exception: da_kwargs["stats"] = False - __save_cog_with_dask(xds, nodata, dtype, output_path, da_kwargs) + delayed = __save_cog_with_dask( + xds, nodata, dtype, output_path, compute, da_kwargs + ) is_written = True @@ -1259,13 +1271,15 @@ def write( if "_FillValue" in xds.attrs: xds.attrs.pop("_FillValue") - xds.rio.to_raster( + delayed = xds.rio.to_raster( str(output_path), BIGTIFF=bigtiff, NUM_THREADS=MAX_CORES, tags=tags, + compute=compute, **misc.remove_empty_values(kwargs), ) + return delayed def _collocate_dataarray(