diff --git a/examples/rag-chat-api/README.md b/examples/rag-chat-api/README.md index 4eb8cd1b92..f236ae60cb 100644 --- a/examples/rag-chat-api/README.md +++ b/examples/rag-chat-api/README.md @@ -21,3 +21,16 @@ Alternatively, you can use the Python script for easier access + string formatti ``` python question.py "INPUT-QUESTION-HERE" ``` + +### UI + +To run the UI, you need to install `streamlit`: +``` +pip install streamlit +``` +Note that newer versions of `streamlit` require Python 3.8+. + +Then, run the app: +``` +streamlit run chat_app.py +``` diff --git a/examples/rag-chat-api/chat_app.py b/examples/rag-chat-api/chat_app.py new file mode 100644 index 0000000000..dc642d76ae --- /dev/null +++ b/examples/rag-chat-api/chat_app.py @@ -0,0 +1,29 @@ +# Copyright 2021-2024 VMware, Inc. +# SPDX-License-Identifier: Apache-2.0 +import streamlit as st +from question import get_api_response + +st.title("VDK Demo Chat Bot") + +# Initialize chat history +if "messages" not in st.session_state: + st.session_state.messages = [] + +# Display chat messages from history on app rerun +for message in st.session_state.messages: + with st.chat_message(message["role"]): + st.markdown(message["content"]) + +# React to user input +if prompt := st.chat_input("What is up?"): + # Display user message in chat message container + st.chat_message("user").markdown(prompt) + # Add user message to chat history + st.session_state.messages.append({"role": "user", "content": prompt}) + + response = get_api_response(prompt) + # Display assistant response in chat message container + with st.chat_message("assistant"): + st.markdown(response) + # Add assistant response to chat history + st.session_state.messages.append({"role": "assistant", "content": response}) diff --git a/examples/rag-chat-api/question.py b/examples/rag-chat-api/question.py index 2f23828219..e0af39c87f 100644 --- a/examples/rag-chat-api/question.py +++ b/examples/rag-chat-api/question.py @@ -5,16 +5,16 @@ import requests -def question(): - if len(sys.argv) != 2: - print("Wrap your question in quotation marks") - +def get_api_response(question): headers = {"Content-Type": "application/json"} - data = {"question": sys.argv[1]} + data = {"question": question} res = requests.post("http://127.0.0.1:8000/question/", headers=headers, json=data) - print(res.text.replace("\\n", "\n").replace("\\t", "\t")) + return res.text.replace("\\n", "\n").replace("\\t", "\t") if __name__ == "__main__": - question() + if len(sys.argv) != 2: + print("Wrap your question in quotation marks") + + print(get_api_response(sys.argv[1]))