This repository has been archived by the owner on Aug 23, 2019. It is now read-only.
forked from abisee/pointer-generator
-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathdecoder.py
126 lines (106 loc) · 3.97 KB
/
decoder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
"""
Seq-to-seq based summarization method. Model is based off of
https://github.com/abisee/pointer-generator and is trained on 300K news articles from
CNN / Dailymail and 100K new cables.
"""
import os
from spacy.tokens.doc import Doc
_model_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'model_parameters')
_vocab_path = os.path.join(_model_dir, 'vocab')
_vocab_size = 20000
_beam_size = 4
_settings = None
_hps = None
_vocab = None
_sess = None
_model = None
def _load_model():
# These imports are slow - lazy import.
import tensorflow as tf
from data import Vocab
from model import Hps, Settings, SummarizationModel
global _settings, _hps, _vocab, _sess, _model
# Define settings and hyperparameters
_settings = Settings(
embeddings_path='',
log_root='',
trace_path='',# traces/traces_blog',
)
_hps = Hps(
# parameters important for decoding
attn_only_entities=False,
batch_size=_beam_size,
copy_only_entities=False,
emb_dim=128,
enc_hidden_dim=200,
dec_hidden_dim=300,
max_dec_steps=1,
max_enc_steps=400,
mode='decode',
output_vocab_size=20000,
restrictive_embeddings=False,
save_matmul=False,
tied_output=True,
two_layer_lstm=True,
# other parameters
adagrad_init_acc=.1,
adam_optimizer=True,
copy_common_loss_wt=0.,
cov_loss_wt=0.,
high_attn_loss_wt=0.,
lr=.15,
max_grad_norm=2.,
people_loss_wt=0.,
rand_unif_init_mag=.02,
scatter_loss_wt=0.,
sharp_loss_wt=0.,
trunc_norm_init_std=1e-4,
)
# Define model
_vocab = Vocab(_vocab_path, _vocab_size)
_model = SummarizationModel(_settings, _hps, _vocab)
_model.build_graph()
# Load model from disk
saver = tf.train.Saver()
config = tf.ConfigProto(
allow_soft_placement=True,
#intra_op_parallelism_threads=1,
#inter_op_parallelism_threads=1,
)
_sess = tf.Session(config=config)
ckpt_state = tf.train.get_checkpoint_state(_model_dir)
saver.restore(_sess, ckpt_state.model_checkpoint_path)
def generate_summary(spacy_article, ideal_summary_length_tokens=60):
"""
Generates summary of the given article. Note that this is slow (~20 seconds on a single CPU).
Args:
spacy_article: Spacy-processed text. The model was trained on the output of
doc.spacy_text(), so for best results the input here should also come from doc.spacy_text().
Returns:
Tuple of unicode summary of the text and scalar score of its quality. Score is approximately
an average log-likelihood of the summary (so it is < 0.) and typically is in the range
[-.2, -.5]. Summaries with scores below -.4 are usually not very good.
"""
assert isinstance(spacy_article, Doc)
# These imports are slow - lazy import.
from batcher import Batch, Example
from beam_search import run_beam_search
from io_processing import process_article, process_output
if _model is None:
_load_model()
# Handle short inputs
article_tokens, _, orig_article_tokens = process_article(spacy_article)
if len(article_tokens) <= ideal_summary_length_tokens:
return spacy_article.text, 0.
min_summary_length = min(10 + len(article_tokens) / 10, 2 * ideal_summary_length_tokens / 3)
max_summary_length = min(10 + len(article_tokens) / 5, 3 * ideal_summary_length_tokens / 2)
# Make input data
example = Example(' '.join(article_tokens), abstract='', vocab=_vocab, hps=_hps)
batch = Batch([example] * _beam_size, _hps, _vocab)
# Generate output
hyp, score = run_beam_search(
_sess, _model, _vocab, batch, _beam_size, max_summary_length, min_summary_length,
_settings.trace_path,
)
# Extract the output ids from the hypothesis and convert back to words
return process_output(hyp.token_strings[1:], orig_article_tokens), score