Skip to content
This repository has been archived by the owner on Dec 1, 2021. It is now read-only.

create div2k dataset class #547

Merged
merged 6 commits into from
Oct 29, 2019
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion blueoil/generate_lmnet_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@
"dataset_module": "camvid",
"dataset_class": "CamvidCustom",
},
"DIV2K": {
"dataset_module": "div2k",
"dataset_class": "Div2k",
}
}


Expand Down Expand Up @@ -127,7 +131,6 @@ def _blueoil_to_lmnet(blueoil_config):
}
dataset = {}


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we remove a line here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My autopep8 removes unnecessary line break such as double empty line in functions.
Is this line necessary?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this line does not seem necessary but this diff is not related to the PR and thus should go into another PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

absolutely!

model_name = blueoil_config["model_name"]

template_file = _TASK_TYPE_TEMPLATE_FILE[blueoil_config["task_type"]]
Expand Down
51 changes: 51 additions & 0 deletions lmnet/lmnet/datasets/div2k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# -*- coding: utf-8 -*-
# Copyright 2018 The Blueoil Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
import functools
import os
from glob import glob

from lmnet.datasets.base import Base
from lmnet.utils.image import load_image


class Div2k(Base):
classes = []
num_classes = 0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess setting num_classes = 0 is for base class restriction.
It's not clear. Please add comment.

extend_dir = "DIV2K"
available_subsets = ["train", "validation"]

@property
@functools.lru_cache(maxsize=None)
def files(self):
if self.subset == "train":
images_dir = os.path.join(self.data_dir, "DIV2K_train_HR")
else:
images_dir = os.path.join(self.data_dir, "DIV2K_valid_HR")

return [filepath for filepath in glob(os.path.join(images_dir, "*.png"))]

@property
def num_per_epoch(self):
return len(self.files)

def __getitem__(self, i, type=None):
target_file = self.files[i]
image = load_image(target_file)

return image, None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above. None label looks tricky.


def __len__(self):
return self.num_per_epoch
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
96 changes: 96 additions & 0 deletions lmnet/tests/lmnet_tests/datasets_tests/test_div2k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# -*- coding: utf-8 -*-
# Copyright 2018 The Blueoil Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
import os
from glob import glob

import numpy as np
import pytest

from lmnet import environment
from lmnet.datasets.dataset_iterator import DatasetIterator
from lmnet.datasets.div2k import Div2k


def test_train_files(set_test_environment):
dataset = Div2k("train")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we don't need to bind everything into local variables.

assert sorted(Div2k("train")) == expected is short and easy to understand.

Same for other tests.

files = sorted(dataset.files)

expected = [
"tests/fixtures/datasets/DIV2K/DIV2K_train_HR/0001.png",
"tests/fixtures/datasets/DIV2K/DIV2K_train_HR/0002.png",
"tests/fixtures/datasets/DIV2K/DIV2K_train_HR/0003.png",
"tests/fixtures/datasets/DIV2K/DIV2K_train_HR/0004.png",
"tests/fixtures/datasets/DIV2K/DIV2K_train_HR/0005.png",
]

assert files == expected


def test_validation_files(set_test_environment):
dataset = Div2k("validation")
files = sorted(dataset.files)

expected = [
"tests/fixtures/datasets/DIV2K/DIV2K_valid_HR/0001.png",
"tests/fixtures/datasets/DIV2K/DIV2K_valid_HR/0002.png",
"tests/fixtures/datasets/DIV2K/DIV2K_valid_HR/0003.png",
"tests/fixtures/datasets/DIV2K/DIV2K_valid_HR/0004.png",
"tests/fixtures/datasets/DIV2K/DIV2K_valid_HR/0005.png",
]

assert files == expected


def test_length(set_test_environment):
dataset = Div2k("train")
expected = 5

assert len(dataset) == expected


def test_num_per_epoch(set_test_environment):
dataset = Div2k("train")
expected = 5

assert dataset.num_per_epoch == expected


def test_get_item(set_test_environment):
dataset = Div2k("train")

assert all([isinstance(image, np.ndarray) for image, _ in dataset])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need to use list ([]) here. Generator comprehension is better as it'll fail early.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did not know that!
thank you!!



@pytest.mark.parametrize("subset", ["train", "validation"])
def test_can_iterate(set_test_environment, subset):
batch_size = 1
image_size = (100, 100)

dataset = Div2k(subset, batch_size=batch_size)
iterator = DatasetIterator(dataset)

for _ in range(len(dataset)):
images, labels = iterator.feed()

assert isinstance(images, np.ndarray)
assert images.shape[0] == batch_size
assert images.shape[1] == image_size[0]
assert images.shape[2] == image_size[1]
assert images.shape[3] == 3

assert isinstance(labels, np.ndarray)
assert labels.shape[0] == batch_size
assert labels[0] is None