diff --git a/CHANGES.md b/CHANGES.md index a709da95..20c4f315 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -2,6 +2,7 @@ # 3.0.0a5 (TBD) * allow the definition of `geographic_crs` used in the `geographic_bounds` property (https://github.com/cogeotiff/rio-tiler/pull/458) +* use `contextlib.ExitStack` to better manager opening/closing rasterio dataset (https://github.com/cogeotiff/rio-tiler/pull/459) # 3.0.0a4 (2021-11-10) diff --git a/rio_tiler/io/cogeo.py b/rio_tiler/io/cogeo.py index 58ef9647..2d1e3e7f 100644 --- a/rio_tiler/io/cogeo.py +++ b/rio_tiler/io/cogeo.py @@ -1,5 +1,6 @@ """rio_tiler.io.cogeo: raster processing.""" +import contextlib import warnings from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union @@ -101,6 +102,9 @@ class COGReader(BaseReader): # _kwargs is used avoid having to set those values on each method call. _kwargs: Dict[str, Any] = attr.ib(init=False, factory=dict) + # Context Manager to handle rasterio open/close + _ctx_stack = attr.ib(init=False, factory=contextlib.ExitStack) + def __attrs_post_init__(self): """Define _kwargs, open dataset and get info.""" if self.nodata is not None: @@ -114,7 +118,9 @@ def __attrs_post_init__(self): if self.post_process is not None: self._kwargs["post_process"] = self.post_process - self.dataset = self.dataset or rasterio.open(self.input) + self.dataset = self.dataset or self._ctx_stack.enter_context( + rasterio.open(self.input) + ) self.bounds = tuple(self.dataset.bounds) self.crs = self.dataset.crs @@ -136,8 +142,7 @@ def __attrs_post_init__(self): def close(self): """Close rasterio dataset.""" - if self.input: - self.dataset.close() + self._ctx_stack.close() def __exit__(self, exc_type, exc_value, traceback): """Support using with Context Managers.""" @@ -696,22 +701,16 @@ class GCPCOGReader(COGReader): # for GCPCOGReader, dataset is not a input option. dataset: WarpedVRT = attr.ib(init=False) - # We use _kwargs to store values of nodata, unscale, vrt_options and resampling_method. - # _kwargs is used avoid having to set those values on each method call. - _kwargs: Dict[str, Any] = attr.ib(init=False, factory=dict) - def __attrs_post_init__(self): """Define _kwargs, open dataset and get info.""" - self.src_dataset = self.src_dataset or rasterio.open(self.input) - self.dataset = WarpedVRT( - self.src_dataset, - src_crs=self.src_dataset.gcps[1], - src_transform=transform.from_gcps(self.src_dataset.gcps[0]), + self.src_dataset = self.src_dataset or self._ctx_stack.enter_context( + rasterio.open(self.input) + ) + self.dataset = self._ctx_stack.enter_context( + WarpedVRT( + self.src_dataset, + src_crs=self.src_dataset.gcps[1], + src_transform=transform.from_gcps(self.src_dataset.gcps[0]), + ) ) super().__attrs_post_init__() - - def close(self): - """Close rasterio dataset.""" - self.dataset.close() - if self.input: - self.src_dataset.close()