From 7feb191ca6ba7fb40ffef2c13524b4a985fc9275 Mon Sep 17 00:00:00 2001 From: Hideaki Masuda Date: Mon, 20 Apr 2020 09:18:25 +0900 Subject: [PATCH] support TFDS format for segmentation --- blueoil/cmd/build_tfds.py | 3 +- blueoil/cmd/train.py | 6 +- blueoil/datasets/tfds.py | 49 +++++++- blueoil/utils/tfds_builders/segmentation.py | 81 +++++++++++++ tests/unit/executor_tests/test_build_tfds.py | 60 +++++++++- .../configs/for_build_tfds_segmentation.py | 106 ++++++++++++++++++ 6 files changed, 300 insertions(+), 5 deletions(-) create mode 100644 blueoil/utils/tfds_builders/segmentation.py create mode 100644 tests/unit/fixtures/configs/for_build_tfds_segmentation.py diff --git a/blueoil/cmd/build_tfds.py b/blueoil/cmd/build_tfds.py index b10ca5668..f13b05eba 100644 --- a/blueoil/cmd/build_tfds.py +++ b/blueoil/cmd/build_tfds.py @@ -24,6 +24,7 @@ from blueoil.utils import config as config_util from blueoil.utils.tfds_builders.classification import ClassificationBuilder from blueoil.utils.tfds_builders.object_detection import ObjectDetectionBuilder +from blueoil.utils.tfds_builders.segmentation import SegmentationBuilder def _get_tfds_settings(config_file): @@ -47,7 +48,7 @@ def _get_tfds_builder_class(dataset_class): raise ValueError("You cannot use dataset classes which is already a TFDS format.") if issubclass(dataset_class, SegmentationBase): - raise NotImplementedError("A dataset builder for segmentation dataset is not implemented yet.") + return SegmentationBuilder if issubclass(dataset_class, ObjectDetectionBase): return ObjectDetectionBuilder diff --git a/blueoil/cmd/train.py b/blueoil/cmd/train.py index 966a9afac..5bc670609 100644 --- a/blueoil/cmd/train.py +++ b/blueoil/cmd/train.py @@ -24,9 +24,9 @@ from blueoil import environment from blueoil.common import Tasks -from blueoil.datasets.base import ObjectDetectionBase +from blueoil.datasets.base import ObjectDetectionBase, SegmentationBase from blueoil.datasets.dataset_iterator import DatasetIterator -from blueoil.datasets.tfds import TFDSClassification, TFDSObjectDetection +from blueoil.datasets.tfds import TFDSClassification, TFDSObjectDetection, TFDSSegmentation from blueoil.utils import config as config_util from blueoil.utils import executor from blueoil.utils import horovod as horovod_util @@ -51,6 +51,8 @@ def setup_dataset(config, subset, rank, local_rank): if tfds_kwargs: if issubclass(DatasetClass, ObjectDetectionBase): DatasetClass = TFDSObjectDetection + elif issubclass(DatasetClass, SegmentationBase): + DatasetClass = TFDSSegmentation else: DatasetClass = TFDSClassification diff --git a/blueoil/datasets/tfds.py b/blueoil/datasets/tfds.py index 8810d10dc..03ef01ee5 100644 --- a/blueoil/datasets/tfds.py +++ b/blueoil/datasets/tfds.py @@ -19,9 +19,10 @@ import tensorflow as tf import tensorflow_datasets as tfds -from blueoil.datasets.base import Base, ObjectDetectionBase +from blueoil.datasets.base import Base, ObjectDetectionBase, SegmentationBase from blueoil.utils.tfds_builders.classification import ClassificationBuilder from blueoil.utils.tfds_builders.object_detection import ObjectDetectionBuilder +from blueoil.utils.tfds_builders.segmentation import SegmentationBuilder def _grayscale_to_rgb(record): @@ -67,6 +68,13 @@ def _format_object_detection_record(record, image_size, num_max_boxes): return {"image": image, "label": gt_boxes} +def _format_segmentation_record(record, image_size): + image = tf.image.resize(record["image"], image_size) + segmentation_mask = tf.squeeze(tf.image.resize(record["segmentation_mask"], image_size), axis=2) + + return {"image": image, "label": segmentation_mask} + + class TFDSMixin: """A Mixin to compose dataset classes for TFDS.""" available_subsets = ["train", "validation"] @@ -273,3 +281,42 @@ def _format_dataset(self): lambda record: _format_object_detection_record(record, self._image_size, num_max_boxes), num_parallel_calls=tf.data.experimental.AUTOTUNE ) + + +class TFDSSegmentation(TFDSMixin, SegmentationBase): + """A dataset class for loading TensorFlow Datasets for segmentation. + TensorFlow Datasets which have "label" and "image" features can be loaded by this class. + """ + builder_class = SegmentationBuilder + + @property + def classes(self): + return self.info.features["label"].names + + @property + def num_classes(self): + return self.info.features["label"].num_classes + + def _validate_feature_structure(self): + is_valid = \ + "label" in self.info.features and \ + "image" in self.info.features and \ + "segmentation_mask" in self.info.features and \ + isinstance(self.info.features["label"], tfds.features.ClassLabel) and \ + isinstance(self.info.features["image"], tfds.features.Image) and \ + isinstance(self.info.features["segmentation_mask"], tfds.features.Image) + + if not is_valid: + raise ValueError("Datasets should have \"label\", \"image\" and \"segmentation_mask\" features.") + + def _format_dataset(self): + if self.info.features['image'].shape[2] == 1: + self.tf_dataset = self.tf_dataset.map( + _grayscale_to_rgb, + num_parallel_calls=tf.data.experimental.AUTOTUNE + ) + + self.tf_dataset = self.tf_dataset.map( + lambda record: _format_segmentation_record(record, self._image_size), + num_parallel_calls=tf.data.experimental.AUTOTUNE + ) diff --git a/blueoil/utils/tfds_builders/segmentation.py b/blueoil/utils/tfds_builders/segmentation.py new file mode 100644 index 000000000..84525b201 --- /dev/null +++ b/blueoil/utils/tfds_builders/segmentation.py @@ -0,0 +1,81 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 The Blueoil Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +import numpy as np +import tensorflow_datasets as tfds + + +class SegmentationBuilder(tfds.core.GeneratorBasedBuilder): + """ + A custom TFDS builder for segmentation dataset. + This class loads data from existing dataset classes and + generate TFDS formatted dataset which is equivalent to the original one. + See also: https://www.tensorflow.org/datasets/add_dataset + """ + + VERSION = tfds.core.Version("0.1.0") + + def __init__(self, dataset_name, dataset_class=None, dataset_kwargs=None, **kwargs): + self.name = dataset_name + self.dataset_class = dataset_class + self.dataset_kwargs = dataset_kwargs + super().__init__(**kwargs) + + def _info(self): + return tfds.core.DatasetInfo( + builder=self, + description="Custom TFDS dataset for segmentation", + features=tfds.features.FeaturesDict({ + "image": tfds.features.Image(), + "label": tfds.features.ClassLabel(), + "segmentation_mask": tfds.features.Image(shape=(None, None, 1)), + }), + ) + + def _split_generators(self, dl_manager): + self.info.features["label"].names = self.dataset_class(**self.dataset_kwargs).classes + + predefined_names = { + "train": tfds.Split.TRAIN, + "validation": tfds.Split.VALIDATION, + "test": tfds.Split.TEST, + } + + splits = [] + for subset in self.dataset_class.available_subsets: + dataset = self.dataset_class(subset=subset, **self.dataset_kwargs) + splits.append( + tfds.core.SplitGenerator( + name=predefined_names[subset], + num_shards=self._num_shards(dataset), + gen_kwargs=dict(dataset=dataset) + ) + ) + + return splits + + def _generate_examples(self, dataset): + for i, (image, segmentation_mask) in enumerate(dataset): + yield i, { + "image": image, + "segmentation_mask": np.expand_dims(segmentation_mask, axis=2), + "label": -1, # dummy label + } + + def _num_shards(self, dataset): + """Decide a number of shards so as not the size of each shard exceeds 256MiB""" + max_shard_size = 256 * 1024 * 1024 # 256MiB + total_size = sum((image.nbytes + mask.nbytes) for image, mask in dataset) + return (total_size + max_shard_size - 1) // max_shard_size diff --git a/tests/unit/executor_tests/test_build_tfds.py b/tests/unit/executor_tests/test_build_tfds.py index 46dfcc6f0..1b7980d19 100644 --- a/tests/unit/executor_tests/test_build_tfds.py +++ b/tests/unit/executor_tests/test_build_tfds.py @@ -20,7 +20,7 @@ from blueoil.cmd.train import run as train_run from blueoil import environment from blueoil.datasets.dataset_iterator import DatasetIterator -from blueoil.datasets.tfds import TFDSClassification, TFDSObjectDetection +from blueoil.datasets.tfds import TFDSClassification, TFDSObjectDetection, TFDSSegmentation from blueoil.utils import config as config_util _RUN_AS_A_SCRIPT = False @@ -151,3 +151,61 @@ def test_build_tfds_object_detection(): assert labels.shape[0] == config.BATCH_SIZE assert labels.shape[1] == num_max_boxes assert labels.shape[2] == 5 + + +def test_build_tfds_segmentation(): + environment.setup_test_environment() + + # Build TFDS Dataset + config_file = "unit/fixtures/configs/for_build_tfds_segmentation.py" + run(config_file, overwrite=True) + + # Check if the builded dataset can be loaded with the same config file + expriment_id = "tfds_segmentation" + train_run(None, None, config_file, expriment_id, recreate=True) + + # Check if the dataset was build correctly + train_data_num = 5 + validation_data_num = 5 + config = config_util.load(config_file) + + train_dataset = setup_dataset(TFDSSegmentation, + subset="train", + batch_size=config.BATCH_SIZE, + **config.DATASET.TFDS_KWARGS) + + validation_dataset = setup_dataset(TFDSSegmentation, + subset="validation", + batch_size=config.BATCH_SIZE, + **config.DATASET.TFDS_KWARGS) + + assert train_dataset.num_per_epoch == train_data_num + assert validation_dataset.num_per_epoch == validation_data_num + + for _ in range(train_data_num): + images, labels = train_dataset.feed() + + assert isinstance(images, np.ndarray) + assert images.shape[0] == config.BATCH_SIZE + assert images.shape[1] == config.IMAGE_SIZE[0] + assert images.shape[2] == config.IMAGE_SIZE[1] + assert images.shape[3] == 3 + + assert isinstance(labels, np.ndarray) + assert labels.shape[0] == config.BATCH_SIZE + assert labels.shape[1] == config.IMAGE_SIZE[0] + assert labels.shape[2] == config.IMAGE_SIZE[1] + + for _ in range(validation_data_num): + images, labels = validation_dataset.feed() + + assert isinstance(images, np.ndarray) + assert images.shape[0] == config.BATCH_SIZE + assert images.shape[1] == config.IMAGE_SIZE[0] + assert images.shape[2] == config.IMAGE_SIZE[1] + assert images.shape[3] == 3 + + assert isinstance(labels, np.ndarray) + assert labels.shape[0] == config.BATCH_SIZE + assert labels.shape[1] == config.IMAGE_SIZE[0] + assert labels.shape[2] == config.IMAGE_SIZE[1] diff --git a/tests/unit/fixtures/configs/for_build_tfds_segmentation.py b/tests/unit/fixtures/configs/for_build_tfds_segmentation.py new file mode 100644 index 000000000..803a60b1c --- /dev/null +++ b/tests/unit/fixtures/configs/for_build_tfds_segmentation.py @@ -0,0 +1,106 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 The Blueoil Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +from easydict import EasyDict +import tensorflow as tf + +from blueoil.common import Tasks +from blueoil.networks.segmentation.lm_segnet_v1 import LmSegnetV1Quantize +from blueoil.datasets.camvid import CamvidCustom +from blueoil.data_processor import Sequence +from blueoil.pre_processor import ( + Resize, + DivideBy255, +) +from blueoil.data_augmentor import ( + Brightness, + Color, + Contrast, + FlipLeftRight, + Hue, +) +from blueoil.quantizations import ( + binary_mean_scaling_quantizer, + linear_mid_tread_half_quantizer, +) + + +class SegmentationDataset(CamvidCustom): + extend_dir = "camvid_custom" + validation_extend_dir = "camvid_custom" + + +IS_DEBUG = False + +NETWORK_CLASS = LmSegnetV1Quantize +DATASET_CLASS = SegmentationDataset + +IMAGE_SIZE = [80, 120] +BATCH_SIZE = 8 +DATA_FORMAT = "NHWC" +TASK = Tasks.SEMANTIC_SEGMENTATION +CLASSES = DATASET_CLASS(subset="train", batch_size=1).classes + +MAX_STEPS = 1 +SAVE_CHECKPOINT_STEPS = 1 +KEEP_CHECKPOINT_MAX = 5 +TEST_STEPS = 100 +SUMMARISE_STEPS = 100 + +# distributed training +IS_DISTRIBUTION = False + +# pretrain +IS_PRETRAIN = False +PRETRAIN_VARS = [] +PRETRAIN_DIR = "" +PRETRAIN_FILE = "" + +PRE_PROCESSOR = Sequence([ + Resize(size=IMAGE_SIZE), + DivideBy255(), +]) +POST_PROCESSOR = None + +NETWORK = EasyDict() +NETWORK.OPTIMIZER_CLASS = tf.compat.v1.train.AdamOptimizer +NETWORK.OPTIMIZER_KWARGS = {"learning_rate": 0.001} +NETWORK.IMAGE_SIZE = IMAGE_SIZE +NETWORK.BATCH_SIZE = BATCH_SIZE +NETWORK.DATA_FORMAT = DATA_FORMAT +NETWORK.ACTIVATION_QUANTIZER = linear_mid_tread_half_quantizer +NETWORK.ACTIVATION_QUANTIZER_KWARGS = { + 'bit': 2, + 'max_value': 2 +} +NETWORK.WEIGHT_QUANTIZER = binary_mean_scaling_quantizer +NETWORK.WEIGHT_QUANTIZER_KWARGS = {} + +DATASET = EasyDict() +DATASET.BATCH_SIZE = BATCH_SIZE +DATASET.DATA_FORMAT = DATA_FORMAT +DATASET.PRE_PROCESSOR = PRE_PROCESSOR +DATASET.AUGMENTOR = Sequence([ + Brightness((0.75, 1.25)), + Color((0.75, 1.25)), + Contrast((0.75, 1.25)), + FlipLeftRight(), + Hue((-10, 10)), +]) +DATASET.TFDS_KWARGS = { + "name": "tfds_segmentation", + "data_dir": "tmp/tests/datasets", + "image_size": IMAGE_SIZE, +}