-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcreate_datasets.py
179 lines (152 loc) · 6.61 KB
/
create_datasets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#!/usr/bin/env python
"""Script to download all datasets and create .tfrecord files.
"""
import collections
import gzip
import os
import tarfile
import tempfile
import urllib.request as urllib
import zipfile
from google_drive_downloader import GoogleDriveDownloader as gdd
import numpy as np
import scipy.io
import tensorflow as tf
from lib.data import DATA_DIR
from tqdm import trange, tqdm
URLS = {
'svhn': 'http://ufldl.stanford.edu/housenumbers/{}_32x32.mat',
'cifar10': 'https://www.cs.toronto.edu/~kriz/cifar-10-matlab.tar.gz',
'celeba': '0B7EVK8r0v71pZjFTYXZWM3FlRnM',
'mnist': 'https://storage.googleapis.com/cvdf-datasets/mnist/{}.gz',
}
def _encode_png(images):
raw = []
with tf.Session() as sess, tf.device('cpu:0'):
image_x = tf.placeholder(tf.uint8, [None, None, None], 'image_x')
to_png = tf.image.encode_png(image_x)
for x in trange(images.shape[0], desc='PNG Encoding', leave=False):
raw.append(sess.run(to_png, feed_dict={image_x: images[x]}))
return raw
def _load_svhn():
splits = collections.OrderedDict()
for split in ['train', 'test', 'extra']:
with tempfile.NamedTemporaryFile() as f:
urllib.urlretrieve(URLS['svhn'].format(split), f.name)
data_dict = scipy.io.loadmat(f.name)
dataset = {}
dataset['images'] = np.transpose(data_dict['X'], [3, 0, 1, 2])
dataset['images'] = _encode_png(dataset['images'])
dataset['labels'] = data_dict['y'].reshape((-1))
# SVHN raw data uses labels from 1 to 10; use 0 to 9 instead.
dataset['labels'] -= 1
splits[split] = dataset
return splits
def _load_cifar10():
def unflatten(images):
return np.transpose(images.reshape((images.shape[0], 3, 32, 32)),
[0, 2, 3, 1])
with tempfile.NamedTemporaryFile() as f:
urllib.urlretrieve(URLS['cifar10'], f.name)
tar = tarfile.open(fileobj=f)
train_data_batches, train_data_labels = [], []
for batch in range(1, 6):
data_dict = scipy.io.loadmat(tar.extractfile(
'cifar-10-batches-mat/data_batch_{}.mat'.format(batch)))
train_data_batches.append(data_dict['data'])
train_data_labels.append(data_dict['labels'].flatten())
train_set = {'images': np.concatenate(train_data_batches, axis=0),
'labels': np.concatenate(train_data_labels, axis=0)}
data_dict = scipy.io.loadmat(tar.extractfile(
'cifar-10-batches-mat/test_batch.mat'))
test_set = {'images': data_dict['data'],
'labels': data_dict['labels'].flatten()}
train_set['images'] = _encode_png(unflatten(train_set['images']))
test_set['images'] = _encode_png(unflatten(test_set['images']))
return dict(train=train_set, test=test_set)
def _load_celeba():
with tempfile.NamedTemporaryFile() as f:
gdd.download_file_from_google_drive(
file_id=URLS['celeba'], dest_path=f.name, overwrite=True)
zip_f = zipfile.ZipFile(f)
images = []
for image_file in tqdm(zip_f.namelist(), 'Decompressing', leave=False):
if os.path.splitext(image_file)[1] == '.jpg':
with zip_f.open(image_file) as image_f:
images.append(image_f.read())
train_set = {'images': images, 'labels': np.zeros(len(images), int)}
return dict(train=train_set)
def _load_mnist():
def _read32(data):
dt = np.dtype(np.uint32).newbyteorder('>')
return np.frombuffer(data.read(4), dtype=dt)[0]
image_filename = '{}-images-idx3-ubyte.gz'
label_filename = '{}-labels-idx1-ubyte.gz'
split_files = collections.OrderedDict(
[('train', 'train'), ('test', 't10k')])
splits = collections.OrderedDict()
for split, split_file in split_files.items():
f = os.path.join('./Data', image_filename.format(split_file))
f = open(f, 'rb')
with gzip.GzipFile(fileobj=f, mode='r') as data:
assert _read32(data) == 2051
n_images = _read32(data)
row = _read32(data)
col = _read32(data)
images = np.frombuffer(
data.read(n_images * row * col), dtype=np.uint8)
images = images.reshape((n_images, row, col, 1))
f.close()
f = os.path.join('./Data', label_filename.format(split_file))
f = open(f, 'rb')
print(f)
with gzip.GzipFile(fileobj=f, mode='r') as data:
assert _read32(data) == 2049
n_labels = _read32(data)
labels = np.frombuffer(data.read(n_labels), dtype=np.uint8)
splits[split] = {'images': _encode_png(images), 'labels': labels}
f.close()
return splits
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def _save_as_tfrecord(data, filename):
assert len(data['images']) == len(data['labels'])
filename = os.path.join(DATA_DIR, filename + '.tfrecord')
# print 'Saving dataset:', filename
with tf.python_io.TFRecordWriter(filename) as writer:
for x in trange(len(data['images']), desc='Building records'):
feat = dict(label=_int64_feature(data['labels'][x]),
image=_bytes_feature(data['images'][x]))
record = tf.train.Example(features=tf.train.Features(feature=feat))
writer.write(record.SerializeToString())
LOADERS = [
('mnist', _load_mnist),
('cifar10', _load_cifar10),
('svhn', _load_svhn),
('celeba', _load_celeba)
]
if __name__ == '__main__':
try:
os.makedirs(DATA_DIR)
except OSError:
pass
for name, loader in LOADERS:
print('Preparing', name)
datas = loader()
for sub_name, data in datas.items():
_save_as_tfrecord(data, '%s-%s' % (name, sub_name))