-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathdata_handler.py
187 lines (147 loc) · 6.94 KB
/
data_handler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
import os
import json
import numpy as np
from nltk.tokenize import word_tokenize
from collections import Counter
PAD, BOS, EOS, UNK = 0, 1, 2, 3
MAX_TOKENS = 15
np.random.seed(1)
def load_vocab(path):
tok2id, id2tok = {}, {}
# Fixed special tokens
for idx, token in [(PAD, "<PAD>"), (BOS, "<BOS>"), (EOS, "<EOS>"), (UNK, "<UNK>")]:
tok2id[token] = idx
id2tok[idx] = token
with open(path, "r") as f:
for token in f:
pr_token = token.strip()
if pr_token not in tok2id:
tok2id[pr_token] = len(tok2id)
id2tok[len(id2tok)] = pr_token
return tok2id, id2tok
# Load and tokenize [source, target] pairs by path
def load_pairs(src_path, tgt_path):
pairs = []
with open(src_path) as f_src, open(tgt_path) as f_tgt:
for curr_src in f_src:
processed_src = curr_src.strip().split(" ")
processed_tgt = f_tgt.readline().strip().split(" ")
pairs.append([processed_src, processed_tgt])
return pairs
def preprocess(seq):
seq = seq.lower().replace(".", "").strip().replace("\n", " ")
return word_tokenize(seq)[:MAX_TOKENS]
def read_image_annotations(path):
# path... str (path to captions_train2014.json or captions_val2014.json file)
with open(path) as f:
data = json.load(f)
id2caption = {}
for annotation in data["annotations"]:
img_id = annotation["image_id"]
cap = annotation["caption"]
if img_id in id2caption:
id2caption[img_id].append(cap)
else:
id2caption[img_id] = [cap]
return id2caption
def mscoco_training_set(train_path, n_most_common=30_000):
token_counter = Counter()
tok2id, id2tok = {}, {}
# Fixed special tokens
for idx, token in [(PAD, "<PAD>"), (BOS, "<BOS>"), (EOS, "<EOS>"), (UNK, "<UNK>")]:
tok2id[token] = idx
id2tok[idx] = token
dataset = []
id2caption = read_image_annotations(train_path)
for img_id, captions in id2caption.items():
chosen_captions = np.random.choice(captions, size=4, replace=False).tolist()
chosen_captions = list(map(preprocess, chosen_captions))
# Count tokens only in the chosen (preprocessed) captions
for curr_cap in chosen_captions:
token_counter += Counter(curr_cap)
src1, src2 = chosen_captions[0], chosen_captions[1]
tgt1, tgt2 = chosen_captions[2], chosen_captions[3]
dataset.append([src1, tgt1])
dataset.append([tgt1, src1])
dataset.append([src2, tgt2])
dataset.append([tgt2, src2])
for curr_token, _ in token_counter.most_common(n_most_common):
tok2id[curr_token] = len(tok2id)
id2tok[len(id2tok)] = curr_token
return dataset, (tok2id, id2tok)
def mscoco_test_set(test_path, include_self_ref=False):
dev_size, test_size = 20_000, 20_000
dev_dataset, dev_refs = [], []
test_dataset, test_refs = [], []
id2caption = read_image_annotations(test_path)
assert len(id2caption) >= dev_size + test_size
# Randomly select 20k images for each of dev and test set (non-overlapping) and create a
# (1.) source-target pair (for loss-evaluation) and
# (2.) references (for BLEU/METEOR/... evaluation)
indices = np.random.choice(np.arange(len(id2caption)), size=(dev_size + test_size), replace=False)
all_captions = list(id2caption.items())
for idx in indices[: dev_size]:
img_id, captions = all_captions[idx]
chosen_captions = np.random.choice(captions, size=5, replace=False).tolist()
chosen_captions = list(map(preprocess, chosen_captions))
cap1, cap2, cap3, cap4, cap5 = chosen_captions
dev_dataset.append([cap1, cap2])
dev_refs.append([cap1, cap2, cap3, cap4, cap5] if include_self_ref else [cap2, cap3, cap4, cap5])
for idx in indices[dev_size:]:
img_id, captions = all_captions[idx]
chosen_captions = np.random.choice(captions, size=5, replace=False).tolist()
chosen_captions = list(map(preprocess, chosen_captions))
cap1, cap2, cap3, cap4, cap5 = chosen_captions
test_dataset.append([cap1, cap2])
test_refs.append([cap1, cap2, cap3, cap4, cap5] if include_self_ref else [cap2, cap3, cap4, cap5])
return (dev_dataset, dev_refs), (test_dataset, test_refs)
if __name__ == "__main__":
DATA_DIR = "data/mscoco"
train_name = "captions_train2014.json"
dev_name = "captions_val2014.json"
train_path = os.path.join(DATA_DIR, train_name)
dev_path = os.path.join(DATA_DIR, dev_name)
train_dir = os.path.join(DATA_DIR, "train_set")
if not os.path.exists(train_dir):
os.makedirs(train_dir)
raw_train_set, (tok2id, _) = mscoco_training_set(train_path)
with open(os.path.join(train_dir, "train_src.txt"), "w") as f_src, \
open(os.path.join(train_dir, "train_dst.txt"), "w") as f_dst:
for curr_src, curr_tgt in raw_train_set:
print(" ".join(curr_src), file=f_src)
print(" ".join(curr_tgt), file=f_dst)
print(f"**Wrote {len(raw_train_set)} train examples to '{train_dir}'**")
# Write vocabulary (sorted by mapping index) to file
with open(os.path.join(DATA_DIR, "vocab.txt"), "w") as f_src:
for token, _ in sorted(tok2id.items(), key=lambda tup: tup[1]):
print(token, file=f_src)
print(f"**Wrote vocabulary ({len(tok2id)}) to '{DATA_DIR}'**")
dev_dir = os.path.join(DATA_DIR, "dev_set")
if not os.path.exists(dev_dir):
os.makedirs(dev_dir)
test_dir = os.path.join(DATA_DIR, "test_set")
if not os.path.exists(test_dir):
os.makedirs(test_dir)
# NOTE: Self-reference is always placed as reference#0
(raw_dev_set, _), (raw_test_set, test_refs) = mscoco_test_set(dev_path, include_self_ref=True)
with open(os.path.join(dev_dir, "dev_src.txt"), "w") as f_src, \
open(os.path.join(dev_dir, "dev_dst.txt"), "w") as f_dst:
# Only write source-target pairs for dev set
for curr_src, curr_tgt in raw_dev_set:
print(" ".join(curr_src), file=f_src)
print(" ".join(curr_tgt), file=f_dst)
print(f"**Wrote {len(raw_dev_set)} dev examples to '{dev_dir}'**")
num_refs = len(test_refs[0])
with open(os.path.join(test_dir, "test_src.txt"), "w") as f_src, \
open(os.path.join(test_dir, "test_dst.txt"), "w") as f_dst:
ref_files = [open(os.path.join(test_dir, f"test_ref{i}.txt"), "w") for i in range(num_refs)]
# Write source-target pairs and source-references for test set
for idx, (curr_src, curr_tgt) in enumerate(raw_test_set):
curr_refs = test_refs[idx]
print(" ".join(curr_src), file=f_src)
print(" ".join(curr_tgt), file=f_dst)
for idx_ref, ref in enumerate(curr_refs):
print(" ".join(ref), file=ref_files[idx_ref])
for f in ref_files:
f.close()
print(f"**Wrote {len(raw_test_set)} test examples to '{test_dir}', {num_refs} references per example**")