Skip to content

Commit

Permalink
timeout prep backend
Browse files Browse the repository at this point in the history
  • Loading branch information
joachim-danswer committed Feb 7, 2025
1 parent 2bb211d commit f4ae944
Show file tree
Hide file tree
Showing 19 changed files with 643 additions and 173 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from datetime import datetime
from typing import cast

import openai
from langchain_core.messages import HumanMessage
from langchain_core.messages import merge_message_runs
from langchain_core.runnables.config import RunnableConfig
Expand All @@ -12,10 +13,14 @@
SubQuestionAnswerCheckUpdate,
)
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.models import AgentError
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
from onyx.configs.agent_configs import AGENT_TIMEOUT_OVERWRITE_LLM_SUBANSWER_CHECK
from onyx.prompts.agent_search import AGENT_LLM_ERROR_MESSAGE
from onyx.prompts.agent_search import AGENT_LLM_TIMEOUT_MESSAGE
from onyx.prompts.agent_search import SUB_ANSWER_CHECK_PROMPT
from onyx.prompts.agent_search import UNKNOWN_ANSWER

Expand Down Expand Up @@ -53,14 +58,46 @@ def check_sub_answer(

graph_config = cast(GraphConfig, config["metadata"]["config"])
fast_llm = graph_config.tooling.fast_llm
response = list(
fast_llm.stream(
prompt=msg,
agent_error: AgentError | None = None
response: list | None = None
try:
response = list(
fast_llm.stream(
prompt=msg,
timeout_overwrite=AGENT_TIMEOUT_OVERWRITE_LLM_SUBANSWER_CHECK,
)
)
)

quality_str: str = merge_message_runs(response, chunk_separator="")[0].content
answer_quality = "yes" in quality_str.lower()
except openai.APITimeoutError:
agent_error = AgentError(
error_type="timeout",
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
error_result="LLM Timeout Error",
)

except Exception:
agent_error = AgentError(
error_type="LLM error",
error_message=AGENT_LLM_ERROR_MESSAGE,
error_result="LLM Error",
)

if agent_error:
answer_quality = True
log_result = agent_error.error_result

else:
if response:
quality_str: str = merge_message_runs(response, chunk_separator="")[
0
].content
answer_quality = "yes" in quality_str.lower()

else:
answer_quality = True
quality_str = "yes - because LLM error"

log_result = f"Answer quality: {quality_str}"

return SubQuestionAnswerCheckUpdate(
answer_quality=answer_quality,
Expand All @@ -69,7 +106,7 @@ def check_sub_answer(
graph_component="initial - generate individual sub answer",
node_name="check sub answer",
node_start_time=node_start_time,
result=f"Answer quality: {quality_str}",
result=log_result,
)
],
)
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Any
from typing import cast

import openai
from langchain_core.messages import merge_message_runs
from langchain_core.runnables.config import RunnableConfig
from langgraph.types import StreamWriter
Expand All @@ -16,6 +17,7 @@
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
build_sub_question_answer_prompt,
)
from onyx.agents.agent_search.shared_graph_utils.models import AgentError
from onyx.agents.agent_search.shared_graph_utils.utils import get_answer_citation_ids
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
Expand All @@ -30,6 +32,10 @@
from onyx.chat.models import StreamStopReason
from onyx.chat.models import StreamType
from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS
from onyx.configs.agent_configs import AGENT_TIMEOUT_OVERWRITE_LLM_SUBANSWER_GENERATION
from onyx.prompts.agent_search import AGENT_LLM_ERROR_MESSAGE
from onyx.prompts.agent_search import AGENT_LLM_TIMEOUT_MESSAGE
from onyx.prompts.agent_search import LLM_ANSWER_ERROR_MESSAGE
from onyx.prompts.agent_search import NO_RECOVERED_DOCS
from onyx.utils.logger import setup_logger

