Skip to content

Commit

Permalink
major Agent Search Updates (#3994)
Browse files Browse the repository at this point in the history
  • Loading branch information
joachim-danswer authored Feb 14, 2025
1 parent ec78f78 commit 6687d5d
Show file tree
Hide file tree
Showing 36 changed files with 2,120 additions and 436 deletions.
3 changes: 1 addition & 2 deletions backend/onyx/agents/agent_search/core_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ class CoreState(BaseModel):
This is the core state that is shared across all subgraphs.
"""

base_question: str = ""
log_messages: Annotated[list[str], add] = []


Expand All @@ -18,4 +17,4 @@ class SubgraphCoreState(BaseModel):
This is the core state that is shared across all subgraphs.
"""

log_messages: Annotated[list[str], add]
log_messages: Annotated[list[str], add] = []
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from datetime import datetime
from typing import cast

from langchain_core.messages import BaseMessage
from langchain_core.messages import HumanMessage
from langchain_core.messages import merge_message_runs
from langchain_core.runnables.config import RunnableConfig

from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
Expand All @@ -12,14 +12,43 @@
SubQuestionAnswerCheckUpdate,
)
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
binary_string_test,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_LLM_RATELIMIT_MESSAGE,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_LLM_TIMEOUT_MESSAGE,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_POSITIVE_VALUE_STR,
)
from onyx.agents.agent_search.shared_graph_utils.constants import AgentLLMErrorType
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLog
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
from onyx.configs.agent_configs import AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_CHECK
from onyx.llm.chat_llm import LLMRateLimitError
from onyx.llm.chat_llm import LLMTimeoutError
from onyx.prompts.agent_search import SUB_ANSWER_CHECK_PROMPT
from onyx.prompts.agent_search import UNKNOWN_ANSWER
from onyx.utils.logger import setup_logger
from onyx.utils.timing import log_function_time

logger = setup_logger()

_llm_node_error_strings = LLMNodeErrorStrings(
timeout="LLM Timeout Error. The sub-answer will be treated as 'relevant'",
rate_limit="LLM Rate Limit Error. The sub-answer will be treated as 'relevant'",
general_error="General LLM Error. The sub-answer will be treated as 'relevant'",
)


@log_function_time(print_only=True)
def check_sub_answer(
state: AnswerQuestionState, config: RunnableConfig
) -> SubQuestionAnswerCheckUpdate:
Expand Down Expand Up @@ -53,14 +82,40 @@ def check_sub_answer(

graph_config = cast(GraphConfig, config["metadata"]["config"])
fast_llm = graph_config.tooling.fast_llm
response = list(
fast_llm.stream(
agent_error: AgentErrorLog | None = None
response: BaseMessage | None = None
try:
response = fast_llm.invoke(
prompt=msg,
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_CHECK,
)

quality_str: str = cast(str, response.content)
answer_quality = binary_string_test(
text=quality_str, positive_value=AGENT_POSITIVE_VALUE_STR
)
log_result = f"Answer quality: {quality_str}"

except LLMTimeoutError:
agent_error = AgentErrorLog(
error_type=AgentLLMErrorType.TIMEOUT,
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
error_result=_llm_node_error_strings.timeout,
)
answer_quality = True
log_result = agent_error.error_result
logger.error("LLM Timeout Error - check sub answer")

except LLMRateLimitError:
agent_error = AgentErrorLog(
error_type=AgentLLMErrorType.RATE_LIMIT,
error_message=AGENT_LLM_RATELIMIT_MESSAGE,
error_result=_llm_node_error_strings.rate_limit,
)
)

quality_str: str = merge_message_runs(response, chunk_separator="")[0].content
answer_quality = "yes" in quality_str.lower()
answer_quality = True
log_result = agent_error.error_result
logger.error("LLM Rate Limit Error - check sub answer")

return SubQuestionAnswerCheckUpdate(
answer_quality=answer_quality,
Expand All @@ -69,7 +124,7 @@ def check_sub_answer(
graph_component="initial - generate individual sub answer",
node_name="check sub answer",
node_start_time=node_start_time,
result=f"Answer quality: {quality_str}",
result=log_result,
)
],
)
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,23 @@
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
build_sub_question_answer_prompt,
)
from onyx.agents.agent_search.shared_graph_utils.calculations import (
dedup_sort_inference_section_list,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_LLM_RATELIMIT_MESSAGE,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AGENT_LLM_TIMEOUT_MESSAGE,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
AgentLLMErrorType,
)
from onyx.agents.agent_search.shared_graph_utils.constants import (
LLM_ANSWER_ERROR_MESSAGE,
)
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLog
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
from onyx.agents.agent_search.shared_graph_utils.utils import get_answer_citation_ids
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
Expand All @@ -30,12 +47,23 @@
from onyx.chat.models import StreamStopReason
from onyx.chat.models import StreamType
from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS
from onyx.configs.agent_configs import AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_GENERATION
from onyx.llm.chat_llm import LLMRateLimitError
from onyx.llm.chat_llm import LLMTimeoutError
from onyx.prompts.agent_search import NO_RECOVERED_DOCS
from onyx.utils.logger import setup_logger
from onyx.utils.timing import log_function_time

