-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathqe_single_entity.py
122 lines (111 loc) · 4.13 KB
/
qe_single_entity.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
from enum import auto, IntEnum
import models_manager
class QEMethod(IntEnum):
none = auto()
global_word2vec = auto()
specific_word2vec = auto()
class QESingleEntity:
def __init__(self, qe_method=None, k=2, models_manager=None, global_model=None):
self.qe_method = qe_method
self.k = k
self.models_manager = models_manager
self.global_model = (
global_model
if global_model or not models_manager
else models_manager[models_manager.STATIC_YEAR]
)
def expand_entity(self, entity, time=None, qe_method=None, k=None):
"""
Apply QE with the selected method and return the expanded query
"""
if qe_method is None:
qe_method = self.qe_method
if k is None:
k = self.k
entity = entity.replace('_', ' ').title()
if qe_method == QEMethod.none:
related_tuples = []
elif qe_method == QEMethod.global_word2vec:
related_tuples = self.expand_entity_word2vec(
entity, models_manager.STATIC_YEAR, k=k
)
elif qe_method == QEMethod.specific_word2vec:
related_tuples = self.expand_entity_word2vec(entity, time, k=k)
else:
raise ValueError('Unknown QE method: {}'.format(qe_method))
return related_tuples
def expand_vector(self, vector, time=None, qe_method=None, k=None, entity=None):
"""
Apply QE with the selected method and return the expanded query
"""
if qe_method is None:
qe_method = self.qe_method
if k is None:
k = self.k
if qe_method == QEMethod.none:
return []
elif qe_method == QEMethod.global_word2vec:
related_tuples = self.expand_vector_word2vec(
vector, models_manager.STATIC_YEAR, k=k, entity=entity
)
elif qe_method == QEMethod.specific_word2vec:
related_tuples = self.expand_vector_word2vec(
vector, time, k=k, entity=entity
)
else:
raise ValueError('Unknown QE method for vectors: {}'.format(qe_method))
return related_tuples
def expand_entity_word2vec(self, entity, time, k=None):
"""
Get an entity and a timestamp.
Return tuples (term, score) of the k closest terms from the word2vec model of that time period.
"""
w2v_model = (
self.global_model
if time == models_manager.STATIC_YEAR
else self.models_manager[time]
if self.models_manager
else None
)
if not w2v_model:
return None
if k is None:
k = self.k
key = w2v_model.get_key(entity)
if not key:
return None
vector = w2v_model.word_vec(entity)
return self.expand_vector_word2vec(vector, time, k, entity)
def expand_vector_word2vec(self, vector, time, k=None, entity=None):
"""
Get an entity and a timestamp.
Return tuples (term, score) of the k closest terms from the word2vec model of that time period.
`entity` can be either a string or a list of entities that compose the given vector, so they won't be returned.
"""
w2v_model = (
self.global_model
if time == models_manager.STATIC_YEAR
else self.models_manager[time]
if self.models_manager
else None
)
if not w2v_model:
return None
if k is None:
k = self.k
topn = k
if entity: # retrieve extra neighbor(s), and then remove the given word(s)
if isinstance(entity, str):
entity = [entity]
topn += len(entity)
related_tuples = w2v_model.most_similar(
[vector],
topn=topn,
filter_func=lambda word: word in w2v_model
and not w2v_model.is_entity(word),
)
if entity:
related_tuples = [
(score, word) for score, word in related_tuples if word not in entity
][:k]
return related_tuples