Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add ImageNet21k #225

Merged
merged 2 commits into from
Mar 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion mmselfsup/datasets/data_sources/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,9 @@
from .cifar import CIFAR10, CIFAR100
from .image_list import ImageList
from .imagenet import ImageNet
from .imagenet_21k import ImageNet21k

__all__ = ['BaseDataSource', 'CIFAR10', 'CIFAR100', 'ImageList', 'ImageNet']
__all__ = [
'BaseDataSource', 'CIFAR10', 'CIFAR100', 'ImageList', 'ImageNet',
'ImageNet21k'
]
10 changes: 9 additions & 1 deletion mmselfsup/datasets/data_sources/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,15 @@ def get_img(self, idx):
if self.file_client is None:
self.file_client = mmcv.FileClient(**self.file_client_args)

if self.data_infos[idx].get('img_prefix', None) is not None:
if 'ImageNet-21k' in self.data_prefix:
filename = osp.join(self.data_prefix,
self.data_infos[idx].decode('utf-8'))
img_bytes = self.file_client.get(filename)
img = mmcv.imfrombytes(
img_bytes,
flag=self.color_type,
channel_order=self.channel_order)
elif self.data_infos[idx].get('img_prefix', None) is not None:
if self.data_infos[idx]['img_prefix'] is not None:
filename = osp.join(
self.data_infos[idx]['img_prefix'],
Expand Down
114 changes: 114 additions & 0 deletions mmselfsup/datasets/data_sources/imagenet_21k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os
import warnings

import numpy as np
from mmcv.utils import scandir

from ..builder import DATASOURCES
from .base import BaseDataSource
from .imagenet import find_folders


@DATASOURCES.register_module()
class ImageNet21k(BaseDataSource):
"""ImageNet21k Dataset. Since the dataset ImageNet21k is extremely big,
cantains 21k+ classes and 1.4B files. This class has improved the following
points on the basis of the class ``ImageNet``, in order to save memory
usage and time required :

- Delete the samples attribute
- using 'slots' create a Data_item tp replace dict
- Modify setting ``info`` dict from function ``load_annotations`` to
function ``prepare_data``
- using int instead of np.array(..., np.int64)
Args:
data_prefix (str): the prefix of data path
ann_file (str | None): the annotation file. When ann_file is str,
the subclass is expected to read from the ann_file. When ann_file
is None, the subclass is expected to read according to data_prefix
test_mode (bool): in train mode or test mode
multi_label (bool): use multi label or not.
recursion_subdir(bool): whether to use sub-directory pictures, which
are meet the conditions in the folder under category directory.
"""

IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif',
'.JPEG', '.JPG')
CLASSES = None

def __init__(self,
data_prefix,
classes=None,
ann_file=None,
multi_label=False,
recursion_subdir=False,
test_mode=False):
self.recursion_subdir = recursion_subdir
if multi_label:
raise NotImplementedError('Multi_label have not be implemented.')
self.multi_lable = multi_label
super(ImageNet21k, self).__init__(data_prefix, classes, ann_file,
test_mode)

def load_annotations(self):
"""load dataset annotations."""
if self.ann_file is None:
data_infos = self._load_annotations_from_dir()
elif isinstance(self.ann_file, str):
data_infos = self._load_annotations_from_file()
else:
raise TypeError('ann_file must be a str or None')

if len(data_infos) == 0:
msg = 'Found no valid file in '
msg += f'{self.ann_file}. ' if self.ann_file \
else f'{self.data_prefix}. '
msg += 'Supported extensions are: ' + \
', '.join(self.IMG_EXTENSIONS)
raise RuntimeError(msg)

return data_infos

def _find_allowed_files(self, root, folder_name):
"""find all the allowed files in a folder, including sub folder if
recursion_subdir is true."""
_dir = os.path.join(root, folder_name)
data_infos = []
for path in scandir(_dir, self.IMG_EXTENSIONS, self.recursion_subdir):
path = os.path.join(folder_name, path)
data_infos.append(path)
return data_infos

def _load_annotations_from_dir(self):
"""load annotations from self.data_prefix directory."""
data_infos, empty_classes = [], []
folder_to_idx = find_folders(self.data_prefix)
root = os.path.expanduser(self.data_prefix)
for folder_name in folder_to_idx.keys():
infos_pre_class = self._find_allowed_files(root, folder_name)
if len(infos_pre_class) == 0:
empty_classes.append(folder_name)
data_infos.extend(infos_pre_class)

if len(empty_classes) != 0:
msg = 'Found no valid file for the classes ' + \
f"{', '.join(sorted(empty_classes))} "
msg += 'Supported extensions are: ' + \
f"{', '.join(self.IMG_EXTENSIONS)}."
warnings.warn(msg)

return np.array(data_infos, dtype='S36')

def _load_annotations_from_file(self):
"""load annotations from self.ann_file."""
data_infos = []
with open(self.ann_file) as f:
for line in f.readlines():
if line == '':
continue
filepath, gt_label = line.strip().rsplit(' ', 1)
# info = ImageInfo(filepath, int(gt_label))
data_infos.append(filepath)

return np.array(data_infos, dtype='S36')