Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BACKPORT] Improve docstrings for most commonly used classes and configs #1638

Merged
merged 4 commits into from
Dec 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def setup(app: 'Sphinx') -> None:

os.environ['GDAL_DATA'] = check_output('pip show rasterio | grep Location | awk \'{print $NF"/rasterio/gdal_data/"}\'', shell=True).decode().strip()
os.environ['AWS_NO_SIGN_REQUEST'] = 'YES'

See this `Colab notebook <https://colab.research.google.com/drive/1qUl_6McbLJr9KjrhHnx_SWbSYkLPL2uW>`__ for an example.

""" # noqa
Expand Down
10 changes: 5 additions & 5 deletions docs/framework/pipelines.rst
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ The following table shows the corresponding ``Configs`` for various commonly use
- :class:`~data.raster_source.multi_raster_source_config.MultiRasterSourceConfig`
- :class:`~data.raster_source.rasterized_source_config.RasterizedSourceConfig`

-
-
* - :class:`~data.raster_transformer.raster_transformer.RasterTransformer`

- :class:`~data.raster_transformer.cast_transformer.CastTransformer`
Expand All @@ -150,7 +150,7 @@ The following table shows the corresponding ``Configs`` for various commonly use
- :class:`~data.raster_transformer.rgb_class_transformer_config.RGBClassTransformerConfig`
- :class:`~data.raster_transformer.stats_transformer_config.StatsTransformerConfig`

-
-
* - :class:`~data.vector_source.vector_source.VectorSource`

- :class:`~data.vector_source.geojson_vector_source.GeoJSONVectorSource`
Expand All @@ -159,7 +159,7 @@ The following table shows the corresponding ``Configs`` for various commonly use

- :class:`~data.vector_source.geojson_vector_source_config.GeoJSONVectorSourceConfig`

-
-
* - :class:`~data.vector_transformer.vector_transformer.VectorTransformer`

- :class:`~data.vector_transformer.buffer_transformer.BufferTransformer`
Expand All @@ -172,7 +172,7 @@ The following table shows the corresponding ``Configs`` for various commonly use
- :class:`~data.vector_transformer.class_inference_transformer_config.ClassInferenceTransformerConfig`
- :class:`~data.vector_transformer.shift_transformer_config.ShiftTransformerConfig`

-
-
* - :class:`~data.label_source.label_source.LabelSource`

- :class:`~data.label_source.chip_classification_label_source.ChipClassificationLabelSource`
Expand All @@ -185,7 +185,7 @@ The following table shows the corresponding ``Configs`` for various commonly use
- :class:`~data.label_source.semantic_segmentation_label_source_config.SemanticSegmentationLabelSourceConfig`
- :class:`~data.label_source.object_detection_label_source_config.ObjectDetectionLabelSourceConfig`

-
-
* - :class:`~data.label_store.label_store.LabelStore`

- :class:`~data.label_store.chip_classification_geojson_store.ChipClassificationGeoJSONStore`
Expand Down
8 changes: 4 additions & 4 deletions docs/framework/quickstart.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,17 +38,17 @@ The Data
.. raw:: html

<div style="position: relative; padding-bottom: 56.25%; overflow: hidden; max-width: 100%;">
<iframe
src="../_static/tiny-spacenet-map.html"
frameborder="0"
<iframe
src="../_static/tiny-spacenet-map.html"
frameborder="0"
allowfullscreen style="position: absolute; top: 0; left: 0; width: 100%; height: 100%;">
</iframe>
</div>

Configuring a semantic segmentation pipeline
----------------------------------------------

Create a Python file in the ``${RV_QUICKSTART_CODE_DIR}`` named ``tiny_spacenet.py``. Inside, you're going to write a function called ``get_config`` that returns a ``SemanticSegmentationConfig`` object. This object's type is a subclass of ``PipelineConfig``, and configures a semantic segmentation pipeline which (optionally) analyzes the imagery, (optionally) creates training chips, trains a model, makes predictions on validation scenes, evaluates the predictions, and saves a model bundle.
Create a Python file in the ``${RV_QUICKSTART_CODE_DIR}`` named ``tiny_spacenet.py``. Inside, you're going to write a function called ``get_config`` that returns a :class:`~rastervision.core.rv_pipeline.semantic_segmentation_config.SemanticSegmentationConfig` object. This object's type is a subclass of :class:`~rastervision.pipeline.pipeline_config.PipelineConfig`, and configures a semantic segmentation pipeline which (optionally) analyzes the imagery, (optionally) creates training chips, trains a model, makes predictions on validation scenes, evaluates the predictions, and saves a model bundle.

.. literalinclude:: /../rastervision_pytorch_backend/rastervision/pytorch_backend/examples/tiny_spacenet.py
:language: python
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

@register_config('analyzer')
class AnalyzerConfig(Config):
"""Configure an :class:`.Analyzer`."""

def build(self, scene_group: Optional[Tuple[str, Iterable[str]]] = None
) -> 'Analyzer':
pass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@


class StatsAnalyzer(Analyzer):
"""Computes RasterStats against the entire scene set."""
"""Compute imagery statistics of scenes."""

def __init__(self,
stats_uri: Optional[str] = None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@

@register_config('stats_analyzer')
class StatsAnalyzerConfig(AnalyzerConfig):
"""Config for an Analyzer that computes imagery statistics of scenes."""
"""Configure a :class:`.StatsAnalyzer`.

