-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbert_features.py
116 lines (97 loc) · 3.9 KB
/
bert_features.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
#########################################################################################
# Script name: bert_features.py
# Author: Jiao Wenxiang
# Date: 2019-08-12
# Function:
# 1. Extract BERT features for transcripts
# HOWTO:
# a. Reconstruct the utterances from the tokens splitted by ourselves, and save them line by line in a .txt file for each video;
# b. Reprocess our tokens by BERT tokenizer, and record the range of indexes of our tokens in the reconstruted utterance;
# 4. Extract BERT features by the python script and align the features to our tokens based on the recorded ranges.
#########################################################################################
import os
import tqdm
import pickle
import jsonlines
import simplejson as json
import numpy as np
from extract_functions import BertTokenizer,extract_features
import logging
logging.basicConfig(level=logging.INFO)
tokenizer = BertTokenizer.from_pretrained('./pretrained_model_bert', do_lower_case=True)
def loadFrJson(path):
file = open(path, 'r')
obj = json.load(file)
file.close()
return obj
def saveToPickle(path, object):
file = open(path, 'wb')
pickle.dump(object, file)
file.close()
return 1
# Recosntruct utterance from the tokens
# Save the transcripts of each video into a .txt file
# Call the extract_features function to extract top 4 layers' representation, and save into a .jsonl file
def reconUtter(dict_path, input_dir, output_dir):
data_dict = loadFrJson(dict_path)
if not os.path.isdir(input_dir):
os.makedirs(input_dir)
if not os.path.isdir(output_dir):
os.makedirs(output_dir)
for vid,vdata in data_dict.items():
print("Reconstructing {}".format(vid))
textname = vid + ".txt"
input_file = os.path.join(input_dir, textname)
n_segs = len(vdata['data'])
with open(input_file, 'w') as f:
for seg_id in range(n_segs):
# Collect the tokens in an utterance
utt = []
for w in vdata['data'][str(seg_id)]:
utt += w['word']
rc_utt = " ".join(utt)
f.write(rc_utt + "\n")
bertname = vid + "_bert.jsonl"
output_file = os.path.join(output_dir, bertname)
extract_features(input_file=input_file, output_file=output_file, bert_model="./pretrained_model_bert", do_lower_case=True)
# Align BERT tokens to our tokens and initialize BERT features for our tokens
def alignTokens(dict_path, bert_dir, feat_name="dataset_bert.pt"):
data_dict = loadFrJson(dict_path)
for vid,vdata in tqdm.tqdm(data_dict.items(), ncols=100, ascii=True):
bertname = vid + "_bert.jsonl"
bert_file = os.path.join(bert_dir, bertname)
with jsonlines.open(bert_file, 'r') as reader:
# Visit all utterances
for seg_id,obj in enumerate(reader):
# Visit all tokens in each utterance
# [CLS]: 0
start_idx = 1
for w in vdata['data'][str(seg_id)]:
rc_phrase = " ".join(w['word'])
bert_tokens = tokenizer.tokenize(rc_phrase)
end_idx = start_idx + len(bert_tokens)
features = obj['features'][start_idx:end_idx]
# Indexed tokens to check if the alignment is correct
feat_real_tokens = []
# Features of tokens by BertTokenizer
feat_our_token = []
for feature in features:
feat_real_tokens.append(feature["token"])
# Feature of each bert token
feat_bert_token = []
# 4 layers
for layer in [0,1,2,3]:
feat_bert_token.append(np.array(feature["layers"][layer]["values"]))
feat_our_token.append(np.mean(feat_bert_token, axis=0))
w['bert'] = np.mean(feat_our_token, axis=0)
start_idx = end_idx
#print(w['word'], bert_tokens, feat_real_tokens)
saveToPickle(feat_name, data_dict)
def main():
dict_path = "./Dataset/dataset_dd.json"
input_dir = "./Dataset/bert_input"
output_dir = "./Dataset/bert_output"
reconUtter(dict_path=dict_path, input_dir=input_dir, output_dir=output_dir)
alignTokens(dict_path=dict_path, bert_dir=output_dir, feat_name="./Dataset/dataset_bert.pt")
if __name__ == '__main__':
main()