Skip to content

Commit

Permalink
Fix bugs
Browse files Browse the repository at this point in the history
- Fix bug: An error occurs when the number of referenced files is less than 6;
- Fix bug: If the default model is used after modifying the `model type`, the response will report an error if no other model is selected.
  • Loading branch information
Wannabeasmartguy committed Sep 13, 2024
1 parent 657d2d3 commit 93d0470
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 51 deletions.
5 changes: 3 additions & 2 deletions RAGenT.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,9 +345,10 @@ def get_selected_non_llamafile_model_index(model_type) -> int:
logger.debug(f"model {model} in options, index: {options_index}")
return options_index
else:
logger.debug(f"model {model} not in options")
st.session_state.chat_config_list[0].update({"model": options[0]})
logger.debug(f"model {model} not in options, set model in config list to first option: {options[0]}")
return 0
except ValueError:
except (ValueError, AttributeError, IndexError):
logger.warning(f"Model {model} not found in model_selector for {model_type}, returning 0")
return 0

Expand Down
84 changes: 35 additions & 49 deletions pages/RAG_Chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import json
from uuid import uuid4
from copy import deepcopy
from functools import lru_cache
from loguru import logger
from datetime import datetime
Expand Down Expand Up @@ -73,21 +74,24 @@ def save_rag_chat_history(response: BaseRAGResponse):


def display_rag_sources(response_sources):
row1 = st.columns(3)
row2 = st.columns(3)

for index, pop in enumerate(row1 + row2):
a = pop.popover(i18n("Cited Source") + f" {index+1}", use_container_width=True)
file_name = response_sources["metadatas"][index]["source"]
file_content = response_sources["page_content"][index]
a.text(i18n("Cited Source") + ": " + file_name)
relevance_score_placeholder = a.empty()
if "relevance_score" in response_sources["metadatas"][index]:
relevance_score = response_sources["metadatas"][index]["relevance_score"]
relevance_score_placeholder.text(
i18n("Relevance Score") + ": " + str(relevance_score)
)
a.code(file_content, language="plaintext")
import itertools
num_sources = len(response_sources["metadatas"])
num_columns = min(3, num_sources)
rows = [st.columns(num_columns) for _ in range((num_sources + 2) // 3)]

for index, pop in enumerate(itertools.chain(*rows)):
if index < num_sources:
a = pop.popover(i18n("Cited Source") + f" {index+1}", use_container_width=True)
file_name = response_sources["metadatas"][index]["source"]
file_content = response_sources["page_content"][index]
a.text(i18n("Cited Source") + ": " + file_name)
relevance_score_placeholder = a.empty()
if "relevance_score" in response_sources["metadatas"][index]:
relevance_score = response_sources["metadatas"][index]["relevance_score"]
relevance_score_placeholder.text(
i18n("Relevance Score") + ": " + str(relevance_score)
)
a.code(file_content, language="plaintext")


@st.cache_data
Expand Down Expand Up @@ -323,32 +327,28 @@ def update_rag_config_in_db_callback():
dialogs_container = st.container(height=250, border=True)

def rag_saved_dialog_change_callback():
origin_config_list = deepcopy(st.session_state.rag_chat_config_list)
st.session_state.rag_run_id = st.session_state.rag_saved_dialog.run_id
st.session_state.rag_chat_config_list = [
chat_history_storage.get_specific_run(
st.session_state.rag_saved_dialog.run_id
).llm
]
log_dict_changes(original_dict=origin_config_list[0], new_dict=st.session_state.rag_chat_config_list[0])
try:
st.session_state.rag_chat_history = (
st.session_state.custom_rag_chat_history = (
chat_history_storage.get_specific_run(st.session_state.rag_saved_dialog.run_id).memory[
"chat_history"
]
)
st.session_state.custom_rag_sources = (
chat_history_storage.get_specific_run(
st.session_state.rag_saved_dialog.run_id
).memory["chat_history"]
).task_data["source_documents"]
)
except:
st.session_state.rag_chat_history = []

# 更新 select_box1 的值
if st.session_state.model_type != "Llamafile":
st.session_state.model = st.session_state.rag_chat_config_list[0].get("model")
else:
st.session_state.model = st.session_state.rag_chat_config_list[0].get("model", "")
try:
st.session_state.model = st.session_state.rag_chat_config_list[0].get("model")
except:
st.session_state.model = oai_model_config_selector(
st.session_state.oai_like_model_config_dict
)[0]
except (TypeError, ValidationError):
st.session_state.custom_rag_chat_history = []
st.session_state.custom_rag_sources = {}

saved_dialog = dialogs_container.radio(
label=i18n("Saved dialog"),
Expand Down Expand Up @@ -442,23 +442,6 @@ def delete_rag_dialog_callback():
on_click=delete_rag_dialog_callback,
)

if saved_dialog:
st.session_state.rag_run_id = saved_dialog.run_id
try:
st.session_state.custom_rag_chat_history = (
chat_history_storage.get_specific_run(saved_dialog.run_id).memory[
"chat_history"
]
)
st.session_state.custom_rag_sources = (
chat_history_storage.get_specific_run(
saved_dialog.run_id
).task_data["source_documents"]
)
except (TypeError, ValidationError):
st.session_state.custom_rag_chat_history = []
st.session_state.custom_rag_sources = {}

# 保存对话
def get_run_name():
try:
Expand Down Expand Up @@ -617,7 +600,8 @@ def get_selected_non_llamafile_model_index(model_type) -> int:
logger.debug(f"model {model} in options, index: {options_index}")
return options_index
else:
logger.debug(f"model {model} not in options")
st.session_state.rag_chat_config_list[0].update({"model": options[0]})
logger.debug(f"model {model} not in options, set model in config list to first option: {options[0]}")
return 0
except ValueError:
logger.warning(f"Model {model} not found in model_selector for {model_type}, returning 0")
Expand All @@ -632,6 +616,7 @@ def get_selected_non_llamafile_model_index(model_type) -> int:
key="model",
on_change=update_rag_config_in_db_callback,
)

elif select_box0 == "Llamafile":

def get_selected_llamafile_model() -> str:
Expand All @@ -649,6 +634,7 @@ def get_selected_llamafile_model() -> str:
key="model",
placeholder=i18n("Fill in custom model name. (Optional)"),
)

with model_choosing_container.popover(
label=i18n("Llamafile config"), use_container_width=True
):
Expand Down

0 comments on commit 93d0470

Please sign in to comment.