Skip to content

Commit

Permalink
fix/add docs and types plus minor refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
AdeelH committed Jan 30, 2023
1 parent c0409a0 commit 7dd92bd
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 63 deletions.
120 changes: 69 additions & 51 deletions rastervision_core/rastervision/core/rv_pipeline/object_detection.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Callable, List, Optional
import logging

from rastervision.core.rv_pipeline.rv_pipeline import RVPipeline
Expand All @@ -10,71 +10,84 @@

if TYPE_CHECKING:
from rastervision.core.backend.backend import Backend
from rastervision.core.data import Labels, Scene
from rastervision.core.data import (Labels, Scene, RasterSource,
ObjectDetectionLabelSource)
from rastervision.core.rv_pipeline.object_detection_config import (
ObjectDetectionChipOptions)

log = logging.getLogger(__name__)


def _make_chip_pos_windows(image_extent, label_store, chip_size):
def _make_chip_pos_windows(image_extent: Box,
label_source: 'ObjectDetectionLabelSource',
chip_size: int) -> List[Box]:
chip_size = chip_size
pos_windows = []
boxes = label_store.get_labels().get_boxes()
boxes = label_source.get_labels().get_boxes()
done_boxes = set()

# Get a random window around each box. If a box was previously included
# in a window, then it is skipped.
for box in boxes:
if box.tuple_format() not in done_boxes:
# If this object is bigger than the chip,
# don't use this box.
if chip_size < box.width or chip_size < box.height:
log.warning(f'Label is larger than chip size: {box} '
'Skipping this label.')
continue

window = box.make_random_square_container(chip_size)
pos_windows.append(window)

# Get boxes that lie completely within window
window_boxes = label_store.get_labels(window=window)
window_boxes = ObjectDetectionLabels.get_overlapping(
window_boxes, window, ioa_thresh=1.0)
window_boxes = window_boxes.get_boxes()
window_boxes = [box.tuple_format() for box in window_boxes]
done_boxes.update(window_boxes)
if box in done_boxes:
continue
# If this object is bigger than the chip, don't use this box.
if chip_size < box.width or chip_size < box.height:
log.warning(f'Label is larger than chip size: {box} '
'Skipping this label.')
continue

window = box.make_random_square_container(chip_size)
pos_windows.append(window)

# Get boxes that lie completely within window
window_boxes = label_source.get_labels(window=window)
window_boxes = ObjectDetectionLabels.get_overlapping(
window_boxes, window, ioa_thresh=1.0)
window_boxes = window_boxes.get_boxes()
done_boxes.update(window_boxes)

return pos_windows


def _make_label_pos_windows(image_extent, label_store, label_buffer):
def _make_label_pos_windows(image_extent: Box,
label_source: 'ObjectDetectionLabelSource',
label_buffer: int) -> List[Box]:
pos_windows = []
for box in label_store.get_labels().get_boxes():
boxes = label_source.get_labels().get_boxes()
for box in boxes:
window = box.buffer(label_buffer, image_extent)
pos_windows.append(window)

return pos_windows


def make_pos_windows(image_extent, label_store, chip_size, window_method,
label_buffer):
def make_pos_windows(
image_extent: Box, label_source: 'ObjectDetectionLabelSource',
chip_size: int, window_method: ObjectDetectionWindowMethod,
label_buffer: Optional[int]) -> List[Box]:
if window_method == ObjectDetectionWindowMethod.chip:
return _make_chip_pos_windows(image_extent, label_store, chip_size)
if window_method == ObjectDetectionWindowMethod.label:
return _make_label_pos_windows(image_extent, label_store, label_buffer)
return _make_chip_pos_windows(image_extent, label_source, chip_size)
elif window_method == ObjectDetectionWindowMethod.label:
if label_buffer is None:
raise ValueError(
'label_buffer must be specified if '
'window_method=ObjectDetectionWindowMethod.label.')
return _make_label_pos_windows(image_extent, label_source,
label_buffer)
elif window_method == ObjectDetectionWindowMethod.image:
return [image_extent.copy()]
else:
raise Exception(
'Window method: {} is cannot be handled.'.format(window_method))
raise NotImplementedError(f'Window method: {window_method}.')


def make_neg_windows(raster_source,
label_store,
chip_size,
nb_windows,
max_attempts,
filter_windows,
chip_nodata_threshold=1.):
def make_neg_windows(raster_source: 'RasterSource',
label_source: 'ObjectDetectionLabelSource',
chip_size: int,
nb_windows: int,
max_attempts: int,
filter_windows: Callable,
chip_nodata_threshold: float = 1.) -> List[Box]:
extent = raster_source.extent
neg_windows = []
for _ in range(max_attempts):
Expand All @@ -84,7 +97,7 @@ def make_neg_windows(raster_source,
break
chip = raster_source.get_chip(window)
labels = ObjectDetectionLabels.get_overlapping(
label_store.get_labels(), window, ioa_thresh=0.2)
label_source.get_labels(), window, ioa_thresh=0.2)

