-
Notifications
You must be signed in to change notification settings - Fork 86
create div2k dataset class #547
Changes from 2 commits
98e5c40
4d0772b
010d235
d73b42f
0f19936
9777c7f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
# -*- coding: utf-8 -*- | ||
# Copyright 2018 The Blueoil Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================= | ||
import functools | ||
import os | ||
from glob import glob | ||
|
||
tsawada marked this conversation as resolved.
Show resolved
Hide resolved
|
||
from lmnet.datasets.base import Base | ||
from lmnet.utils.image import load_image | ||
|
||
|
||
class Div2k(Base): | ||
classes = [] | ||
num_classes = 0 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess setting |
||
extend_dir = "DIV2K" | ||
available_subsets = ["train", "validation"] | ||
|
||
@property | ||
@functools.lru_cache(maxsize=None) | ||
def files(self): | ||
if self.subset == "train": | ||
images_dir = os.path.join(self.data_dir, "DIV2K_train_HR") | ||
else: | ||
images_dir = os.path.join(self.data_dir, "DIV2K_valid_HR") | ||
|
||
return [filepath for filepath in glob(os.path.join(images_dir, "*.png"))] | ||
|
||
@property | ||
def num_per_epoch(self): | ||
return len(self.files) | ||
|
||
def __getitem__(self, i, type=None): | ||
target_file = self.files[i] | ||
image = load_image(target_file) | ||
|
||
return image, None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same as above. |
||
|
||
def __len__(self): | ||
return self.num_per_epoch |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
# -*- coding: utf-8 -*- | ||
# Copyright 2018 The Blueoil Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================= | ||
import os | ||
from glob import glob | ||
tsawada marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
import numpy as np | ||
import pytest | ||
|
||
from lmnet import environment | ||
from lmnet.datasets.dataset_iterator import DatasetIterator | ||
from lmnet.datasets.div2k import Div2k | ||
|
||
|
||
def test_train_files(set_test_environment): | ||
dataset = Div2k("train") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we don't need to bind everything into local variables.
Same for other tests. |
||
files = sorted(dataset.files) | ||
|
||
expected = [ | ||
"tests/fixtures/datasets/DIV2K/DIV2K_train_HR/0001.png", | ||
"tests/fixtures/datasets/DIV2K/DIV2K_train_HR/0002.png", | ||
"tests/fixtures/datasets/DIV2K/DIV2K_train_HR/0003.png", | ||
"tests/fixtures/datasets/DIV2K/DIV2K_train_HR/0004.png", | ||
"tests/fixtures/datasets/DIV2K/DIV2K_train_HR/0005.png", | ||
] | ||
|
||
assert files == expected | ||
|
||
|
||
def test_validation_files(set_test_environment): | ||
dataset = Div2k("validation") | ||
files = sorted(dataset.files) | ||
|
||
expected = [ | ||
"tests/fixtures/datasets/DIV2K/DIV2K_valid_HR/0001.png", | ||
"tests/fixtures/datasets/DIV2K/DIV2K_valid_HR/0002.png", | ||
"tests/fixtures/datasets/DIV2K/DIV2K_valid_HR/0003.png", | ||
"tests/fixtures/datasets/DIV2K/DIV2K_valid_HR/0004.png", | ||
"tests/fixtures/datasets/DIV2K/DIV2K_valid_HR/0005.png", | ||
] | ||
|
||
assert files == expected | ||
|
||
|
||
def test_length(set_test_environment): | ||
dataset = Div2k("train") | ||
expected = 5 | ||
|
||
assert len(dataset) == expected | ||
|
||
|
||
def test_num_per_epoch(set_test_environment): | ||
dataset = Div2k("train") | ||
expected = 5 | ||
|
||
assert dataset.num_per_epoch == expected | ||
|
||
|
||
def test_get_item(set_test_environment): | ||
dataset = Div2k("train") | ||
|
||
assert all([isinstance(image, np.ndarray) for image, _ in dataset]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. no need to use list ( There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I did not know that! |
||
|
||
|
||
@pytest.mark.parametrize("subset", ["train", "validation"]) | ||
def test_can_iterate(set_test_environment, subset): | ||
batch_size = 1 | ||
image_size = (100, 100) | ||
|
||
dataset = Div2k(subset, batch_size=batch_size) | ||
iterator = DatasetIterator(dataset) | ||
|
||
for _ in range(len(dataset)): | ||
images, labels = iterator.feed() | ||
|
||
assert isinstance(images, np.ndarray) | ||
assert images.shape[0] == batch_size | ||
assert images.shape[1] == image_size[0] | ||
assert images.shape[2] == image_size[1] | ||
assert images.shape[3] == 3 | ||
|
||
assert isinstance(labels, np.ndarray) | ||
assert labels.shape[0] == batch_size | ||
assert labels[0] is None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we remove a line here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My autopep8 removes unnecessary line break such as double empty line in functions.
Is this line necessary?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this line does not seem necessary but this diff is not related to the PR and thus should go into another PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
absolutely!