Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

expand xarray capabilities #755

Merged
merged 9 commits into from
Oct 28, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 150 additions & 11 deletions rio_tiler/io/xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

from __future__ import annotations

from typing import Any, Dict, List, Optional
import warnings
from typing import Dict, List, Optional

import attr
from morecantile import Tile, TileMatrixSet
Expand All @@ -21,8 +22,9 @@
)
from rio_tiler.io.base import BaseReader
from rio_tiler.models import BandStatistics, ImageData, Info, PointData
from rio_tiler.types import BBox, NoData, WarpResampling
from rio_tiler.utils import CRS_to_uri, _validate_shape_input
from rio_tiler.reader import _get_width_height
from rio_tiler.types import BBox, NoData, RIOResampling, WarpResampling
from rio_tiler.utils import CRS_to_uri, _validate_shape_input, get_array_statistics

try:
import xarray
Expand Down Expand Up @@ -118,7 +120,7 @@ def info(self) -> Info:

meta = {
"bounds": self.bounds,
"crs": CRS_to_uri(self.crs),
"crs": CRS_to_uri(self.crs) or self.crs.to_wkt(),
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

match what we do in io.rasterio.Reader

"band_metadata": [(f"b{ix}", v) for ix, v in enumerate(metadata, 1)],
"band_descriptions": [(f"b{ix}", v) for ix, v in enumerate(bands, 1)],
"dtype": str(self.input.dtype),
Expand All @@ -137,11 +139,29 @@ def statistics(
categories: Optional[List[float]] = None,
percentiles: Optional[List[int]] = None,
hist_options: Optional[Dict] = None,
max_size: int = 1024,
**kwargs: Any,
nodata: Optional[NoData] = None,
) -> Dict[str, BandStatistics]:
"""Return bands statistics from a dataset."""
raise NotImplementedError
"""Return statistics from a dataset."""
hist_options = hist_options or {}

ds = self.input
if nodata is not None:
ds = ds.rio.write_nodata(nodata)

data = ds.to_masked_array()
data.mask |= data.data == ds.rio.nodata

stats = get_array_statistics(
data,
categorical=categorical,
categories=categories,
percentiles=percentiles,
**hist_options,
)

return {
self.band_names[ix]: BandStatistics(**val) for ix, val in enumerate(stats)
}

def tile(
self,
Expand Down Expand Up @@ -219,6 +239,10 @@ def part(
reproject_method: WarpResampling = "nearest",
auto_expand: bool = True,
nodata: Optional[NoData] = None,
max_size: Optional[int] = None,
height: Optional[int] = None,
width: Optional[int] = None,
resampling_method: RIOResampling = "nearest",
) -> ImageData:
"""Read part of a dataset.

Expand All @@ -229,11 +253,21 @@ def part(
reproject_method (WarpResampling, optional): WarpKernel resampling algorithm. Defaults to `nearest`.
auto_expand (boolean, optional): When True, rioxarray's clip_box will expand clip search if only 1D raster found with clip. When False, will throw `OneDimensionalRaster` error if only 1 x or y data point is found. Defaults to True.
nodata (int or float, optional): Overwrite dataset internal nodata value.
max_size (int, optional): Limit the size of the longest dimension of the dataset read, respecting bounds X/Y aspect ratio.
height (int, optional): Output height of the array.
width (int, optional): Output width of the array.
resampling_method (RIOResampling, optional): RasterIO resampling algorithm. Defaults to `nearest`.

Returns:
rio_tiler.models.ImageData: ImageData instance with data, mask and input spatial info.

"""
if max_size and width and height:
warnings.warn(
"'max_size' will be ignored with with 'height' and 'width' set.",
UserWarning,
)

dst_crs = dst_crs or bounds_crs

ds = self.input
Expand Down Expand Up @@ -271,32 +305,109 @@ def part(
arr = ds.to_masked_array()
arr.mask |= arr.data == ds.rio.nodata

return ImageData(
img = ImageData(
arr,
bounds=ds.rio.bounds(),
crs=ds.rio.crs,
dataset_statistics=stats,
band_names=self.band_names,
)

output_height = height or img.height
output_width = width or img.width
if max_size and not (width and height):
output_height, output_width = _get_width_height(
max_size, img.height, img.width
)

if output_height != img.height or output_width != img.width:
img = img.resize(
output_height, output_width, resampling_method=resampling_method
)

return img

def preview(
self,
max_size: int = 1024,
height: Optional[int] = None,
width: Optional[int] = None,
nodata: Optional[NoData] = None,
dst_crs: Optional[CRS] = None,
reproject_method: WarpResampling = "nearest",
resampling_method: RIOResampling = "nearest",
) -> ImageData:
"""Return a preview of a dataset.

Args:
max_size (int, optional): Limit the size of the longest dimension of the dataset read, respecting bounds X/Y aspect ratio. Defaults to 1024.
height (int, optional): Output height of the array.
width (int, optional): Output width of the array.
nodata (int or float, optional): Overwrite dataset internal nodata value.
dst_crs (rasterio.crs.CRS, optional): target coordinate reference system.
reproject_method (WarpResampling, optional): WarpKernel resampling algorithm. Defaults to `nearest`.
resampling_method (RIOResampling, optional): RasterIO resampling algorithm. Defaults to `nearest`.

Returns:
rio_tiler.models.ImageData: ImageData instance with data, mask and input spatial info.

"""
raise NotImplementedError
if max_size and width and height:
warnings.warn(
"'max_size' will be ignored with with 'height' and 'width' set.",
UserWarning,
)

ds = self.input
if nodata is not None:
ds = ds.rio.write_nodata(nodata)

if dst_crs and dst_crs != self.crs:
dst_transform, w, h = calculate_default_transform(
self.crs,
dst_crs,
ds.rio.width,
ds.rio.height,
*ds.rio.bounds(),
)
ds = ds.rio.reproject(
dst_crs,
shape=(h, w),
transform=dst_transform,
resampling=Resampling[reproject_method],
nodata=nodata,
)

# Forward valid_min/valid_max to the ImageData object
minv, maxv = ds.attrs.get("valid_min"), ds.attrs.get("valid_max")
stats = None
if minv is not None and maxv is not None:
stats = ((minv, maxv),) * ds.rio.count

arr = ds.to_masked_array()
arr.mask |= arr.data == ds.rio.nodata

img = ImageData(
arr,
bounds=ds.rio.bounds(),
crs=ds.rio.crs,
dataset_statistics=stats,
band_names=self.band_names,
)

output_height = height or img.height
output_width = width or img.width
if max_size and not (width and height):
output_height, output_width = _get_width_height(
max_size, img.height, img.width
)

if output_height != img.height or output_width != img.width:
img = img.resize(
output_height, output_width, resampling_method=resampling_method
)

return img

def point(
self,
Expand Down Expand Up @@ -348,6 +459,10 @@ def feature(
shape_crs: CRS = WGS84_CRS,
reproject_method: WarpResampling = "nearest",
nodata: Optional[NoData] = None,
max_size: Optional[int] = None,
height: Optional[int] = None,
width: Optional[int] = None,
resampling_method: RIOResampling = "nearest",
) -> ImageData:
"""Read part of a dataset defined by a geojson feature.

Expand All @@ -357,11 +472,21 @@ def feature(
shape_crs (rasterio.crs.CRS, optional): Input geojson coordinate reference system. Defaults to `epsg:4326`.
reproject_method (WarpResampling, optional): WarpKernel resampling algorithm. Defaults to `nearest`.
nodata (int or float, optional): Overwrite dataset internal nodata value.
max_size (int, optional): Limit the size of the longest dimension of the dataset read, respecting bounds X/Y aspect ratio.
height (int, optional): Output height of the array.
width (int, optional): Output width of the array.
resampling_method (RIOResampling, optional): RasterIO resampling algorithm. Defaults to `nearest`.

Returns:
rio_tiler.models.ImageData: ImageData instance with data, mask and input spatial info.

"""
if max_size and width and height:
warnings.warn(
"'max_size' will be ignored with with 'height' and 'width' set.",
UserWarning,
)

if not dst_crs:
dst_crs = shape_crs

Expand Down Expand Up @@ -398,10 +523,24 @@ def feature(
arr = ds.to_masked_array()
arr.mask |= arr.data == ds.rio.nodata

return ImageData(
img = ImageData(
arr,
bounds=ds.rio.bounds(),
crs=ds.rio.crs,
dataset_statistics=stats,
band_names=self.band_names,
)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

feature() now use part() method (as we do in io.rasterio.Reader)


output_height = height or img.height
output_width = width or img.width
if max_size and not (width and height):
output_height, output_width = _get_width_height(
max_size, img.height, img.width
)

if output_height != img.height or output_width != img.width:
img = img.resize(
output_height, output_width, resampling_method=resampling_method
)

return img
55 changes: 54 additions & 1 deletion tests/test_io_xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ def test_xarray_reader():
assert info.attrs

with XarrayReader(data) as dst:
stats = dst.statistics()
assert stats["2022-01-01T00:00:00.000000000"]
assert stats["2022-01-01T00:00:00.000000000"].min == 0.0

img = dst.tile(0, 0, 0)
assert img.count == 1
assert img.width == 256
Expand Down Expand Up @@ -69,8 +73,49 @@ def test_xarray_reader():
assert img.dataset_statistics == ((arr.min(), arr.max()),)

img = dst.part((-160, -80, 160, 80))
assert img.crs == "epsg:4326"
assert img.count == 1
assert img.band_names == ["2022-01-01T00:00:00.000000000"]
assert img.array.shape == (1, 33, 33)

img = dst.part((-160, -80, 160, 80), dst_crs="epsg:3857")
assert img.crs == "epsg:3857"
assert img.count == 1
assert img.band_names == ["2022-01-01T00:00:00.000000000"]
assert img.array.shape == (1, 32, 34)

img = dst.part((-160, -80, 160, 80), max_size=15)
assert img.array.shape == (1, 15, 15)

img = dst.part((-160, -80, 160, 80), width=40, height=35)
assert img.array.shape == (1, 35, 40)

img = dst.part((-160, -80, 160, 80), max_size=15, resampling_method="bilinear")
assert img.array.shape == (1, 15, 15)

img = dst.preview()
assert img.crs == "epsg:4326"
assert img.count == 1
assert img.band_names == ["2022-01-01T00:00:00.000000000"]
assert img.array.shape == (1, 33, 35)

img = dst.preview(dst_crs="epsg:3857")
assert img.crs == "epsg:3857"
assert img.count == 1
assert img.band_names == ["2022-01-01T00:00:00.000000000"]
assert img.array.shape == (1, 32, 36)

img = dst.preview(max_size=None)
assert img.array.shape == (1, 33, 35)

img = dst.preview(max_size=15)
assert img.array.shape == (1, 15, 15)

img = dst.preview(max_size=15, resampling_method="bilinear")
assert img.array.shape == (1, 15, 15)

img = dst.preview(height=25, width=25, max_size=None)
assert img.array.shape == (1, 25, 25)

pt = dst.point(0, 0)
assert pt.count == 1
Expand Down Expand Up @@ -106,11 +151,19 @@ def test_xarray_reader():
img = dst.feature(feat)
assert img.count == 1
assert img.band_names == ["2022-01-01T00:00:00.000000000"]
assert img.array.shape == (1, 25, 30)

img = dst.feature(feat, dst_crs="epsg:3857")
assert img.count == 1
assert img.band_names == ["2022-01-01T00:00:00.000000000"]
assert img.crs.to_epsg() == 3857
assert img.crs == "epsg:3857"
assert img.array.shape == (1, 20, 33)

img = dst.feature(feat, max_size=15)
assert img.array.shape == (1, 13, 15)

img = dst.feature(feat, width=50, height=45)
assert img.array.shape == (1, 45, 50)

arr = numpy.zeros((1, 1000, 2000))
data = xarray.DataArray(
Expand Down
Loading