diff --git a/pyexamples/attention.py b/pyexamples/attention.py index 90d7c642c..09dd7d452 100644 --- a/pyexamples/attention.py +++ b/pyexamples/attention.py @@ -20,23 +20,24 @@ enc_fwd_lstm = pc.LSTMBuilder(LSTM_NUM_OF_LAYERS, EMBEDDINGS_SIZE, STATE_SIZE, model) enc_bwd_lstm = pc.LSTMBuilder(LSTM_NUM_OF_LAYERS, EMBEDDINGS_SIZE, STATE_SIZE, model) -dec_lstm = pc.LSTMBuilder(LSTM_NUM_OF_LAYERS, STATE_SIZE*2, STATE_SIZE, model) +dec_lstm = pc.LSTMBuilder(LSTM_NUM_OF_LAYERS, STATE_SIZE*2+EMBEDDINGS_SIZE, STATE_SIZE, model) -lookup = model.add_lookup_parameters( (VOCAB_SIZE, EMBEDDINGS_SIZE)) +input_lookup = model.add_lookup_parameters((VOCAB_SIZE, EMBEDDINGS_SIZE)) attention_w1 = model.add_parameters( (ATTENTION_SIZE, STATE_SIZE*2)) attention_w2 = model.add_parameters( (ATTENTION_SIZE, STATE_SIZE*LSTM_NUM_OF_LAYERS*2)) attention_v = model.add_parameters( (1, ATTENTION_SIZE)) decoder_w = model.add_parameters( (VOCAB_SIZE, STATE_SIZE)) decoder_b = model.add_parameters( (VOCAB_SIZE)) +output_lookup = model.add_lookup_parameters((VOCAB_SIZE, EMBEDDINGS_SIZE)) def embed_sentence(sentence): sentence = [EOS] + list(sentence) + [EOS] sentence = [char2int[c] for c in sentence] - global lookup + global input_lookup - return [lookup[char] for char in sentence] + return [input_lookup[char] for char in sentence] def run_lstm(init_state, input_vecs): @@ -86,15 +87,16 @@ def decode(dec_lstm, vectors, output): w = pc.parameter(decoder_w) b = pc.parameter(decoder_b) - s = dec_lstm.initial_state().add_input(pc.vecInput(STATE_SIZE*2)) - + last_output_embeddings = output_lookup[char2int[EOS]] + s = dec_lstm.initial_state().add_input(pc.concatenate([pc.vecInput(STATE_SIZE*2), last_output_embeddings])) loss = [] for char in output: - vector = attend(vectors, s) + vector = pc.concatenate([attend(vectors, s), last_output_embeddings]) s = s.add_input(vector) out_vector = w * s.output() + b probs = pc.softmax(out_vector) + last_output_embeddings = output_lookup[char] loss.append(-pc.log(pc.pick(probs, char))) loss = pc.esum(loss) return loss @@ -114,18 +116,20 @@ def sample(probs): w = pc.parameter(decoder_w) b = pc.parameter(decoder_b) - s = dec_lstm.initial_state().add_input(pc.vecInput(STATE_SIZE * 2)) + last_output_embeddings = output_lookup[char2int[EOS]] + s = dec_lstm.initial_state().add_input(pc.concatenate([pc.vecInput(STATE_SIZE * 2), last_output_embeddings])) out = '' count_EOS = 0 for i in range(len(input)*2): if count_EOS == 2: break - vector = attend(encoded, s) + vector = pc.concatenate([attend(encoded, s), last_output_embeddings]) s = s.add_input(vector) out_vector = w * s.output() + b probs = pc.softmax(out_vector) probs = probs.vec_value() next_char = sample(probs) + last_output_embeddings = output_lookup[next_char] if int2char[next_char] == EOS: count_EOS += 1 continue @@ -155,5 +159,3 @@ def train(model, sentence): train(model, "it is working") - -