This repository has been archived by the owner on Aug 23, 2019. It is now read-only.
forked from abisee/pointer-generator
-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathscripts.py
343 lines (273 loc) · 10.9 KB
/
scripts.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
import reproducible
reproducible.add_non_git_file('results/articles/article_0.txt')
reproducible.write_fingerprint()
import json
import numpy as np
import os
import spacy
import string
import struct
import sys
from sklearn.decomposition.truncated_svd import TruncatedSVD
from tensorflow.core.example import example_pb2
import time
from data import N_FREE_TOKENS, Vocab
from make_datafiles import get_art_abs
from primer_core.analytic_pipelines.base.document_pipeline import SingleDocument
from primer_core.nlp.summary.lexrank.summary import compute_summaries
from decoder import generate_summary
######################################################
# Handling vocab
######################################################
def compute_reduced_embeddings_original_vocab(
output_vocab_filepath, output_embeddings_filepath, input_vocab_filepath, vocab_size,
embedding_dim
):
print N_FREE_TOKENS
vocab = Vocab(input_vocab_filepath, 1.5 * vocab_size)
spacy_vocab = spacy.load('en').vocab
matrix = np.zeros((vocab_size, spacy_vocab.vectors_length), dtype=np.float32)
new_i = 0
final_vocab = []
for i, word in vocab._id_to_word.iteritems():
if new_i == vocab_size:
break
if i >= N_FREE_TOKENS and unicode(word) not in spacy_vocab:
continue
if i >= N_FREE_TOKENS:
final_vocab.append(word)
matrix[new_i] = spacy_vocab[unicode(word)].vector
new_i += 1
print 'Last word added:', final_vocab[-1]
if embedding_dim < spacy_vocab.vectors_length:
svd = TruncatedSVD(n_components=embedding_dim, algorithm='arpack')
embeddings = svd.fit_transform(matrix)
print embeddings.shape
print [sum(svd.explained_variance_ratio_[:i]) for i in range(1, embedding_dim + 1)]
else:
embeddings = matrix
with open(output_vocab_filepath, 'w') as output:
for word in final_vocab:
output.write('%s\n' % word)
np.save(output_embeddings_filepath, embeddings)
# NOTE: Don't use this anymore - is not consistent with how Vocab loads words
def write_spacy_vocab(output_dirpath, vocab_size, embedding_dim):
if not os.path.exists(output_dirpath):
os.makedirs(output_dirpath)
allowed_chars = set(string.ascii_letters + string.punctuation)
ascii = set(string.ascii_letters)
ascii_plus_period = set(string.ascii_letters + '.')
word_set = set()
spacy_vocab = spacy.load('en').vocab
top_words = []
for w in spacy_vocab:
if w.rank > 2 * vocab_size:
continue
try:
word_string = str(w.lower_).strip()
if not word_string:
continue
if word_string in word_set:
continue
if any(bad_char in word_string for bad_char in ('[', ']', '<', '>', '{', '}')):
# these are used to mark word types and person ids.
continue
if any(c not in allowed_chars for c in word_string):
continue
if sum(1 for c in word_string if c not in ascii_plus_period) > 2:
continue
if word_string[-1] == '.' and sum(1 for c in word_string if c in ascii) > 2:
continue
top_words.append(w)
word_set.add(word_string)
except:
pass
top_words.sort(key=lambda w: w.rank)
top_words = top_words[:vocab_size]
with open(os.path.join(output_dirpath, 'vocab'), 'w') as f:
for word in top_words:
f.write('%s\n' % word.lower_.strip())
vectors = np.array([w.vector for w in top_words])
svd = TruncatedSVD(n_components=embedding_dim, algorithm='arpack')
embeddings = svd.fit_transform(vectors)
print embeddings.shape
print [sum(svd.explained_variance_ratio_[:i]) for i in range(1, embedding_dim + 1)]
np.save(os.path.join(output_dirpath, 'pretrained_embeddings.npy'), embeddings)
def compute_vocab_overlap(vocab_1, vocab_2):
matches = [[0] * 5 for i in range(5)]
missing = []
for w1, r1 in vocab_1.iteritems():
r2 = vocab_2.get(w1, 100000)
for i1, rank1 in enumerate((10000, 20000, 30000, 40000, 50000)):
if r1 >= rank1:
continue
for i2, rank2 in enumerate((10000, 20000, 30000, 40000, 50000)):
if r2 >= rank2:
if i1 == 0 and i2 == 1:
missing.append((w1, r1))
continue
matches[i1][i2] += 1
for row in matches:
print row
missing.sort(key=lambda pair: pair[1])
for item in missing:
print item
def read_vocab(filename):
vocab = {}
with open(filename) as f:
for i, line in enumerate(f):
word = line.split()[0]
vocab[word] = i
if i == 100:
print vocab
return vocab
def see_vocab_overlap(filepath1, filepath2):
vocab1 = read_vocab(filepath1)
vocab2 = read_vocab(filepath2)
compute_vocab_overlap(vocab1, vocab2)
######################################################
# Testing
######################################################
def write_dummy_example(out_file):
def write_single(article, abstract):
tf_example = example_pb2.Example()
tf_example.features.feature['article'].bytes_list.value.extend([article])
tf_example.features.feature['abstract'].bytes_list.value.extend([abstract])
tf_example_str = tf_example.SerializeToString()
str_len = len(tf_example_str)
writer.write(struct.pack('q', str_len))
writer.write(struct.pack('%ds' % str_len, tf_example_str))
with open(out_file, 'wb') as writer:
for i in range(1000):
article = 'hi there Michael{1} . that was Jake{2} .'
abstract = '<s> bye Michael{1} guy . </s>'
write_single(article, abstract)
article = 'hi there Michael{1} . this is Jake{2} .'
abstract = '<s> bye Jake{2} . </s>'
write_single(article, abstract)
######################################################
# Comparing results
######################################################
RAW_DATA_DIR = '/Users/michaelwu/dev/cnn-dailymail/raw_data/'
RAW_ARTICLE_DIRS = (os.path.join(RAW_DATA_DIR, dir) for dir in ('cnn', 'dailymail'))
RESULTS_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'results')
RESULTS_ARTICLE_DIR = os.path.join(RESULTS_DIR, 'articles')
RESULTS_ABSTRACT_DIR = os.path.join(RESULTS_DIR, 'abstracts')
N_ARTICLES = 100
SEARCH_TERMS = {
'tesla',
'crispr',
'bitcoin',
'zika',
'wall street',
'tim cook',
'merkel',
'intelligence',
'amazon',
'microsoft',
'FBI',
'vladimir putin',
'zuckerburg',
}
def find_articles():
if os.path.exists(RESULTS_DIR):
raise Exception
os.mkdir(RESULTS_DIR)
os.mkdir(RESULTS_ARTICLE_DIR)
os.mkdir(RESULTS_ABSTRACT_DIR)
articles = []
for article_dir in RAW_ARTICLE_DIRS:
for filename in os.listdir(article_dir):
article_path = os.path.join(article_dir, filename)
article, abstract = get_art_abs(article_path, add_periods=True, is_cable=False)
article_words = article.lower().split()
full_article_words = set(
article_words + [
' '.join([article_words[i], article_words[i + 1]])
for i in range(len(article_words) - 1)
]
)
search_count = sum(1 for term in SEARCH_TERMS if term in full_article_words)
if search_count:
articles.append((article, abstract, search_count))
articles.sort(key=lambda info: info[2], reverse=True)
for i, (article, abstract, search_count) in enumerate(articles[:N_ARTICLES]):
article_path = os.path.join(RESULTS_ARTICLE_DIR, 'article_%d.txt' % i)
abstract_path = os.path.join(RESULTS_ABSTRACT_DIR, 'abstract_%d.txt' % i)
print '#####################'
print i
print abstract
with open(article_path, 'w') as f:
f.write(article)
with open(abstract_path, 'w') as f:
f.write(abstract)
def get_lexrank_summary(doc):
summaries = compute_summaries(
[0],
{0: doc.text()},
{0: [{'start': span[0], 'end': span[1]} for span in doc.sentence_spans()]},
{},
)
return summaries[0]['summary']
def write_results(out_file):
out = open(out_file, 'w')
out.write('\t'.join(['Reference', 'Lexrank', 'Seq-to-seq', 'Score']) + '\n')
for filename in sorted(os.listdir(RESULTS_ARTICLE_DIR)):
article_id = int(filename.split('.')[0].split('_')[1])
# Read article
with open(os.path.join(RESULTS_ARTICLE_DIR, filename)) as f:
article_text = unicode(f.read(), 'utf-8')
article_text = article_text.replace(u'\xa0', ' ').replace('\t', ' ').replace('\n', ' ')
# Read reference summary
with open(os.path.join(RESULTS_ABSTRACT_DIR, 'abstract_%d.txt' % article_id)) as f:
reference_summary = f.read()
doc = SingleDocument(0, raw={'body': article_text})
# Generate lexrank summary
lexrank_summary = get_lexrank_summary(doc).encode('utf-8')
# Generate seq-to-seq summary
t0 = time.time()
spacy_article = doc.spacy_text()
seq_to_seq_summary, score = generate_summary(spacy_article)
seq_to_seq_summary = seq_to_seq_summary.encode('utf-8')
print '####################'
print seq_to_seq_summary
print 'Time:', time.time() - t0, '| Score:', score
# Write all results together
out.write('\t'.join([
reference_summary, lexrank_summary, seq_to_seq_summary, str(score)
]) + '\n')
out.flush()
out.close()
######################################################
# Generate sample summaries
######################################################
def get_cable_results(data_file, out_file):
out = open(out_file, 'w')
out.write('\t'.join(['Cable', 'Lexrank', 'Seq-to-seq']) + '\n')
with open(data_file) as f:
cables = json.load(f)
for cable in cables[:100]:
doc = SingleDocument(0, raw={'body': cable})
if len(doc.text()) < 500:
continue
lexrank = get_lexrank_summary(doc)
seq2seq = generate_summary(doc.spacy_text())[0]
out.write(
'\t'.join([
string.encode('utf-8').replace('\t', ' ').replace('\n', ' ')
for string in [cable, lexrank, seq2seq]
]) + '\n'
)
out.flush()
out.close()
######################################################
# Generate sample summaries
######################################################
if __name__ == '__main__':
#compute_reduced_embeddings_original_vocab(
# sys.argv[1], sys.argv[2], sys.argv[3], int(sys.argv[4]), int(sys.argv[5])
#)
write_results(sys.argv[1])
#find_articles()
#generate_input_file(sys.argv[1])
#get_cable_results(sys.argv[1], sys.argv[2])