-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodels_dict.py
115 lines (114 loc) · 3.58 KB
/
models_dict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import torch
import models
import params
from utils import NLL_loss
#
# Model Hyperparameters
MODELS = {
'MLP': {
'name': 'MLP',
'ref':models.MLP,
'disc':'',
'tag': '',
'hyperparams':{
'hidden dim':512,
'tv only': False,
'task': params.DUAL,
'curriculum loss': False,
'curriculum seq': False,
'curriculum virtual':False,
},
'optimizer': torch.optim.Adam,
'lc loss function': torch.nn.CrossEntropyLoss,
'ttlc loss function': torch.nn.MSELoss,
'data type': 'state',
'state type': '',
},
'VLSTM': {
'name': 'VLSTM',
'ref':models.VanillaLSTM,
'disc':'',
'tag': '',
'hyperparams':{
'layer number': 1,
'tv only': False,
'hidden dim':512,
'task': params.REGRESSION,
'curriculum loss': False,
'curriculum seq': False,
'curriculum virtual':False,
},
'optimizer': torch.optim.Adam,
'lc loss function': torch.nn.CrossEntropyLoss,
'ttlc loss function': torch.nn.MSELoss,
'data type': 'state',
'state type': '',
},
'VGRU': {
'name': 'VGRU',
'ref':models.VanillaGRU,
'disc':'',
'tag': '',
'hyperparams':{
'layer number': 1,
'tv only': False,
'hidden dim':512,
'task': params.REGRESSION,
'curriculum loss': False,
'curriculum seq': False,
'curriculum virtual':False,
},
'optimizer': torch.optim.Adam,
'lc loss function': torch.nn.CrossEntropyLoss,
'ttlc loss function': torch.nn.MSELoss,
'data type': 'state',
'state type': '',
},
'VCNN':{
'name': 'VCNN',
'ref':models.VanillaCNN,
'disc':'',
'tag': '',
'hyperparams':{
'kernel size': 3,
'channel number':32,
'merge channels': True,
'task': params.DUAL,
'curriculum loss': False,
'curriculum seq': False,
'curriculum virtual':False,
# 'probabilistic model':False, # False True
# 'LSTM model':True,
'model type': 'CNN-LSTM-v2' # Choices = ['Resnet_LSTM', 'pretrainedResnetModel',
# 'pretrained_denseNet',
# 'probabilistic' , 'CNN-LSTM-v3', 'CNN-LSTM-v1',
# 'CNN-LSTM-v2', 'CNN-Linear']
},
'optimizer': torch.optim.Adam,
'lc loss function': torch.nn.CrossEntropyLoss,
'ttlc loss function': torch.nn.MSELoss(), #NLL_loss, # torch.nn.MSELoss(), NLL_loss,
# 'ttlc loss function': NLL_loss, #NLL_loss, # torch.nn.MSELoss(), NLL_loss,
'data type': 'image',
'state type': '',
},
'REGIONATTCNN3':{
'name': 'REGIONATTCNN3',
'ref':models.ATTCNN3,
'disc':'attention weights for quad-regions',
'tag': '',
'hyperparams':{
'kernel size': 3,
'channel number':16,
'merge channels': True,
'task': params.DUAL,
'curriculum loss': False,
'curriculum seq': False,
'curriculum virtual':False,
},
'optimizer': torch.optim.Adam,
'lc loss function': torch.nn.CrossEntropyLoss,
'ttlc loss function': torch.nn.MSELoss,
'data type': 'image',
'state type': '',
},
}