-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconvert.py
70 lines (49 loc) · 2.02 KB
/
convert.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import json
import argparse
from collections import defaultdict
from utils.io_utils import load_json, save_json
def bspn_to_constraint_dict(bspn):
bspn = bspn.replace('<bos_belief>', '')
bspn = bspn.replace('<eos_belief>', '')
bspn = bspn.strip().split()
constraint_dict = {}
domain, slot = None, None
for token in bspn:
if token.startswith('['):
token = token[1:-1]
if token.startswith('value_'):
if domain is None:
continue
if domain not in constraint_dict:
constraint_dict[domain] = {}
slot = token.split('_')[1]
constraint_dict[domain][slot] = []
else:
domain = token
else:
try:
constraint_dict[domain][slot].append(token)
except KeyError:
continue
for domain, sv_dict in constraint_dict.items():
for s, value_tokens in sv_dict.items():
constraint_dict[domain][s] = ' '.join(value_tokens)
return constraint_dict
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Conversion Output')
parser.add_argument('-input', type=str, required=True)
parser.add_argument('-output', type=str, required=True)
args = parser.parse_args()
results = load_json(args.input)
converted_results = defaultdict(list)
for dial_id, dial in results.items():
dial_id = dial_id.split('.')[0]
for turn in dial:
converted_turn = {'response': '', 'state': {}}
resp = turn['resp_gen']
resp = resp.replace('<bos_resp>', '')
resp = resp.replace('<eos_resp>', '')
converted_turn['response'] = resp.strip()
converted_turn['state'] = bspn_to_constraint_dict(turn['bspn_gen'])
converted_results[dial_id].append(converted_turn)
save_json(converted_results, args.output)