-
Notifications
You must be signed in to change notification settings - Fork 38
/
Copy pathgenerate_data.py
executable file
·187 lines (164 loc) · 8.12 KB
/
generate_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
#coding:utf8
# We did a lot of data preprocessing before feeding the data into our SGNN model to accelerate.
from gnn_with_args import *
def get_event_chains(event_list):
return [['%s_%s' % (ev[0],ev[2]) for ev in event_list],['%s' % ev[3] for ev in event_list],['%s' % ev[4] for ev in event_list],['%s' % ev[5] for ev in event_list]]
class Data_txt(object):
def __init__(self, questions):
super(Data_txt, self).__init__()
# random.shuffle(questions)
self.corpus = questions
self.corpus_length=len(questions)
self.start=0
def next_batch(self,batch_size):
batch=[]
for i in range(self.start,self.start+batch_size):
i=i%self.corpus_length
q=self.corpus[i]
context_chains=get_event_chains(q[0])
choices_chains=get_event_chains(q[1])
batch.append([context_chains,choices_chains,q[2]])
self.start=(self.start+batch_size)
if self.start<self.corpus_length:
epoch_flag=False
else:
self.start=self.start%self.corpus_length
epoch_flag=True
return batch,epoch_flag
def all_data(self):
batch=[]
for i in range(0,self.corpus_length):
q=self.corpus[i]
context_chains=get_event_chains(q[0])
choices_chains=get_event_chains(q[1])
batch.append([context_chains,choices_chains,q[2]])
return batch
def build_graph(filename):
graph=nx.DiGraph()
for s in open(filename):
s=s.strip().split()
graph.add_edge(s[0],s[1],weight=float(s[2]))
return graph
def return_id_list(event_list,word_id):
id_list=[]
for event in event_list:
if event in word_id:
id_list.append(word_id[event])
else:
id_list.append('0')
return id_list
def get_matrix(g,node_list,edge_list):
node_list_map={}
for i in node_list:
node_list_map[i]=len(node_list_map)
node_num=len(node_list) #13
A=np.zeros((node_num,node_num),dtype=np.float32)
for edge in edge_list:
start=edge[0]
end=edge[1]
A[node_list_map[start]][node_list_map[end]]=g[start][end]['weight']
return A
def get_matrix_for_chain(g,node_list,edge_list):
node_list_map={}
for i in node_list:
node_list_map[i]=len(node_list_map)
node_num=len(node_list) #13
A=np.zeros((node_num,node_num),dtype=np.float32)
for i,node in enumerate(node_list[0:7]):
if (node_list[i],node_list[i+1]) in edge_list:
start=node_list[i]
end=node_list[i+1]
A[node_list_map[start]][node_list_map[end]]=g[start][end]['weight']
for i,node in enumerate(node_list[8:13]):
if (node_list[7],node) in edge_list:
start=node_list[7]
end=node
A[node_list_map[start]][node_list_map[end]]=g[start][end]['weight']
return A
def process(data,word_id,g,predict=False):
input_data=[]
targets=[]
A=[]
pbar=get_progress_bar(len(data),title='Process Data')
for i in range(len(data)):
pbar.update(i)
context,choice,answer=data[i]
targets.append(answer)
context_id=return_id_list(context[0],word_id)
choice_id=return_id_list(choice[0],word_id)
context_subject_id=return_id_list(context[1],word_id)
choice_subject_id=return_id_list(choice[1],word_id)
context_object_id=return_id_list(context[2],word_id)
choice_object_id=return_id_list(choice[2],word_id)
context_perp_id=return_id_list(context[3],word_id)
choice_perp_id=return_id_list(choice[3],word_id)
node_list=context_id+choice_id
node_list_int=[int(i) for i in node_list]
node_list_subject=context_subject_id+choice_subject_id
node_list_int_subject=[int(i) for i in node_list_subject]
node_list_object=context_object_id+choice_object_id
node_list_int_object=[int(i) for i in node_list_object]
node_list_perp=context_perp_id+choice_perp_id
node_list_int_perp=[int(i) for i in node_list_perp]
# print node_list_int+node_list_int_subject+node_list_int_object
input_data.append(node_list_int+node_list_int_subject+node_list_int_object+node_list_int_perp)
new_g=g.subgraph(node_list)
edge_list=list(new_g.edges())
# A.append(get_matrix(new_g,node_list,edge_list))
A.append(get_matrix_for_chain(new_g,node_list,edge_list))
pbar.finish()
A=Variable(torch.from_numpy(np.array(A)))
if not predict:
input_data=Variable(torch.from_numpy(np.array(input_data)))
else:
input_data=Variable(torch.from_numpy(np.array(input_data)),volatile=True)
targets=Variable(torch.from_numpy(np.array(targets)))
return A,input_data,targets
def dump_data():
dev_small_data=Data_txt(pickle.load(open('../data/corpus_index_dev_small.txt','rb')))
dev_data=Data_txt(pickle.load(open('../data/corpus_index_dev.txt','rb')))
test_data=Data_txt(pickle.load(open('../data/corpus_index_test.txt','rb')))
train_data=Data_txt(pickle.load(open('../data/corpus_index_train0.txt','rb')))
print ('train data prepare done')
word_id,id_vec,word_vec=get_hash_for_word('../data/deepwalk_128_unweighted_with_args.txt',verb_net3_mapping_with_args)
g=build_graph('../data/data2.csv')
print ('word vector prepare done')
A,input_data,targets=process(dev_small_data.all_data(),word_id,g)
pickle.dump([A,input_data,targets],open('../data/corpus_index_dev_small_with_args_all_chain.data','wb'),-1)
print ('dev_small_data done.')
A,input_data,targets=process(dev_data.all_data(),word_id,g)
pickle.dump([A,input_data,targets],open('../data/corpus_index_dev_with_args_all_chain.data','wb'),-1)
print ('dev_data done.')
A,input_data,targets=process(test_data.all_data(),word_id,g)
pickle.dump([A,input_data,targets],open('../data/corpus_index_test_with_args_all_chain.data','wb'),-1)
print ('test_data done.')
A,input_data,targets=process(train_data.all_data(),word_id,g)
pickle.dump([A,input_data,targets],open('../data/corpus_index_train0_with_args_all_chain.data','wb'),-1)
print ('train_data done.')
def process_matrix(data):
A=data[0]
new_A=Variable(torch.zeros_like(A.data))
for i in range(A.shape[0]):
for j in range(A.shape[1]):
for k in range(A.shape[2]):
if (A[i,j,k]!=0).data[0]:
new_A[i,j,k]=0.01
return [new_A,data[1],data[2]]
def change_graph_to_unweighted():
dev_data_small=pickle.load(open('../data/corpus_index_dev_small_with_args_all.data','rb'))
dev_data=pickle.load(open('../data/corpus_index_dev_with_args_all.data','rb'))
test_data=pickle.load(open('../data/corpus_index_test_with_args_all.data','rb'))
train_data=pickle.load(open('../data/corpus_index_train0_with_args_all.data','rb'))
pickle.dump(process_matrix(dev_data_small),open('../data/corpus_index_dev_small_with_args_all_unweighted.data','wb'),-1)
pickle.dump(process_matrix(dev_data),open('../data/corpus_index_dev_with_args_all_unweighted.data','wb'),-1)
pickle.dump(process_matrix(test_data),open('../data/corpus_index_test_with_args_all_unweighted.data','wb'),-1)
pickle.dump(process_matrix(train_data),open('../data/corpus_index_train0_with_args_all_unweighted.data','wb'),-1)
def change_chain_to_unweighted():
dev_data_small=pickle.load(open('../data/corpus_index_dev_small_with_args_all_chain.data','rb'))
dev_data=pickle.load(open('../data/corpus_index_dev_with_args_all_chain.data','rb'))
test_data=pickle.load(open('../data/corpus_index_test_with_args_all_chain.data','rb'))
train_data=pickle.load(open('../data/corpus_index_train0_with_args_all_chain.data','rb'))
pickle.dump(process_matrix(dev_data_small),open('../data/corpus_index_dev_small_with_args_all_chain_unweighted.data','wb'),-1)
pickle.dump(process_matrix(dev_data),open('../data/corpus_index_dev_with_args_all_chain_unweighted.data','wb'),-1)
pickle.dump(process_matrix(test_data),open('../data/corpus_index_test_with_args_all_chain_unweighted.data','wb'),-1)
pickle.dump(process_matrix(train_data),open('../data/corpus_index_train0_with_args_all_chain_unweighted.data','wb'),-1)