-
Notifications
You must be signed in to change notification settings - Fork 23
/
main.py
27 lines (24 loc) · 961 Bytes
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
from __future__ import print_function, division
import sys
from gdynet.model import GDyNet
from gdynet.parsers import main_parser as parser
if __name__ == '__main__':
args = parser.parse_args(sys.argv[1:])
print(args)
gdynet = GDyNet(train_flist=args.train_flist,
val_flist=args.val_flist,
test_flist=args.test_flist,
job_dir=args.job_dir,
mode=args.mode,
tau=args.tau,
n_classes=args.n_classes,
k_eig=args.k_eig,
atom_fea_len=args.atom_fea_len,
n_conv=args.n_conv,
learning_rate=args.lr,
batch_size=args.batch_size,
use_bn=not args.no_bn,
n_epoch=args.n_epoch,
shuffle=not args.no_shuffle,
random_seed=args.random_seed)
gdynet.train_model()