Skip to content

Commit

Permalink
Include swir datasetparams for custom swir band stretching.
Browse files Browse the repository at this point in the history
  • Loading branch information
sharkinsspatial committed Mar 1, 2022
1 parent 58da44d commit 8457367
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 1 deletion.
7 changes: 6 additions & 1 deletion src/eoapi/raster/eoapi/raster/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from eoapi.raster.factory import MultiBaseTilerFactory
from eoapi.raster.reader import STACReader
from eoapi.raster.version import __version__ as eoapi_raster_version
from eoapi.raster.datasetparams import DatasetParams

logging.getLogger("botocore.credentials").disabled = True
logging.getLogger("botocore.utils").disabled = True
Expand All @@ -34,7 +35,11 @@
add_exception_handlers(app, MOSAIC_STATUS_CODES)

# PgSTAC mosaic tiler
mosaic = MosaicTilerFactory(router_prefix="mosaic", optional_headers=optional_headers)
mosaic = MosaicTilerFactory(
router_prefix="mosaic",
optional_headers=optional_headers,
dataset_dependency=DatasetParams,
)
app.include_router(mosaic.router, prefix="/mosaic", tags=["PgSTAC Mosaic"])

# Custom STAC titiler endpoint (not added to the openapi docs)
Expand Down
45 changes: 45 additions & 0 deletions src/eoapi/raster/eoapi/raster/datasetparams.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from titiler.core import dependencies
from dataclasses import dataclass
import numpy
from fastapi import Query

from typing import Tuple, Optional

# https://github.com/cogeotiff/rio-tiler/blob/master/rio_tiler/reader.py#L35-L37

import math


def swir(data, mask) -> Tuple[numpy.ndarray, numpy.ndarray]:
low_value = math.e
high_value = 255

low_threshold = math.log(1000)
high_threshold = math.log(7500)

data = numpy.log(data)
data[numpy.where(data <= low_threshold)] = low_value
data[numpy.where(data >= high_threshold)] = high_value
indices = numpy.where(
(data > low_value) & (data < high_value)
)
data[indices] = (
high_value * (data[indices] - low_threshold) / (high_threshold - low_threshold)
)
return data.astype("uint8"), mask


pp_methods = {
"swir": swir,
}


@dataclass
class DatasetParams(dependencies.DatasetParams):
post_process: Optional[str] = Query(None, description="Post Process Name.")

def __post_init__(self):
super().__post_init__()

if self.post_process is not None:
self.post_process = pp_methods.get(self.post_process) # type: ignore

0 comments on commit 8457367

Please sign in to comment.