diff --git a/blueoil/generate_lmnet_config.py b/blueoil/generate_lmnet_config.py index b8cc4c255..4e788a29d 100644 --- a/blueoil/generate_lmnet_config.py +++ b/blueoil/generate_lmnet_config.py @@ -76,6 +76,10 @@ "dataset_module": "camvid", "dataset_class": "CamvidCustom", }, + "DIV2K": { + "dataset_module": "div2k", + "dataset_class": "Div2k", + } } diff --git a/lmnet/lmnet/datasets/div2k.py b/lmnet/lmnet/datasets/div2k.py new file mode 100644 index 000000000..4aa0e4b86 --- /dev/null +++ b/lmnet/lmnet/datasets/div2k.py @@ -0,0 +1,51 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 The Blueoil Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +import functools +import os +from glob import glob + +from lmnet.datasets.base import Base +from lmnet.utils.image import load_image + + +class Div2k(Base): + classes = [] + num_classes = 0 + extend_dir = "DIV2K" + available_subsets = ["train", "validation"] + + @property + @functools.lru_cache(maxsize=None) + def files(self): + if self.subset == "train": + images_dir = os.path.join(self.data_dir, "DIV2K_train_HR") + else: + images_dir = os.path.join(self.data_dir, "DIV2K_valid_HR") + + return [filepath for filepath in glob(os.path.join(images_dir, "*.png"))] + + @property + def num_per_epoch(self): + return len(self.files) + + def __getitem__(self, i, type=None): + target_file = self.files[i] + image = load_image(target_file) + + return image, None + + def __len__(self): + return self.num_per_epoch diff --git a/lmnet/tests/fixtures/datasets/DIV2K/DIV2K_train_HR/0001.png b/lmnet/tests/fixtures/datasets/DIV2K/DIV2K_train_HR/0001.png new file mode 100644 index 000000000..ab1b2b623 Binary files /dev/null and b/lmnet/tests/fixtures/datasets/DIV2K/DIV2K_train_HR/0001.png differ diff --git a/lmnet/tests/fixtures/datasets/DIV2K/DIV2K_train_HR/0002.png b/lmnet/tests/fixtures/datasets/DIV2K/DIV2K_train_HR/0002.png new file mode 100644 index 000000000..ab1b2b623 Binary files /dev/null and b/lmnet/tests/fixtures/datasets/DIV2K/DIV2K_train_HR/0002.png differ diff --git a/lmnet/tests/fixtures/datasets/DIV2K/DIV2K_train_HR/0003.png b/lmnet/tests/fixtures/datasets/DIV2K/DIV2K_train_HR/0003.png new file mode 100644 index 000000000..ab1b2b623 Binary files /dev/null and b/lmnet/tests/fixtures/datasets/DIV2K/DIV2K_train_HR/0003.png differ diff --git a/lmnet/tests/fixtures/datasets/DIV2K/DIV2K_train_HR/0004.png b/lmnet/tests/fixtures/datasets/DIV2K/DIV2K_train_HR/0004.png new file mode 100644 index 000000000..ab1b2b623 Binary files /dev/null and b/lmnet/tests/fixtures/datasets/DIV2K/DIV2K_train_HR/0004.png differ diff --git a/lmnet/tests/fixtures/datasets/DIV2K/DIV2K_train_HR/0005.png b/lmnet/tests/fixtures/datasets/DIV2K/DIV2K_train_HR/0005.png new file mode 100644 index 000000000..ab1b2b623 Binary files /dev/null and b/lmnet/tests/fixtures/datasets/DIV2K/DIV2K_train_HR/0005.png differ diff --git a/lmnet/tests/fixtures/datasets/DIV2K/DIV2K_valid_HR/0001.png b/lmnet/tests/fixtures/datasets/DIV2K/DIV2K_valid_HR/0001.png new file mode 100644 index 000000000..ab1b2b623 Binary files /dev/null and b/lmnet/tests/fixtures/datasets/DIV2K/DIV2K_valid_HR/0001.png differ diff --git a/lmnet/tests/fixtures/datasets/DIV2K/DIV2K_valid_HR/0002.png b/lmnet/tests/fixtures/datasets/DIV2K/DIV2K_valid_HR/0002.png new file mode 100644 index 000000000..ab1b2b623 Binary files /dev/null and b/lmnet/tests/fixtures/datasets/DIV2K/DIV2K_valid_HR/0002.png differ diff --git a/lmnet/tests/fixtures/datasets/DIV2K/DIV2K_valid_HR/0003.png b/lmnet/tests/fixtures/datasets/DIV2K/DIV2K_valid_HR/0003.png new file mode 100644 index 000000000..ab1b2b623 Binary files /dev/null and b/lmnet/tests/fixtures/datasets/DIV2K/DIV2K_valid_HR/0003.png differ diff --git a/lmnet/tests/fixtures/datasets/DIV2K/DIV2K_valid_HR/0004.png b/lmnet/tests/fixtures/datasets/DIV2K/DIV2K_valid_HR/0004.png new file mode 100644 index 000000000..ab1b2b623 Binary files /dev/null and b/lmnet/tests/fixtures/datasets/DIV2K/DIV2K_valid_HR/0004.png differ diff --git a/lmnet/tests/fixtures/datasets/DIV2K/DIV2K_valid_HR/0005.png b/lmnet/tests/fixtures/datasets/DIV2K/DIV2K_valid_HR/0005.png new file mode 100644 index 000000000..ab1b2b623 Binary files /dev/null and b/lmnet/tests/fixtures/datasets/DIV2K/DIV2K_valid_HR/0005.png differ diff --git a/lmnet/tests/lmnet_tests/datasets_tests/test_div2k.py b/lmnet/tests/lmnet_tests/datasets_tests/test_div2k.py new file mode 100644 index 000000000..4f96588ef --- /dev/null +++ b/lmnet/tests/lmnet_tests/datasets_tests/test_div2k.py @@ -0,0 +1,78 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 The Blueoil Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +import numpy as np +import pytest + +from lmnet.datasets.dataset_iterator import DatasetIterator +from lmnet.datasets.div2k import Div2k + + +def test_train_files(set_test_environment): + expected = [ + "tests/fixtures/datasets/DIV2K/DIV2K_train_HR/0001.png", + "tests/fixtures/datasets/DIV2K/DIV2K_train_HR/0002.png", + "tests/fixtures/datasets/DIV2K/DIV2K_train_HR/0003.png", + "tests/fixtures/datasets/DIV2K/DIV2K_train_HR/0004.png", + "tests/fixtures/datasets/DIV2K/DIV2K_train_HR/0005.png", + ] + assert sorted(Div2k("train").files) == expected + + +def test_validation_files(set_test_environment): + expected = [ + "tests/fixtures/datasets/DIV2K/DIV2K_valid_HR/0001.png", + "tests/fixtures/datasets/DIV2K/DIV2K_valid_HR/0002.png", + "tests/fixtures/datasets/DIV2K/DIV2K_valid_HR/0003.png", + "tests/fixtures/datasets/DIV2K/DIV2K_valid_HR/0004.png", + "tests/fixtures/datasets/DIV2K/DIV2K_valid_HR/0005.png", + ] + assert sorted(Div2k("validation").files) == expected + + +def test_length(set_test_environment): + expected = 5 + assert len(Div2k("train")) == expected + + +def test_num_per_epoch(set_test_environment): + expected = 5 + assert Div2k("train").num_per_epoch == expected + + +def test_get_item(set_test_environment): + assert all(isinstance(image, np.ndarray) for image, _ in Div2k("train")) + + +@pytest.mark.parametrize("subset", ["train", "validation"]) +def test_can_iterate(set_test_environment, subset): + batch_size = 1 + image_size = (100, 100) + + dataset = Div2k(subset, batch_size=batch_size) + iterator = DatasetIterator(dataset) + + for _ in range(len(dataset)): + images, labels = iterator.feed() + + assert isinstance(images, np.ndarray) + assert images.shape[0] == batch_size + assert images.shape[1] == image_size[0] + assert images.shape[2] == image_size[1] + assert images.shape[3] == 3 + + assert isinstance(labels, np.ndarray) + assert labels.shape[0] == batch_size + assert labels[0] is None