Skip to content
This repository has been archived by the owner on Dec 1, 2021. It is now read-only.

Support TFDS format for segmentation #1005

Merged
merged 1 commit into from
Apr 22, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion blueoil/cmd/build_tfds.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from blueoil.utils import config as config_util
from blueoil.utils.tfds_builders.classification import ClassificationBuilder
from blueoil.utils.tfds_builders.object_detection import ObjectDetectionBuilder
from blueoil.utils.tfds_builders.segmentation import SegmentationBuilder


def _get_tfds_settings(config_file):
Expand All @@ -47,7 +48,7 @@ def _get_tfds_builder_class(dataset_class):
raise ValueError("You cannot use dataset classes which is already a TFDS format.")

if issubclass(dataset_class, SegmentationBase):
raise NotImplementedError("A dataset builder for segmentation dataset is not implemented yet.")
return SegmentationBuilder

if issubclass(dataset_class, ObjectDetectionBase):
return ObjectDetectionBuilder
Expand Down
6 changes: 4 additions & 2 deletions blueoil/cmd/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@

from blueoil import environment
from blueoil.common import Tasks
from blueoil.datasets.base import ObjectDetectionBase
from blueoil.datasets.base import ObjectDetectionBase, SegmentationBase
from blueoil.datasets.dataset_iterator import DatasetIterator
from blueoil.datasets.tfds import TFDSClassification, TFDSObjectDetection
from blueoil.datasets.tfds import TFDSClassification, TFDSObjectDetection, TFDSSegmentation
from blueoil.utils import config as config_util
from blueoil.utils import executor
from blueoil.utils import horovod as horovod_util
Expand All @@ -51,6 +51,8 @@ def setup_dataset(config, subset, rank, local_rank):
if tfds_kwargs:
if issubclass(DatasetClass, ObjectDetectionBase):
DatasetClass = TFDSObjectDetection
elif issubclass(DatasetClass, SegmentationBase):
DatasetClass = TFDSSegmentation
else:
DatasetClass = TFDSClassification

Expand Down
49 changes: 48 additions & 1 deletion blueoil/datasets/tfds.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@
import tensorflow as tf
import tensorflow_datasets as tfds

from blueoil.datasets.base import Base, ObjectDetectionBase
from blueoil.datasets.base import Base, ObjectDetectionBase, SegmentationBase
from blueoil.utils.tfds_builders.classification import ClassificationBuilder
from blueoil.utils.tfds_builders.object_detection import ObjectDetectionBuilder
from blueoil.utils.tfds_builders.segmentation import SegmentationBuilder


def _grayscale_to_rgb(record):
Expand Down Expand Up @@ -67,6 +68,13 @@ def _format_object_detection_record(record, image_size, num_max_boxes):
return {"image": image, "label": gt_boxes}


def _format_segmentation_record(record, image_size):
image = tf.image.resize(record["image"], image_size)
segmentation_mask = tf.squeeze(tf.image.resize(record["segmentation_mask"], image_size), axis=2)

return {"image": image, "label": segmentation_mask}


class TFDSMixin:
"""A Mixin to compose dataset classes for TFDS."""
available_subsets = ["train", "validation"]
Expand Down Expand Up @@ -273,3 +281,42 @@ def _format_dataset(self):
lambda record: _format_object_detection_record(record, self._image_size, num_max_boxes),
num_parallel_calls=tf.data.experimental.AUTOTUNE
)


class TFDSSegmentation(TFDSMixin, SegmentationBase):
"""A dataset class for loading TensorFlow Datasets for segmentation.
TensorFlow Datasets which have "label" and "image" features can be loaded by this class.
"""
builder_class = SegmentationBuilder

@property
def classes(self):
return self.info.features["label"].names

@property
def num_classes(self):
return self.info.features["label"].num_classes

def _validate_feature_structure(self):
is_valid = \
"label" in self.info.features and \
"image" in self.info.features and \
"segmentation_mask" in self.info.features and \
isinstance(self.info.features["label"], tfds.features.ClassLabel) and \
isinstance(self.info.features["image"], tfds.features.Image) and \
isinstance(self.info.features["segmentation_mask"], tfds.features.Image)

if not is_valid:
raise ValueError("Datasets should have \"label\", \"image\" and \"segmentation_mask\" features.")

def _format_dataset(self):
if self.info.features['image'].shape[2] == 1:
self.tf_dataset = self.tf_dataset.map(
_grayscale_to_rgb,
num_parallel_calls=tf.data.experimental.AUTOTUNE
)