# If no labels and not too many nodata pixels, append the chip
nodata_below_thresh = nodata_below_threshold(
Expand All @@ -98,9 +111,12 @@ def make_neg_windows(raster_source,
return list(neg_windows)


def get_train_windows(scene, chip_opts, chip_size, chip_nodata_threshold=1.):
def get_train_windows(scene: 'Scene',
chip_opts: 'ObjectDetectionChipOptions',
chip_size: int,
chip_nodata_threshold: float = 1.) -> List[Box]:
raster_source = scene.raster_source
label_store = scene.label_source
label_source = scene.label_source

def filter_windows(windows):
if scene.aoi_polygons:
Expand All @@ -110,14 +126,14 @@ def filter_windows(windows):
window_method = chip_opts.window_method
if window_method == ObjectDetectionWindowMethod.sliding:
stride = chip_size
return list(
filter_windows((raster_source.extent.get_windows(
chip_size, stride))))
windows = raster_source.extent.get_windows(chip_size, stride)
return list(filter_windows(windows))

# Make positive windows which contain labels.
pos_windows = filter_windows(
make_pos_windows(raster_source.extent, label_store, chip_size,
chip_opts.window_method, chip_opts.label_buffer))
pos_windows = make_pos_windows(raster_source.extent, label_source,
chip_size, chip_opts.window_method,
chip_opts.label_buffer)
pos_windows = filter_windows(pos_windows)
nb_pos_windows = len(pos_windows)

# Make negative windows which do not contain labels.
Expand All @@ -132,7 +148,7 @@ def filter_windows(windows):
max_attempts = 100 * nb_neg_windows
neg_windows = make_neg_windows(
raster_source,
label_store,
label_source,
chip_size,
nb_neg_windows,
max_attempts,
Expand All @@ -143,14 +159,15 @@ def filter_windows(windows):


class ObjectDetection(RVPipeline):
def get_train_windows(self, scene):
def get_train_windows(self, scene: 'Scene') -> List[Box]:
return get_train_windows(
scene,
self.config.chip_options,
self.config.train_chip_sz,
chip_nodata_threshold=self.config.chip_nodata_threshold)

def get_train_labels(self, window, scene):
def get_train_labels(self, window: Box,
scene: 'Scene') -> ObjectDetectionLabels:
window_labels = scene.label_source.get_labels(window=window)
return ObjectDetectionLabels.get_overlapping(
window_labels,
Expand All @@ -166,7 +183,8 @@ def predict_scene(self, scene: 'Scene', backend: 'Backend') -> 'Labels':
stride = chip_sz // 2
return backend.predict_scene(scene, chip_sz=chip_sz, stride=stride)

def post_process_predictions(self, labels, scene):
def post_process_predictions(self, labels: ObjectDetectionLabels,
scene: 'Scene') -> ObjectDetectionLabels:
return ObjectDetectionLabels.prune_duplicates(
labels,
score_thresh=self.config.predict_options.score_thresh,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from rastervision.core.data.utils import make_od_scene

if TYPE_CHECKING:
from rastervision.core.data import ClassConfig
from rastervision.core.data import ClassConfig, ObjectDetectionLabelSource
log = logging.getLogger(__name__)


Expand Down Expand Up @@ -196,17 +196,23 @@ def __init__(self, *args, **kwargs):
super().__init__(
*args, **kwargs, transform_type=TransformType.object_detection)

self.scene.label_source.ioa_thresh = ioa_thresh
self.scene.label_source.clip = clip
label_source: Optional[
'ObjectDetectionLabelSource'] = self.scene.label_source
if label_source is not None:
label_source.ioa_thresh = ioa_thresh
label_source.clip = clip

if neg_ratio is not None:
if label_source is None:
raise ValueError(
'Scene must have a LabelSource if neg_ratio is set.')
self.neg_probability = neg_ratio / (neg_ratio + 1)
self.neg_ioa_thresh: float = neg_ioa_thresh

# Get labels for the entire scene.
# clip=True here to ensure that any window we draw around a box
# will always lie inside the scene.
self.labels = self.scene.label_source.get_labels(
self.labels: ObjectDetectionLabels = label_source.get_labels(
ioa_thresh=ioa_thresh, clip=True)
self.bboxes = self.labels.get_boxes()
if len(self.bboxes) == 0:
Expand All @@ -215,7 +221,8 @@ def __init__(self, *args, **kwargs):
else:
self.neg_probability = None

def get_resize_transform(self, transform, out_size):
def get_resize_transform(self, transform: A.BasicTransform,
out_size: Tuple[int, int]) -> A.BasicTransform:
resize_tf = A.Resize(*out_size, always_apply=True)
if transform is None:
transform = resize_tf
Expand All @@ -225,11 +232,12 @@ def get_resize_transform(self, transform, out_size):
return transform

def _sample_pos_window(self) -> Box:
"""Sample a window that contains at least one bounding box.
"""Sample a window containing at least one bounding box.
This is done by randomly sampling one of the bounding boxes in the
scene and drawing a random window around it.
"""
bbox = np.random.choice(self.bboxes)
bbox: Box = np.random.choice(self.bboxes)
box_h, box_w = bbox.size

# check if it is possible to sample a containing widnow
Expand All @@ -252,8 +260,8 @@ def _sample_pos_window(self) -> Box:
return window

def _sample_neg_window(self) -> Box:
"""Attempt to sample, within self.max_sample_attempts, a window
containing no bounding boxes.
"""Attempt to sample a window containing no bounding boxes.
If not found within self.max_sample_attempts, just return the last
sampled window.
"""
Expand All @@ -269,9 +277,10 @@ def _sample_neg_window(self) -> Box:
return window

def _sample_window(self) -> Box:
"""If self.neg_probability is specified, sample a negative or positive window
based on that probability. Otherwise, just use RandomWindowGeoDataset's
default window sampling behavior.
"""Sample negative or positive window based on neg_probability, if set.
If neg_probability is not set, use
:meth:`.RandomWindowGeoDataset._sample_window`.
"""
if self.neg_probability is None:
return super()._sample_window()
Expand Down

0 comments on commit 7dd92bd

Please sign in to comment.