-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathEncoder.py
50 lines (32 loc) · 1.48 KB
/
Encoder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import os
import scipy as sp
import torch
from sklearn import preprocessing
import operators, platforms, OperatorEmbedding
class OperatorEncoder:
def __init__(self) -> None:
self.label_encoder = preprocessing.LabelEncoder()
self.label_encoder.fit([o.name for o in operators.all_opts])
self.onehot_encoder = preprocessing.OneHotEncoder(sparse=False)
self.onehot_encoder.fit([[o.name] for o in operators.all_opts])
p = os.path.join(os.getcwd(), 'data', 'operator_embedding8.ebd')
self.embedding_model = OperatorEmbedding.getModel(p)
def encode(self, opts, method='onehot'):
if type(opts) != type([]):
raise TypeError
if (method == 'onehot'):
return self.onehot_encoder.transform([[o.name] for o in opts])
elif (method == 'embedding'):
return [self.embedding_model[o.name] for o in opts]
class PlatformEncoder:
def __init__(self) -> None:
self.label_encoder = preprocessing.LabelEncoder()
self.label_encoder.fit([p.name for p in platforms.all_plt])
self.onehot_encoder = preprocessing.OneHotEncoder(sparse=False)
self.onehot_encoder.fit([[p.name] for p in platforms.all_plt])
def encode(self, plts, method='onehot'):
if type(plts) != type([]):
raise TypeError
return self.label_encoder.transform([p.name for p in plts])
operator_encoder = OperatorEncoder()
platform_encoder = PlatformEncoder()