You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
In the paper, you mentioned about "we used 300D GloVe [11] vectors to initialize the word embeddings and then passed it to a GRU network and a question attention module to extract attentive text features"
However, in question_embeding.py, there are two methods for question embedding. As I noticed in the config, you're using att_que_embed, which does not pass through a GRU layer.
defbuild_question_encoding_module(method, par, num_vocab):
ifmethod=="default_que_embed":
returnQuestionEmbeding(num_vocab, **par)
elifmethod=="att_que_embed":
returnAttQuestionEmbedding(num_vocab, **par)
else:
raiseNotImplementedError(
"unknown question encoding model %s"%method)
classQuestionEmbeding(nn.Module):
def__init__(self, **kwargs):
super(QuestionEmbeding, self).__init__()
self.text_out_dim=kwargs['LSTM_hidden_size']
self.num_vocab=kwargs['num_vocab']
self.embedding_dim=kwargs['embedding_dim']
self.embedding=nn.Embedding(
kwargs['num_vocab'], kwargs['embedding_dim'])
self.gru=nn.GRU(
input_size=kwargs['embedding_dim'],
hidden_size=kwargs['LSTM_hidden_size'],
num_layers=kwargs['lstm_layer'],
dropout=kwargs['lstm_dropout'],
batch_first=True)
self.batch_first=Trueif'embedding_init'inkwargsandkwargs['embedding_init'] isnotNone:
self.embedding.weight.data.copy_(
torch.from_numpy(kwargs['embedding_init']))
defforward(self, input_text):
embeded_txt=self.embedding(input_text)
out, hidden_state=self.gru(embeded_txt)
res=out[:, -1]
returnresclassAttQuestionEmbedding(nn.Module):
def__init__(self, num_vocab, **kwargs):
super(AttQuestionEmbedding, self).__init__()
self.embedding=nn.Embedding(num_vocab, kwargs['embedding_dim'])
self.LSTM=nn.LSTM(input_size=kwargs['embedding_dim'],
hidden_size=kwargs['LSTM_hidden_size'],
num_layers=kwargs['LSTM_layer'],
batch_first=True)
self.Dropout=nn.Dropout(p=kwargs['dropout'])
self.conv1=nn.Conv1d(
in_channels=kwargs['LSTM_hidden_size'],
out_channels=kwargs['conv1_out'],
kernel_size=kwargs['kernel_size'],
padding=kwargs['padding'])
self.conv2=nn.Conv1d(
in_channels=kwargs['conv1_out'],
out_channels=kwargs['conv2_out'],
kernel_size=kwargs['kernel_size'],
padding=kwargs['padding'])
self.text_out_dim=kwargs['LSTM_hidden_size'] *kwargs['conv2_out']
if'embedding_init_file'inkwargs \
andkwargs['embedding_init_file'] isnotNone:
ifos.path.isabs(kwargs['embedding_init_file']):
embedding_file=kwargs['embedding_init_file']
else:
embedding_file=os.path.join(
cfg.data.data_root_dir, kwargs['embedding_init_file'])
embedding_init=np.load(embedding_file)
self.embedding.weight.data.copy_(torch.from_numpy(embedding_init))
defforward(self, input_text):
batch_size, _=input_text.data.shapeembed_txt=self.embedding(input_text) # N * T * embedding_dim# self.LSTM.flatten_parameters()lstm_out, _=self.LSTM(embed_txt) # N * T * LSTM_hidden_sizelstm_drop=self.Dropout(lstm_out) # N * T * LSTM_hidden_sizelstm_reshape=lstm_drop.permute(0, 2, 1) # N * LSTM_hidden_size * Tqatt_conv1=self.conv1(lstm_reshape) # N x conv1_out x Tqatt_relu=F.relu(qatt_conv1)
qatt_conv2=self.conv2(qatt_relu) # N x conv2_out x Tqtt_softmax=F.softmax(qatt_conv2, dim=2)
# N * conv2_out * LSTM_hidden_sizeqtt_feature=torch.bmm(qtt_softmax, lstm_drop)
# N * (conv2_out * LSTM_hidden_size)qtt_feature_concat=qtt_feature.view(batch_size, -1)
returnqtt_feature_concat
The text was updated successfully, but these errors were encountered:
Yup, you are correct. I am not sure but I think it is a typo in the paper. It should be more clear in newer version. Please reopen if you have any more questions.
In the paper, you mentioned about "we used 300D GloVe [11] vectors to initialize the word embeddings and then passed it to a GRU network and a question attention module to extract attentive text features"
However, in
question_embeding.py
, there are two methods for question embedding. As I noticed in the config, you're usingatt_que_embed
, which does not pass through a GRU layer.The text was updated successfully, but these errors were encountered: