-
Notifications
You must be signed in to change notification settings - Fork 31
/
Copy path_model.py
97 lines (84 loc) · 3.66 KB
/
_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import keras as ks
from keras.layers import Dense
from kgcnn.layers.conv import GIN, GINE
from kgcnn.layers.mlp import GraphMLP, MLP
from kgcnn.layers.modules import Embedding
from kgcnn.layers.pooling import PoolingNodes
def model_disjoint(inputs,
use_node_embedding: bool = None,
input_node_embedding: dict = None,
depth: int = None,
gin_args: dict = None,
gin_mlp: dict = None,
last_mlp: dict = None,
dropout: float = None,
output_embedding: str = None,
output_mlp: dict = None):
n, disjoint_indices, batch_id_node, count_nodes = inputs
# Embedding, if no feature dimension
if use_node_embedding:
n = Embedding(**input_node_embedding)(n)
# Model
# Map to the required number of units.
n_units = gin_mlp["units"][-1] if isinstance(gin_mlp["units"], list) else int(gin_mlp["units"])
n = Dense(n_units, use_bias=True, activation='linear')(n)
list_embeddings = [n]
for i in range(0, depth):
n = GIN(**gin_args)([n, disjoint_indices])
n = GraphMLP(**gin_mlp)([n, batch_id_node, count_nodes])
list_embeddings.append(n)
# Output embedding choice
if output_embedding == "graph":
out = [PoolingNodes()([count_nodes, x, batch_id_node]) for x in list_embeddings] # will return tensor
out = [MLP(**last_mlp)(x) for x in out]
out = [ks.layers.Dropout(dropout)(x) for x in out]
out = ks.layers.Add()(out)
out = MLP(**output_mlp)(out)
elif output_embedding == "node": # Node labeling
out = GraphMLP(**last_mlp)([n, batch_id_node, count_nodes])
out = GraphMLP(**output_mlp)([out, batch_id_node, count_nodes])
else:
raise ValueError("Unsupported output embedding for mode `GIN` .")
return out
def model_disjoint_edge(
inputs,
use_node_embedding: bool = None,
use_edge_embedding: bool = None,
input_node_embedding: dict = None,
input_edge_embedding: dict = None,
depth: int = None,
gin_args: dict = None,
gin_mlp: dict = None,
last_mlp: dict = None,
dropout: float = None,
output_embedding: str = None,
output_mlp: dict = None):
n, ed, disjoint_indices, batch_id_node, count_nodes = inputs
# Embedding, if no feature dimension
if use_node_embedding:
n = Embedding(**input_node_embedding)(n)
if use_edge_embedding:
ed = Embedding(**input_edge_embedding)(ed)
# Model
# Map to the required number of units.
n_units = gin_mlp["units"][-1] if isinstance(gin_mlp["units"], list) else int(gin_mlp["units"])
n = Dense(n_units, use_bias=True, activation='linear')(n)
ed = Dense(n_units, use_bias=True, activation='linear')(ed)
list_embeddings = [n]
for i in range(0, depth):
n = GINE(**gin_args)([n, disjoint_indices, ed])
n = GraphMLP(**gin_mlp)([n, batch_id_node, count_nodes])
list_embeddings.append(n)
# Output embedding choice
if output_embedding == "graph":
out = [PoolingNodes()([count_nodes, x, batch_id_node]) for x in list_embeddings] # will return tensor
out = [MLP(**last_mlp)(x) for x in out]
out = [ks.layers.Dropout(dropout)(x) for x in out]
out = ks.layers.Add()(out)
out = MLP(**output_mlp)(out)
elif output_embedding == "node": # Node labeling
out = GraphMLP(**last_mlp)([n, batch_id_node, count_nodes])
out = GraphMLP(**output_mlp)([out, batch_id_node, count_nodes])
else:
raise ValueError("Unsupported output embedding for mode `GINE` .")
return out