Expand Down Expand Up @@ -57,6 +63,8 @@ def generate_sub_answer(

if len(context_docs) == 0:
answer_str = NO_RECOVERED_DOCS
cited_documents: list = []
log_results = "No documents retrieved"
write_custom_event(
"sub_answers",
AgentAnswerPiece(
Expand All @@ -79,41 +87,66 @@ def generate_sub_answer(

response: list[str | list[str | dict[str, Any]]] = []
dispatch_timings: list[float] = []
for message in fast_llm.stream(
prompt=msg,
):
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
content = message.content
if not isinstance(content, str):
raise ValueError(
f"Expected content to be a string, but got {type(content)}"

agent_error: AgentError | None = None

try:
for message in fast_llm.stream(
prompt=msg,
timeout_overwrite=AGENT_TIMEOUT_OVERWRITE_LLM_SUBANSWER_GENERATION,
):
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
content = message.content
if not isinstance(content, str):
raise ValueError(
f"Expected content to be a string, but got {type(content)}"
)
start_stream_token = datetime.now()
write_custom_event(
"sub_answers",
AgentAnswerPiece(
answer_piece=content,
level=level,
level_question_num=question_num,
answer_type="agent_sub_answer",
),
writer,
)
start_stream_token = datetime.now()
write_custom_event(
"sub_answers",
AgentAnswerPiece(
answer_piece=content,
level=level,
level_question_num=question_num,
answer_type="agent_sub_answer",
),
writer,
end_stream_token = datetime.now()
dispatch_timings.append(
(end_stream_token - start_stream_token).microseconds
)
response.append(content)

except openai.APITimeoutError:
agent_error = AgentError(
error_type="timeout",
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
error_result="LLM Timeout Error",
)
end_stream_token = datetime.now()
dispatch_timings.append(
(end_stream_token - start_stream_token).microseconds

except Exception:
agent_error = AgentError(
error_type="LLM error",
error_message=AGENT_LLM_ERROR_MESSAGE,
error_result="LLM Error",
)
response.append(content)

answer_str = merge_message_runs(response, chunk_separator="")[0].content
logger.debug(
f"Average dispatch time: {sum(dispatch_timings) / len(dispatch_timings)}"
)
if agent_error:
answer_str = LLM_ANSWER_ERROR_MESSAGE
cited_documents = []
log_results = (
agent_error.error_result
or "Sub-answer generation failed due to LLM error"
)

answer_citation_ids = get_answer_citation_ids(answer_str)
cited_documents = [
context_docs[id] for id in answer_citation_ids if id < len(context_docs)
]
else:
answer_str = merge_message_runs(response, chunk_separator="")[0].content
answer_citation_ids = get_answer_citation_ids(answer_str)
cited_documents = [
context_docs[id] for id in answer_citation_ids if id < len(context_docs)
]
log_results = ""

stop_event = StreamStopInfo(
stop_reason=StreamStopReason.FINISHED,
Expand All @@ -131,7 +164,7 @@ def generate_sub_answer(
graph_component="initial - generate individual sub answer",
node_name="generate sub answer",
node_start_time=node_start_time,
result="",
result=log_results,
)
],
)
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Any
from typing import cast

import openai
from langchain_core.messages import HumanMessage
from langchain_core.messages import merge_content
from langchain_core.runnables import RunnableConfig
Expand All @@ -26,6 +27,7 @@
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
trim_prompt_piece,
)
from onyx.agents.agent_search.shared_graph_utils.models import AgentError
from onyx.agents.agent_search.shared_graph_utils.models import InitialAgentResultStats
from onyx.agents.agent_search.shared_graph_utils.operators import (
dedup_inference_sections,
Expand All @@ -42,12 +44,16 @@
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import AgentAnswerPiece
from onyx.chat.models import ExtendedToolResponse
from onyx.chat.models import StreamingError
from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS
from onyx.configs.agent_configs import AGENT_MIN_ORIG_QUESTION_DOCS
from onyx.context.search.models import InferenceSection
from onyx.prompts.agent_search import (
INITIAL_ANSWER_PROMPT_W_SUB_QUESTIONS,
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_OVERWRITE_LLM_INITIAL_ANSWER_GENERATION,
)
from onyx.context.search.models import InferenceSection
from onyx.prompts.agent_search import AGENT_LLM_ERROR_MESSAGE
from onyx.prompts.agent_search import AGENT_LLM_TIMEOUT_MESSAGE
from onyx.prompts.agent_search import INITIAL_ANSWER_PROMPT_W_SUB_QUESTIONS
from onyx.prompts.agent_search import (
INITIAL_ANSWER_PROMPT_WO_SUB_QUESTIONS,
)
Expand Down Expand Up @@ -224,30 +230,80 @@ def generate_initial_answer(

streamed_tokens: list[str | list[str | dict[str, Any]]] = [""]
dispatch_timings: list[float] = []
for message in model.stream(msg):
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
content = message.content
if not isinstance(content, str):
raise ValueError(
f"Expected content to be a string, but got {type(content)}"

agent_error: AgentError | None = None

try:
for message in model.stream(
msg,
timeout_overwrite=AGENT_TIMEOUT_OVERWRITE_LLM_INITIAL_ANSWER_GENERATION,
):
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
content = message.content
if not isinstance(content, str):
raise ValueError(
f"Expected content to be a string, but got {type(content)}"
)
start_stream_token = datetime.now()

write_custom_event(
"initial_agent_answer",
AgentAnswerPiece(
answer_piece=content,
level=0,
level_question_num=0,
answer_type="agent_level_answer",
),
writer,
)
start_stream_token = datetime.now()
end_stream_token = datetime.now()
dispatch_timings.append(
(end_stream_token - start_stream_token).microseconds
)
streamed_tokens.append(content)

except openai.APITimeoutError:
agent_error = AgentError(
error_type="timeout",
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
error_result="LLM Timeout Error",
)

except Exception:
agent_error = AgentError(
error_type="LLM error",
error_message=AGENT_LLM_ERROR_MESSAGE,
error_result="LLM Error",
)

if agent_error:
write_custom_event(
"initial_agent_answer",
AgentAnswerPiece(
answer_piece=content,
level=0,
level_question_num=0,
answer_type="agent_level_answer",
StreamingError(
error=AGENT_LLM_TIMEOUT_MESSAGE,
),
writer,
)
end_stream_token = datetime.now()
dispatch_timings.append(
(end_stream_token - start_stream_token).microseconds
return InitialAnswerUpdate(
initial_answer=None,
error=AgentError(
error_message=agent_error.error_message or "An LLM error occurred",
error_type=agent_error.error_type,
error_result=agent_error.error_result,
),
initial_agent_stats=None,
generated_sub_questions=sub_questions,
agent_base_end_time=None,
agent_base_metrics=None,
log_messages=[
get_langgraph_node_log_string(
graph_component="initial - generate initial answer",
node_name="generate initial answer",
node_start_time=node_start_time,
result=agent_error.error_result or "An LLM error occurred",
)
],
)
streamed_tokens.append(content)

logger.debug(
f"Average dispatch time for initial answer: {sum(dispatch_timings) / len(dispatch_timings)}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def validate_initial_answer(
f"--------{node_start_time}--------Checking for base answer validity - for not set True/False manually"
)

verdict = True
verdict = True # not actually requitred as already streamed out. Refinement will do similar

return InitialAnswerQualityUpdate(
initial_answer_quality_eval=verdict,
Expand Down
Loading

0 comments on commit f4ae944

Please sign in to comment.