-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathprepare_question.py
executable file
·53 lines (50 loc) · 2.31 KB
/
prepare_question.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import re
import json
import yaml
import torch
def prepare_questions(annotations):
prepared = []
questions = [q['question'] for q in annotations]
for question in questions:
question = question.lower()
punctuation_dict = {'.': ' ', "'": '', '?': ' ', '_': ' ', '-': ' ',
'/': ' ', ',': ' '}
conversational_dict = {"thank you": '', "thanks": '', "thank": '',
"please": '', "hello": '',
"hi ": ' ', "hey ": ' ', "good morning": '',
"good afternoon": '', "have a nice day": '',
"okay": '', "goodbye": ''}
rep = punctuation_dict
rep.update(conversational_dict)
rep = dict((re.escape(k), v) for k, v in rep.items())
pattern = re.compile("|".join(rep.keys()))
question = pattern.sub(lambda m: rep[re.escape(m.group(0))], question)
question = question.split(' ')
question = list(filter(None, question))
prepared.append(question)
return prepared
def encode_question(question, token_to_index, max_length):
question_vec = torch.zeros(max_length).long()
length = min(len(question), max_length)
for i in range(length):
token = question[i]
index = token_to_index.get(token, 0)
question_vec[i] = index
return question_vec, max(length, 1)
if __name__ == '__main__':
config = 'config/default.yaml'
with open(config, 'r') as f: config = yaml.load(f)
vocabs = config['annotations']['path_vocabs']
with open(vocabs, 'r') as f: vocabs = json.load(f)
token_to_index = vocabs['question']
max_question_length = config['annotations']['max_length']
splits = ['train','val']
for split in splits:
annotations = config['annotations']['dir'] + '/' + split + '.json'
with open(annotations, 'r') as f: annotations_json = json.load(f)
questions = prepare_questions(annotations_json)
question_json = 'data/' + split + '_questions.json'
with open(question_json, 'w') as f: json.dump(questions, f, indent = 2)
with open(question_json, 'r') as f: question_json = json.load(f)
questions = [encode_question(question, token_to_index,
max_question_length) for question in questions]