Skip to content

Commit

Permalink
Merge pull request #1638 from AdeelH/backport-docstrs
Browse files Browse the repository at this point in the history
[BACKPORT] Improve docstrings for most commonly used classes and configs
  • Loading branch information
AdeelH authored Dec 29, 2022
2 parents 4ac849e + d6f1a41 commit 368ba56
Show file tree
Hide file tree
Showing 75 changed files with 356 additions and 165 deletions.
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def setup(app: 'Sphinx') -> None:
os.environ['GDAL_DATA'] = check_output('pip show rasterio | grep Location | awk \'{print $NF"/rasterio/gdal_data/"}\'', shell=True).decode().strip()
os.environ['AWS_NO_SIGN_REQUEST'] = 'YES'
See this `Colab notebook <https://colab.research.google.com/drive/1qUl_6McbLJr9KjrhHnx_SWbSYkLPL2uW>`__ for an example.
""" # noqa
Expand Down
10 changes: 5 additions & 5 deletions docs/framework/pipelines.rst
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ The following table shows the corresponding ``Configs`` for various commonly use
- :class:`~data.raster_source.multi_raster_source_config.MultiRasterSourceConfig`
- :class:`~data.raster_source.rasterized_source_config.RasterizedSourceConfig`

-
-
* - :class:`~data.raster_transformer.raster_transformer.RasterTransformer`

- :class:`~data.raster_transformer.cast_transformer.CastTransformer`
Expand All @@ -150,7 +150,7 @@ The following table shows the corresponding ``Configs`` for various commonly use
- :class:`~data.raster_transformer.rgb_class_transformer_config.RGBClassTransformerConfig`
- :class:`~data.raster_transformer.stats_transformer_config.StatsTransformerConfig`

-
-
* - :class:`~data.vector_source.vector_source.VectorSource`

- :class:`~data.vector_source.geojson_vector_source.GeoJSONVectorSource`
Expand All @@ -159,7 +159,7 @@ The following table shows the corresponding ``Configs`` for various commonly use

- :class:`~data.vector_source.geojson_vector_source_config.GeoJSONVectorSourceConfig`

-
-
* - :class:`~data.vector_transformer.vector_transformer.VectorTransformer`

- :class:`~data.vector_transformer.buffer_transformer.BufferTransformer`
Expand All @@ -172,7 +172,7 @@ The following table shows the corresponding ``Configs`` for various commonly use
- :class:`~data.vector_transformer.class_inference_transformer_config.ClassInferenceTransformerConfig`
- :class:`~data.vector_transformer.shift_transformer_config.ShiftTransformerConfig`

-
-
* - :class:`~data.label_source.label_source.LabelSource`

- :class:`~data.label_source.chip_classification_label_source.ChipClassificationLabelSource`
Expand All @@ -185,7 +185,7 @@ The following table shows the corresponding ``Configs`` for various commonly use
- :class:`~data.label_source.semantic_segmentation_label_source_config.SemanticSegmentationLabelSourceConfig`
- :class:`~data.label_source.object_detection_label_source_config.ObjectDetectionLabelSourceConfig`

-
-
* - :class:`~data.label_store.label_store.LabelStore`

- :class:`~data.label_store.chip_classification_geojson_store.ChipClassificationGeoJSONStore`
Expand Down
8 changes: 4 additions & 4 deletions docs/framework/quickstart.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,17 +38,17 @@ The Data
.. raw:: html

<div style="position: relative; padding-bottom: 56.25%; overflow: hidden; max-width: 100%;">
<iframe
src="../_static/tiny-spacenet-map.html"
frameborder="0"
<iframe
src="../_static/tiny-spacenet-map.html"
frameborder="0"
allowfullscreen style="position: absolute; top: 0; left: 0; width: 100%; height: 100%;">
</iframe>
</div>

Configuring a semantic segmentation pipeline
----------------------------------------------

Create a Python file in the ``${RV_QUICKSTART_CODE_DIR}`` named ``tiny_spacenet.py``. Inside, you're going to write a function called ``get_config`` that returns a ``SemanticSegmentationConfig`` object. This object's type is a subclass of ``PipelineConfig``, and configures a semantic segmentation pipeline which (optionally) analyzes the imagery, (optionally) creates training chips, trains a model, makes predictions on validation scenes, evaluates the predictions, and saves a model bundle.
Create a Python file in the ``${RV_QUICKSTART_CODE_DIR}`` named ``tiny_spacenet.py``. Inside, you're going to write a function called ``get_config`` that returns a :class:`~rastervision.core.rv_pipeline.semantic_segmentation_config.SemanticSegmentationConfig` object. This object's type is a subclass of :class:`~rastervision.pipeline.pipeline_config.PipelineConfig`, and configures a semantic segmentation pipeline which (optionally) analyzes the imagery, (optionally) creates training chips, trains a model, makes predictions on validation scenes, evaluates the predictions, and saves a model bundle.

.. literalinclude:: /../rastervision_pytorch_backend/rastervision/pytorch_backend/examples/tiny_spacenet.py
:language: python
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

@register_config('analyzer')
class AnalyzerConfig(Config):
"""Configure an :class:`.Analyzer`."""

def build(self, scene_group: Optional[Tuple[str, Iterable[str]]] = None
) -> 'Analyzer':
pass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@


class StatsAnalyzer(Analyzer):
"""Computes RasterStats against the entire scene set."""
"""Compute imagery statistics of scenes."""

def __init__(self,
stats_uri: Optional[str] = None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@

@register_config('stats_analyzer')
class StatsAnalyzerConfig(AnalyzerConfig):
"""Config for an Analyzer that computes imagery statistics of scenes."""
"""Configure a :class:`.StatsAnalyzer`.
A :class:`.StatsAnalyzer` computes imagery statistics of scenes which can
be used to normalize chips read from them.
"""

output_uri: Optional[str] = Field(
None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

@register_config('backend')
class BackendConfig(Config):
"""Configuration for Backend."""
"""Configure a :class:`.Backend`."""

def build(self, pipeline: 'RVPipeline', tmp_dir: str) -> 'Backend':
raise NotImplementedError()
Expand Down
2 changes: 1 addition & 1 deletion rastervision_core/rastervision/core/box.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class BoxSizeError(ValueError):


class Box():
"""A multi-purpose box (ie. rectangle)."""
"""A multi-purpose box (ie. rectangle) representation ."""

def __init__(self, ymin, xmin, ymax, xmax):
"""Construct a bounding box.
Expand Down
3 changes: 2 additions & 1 deletion rastervision_core/rastervision/core/data/class_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@

