Skip to content

Commit

Permalink
ENH: Add the ability to return a dask.delayed object in `rasters.wr…
Browse files Browse the repository at this point in the history
…ite`
  • Loading branch information
remi-braun committed Feb 20, 2025
1 parent b3adeb0 commit 091831c
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 9 deletions.
1 change: 1 addition & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
## 1.46.0 (2025-mm-dd)

- **ENH: Add two functions for converting degrees to and from meters: `rasters.from_deg_to_meters` and `rasters.from_meters_to_deg`**
- **ENH: Add the ability to return a `dask.delayed` object in `rasters.write`**
- FIX: Don't take nodata value into account in `ci.assert_raster_almost_equal_magnitude`

## 1.45.2 (2025-02-17)
Expand Down
2 changes: 1 addition & 1 deletion ci/script_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def dask_env_wrapper(*_args, **_kwargs):
print("Using DASK multithreaded.")
function(*_args, **_kwargs)

with dask.get_or_create_dask_client():
with dask.get_or_create_dask_client(processes=True):
print("Using DASK with local cluster")
function(*_args, **_kwargs)
except ImportError:
Expand Down
35 changes: 34 additions & 1 deletion ci/test_rasters.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def test_read_with_window(tmp_path, raster_path, mask_path, mask):

@s3_env
@dask_env
@pytest.mark.timeout(2, func_only=True)
@pytest.mark.timeout(4, func_only=True)
def test_very_big_file(tmp_path, mask_path):
"""Test read with big files function (should be fast, if not there is something going on...)"""
dem_path = unistra.get_geodatastore() / "GLOBAL" / "EUDEM_v2" / "eudem_wgs84.tif"
Expand Down Expand Up @@ -238,6 +238,39 @@ def test_write_basic(tmp_path, raster_path, xda, xds, xda_dask, ds_dtype):
ci.assert_raster_equal(raster_path, xda_dask_out)


@s3_env
@dask_env
def test_write_dask(tmp_path, raster_path, xda, xds, xda_dask, ds_dtype):
"""Test write (basic) function"""
# DataArray
xda_out = os.path.join(tmp_path, "test_xda.tif")
delayed_1 = rasters.write(xda, xda_out, dtype=ds_dtype, compute=False)
# assert not os.path.isfile(xda_out)

# Dataset
xds_out = os.path.join(tmp_path, "test_xds.tif")
delayed_2 = rasters.write(xds, xds_out, dtype=ds_dtype, compute=False)
# assert os.path.isfile(xds_out)

# With dask
xda_dask_out = os.path.join(tmp_path, "test_xda_dask.tif")
delayed_3 = rasters.write(xda_dask, xda_dask_out, dtype=ds_dtype, compute=False)
# assert not os.path.isfile(xda_dask_out)

# Compute
import dask

dask.compute(delayed_1, delayed_2, delayed_3)
assert os.path.isfile(xda_out)
assert os.path.isfile(xds_out)
assert os.path.isfile(xda_dask_out)

# Tests
ci.assert_raster_equal(raster_path, xda_out)
ci.assert_raster_equal(raster_path, xds_out)
ci.assert_raster_equal(raster_path, xda_dask_out)


@s3_env
@dask_env
def test_mask(tmp_path, xda, xds, xda_dask, mask):
Expand Down
28 changes: 21 additions & 7 deletions sertit/rasters.py
Original file line number Diff line number Diff line change
Expand Up @@ -1060,8 +1060,9 @@ def read(
return xda


def __save_cog_with_dask(xds: AnyXrDataStructure, nodata, dtype, output_path, kwargs):
from dask import optimize
def __save_cog_with_dask(
xds: AnyXrDataStructure, nodata, dtype, output_path, compute, kwargs
):
from odc.geo import cog, xr # noqa

delayed = cog.save_cog_with_dask(
Expand All @@ -1070,8 +1071,10 @@ def __save_cog_with_dask(xds: AnyXrDataStructure, nodata, dtype, output_path, kw
**kwargs,
)

(delayed,) = optimize(delayed)
delayed.compute(optimize_graph=True)
if compute:
delayed.compute(optimize_graph=True)

return delayed


@any_raster_to_xr_ds
Expand All @@ -1080,6 +1083,7 @@ def write(
output_path: AnyPathStrType = None,
tags: dict = None,
write_cogs_with_dask: bool = True,
compute: bool = True,
**kwargs,
) -> None:
"""
Expand Down Expand Up @@ -1110,6 +1114,9 @@ def write(
tags (dict): Tags that will be written in your file
write_cogs_with_dask (bool): If odc-geo and imagecodecs are installed, write your COGs with Dask.
Otherwise, the array will be loaded into memory before writing it on disk (and can cause MemoryErrors).
compute (bool): If True (default) and data is a dask array, then compute and save the data immediately.
If False, return a dask Delayed object. Call ".compute()" on the Delayed object to compute the result later.
Call ``dask.compute(delayed1, delayed2)`` to save multiple delayed files at once.
**kwargs: Overloading metadata, ie :code:`nodata=255` or :code:`dtype=np.uint8`
Examples:
Expand All @@ -1122,6 +1129,7 @@ def write(
>>> # Rewrite it
>>> write(xds, raster_out)
"""
delayed = None
if output_path is None:
logs.deprecation_warning(
"'path' is deprecated in 'rasters.write'. Use 'output_path' instead."
Expand Down Expand Up @@ -1221,10 +1229,14 @@ def write(

# Write cog on disk
try:
__save_cog_with_dask(xds, nodata, dtype, output_path, da_kwargs)
delayed = __save_cog_with_dask(
xds, nodata, dtype, output_path, compute, da_kwargs
)
except Exception:
da_kwargs["stats"] = False
__save_cog_with_dask(xds, nodata, dtype, output_path, da_kwargs)
delayed = __save_cog_with_dask(
xds, nodata, dtype, output_path, compute, da_kwargs
)

is_written = True

Expand Down Expand Up @@ -1259,13 +1271,15 @@ def write(
if "_FillValue" in xds.attrs:
xds.attrs.pop("_FillValue")

xds.rio.to_raster(
delayed = xds.rio.to_raster(
str(output_path),
BIGTIFF=bigtiff,
NUM_THREADS=MAX_CORES,
tags=tags,
compute=compute,
**misc.remove_empty_values(kwargs),
)
return delayed


def _collocate_dataarray(
Expand Down

0 comments on commit 091831c

Please sign in to comment.