-
Notifications
You must be signed in to change notification settings - Fork 97
/
Copy pathpredict.py
26 lines (24 loc) · 911 Bytes
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import os
import pickle
import tensorflow as tf
from utils import create_model, get_logger
from model import Model
from loader import input_from_line
from train import FLAGS, load_config
def main(_):
config = load_config(FLAGS.config_file)
logger = get_logger(FLAGS.log_file)
# limit GPU memory
tf_config = tf.ConfigProto()
tf_config.gpu_options.allow_growth = True
with open(FLAGS.map_file, "rb") as f:
tag_to_id, id_to_tag = pickle.load(f)
with tf.Session(config=tf_config) as sess:
model = create_model(sess, Model, FLAGS.ckpt_path, config, logger)
while True:
line = input("input sentence, please:")
result = model.evaluate_line(sess, input_from_line(line, FLAGS.max_seq_len, tag_to_id), id_to_tag)
print(result['entities'])
if __name__ == '__main__':
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
tf.app.run(main)