Skip to content

Commit

Permalink
Merge pull request #257 from emanjavacas/attn
Browse files Browse the repository at this point in the history
A pull request for #242 (greedy decoding and vectorization in attention.py)
  • Loading branch information
neubig authored Jan 20, 2017
2 parents d3ac428 + 991aa50 commit 95cac6b
Showing 1 changed file with 29 additions and 27 deletions.
56 changes: 29 additions & 27 deletions examples/python/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,22 +62,22 @@ def encode_sentence(enc_fwd_lstm, enc_bwd_lstm, sentence):
return vectors


def attend(input_vectors, state):
global attention_w1
def attend(input_mat, state, w1dt):
global attention_w2
global attention_v
w1 = dy.parameter(attention_w1)
w2 = dy.parameter(attention_w2)
v = dy.parameter(attention_v)
attention_weights = []

# input_mat: (encoder_state x seqlen) => input vecs concatenated as cols
# w1dt: (attdim x seqlen)
# w2dt: (attdim x attdim)
w2dt = w2*dy.concatenate(list(state.s()))
for input_vector in input_vectors:
attention_weight = v*dy.tanh(w1*input_vector + w2dt)
attention_weights.append(attention_weight)
attention_weights = dy.softmax(dy.concatenate(attention_weights))
output_vectors = dy.esum([vector*attention_weight for vector, attention_weight in zip(input_vectors, attention_weights)])
return output_vectors
# att_weights: (seqlen,) row vector
unnormalized = dy.transpose(v * dy.tanh(dy.colwise_add(w1dt, w2dt)))
att_weights = dy.softmax(unnormalized)
# context: (encoder_state)
context = input_mat * att_weights
return context


def decode(dec_lstm, vectors, output):
Expand All @@ -86,13 +86,18 @@ def decode(dec_lstm, vectors, output):

w = dy.parameter(decoder_w)
b = dy.parameter(decoder_b)
w1 = dy.parameter(attention_w1)
input_mat = dy.concatenate_cols(vectors)
w1dt = None

last_output_embeddings = output_lookup[char2int[EOS]]
s = dec_lstm.initial_state().add_input(dy.concatenate([dy.vecInput(STATE_SIZE*2), last_output_embeddings]))
loss = []
for char in output:
vector = dy.concatenate([attend(vectors, s), last_output_embeddings])

for char in output:
# w1dt can be computed and cached once for the entire decoding phase
w1dt = w1dt or w1 * input_mat
vector = dy.concatenate([attend(input_mat, s, w1dt), last_output_embeddings])
s = s.add_input(vector)
out_vector = w * s.output() + b
probs = dy.softmax(out_vector)
Expand All @@ -102,33 +107,30 @@ def decode(dec_lstm, vectors, output):
return loss


def generate(input, enc_fwd_lstm, enc_bwd_lstm, dec_lstm):
def sample(probs):
rnd = random.random()
for i, p in enumerate(probs):
rnd -= p
if rnd <= 0: break
return i

embedded = embed_sentence(input)
def generate(in_seq, enc_fwd_lstm, enc_bwd_lstm, dec_lstm):
embedded = embed_sentence(in_seq)
encoded = encode_sentence(enc_fwd_lstm, enc_bwd_lstm, embedded)

w = dy.parameter(decoder_w)
b = dy.parameter(decoder_b)
w1 = dy.parameter(attention_w1)
input_mat = dy.concatenate_cols(encoded)
w1dt = None

last_output_embeddings = output_lookup[char2int[EOS]]
s = dec_lstm.initial_state().add_input(dy.concatenate([dy.vecInput(STATE_SIZE * 2), last_output_embeddings]))

out = ''
count_EOS = 0
for i in range(len(input)*2):
for i in range(len(in_seq)*2):
if count_EOS == 2: break
vector = dy.concatenate([attend(encoded, s), last_output_embeddings])

# w1dt can be computed and cached once for the entire decoding phase
w1dt = w1dt or w1 * input_mat
vector = dy.concatenate([attend(input_mat, s, w1dt), last_output_embeddings])
s = s.add_input(vector)
out_vector = w * s.output() + b
probs = dy.softmax(out_vector)
probs = probs.vec_value()
next_char = sample(probs)
probs = dy.softmax(out_vector).vec_value()
next_char = probs.index(max(probs))
last_output_embeddings = output_lookup[next_char]
if int2char[next_char] == EOS:
count_EOS += 1
Expand Down

0 comments on commit 95cac6b

Please sign in to comment.