From 93d04701c71dc7799f3da7db96ac34e5fe531df2 Mon Sep 17 00:00:00 2001 From: Wannabeasmartguy <997139385@qq.com> Date: Fri, 13 Sep 2024 13:54:23 +0800 Subject: [PATCH] Fix bugs - Fix bug: An error occurs when the number of referenced files is less than 6; - Fix bug: If the default model is used after modifying the `model type`, the response will report an error if no other model is selected. --- RAGenT.py | 5 +-- pages/RAG_Chat.py | 84 ++++++++++++++++++++--------------------------- 2 files changed, 38 insertions(+), 51 deletions(-) diff --git a/RAGenT.py b/RAGenT.py index 026efc1..cb4d457 100644 --- a/RAGenT.py +++ b/RAGenT.py @@ -345,9 +345,10 @@ def get_selected_non_llamafile_model_index(model_type) -> int: logger.debug(f"model {model} in options, index: {options_index}") return options_index else: - logger.debug(f"model {model} not in options") + st.session_state.chat_config_list[0].update({"model": options[0]}) + logger.debug(f"model {model} not in options, set model in config list to first option: {options[0]}") return 0 - except ValueError: + except (ValueError, AttributeError, IndexError): logger.warning(f"Model {model} not found in model_selector for {model_type}, returning 0") return 0 diff --git a/pages/RAG_Chat.py b/pages/RAG_Chat.py index 614c2e6..2978abf 100644 --- a/pages/RAG_Chat.py +++ b/pages/RAG_Chat.py @@ -2,6 +2,7 @@ import os import json from uuid import uuid4 +from copy import deepcopy from functools import lru_cache from loguru import logger from datetime import datetime @@ -73,21 +74,24 @@ def save_rag_chat_history(response: BaseRAGResponse): def display_rag_sources(response_sources): - row1 = st.columns(3) - row2 = st.columns(3) - - for index, pop in enumerate(row1 + row2): - a = pop.popover(i18n("Cited Source") + f" {index+1}", use_container_width=True) - file_name = response_sources["metadatas"][index]["source"] - file_content = response_sources["page_content"][index] - a.text(i18n("Cited Source") + ": " + file_name) - relevance_score_placeholder = a.empty() - if "relevance_score" in response_sources["metadatas"][index]: - relevance_score = response_sources["metadatas"][index]["relevance_score"] - relevance_score_placeholder.text( - i18n("Relevance Score") + ": " + str(relevance_score) - ) - a.code(file_content, language="plaintext") + import itertools + num_sources = len(response_sources["metadatas"]) + num_columns = min(3, num_sources) + rows = [st.columns(num_columns) for _ in range((num_sources + 2) // 3)] + + for index, pop in enumerate(itertools.chain(*rows)): + if index < num_sources: + a = pop.popover(i18n("Cited Source") + f" {index+1}", use_container_width=True) + file_name = response_sources["metadatas"][index]["source"] + file_content = response_sources["page_content"][index] + a.text(i18n("Cited Source") + ": " + file_name) + relevance_score_placeholder = a.empty() + if "relevance_score" in response_sources["metadatas"][index]: + relevance_score = response_sources["metadatas"][index]["relevance_score"] + relevance_score_placeholder.text( + i18n("Relevance Score") + ": " + str(relevance_score) + ) + a.code(file_content, language="plaintext") @st.cache_data @@ -323,32 +327,28 @@ def update_rag_config_in_db_callback(): dialogs_container = st.container(height=250, border=True) def rag_saved_dialog_change_callback(): + origin_config_list = deepcopy(st.session_state.rag_chat_config_list) st.session_state.rag_run_id = st.session_state.rag_saved_dialog.run_id st.session_state.rag_chat_config_list = [ chat_history_storage.get_specific_run( st.session_state.rag_saved_dialog.run_id ).llm ] + log_dict_changes(original_dict=origin_config_list[0], new_dict=st.session_state.rag_chat_config_list[0]) try: - st.session_state.rag_chat_history = ( + st.session_state.custom_rag_chat_history = ( + chat_history_storage.get_specific_run(st.session_state.rag_saved_dialog.run_id).memory[ + "chat_history" + ] + ) + st.session_state.custom_rag_sources = ( chat_history_storage.get_specific_run( st.session_state.rag_saved_dialog.run_id - ).memory["chat_history"] + ).task_data["source_documents"] ) - except: - st.session_state.rag_chat_history = [] - - # 更新 select_box1 的值 - if st.session_state.model_type != "Llamafile": - st.session_state.model = st.session_state.rag_chat_config_list[0].get("model") - else: - st.session_state.model = st.session_state.rag_chat_config_list[0].get("model", "") - try: - st.session_state.model = st.session_state.rag_chat_config_list[0].get("model") - except: - st.session_state.model = oai_model_config_selector( - st.session_state.oai_like_model_config_dict - )[0] + except (TypeError, ValidationError): + st.session_state.custom_rag_chat_history = [] + st.session_state.custom_rag_sources = {} saved_dialog = dialogs_container.radio( label=i18n("Saved dialog"), @@ -442,23 +442,6 @@ def delete_rag_dialog_callback(): on_click=delete_rag_dialog_callback, ) - if saved_dialog: - st.session_state.rag_run_id = saved_dialog.run_id - try: - st.session_state.custom_rag_chat_history = ( - chat_history_storage.get_specific_run(saved_dialog.run_id).memory[ - "chat_history" - ] - ) - st.session_state.custom_rag_sources = ( - chat_history_storage.get_specific_run( - saved_dialog.run_id - ).task_data["source_documents"] - ) - except (TypeError, ValidationError): - st.session_state.custom_rag_chat_history = [] - st.session_state.custom_rag_sources = {} - # 保存对话 def get_run_name(): try: @@ -617,7 +600,8 @@ def get_selected_non_llamafile_model_index(model_type) -> int: logger.debug(f"model {model} in options, index: {options_index}") return options_index else: - logger.debug(f"model {model} not in options") + st.session_state.rag_chat_config_list[0].update({"model": options[0]}) + logger.debug(f"model {model} not in options, set model in config list to first option: {options[0]}") return 0 except ValueError: logger.warning(f"Model {model} not found in model_selector for {model_type}, returning 0") @@ -632,6 +616,7 @@ def get_selected_non_llamafile_model_index(model_type) -> int: key="model", on_change=update_rag_config_in_db_callback, ) + elif select_box0 == "Llamafile": def get_selected_llamafile_model() -> str: @@ -649,6 +634,7 @@ def get_selected_llamafile_model() -> str: key="model", placeholder=i18n("Fill in custom model name. (Optional)"), ) + with model_choosing_container.popover( label=i18n("Llamafile config"), use_container_width=True ):