self.tf_dataset = self.tf_dataset.map(
lambda record: _format_segmentation_record(record, self._image_size),
num_parallel_calls=tf.data.experimental.AUTOTUNE
)
81 changes: 81 additions & 0 deletions blueoil/utils/tfds_builders/segmentation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# -*- coding: utf-8 -*-
# Copyright 2018 The Blueoil Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
import numpy as np
import tensorflow_datasets as tfds


class SegmentationBuilder(tfds.core.GeneratorBasedBuilder):
"""
A custom TFDS builder for segmentation dataset.
This class loads data from existing dataset classes and
generate TFDS formatted dataset which is equivalent to the original one.
See also: https://www.tensorflow.org/datasets/add_dataset
"""

VERSION = tfds.core.Version("0.1.0")

def __init__(self, dataset_name, dataset_class=None, dataset_kwargs=None, **kwargs):
self.name = dataset_name
self.dataset_class = dataset_class
self.dataset_kwargs = dataset_kwargs
super().__init__(**kwargs)

def _info(self):
return tfds.core.DatasetInfo(
builder=self,
description="Custom TFDS dataset for segmentation",
features=tfds.features.FeaturesDict({
"image": tfds.features.Image(),
"label": tfds.features.ClassLabel(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the feature label is used for having a list of classes but not needed for every example. Is my understanding correct?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tfujiwar
Yes! Your understanding is correct! 👍

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you!

"segmentation_mask": tfds.features.Image(shape=(None, None, 1)),
}),
)

def _split_generators(self, dl_manager):
self.info.features["label"].names = self.dataset_class(**self.dataset_kwargs).classes

predefined_names = {
"train": tfds.Split.TRAIN,
"validation": tfds.Split.VALIDATION,
"test": tfds.Split.TEST,
}

splits = []
for subset in self.dataset_class.available_subsets:
dataset = self.dataset_class(subset=subset, **self.dataset_kwargs)
splits.append(
tfds.core.SplitGenerator(
name=predefined_names[subset],
num_shards=self._num_shards(dataset),
gen_kwargs=dict(dataset=dataset)
)
)

return splits

def _generate_examples(self, dataset):
for i, (image, segmentation_mask) in enumerate(dataset):
yield i, {
"image": image,
"segmentation_mask": np.expand_dims(segmentation_mask, axis=2),
"label": -1, # dummy label
}

def _num_shards(self, dataset):
"""Decide a number of shards so as not the size of each shard exceeds 256MiB"""
max_shard_size = 256 * 1024 * 1024 # 256MiB
total_size = sum((image.nbytes + mask.nbytes) for image, mask in dataset)
return (total_size + max_shard_size - 1) // max_shard_size
60 changes: 59 additions & 1 deletion tests/unit/executor_tests/test_build_tfds.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from blueoil.cmd.train import run as train_run
from blueoil import environment
from blueoil.datasets.dataset_iterator import DatasetIterator
from blueoil.datasets.tfds import TFDSClassification, TFDSObjectDetection
from blueoil.datasets.tfds import TFDSClassification, TFDSObjectDetection, TFDSSegmentation
from blueoil.utils import config as config_util

_RUN_AS_A_SCRIPT = False
Expand Down Expand Up @@ -151,3 +151,61 @@ def test_build_tfds_object_detection():
assert labels.shape[0] == config.BATCH_SIZE
assert labels.shape[1] == num_max_boxes
assert labels.shape[2] == 5


def test_build_tfds_segmentation():
environment.setup_test_environment()

# Build TFDS Dataset
config_file = "unit/fixtures/configs/for_build_tfds_segmentation.py"
run(config_file, overwrite=True)

# Check if the builded dataset can be loaded with the same config file
expriment_id = "tfds_segmentation"
train_run(None, None, config_file, expriment_id, recreate=True)

# Check if the dataset was build correctly
train_data_num = 5
validation_data_num = 5
config = config_util.load(config_file)

train_dataset = setup_dataset(TFDSSegmentation,
subset="train",
batch_size=config.BATCH_SIZE,
**config.DATASET.TFDS_KWARGS)

validation_dataset = setup_dataset(TFDSSegmentation,
subset="validation",
batch_size=config.BATCH_SIZE,
**config.DATASET.TFDS_KWARGS)

assert train_dataset.num_per_epoch == train_data_num
assert validation_dataset.num_per_epoch == validation_data_num

for _ in range(train_data_num):
images, labels = train_dataset.feed()

assert isinstance(images, np.ndarray)
assert images.shape[0] == config.BATCH_SIZE
assert images.shape[1] == config.IMAGE_SIZE[0]
assert images.shape[2] == config.IMAGE_SIZE[1]
assert images.shape[3] == 3

assert isinstance(labels, np.ndarray)
assert labels.shape[0] == config.BATCH_SIZE
assert labels.shape[1] == config.IMAGE_SIZE[0]
assert labels.shape[2] == config.IMAGE_SIZE[1]

for _ in range(validation_data_num):
images, labels = validation_dataset.feed()

assert isinstance(images, np.ndarray)
assert images.shape[0] == config.BATCH_SIZE
assert images.shape[1] == config.IMAGE_SIZE[0]
assert images.shape[2] == config.IMAGE_SIZE[1]
assert images.shape[3] == 3

assert isinstance(labels, np.ndarray)
assert labels.shape[0] == config.BATCH_SIZE
assert labels.shape[1] == config.IMAGE_SIZE[0]
assert labels.shape[2] == config.IMAGE_SIZE[1]
106 changes: 106 additions & 0 deletions tests/unit/fixtures/configs/for_build_tfds_segmentation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# -*- coding: utf-8 -*-
# Copyright 2018 The Blueoil Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
from easydict import EasyDict
import tensorflow as tf

from blueoil.common import Tasks
from blueoil.networks.segmentation.lm_segnet_v1 import LmSegnetV1Quantize
from blueoil.datasets.camvid import CamvidCustom
from blueoil.data_processor import Sequence
from blueoil.pre_processor import (
Resize,
DivideBy255,
)
from blueoil.data_augmentor import (
Brightness,
Color,
Contrast,
FlipLeftRight,
Hue,
)
from blueoil.quantizations import (
binary_mean_scaling_quantizer,
linear_mid_tread_half_quantizer,
)


class SegmentationDataset(CamvidCustom):
extend_dir = "camvid_custom"
validation_extend_dir = "camvid_custom"


IS_DEBUG = False

NETWORK_CLASS = LmSegnetV1Quantize
DATASET_CLASS = SegmentationDataset

IMAGE_SIZE = [80, 120]
BATCH_SIZE = 8
DATA_FORMAT = "NHWC"
TASK = Tasks.SEMANTIC_SEGMENTATION
CLASSES = DATASET_CLASS(subset="train", batch_size=1).classes

MAX_STEPS = 1
SAVE_CHECKPOINT_STEPS = 1
KEEP_CHECKPOINT_MAX = 5
TEST_STEPS = 100
SUMMARISE_STEPS = 100

# distributed training
IS_DISTRIBUTION = False

# pretrain
IS_PRETRAIN = False
PRETRAIN_VARS = []
PRETRAIN_DIR = ""
PRETRAIN_FILE = ""

PRE_PROCESSOR = Sequence([
Resize(size=IMAGE_SIZE),
DivideBy255(),
])
POST_PROCESSOR = None

NETWORK = EasyDict()
NETWORK.OPTIMIZER_CLASS = tf.compat.v1.train.AdamOptimizer
NETWORK.OPTIMIZER_KWARGS = {"learning_rate": 0.001}
NETWORK.IMAGE_SIZE = IMAGE_SIZE
NETWORK.BATCH_SIZE = BATCH_SIZE
NETWORK.DATA_FORMAT = DATA_FORMAT
NETWORK.ACTIVATION_QUANTIZER = linear_mid_tread_half_quantizer
NETWORK.ACTIVATION_QUANTIZER_KWARGS = {
'bit': 2,
'max_value': 2
}
NETWORK.WEIGHT_QUANTIZER = binary_mean_scaling_quantizer
NETWORK.WEIGHT_QUANTIZER_KWARGS = {}

DATASET = EasyDict()
DATASET.BATCH_SIZE = BATCH_SIZE
DATASET.DATA_FORMAT = DATA_FORMAT
DATASET.PRE_PROCESSOR = PRE_PROCESSOR
DATASET.AUGMENTOR = Sequence([
Brightness((0.75, 1.25)),
Color((0.75, 1.25)),
Contrast((0.75, 1.25)),
FlipLeftRight(),
Hue((-10, 10)),
])
DATASET.TFDS_KWARGS = {
"name": "tfds_segmentation",
"data_dir": "tmp/tests/datasets",
"image_size": IMAGE_SIZE,
}