Skip to content

Commit

Permalink
Get agent callbacks working
Browse files Browse the repository at this point in the history
  • Loading branch information
kdid committed Dec 9, 2024
1 parent a38842d commit 491d3cf
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 46 deletions.
12 changes: 12 additions & 0 deletions chat/src/agent/agent_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from typing import Any, Dict, List

from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.messages import BaseMessage
from langchain_core.outputs import LLMResult


class AgentHandler(BaseCallbackHandler):
def on_tool_start(self, serialized: Dict[str, Any], input_str: str, **kwargs: Any) -> Any:
print(f"on_tool_start (A tool is starting): {serialized, input_str}")

callbacks = [AgentHandler()]
4 changes: 2 additions & 2 deletions chat/src/agent/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
@tool(response_format="content_and_artifact")
def search(query: str):
"""Perform a semantic search of Northwestern University Library digital collections. When answering a search query, ground your answer in the context of the results with references to the document's metadata."""
query_results = opensearch_vector_store.similarity_search(query, size=20)
query_results = opensearch_vector_store().similarity_search(query, size=20)
return json.dumps(query_results, default=str), query_results

@tool(response_format="content_and_artifact")
Expand All @@ -29,7 +29,7 @@ def aggregate(aggregation_query: str):
- Number of works by work type: work_type
"""
try:
response = opensearch_vector_store.aggregations_search(aggregation_query)
response = opensearch_vector_store().aggregations_search(aggregation_query)
return json.dumps(response, default=str), response
except Exception as e:
return json.dumps({"error": str(e)}), None
24 changes: 11 additions & 13 deletions chat/src/handlers/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from honeybadger import honeybadger
from agent.search_agent import search_agent
from langchain_core.messages import HumanMessage
from agent.agent_handler import callbacks

honeybadger.configure()
logging.getLogger("honeybadger").addHandler(logging.StreamHandler())
Expand Down Expand Up @@ -61,23 +62,20 @@ def handler(event, context):
config.socket.send({"type": "error", "message": "Question cannot be blank"})
return {"statusCode": 400, "body": "Question cannot be blank"}

log_group = os.getenv("METRICS_LOG_GROUP")
log_stream = context.log_stream_name
if log_group and ensure_log_stream_exists(log_group, log_stream):
log_client = boto3.client("logs")
log_events = [{"timestamp": timestamp(), "message": "Hello world"}]
log_client.put_log_events(
logGroupName=log_group, logStreamName=log_stream, logEvents=log_events
)

response = search_agent.invoke(
{"messages": [HumanMessage(content=config.question)]},
config={"configurable": {"thread_id": config.ref}},
config={"configurable": {"thread_id": config.ref}, "callbacks": callbacks},
)

log_group = os.getenv('METRICS_LOG_GROUP')
log_stream = context.log_stream_name
if log_group and ensure_log_stream_exists(log_group, log_stream):
log_client = boto3.client('logs')
log_events = [
{
'timestamp': timestamp(),
'message': json.dumps(response)
}
]
log_client.put_log_events(logGroupName=log_group, logStreamName=log_stream, logEvents=log_events)

return {"statusCode": 200}


Expand Down
63 changes: 32 additions & 31 deletions chat/template.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ Resources:
HONEYBADGER_REVISION: !Ref HoneybadgerRevision
METRICS_LOG_GROUP: !Ref ChatMetricsLog
SECRETS_PATH: !Ref SecretsPath
NO_COLOR: 1
Policies:
- !Ref SecretsPolicy
- Statement:
Expand All @@ -230,37 +231,37 @@ Resources:
Resource: !Sub "${ChatMetricsLog.Arn}:*"
#* Metadata:
#* BuildMethod: nodejs20.x
ChatSyncFunction:
Type: AWS::Serverless::Function
Properties:
CodeUri: ./src
Runtime: python3.12
Architectures:
- x86_64
#* Layers:
#* - !Ref ChatDependencies
MemorySize: 1024
Handler: handlers/chat_sync.handler
Timeout: 300
Environment:
Variables:
API_TOKEN_NAME: !Ref ApiTokenName
ENV_PREFIX: !Ref EnvironmentPrefix
HONEYBADGER_API_KEY: !Ref HoneybadgerApiKey
HONEYBADGER_ENVIRONMENT: !Ref HoneybadgerEnv
HONEYBADGER_REVISION: !Ref HoneybadgerRevision
METRICS_LOG_GROUP: !Ref ChatMetricsLog
SECRETS_PATH: !Ref SecretsPath
FunctionUrlConfig:
AuthType: NONE
Policies:
- !Ref SecretsPolicy
- Statement:
- Effect: Allow
Action:
- 'es:ESHttpGet'
- 'es:ESHttpPost'
Resource: '*'
# ChatSyncFunction:
# Type: AWS::Serverless::Function
# Properties:
# CodeUri: ./src
# Runtime: python3.12
# Architectures:
# - x86_64
# #* Layers:
# #* - !Ref ChatDependencies
# MemorySize: 1024
# Handler: handlers/chat_sync.handler
# Timeout: 300
# Environment:
# Variables:
# API_TOKEN_NAME: !Ref ApiTokenName
# ENV_PREFIX: !Ref EnvironmentPrefix
# HONEYBADGER_API_KEY: !Ref HoneybadgerApiKey
# HONEYBADGER_ENVIRONMENT: !Ref HoneybadgerEnv
# HONEYBADGER_REVISION: !Ref HoneybadgerRevision
# METRICS_LOG_GROUP: !Ref ChatMetricsLog
# SECRETS_PATH: !Ref SecretsPath
# FunctionUrlConfig:
# AuthType: NONE
# Policies:
# - !Ref SecretsPolicy
# - Statement:
# - Effect: Allow
# Action:
# - 'es:ESHttpGet'
# - 'es:ESHttpPost'
# Resource: '*'
# - Statement:
# - Effect: Allow
# Action:
Expand Down

0 comments on commit 491d3cf

Please sign in to comment.