-
Notifications
You must be signed in to change notification settings - Fork 16
/
Copy pathablate_xqa.py
88 lines (72 loc) · 3.83 KB
/
ablate_xqa.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import argparse
from datetime import datetime
from docqa import model_dir
from docqa import trainer
from docqa.data_processing.document_splitter import MergeParagraphs, ShallowOpenWebRanker
from docqa.data_processing.multi_paragraph_qa import StratifyParagraphsBuilder, \
StratifyParagraphSetsBuilder, RandomParagraphSetDatasetBuilder
from docqa.data_processing.preprocessed_corpus import PreprocessedData
from docqa.data_processing.qa_training_data import ContextLenBucketedKey
from docqa.dataset import ClusteredBatcher
from docqa.evaluator import LossEvaluator, MultiParagraphSpanEvaluator
from docqa.scripts.ablate_triviaqa import get_model
from docqa.text_preprocessor import WithIndicators
from docqa.trainer import SerializableOptimizer, TrainParams
from docqa.triviaqa.training_data import ExtractMultiParagraphsPerQuestion
from build_span_corpus import XQADataset
def main():
parser = argparse.ArgumentParser()
parser.add_argument("corpus", choices=["en", "en_trans_de", "en_trans_zh"])
parser.add_argument('mode', choices=["confidence", "merge", "shared-norm",
"sigmoid", "paragraph"])
# Note I haven't tested modes other than `shared-norm` on this corpus, so
# some things might need adjusting
parser.add_argument("name", help="Where to store the model")
parser.add_argument("-t", "--n_tokens", default=400, type=int,
help="Paragraph size")
parser.add_argument('-n', '--n_processes', type=int, default=2,
help="Number of processes (i.e., select which paragraphs to train on) "
"the data with"
)
args = parser.parse_args()
mode = args.mode
corpus = args.corpus
out = args.name + "-" + datetime.now().strftime("%m%d-%H%M%S")
model = get_model(100, 140, mode, WithIndicators())
extract = ExtractMultiParagraphsPerQuestion(MergeParagraphs(args.n_tokens),
ShallowOpenWebRanker(16),
model.preprocessor, intern=True)
eval = [LossEvaluator(), MultiParagraphSpanEvaluator(8, "triviaqa", mode != "merge", per_doc=False)]
oversample = [1] * 2 # Sample the top two answer-containing paragraphs twice
if mode == "paragraph":
n_epochs = 120
test = RandomParagraphSetDatasetBuilder(120, "flatten", True, oversample)
train = StratifyParagraphsBuilder(ClusteredBatcher(60, ContextLenBucketedKey(3), True),
oversample, only_answers=True)
elif mode == "confidence" or mode == "sigmoid":
if mode == "sigmoid":
n_epochs = 640
else:
n_epochs = 160
test = RandomParagraphSetDatasetBuilder(120, "flatten", True, oversample)
train = StratifyParagraphsBuilder(ClusteredBatcher(60, ContextLenBucketedKey(3), True), oversample)
else:
n_epochs = 80
test = RandomParagraphSetDatasetBuilder(120, "merge" if mode == "merge" else "group", True, oversample)
train = StratifyParagraphSetsBuilder(30, mode == "merge", True, oversample)
data = XQADataset(corpus)
params = TrainParams(
SerializableOptimizer("Adadelta", dict(learning_rate=1)),
num_epochs=n_epochs, ema=0.999, max_checkpoints_to_keep=2,
async_encoding=10, log_period=30, eval_period=1800, save_period=1800,
best_weights=("dev", "b8/question-text-f1"),
eval_samples=dict(dev=None, train=6000)
)
data = PreprocessedData(data, extract, train, test, eval_on_verified=False)
data.preprocess(args.n_processes, 1000)
with open(__file__, "r") as f:
notes = f.read()
notes = "Mode: " + args.mode + "\n" + notes
trainer.start_training(data, model, params, eval, model_dir.ModelDir(out), notes)
if __name__ == "__main__":
main()