-
Notifications
You must be signed in to change notification settings - Fork 16
/
Copy pathcache_dev.py
70 lines (56 loc) · 2.94 KB
/
cache_dev.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import argparse
from datetime import datetime
from docqa import model_dir
from docqa import trainer
from docqa.data_processing.document_splitter import MergeParagraphs, ShallowOpenWebRanker
from docqa.data_processing.multi_paragraph_qa import StratifyParagraphsBuilder, \
StratifyParagraphSetsBuilder, RandomParagraphSetDatasetBuilder
from docqa.data_processing.preprocessed_corpus import PreprocessedData
from docqa.data_processing.qa_training_data import ContextLenBucketedKey
from docqa.dataset import ClusteredBatcher
from docqa.evaluator import LossEvaluator, MultiParagraphSpanEvaluator
from docqa.scripts.ablate_triviaqa import get_model
from docqa.text_preprocessor import WithIndicators
from docqa.trainer import SerializableOptimizer, TrainParams
from docqa.triviaqa.training_data import ExtractMultiParagraphsPerQuestion
from build_span_corpus import XQADataset
def main():
parser = argparse.ArgumentParser()
parser.add_argument("corpus", choices=["en", "fr", "de", "ru", "pt", "zh", "pl", "uk", "ta"])
parser.add_argument('mode', choices=["confidence", "merge", "shared-norm",
"sigmoid", "paragraph"])
# Note I haven't tested modes other than `shared-norm` on this corpus, so
# some things might need adjusting
parser.add_argument("-t", "--n_tokens", default=400, type=int,
help="Paragraph size")
args = parser.parse_args()
mode = args.mode
corpus = args.corpus
model = get_model(100, 140, mode, WithIndicators())
extract = ExtractMultiParagraphsPerQuestion(MergeParagraphs(args.n_tokens),
ShallowOpenWebRanker(16),
model.preprocessor, intern=True)
oversample = [1] * 2 # Sample the top two answer-containing paragraphs twice
if mode == "paragraph":
n_epochs = 120
test = RandomParagraphSetDatasetBuilder(120, "flatten", True, oversample)
train = StratifyParagraphsBuilder(ClusteredBatcher(60, ContextLenBucketedKey(3), True),
oversample, only_answers=True)
elif mode == "confidence" or mode == "sigmoid":
if mode == "sigmoid":
n_epochs = 640
else:
n_epochs = 160
test = RandomParagraphSetDatasetBuilder(120, "flatten", True, oversample)
train = StratifyParagraphsBuilder(ClusteredBatcher(60, ContextLenBucketedKey(3), True), oversample)
else:
n_epochs = 80
test = RandomParagraphSetDatasetBuilder(120, "merge" if mode == "merge" else "group", True, oversample)
train = StratifyParagraphSetsBuilder(30, mode == "merge", True, oversample)
data = XQADataset(corpus)
data = PreprocessedData(data, extract, train, test, eval_on_verified=False)
data.preprocess(1, 1000)
# dump preprocessed dev data for bert
data.cache_preprocess("dev_data_%s.pkl" % args.corpus)
if __name__ == "__main__":
main()