-
Notifications
You must be signed in to change notification settings - Fork 13
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Assemble BiDAF model and add training script #15
base: bidaf
Are you sure you want to change the base?
Changes from all commits
a751587
32787aa
23f1e9e
9c6c98e
00b051d
29d8b1e
51f7d2a
7cbd563
ca9c948
1f39fb1
fa24368
8b14acd
f8368f6
152bb4c
5e4975b
69ea71c
74707fc
20ffcd9
d35c802
19c37d2
60a4374
ddaac06
8379a2c
e163925
e72eb64
5742bea
bc6b8a7
bf70d66
dde6539
a1fcdcc
b3d95e6
8998836
2528b8e
2a28253
f62e1c5
634b069
c4f4a3c
6ee7732
11e66c8
5bddd48
b0dbe34
1c15cca
7303444
1571548
cbbfcb3
355116c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
# coding: utf-8 | ||
|
||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
|
||
"""Attention Flow Layer""" | ||
from mxnet import gluon | ||
|
||
from .similarity_function import DotProductSimilarity | ||
|
||
|
||
class AttentionFlow(gluon.HybridBlock): | ||
""" | ||
This ``block`` takes two ndarrays as input and returns a ndarray of attentions. | ||
|
||
We compute the similarity between each row in each matrix and return unnormalized similarity | ||
scores. Because these scores are unnormalized, we don't take a mask as input; it's up to the | ||
caller to deal with masking properly when this output is used. | ||
|
||
By default similarity is computed with a dot product, but you can alternatively use a | ||
parameterized similarity function if you wish. | ||
|
||
|
||
Input: | ||
- ndarray_1: ``(batch_size, num_rows_1, embedding_dim)`` | ||
- ndarray_2: ``(batch_size, num_rows_2, embedding_dim)`` | ||
|
||
Output: | ||
- ``(batch_size, num_rows_1, num_rows_2)`` | ||
|
||
Parameters | ||
---------- | ||
similarity_function: ``SimilarityFunction``, optional (default=``DotProductSimilarity``) | ||
The similarity function to use when computing the attention. | ||
""" | ||
def __init__(self, similarity_function, batch_size, passage_length, | ||
question_length, embedding_size, **kwargs): | ||
super(AttentionFlow, self).__init__(**kwargs) | ||
|
||
self._similarity_function = similarity_function or DotProductSimilarity() | ||
self._batch_size = batch_size | ||
self._passage_length = passage_length | ||
self._question_length = question_length | ||
self._embedding_size = embedding_size | ||
|
||
def hybrid_forward(self, F, matrix_1, matrix_2): | ||
# pylint: disable=arguments-differ | ||
tiled_matrix_1 = matrix_1.expand_dims(2).broadcast_to(shape=(self._batch_size, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To make this hyridizable, we can use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is hybridizable even if we use expand directly on matrix_1. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sounds this is a bug in MXnet based on offline discussion, we should have the automatic registration of the operators in |
||
self._passage_length, | ||
self._question_length, | ||
self._embedding_size)) | ||
tiled_matrix_2 = matrix_2.expand_dims(1).broadcast_to(shape=(self._batch_size, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same as above. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This block is hybridizable even when calling |
||
self._passage_length, | ||
self._question_length, | ||
self._embedding_size)) | ||
return self._similarity_function(tiled_matrix_1, tiled_matrix_2) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
# coding: utf-8 | ||
|
||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
|
||
"""Bidirectional attention flow layer""" | ||
from mxnet import gluon | ||
import numpy as np | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. One question, why are we using There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is only used to get epsilon and minimum value for different precisions, instead of hardcoding the numbers into the code. Look at how Do we have this info available via nd package? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please find the comment in the |
||
|
||
from .utils import last_dim_softmax, weighted_sum, replace_masked_values, masked_softmax | ||
|
||
|
||
class BidirectionalAttentionFlow(gluon.HybridBlock): | ||
""" | ||
This class implements Minjoon Seo's `Bidirectional Attention Flow model | ||
<https://www.semanticscholar.org/paper/Bidirectional-Attention-Flow-for-Machine-Seo-Kembhavi/7586b7cca1deba124af80609327395e613a20e9d>`_ | ||
for answering reading comprehension questions (ICLR 2017). | ||
""" | ||
|
||
def __init__(self, | ||
batch_size, | ||
passage_length, | ||
question_length, | ||
encoding_dim, | ||
**kwargs): | ||
super(BidirectionalAttentionFlow, self).__init__(**kwargs) | ||
|
||
self._batch_size = batch_size | ||
self._passage_length = passage_length | ||
self._question_length = question_length | ||
self._encoding_dim = encoding_dim | ||
|
||
def _get_big_negative_value(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We might want to change this based on ndarray. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry, didn't get it. Do you want to do this based on dtype of NDArray instead of using an There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can keep this as-is. Another way is that we can use python |
||
"""Provides maximum negative Float32 value | ||
Returns | ||
------- | ||
value : float32 | ||
Maximum negative float32 value | ||
""" | ||
return np.finfo(np.float32).min | ||
|
||
def _get_small_positive_value(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We might want to change this based on ndarray. |
||
"""Provides minimal possible Float32 value | ||
Returns | ||
------- | ||
value : float32 | ||
Minimal float32 value | ||
""" | ||
return np.finfo(np.float32).eps | ||
|
||
def hybrid_forward(self, F, passage_question_similarity, | ||
encoded_passage, encoded_question, question_mask, passage_mask): | ||
# pylint: disable=arguments-differ | ||
# Shape: (batch_size, passage_length, question_length) | ||
passage_question_similarity_shape = (self._batch_size, self._passage_length, | ||
self._question_length) | ||
|
||
question_mask_shape = (self._batch_size, self._question_length) | ||
# Shape: (batch_size, passage_length, question_length) | ||
passage_question_attention = last_dim_softmax(F, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There seems no need to pass There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The problem here is that
Since Symbol object doesn't have it (and |
||
passage_question_similarity, | ||
question_mask, | ||
passage_question_similarity_shape, | ||
question_mask_shape, | ||
epsilon=self._get_small_positive_value()) | ||
# Shape: (batch_size, passage_length, encoding_dim) | ||
encoded_question_shape = (self._batch_size, self._question_length, self._encoding_dim) | ||
passage_question_attention_shape = (self._batch_size, self._passage_length, | ||
self._question_length) | ||
passage_question_vectors = weighted_sum(F, encoded_question, passage_question_attention, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The same as the above comment, what's the reason of passing There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same reason as above, but in this case the problem is |
||
encoded_question_shape, | ||
passage_question_attention_shape) | ||
|
||
# We replace masked values with something really negative here, so they don't affect the | ||
# max below. | ||
masked_similarity = passage_question_similarity if question_mask is None else \ | ||
replace_masked_values(F, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same issue, here it is both |
||
passage_question_similarity, | ||
question_mask.expand_dims(1), | ||
replace_with=self._get_big_negative_value()) | ||
|
||
# Shape: (batch_size, passage_length) | ||
question_passage_similarity = masked_similarity.max(axis=-1) | ||
|
||
# Shape: (batch_size, passage_length) | ||
question_passage_attention = masked_softmax(F, question_passage_similarity, passage_mask, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same: |
||
epsilon=self._get_small_positive_value()) | ||
|
||
# Shape: (batch_size, encoding_dim) | ||
encoded_passage_shape = (self._batch_size, self._passage_length, self._encoding_dim) | ||
question_passage_attention_shape = (self._batch_size, self._passage_length) | ||
question_passage_vector = weighted_sum(F, encoded_passage, question_passage_attention, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
encoded_passage_shape, | ||
question_passage_attention_shape) | ||
|
||
# Shape: (batch_size, passage_length, encoding_dim) | ||
tiled_question_passage_vector = question_passage_vector.expand_dims(1) | ||
|
||
# Shape: (batch_size, passage_length, encoding_dim * 4) | ||
final_merged_passage = F.concat(encoded_passage, | ||
passage_question_vectors, | ||
encoded_passage * passage_question_vectors, | ||
F.broadcast_mul(encoded_passage, | ||
tiled_question_passage_vector), | ||
dim=-1) | ||
|
||
return final_merged_passage |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should data always come from the
SQuAD._get_records
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep.
The thing is that the official evaluation scripts need original JSON content, and instead of trying to figure out the path in FS and loading it manually, it is easier to just have a separate method on reading JSON from disk and parsing it. Then I can just send the JSON as is to official evaluation script without a need to know where MXNet stores the file itself.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think in MXnet in general, we will normally download the data into mxnet root dir, and then we will identify whether the data is there and load from there if it exists, otherwise we will download the data from s3.