-
Notifications
You must be signed in to change notification settings - Fork 46
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: enable Gradio page to share state between tabs (#48)
* Remove entrypoint * Update ui * Revert dockerfile change * Delete debug clauses * Fix embed dim value * Update view model
- Loading branch information
Showing
12 changed files
with
803 additions
and
615 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
from typing import TYPE_CHECKING, Dict, Generator, List, Tuple | ||
|
||
if TYPE_CHECKING: | ||
from gradio.components import Component | ||
|
||
|
||
class ElementManager: | ||
def __init__(self) -> None: | ||
self._id_to_elem: Dict[str, "Component"] = {} | ||
self._elem_to_id: Dict["Component", str] = {} | ||
|
||
def add_elems(self, elem_dict: Dict[str, "Component"]) -> None: | ||
r""" | ||
Adds elements to manager. | ||
""" | ||
for elem_id, elem in elem_dict.items(): | ||
self._id_to_elem[elem_id] = elem | ||
self._elem_to_id[elem] = elem_id | ||
|
||
def get_elem_list(self) -> List["Component"]: | ||
r""" | ||
Returns the list of all elements. | ||
""" | ||
return list(self._id_to_elem.values()) | ||
|
||
def get_elem_iter(self) -> Generator[Tuple[str, "Component"], None, None]: | ||
r""" | ||
Returns an iterator over all elements with their names. | ||
""" | ||
for elem_id, elem in self._id_to_elem.items(): | ||
yield elem_id.split(".")[-1], elem | ||
|
||
def get_elem_by_id(self, elem_id: str) -> "Component": | ||
r""" | ||
Gets element by id. | ||
Example: top.lang, train.dataset | ||
""" | ||
return self._id_to_elem[elem_id] | ||
|
||
def get_id_by_elem(self, elem: "Component") -> str: | ||
r""" | ||
Gets id by element. | ||
""" | ||
return self._elem_to_id[elem] | ||
|
||
|
||
elem_manager = ElementManager() |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,249 @@ | ||
from typing import Dict, Any, List | ||
import gradio as gr | ||
from pai_rag.app.web.rag_client import rag_client | ||
from pai_rag.app.web.view_model import view_model | ||
from pai_rag.app.web.ui_constants import ( | ||
SIMPLE_PROMPTS, | ||
GENERAL_PROMPTS, | ||
EXTRACT_URL_PROMPTS, | ||
ACCURATE_CONTENT_PROMPTS, | ||
) | ||
|
||
|
||
current_session_id = None | ||
|
||
|
||
def clear_history(chatbot): | ||
chatbot = [] | ||
global current_session_id | ||
current_session_id = None | ||
return chatbot, 0 | ||
|
||
|
||
def respond(input_elements: List[Any]): | ||
global current_session_id | ||
|
||
update_dict = {} | ||
for element, value in input_elements.items(): | ||
update_dict[element.elem_id] = value | ||
|
||
# empty input. | ||
if not update_dict["question"]: | ||
return "", update_dict["chatbot"], 0 | ||
|
||
view_model.update(update_dict) | ||
new_config = view_model.to_app_config() | ||
rag_client.reload_config(new_config) | ||
|
||
query_type = update_dict["query_type"] | ||
msg = update_dict["question"] | ||
chatbot = update_dict["chatbot"] | ||
|
||
if query_type == "LLM": | ||
response = rag_client.query_llm( | ||
msg, | ||
session_id=current_session_id, | ||
) | ||
|
||
elif query_type == "Retrieval": | ||
response = rag_client.query_vector(msg) | ||
else: | ||
response = rag_client.query(msg, session_id=current_session_id) | ||
if update_dict["include_history"]: | ||
current_session_id = response.session_id | ||
else: | ||
current_session_id = None | ||
chatbot.append((msg, response.answer)) | ||
return "", chatbot, 0 | ||
|
||
|
||
def create_chat_tab() -> Dict[str, Any]: | ||
with gr.Row(): | ||
with gr.Column(scale=2): | ||
query_type = gr.Radio( | ||
["Retrieval", "LLM", "RAG (Retrieval + LLM)"], | ||
label="\N{fire} Which query do you want to use?", | ||
elem_id="query_type", | ||
value="RAG (Retrieval + LLM)", | ||
) | ||
|
||
with gr.Column(visible=True) as vs_col: | ||
vec_model_argument = gr.Accordion("Parameters of Vector Retrieval") | ||
|
||
with vec_model_argument: | ||
similarity_top_k = gr.Slider( | ||
minimum=0, | ||
maximum=100, | ||
step=1, | ||
elem_id="similarity_top_k", | ||
label="Top K (choose between 0 and 100)", | ||
) | ||
# similarity_cutoff = gr.Slider(minimum=0, maximum=1, step=0.01,elem_id="similarity_cutoff",value=view_model.similarity_cutoff, label="Similarity Distance Threshold (The more similar the vectors, the smaller the value.)") | ||
rerank_model = gr.Radio( | ||
[ | ||
"no-reranker", | ||
"bge-reranker-base", | ||
"bge-reranker-large", | ||
"llm-reranker", | ||
], | ||
label="Re-Rank Model (Note: It will take a long time to load the model when using it for the first time.)", | ||
elem_id="rerank_model", | ||
) | ||
retrieval_mode = gr.Radio( | ||
["Embedding Only", "Keyword Ensembled", "Keyword Only"], | ||
label="Retrieval Mode", | ||
elem_id="retrieval_mode", | ||
) | ||
vec_args = { | ||
similarity_top_k, | ||
# similarity_cutoff, | ||
rerank_model, | ||
retrieval_mode, | ||
} | ||
with gr.Column(visible=True) as llm_col: | ||
model_argument = gr.Accordion("Inference Parameters of LLM") | ||
with model_argument: | ||
include_history = gr.Checkbox( | ||
label="Chat history", | ||
info="Query with chat history.", | ||
elem_id="include_history", | ||
) | ||
llm_topk = gr.Slider( | ||
minimum=0, | ||
maximum=100, | ||
step=1, | ||
value=30, | ||
elem_id="llm_topk", | ||
label="Top K (choose between 0 and 100)", | ||
) | ||
llm_topp = gr.Slider( | ||
minimum=0, | ||
maximum=1, | ||
step=0.01, | ||
value=0.8, | ||
elem_id="llm_topp", | ||
label="Top P (choose between 0 and 1)", | ||
) | ||
llm_temp = gr.Slider( | ||
minimum=0, | ||
maximum=1, | ||
step=0.01, | ||
value=0.7, | ||
elem_id="llm_temp", | ||
label="Temperature (choose between 0 and 1)", | ||
) | ||
llm_args = {llm_topk, llm_topp, llm_temp, include_history} | ||
|
||
with gr.Column(visible=True) as lc_col: | ||
prm_type = gr.Radio( | ||
[ | ||
"Simple", | ||
"General", | ||
"Extract URL", | ||
"Accurate Content", | ||
"Custom", | ||
], | ||
label="\N{rocket} Please choose the prompt template type", | ||
elem_id="prm_type", | ||
) | ||
text_qa_template = gr.Textbox( | ||
label="prompt template", | ||
placeholder="", | ||
elem_id="text_qa_template", | ||
lines=4, | ||
) | ||
|
||
def change_prompt_template(prm_type): | ||
if prm_type == "Simple": | ||
return { | ||
text_qa_template: gr.update( | ||
value=SIMPLE_PROMPTS, interactive=False | ||
) | ||
} | ||
elif prm_type == "General": | ||
return { | ||
text_qa_template: gr.update( | ||
value=GENERAL_PROMPTS, interactive=False | ||
) | ||
} | ||
elif prm_type == "Extract URL": | ||
return { | ||
text_qa_template: gr.update( | ||
value=EXTRACT_URL_PROMPTS, interactive=False | ||
) | ||
} | ||
elif prm_type == "Accurate Content": | ||
return { | ||
text_qa_template: gr.update( | ||
value=ACCURATE_CONTENT_PROMPTS, | ||
interactive=False, | ||
) | ||
} | ||
else: | ||
return {text_qa_template: gr.update(value="", interactive=True)} | ||
|
||
prm_type.change( | ||
fn=change_prompt_template, | ||
inputs=prm_type, | ||
outputs=[text_qa_template], | ||
) | ||
|
||
cur_tokens = gr.Textbox( | ||
label="\N{fire} Current total count of tokens", visible=False | ||
) | ||
|
||
def change_query_radio(query_type): | ||
global current_session_id | ||
current_session_id = None | ||
if query_type == "Retrieval": | ||
return { | ||
vs_col: gr.update(visible=True), | ||
llm_col: gr.update(visible=False), | ||
lc_col: gr.update(visible=False), | ||
} | ||
elif query_type == "LLM": | ||
return { | ||
vs_col: gr.update(visible=False), | ||
llm_col: gr.update(visible=True), | ||
lc_col: gr.update(visible=False), | ||
} | ||
elif query_type == "RAG (Retrieval + LLM)": | ||
return { | ||
vs_col: gr.update(visible=True), | ||
llm_col: gr.update(visible=True), | ||
lc_col: gr.update(visible=True), | ||
} | ||
|
||
query_type.change( | ||
fn=change_query_radio, | ||
inputs=query_type, | ||
outputs=[vs_col, llm_col, lc_col], | ||
) | ||
|
||
with gr.Column(scale=8): | ||
chatbot = gr.Chatbot(height=500, elem_id="chatbot") | ||
question = gr.Textbox(label="Enter your question.", elem_id="question") | ||
with gr.Row(): | ||
submitBtn = gr.Button("Submit", variant="primary") | ||
clearBtn = gr.Button("Clear History", variant="secondary") | ||
|
||
chat_args = ( | ||
{text_qa_template, question, query_type, chatbot} | ||
.union(vec_args) | ||
.union(llm_args) | ||
) | ||
|
||
submitBtn.click( | ||
respond, | ||
chat_args, | ||
[question, chatbot, cur_tokens], | ||
api_name="respond", | ||
) | ||
clearBtn.click(clear_history, [chatbot], [chatbot, cur_tokens]) | ||
return { | ||
similarity_top_k.elem_id: similarity_top_k, | ||
rerank_model.elem_id: rerank_model, | ||
retrieval_mode.elem_id: retrieval_mode, | ||
prm_type.elem_id: prm_type, | ||
text_qa_template.elem_id: text_qa_template, | ||
} |
Oops, something went wrong.