A :class:`.StatsAnalyzer` computes imagery statistics of scenes which can
be used to normalize chips read from them.
"""

output_uri: Optional[str] = Field(
None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

@register_config('backend')
class BackendConfig(Config):
"""Configuration for Backend."""
"""Configure a :class:`.Backend`."""

def build(self, pipeline: 'RVPipeline', tmp_dir: str) -> 'Backend':
raise NotImplementedError()
Expand Down
2 changes: 1 addition & 1 deletion rastervision_core/rastervision/core/box.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class BoxSizeError(ValueError):


class Box():
"""A multi-purpose box (ie. rectangle)."""
"""A multi-purpose box (ie. rectangle) representation ."""

def __init__(self, ymin, xmin, ymax, xmax):
"""Construct a bounding box.
Expand Down
3 changes: 2 additions & 1 deletion rastervision_core/rastervision/core/data/class_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@

@register_config('class_config')
class ClassConfig(Config):
"""Configures the class names that are being predicted."""
"""Configure class information for a machine learning task."""

names: List[str] = Field(
...,
description='Names of classes. The i-th class in this list will have '
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
class CRSTransformer(ABC):
"""Transforms map points in some CRS into pixel coordinates.

Each transformer is associated with a particular RasterSource.
Each transformer is associated with a particular :class:`.RasterSource`.
"""

def __init__(self,
Expand Down
2 changes: 1 addition & 1 deletion rastervision_core/rastervision/core/data/dataset_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def dataset_config_upgrader(cfg_dict: dict, version: int) -> dict:

@register_config('dataset', upgrader=dataset_config_upgrader)
class DatasetConfig(Config):
"""Config for a Dataset comprising the scenes for train, valid, and test splits."""
"""Configure train, validation, and test splits for a dataset."""
class_config: ClassConfig
train_scenes: List[SceneConfig]
validation_scenes: List[SceneConfig]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ def cc_label_source_config_upgrader(cfg_dict: dict, version: int) -> dict:
'chip_classification_label_source',
upgrader=cc_label_source_config_upgrader)
class ChipClassificationLabelSourceConfig(LabelSourceConfig):
"""Config for a source of labels for chip classification.
"""Configure a :class:`.ChipClassificationLabelSource`.

This can be provided explicitly as a grid of cells, or a grid of cells can be
inferred from arbitrary polygons.
This can be provided explicitly as a grid of cells, or a grid of cells can
be inferred from arbitrary polygons.
"""
vector_source: Optional[VectorSourceConfig] = None
ioa_thresh: float = Field(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

@register_config('label_source')
class LabelSourceConfig(Config):
"""Configure a :class:`.LabelSource`."""

def build(self, class_config, crs_transformer, extent, tmp_dir):
raise NotImplementedError()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@

@register_config('object_detection_label_source')
class ObjectDetectionLabelSourceConfig(LabelSourceConfig):
"""Config for a read-only label source for object detection."""
"""Configure an :class:`.ObjectDetectionLabelSource`."""

vector_source: VectorSourceConfig

@validator('vector_source')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ def ss_label_source_config_upgrader(cfg_dict: dict, version: int) -> dict:
'semantic_segmentation_label_source',
upgrader=ss_label_source_config_upgrader)
class SemanticSegmentationLabelSourceConfig(LabelSourceConfig):
"""Config for a read-only label source for semantic segmentation."""
"""Configure a :class:`.SemanticSegmentationLabelSource`."""

raster_source: Union[RasterSourceConfig, RasterizedSourceConfig] = Field(
..., description='The labels in the form of rasters.')

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

@register_config('chip_classification_geojson_store')
class ChipClassificationGeoJSONStoreConfig(LabelStoreConfig):
"""Config for storage for chip classification predictions."""
"""Configure a :class:`.ChipClassificationGeoJSONStore`."""

uri: Optional[str] = Field(
None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@


class LabelStore(ABC):
"""This defines how to store prediction labels are stored for a scene.
"""
"""This defines how to store prediction labels for a scene."""

@abstractmethod
def save(self, labels):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

@register_config('label_store')
class LabelStoreConfig(Config):
"""Configure a :class:`.LabelStore`."""

def build(self, class_config, crs_transformer, extent, tmp_dir):
pass

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

@register_config('object_detection_geojson_store')
class ObjectDetectionGeoJSONStoreConfig(LabelStoreConfig):
"""Config for storage for object detection predictions."""
"""Configure an :class:`.ObjectDetectionGeoJSONStore`."""

uri: Optional[str] = Field(
None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@
class SemanticSegmentationLabelStore(LabelStore):
"""Storage for semantic segmentation predictions.

Stores class raster as GeoTIFF, and can optionally vectorizes predictions and stores
them in GeoJSON files.
Can store predicted class ID raster and class scores raster as GeoTIFFs,
and can optionally vectorize predictions and store them as GeoJSON files.
"""

def __init__(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def get_mode(self) -> str:

@register_config('semantic_segmentation_label_store')
class SemanticSegmentationLabelStoreConfig(LabelStoreConfig):
"""Config for storage for semantic segmentation predictions.
"""Configure a :class:`.SemanticSegmentationLabelStore`.

Stores class raster as GeoTIFF, and can optionally vectorizes predictions and stores
them in GeoJSON files.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,8 @@
from rastervision.core.data.utils import all_equal


class MultiRasterSourceError(Exception):
pass


class MultiRasterSource(RasterSource):
"""A RasterSource that combines multiple RasterSources by concatenting
their output along the channel dimension (assumed to be the last dimension).
"""
"""Merge multiple ``RasterSources`` by concatenating along channel dim."""

def __init__(self,
raster_sources: Sequence[RasterSource],
Expand Down Expand Up @@ -70,9 +64,18 @@ def __init__(self,
self.validate_raster_sources()

def validate_raster_sources(self) -> None:
"""Validate sub-``RasterSources``.

Checks if:

- dtypes are same or ``force_same_dtype`` is True.
- each sub-``RasterSource`` is a :class:`.RasterioSource` if extents
not identical.

"""
dtypes = [rs.dtype for rs in self.raster_sources]
if not self.force_same_dtype and not all_equal(dtypes):
raise MultiRasterSourceError(
raise ValueError(
'dtypes of all sub raster sources must be the same. '
f'Got: {dtypes} '
'(Use force_same_dtype to cast all to the dtype of the '
Expand All @@ -81,18 +84,13 @@ def validate_raster_sources(self) -> None:
all_rasterio_sources = all(
isinstance(rs, RasterioSource) for rs in self.raster_sources)
if not all_rasterio_sources:
raise MultiRasterSourceError(
'Non-identical extents are only supported '
'for RasterioSource raster sources.')

sub_num_channels = sum(rs.num_channels for rs in self.raster_sources)
if sub_num_channels != self.num_channels:
raise MultiRasterSourceError(
f'num_channels ({self.num_channels}) != sum of num_channels '
f'of sub raster sources ({sub_num_channels})')
raise NotImplementedError(
'Non-identical extents are only '
'supported for RasterioSource raster sources.')

@property
def primary_source(self) -> RasterSource:
"""Primary sub-``RasterSource``"""
return self.raster_sources[self.primary_source_idx]

@property
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ def multi_rs_config_upgrader(cfg_dict: dict, version: int) -> dict:

@register_config('multi_raster_source', upgrader=multi_rs_config_upgrader)
class MultiRasterSourceConfig(RasterSourceConfig):
"""Configure a :class:`.MultiRasterSource`."""

raster_sources: conlist(
RasterSourceConfig, min_items=1) = Field(
..., description='List of RasterSourceConfig to combine.')
Expand Down
Loading