-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathws_exp_sasaki.py
52 lines (40 loc) · 1.42 KB
/
ws_exp_sasaki.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import logging
import multiprocessing as mp
from pathlib import Path
from datasets import prepare_ws_combined_query_path, prepare_target_vector_paths
from sasaki_utils import inference, prepare_codecs_path, train
from utils import dotdict
from ws_exp_pbos import evaluate
logger = logging.getLogger(__name__)
def exp(ref_vec_name):
result_path = Path("results") / "ws" / f"{ref_vec_name}_sasaki"
ref_vec_path = prepare_target_vector_paths(ref_vec_name).w2v_emb_path
codecs_path = prepare_codecs_path(ref_vec_path, result_path)
log_file = open(result_path / "log.txt", "w+")
logging.basicConfig(level=logging.DEBUG, stream=log_file)
logger.info("Training...")
model_info = train(
ref_vec_path,
result_path,
codecs_path=codecs_path,
H=40_000,
F=500_000,
epoch=300,
)
logger.info("Inferencing...")
combined_query_path = prepare_ws_combined_query_path()
result_emb_path = inference(model_info, combined_query_path)
logger.info("Evaluating...")
evaluate(dotdict(
eval_result_path=result_path / "result.txt",
pred_path=result_emb_path
))
if __name__ == '__main__':
with mp.Pool() as pool:
target_vector_names = ("polyglot", "google")
results = [
pool.apply_async(exp, (ref_vec_name,))
for ref_vec_name in target_vector_names
]
for r in results:
r.get()