forked from kunalBhashkar/seq2seq_chatbot_tensorflow
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathchat_web.py
84 lines (70 loc) · 3.11 KB
/
chat_web.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
"""
Script for serving a trained chatbot model over http
"""
import datetime
import click
from os import path
from flask import Flask, request, send_from_directory
from flask_cors import CORS
from flask_restful import Resource, Api
import general_utils
import chat_command_handler
from chat_settings import ChatSettings
from chatbot_model import ChatbotModel
from vocabulary import Vocabulary
app = Flask(__name__)
CORS(app)
@app.cli.command()
@click.argument("checkpointfile")
@click.option("-p", "--port", type=int)
def serve_chat(checkpointfile, port):
api = Api(app)
#Read the hyperparameters and configure paths
model_dir, hparams, checkpoint = general_utils.initialize_session_server(checkpointfile)
#Load the vocabulary
print()
print ("Loading vocabulary...")
if hparams.model_hparams.share_embedding:
shared_vocab_filepath = path.join(model_dir, Vocabulary.SHARED_VOCAB_FILENAME)
input_vocabulary = Vocabulary.load(shared_vocab_filepath)
output_vocabulary = input_vocabulary
else:
input_vocab_filepath = path.join(model_dir, Vocabulary.INPUT_VOCAB_FILENAME)
input_vocabulary = Vocabulary.load(input_vocab_filepath)
output_vocab_filepath = path.join(model_dir, Vocabulary.OUTPUT_VOCAB_FILENAME)
output_vocabulary = Vocabulary.load(output_vocab_filepath)
#Create the model
print ("Initializing model...")
print()
with ChatbotModel(mode = "infer",
model_hparams = hparams.model_hparams,
input_vocabulary = input_vocabulary,
output_vocabulary = output_vocabulary,
model_dir = model_dir) as model:
#Load the weights
print()
print ("Loading model weights...")
model.load(checkpoint)
# Setting up the chat
chatlog_filepath = path.join(model_dir, "chat_logs", "web_chatlog_{0}.txt".format(datetime.datetime.now().strftime("%Y%m%d_%H%M%S")))
chat_settings = ChatSettings(hparams.model_hparams, hparams.inference_hparams)
chat_command_handler.print_commands()
class Answer(Resource):
def get(self, question):
is_command, terminate_chat, _ = chat_command_handler.handle_command(question, model, chat_settings)
if terminate_chat:
answer = "[Can't terminate from http request]"
elif is_command:
answer = "[Command processed]"
else:
#If it is not a command (it is a question), pass it on to the chatbot model to get the answer
_, answer = model.chat(question, chat_settings)
if chat_settings.inference_hparams.log_chat:
chat_command_handler.append_to_chatlog(chatlog_filepath, question, answer)
return answer
class UI(Resource):
def get(self):
return send_from_directory(".", "chat_ui.html")
api.add_resource(Answer, "/chat/<string:question>")
api.add_resource(UI, "/chat_ui/")
app.run(debug=False, port=port)