-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathload_data.py
61 lines (41 loc) · 2.01 KB
/
load_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import tensorflow as tf
import numpy as np
import json
import os
def ParseFunction(serialized, image_shape=[32, 32, 3]):
features = {'image': tf.io.FixedLenFeature([], tf.string),
'label': tf.io.FixedLenFeature([], tf.int64)}
parsed_example = tf.io.parse_single_example(serialized=serialized, features=features)
image = tf.image.decode_image(parsed_example['image'])
image.set_shape(image_shape)
# image = tf.cast(image, tf.float32) #* (2.0 / 255) - 1.0
data = dict(image=image, label=parsed_example['label'])
return data
def LoadData(filename, tensor=False):
dataset = tf.data.TFRecordDataset(filename)
dataset = dataset.map(ParseFunction)
# it = tf.compat.v1.data.make_one_shot_iterator(dataset) # Never used?
images = np.stack([x['image'] for x in dataset])
labels = np.stack([x['label'] for x in dataset])
if tensor:
return tf.data.Dataset.from_tensor_slices((images, labels))
else:
return images, labels
def LoadAll(dir, dataset, seed, n_labeled, tensor=False):
l_data_fname = os.path.join(dir, "{}.{}@{}-label.tfrecord".format(dataset, str(seed), str(n_labeled)))
l_json_fname = os.path.join(dir, "{}.{}@{}-label.json".format(dataset, str(seed), str(n_labeled)))
u_data_fname = os.path.join(dir, "{}-unlabel.tfrecord".format(dataset))
u_json_fname = os.path.join(dir, "{}-unlabel.json".format(dataset))
with open(l_json_fname, "r") as f:
l_json = json.load(f)['label']
with open(u_json_fname, "r") as f:
u_json = json.load(f)['indexes']
ds_l, ls = LoadData(l_data_fname)
ds_u, _ = LoadData(u_data_fname)
new_ds_u = np.stack([ds_u[i, :, :, :] for i in u_json if i not in l_json])
if tensor:
return tf.data.Dataset.from_tensor_slices((ds_l, ls)), tf.data.Dataset.from_tensor_slices((new_ds_u))
return ds_l, new_ds_u, ls
def LoadTest(dir, dataset, tensor=False):
data_fname = os.path.join(dir, "{}-test.tfrecord".format(dataset))
return LoadData(data_fname, tensor)