diff --git a/README.md b/README.md index 8ec2feb..f288e3e 100644 --- a/README.md +++ b/README.md @@ -36,18 +36,22 @@ python geo_inference -a ``` - `-a`, `--args`: Path to arguments stored in yaml, consult ./config/sample_config.yaml ```bash -python geo_inference -i -m -wd -bs -v -d -id +python geo_inference -i -br -m -wd -ps -v -d -id -cls -mg ``` - `-i`, `--image`: Path to Geotiff - `-bb`, `--bbox`: AOI bbox in this format "minx, miny, maxx, maxy" (Optional) +- `-br`, `--bands_requested`: The requested bands from provided Geotiff (if not provided, it uses all bands) - `-m`, `--model`: Path or URL to the model file - `-wd`, `--work_dir`: Working Directory -- `-bs`, `--batch_size`: The Batch Size +- `-ps`, `--patch_size`: The patch Size, the size of dask chunks, Default = 1024 - `-v`, `--vec`: Vector Conversion - `-y`, `--yolo`: Yolo Conversion - `-c`, `--coco`: Coco Conversion - `-d`, `--device`: CPU or GPU Device - `-id`, `--gpu_id`: GPU ID, Default = 0 +- `-cls`, `--classes`: The number of classes that model outputs, Default = 5 +- `-mg`, `--mgpu`: Whether to use multi-gpu processing or not, Default = False + You can also use the `-h` option to get a list of supported arguments: @@ -63,18 +67,20 @@ from geo_inference.geo_inference import GeoInference geo_inference = GeoInference( model="/path/to/segformer_B5.pt", work_dir="/path/to/work/dir", - batch_size=4, + patch_size=1024, mask_to_vec=False, - vec_to_yolo=False, - vec_to_coco=False, + mask_to_yolo=False, + mask_to_coco=False, device="gpu", - gpu_id=0 + multi_gpu=False, + gpu_id=0, + num_classes=5 ) # Perform feature extraction on a TIFF image image_path = "/path/to/image.tif" patch_size = 512 -geo_inference(tiff_image = image_path, patch_size = patch_size,) +geo_inference(tiff_image = image_path, bands_requested = bands_requested, patch_size = patch_size,) ``` ## Parameters @@ -83,13 +89,14 @@ The `GeoInference` class takes the following parameters: - `model`: The path or URL to the model file (.pt for PyTorch models) to use for feature extraction. - `work_dir`: The path to the working directory. Default is `"~/.cache"`. -- `batch_size`: The batch size to use for feature extraction. Default is `4`. +- `patch_size`: The patch size to use for feature extraction. Default is `4`. - `mask_to_vec`: If set to `"True"`, vector data will be created from mask. Default is `"False"` -- `vec_to_yolo`: If set to `"True"`, vector data will be converted to YOLO format. Default is `"False"` -- `vec_to_coco`: If set to `"True"`, vector data will be converted to COCO format. Default is `"False"` +- `mask_to_yolo`: If set to `"True"`, vector data will be converted to YOLO format. Default is `"False"` +- `mask_to_coco`: If set to `"True"`, vector data will be converted to COCO format. Default is `"False"` - `device`: The device to use for feature extraction. Can be `"cpu"` or `"gpu"`. Default is `"gpu"`. +- `multi_gpu`: If set to `"True"`, uses multi-gpu for running the inference. Default is `"False"` - `gpu_id`: The ID of the GPU to use for feature extraction. Default is `0`. - +- `num_classes`: The number of classes that the TorchScript model outputs. Default is `5`. ## Output The `GeoInference` class outputs the following files: diff --git a/geo_inference/config/sample.yaml b/geo_inference/config/sample.yaml index 5b3cb28..8cdc2ff 100644 --- a/geo_inference/config/sample.yaml +++ b/geo_inference/config/sample.yaml @@ -1,11 +1,15 @@ arguments: - image: "./data/areial.tiff" # Path to Geotiff + image: "./NB_PointEscuminac_WV02_20220615_A-017161374010_01_P001-WV02-R-G-B_clahe25.tif" + model: ./deep_learning_model/4cls_RGB_5_1_2_3_scripted.pt # Name of Extraction Model: str bbox: None # "minx, miny, maxx, maxy" - model: "rgb-4class-segformer" # Name of Extraction Model: str - work_dir: None # Working Directory: str - batch_size: 1 # Batch size + work_dir: ./dask_geo_deep_learning/dask_geo_inference # Working Directory: str + bands_requested: '1,2,3' # requested Bands + vec: False # Vector Coversion: bool vec: False # Vector Conversion: bool yolo: False # YOLO Conversion: bool coco: False # COCO Conversion: bool device: "gpu" # cpu or gpu: str - gpu_id: 0 # GPU ID: int + gpu_id: 0 + mgpu: False + classes : 5 + patch_size: 1024 diff --git a/geo_inference/geo_blocks.py b/geo_inference/geo_blocks.py deleted file mode 100644 index 7e3f99f..0000000 --- a/geo_inference/geo_blocks.py +++ /dev/null @@ -1,404 +0,0 @@ -import logging -import os -import sys -from typing import Any, Dict, Iterator, Optional, Tuple, Union, cast - -import numpy as np -import rasterio as rio -import scipy.signal.windows as w -import torch -from rasterio.crs import CRS -from scipy.special import expit -from rasterio.windows import Window -from torch import Tensor -from torch.nn import functional as F -from torchgeo.datasets import GeoDataset -from torchgeo.datasets.utils import BoundingBox -from torchgeo.samplers import GeoSampler -from torchgeo.samplers.utils import _to_tuple, tile_to_chips - -from .config.logging_config import logger -from .utils.helpers import validate_asset_type - -logger = logging.getLogger(__name__) - - -class RasterDataset(GeoDataset): - """ - A dataset class for raster data. - - Attributes: - image_asset: The path to the image asset. - src: The rasterio dataset object. - cmap: The colormap of the image asset. - crs: The coordinate reference system of the image asset. - res: The resolution of the image asset. - bands: The number of bands in the image asset. - index: The rtree index of the image asset. - """ - def __init__(self, image_asset: str, bbox: str = None) -> None: - """Initializes a RasterDataset object. - - Args: - image_asset (str): The path or rasterio dataset of image asset. - bounding_box (str): The bounding box of image asset. - """ - super().__init__() - self.src = validate_asset_type(image_asset) - - try: - self.cmap = self.src.colormap(1) - except ValueError: - pass - - crs = self.src.crs - res = self.src.res[0] - bands = self.src.count - image_height = self.src.height - image_width = self.src.width - minx, miny, maxx, maxy = self.src.bounds - - mint: float = 0 - maxt: float = sys.maxsize - - if bbox is None: - bbox = BoundingBox(minx=minx, - miny=miny, - maxx=maxx, - maxy=maxy, - mint=mint, maxt=maxt) - else: - bbox = tuple(map(float, bbox.split(', '))) - bbox = BoundingBox(minx=bbox[0], - miny=bbox[1], - maxx=bbox[2], - maxy=bbox[3], - mint=mint, maxt=maxt) - - coords = (minx, maxx, miny, maxy, mint, maxt) - self.index.insert(0, coords, self.src.name) - - self._crs = cast(CRS, crs) - self.res = cast(float, res) - self.bands = bands - self.bbox = bbox - self.image_height = image_height - self.image_width = image_width - - def __getitem__(self, query: Dict[str, Any]) -> Dict[str, Any]: - """ - Get a sample from the dataset. - - Args: - query: A dictionary containing the query parameters. - - Returns: - A dictionary containing the sample data. - """ - filepath = query['path'] - window = query["window"] - pixel_coords = query["pixel_coords"] - patch_size = pixel_coords[-1] - - data = self._get_tensor(pixel_coords, patch_size) - sample = {"image": data, - "crs": self.crs, - "pixel_coords": pixel_coords, - "window": window, - "path": filepath} - - return sample - - def _get_tensor(self, query, size): - """ - Get a patch based on the given query (pixel coordinates). - - Args: - query: The pixel coordinates of the patch. - size: The desired patch size. - - Returns: - A torch tensor patch. - """ - (x_min, y_min, patch_width, patch_height) = query - - window = Window.from_slices(slice(y_min, y_min + patch_height), - slice(x_min, x_min + patch_width)) - - dest = self.src.read(window=window) - if dest.dtype == np.uint16: - dest = dest.astype(np.int32) - elif dest.dtype == np.uint32: - dest = dest.astype(np.int64) - - tensor = torch.tensor(dest) - tensor = self.pad_patch(tensor, size) - - return tensor - - @staticmethod - def pad_patch(x: Tensor, patch_size: int): - """ - Pad the patch to desired patch_size. - - Args: - x: The tensor patch to pad. - patch_size: The desired patch size. - - Returns: - The padded tensor patch. - """ - h, w = x.shape[-2:] - pad_h = patch_size - h - pad_w = patch_size - w - # pads are described starting from the last dimension and moving forward. - x = F.pad(x, (0, pad_w, 0, pad_h)) - return x - - -class InferenceSampler(GeoSampler): - """Class for creating an inference sampler. - - This class extends GeoSampler and is designed for generating patches - for inference on a GeoDataset. - - Attributes: - dataset (GeoDataset): The dataset to generate patches from. - size (Union[Tuple[float, float], float]): Dimensions of each patch. - stride (Union[Tuple[float, float], float]): Distance to skip between each patch. - roi (Optional[BoundingBox]): Region of interest to sample from. - """ - def __init__(self, - dataset: GeoDataset, - size: Union[Tuple[float, float], float], - stride: Union[Tuple[float, float], float], - roi: Optional[BoundingBox] = None, - ) -> None: - """ - Initializes an InferenceSampler object. - - Args: - dataset (GeoDataset): A GeoDataset object. - size (Union[Tuple[float, float], float]): The size of the patch. - stride (Union[Tuple[float, float], float]): The stride of the patch. - roi (Optional[BoundingBox], optional): A BoundingBox object. Defaults to None. - """ - super().__init__(dataset, roi) - self.size = _to_tuple(size) - self.patch_size = self.size - self.stride = _to_tuple(stride) - - # Generates 9 2D signal windows of patch size that covers edge and corner coordinates - self.windows = torch.tensor(self.generate_corner_windows(self.patch_size[0]), dtype=torch.float32) - self.size_in_crs_units = (self.size[0] * self.res, self.size[1] * self.res) - self.stride_in_crs_units = (self.stride[0] * self.res, self.stride[1] * self.res) - self.hits = [] - self.hits_small = [] - for hit in self.index.intersection(tuple(self.roi), objects=True): - bounds = BoundingBox(*hit.bounds) - if (bounds.maxx - bounds.minx >= self.size_in_crs_units[1] - and bounds.maxy - bounds.miny >= self.size_in_crs_units[0]): - self.hits.append(hit) - else: - self.hits_small.append(hit) - - self.length = 0 - for hit in self.hits: - bounds = BoundingBox(*hit.bounds) - rows, cols = tile_to_chips(bounds=bounds, size=self.size_in_crs_units, stride=self.stride_in_crs_units) - self.length += rows * cols - - for hit in self.hits_small: - bounds = BoundingBox(*hit.bounds) - self.length += 1 - - for hit in self.hits + self.hits_small: - if hit in self.hits: - bounds = BoundingBox(*hit.bounds) - self.im_height = round((bounds.maxy - bounds.miny) / self.res) - self.im_width = round((bounds.maxx - bounds.minx) / self.res) - else: - bounds = BoundingBox(*hit.bounds) - self.im_height = round((bounds.maxy - bounds.miny) / self.res) - self.im_width = round((bounds.maxx - bounds.minx) / self.res) - - def __iter__(self) -> Iterator[Dict[str, Any]]: - """ - Yields a dictionary containing the pixel coordinates, path, and window. - """ - for hit in self.hits + self.hits_small: - if hit in self.hits: - y_steps = int(np.ceil((self.im_height - self.size[0]) / self.stride[0]) + 1) - x_steps = int(np.ceil((self.im_width - self.size[1]) / self.stride[1]) + 1) - for y in range(y_steps): - if self.stride[0] * y + self.size[0] > self.im_height: - y_min = self.im_height - self.size[0] - else: - y_min = self.stride[0] * y - for x in range(x_steps): - if self.stride[1] * x + self.size[1] > self.im_width: - x_min = self.im_width - self.size[1] - else: - x_min = self.stride[1] * x - # get center, border and corner windows - border_x, border_y = 1, 1 - if y == 0: border_x = 0 - if x == 0: border_y = 0 - if y == y_steps - 1: border_x = 2 - if x == x_steps - 1: border_y = 2 - # Select the right window - current_window = self.windows[border_x, border_y] - query = {"pixel_coords": (x_min, y_min, self.patch_size[1], self.patch_size[0]), - "path": cast(str, hit.object), - "window": current_window} - yield query - else: - x_min, y_min = (0, 0) - current_window = torch.ones((self.patch_size[0], self.patch_size[1])) - query = {"pixel_coords": (x_min, y_min, self.patch_size[1], self.patch_size[0]), - "path": cast(str, hit.object), - "window": current_window} - yield query - - def __len__(self) -> int: - """ - Returns the number of samples over the ROI. - - Returns: - int: The number of patches that will be sampled. - """ - return self.length - - @staticmethod - def generate_corner_windows(window_size: int) -> np.ndarray: - """ - Generates 9 2D signal windows that covers edge and corner coordinates - - Args: - window_size (int): The size of the window. - - Returns: - np.ndarray: 9 2D signal windows stacked in array (3, 3). - """ - step = window_size >> 1 - window = np.matrix(w.hann(M=window_size, sym=False)) - window = window.T.dot(window) - window_u = np.vstack([np.tile(window[step:step+1, :], (step, 1)), window[step:, :]]) - window_b = np.vstack([window[:step, :], np.tile(window[step:step+1, :], (step, 1))]) - window_l = np.hstack([np.tile(window[:, step:step+1], (1, step)), window[:, step:]]) - window_r = np.hstack([window[:, :step], np.tile(window[:, step:step+1], (1, step))]) - window_ul = np.block([[np.ones((step, step)), window_u[:step, step:]], - [window_l[step:, :step], window_l[step:, step:]]]) - window_ur = np.block([[window_u[:step, :step], np.ones((step, step))], - [window_r[step:, :step], window_r[step:, step:]]]) - window_bl = np.block([[window_l[:step, :step], window_l[:step, step:]], - [np.ones((step, step)), window_b[step:, step:]]]) - window_br = np.block([[window_r[:step, :step], window_r[:step, step:]], - [window_b[step:, :step], np.ones((step, step))]]) - return np.array([[window_ul, window_u, window_ur], - [window_l, window, window_r], - [window_bl, window_b, window_br]]) - - -class InferenceMerge: - """ - A class for merging inference results. - - Attributes: - height (int): The padded height of roi. - width (int): The padded width of roi. - device (torch.device): The device to use for computation. - image (np.ndarray): The merged image. - norm_mask (np.ndarray): The normalization mask. - - """ - def __init__(self, - height: int, - width: int, - classes: int, - device: torch.device) -> None: - """ - Initializes a new instance of the InferenceMerge class. - - Args: - height (int): The padded height of roi. - width (int): The padded width of roi. - device (torch.device): The device to use for computation. - """ - self.height = height - self.width = width - self.classes = classes - self.device = device - self.image = np.zeros((self.classes, self.height, self.width), dtype=np.float16) - self.norm_mask = np.ones((1, self.height, self.width), dtype=np.float16) - # self.image = torch.zeros((self.classes, self.height, self.width), dtype=torch.float16, device=self.device) - # self.norm_mask = torch.ones((1, self.height, self.width), dtype=torch.float16, device=self.device) - - @torch.no_grad() - def merge_on_cpu(self, batch: torch.Tensor, windows: torch.Tensor, pixel_coords): - """ - Merge the patches on CPU. - - Args: - batch (torch.Tensor): The batch of inference results. - windows (torch.Tensor): The windows used for inference. - pixel_coords (list): The pixel coordinates of the patches. - - Returns: - None - """ - for output, window, (x, y, patch_width, patch_height) in zip(batch, windows, pixel_coords): - # It is best to have these functions scripted in the model - # if self.classes == 1: - # output = F.sigmoid(output) * window - # else: - # output = F.softmax(output, dim=0) * window - output = output * window - self.image[:, y : y + patch_height, x : x + patch_width] += output.cpu().numpy() - self.norm_mask[:, y : y + patch_height, x : x + patch_width] += window.cpu().numpy() - - @torch.no_grad() - def merge_on_gpu(self,): - """ - Merge the patches on GPU. - """ - pass - - def save_as_tiff(self, - height: int, - width: int, - output_meta: dict, - output_path: os.PathLike) -> torch.Tensor: - """ - Save mask to file. - - Args: - height (int): The height of the output mask. - width (int): The width of the output mask. - output_meta (dict): The meta data of the output mask. - output_path (os.PathLike): The path to save the output mask. - - Returns: - None - """ - threshold = 0.5 - self.image /= self.norm_mask - - # Binary mask - if self.image.shape[0] == 1: - self.image = expit(self.image) - self.image = np.where(self.image > threshold, 1, 0).squeeze(0).astype(np.uint8) - else: - self.image = np.argmax(self.image, axis=0).astype(np.uint8) - # self.image = torch.argmax(self.image, dim=0).byte().cpu().numpy() - - self.image = self.image[np.newaxis, :height, :width] - output_meta.update({"driver": "GTiff", - "height": self.image.shape[1], - "width": self.image.shape[2], - "count": self.image.shape[0], - "dtype": 'uint8', - "compress": 'lzw'}) - with rio.open(output_path, 'w+', **output_meta) as dest: - dest.write(self.image) - logger.info(f"Mask saved to {output_path}") \ No newline at end of file diff --git a/geo_inference/geo_dask.py b/geo_inference/geo_dask.py new file mode 100644 index 0000000..c40f815 --- /dev/null +++ b/geo_inference/geo_dask.py @@ -0,0 +1,304 @@ + +import torch +import logging + +import numpy as np +import scipy.signal.windows as w +from scipy.special import expit + +logger = logging.getLogger(__name__) + + +def runModel( + chunk_data: np.ndarray, + model, + patch_size: int, + device: str, + num_classes: int = 5, + block_info=None, +): + """ + This function is for running the model on partial neighbor (The right and bottom neighbors). + After running the model, depending on the location of chuck, it multiplies the chunk with a window and adds the windows to another dimension of the chunk and returns it. + This window is used for edge artifact. + @param chunk_data: np.ndarray, this is a chunk of data in dask array + chunk_size: int, the size of chunk data that we want to feed the model with + model: ScrptedModel, the scripted model. + patch_size: int , the size of each patch on which the model should be run. + device : str, the torch device; either cpu or gpu. + num_classes: int, the number of classes that model work with. + block_info: none, this is having all the info about the chunk relative to the whole data (dask array) + @return: predited chunks + """ + num_chunks = block_info[0]["num-chunks"] + chunk_location = block_info[0]["chunk-location"] + if chunk_data.size > 0 and chunk_data is not None: + try: + # Defining the base window for window creation later + step = patch_size >> 1 + window = w.hann(M=patch_size, sym=False) + window = window[:, np.newaxis] * window[np.newaxis, :] + final_window = np.empty((1, 1)) + + if chunk_location[2] >= num_chunks[2] - 2 and chunk_location[1] == 0: + window_u = np.vstack( + [ + np.tile(window[step : step + 1, :], (step, 1)), + window[step:, :], + ] + ) + window_r = np.hstack( + [ + window[:, :step], + np.tile(window[:, step : step + 1], (1, step)), + ] + ) + final_window = np.block( + [ + [window_u[:step, :step], np.ones((step, step))], + [window_r[step:, :step], window_r[step:, step:]], + ] + ) + elif chunk_location[2] >= num_chunks[2] - 2 and ( + chunk_location[1] > 0 and chunk_location[1] < num_chunks[1] - 2 + ): + # left egde window + final_window = np.hstack( + [ + window[:, :step], + np.tile(window[:, step : step + 1], (1, step)), + ] + ) + elif chunk_location[2] >= num_chunks[2] - 2 and ( + chunk_location[1] >= num_chunks[1] - 2 + ): + # bottom right window + window_r = np.hstack( + [ + window[:, :step], + np.tile(window[:, step : step + 1], (1, step)), + ] + ) + window_b = np.vstack( + [ + window[:step, :], + np.tile(window[step : step + 1, :], (step, 1)), + ] + ) + final_window = np.block( + [ + [window_r[:step, :step], window_r[:step, step:]], + [window_b[step:, :step], np.ones((step, step))], + ] + ) + elif chunk_location[1] >= num_chunks[1] - 2 and ( + chunk_location[2] > 0 and chunk_location[2] < num_chunks[2] - 2 + ): + # bottom egde window + final_window = np.vstack( + [ + window[:step, :], + np.tile(window[step : step + 1, :], (step, 1)), + ] + ) + elif chunk_location[1] >= num_chunks[1] - 2 and chunk_location[2] == 0: + # bottom left window + window_l = np.hstack( + [ + np.tile(window[:, step : step + 1], (1, step)), + window[:, step:], + ] + ) + window_b = np.vstack( + [ + window[:step, :], + np.tile(window[step : step + 1, :], (step, 1)), + ] + ) + final_window = np.block( + [ + [window_l[:step, :step], window_l[:step, step:]], + [np.ones((step, step)), window_b[step:, step:]], + ] + ) + elif chunk_location[1] == 0 and chunk_location[2] == 0: + # Top left window + window_u = np.vstack( + [ + np.tile(window[step : step + 1, :], (step, 1)), + window[step:, :], + ] + ) + window_l = np.hstack( + [ + np.tile(window[:, step : step + 1], (1, step)), + window[:, step:], + ] + ) + final_window = np.block( + [ + [np.ones((step, step)), window_u[:step, step:]], + [window_l[step:, :step], window_l[step:, step:]], + ] + ) + elif chunk_location[2] == 0 and ( + chunk_location[1] > 0 and chunk_location[1] < num_chunks[1] + ): + # top edge window + final_window = np.hstack( + [ + np.tile(window[:, step : step + 1], (1, step)), + window[:, step:], + ] + ) + elif (chunk_location[2] > 0 and chunk_location[2] < num_chunks[2] - 2) and ( + chunk_location[1] == 0 + ): + # top edge window + final_window = np.vstack( + [ + np.tile(window[step : step + 1, :], (step, 1)), + window[step:, :], + ] + ) + elif (chunk_location[1] > 0 and chunk_location[1] < num_chunks[1] - 2) and ( + chunk_location[2] > 0 and chunk_location[2] < num_chunks[2] - 2 + ): + final_window = window + + tensor = torch.as_tensor(chunk_data[np.newaxis, ...]).to( + torch.device(device) + ) + out = np.empty( + shape=(num_classes, chunk_data.shape[1], chunk_data.shape[2]) + ) # Create the output but empty + with torch.no_grad(): + out = model(tensor).cpu().numpy()[0] + del tensor + if out.shape[1:] == final_window.shape and out.shape[1:] == ( + patch_size, + patch_size, + ): + return np.concatenate( + (out * final_window, final_window[np.newaxis, :, :]), axis=0 + ) + else: + return np.zeros((num_classes + 1, patch_size, patch_size)) + except Exception as e: + logging.error(f"Error occured in RunModel: {e}") + finally: + if torch.cuda.is_available(): + torch.cuda.empty_cache() # Release unused memory + + +def sum_overlapped_chunks( + aoi_chunk: np.ndarray, + chunk_size: int, + block_info=None, +): + """ + This function is for summing up the overlapped parts of the patches in order to reduce the edge artifact. + After running the model, we run this function on neighbor chunks. + @param aoi_chunk: np.ndarray, this is a chunk of data in dask array. + aoi_chunk: int, the size of chunk data that we want to feed the model with + chunk_size: int , the size of each patch on which the model should be run. + block_info: none, this is having all the info about the chunk relative to the whole data (dask array) + @return: reday-to-save chunks + """ + if aoi_chunk.size > 0 and aoi_chunk is not None: + num_chunks = block_info[0]["num-chunks"] + chunk_location = block_info[0]["chunk-location"] + full_array = np.empty((1, 1)) + if (chunk_location[1] == 0 or chunk_location[1] == num_chunks[1] - 1) and ( + chunk_location[2] == 0 or chunk_location[2] == num_chunks[2] - 1 + ): + """ All 4 corners""" + full_array = aoi_chunk[ + :, + : int(chunk_size / 2), + : int(chunk_size / 2), + ] + elif (chunk_location[1] == 0 or chunk_location[1] == num_chunks[1] - 1) and ( + chunk_location[2] > 0 and chunk_location[2] < num_chunks[2] - 1 + ): + """ Top and bottom edges but not corners""" + full_array = ( + aoi_chunk[ + :, + : int(chunk_size / 2), + int(chunk_size / 2) : int(chunk_size / 2) * 2, + ] + + aoi_chunk[ + :, + : int(chunk_size / 2), + : int(chunk_size / 2), + ] + ) + elif (chunk_location[2] == 0 or chunk_location[2] == num_chunks[2] - 1) and ( + chunk_location[1] > 0 and chunk_location[1] < num_chunks[1] - 1 + ): + """ Left and right edges but not corners""" + full_array = ( + aoi_chunk[ + :, + int(chunk_size / 2) : int(chunk_size / 2) * 2, + : int(chunk_size / 2), + ] + + aoi_chunk[ + :, + : int(chunk_size / 2), + : int(chunk_size / 2), + ] + ) + elif (chunk_location[2] > 0 and chunk_location[2] < num_chunks[2] - 1) and ( + chunk_location[1] > 0 and chunk_location[1] < num_chunks[1] - 1 + ): + """ Middle chunks """ + full_array = ( + aoi_chunk[ + :, + : int(chunk_size / 2), + : int(chunk_size / 2), + ] + + aoi_chunk[ + :, + : int(chunk_size / 2), + int(chunk_size / 2) : int(chunk_size / 2) * 2, + ] + + aoi_chunk[ + :, + int(chunk_size / 2) : int(chunk_size / 2) * 2, + : int(chunk_size / 2), + ] + + aoi_chunk[ + :, + int(chunk_size / 2) : int(chunk_size / 2) * 2, + int(chunk_size / 2) : int(chunk_size / 2) * 2, + ] + ) + + if full_array.shape != ( + aoi_chunk.shape[0], + int(chunk_size / 2), + int(chunk_size / 2), + ): + logging.error( + f" In sum_overlapped_chunks the shape of full_array is not {(6, int(chunk_size / 2), int(chunk_size / 2))}" + f" The size of it {full_array.shape}" + ) + else: + with np.errstate(divide="ignore", invalid="ignore"): + final_result = np.divide( + full_array[:-1, :, :], + full_array[-1, :, :][np.newaxis, :, :], + out=np.zeros_like(full_array[:-1, :, :], dtype=float), + where=full_array[-1, :, :] != 0, + ) + if final_result.shape[0] == 1: + final_result = expit(final_result) + final_result = ( + np.where(final_result > 0.5, 1, 0).squeeze(0).astype(np.uint8) + ) + else: + final_result = np.argmax(final_result, axis=0).astype(np.uint8) + return final_result diff --git a/geo_inference/geo_inference.py b/geo_inference/geo_inference.py index ba4ff0a..190a576 100644 --- a/geo_inference/geo_inference.py +++ b/geo_inference/geo_inference.py @@ -1,144 +1,349 @@ -import logging +import os import time +import torch # type: ignore +import logging +import pystac # type: ignore +import numpy as np +import dask.array as da +import asyncio +import gc +from dask import config +from typing import Dict # type: ignore from pathlib import Path +import rasterio # type: ignore +from rasterio.windows import from_bounds # type: ignore +from dask_image.imread import imread as dask_imread # type: ignore +from typing import Union, Sequence, List +from omegaconf import ListConfig # type: ignore +import threading +import xarray as xr +from dask.diagnostics import ResourceProfiler, ProgressBar +from multiprocessing.pool import ThreadPool -import torch -import rasterio as rio -from torch.utils.data import DataLoader -from torchgeo.datasets import stack_samples -from tqdm import tqdm +from .utils.helpers import ( + cmd_interface, + get_directory, + get_model, + xarray_profile_info, + select_model_device, + asset_by_common_name, +) +from .geo_dask import ( + runModel, + sum_overlapped_chunks, +) -from .config.logging_config import logger -from .geo_blocks import InferenceMerge, InferenceSampler, RasterDataset -from .utils.helpers import cmd_interface, get_device, get_directory, get_model from .utils.polygon import gdf_to_yolo, mask_to_poly_geojson, geojson2coco logger = logging.getLogger(__name__) class GeoInference: + """ A class for performing geo inference on geospatial imagery using a pre-trained model. Args: model (str): The path or url to the model file work_dir (str): The directory where the model and output files will be saved. - batch_size (int): The batch size to use for inference. mask_to_vec (bool): Whether to convert the output mask to vector format. + mask_to_coco (bool): Whether to convert the output mask to coco format. + mask_to_yolo (bool): Whether to convert the output mask to yolo format. device (str): The device to use for inference (either "cpu" or "gpu"). + multi_gpu (bool): Whether to run the inference on multi-gpu or not. gpu_id (int): The ID of the GPU to use for inference (if device is "gpu"). + num_classes (int) : The number of classes in the output of the model. Attributes: - batch_size (int): The batch size to use for inference. work_dir (Path): The directory where the model and output files will be saved. - device (torch.device): The device to use for inference. + device (str): The device to use for inference (either "cpu" or "gpu"). + model (str): The path or url to the model file. mask_to_vec (bool): Whether to convert the output mask to vector format. - model (torch.jit.ScriptModule): The pre-trained model to use for inference. + mask_to_coco (bool): Whether to convert the output mask to coco format. + mask_to_yolo (bool): Whether to convert the output mask to yolo format. classes (int): The number of classes in the output of the model. + raster_meta : The metadata of the input raster. """ - def __init__(self, - model: str = None, - work_dir: str = None, - batch_size: int = 1, - mask_to_vec: bool = False, - vec_to_yolo: bool = False, - vec_to_coco: bool = False, - device: str = "gpu", - gpu_id: int = 0): - self.gpu_id = int(gpu_id) - self.batch_size = int(batch_size) + def __init__( + self, + model: str = None, + work_dir: str = None, + mask_to_vec: bool = False, + mask_to_coco: bool = False, + mask_to_yolo: bool = False, + device: str = None, + multi_gpu: bool = False, + gpu_id: int = 0, + num_classes: int = 5, + ): self.work_dir: Path = get_directory(work_dir) - self.device = get_device(device=device, - gpu_id=self.gpu_id) - model_path: Path = get_model(model_path_or_url=model, - work_dir=self.work_dir) + self.device = ( + device if device == "cpu" else select_model_device(gpu_id, multi_gpu) + ) + self.model = torch.jit.load( + get_model( + model_path_or_url=model, + work_dir=self.work_dir, + ), + map_location=self.device, + ) self.mask_to_vec = mask_to_vec - self.vec_to_yolo = vec_to_yolo - self.vec_to_coco = vec_to_coco - self.model = torch.jit.load(model_path, map_location=self.device) - dummy_input = torch.ones((1, 3, 32, 32), device=self.device) - with torch.no_grad(): - self.classes = self.model(dummy_input).shape[1] - - @torch.no_grad() - def __call__(self, - tiff_image: str, - tiff_name: str = None, - bbox: str = None, patch_size: int = 512, stride_size: str = None) -> None: + self.mask_to_coco = mask_to_coco + self.mask_to_yolo = mask_to_yolo + self.classes = num_classes + self.raster_meta = None + + @torch.no_grad() + def __call__( + self, + inference_input: Union[Path, str], + bands_requested: List[str] = [], + patch_size: int = 1024, + bbox: str = None, + ) -> None: + + async def run_async(): + + # Start the periodic garbage collection task + self.gc_task = asyncio.create_task(self.constant_gc(5)) # Calls gc.collect() every 5 seconds + # Run the main computation asynchronously + await self.async_run_inference( + inference_input=inference_input, + bands_requested=bands_requested, + patch_size=patch_size, + bbox=bbox + ) + self.gc_task.cancel() + + try: + await self.gc_task + except asyncio.CancelledError: + logger.info("The End of Inference") + + asyncio.run(run_async()) + + async def async_run_inference(self, + inference_input: Union[Path, str], + bands_requested: List[str] = [], + patch_size: int = 1024, + bbox: str = None, + ) -> None: + """ - Perform geo inference on geospatial imagery. + Perform geo inference on geospatial imagery using dask array. Args: - tiff_image (str): The path to the geospatial image to perform inference on. - bbox (str): The bbox or extent of the image in this format "minx, miny, maxx, maxy" + inference_input Union[Path, str]: The path/url to the geospatial image to perform inference on. + bands_requested List[str]: The requested bands to consider for the inference. patch_size (int): The size of the patches to use for inference. - stride_size (int): The stride to use between patches. + bbox (str): The bbox or extent of the image in this format "minx, miny, maxx, maxy". Returns: None """ - if isinstance(tiff_image, rio.io.DatasetReader): - if tiff_name is not None: - tiff_id = Path(tiff_name).stem - else: - logger.error(f"tiff_name is required when tiff_image is a rasterio dataset") - raise ValueError("tiff_name is required when tiff_image is a rasterio dataset") - else: - tiff_id = Path(tiff_image).stem - mask_path = self.work_dir.joinpath(tiff_id + "_mask.tif") - polygons_path = self.work_dir.joinpath(tiff_id + "_polygons.geojson") - yolo_csv_path = self.work_dir.joinpath(tiff_id + "_yolo.csv") - coco_json_path = self.work_dir.joinpath(tiff_id + "_coco.json") - dataset = RasterDataset(tiff_image, bbox=bbox) - sampler = InferenceSampler(dataset, size=patch_size, - stride=patch_size >> 1 if stride_size is None else stride_size, roi=dataset.bbox) - roi_height = sampler.im_height - roi_width = sampler.im_width - h_padded, w_padded = roi_height + patch_size, roi_width + patch_size - output_meta = dataset.src.meta - merge_patches = InferenceMerge(height=h_padded, width=w_padded, classes=self.classes, device=self.device) - dataloader = DataLoader(dataset, batch_size=self.batch_size, sampler=sampler, collate_fn=stack_samples) + # configuring dask + try: + config.set(scheduler='threads', num_workers=int(os.getenv('SLURM_CPUS_PER_TASK', 'Not available')) - 1) + config.set(pool=ThreadPool(int(os.getenv('SLURM_CPUS_PER_TASK', 'Not available')) - 1)) + except ValueError: + config.set(scheduler='threads', num_workers = os.cpu_count() - 1) + config.set(pool=ThreadPool(os.cpu_count() -1)) + if not isinstance(inference_input, (str, Path)): + raise TypeError( + f"Invalid raster type.\nGot {inference_input} of type {type(inference_input)}" + ) + if not isinstance(bands_requested, (Sequence, ListConfig)): + raise ValueError( + f"Requested bands should be a list." + f"\nGot {bands_requested} of type {type(bands_requested)}" + ) + if not isinstance(patch_size, int): + raise TypeError( + f"Invalid patch size. Patch size should be an integer..\nGot {patch_size}" + ) + + base_name = os.path.basename( + Path(inference_input) + if isinstance(inference_input, str) + else inference_input + ) + # it takes care of urls + prefix_base_name = ( + base_name if not base_name.endswith(".tif") else base_name[:-4] + ) + mask_path = self.work_dir.joinpath(prefix_base_name + "_mask.tif") + polygons_path = self.work_dir.joinpath(prefix_base_name + "_polygons.geojson") + yolo_csv_path = self.work_dir.joinpath(prefix_base_name + "_yolo.csv") + coco_json_path = self.work_dir.joinpath(prefix_base_name + "_coco.json") + stride_patch_size = int(patch_size / 2) + + """ Processing starts""" start_time = time.time() + import rioxarray # type: ignore + try: + raster_stac_item = False + if isinstance(inference_input, pystac.Item): + raster_stac_item = True + else: + try: + pystac.Item.from_file(str(inference_input)) + raster_stac_item = True + except Exception: + raster_stac_item = False + if not raster_stac_item: + with rasterio.open(inference_input, "r") as src: + self.raster_meta = src.meta + self.raster = src + aoi_dask_array = rioxarray.open_rasterio(inference_input, chunks=stride_patch_size) + try: + if bands_requested: + raster_bands_request = [int(b) for b in bands_requested.split(",")] + if ( + len(raster_bands_request) != 0 + and len(raster_bands_request) != aoi_dask_array.shape[0] + ): + aoi_dask_array = xr.concat( + [aoi_dask_array[i - 1, :, :] for i in raster_bands_request], + dim="band" + ) + except Exception as e: + raise e + else: + assets = asset_by_common_name(inference_input) + bands_requested = { + band: assets[band] for band in bands_requested.split(",") + } + rio_gdal_options = { + "GDAL_DISABLE_READDIR_ON_OPEN": "EMPTY_DIR", + "CPL_VSIL_CURL_ALLOWED_EXTENSIONS": ".tif", + } + all_bands_requested = [] + with rasterio.Env(**rio_gdal_options): + with rasterio.open(bands_requested[next(iter(bands_requested))]["meta"].href, "r") as src: + self.raster_meta = src.meta + self.raster = src + for key, value in bands_requested.items(): + all_bands_requested.append(rioxarray.open_rasterio(value["meta"].href, chunks=stride_patch_size)) + aoi_dask_array = xr.concat(all_bands_requested, dim="band") + del all_bands_requested + + if bbox is not None: + bbox = tuple(map(float, bbox.split(", "))) + roi_window = from_bounds( + left=bbox[0], + bottom=bbox[1], + right=bbox[2], + top=bbox[3], + transform=self.raster_meta["transform"], + ) + col_off, row_off = roi_window.col_off, roi_window.row_off + width, height = roi_window.width, roi_window.height + aoi_dask_array = aoi_dask_array[ + :, row_off : row_off + height, col_off : col_off + width + ] + self.original_shape = aoi_dask_array.shape + # Pad the array to make dimensions multiples of the patch size + pad_height = ( + stride_patch_size - aoi_dask_array.shape[1] % stride_patch_size + ) % stride_patch_size + pad_width = ( + stride_patch_size - aoi_dask_array.shape[2] % stride_patch_size + ) % stride_patch_size + print(aoi_dask_array.shape[0]) + aoi_dask_array = da.pad( + aoi_dask_array.data, + ((0, 0), (0, pad_height), (0, pad_width)), + mode="constant", + ).rechunk((aoi_dask_array.shape[0], stride_patch_size, stride_patch_size)) + + + # run the model + aoi_dask_array = aoi_dask_array.map_overlap( + runModel, + model=self.model, + patch_size=patch_size, + device=self.device, + chunks=( + self.classes + 1, + patch_size, + patch_size, + ), + depth={1: (0, stride_patch_size), 2: (0, stride_patch_size)}, + boundary="none", + trim=False, + dtype=np.float16, + ) + aoi_dask_array = aoi_dask_array.map_overlap( + sum_overlapped_chunks, + chunk_size=patch_size, + drop_axis=0, + chunks=( + stride_patch_size, + stride_patch_size, + ), + depth={1: (stride_patch_size, 0), 2: (stride_patch_size, 0)}, + trim=False, + boundary="none", + dtype=np.uint8, + ) + - for batch in tqdm(dataloader, desc='extracting features', unit='batch', total=len(dataloader)): - image_tensor = batch["image"].to(self.device) - window_tensor = batch["window"].unsqueeze(1).to(self.device) - pixel_xy = batch["pixel_coords"] - output = self.model(image_tensor) - merge_patches.merge_on_cpu(batch=output, windows=window_tensor, pixel_coords=pixel_xy) - merge_patches.save_as_tiff(height=dataset.image_height, - width=dataset.image_width, - output_meta=output_meta, - output_path=mask_path) - - if self.mask_to_vec: - mask_to_poly_geojson(mask_path, polygons_path) - if self.vec_to_yolo: - gdf_to_yolo(polygons_path, mask_path, yolo_csv_path) - if self.vec_to_coco: - geojson2coco(mask_path, polygons_path, coco_json_path) - - dataset.src.close() - end_time = time.time() - start_time - - logger.info('Extraction Completed in {:.0f}m {:.0f}s'.format(end_time // 60, end_time % 60)) + with ResourceProfiler(dt=1) as prof: + with ProgressBar() as pbar: + pbar.register() + import rioxarray + logger.info("Inference is running:") + aoi_dask_array = xr.DataArray(aoi_dask_array[: self.original_shape[1], : self.original_shape[2]], dims=("y", "x"),attrs=xarray_profile_info(self.raster)) + aoi_dask_array.rio.to_raster(mask_path, tiled=True, lock=threading.Lock()) + + total_time = time.time() - start_time + if self.mask_to_vec: + mask_to_poly_geojson(mask_path, polygons_path) + if self.mask_to_yolo: + gdf_to_yolo(polygons_path, mask_path, yolo_csv_path) + if self.mask_to_coco: + geojson2coco(mask_path, polygons_path, coco_json_path) + logger.info( + "Extraction Completed in {:.0f}m {:.0f}s".format( + total_time // 60, total_time % 60 + ) + ) + torch.cuda.empty_cache() + + except Exception as e: + print(f"Processing on the Dask cluster failed due to: {e}") + raise e + + async def constant_gc(self,interval_seconds): + while True: + gc.collect() # Call garbage collection + await asyncio.sleep(interval_seconds) # Wait for the specified interval def main() -> None: arguments = cmd_interface() - geo_inference = GeoInference(model=arguments["model"], - work_dir=arguments["work_dir"], - batch_size=arguments["batch_size"], - mask_to_vec=arguments["vec"], - vec_to_yolo=arguments["yolo"], - vec_to_coco=arguments["coco"], - device=arguments["device"], - gpu_id=arguments["gpu_id"]) - geo_inference(tiff_image=arguments["image"], bbox=arguments["bbox"]) - + geo_inference = GeoInference( + model=arguments["model"], + work_dir=arguments["work_dir"], + mask_to_vec=arguments["vec"], + mask_to_coco=arguments["coco"], + mask_to_yolo=arguments["yolo"], + multi_gpu=arguments["multi_gpu"], + device=arguments["device"], + gpu_id=arguments["gpu_id"], + num_classes=arguments["classes"], + ) + geo_inference( + inference_input=arguments["image"], + bands_requested=arguments["bands_requested"], + patch_size=arguments["patch_size"], + bbox=arguments["bbox"], + ) + + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/geo_inference/utils/geo.py b/geo_inference/utils/geo.py index 70b530c..b202325 100644 --- a/geo_inference/utils/geo.py +++ b/geo_inference/utils/geo.py @@ -11,6 +11,7 @@ from shapely.geometry.base import BaseGeometry from pyogrio.errors import DataSourceError + from ..config.logging_config import logger logger = logging.getLogger(__name__) diff --git a/geo_inference/utils/geo_transforms.py b/geo_inference/utils/geo_transforms.py index 978cc28..80a7a75 100644 --- a/geo_inference/utils/geo_transforms.py +++ b/geo_inference/utils/geo_transforms.py @@ -508,4 +508,4 @@ def _row_to_coco(row, geom_col, category_id_col, image_id_col, score_col): with open(output_path, 'w') as outfile: json.dump(output_dict, outfile) - return output_dict + return output_dict \ No newline at end of file diff --git a/geo_inference/utils/helpers.py b/geo_inference/utils/helpers.py index 71a1a74..86c659d 100644 --- a/geo_inference/utils/helpers.py +++ b/geo_inference/utils/helpers.py @@ -6,33 +6,50 @@ import rasterio from pathlib import Path from urllib.parse import urlparse - import requests import torch import yaml +import csv +from tqdm import tqdm +from typing import Dict, Union +from hydra.utils import to_absolute_path +from pandas.io.common import is_url +from collections import OrderedDict +import pystac +from pystac.extensions.eo import Band +from pathlib import Path + + from ..config.logging_config import logger logger = logging.getLogger(__name__) USER_CACHE = Path.home().joinpath(".cache") script_dir = Path(__file__).resolve().parent.parent -MODEL_CONFIG = script_dir / "config" / "models.yaml" +MODEL_CONFIG = script_dir / "config" / "models.yaml" + def is_tiff_path(path: str): # Check if the given path ends with .tiff or .tif (case insensitive) - return re.match(r'.*\.(tiff|tif)$', path, re.IGNORECASE) is not None + return re.match(r".*\.(tiff|tif)$", path, re.IGNORECASE) is not None + def is_tiff_url(url: str): # Check if the URL ends with .tiff or .tif (case insensitive) parsed_url = urlparse(url) - return re.match(r'.*\.(tiff|tif)$', os.path.basename(parsed_url.path), re.IGNORECASE) is not None + return ( + re.match(r".*\.(tiff|tif)$", os.path.basename(parsed_url.path), re.IGNORECASE) + is not None + ) + def read_yaml(yaml_file_path: str | Path): with open(yaml_file_path, "r") as f: config = yaml.safe_load(f.read()) return config + def validate_asset_type(image_asset: str): """Validate image asset type @@ -43,10 +60,14 @@ def validate_asset_type(image_asset: str): rasterio.io.DatasetReader: rasterio.io.DatasetReader. """ if isinstance(image_asset, rasterio.io.DatasetReader): - return image_asset if not image_asset.closed else rasterio.open(image_asset.name) - + return ( + image_asset if not image_asset.closed else rasterio.open(image_asset.name) + ) + if isinstance(image_asset, str): - if urlparse(image_asset).scheme in ('http', 'https') and is_tiff_url(image_asset): + if urlparse(image_asset).scheme in ("http", "https") and is_tiff_url( + image_asset + ): try: return rasterio.open(image_asset) except rasterio.errors.RasterioIOError as e: @@ -58,10 +79,13 @@ def validate_asset_type(image_asset: str): except rasterio.errors.RasterioIOError as e: logger.error(f"Failed to open file {image_asset}: {e}") raise ValueError(f"Invalid image_asset file: {image_asset}") - - logger.error("Image asset is neither a valid TIFF image, Rasterio dataset, nor a valid TIFF URL.") + + logger.error( + "Image asset is neither a valid TIFF image, Rasterio dataset, nor a valid TIFF URL." + ) raise ValueError("Invalid image_asset type") + def calculate_gpu_stats(gpu_id: int = 0): """Calculate GPU stats @@ -71,36 +95,11 @@ def calculate_gpu_stats(gpu_id: int = 0): Returns: tuple(dict, dict): gpu stats. """ - res = {'gpu': torch.cuda.utilization(gpu_id)} + res = {"gpu": torch.cuda.utilization(gpu_id)} torch_cuda_mem = torch.cuda.mem_get_info(gpu_id) - mem = { - 'used': torch_cuda_mem[-1] - torch_cuda_mem[0], - 'total': torch_cuda_mem[-1] - } + mem = {"used": torch_cuda_mem[-1] - torch_cuda_mem[0], "total": torch_cuda_mem[-1]} return res, mem -def download_file_from_url(url, save_path, access_token=None): - """Download a file from a URL - - Args: - url (str): URL to the file. - save_path (str or Path): Path to save the file. - access_token (str, optional): Access token. Defaults to None. - """ - try: - headers = {} - headers["Authorization"] = f"Bearer {access_token}" - response = requests.get(url, headers=headers, stream=True) - if response.status_code == 200: - with open(save_path, 'wb') as file: - for chunk in response.iter_content(chunk_size=128): - file.write(chunk) - logger.info(f"Downloaded {save_path}") - else: - logger.error(f"Failed to download the file from {url}. Status code: {response.status_code}") - except Exception as e: - logger.error(f"An error occurred: {e}") - raise def extract_tar_gz(tar_gz_file: str | Path, target_directory: str | Path): """Extracts a tar.gz file to a target directory @@ -113,68 +112,27 @@ def extract_tar_gz(tar_gz_file: str | Path, target_directory: str | Path): for member in tar.getmembers(): if member.isreg(): member.name = os.path.basename(member.name) - tar.extract(member, target_directory) + tar.extract(member, target_directory) # tar.extractall(path=target_directory) logger.info(f"Successfully extracted {tar_gz_file} to {target_directory}") os.remove(tar_gz_file) except tarfile.TarError as e: logger.error(f"Error while extracting {tar_gz_file}: {e}") except Exception as e: - logger.error(f"An error occurred: {e}") + logger.error(f"An error occurred: {e}") raise -def get_device(device: str = "gpu", - gpu_id: int = 0, - gpu_max_ram_usage: int = 25, - gpu_max_utilization: int = 15): - """Returns a torch device - - Args: - device (str): Accepts "cpu" or "gpu". Defaults to "gpu". - gpu_id (int): GPU id. Defaults to 0. - gpu_max_ram_usage (int): GPU max ram usage. Defaults to 25. - gpu_max_utilization (int): GPU max utilization. Defaults to 15. - Returns: - torch.device: torch device. +def get_directory(work_directory: str) -> Path: """ - if device == "cpu": - return torch.device('cpu') - elif device == "gpu": - res, mem = calculate_gpu_stats(gpu_id=gpu_id) - used_ram = mem['used'] / (1024 ** 2) - max_ram = mem['total'] / (1024 ** 2) - used_ram_percentage = (used_ram / max_ram) * 100 - logger.info(f"\nGPU RAM used: {round(used_ram_percentage)}%" - f"[used_ram: {used_ram:.0f}MiB] [max_ram: {max_ram:.0f}MiB]\n" - f"GPU Utilization: {res['gpu']}%") - if used_ram_percentage < gpu_max_ram_usage: - if res["gpu"] < gpu_max_utilization: - return torch.device(f"cuda:{gpu_id}") - else: - logger.warning(f"Reverting to CPU!\n" - f"Current GPU:{gpu_id} utilization: {res['gpu']}%\n" - f"Max GPU utilization allowed: {gpu_max_utilization}%") - return torch.device('cpu') - else: - logger.warning(f"Reverting to CPU!\n" - f"Current GPU:{gpu_id} RAM usage: {used_ram_percentage}%\n" - f"Max used RAM allowed: {gpu_max_ram_usage}%") - return torch.device('cpu') - else: - logger.error("Invalid device type requested: {device}") - raise ValueError("Invalid device type") - -def get_directory(work_directory: str)-> Path: - """Returns a working directory - + Returns a working directory Args: work_directory (str): User's specified path Returns: Path: working directory """ - + if work_directory: work_directory = Path(work_directory) if not work_directory.is_dir(): @@ -183,9 +141,34 @@ def get_directory(work_directory: str)-> Path: work_directory = USER_CACHE.joinpath("geo-inference") if not work_directory.is_dir(): Path.mkdir(work_directory, parents=True) - + return work_directory + +def download_file_from_url(url, save_path, access_token=None): + """Download a file from a URL + + Args: + url (str): URL to the file. + save_path (str or Path): Path to save the file. + access_token (str, optional): Access token. Defaults to None. + """ + try: + headers = {} + headers["Authorization"] = f"Bearer {access_token}" + response = requests.get(url, headers=headers, stream=True) + if response.status_code == 200: + with open(save_path, 'wb') as file: + for chunk in response.iter_content(chunk_size=128): + file.write(chunk) + logger.info(f"Downloaded {save_path}") + else: + logger.error(f"Failed to download the file from {url}. Status code: {response.status_code}") + except Exception as e: + logger.error(f"An error occurred: {e}") + raise + + def get_model(model_path_or_url: str, work_dir: Path) -> Path: """Download a model from the model zoo @@ -212,6 +195,177 @@ def get_model(model_path_or_url: str, work_dir: Path) -> Path: logger.error(f"Model {model_path_or_url} not found") raise ValueError("Invalid model path") + +def select_model_device(gpu_id: int, multi_gpu: bool): + device = "cpu" + if torch.cuda.is_available(): + if not multi_gpu: + res = {"gpu": torch.cuda.utilization(gpu_id)} + torch_cuda_mem = torch.cuda.mem_get_info(gpu_id) + mem = { + "used": torch_cuda_mem[-1] - torch_cuda_mem[0], + "total": torch_cuda_mem[-1], + } + used_ram = mem["used"] / (1024**2) + max_ram = mem["total"] / (1024**2) + used_ram_percentage = (used_ram / max_ram) * 100 + if used_ram_percentage < 70 and res["gpu"] < 70: + device = f"cuda:{gpu_id}" + else: + num_devices = torch.cuda.device_count() + for i in range(num_devices): + res = {"gpu": torch.cuda.utilization(i)} + torch_cuda_mem = torch.cuda.mem_get_info(i) + mem = { + "used": torch_cuda_mem[-1] - torch_cuda_mem[0], + "total": torch_cuda_mem[-1], + } + used_ram = mem["used"] / (1024**2) + max_ram = mem["total"] / (1024**2) + used_ram_percentage = (used_ram / max_ram) * 100 + if used_ram_percentage < 70 and res["gpu"] < 70: + device = f"cuda:{i}" + break + return device + + +def xarray_profile_info( + raster, +): + """ + Save mask to file. + Args: + raster : The meta data of the input raster. + Returns: + None + """ + driver = 'GTiff' if raster.driver == 'VRT' else raster.driver + profile_kwargs = { + 'crs': raster.crs.to_string(), # Coordinate Reference System, using src.crs.to_string() to get a string representation + 'transform': raster.transform, # Affine transformation matrix + 'count': 1, # Number of bands + 'width': raster.width, # Width of the raster + 'height': raster.height, # Height of the raster + 'driver': driver, # Raster format driver + 'dtype': "uint8", # Data type (use dtype directly if it's a valid format for xarray) + 'BIGTIFF': 'YES', # BigTIFF option + 'compress': 'lzw' # Compression type + } + return profile_kwargs + + +def get_tiff_paths_from_csv( + csv_path: Union[str, Path], +): + """ + Creates list of to-be-processed tiff files from a csv file referencing input data + Args: + csv_path (Union[str, Path]) : path to csv file containing list of input data. See README for details on expected structure of csv. + Returns: + A list of tiff path + """ + aois_dictionary = [] + data_list = read_csv(csv_path) + logger.info( + f"\n\tSuccessfully read csv file: {Path(csv_path).name}\n" + f"\tNumber of rows: {len(data_list)}\n" + f"\tCopying first row:\n{data_list[0]}\n" + ) + with tqdm( + enumerate(data_list), desc="Creating A list of tiff paths", total=len(data_list) + ) as _tqdm: + for i, aoi_dict in _tqdm: + _tqdm.set_postfix_str(f"Image: {Path(aoi_dict['tif']).stem}") + try: + aois_dictionary.append(aoi_dict) + except FileNotFoundError as e: + logger.error( + f"{e}" f"Failed to get the path of :\n{aoi_dict}\n" f"Index: {i}" + ) + return aois_dictionary + + +def asset_by_common_name(raster_raw_input) -> Dict: + """ + Get assets by common band name (only works for assets containing 1 band) + Adapted from: + https://github.com/sat-utils/sat-stac/blob/40e60f225ac3ed9d89b45fe564c8c5f33fdee7e8/satstac/item.py#L75 + @return: + """ + _assets_by_common_name = OrderedDict() + item = pystac.Item.from_file(raster_raw_input) + for name, a_meta in item.assets.items(): + bands = [] + if "eo:bands" in a_meta.extra_fields.keys(): + bands = a_meta.extra_fields["eo:bands"] + if len(bands) == 1: + eo_band = bands[0] + if "common_name" in eo_band.keys(): + common_name = eo_band["common_name"] + if not Band.band_range(common_name): + raise ValueError( + f'Must be one of the accepted common names. Got "{common_name}".' + ) + else: + _assets_by_common_name[common_name] = { + "meta": a_meta, + "name": name, + } + if not _assets_by_common_name: + raise ValueError("Common names for assets cannot be retrieved") + return _assets_by_common_name + + +def read_csv(csv_file_name: str) -> Dict: + """ + Open csv file and parse it, returning a list of dictionaries with keys: + - "tif": path to a single image + - "gpkg": path to a single ground truth file + - dataset: (str) "trn" or "tst" + - aoi_id: (str) a string id for area of interest + @param csv_file_name: + path to csv file containing list of input data with expected columns + expected columns (without header): imagery, ground truth, dataset[, aoi id] + source : geo_deep_learning + """ + list_values = [] + with open(csv_file_name, "r") as f: + reader = csv.reader(f) + row_lengths_set = set() + for row in reader: + row_lengths_set.update([len(row)]) + if ";" in row[0]: + raise TypeError( + "Elements in rows should be delimited with comma, not semicolon." + ) + if not len(row_lengths_set) == 1: + raise ValueError( + f"Rows in csv should be of same length. Got rows with length: {row_lengths_set}" + ) + row = [str(i) or None for i in row] # replace empty strings to None. + row.extend( + [None] * (4 - len(row)) + ) # fill row with None values to obtain row of length == 5 + + row[0] = ( + to_absolute_path(row[0]) if not is_url(row[0]) else row[0] + ) # Convert relative paths to absolute with hydra's util to_absolute_path() + try: + row[1] = str(to_absolute_path(row[1]) if not is_url(row[1]) else row[1]) + except TypeError: + row[1] = None + # save all values + list_values.append( + {"tif": str(row[0]), "gpkg": row[1], "split": row[2], "aoi_id": row[3]} + ) + try: + # Try sorting according to dataset name (i.e. group "train", "val" and "test" rows together) + list_values = sorted(list_values, key=lambda k: k["split"]) + except TypeError: + logger.warning("Unable to sort csv rows") + return list_values + + def cmd_interface(argv=None): """ Parse command line arguments for extracting features from high-resolution imagery using pre-trained models. @@ -228,73 +382,104 @@ def cmd_interface(argv=None): Usage: Use the -h option to get supported arguments. """ - parser = argparse.ArgumentParser(usage="%(prog)s [-h HELP] use -h to get supported arguments.", - description='Extract features from high-resolution imagery using pre-trained models.') - - parser.add_argument("-a", "--args", nargs=1, help="Path to arguments stored in yaml, consult ./config/sample_config.yaml") - - parser.add_argument("-i", "--image", nargs=1, help="Path to Geotiff") - - parser.add_argument("-bb", "--bbox", nargs=1, help="AOI bbox in this format'minx, miny, maxx, maxy'") - + parser = argparse.ArgumentParser( + usage="%(prog)s [-h HELP] use -h to get supported arguments.", + description="Extract features from high-resolution imagery using pre-trained models.", + ) + + parser.add_argument( + "-a", + "--args", + nargs=1, + help="Path to arguments stored in yaml, consult ./config/sample_config.yaml", + ) + + parser.add_argument( + "-bb", "--bbox", nargs=1, help="AOI bbox in this format'minx, miny, maxx, maxy'" + ) + + parser.add_argument( + "-br", + "--bands_requested", + nargs=1, + help="bands_requested in this format'R,G,B'", + ) + + parser.add_argument( + "-i", "--image", nargs=1, help="Path or URL to the input image" + ) + parser.add_argument("-m", "--model", nargs=1, help="Path or URL to the model file") - + parser.add_argument("-wd", "--work_dir", nargs=1, help="Working Directory") - - parser.add_argument("-bs", "--batch_size", nargs=1, help="The Batch Size") - + + parser.add_argument("-ps", "--patch_size", nargs=1, help="The Patch Size") + parser.add_argument("-v", "--vec", nargs=1, help="Vector Conversion") - + + parser.add_argument("-mg", "--mgpu", nargs=1, help="Multi GPU") + + parser.add_argument("-cls", "--classes", nargs=1, help="Inference Classes") + parser.add_argument("-y", "--yolo", nargs=1, help="Yolo Conversion") - + parser.add_argument("-c", "--coco", nargs=1, help="Coco Conversion") - + parser.add_argument("-d", "--device", nargs=1, help="CPU or GPU Device") - + parser.add_argument("-id", "--gpu_id", nargs=1, help="GPU ID, Default = 0") - + args = parser.parse_args() - + if args.args: config = read_yaml(args.args[0]) image = config["arguments"]["image"] - bbox = config["arguments"]["bbox"] model = config["arguments"]["model"] + bbox = None if config["arguments"]["bbox"].lower() == "none" else config["arguments"]["bbox"] work_dir = config["arguments"]["work_dir"] - batch_size = config["arguments"]["batch_size"] + bands_requested = config["arguments"]["bands_requested"] vec = config["arguments"]["vec"] yolo = config["arguments"]["yolo"] coco = config["arguments"]["coco"] device = config["arguments"]["device"] gpu_id = config["arguments"]["gpu_id"] + multi_gpu = config["arguments"]["mgpu"] + classes = config["arguments"]["classes"] + patch_size = config["arguments"]["patch_size"] elif args.image: - image = args.image[0] - bbox = args.bbox[0] if args.bbox else None + image =args.image[0] model = args.model[0] if args.model else None + bbox = args.bbox[0] if args.bbox else None work_dir = args.work_dir[0] if args.work_dir else None - batch_size = args.batch_size[0] if args.batch_size else 1 + bands_requested = args.bands_requested[0] if args.bands_requested else [] vec = args.vec[0] if args.vec else False yolo = args.yolo[0] if args.yolo else False coco = args.coco[0] if args.coco else False device = args.device[0] if args.device else "gpu" gpu_id = args.gpu_id[0] if args.gpu_id else 0 + multi_gpu = args.mgpu[0] if args.mgpu else False + classes = args.classes[0] if args.classes else 5 + patch_size = args.patch_size[0] if args.patch_size else 1024 else: - print('use the help [-h] option for correct usage') + print("use the help [-h] option for correct usage") raise SystemExit - - arguments= {"image": image, - "bbox": bbox, - "model": model, - "work_dir": work_dir, - "batch_size": batch_size, - "vec": vec, - "yolo": yolo, - "coco": coco, - "device": device, - "gpu_id": gpu_id - } + arguments = { + "model": model, + "image": image, + "bands_requested": bands_requested, + "work_dir": work_dir, + "classes": classes, + "bbox": bbox, + "multi_gpu": multi_gpu, + "vec": vec, + "yolo": yolo, + "coco": coco, + "device": device, + "gpu_id": gpu_id, + "patch_size": patch_size, + } return arguments - -if __name__ == '__main__': + + +if __name__ == "__main__": pass - \ No newline at end of file diff --git a/geo_inference/utils/polygon.py b/geo_inference/utils/polygon.py index d2d8e97..761996b 100644 --- a/geo_inference/utils/polygon.py +++ b/geo_inference/utils/polygon.py @@ -9,6 +9,7 @@ import shapely from rasterio import features from shapely.geometry import shape +from pathlib import Path from ..config.logging_config import logger from .geo import rasterio_load @@ -256,4 +257,3 @@ def geojson2coco(image_src, label_src, output_path=None, category_attribute="val json.dump(coco_dataset, outfile) logger.info(f"CocoJson file saved to {output_path}") - diff --git a/requirements.txt b/requirements.txt index b8207cf..ac6fc6e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,8 +1,13 @@ -torchgeo>=0.3 -affine>=2.3.0 +torchgeo>=0.5.2 +affine>=2.4.0 colorlog==6.7.0 -scipy>=1.11.2 +scipy>=1.13.1 pyyaml>=5.2 -requests>=2.22.0 pynvml>=11.0 -geopandas>=0.10.2 +geopandas>=0.14.4 +dask-image>=2024.5.3 +dask>= 2024.6.2 +requests>=2.32.3 +xarray>=2024.6.0 +pystac>=1.10.1 +rioxarray>=0.15.6 \ No newline at end of file diff --git a/tests/data/sample.yaml b/tests/data/sample.yaml index 14bff8a..6fe6671 100644 --- a/tests/data/sample.yaml +++ b/tests/data/sample.yaml @@ -3,9 +3,13 @@ arguments: bbox: None # "minx, miny, maxx, maxy" model: "rgb-4class-segformer" # Name of Extraction Model: str work_dir: None # Working Directory: str - batch_size: 1 # Batch size vec: False # Vector Conversion: bool yolo: False # YOLO Conversion: bool coco: False # COCO Conversion: bool device: "gpu" # cpu or gpu: str gpu_id: 0 # GPU ID: int + bands_requested: '1,2,3' # requested Bands + mgpu: False + classes : 5 + n_workers: 20 + patch_size: 1024 \ No newline at end of file diff --git a/tests/test_geo_blocks.py b/tests/test_geo_blocks.py deleted file mode 100644 index 7de142b..0000000 --- a/tests/test_geo_blocks.py +++ /dev/null @@ -1,146 +0,0 @@ -from pathlib import Path -from typing import Any, Dict - -import os -import pytest -import numpy as np -import scipy.signal.windows as w -import rasterio as rio -import torch -from geo_inference.geo_blocks import RasterDataset, InferenceSampler, InferenceMerge - - -@pytest.fixture -def test_data_dir(): - return Path(__file__).parent / "data" - -@pytest.fixture -def raster_dataset(test_data_dir): - image_asset = str(test_data_dir / "0.tif") - return RasterDataset(image_asset) - -class TestRasterDataset: - def test_init(self, raster_dataset): - assert isinstance(raster_dataset.src, rio.DatasetReader) - assert raster_dataset.bands > 0 - assert raster_dataset.res > 0 - assert raster_dataset._crs is not None - - def test_getitem(self, raster_dataset): - query: Dict[str, Any] = { - 'path': raster_dataset.src.name, - 'window': (0, 0, 10, 10), # replace with actual window - 'pixel_coords': (0, 0, 10, 10) # replace with actual pixel_coords - } - sample = raster_dataset.__getitem__(query) - assert isinstance(sample, dict) - assert 'image' in sample - assert 'crs' in sample - assert 'pixel_coords' in sample - assert 'window' in sample - assert 'path' in sample - - def test_get_tensor(self, raster_dataset): - query = (0, 0, 10, 10) # replace with actual query - size = 10 # replace with actual size - tensor = raster_dataset._get_tensor(query, size) - assert isinstance(tensor, torch.Tensor) - assert tensor.shape[-2:] == (size, size) - - def test_pad_patch(self): - x = torch.rand((3, 5, 5)) - patch_size = 10 - padded = RasterDataset.pad_patch(x, patch_size) - assert isinstance(padded, torch.Tensor) - assert padded.shape[-2:] == (patch_size, patch_size) - - -class TestInferenceSampler: - @pytest.fixture - def inference_sampler(self, raster_dataset): - size = (10, 10) - stride = (5, 5) - return InferenceSampler(raster_dataset, size, stride) - - def test_init(self, inference_sampler): - assert inference_sampler.size == (10, 10) - assert inference_sampler.stride == (5, 5) - assert inference_sampler.length > 0 - - def test_iter(self, inference_sampler): - for sample in inference_sampler: - assert isinstance(sample, dict) - assert 'pixel_coords' in sample - assert 'path' in sample - assert 'window' in sample - - def test_len(self, inference_sampler): - assert len(inference_sampler) == inference_sampler.length - - def test_generate_corner_windows(self): - window_size = 10 - step = window_size >> 1 - windows = InferenceSampler.generate_corner_windows(window_size) - center_window = np.matrix(w.hann(M=window_size, sym=False)) - center_window = center_window.T.dot(center_window) - window_top = np.vstack([np.tile(center_window[step:step + 1, :], (step, 1)), center_window[step:, :]]) - window_bottom = np.vstack([center_window[:step, :], np.tile(center_window[step:step + 1, :], (step, 1))]) - window_left = np.hstack([np.tile(center_window[:, step:step + 1], (1, step)), center_window[:, step:]]) - window_right = np.hstack([center_window[:, :step], np.tile(center_window[:, step:step + 1], (1, step))]) - window_top_left = np.block([[np.ones((step, step)), window_top[:step, step:]], - [window_left[step:, :step], window_left[step:, step:]]]) - window_top_right = np.block([[window_top[:step, :step], np.ones((step, step))], - [window_right[step:, :step], window_right[step:, step:]]]) - window_bottom_left = np.block([[window_left[:step, :step], window_left[:step, step:]], - [np.ones((step, step)), window_bottom[step:, step:]]]) - window_bottom_right = np.block([[window_right[:step, :step], window_right[:step, step:]], - [window_bottom[step:, :step], np.ones((step, step))]]) - assert isinstance(windows, np.ndarray) - assert windows.shape == (3, 3, window_size, window_size) - assert np.all(windows >= 0) and np.all(windows <= 1) - assert np.allclose(windows[1, 1], center_window) - assert np.allclose(windows[0, 1], window_top) - assert np.allclose(windows[2, 1], window_bottom) - assert np.allclose(windows[1, 0], window_left) - assert np.allclose(windows[1, 2], window_right) - assert np.allclose(windows[0, 0], window_top_left) - assert np.allclose(windows[0, 2], window_top_right) - assert np.allclose(windows[2, 0], window_bottom_left) - assert np.allclose(windows[2, 2], window_bottom_right) - -class TestInferenceMerge: - @pytest.fixture - def inference_merge(self): - height = 100 - width = 100 - classes = 3 - device = torch.device('cpu') - return InferenceMerge(height, width, classes, device) - - def test_init(self, inference_merge): - assert inference_merge.height == 100 - assert inference_merge.width == 100 - assert inference_merge.classes == 3 - assert inference_merge.device == torch.device('cpu') - assert isinstance(inference_merge.image, np.ndarray) - assert isinstance(inference_merge.norm_mask, np.ndarray) - - def test_merge_on_cpu(self, inference_merge): - batch = torch.rand((3, 3, 10, 10)) - windows = torch.rand((3, 1, 10, 10)) - pixel_coords = [(0, 0, 10, 10), (10, 10, 10, 10), (20, 20, 10, 10)] - inference_merge.merge_on_cpu(batch, windows, pixel_coords) - assert np.all(inference_merge.image >= 0) - assert np.all(inference_merge.norm_mask >= 1) - - def test_save_as_tiff(self, inference_merge, test_data_dir): - height = 100 - width = 100 - output_meta = { - "crs": "+proj=latlong", - "transform": rio.Affine(1.0, 0, 0, 0, 1.0, 0) - } - output_path = test_data_dir / "test_1.tiff" - inference_merge.save_as_tiff(height, width, output_meta, output_path) - assert output_path.exists() - os.remove(output_path) \ No newline at end of file diff --git a/tests/test_geo_dask.py b/tests/test_geo_dask.py new file mode 100644 index 0000000..06a0234 --- /dev/null +++ b/tests/test_geo_dask.py @@ -0,0 +1,632 @@ + + +import pytest +import numpy as np +import scipy.signal.windows as w +import torch + + + +@pytest.fixture +def mock_block_info_top_right_corner(): + # Mocking block_info with predefined values + block_info = [{"num-chunks": [1, 3, 3], "chunk-location": [0, 0, 2]}] + return block_info + + +@pytest.fixture +def mock_block_info_top_left_corner(): + # Mocking block_info with predefined values + block_info = [{"num-chunks": [1, 3, 3], "chunk-location": [0, 0, 0]}] + return block_info + + +@pytest.fixture +def mock_block_info_bottom_right_corner(): + # Mocking block_info with predefined values + block_info = [{"num-chunks": [1, 3, 3], "chunk-location": [0, 2, 2]}] + return block_info + + +@pytest.fixture +def mock_block_info_bottom_left_corner(): + # Mocking block_info with predefined values + block_info = [{"num-chunks": [1, 3, 3], "chunk-location": [0, 2, 0]}] + return block_info + + +@pytest.fixture +def mock_block_info_bottom_edge(): + # Mocking block_info with predefined values + block_info = [{"num-chunks": [1, 3, 3], "chunk-location": [0, 2, 1]}] + return block_info + + +@pytest.fixture +def mock_block_info_top_edge(): + # Mocking block_info with predefined values + block_info = [{"num-chunks": [1, 3, 3], "chunk-location": [0, 0, 1]}] + return block_info + + +@pytest.fixture +def mock_block_info_left_edge(): + # Mocking block_info with predefined values + block_info = [{"num-chunks": [1, 3, 3], "chunk-location": [0, 1, 0]}] + return block_info + + +@pytest.fixture +def mock_block_info_right_edge(): + # Mocking block_info with predefined values + block_info = [{"num-chunks": [1, 3, 3], "chunk-location": [0, 1, 2]}] + return block_info + +@pytest.fixture +def generate_corner_windows() -> np.ndarray: + """ + Generates 9 2D signal windows that covers edge and corner coordinates + + Args: + window_size (int): The size of the window. + + Returns: + np.ndarray: 9 2D signal windows stacked in array (3, 3). + """ + step = 4 >> 1 + window = np.matrix(w.hann(M=4, sym=False)) + window = window.T.dot(window) + window_u = np.vstack( + [np.tile(window[step : step + 1, :], (step, 1)), window[step:, :]] + ) + window_b = np.vstack( + [window[:step, :], np.tile(window[step : step + 1, :], (step, 1))] + ) + window_l = np.hstack( + [np.tile(window[:, step : step + 1], (1, step)), window[:, step:]] + ) + window_r = np.hstack( + [window[:, :step], np.tile(window[:, step : step + 1], (1, step))] + ) + window_ul = np.block( + [ + [np.ones((step, step)), window_u[:step, step:]], + [window_l[step:, :step], window_l[step:, step:]], + ] + ) + window_ur = np.block( + [ + [window_u[:step, :step], np.ones((step, step))], + [window_r[step:, :step], window_r[step:, step:]], + ] + ) + window_bl = np.block( + [ + [window_l[:step, :step], window_l[:step, step:]], + [np.ones((step, step)), window_b[step:, step:]], + ] + ) + window_br = np.block( + [ + [window_r[:step, :step], window_r[:step, step:]], + [window_b[step:, :step], np.ones((step, step))], + ] + ) + return np.array( + [ + [window_ul, window_u, window_ur], + [window_l, window, window_r], + [window_bl, window_b, window_br], + ] + ) + + +class TestSumOverlappedChunks: + + def test_sum_overlapped_chunks_top_edge( + self, + mock_block_info_top_edge, + ): + from geo_inference import geo_dask as code + + arr = np.zeros((3, 6, 8)) + arr[0, :, :] = np.random.randint(low=1, high=5, size=(1, 6, 8)) + arr[1, :, :] = np.random.randint(low=1, high=5, size=(1, 6, 8)) + arr[2, :, :] = np.random.randint(low=1, high=5, size=(1, 6, 8)) + + expected_result = np.divide( + arr[:-1, :2, :2] + arr[:-1, :2, 2:4], + arr[-1, :2, :2], + out=np.zeros_like(arr[:-1, :2, :2], dtype=float), + where=arr[:-1, :2, :2] != 0, + ) + + produced_result = code.sum_overlapped_chunks(arr, 4, mock_block_info_top_edge) + np.testing.assert_array_almost_equal( + np.argmax(expected_result, axis=0), produced_result,decimal=2) + + def test_sum_overlapped_chunks_top_right_corner( + self, + mock_block_info_top_right_corner, + ): + from geo_inference import geo_dask as code + + arr = np.zeros((3, 6, 6)) + arr[0, :, :] = np.random.randint(low=1, high=5, size=(1, 6, 6)) + arr[1, :, :] = np.random.randint(low=1, high=5, size=(1, 6, 6)) + arr[2, :, :] = np.random.randint(low=1, high=5, size=(1, 6, 6)) + expected_result = np.divide( + arr[:-1, :2, :2], + arr[-1, :2, :2], + out=np.zeros_like(arr[:-1, :2, :2], dtype=float), + where=arr[:-1, :2, :2] != 0, + ) + produced_result = code.sum_overlapped_chunks( + arr, 4, mock_block_info_top_right_corner + ) + np.testing.assert_array_almost_equal( + np.argmax(expected_result, axis=0), produced_result,decimal=2) + + def test_sum_overlapped_chunks_top_left_corner( + self, mock_block_info_top_left_corner + ): + from geo_inference import geo_dask as code + + arr = np.zeros((3, 6, 6)) + arr[0, :, :] = np.random.randint(low=1, high=5, size=(1, 6, 6)) + arr[1, :, :] = np.random.randint(low=1, high=5, size=(1, 6, 6)) + arr[2, :, :] = np.random.randint(low=1, high=5, size=(1, 6, 6)) + expected_result = np.divide( + arr[:-1, :2, :2], + arr[-1, :2, :2], + out=np.zeros_like(arr[:-1, :2, :2], dtype=float), + where=arr[:-1, :2, :2] != 0, + ) + produced_result = code.sum_overlapped_chunks( + arr, 4, mock_block_info_top_left_corner + ) + np.testing.assert_array_almost_equal( + np.argmax(expected_result, axis=0), produced_result,decimal=2) + + def test_sum_overlapped_chunks_bottom_right_corner( + self, + mock_block_info_bottom_right_corner, + ): + from geo_inference import geo_dask as code + + arr = np.zeros((3, 6, 6)) + arr[0, :, :] = np.random.randint(low=1, high=5, size=(1, 6, 6)) + arr[1, :, :] = np.random.randint(low=1, high=5, size=(1, 6, 6)) + + arr[2, :, :] = np.random.randint(low=1, high=5, size=(1, 6, 6)) + expected_result = np.divide( + arr[:-1, :2, :2], + arr[-1, :2, :2], + out=np.zeros_like(arr[:-1, :2, :2], dtype=float), + where=arr[:-1, :2, :2] != 0, + ) + produced_result = code.sum_overlapped_chunks( + arr, 4, mock_block_info_bottom_right_corner + ) + np.testing.assert_array_almost_equal( + np.argmax(expected_result, axis=0), produced_result,decimal=2) + + def test_sum_overlapped_chunks_bottom_left_corner( + self, mock_block_info_bottom_left_corner + ): + from geo_inference import geo_dask as code + + arr = np.zeros((3, 6, 6)) + arr[0, :, :] = np.random.randint(low=1, high=5, size=(1, 6, 6)) + arr[1, :, :] = np.random.randint(low=1, high=5, size=(1, 6, 6)) + arr[2, :, :] = np.random.randint(low=1, high=5, size=(1, 6, 6)) + expected_result = np.divide( + arr[:-1, :2, :2], + arr[-1, :2, :2], + out=np.zeros_like(arr[:-1, :2, :2], dtype=float), + where=arr[:-1, :2, :2] != 0, + ) + produced_result = code.sum_overlapped_chunks( + arr, 4, mock_block_info_bottom_left_corner + ) + np.testing.assert_array_almost_equal( + np.argmax(expected_result, axis=0), produced_result,decimal=2) + + def test_sum_overlapped_chunks_bottom_edge( + self, + mock_block_info_bottom_edge, + ): + from geo_inference import geo_dask as code + + arr = np.zeros((3, 6, 8)) + arr[0, :, :] = np.random.randint(low=1, high=5, size=(1, 6, 8)) + arr[1, :, :] = np.random.randint(low=1, high=5, size=(1, 6, 8)) + arr[2, :, :] = np.random.randint(low=1, high=5, size=(1, 6, 8)) + + expected_result = np.divide( + arr[:-1, :2, :2] + arr[:-1, :2, 2:4], + arr[-1, :2, :2], + out=np.zeros_like(arr[:-1, :2, :2], dtype=float), + where=arr[:-1, :2, :2] != 0, + ) + + produced_result = code.sum_overlapped_chunks( + arr, 4, mock_block_info_bottom_edge + ) + np.testing.assert_array_almost_equal( + np.argmax(expected_result, axis=0), produced_result,decimal=2) + + def test_sum_overlapped_chunks_right_edge( + self, + mock_block_info_right_edge, + ): + from geo_inference import geo_dask as code + + arr = np.zeros((3, 6, 8)) + arr[0, :, :] = np.random.randint(low=1, high=5, size=(1, 6, 8)) + arr[1, :, :] = np.random.randint(low=1, high=5, size=(1, 6, 8)) + arr[2, :, :] = np.random.randint(low=1, high=5, size=(1, 6, 8)) + + expected_result = np.divide( + arr[:-1, :2, :2] + arr[:-1, 2:4, :2], + arr[-1, :2, :2], + out=np.zeros_like(arr[:-1, :2, :2], dtype=float), + where=arr[:-1, :2, :2] != 0, + ) + + produced_result = code.sum_overlapped_chunks(arr, 4, mock_block_info_right_edge) + np.testing.assert_array_almost_equal( + np.argmax(expected_result, axis=0), produced_result,decimal=2) + + def test_sum_overlapped_chunks_left_edge( + self, + mock_block_info_left_edge, + ): + from geo_inference import geo_dask as code + + arr = np.zeros((3, 6, 8)) + arr[0, :, :] = np.random.randint(low=1, high=5, size=(1, 6, 8)) + arr[1, :, :] = np.random.randint(low=1, high=5, size=(1, 6, 8)) + arr[2, :, :] = np.random.randint(low=1, high=5, size=(1, 6, 8)) + + expected_result = np.divide( + arr[:-1, :2, :2] + arr[:-1, 2:4, :2], + arr[-1, :2, :2], + out=np.zeros_like(arr[:-1, :2, :2], dtype=float), + where=arr[:-1, :2, :2] != 0, + ) + + produced_result = code.sum_overlapped_chunks(arr, 4, mock_block_info_left_edge) + np.testing.assert_array_almost_equal( + np.argmax(expected_result, axis=0), produced_result,decimal=2) + + +class TestModelInference: + + from unittest.mock import patch + + @patch("torch.jit.load") + def test_run_model_inference_left_edge(self, mock_load, mock_block_info_left_edge, generate_corner_windows): + from unittest.mock import MagicMock + from geo_inference import geo_dask as code + + # Mocking parameters + chunk_data = np.zeros((3, 12, 12)) + mock_call_result = MagicMock() + mock_cpu_result = MagicMock() + mock_chunk_size = 4 + mock_num_classes = 3 + mock_numpy_result = np.full((1, 3, 4, 4), fill_value=1) # Example NumPy array + + mock_cpu_result.numpy.return_value = mock_numpy_result + mock_cpu_result.shape = mock_numpy_result.shape + mock_call_result.cpu.return_value = mock_cpu_result + + # Mock TorchScript model and its methods + mock_model = MagicMock( + spec=torch.jit.ScriptModule, return_value=mock_call_result + ) + mock_model.to.return_value = mock_model + # Call the function under test + output = code.runModel( + chunk_data, + mock_model, + mock_chunk_size, + "cpu", + mock_num_classes, + mock_block_info_left_edge, + ) + assert np.array_equal( + output[0, :, :], generate_corner_windows[2, 0, :, :] + ) + assert output.shape[0] == mock_num_classes + 1 + assert output.shape[1:] == (mock_chunk_size, mock_chunk_size) + + @patch("torch.jit.load") + def test_run_model_inference_right_edge( + self, mock_load, mock_block_info_right_edge, generate_corner_windows + ): + from unittest.mock import MagicMock + from geo_inference import geo_dask as code + + # Mocking parameters + chunk_data = np.zeros((3, 12, 12)) + mock_call_result = MagicMock() + mock_cpu_result = MagicMock() + mock_chunk_size = 4 + mock_num_classes = 3 + mock_numpy_result = np.full((1, 3, 4, 4), fill_value=1) # Example NumPy array + + mock_cpu_result.numpy.return_value = mock_numpy_result + mock_cpu_result.shape = mock_numpy_result.shape + mock_call_result.cpu.return_value = mock_cpu_result + + # Mock TorchScript model and its methods + mock_model = MagicMock( + spec=torch.jit.ScriptModule, return_value=mock_call_result + ) + mock_model.to.return_value = mock_model + # Call the function under test + output = code.runModel( + chunk_data, + mock_model, + mock_chunk_size, + "cpu", + mock_num_classes, + mock_block_info_right_edge, + ) + assert np.array_equal( + output[0, :, :], generate_corner_windows[2, 2, :, :] + ) + assert output.shape[0] == mock_num_classes + 1 + assert output.shape[1:] == (mock_chunk_size, mock_chunk_size) + + @patch("torch.jit.load") + def test_run_model_inference_bottom_edge( + self, mock_load, mock_block_info_bottom_edge, generate_corner_windows + ): + from unittest.mock import MagicMock + from geo_inference import geo_dask as code + + # Mocking parameters + chunk_data = np.zeros((3, 12, 12)) + mock_call_result = MagicMock() + mock_cpu_result = MagicMock() + mock_chunk_size = 4 + mock_num_classes = 3 + mock_numpy_result = np.full((1, 3, 4, 4), fill_value=1) # Example NumPy array + + mock_cpu_result.numpy.return_value = mock_numpy_result + mock_cpu_result.shape = mock_numpy_result.shape + mock_call_result.cpu.return_value = mock_cpu_result + + # Mock TorchScript model and its methods + mock_model = MagicMock( + spec=torch.jit.ScriptModule, return_value=mock_call_result + ) + mock_model.to.return_value = mock_model + # Call the function under test + output = code.runModel( + chunk_data, + mock_model, + mock_chunk_size, + "cpu", + mock_num_classes, + mock_block_info_bottom_edge, + ) + assert np.array_equal( + output[0, :, :], generate_corner_windows[2, 2, :, :] + ) + assert output.shape[0] == mock_num_classes + 1 + assert output.shape[1:] == (mock_chunk_size, mock_chunk_size) + + @patch("torch.jit.load") + def test_run_model_inference_bottom_left_corner( + self, mock_load, mock_block_info_bottom_left_corner, generate_corner_windows + ): + from unittest.mock import MagicMock + from geo_inference import geo_dask as code + + # Mocking parameters + chunk_data = np.zeros((3, 12, 12)) + mock_call_result = MagicMock() + mock_cpu_result = MagicMock() + mock_chunk_size = 4 + mock_num_classes = 3 + mock_numpy_result = np.full((1, 3, 4, 4), fill_value=2) # Example NumPy array + + mock_cpu_result.numpy.return_value = mock_numpy_result + mock_cpu_result.shape = mock_numpy_result.shape + mock_call_result.cpu.return_value = mock_cpu_result + + # Mock TorchScript model and its methods + mock_model = MagicMock( + spec=torch.jit.ScriptModule, return_value=mock_call_result + ) + mock_model.to.return_value = mock_model + # Call the function under test + output = code.runModel( + chunk_data, + mock_model, + mock_chunk_size, + "cpu", + mock_num_classes, + mock_block_info_bottom_left_corner, + ) + assert np.array_equal( + output[0, :, :], generate_corner_windows[2,0, :, :] * mock_numpy_result[0, 0, :, :] + ) + assert np.array_equal( + output[3, :, :], generate_corner_windows[2,0, :, :] + ) + assert output.shape[0] == mock_num_classes + 1 + assert output.shape[1:] == (mock_chunk_size, mock_chunk_size) + + @patch("torch.jit.load") + def test_run_model_inference_bottom_right_corner( + self, mock_load, mock_block_info_bottom_right_corner, generate_corner_windows + ): + from unittest.mock import MagicMock + from geo_inference import geo_dask as code + + # Mocking parameters + chunk_data = np.zeros((3, 12, 12)) + mock_call_result = MagicMock() + mock_cpu_result = MagicMock() + mock_chunk_size = 4 + mock_num_classes = 3 + mock_numpy_result = np.full((1, 3, 4, 4), fill_value=2) # Example NumPy array + + mock_cpu_result.numpy.return_value = mock_numpy_result + mock_cpu_result.shape = mock_numpy_result.shape + mock_call_result.cpu.return_value = mock_cpu_result + + # Mock TorchScript model and its methods + mock_model = MagicMock( + spec=torch.jit.ScriptModule, return_value=mock_call_result + ) + mock_model.to.return_value = mock_model + # Call the function under test + output = code.runModel( + chunk_data, + mock_model, + mock_chunk_size, + "cpu", + mock_num_classes, + mock_block_info_bottom_right_corner, + ) + assert np.array_equal( + output[0, :, :], generate_corner_windows[2,2, :, :] * mock_numpy_result[0, 0, :, :] + ) + assert np.array_equal( + output[3, :, :], generate_corner_windows[2,2, :, :] + ) + assert output.shape[0] == mock_num_classes + 1 + assert output.shape[1:] == (mock_chunk_size, mock_chunk_size) + + @patch("torch.jit.load") + def test_run_model_inference_top_left_corner( + self, mock_load, mock_block_info_top_left_corner, generate_corner_windows + ): + from unittest.mock import MagicMock + from geo_inference import geo_dask as code + + # Mocking parameters + chunk_data = np.zeros((3, 12, 12)) + mock_call_result = MagicMock() + mock_cpu_result = MagicMock() + mock_chunk_size = 4 + mock_num_classes = 3 + mock_numpy_result = np.full((1, 3, 4, 4), fill_value=2) # Example NumPy array + + mock_cpu_result.numpy.return_value = mock_numpy_result + mock_cpu_result.shape = mock_numpy_result.shape + mock_call_result.cpu.return_value = mock_cpu_result + + # Mock TorchScript model and its methods + mock_model = MagicMock( + spec=torch.jit.ScriptModule, return_value=mock_call_result + ) + mock_model.to.return_value = mock_model + # Call the function under test + output = code.runModel( + chunk_data, + mock_model, + mock_chunk_size, + "cpu", + mock_num_classes, + mock_block_info_top_left_corner, + ) + assert np.array_equal( + output[0, :, :], generate_corner_windows[0,0, :, :] * mock_numpy_result[0, 0, :, :] + ) + assert np.array_equal( + output[3, :, :], generate_corner_windows[0,0, :, :] + ) + assert output.shape[0] == mock_num_classes + 1 + assert output.shape[1:] == (mock_chunk_size, mock_chunk_size) + + @patch("torch.jit.load") + def test_run_model_inference_top_right_corner( + self, mock_load, mock_block_info_top_right_corner, generate_corner_windows + ): + from unittest.mock import MagicMock + from geo_inference import geo_dask as code + + # Mocking parameters + chunk_data = np.zeros((3, 12, 12)) + mock_call_result = MagicMock() + mock_cpu_result = MagicMock() + mock_chunk_size = 4 + mock_num_classes = 3 + mock_numpy_result = np.full((1, 3, 4, 4), fill_value=2) # Example NumPy array + + mock_cpu_result.numpy.return_value = mock_numpy_result + mock_cpu_result.shape = mock_numpy_result.shape + mock_call_result.cpu.return_value = mock_cpu_result + + # Mock TorchScript model and its methods + mock_model = MagicMock( + spec=torch.jit.ScriptModule, return_value=mock_call_result + ) + mock_model.to.return_value = mock_model + # Call the function under test + output = code.runModel( + chunk_data, + mock_model, + mock_chunk_size, + "cpu", + mock_num_classes, + mock_block_info_top_right_corner, + ) + assert np.array_equal( + output[0, :, :], generate_corner_windows[0,2, :, :] * mock_numpy_result[0, 0, :, :] + ) + assert np.array_equal( + output[3, :, :], generate_corner_windows[0,2, :, :] + ) + assert output.shape[0] == mock_num_classes + 1 + assert output.shape[1:] == (mock_chunk_size, mock_chunk_size) + + @patch("torch.jit.load") + def test_run_model_inference_top_edge( + self, mock_load, mock_block_info_top_edge, generate_corner_windows + ): + from unittest.mock import MagicMock + from geo_inference import geo_dask as code + + # Mocking parameters + chunk_data = np.zeros((3, 12, 12)) + mock_call_result = MagicMock() + mock_cpu_result = MagicMock() + mock_chunk_size = 4 + mock_num_classes = 3 + mock_numpy_result = np.full((1, 3, 4, 4), fill_value=2) # Example NumPy array + + mock_cpu_result.numpy.return_value = mock_numpy_result + mock_cpu_result.shape = mock_numpy_result.shape + mock_call_result.cpu.return_value = mock_cpu_result + + # Mock TorchScript model and its methods + mock_model = MagicMock( + spec=torch.jit.ScriptModule, return_value=mock_call_result + ) + mock_model.to.return_value = mock_model + # Call the function under test + output = code.runModel( + chunk_data, + mock_model, + mock_chunk_size, + "cpu", + mock_num_classes, + mock_block_info_top_edge, + ) + assert np.array_equal( + output[0, :, :], generate_corner_windows[0,2, :, :] * mock_numpy_result[0, 0, :, :] + ) + assert np.array_equal( + output[3, :, :], generate_corner_windows[0,2, :, :] + ) + assert output.shape[0] == mock_num_classes + 1 + assert output.shape[1:] == (mock_chunk_size, mock_chunk_size) + \ No newline at end of file diff --git a/tests/test_geo_inference.py b/tests/test_geo_inference.py index d32898b..758e64e 100644 --- a/tests/test_geo_inference.py +++ b/tests/test_geo_inference.py @@ -1,6 +1,7 @@ import os import pytest import torch + from geo_inference.geo_inference import GeoInference from pathlib import Path @@ -9,35 +10,34 @@ def test_data_dir(): return Path(__file__).parent / "data" class TestGeoInference: + @pytest.fixture def geo_inference(self, test_data_dir): model = str(test_data_dir / "inference"/ "test_model" / "test_model.pt") work_dir = str(test_data_dir / "inference") - batch_size = 1 mask_to_vec = True - vec_to_yolo = True - vec_to_coco = True + mask_to_yolo = True + mask_to_coco = True device = 'cpu' gpu_id = 0 - return GeoInference(model, work_dir, batch_size, mask_to_vec, vec_to_yolo, vec_to_coco, device, gpu_id) + return GeoInference(model, work_dir, mask_to_vec, mask_to_yolo, mask_to_coco, device, gpu_id) def test_init(self, geo_inference, test_data_dir): - assert geo_inference.gpu_id == 0 - assert geo_inference.batch_size == 1 + assert geo_inference.work_dir == test_data_dir / "inference" - assert geo_inference.device == torch.device('cpu') + assert geo_inference.device == 'cpu' assert geo_inference.mask_to_vec == True - assert geo_inference.vec_to_yolo == True - assert geo_inference.vec_to_coco == True + assert geo_inference.mask_to_yolo == True + assert geo_inference.mask_to_coco == True assert isinstance(geo_inference.model, torch.jit.ScriptModule) - assert geo_inference.classes > 0 + assert geo_inference.classes >0 def test_call(self, geo_inference, test_data_dir): tiff_image = test_data_dir / '0.tif' # bbox = '0,0,100,100' - # patch_size = 512 - # stride_size = 256 - geo_inference(str(tiff_image)) + patch_size = 512 + bands_requested="1,2,3" + geo_inference(str(tiff_image), bands_requested, patch_size, None) mask_path = geo_inference.work_dir / "0_mask.tif" assert mask_path.exists() if geo_inference.mask_to_vec: diff --git a/tests/utils/test_geo.py b/tests/utils/test_geo.py index 262be5d..8e6d339 100644 --- a/tests/utils/test_geo.py +++ b/tests/utils/test_geo.py @@ -52,7 +52,6 @@ def test_df_load(self, test_data_dir): csv_path = test_data_dir / "0_yolo.csv" df_1 = df_load(str(csv_path)) assert isinstance(df_1, pd.DataFrame) - json_path = test_data_dir / "0_coco.json" gdf = gdf_load(json_path) assert isinstance(gdf, gpd.GeoDataFrame) diff --git a/tests/utils/test_helpers.py b/tests/utils/test_helpers.py index ce574c7..22f5c64 100644 --- a/tests/utils/test_helpers.py +++ b/tests/utils/test_helpers.py @@ -4,14 +4,12 @@ import rasterio from pathlib import Path from unittest.mock import MagicMock, patch - import pytest -import torch - from geo_inference.utils.helpers import (calculate_gpu_stats, - download_file_from_url, - extract_tar_gz, get_device, + extract_tar_gz, get_directory, get_model, + download_file_from_url, + select_model_device, is_tiff_path, is_tiff_url, read_yaml, validate_asset_type, cmd_interface) @@ -48,12 +46,16 @@ def test_read_yaml(test_data_dir): "bbox": "None", "model": "rgb-4class-segformer", "work_dir": "None", - "batch_size": 1, + "patch_size": 1024, "vec": False, "yolo": False, "coco": False, "device": "gpu", - "gpu_id": 0 + "gpu_id": 0, + "bands_requested": '1,2,3', + "mgpu": False, + "classes": 5, + "n_workers": 20 } def test_validate_asset_type(test_data_dir): @@ -97,8 +99,8 @@ def test_extract_tar_gz(temp_tar_gz_file, test_data_dir): def test_get_device(): with patch('geo_inference.utils.helpers.calculate_gpu_stats') as mock_calculate_gpu_stats: mock_calculate_gpu_stats.return_value = ({"gpu": 10}, {"used": 100, "total": 1024}) - device = get_device(device="gpu", gpu_id=1, gpu_max_ram_usage=20, gpu_max_utilization=10) - assert device == torch.device('cpu') + device = select_model_device(gpu_id=1, multi_gpu=False) + assert device == 'cpu' def test_get_directory(): with patch('pathlib.Path.is_dir', return_value=False), patch('pathlib.Path.mkdir'): @@ -126,17 +128,20 @@ def test_cmd_interface_with_args(monkeypatch, test_data_dir): # Call the function result = cmd_interface() - + assert result == {"image": "./data/areial.tiff", - "bbox": "None", + "bbox": None, + "bands_requested" : "1,2,3", "model": "rgb-4class-segformer", "work_dir": "None", - "batch_size": 1, "vec": False, "yolo": False, "coco": False, "device": "gpu", - "gpu_id": 0 + "gpu_id": 0, + "classes": 5, + "multi_gpu": False, + "patch_size": 1024 } def test_cmd_interface_with_image(monkeypatch): @@ -148,14 +153,17 @@ def test_cmd_interface_with_image(monkeypatch): assert result == { "image": "image.tif", "bbox": None, + "bands_requested" : [], "model": None, "work_dir": None, - "batch_size": 1, + "patch_size": 1024, "vec": False, "yolo": False, "coco": False, "device": "gpu", - "gpu_id": 0 + "gpu_id": 0, + "classes": 5, + "multi_gpu": False, } def test_cmd_interface_no_args(monkeypatch):