logger = setup_logger()

_llm_node_error_strings = LLMNodeErrorStrings(
timeout="LLM Timeout Error. A sub-answer could not be constructed and the sub-question will be ignored.",
rate_limit="LLM Rate Limit Error. A sub-answer could not be constructed and the sub-question will be ignored.",
general_error="General LLM Error. A sub-answer could not be constructed and the sub-question will be ignored.",
)


@log_function_time(print_only=True)
def generate_sub_answer(
state: AnswerQuestionState,
config: RunnableConfig,
Expand All @@ -51,12 +79,17 @@ def generate_sub_answer(
state.verified_reranked_documents
level, question_num = parse_question_id(state.question_id)
context_docs = state.context_documents[:AGENT_MAX_ANSWER_CONTEXT_DOCS]

context_docs = dedup_sort_inference_section_list(context_docs)

persona_contextualized_prompt = get_persona_agent_prompt_expressions(
graph_config.inputs.search_request.persona
).contextualized_prompt

if len(context_docs) == 0:
answer_str = NO_RECOVERED_DOCS
cited_documents: list = []
log_results = "No documents retrieved"
write_custom_event(
"sub_answers",
AgentAnswerPiece(
Expand All @@ -79,41 +112,67 @@ def generate_sub_answer(

response: list[str | list[str | dict[str, Any]]] = []
dispatch_timings: list[float] = []
for message in fast_llm.stream(
prompt=msg,
):
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
content = message.content
if not isinstance(content, str):
raise ValueError(
f"Expected content to be a string, but got {type(content)}"

agent_error: AgentErrorLog | None = None

try:
for message in fast_llm.stream(
prompt=msg,
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_GENERATION,
):
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
content = message.content
if not isinstance(content, str):
raise ValueError(
f"Expected content to be a string, but got {type(content)}"
)
start_stream_token = datetime.now()
write_custom_event(
"sub_answers",
AgentAnswerPiece(
answer_piece=content,
level=level,
level_question_num=question_num,
answer_type="agent_sub_answer",
),
writer,
)
end_stream_token = datetime.now()
dispatch_timings.append(
(end_stream_token - start_stream_token).microseconds
)
start_stream_token = datetime.now()
write_custom_event(
"sub_answers",
AgentAnswerPiece(
answer_piece=content,
level=level,
level_question_num=question_num,
answer_type="agent_sub_answer",
),
writer,
response.append(content)

except LLMTimeoutError:
agent_error = AgentErrorLog(
error_type=AgentLLMErrorType.TIMEOUT,
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
error_result=_llm_node_error_strings.timeout,
)
end_stream_token = datetime.now()
dispatch_timings.append(
(end_stream_token - start_stream_token).microseconds
logger.error("LLM Timeout Error - generate sub answer")
except LLMRateLimitError:
agent_error = AgentErrorLog(
error_type=AgentLLMErrorType.RATE_LIMIT,
error_message=AGENT_LLM_RATELIMIT_MESSAGE,
error_result=_llm_node_error_strings.rate_limit,
)
response.append(content)
logger.error("LLM Rate Limit Error - generate sub answer")

answer_str = merge_message_runs(response, chunk_separator="")[0].content
logger.debug(
f"Average dispatch time: {sum(dispatch_timings) / len(dispatch_timings)}"
)
if agent_error:
answer_str = LLM_ANSWER_ERROR_MESSAGE
cited_documents = []
log_results = (
agent_error.error_result
or "Sub-answer generation failed due to LLM error"
)

answer_citation_ids = get_answer_citation_ids(answer_str)
cited_documents = [
context_docs[id] for id in answer_citation_ids if id < len(context_docs)
]
else:
answer_str = merge_message_runs(response, chunk_separator="")[0].content
answer_citation_ids = get_answer_citation_ids(answer_str)
cited_documents = [
context_docs[id] for id in answer_citation_ids if id < len(context_docs)
]
log_results = None

stop_event = StreamStopInfo(
stop_reason=StreamStopReason.FINISHED,
Expand All @@ -131,7 +190,7 @@ def generate_sub_answer(
graph_component="initial - generate individual sub answer",
node_name="generate sub answer",
node_start_time=node_start_time,
result="",
result=log_results or "",
)
],
)
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,8 @@ class SubQuestionRetrievalIngestionUpdate(LoggerUpdate, BaseModel):


class SubQuestionAnsweringInput(SubgraphCoreState):
question: str = ""
question_id: str = (
"" # 0_0 is original question, everything else is <level>_<question_num>.
)
question: str
question_id: str
# level 0 is original question and first decomposition, level 1 is follow up, etc
# question_num is a unique number per original question per level.

Expand Down
Loading

0 comments on commit 6687d5d

Please sign in to comment.