-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathextract_word2vec.py
75 lines (63 loc) · 2.48 KB
/
extract_word2vec.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
from pathlib import Path
import fasttext
import numpy as np
import pandas as pd
def main():
# model = fasttext.load_model("/shared-local/datasets/word2vec/wiki.en.bin")
model = fasttext.load_model("/shared-network/lriesch29/word2vec/wiki.en.bin")
#audioset_classes = get_audioset_classes()
#vggsound_classes = get_vggsound_classes()
#data_audioset = extract_label_embeddings(model, audioset_classes)
#data_vggsound = extract_label_embeddings(model, vggsound_classes)
#ucf_classes = get_ucf_classes()
activity_classes = _get_class_names(Path("/home/lriesch29/ExplainableAudioVisualLowShotLearning/dat/ActivityNet/class-split/all_class.txt"))
#data_ucf = extract_label_embeddings(model, ucf_classes)
data_activity = extract_label_embeddings(model, activity_classes)
#print(len(data_audioset))
#print(len(data_vggsound))
#print(len(data_ucf))
print(len(data_activity))
#np.save('word_embeddings_audiosetzsl_normed.npy', data_audioset)
#np.save('word_embeddings_vggsound_normed.npy', data_vggsound)
#np.save('word_embeddings_ucf_normed.npy', data_ucf)
np.save('word_embeddings_activity_normed.npy', data_activity)
def get_audioset_classes():
path_audioset = Path("data/all_class_clean.txt")
classes = []
with path_audioset.open() as f:
for line in f:
classes.append(line.strip())
return classes
#def get_ucf_classes():
# path = Path("data/ucf_class_clean.txt")
# classes = []
# with path.open() as f:
# for line in f:
# classes.append(line.strip())
# return classes
def get_ucf_classes():
return list(pd.read_csv("/home/lriesch29/akata-shared/shared/avzsl/UCF/class-split/ucf_manual_names_ask.csv").manual)
def get_vggsound_classes():
path_vggsound = Path("data/vggsound_class_clean.txt")
classes = []
with path_vggsound.open() as f:
for line in f:
classes.append(line.strip())
return classes
def _get_class_names(path):
if isinstance(path, str):
path = Path(path)
with path.open("r") as f:
classes = sorted([line.strip() for line in f])
return classes
def extract_label_embeddings(model, classes, normalize=True):
result = {}
for c in classes:
value = np.array(model.get_word_vector(c))
if normalize:
value = value / np.linalg.norm(value)
np.testing.assert_almost_equal(np.linalg.norm(value), 1)
result[c] = value
return result
if __name__ == '__main__':
main()