-
Notifications
You must be signed in to change notification settings - Fork 28
/
Copy pathmain.py
34 lines (24 loc) · 919 Bytes
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
from bert.preprocess.preprocess import add_preprocess_parser
from bert.train.train import add_pretrain_parser, add_finetune_parser
import json
from argparse import ArgumentParser
def main():
parser = ArgumentParser('BERT')
parser.add_argument('-c', '--config_path', type=str, default=None)
subparsers = parser.add_subparsers()
add_preprocess_parser(subparsers)
add_pretrain_parser(subparsers)
add_finetune_parser(subparsers)
args = parser.parse_args()
if args.config_path is not None:
with open(args.config_path) as f:
config = json.load(f)
default_config = vars(args)
for key, default_value in default_config.items():
if key not in config:
config[key] = default_value
else:
config = vars(args) # convert to dictionary
args.function(**config, config=config)
if __name__ == '__main__':
main()