@register_config('class_config')
class ClassConfig(Config):
"""Configures the class names that are being predicted."""
"""Configure class information for a machine learning task."""

names: List[str] = Field(
...,
description='Names of classes. The i-th class in this list will have '
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
class CRSTransformer(ABC):
"""Transforms map points in some CRS into pixel coordinates.
Each transformer is associated with a particular RasterSource.
Each transformer is associated with a particular :class:`.RasterSource`.
"""

def __init__(self,
Expand Down
2 changes: 1 addition & 1 deletion rastervision_core/rastervision/core/data/dataset_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def dataset_config_upgrader(cfg_dict: dict, version: int) -> dict:

@register_config('dataset', upgrader=dataset_config_upgrader)
class DatasetConfig(Config):
"""Config for a Dataset comprising the scenes for train, valid, and test splits."""
"""Configure train, validation, and test splits for a dataset."""
class_config: ClassConfig
train_scenes: List[SceneConfig]
validation_scenes: List[SceneConfig]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ def cc_label_source_config_upgrader(cfg_dict: dict, version: int) -> dict:
'chip_classification_label_source',
upgrader=cc_label_source_config_upgrader)
class ChipClassificationLabelSourceConfig(LabelSourceConfig):
"""Config for a source of labels for chip classification.
"""Configure a :class:`.ChipClassificationLabelSource`.
This can be provided explicitly as a grid of cells, or a grid of cells can be
inferred from arbitrary polygons.
This can be provided explicitly as a grid of cells, or a grid of cells can
be inferred from arbitrary polygons.
"""
vector_source: Optional[VectorSourceConfig] = None
ioa_thresh: float = Field(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

@register_config('label_source')
class LabelSourceConfig(Config):
"""Configure a :class:`.LabelSource`."""

def build(self, class_config, crs_transformer, extent, tmp_dir):
raise NotImplementedError()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@

@register_config('object_detection_label_source')
class ObjectDetectionLabelSourceConfig(LabelSourceConfig):
"""Config for a read-only label source for object detection."""
"""Configure an :class:`.ObjectDetectionLabelSource`."""

vector_source: VectorSourceConfig

@validator('vector_source')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ def ss_label_source_config_upgrader(cfg_dict: dict, version: int) -> dict:
'semantic_segmentation_label_source',
upgrader=ss_label_source_config_upgrader)
class SemanticSegmentationLabelSourceConfig(LabelSourceConfig):
"""Config for a read-only label source for semantic segmentation."""
"""Configure a :class:`.SemanticSegmentationLabelSource`."""

raster_source: Union[RasterSourceConfig, RasterizedSourceConfig] = Field(
..., description='The labels in the form of rasters.')

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

@register_config('chip_classification_geojson_store')
class ChipClassificationGeoJSONStoreConfig(LabelStoreConfig):
"""Config for storage for chip classification predictions."""
"""Configure a :class:`.ChipClassificationGeoJSONStore`."""

uri: Optional[str] = Field(
None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@


class LabelStore(ABC):
"""This defines how to store prediction labels are stored for a scene.
"""
"""This defines how to store prediction labels for a scene."""

@abstractmethod
def save(self, labels):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

@register_config('label_store')
class LabelStoreConfig(Config):
"""Configure a :class:`.LabelStore`."""

def build(self, class_config, crs_transformer, extent, tmp_dir):
pass

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

@register_config('object_detection_geojson_store')
class ObjectDetectionGeoJSONStoreConfig(LabelStoreConfig):
"""Config for storage for object detection predictions."""
"""Configure an :class:`.ObjectDetectionGeoJSONStore`."""

uri: Optional[str] = Field(
None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@
class SemanticSegmentationLabelStore(LabelStore):
"""Storage for semantic segmentation predictions.
Stores class raster as GeoTIFF, and can optionally vectorizes predictions and stores
them in GeoJSON files.
Can store predicted class ID raster and class scores raster as GeoTIFFs,
and can optionally vectorize predictions and store them as GeoJSON files.
"""

def __init__(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def get_mode(self) -> str:

@register_config('semantic_segmentation_label_store')
class SemanticSegmentationLabelStoreConfig(LabelStoreConfig):
"""Config for storage for semantic segmentation predictions.
"""Configure a :class:`.SemanticSegmentationLabelStore`.
Stores class raster as GeoTIFF, and can optionally vectorizes predictions and stores
them in GeoJSON files.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,8 @@
from rastervision.core.data.utils import all_equal


class MultiRasterSourceError(Exception):
pass


class MultiRasterSource(RasterSource):
"""A RasterSource that combines multiple RasterSources by concatenting
their output along the channel dimension (assumed to be the last dimension).
"""
"""Merge multiple ``RasterSources`` by concatenating along channel dim."""

def __init__(self,
raster_sources: Sequence[RasterSource],
Expand Down Expand Up @@ -70,9 +64,18 @@ def __init__(self,
self.validate_raster_sources()

def validate_raster_sources(self) -> None:
"""Validate sub-``RasterSources``.
Checks if:
- dtypes are same or ``force_same_dtype`` is True.
- each sub-``RasterSource`` is a :class:`.RasterioSource` if extents
not identical.
"""
dtypes = [rs.dtype for rs in self.raster_sources]
if not self.force_same_dtype and not all_equal(dtypes):
raise MultiRasterSourceError(
raise ValueError(
'dtypes of all sub raster sources must be the same. '
f'Got: {dtypes} '
'(Use force_same_dtype to cast all to the dtype of the '
Expand All @@ -81,18 +84,13 @@ def validate_raster_sources(self) -> None:
all_rasterio_sources = all(
isinstance(rs, RasterioSource) for rs in self.raster_sources)
if not all_rasterio_sources:
raise MultiRasterSourceError(
'Non-identical extents are only supported '
'for RasterioSource raster sources.')

sub_num_channels = sum(rs.num_channels for rs in self.raster_sources)
if sub_num_channels != self.num_channels:
raise MultiRasterSourceError(
f'num_channels ({self.num_channels}) != sum of num_channels '
f'of sub raster sources ({sub_num_channels})')
raise NotImplementedError(
'Non-identical extents are only '
'supported for RasterioSource raster sources.')

@property
def primary_source(self) -> RasterSource:
"""Primary sub-``RasterSource``"""
return self.raster_sources[self.primary_source_idx]

@property
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ def multi_rs_config_upgrader(cfg_dict: dict, version: int) -> dict:

@register_config('multi_raster_source', upgrader=multi_rs_config_upgrader)
class MultiRasterSourceConfig(RasterSourceConfig):
"""Configure a :class:`.MultiRasterSource`."""

raster_sources: conlist(
RasterSourceConfig, min_items=1) = Field(
..., description='List of RasterSourceConfig to combine.')
Expand Down
Loading

0 comments on commit 368ba56

Please sign in to comment.