Skip to content

Commit

Permalink
New expression blocks (#479)
Browse files Browse the repository at this point in the history
* better multiblock expression

* update changelog

* Update CHANGES.md

* Update rio_tiler/expression.py

Co-authored-by: Kyle Barron <[email protected]>
  • Loading branch information
vincentsarago and kylebarron authored Feb 7, 2022
1 parent 90e5065 commit fc1af18
Show file tree
Hide file tree
Showing 4 changed files with 140 additions and 12 deletions.
14 changes: 14 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,20 @@
# unreleased

* add support for setting the S3 endpoint url scheme via the `AWS_HTTPS` environment variables in `aws_get_object` function using boto3 (https://github.com/cogeotiff/rio-tiler/pull/476)
* Add semicolon `;` support for multi-blocks expression (https://github.com/cogeotiff/rio-tiler/pull/479)
* add `rio_tiler.expression.get_expression_blocks` method to split expression (https://github.com/cogeotiff/rio-tiler/pull/479)

**future deprecation**

* using a comma `,` in an expression to define multiple blocks will be replaced by semicolon `;`

```python
# before
expression = "b1+b2,b2"

# new
expression = "b1+b2;b2"
```

# 3.0.3 (2022-01-18)

Expand Down
31 changes: 30 additions & 1 deletion rio_tiler/expression.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""rio-tiler.expression: Parse and Apply expression."""

import re
from typing import Sequence, Tuple, Union
import warnings
from typing import List, Sequence, Tuple, Union

import numexpr
import numpy
Expand Down Expand Up @@ -29,6 +30,33 @@ def parse_expression(expression: str, cast: bool = True) -> Tuple:
return tuple(map(int, bands)) if cast else tuple(bands)


def get_expression_blocks(expression: str) -> List[str]:
"""Split expression in blocks.
Args:
expression (str): band math/combination expression.
Returns:
tuple: expression blocks.
Examples:
>>> parse_expression("b1/b2,b2+b1")
("b1/b2", "b2+b1")
"""
if ";" in expression:
return [expr for expr in expression.split(";") if expr]

expr = [expr for expr in expression.split(",") if expr]
if len(expr) > 1:
warnings.warn(
"Using comma `,` for multiband expression will be deprecated in rio-tiler 4.0. Please use semicolon `;`.",
DeprecationWarning,
)

return expr


def apply_expression(
blocks: Sequence[str],
bands: Sequence[Union[str, int]],
Expand All @@ -52,5 +80,6 @@ def apply_expression(
numexpr.evaluate(bloc.strip(), local_dict=dict(zip(bands, data)))
)
for bloc in blocks
if bloc
]
)
22 changes: 11 additions & 11 deletions rio_tiler/io/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
MissingBands,
TileOutsideBounds,
)
from ..expression import apply_expression
from ..expression import apply_expression, get_expression_blocks
from ..models import BandStatistics, ImageData, Info
from ..tasks import multi_arrays, multi_values
from ..types import BBox, Indexes
Expand Down Expand Up @@ -526,7 +526,7 @@ def _reader(asset: str, *args: Any, **kwargs: Any) -> ImageData:
)

if expression:
blocks = expression.split(",")
blocks = get_expression_blocks(expression)
output.data = apply_expression(blocks, assets, output.data)
output.band_names = blocks

Expand Down Expand Up @@ -590,7 +590,7 @@ def _reader(asset: str, *args: Any, **kwargs: Any) -> ImageData:
output = multi_arrays(assets, _reader, bbox, **kwargs)

if expression:
blocks = expression.split(",")
blocks = get_expression_blocks(expression)
output.data = apply_expression(blocks, assets, output.data)
output.band_names = blocks

Expand Down Expand Up @@ -651,7 +651,7 @@ def _reader(asset: str, **kwargs: Any) -> ImageData:
output = multi_arrays(assets, _reader, **kwargs)

if expression:
blocks = expression.split(",")
blocks = get_expression_blocks(expression)
output.data = apply_expression(blocks, assets, output.data)
output.band_names = blocks

Expand Down Expand Up @@ -716,7 +716,7 @@ def _reader(asset: str, *args, **kwargs: Any) -> Dict:

values = numpy.array([d for _, d in data.items()])
if expression:
blocks = expression.split(",")
blocks = get_expression_blocks(expression)
values = apply_expression(blocks, assets, values)

return values.tolist()
Expand Down Expand Up @@ -779,7 +779,7 @@ def _reader(asset: str, *args: Any, **kwargs: Any) -> ImageData:
output = multi_arrays(assets, _reader, shape, **kwargs)

if expression:
blocks = expression.split(",")
blocks = get_expression_blocks(expression)
output.data = apply_expression(blocks, assets, output.data)
output.band_names = blocks

Expand Down Expand Up @@ -991,7 +991,7 @@ def _reader(band: str, *args: Any, **kwargs: Any) -> ImageData:
output = multi_arrays(bands, _reader, tile_x, tile_y, tile_z, **kwargs)

if expression:
blocks = expression.split(",")
blocks = get_expression_blocks(expression)
output.data = apply_expression(blocks, bands, output.data)
output.band_names = blocks

Expand Down Expand Up @@ -1043,7 +1043,7 @@ def _reader(band: str, *args: Any, **kwargs: Any) -> ImageData:
output = multi_arrays(bands, _reader, bbox, **kwargs)

if expression:
blocks = expression.split(",")
blocks = get_expression_blocks(expression)
output.data = apply_expression(blocks, bands, output.data)
output.band_names = blocks

Expand Down Expand Up @@ -1093,7 +1093,7 @@ def _reader(band: str, **kwargs: Any) -> ImageData:
output = multi_arrays(bands, _reader, **kwargs)

if expression:
blocks = expression.split(",")
blocks = get_expression_blocks(expression)
output.data = apply_expression(blocks, bands, output.data)
output.band_names = blocks

Expand Down Expand Up @@ -1146,7 +1146,7 @@ def _reader(band: str, *args, **kwargs: Any) -> Dict:

values = numpy.array([d for _, d in data.items()])
if expression:
blocks = expression.split(",")
blocks = get_expression_blocks(expression)
values = apply_expression(blocks, bands, values)

return values.tolist()
Expand Down Expand Up @@ -1197,7 +1197,7 @@ def _reader(band: str, *args: Any, **kwargs: Any) -> ImageData:
output = multi_arrays(bands, _reader, shape, **kwargs)

if expression:
blocks = expression.split(",")
blocks = get_expression_blocks(expression)
output.data = apply_expression(blocks, bands, output.data)
output.band_names = blocks

Expand Down
85 changes: 85 additions & 0 deletions tests/test_expression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
"""test rio_tiler.expression functions."""

import numpy
import pytest

from rio_tiler.expression import (
apply_expression,
get_expression_blocks,
parse_expression,
)


@pytest.mark.parametrize(
"expr,expected",
[
("b1,b2", [1, 2]),
("B1,b2", [1, 2]),
("B1,B2", [1, 2]),
("where((b1==1) | (b1 > 0.5),1,0);", [1]),
],
)
def test_parse(expr, expected):
"""test parse_expression."""
assert sorted(parse_expression(expr)) == expected


@pytest.mark.parametrize(
"expr,expected",
[
("b1,b2", ["1", "2"]),
("B1,b2", ["1", "2"]),
("B1,B2", ["1", "2"]),
],
)
def test_parse_cast(expr, expected):
"""test parse_expression without casting."""
assert sorted(parse_expression(expr, cast=False)) == expected


@pytest.mark.parametrize(
"expr,expected",
[
("b1,", ["b1"]),
("b1,b2", ["b1", "b2"]),
("where((b1==1) | (b1 > 0.5),1,0)", ["where((b1==1) | (b1 > 0.5)", "1", "0)"]),
("where((b1==1) | (b1 > 0.5),1,0);", ["where((b1==1) | (b1 > 0.5),1,0)"]),
],
)
def test_get_blocks(expr, expected):
"""test get_expression_blocks."""
with pytest.warns(None):
assert get_expression_blocks(expr) == expected


def test_get_blocks_warn():
"""test get_expression_blocks."""
with pytest.warns(DeprecationWarning):
assert get_expression_blocks("b1,b2")


def test_apply_expression():
"""test apply_expression."""
# divide b1 by b2
data = numpy.zeros(shape=(2, 10, 10), dtype=numpy.uint8)
data[0] += 1
data[1] += 2
d = apply_expression(["b1/b2"], ["b1", "b2"], data)
assert numpy.unique(d) == 0.5

# complex expression
data = numpy.zeros(shape=(2, 10, 10), dtype=numpy.uint8)
data[0, 0:5, 0:5] += 1
d = apply_expression(["where((b1==1) | (b1 > 0.5),1,0)"], ["b1", "b2"], data)
# data has 2 bands but expression just use one
assert d.shape == (1, 10, 10)
assert len(numpy.unique(d)) == 2
assert numpy.unique(d[0, 0:5, 0:5]) == [1]

data = numpy.zeros(shape=(2, 10, 10), dtype=numpy.uint8)
data[0, 0:5, 0:5] += 1
data[1, 0:5, 0:5] += 5
d = apply_expression(
["where((b1==1) | (b1 > 0.5),1,0)", "where(b2 > 5,1,0)"], ["b1", "b2"], data
)
assert d.shape == (2, 10, 10)

0 comments on commit fc1af18

Please sign in to comment.