-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathflask_app.py
140 lines (108 loc) · 3.44 KB
/
flask_app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import json
import os
from flask import Flask, jsonify, request
from flask_cors import CORS
import src.config as cfg
from src.csv_retriever import CSVRetriever
from src.qna import QnA, QnAResponse
from src.rag import RAG
from src.utils.conversation import parse_llm_messages
app = Flask(__name__)
CORS(app)
def get_qna():
default_model = cfg.llm_options['azure-openai'].get('llm')
rag = RAG(model=default_model, rerank=cfg.rerank)
csv_retriever = CSVRetriever(
llm=default_model,
directory_path='./data/preprocessed/csv/'
)
qa = QnA(
model=default_model,
retriever=rag.retriever,
data_retriever=csv_retriever
)
return qa
@app.route('/chat', methods=['POST'])
def chat():
input_json = request.get_json(force=True)
query = input_json["query"]
session_id = input_json.get('session_id', '')
qa: QnA = get_qna()
response: QnAResponse = qa.ask_question(
query=query,
session_id=session_id,
stream=False
)
return jsonify({
"response": {
"answer": response['answer'],
"chat_history": parse_llm_messages(response['chat_history']),
"context": [
{
"source": doc.metadata['source'],
"relevance_score": doc.metadata['relevance_score'],
"page_content": doc.page_content[:100]
}
for doc in response['context']
]
},
"status": "Success"
})
def fake_stream():
for i in range(10):
data = json.dumps({
"response": {
"answer": f"Answer {i}",
"chat_history": [],
"context": []
},
"status": "Success"
})
# Convert the event object to a string
msg = f'id: {1}\ndata: {data}\n\n'
yield msg
def stream_parser(response):
context = []
chat_history = []
for chunk in response:
if 'context' in chunk:
for doc in chunk['context']:
context.append([
{
"source": doc.metadata['source'],
"relevance_score": doc.metadata['relevance_score'],
"page_content": doc.page_content[:100]
}
])
if 'chat_history' in chunk:
chat_history = parse_llm_messages(chunk['chat_history'])
answer = chunk.get('answer', '')
data = json.dumps({
"response": {
"answer": answer,
"chat_history": chat_history,
"context": context
},
"status": "Success"
})
# Convert the event object to a string
msg = f'id: {1}\ndata: {data}\n\n'
yield msg
@app.route('/stream', methods=['POST'])
def stream():
input_json = request.get_json(force=True)
query = input_json["query"]
session_id = input_json.get('session_id', '')
qa: QnA = get_qna()
# response = None
response: QnAResponse = qa.ask_question(
query=query,
session_id=session_id,
stream=True
)
if response:
return app.response_class(stream_parser(response), content_type='text/event-stream')
else:
return app.response_class(fake_stream(), content_type='text/event-stream')
if __name__ == "__main__":
app.run(host="0.0.0.0", port=os.getenv('PORT', 5000), debug=True)