forked from jodaiber/semantic_compound_splitting
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdecompound_annoy.py
299 lines (249 loc) · 12 KB
/
decompound_annoy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
__author__ = 'lqrz'
import cPickle as pickle
import logging
import pdb
from nltk.corpus import PlaintextCorpusReader
from nltk.tokenize import WhitespaceTokenizer
import sys
import multiprocessing as mp
import codecs
from annoy import AnnoyIndex
from sklearn.metrics.pairwise import cosine_similarity
def decompound((inputCompound, nAccuracy, bestSimilarity)):
global annoy_tree
global vectors
global pickledIndexes
global pickledVectors
global globalNN
if len(inputCompound) == 0:
return []
try:
logger.debug('Looking up word '+inputCompound+' in index dict')
inputCompoundIndex = pickledIndexes[inputCompound]
logger.debug('Found key in index dict for word '+inputCompound)
except KeyError:
logger.debug('ERROR COULDNT FIND KEY '+inputCompound+' IN INDEX VECTOR')
# return [(inputCompound, 'Noinputrep', '')]
return [inputCompound]
try:
logger.debug('Looking up index '+str(inputCompoundIndex))
inputCompoundRep = pickledVectors[inputCompoundIndex]
logger.debug('Found key in vector dict for index '+str(inputCompoundIndex))
except KeyError:
logger.debug('ERROR COULDNT FIND KEY '+str(inputCompoundIndex)+' IN VECTOR DICT')
# return [(inputCompound, 'Noinputrep', '')]
return [inputCompound]
# get all matching prefixes
logger.info('Getting all matching prefixes')
prefixes = set()
for prefix in vectors.keys():
found = inputCompound.find(prefix)
if found == 0 and len(vectors[prefix]) > 0 and len(inputCompound[len(prefix):])>0:
prefixes.add(prefix)
logger.debug('Possible prefixes')
logger.debug(prefixes)
# get all possible splits
logger.info('Getting possible splits')
splits = set()
splitsWithNoRep = set()
for prefix in prefixes:
fugenlaute = ['', 'e', 'es']
for fug in fugenlaute:
if fug == '' or inputCompound[len(prefix):].find(fug) == 0:
if not debug:
try:
# look for the uppercased rest representation
tail = inputCompound[len(prefix) + len(fug):].title()
logger.debug('Tail: '+tail)
tailRepresentationIndex = pickledIndexes[tail]
logger.debug('Tail index: '+str(tailRepresentationIndex))
splits.add((prefix, tail, tailRepresentationIndex))
msg = ' '.join(['Considering split', inputCompound, prefix, tail])
logger.debug(msg)
except KeyError:
# if i dont have a vector rep for the rest, i discard it
splitsWithNoRep.add((prefix, tail))
msg = ' '.join(['Discarding split', inputCompound, prefix, tail])
logger.debug(msg)
try:
# look for the lowercased rest representation
tail = inputCompound[len(prefix) + len(fug):]
logger.debug('Tail: '+tail)
tailRepresentationIndex = pickledIndexes[tail]
logger.debug('Tail index: '+str(tailRepresentationIndex))
splits.add((prefix, tail, tailRepresentationIndex))
msg = ' '.join(['Considering split', inputCompound, prefix, tail])
logger.debug(msg)
except KeyError:
# if i dont have a vector rep for the rest, i discard it
msg = ' '.join(['Discarding split', inputCompound, prefix, tail])
logger.debug(msg)
splitsWithNoRep.add((prefix, tail))
continue
if len(splits) == 0:
logger.error('Cannot decompound '+inputCompound)
# exit()
# return [(inputCompound, 'Notailrep', '')]
return [inputCompound]
# apply direction vectors to splits
logger.info('Applying direction vectors to possible splits')
representations = set()
# bestSimilarity = 0.46 # so we do not split "Bahnhof" = ["Bahn", "Hof"]
best = None
maxEvidence = 0
bestEvidence = None
for prefix, tail, tailRepresentationIndex in splits:
msg = ' '.join(['Applying', str(len(vectors[prefix])), 'direction vectors to split', prefix, tail])
logger.debug(msg)
for origin, evidence in vectors[prefix]:
logger.debug('Prefix '+prefix+' by indexes '+str(origin[0])+' and '+str(origin[1]))
try:
dirVectorCompoundRepresentation = pickledVectors[origin[0]]
logger.debug('Found key in vector dict for index '+str(origin[0]))
except KeyError:
logger.debug('ERROR COULDNT FIND KEY '+str(origin[0])+' IN VECTOR DICT')
continue
try:
dirVectorTailRepresentation = pickledVectors[origin[1]]
logger.debug('Found key in vector dict for index '+str(origin[1]))
except KeyError:
logger.debug('ERROR COULDNT FIND KEY '+str(origin[1])+' IN VECTOR DICT')
continue
dirVectorDifference = dirVectorCompoundRepresentation - dirVectorTailRepresentation
try:
logger.debug('Looking up tail index '+str(tailRepresentationIndex))
predictionRepresentation = pickledVectors[tailRepresentationIndex] + dirVectorDifference
logger.debug('Found key in vector dict for index '+str(tailRepresentationIndex))
except KeyError:
logger.debug('ERROR COULDNT FIND KEY '+str(tailRepresentationIndex)+' IN VECTOR DICT')
continue
# accuracy
# neighbours = sorted(model.most_similar(positive=[predictionRepresentation], negative=[], topn=nAccuracy), \
# key=lambda x: x[1], reverse=True)
logger.debug('Getting Annoy KNN')
try:
# neighbours = annoy_tree.get_nns_by_vector(list(predictionRepresentation), nAccuracy)
neighbours = annoy_tree.get_nns_by_vector(list(predictionRepresentation), globalNN)[:nAccuracy]
logger.debug(neighbours)
except:
logger.error('Problem found when retrieving KNN for prediction representation')
logger.error(list(predictionRepresentation))
exit()
try:
rank = [i for i, nei in enumerate(neighbours) if nei == inputCompoundIndex][0]
logger.debug(str(inputCompoundIndex)+' found in neighbours. Rank: '+str(rank))
similarity = cosine_similarity(predictionRepresentation, inputCompoundRep)[0][0]
logger.debug('Computed cosine similarity: '+str(similarity))
res = (prefix, tail, origin[0], origin[1], rank, similarity)
representations.add(res)
if similarity > bestSimilarity: # compare cosine similarity
logger.debug('Found new best similarity score. Old: '+str(bestSimilarity)+' New: '+str(similarity))
bestSimilarity = similarity
best = res
except IndexError:
logger.debug(str(inputCompoundIndex)+' not found in neighbours. NO RANK. WONT SPLIT')
continue
# except IndexError:
# splitsWithNoRep.add((prefix, tail))
# res = (prefix, tail, origin[0], origin[1])
# if len(evidence) > maxEvidence:
# maxEvidence = len(evidence)
# bestEvidence = res
#
# continue
logger.debug('Choosing best direction vector')
chosenSplit = None
if best:
chosenSplit = best
msg = ' '.join(['Splitting',inputCompound,'as', chosenSplit[0], chosenSplit[1], str(chosenSplit[2]), \
str(chosenSplit[3]), 'rank', str(chosenSplit[4]), 'similarity', str(chosenSplit[5])])
logger.debug(msg)
logger.debug('Decompounding '+chosenSplit[1])
else:
# nobody got the original representation within the KNN
# chosenSplit = bestEvidence
# chosenSplit = (inputCompound, '') # not split at all
chosenSplit = (inputCompound, '') # not split at all
logger.debug('Not splitting compound '+inputCompound)
# logging.debug('Found prefix '+chosenSplit[0])
# logging.debug('Decompounding '+chosenSplit[1])
return [chosenSplit[0]] + decompound((chosenSplit[1], nAccuracy, bestSimilarity))
# return [(inputCompound, chosenSplit[0], chosenSplit[1])] # do not apply recursion
if __name__ == '__main__':
# logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger('')
hdlr = logging.FileHandler('decompound_annoy.log')
formatter = logging.Formatter('%(asctime)s %(levelname)s %(message)s')
hdlr.setFormatter(formatter)
logger.addHandler(hdlr)
logger.setLevel(logging.DEBUG)
# resultsPath = 'results/dir_vecs_4_100.p'
# annoyTreeFile = 'tree.ann'
# pickledIndexes = pickle.load(open('decompoundIndexes.p','rb'))
# pickledVectors = pickle.load(open('decompoundVectors.p','rb'))
# corpusPath = './prueba.txt'
# outPath = 'splits.txt'
# multiprocessed = True
# nWorkers = 4
#TODO: define threshold
# nAccuracy= 100
globalNN = 500
if len(sys.argv) == 11:
resultsPath = sys.argv[1]
# w2vPath = sys.argv[2]
annoyTreeFile = sys.argv[2]
corpusPath = sys.argv[3]
pickledIndexesPath = sys.argv[4]
pickledVectorsPath = sys.argv[5]
multiprocessed = sys.argv[6] == 'True'
nWorkers = sys.argv[7]
outPath = sys.argv[8]
nAccuracy = int(sys.argv[9])
similarityThreshold = float(sys.argv[10])
elif len(sys.argv)>1:
print 'Error in params'
exit()
idx = corpusPath.rfind('/') + 1
folder = corpusPath[0:idx]
filename = corpusPath[idx:]
logger.debug('Corpus folder: '+folder)
logger.debug('Corpus filename: '+filename)
corpus = PlaintextCorpusReader(folder, filename, word_tokenizer=WhitespaceTokenizer(), encoding='utf-8')
inputCompounds = corpus.words()
logger.debug('Words in corpus')
logger.debug(inputCompounds)
debug = False
logger.info('Getting pickled direction vectors file')
vectors = pickle.load(open(resultsPath, 'rb'))
logger.info('Getting pickled indexes')
pickledIndexes = pickle.load(open(pickledIndexesPath,'rb'))
pickledVectors = pickle.load(open(pickledVectorsPath,'rb'))
logger.info('Getting annoy tree')
# model = gensim.models.Word2Vec.load_word2vec_format(w2vPath, binary=True)
annoy_tree = AnnoyIndex(500)
annoy_tree.load(annoyTreeFile)
if multiprocessed:
logger.info('Instantiating pool with '+str(nWorkers))
pool = mp.Pool(processes=int(nWorkers))
results = pool.map(decompound, zip(inputCompounds, [nAccuracy]*len(inputCompounds), \
[similarityThreshold]*len(inputCompounds)))
else:
results = []
for inputCompound in inputCompounds:
# try:
# inputCompoundIndex = pickledIndexes[inputCompound]
# compoundRepresentation = pickledVectors[inputCompoundIndex]
# results.append(decompound((inputCompound, nAccuracy)))
results.append(decompound((inputCompound, nAccuracy, similarityThreshold)))
# except KeyError:
# logger.error('No word2vec representation for input compound'+inputCompound)
# # exit()
# results.append(inputCompound)
print results
fout = codecs.open(outPath, 'w', encoding='utf-8')
for i, split in enumerate(results):
fout.write(inputCompounds[i] + '\t' + ' '.join(split) + '\n')
# for comp, decomp1, decomp2 in split:
# fout.write(comp + '\t' + decomp1 + '\t' + decomp2 + '\n')
fout.close()
logger.info('End')