Skip to content

Commit

Permalink
Refactor SearchAgent WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
bmquinn committed Dec 17, 2024
1 parent eea458f commit f5445ef
Show file tree
Hide file tree
Showing 8 changed files with 181 additions and 97 deletions.
12 changes: 12 additions & 0 deletions chat/src/agent/checkpoint_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import os
import logging

from agent.s3_saver import S3Saver
from langgraph.checkpoint.base import BaseCheckpointSaver

logger = logging.getLogger(__name__)

def create_checkpoint_saver(**kwargs) -> BaseCheckpointSaver:
checkpoint_bucket: str = os.getenv("CHECKPOINT_BUCKET_NAME")

return S3Saver(bucket_name=checkpoint_bucket, **kwargs)
24 changes: 17 additions & 7 deletions chat/src/agent/metrics_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,25 @@ def __init__(self, *args, **kwargs):
self.answers = []
self.artifacts = []
super().__init__(*args, **kwargs)

def on_llm_end(self, response: LLMResult, **kwargs: Dict[str, Any]):
if response is None:
return

if not response.generations or not response.generations[0]:
return

for generation in response.generations[0]:
self.answers.append(generation.text)
for k, v in generation.message.usage_metadata.items():
if k not in self.accumulator:
self.accumulator[k] = v
else:
self.accumulator[k] += v
self.answers.append(generation.text)

if not hasattr(generation, 'message') or generation.message is None:
continue

metadata = getattr(generation.message, 'usage_metadata', None)
if metadata is None:
continue

for k, v in metadata.items():
self.accumulator[k] = self.accumulator.get(k, 0) + v

def on_tool_end(self, output: ToolMessage, **kwargs: Dict[str, Any]):
match output.name:
Expand Down
47 changes: 23 additions & 24 deletions chat/src/agent/s3_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,31 +371,30 @@ def _load_pending_writes(self, thread_id: str, checkpoint_ns: str, checkpoint_id

return writes

def delete_checkpoints(bucket_name, thread_id, region_name="us-east-1"):
"""
Deletes all items with the specified thread_id from the checkpoint
bucket.
:param bucket_name: The name of the S3 checkpoint bucket
:param thread_id: The thread_id value to delete.
:param region_name: The S3 region the bucket is in
"""
session = boto3.Session(region_name=region_name)
client = session.client("s3")

def delete_objects(objects):
if objects['Objects']:
client.delete_objects(Bucket=bucket_name, Delete=objects)
def delete_checkpoints(self, thread_id: str) -> None:
"""
Deletes all items with the specified thread_id from the checkpoint bucket.
paginator = client.get_paginator("list_objects_v2")
pages = paginator.paginate(Bucket=bucket_name, Prefix=f"checkpoints/{thread_id}/")
Args:
thread_id: The thread_id value to delete
"""
def delete_objects(objects: dict) -> None:
if objects['Objects']:
self.s3.delete_objects(Bucket=self.bucket_name, Delete=objects)

paginator = self.s3.get_paginator("list_objects_v2")
prefix = f"checkpoints/{thread_id}/"
pages = paginator.paginate(Bucket=self.bucket_name, Prefix=prefix)

to_delete = dict(Objects=[])
for item in pages.search('Contents'):
if item is not None:
to_delete['Objects'].append(dict(Key=item['Key']))
to_delete = {'Objects': []}
for item in pages.search('Contents'):
if item is not None:
to_delete['Objects'].append({'Key': item['Key']})

if len(to_delete['Objects']) >= 1000:
delete_objects(to_delete)
# Batch deletions in groups of 1000 (S3's limit)
if len(to_delete['Objects']) >= 1000:
delete_objects(to_delete)
to_delete['Objects'] = []

delete_objects(to_delete)
# Delete any remaining objects
delete_objects(to_delete)
22 changes: 10 additions & 12 deletions chat/src/agent/search_agent.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import os

from typing import Literal, List

from agent.s3_saver import S3Saver, delete_checkpoints
from agent.checkpoint_factory import create_checkpoint_saver
from agent.tools import aggregate, discover_fields, search
from langchain_core.messages import HumanMessage
from langchain_core.messages.base import BaseMessage
Expand All @@ -23,21 +21,21 @@ def __init__(
self,
model: BaseModel,
*,
checkpoint_bucket: str = os.getenv("CHECKPOINT_BUCKET_NAME"),
streaming: bool = True,
system_message: str = DEFAULT_SYSTEM_MESSAGE,
**kwargs):

self.checkpoint_bucket = checkpoint_bucket
**kwargs
):
self.streaming = streaming

tools = [discover_fields, search, aggregate]
tool_node = ToolNode(tools)

try:
model = model.bind_tools(tools)
except NotImplementedError:
print("Model does not support tool binding")
pass


# Define the function that determines whether to continue or not
def should_continue(state: MessagesState) -> Literal["tools", END]:
messages = state["messages"]
Expand Down Expand Up @@ -73,13 +71,13 @@ def call_model(state: MessagesState):
# Add a normal edge from `tools` to `agent`
workflow.add_edge("tools", "agent")

checkpointer = S3Saver(bucket_name=checkpoint_bucket, compression="gzip")
self.search_agent = workflow.compile(checkpointer=checkpointer)
self.checkpointer = create_checkpoint_saver()
self.search_agent = workflow.compile(checkpointer=self.checkpointer)

def invoke(self, question: str, ref: str, *, callbacks: List[BaseCallbackHandler] = [], forget: bool = False, **kwargs):
if forget:
delete_checkpoints(self.checkpoint_bucket, ref)

self.checkpointer.delete_checkpoints(ref)
return self.search_agent.invoke(
{"messages": [HumanMessage(content=question)]},
config={"configurable": {"thread_id": ref}, "callbacks": callbacks},
Expand Down
1 change: 1 addition & 0 deletions chat/src/handlers/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def handler(event, context):
metrics = MetricsHandler()
callbacks = [AgentHandler(config.socket, config.ref), metrics]
search_agent = SearchAgent(model=chat_model(config), streaming=True)

try:
search_agent.invoke(config.question, config.ref, forget=config.forget, callbacks=callbacks)
log_metrics(context, metrics, config)
Expand Down
Empty file added chat/test/agent/__init__.py
Empty file.
89 changes: 89 additions & 0 deletions chat/test/agent/test_search_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
from unittest import TestCase
from unittest.mock import patch
import sys

sys.path.append('./src')

from agent.search_agent import SearchAgent
from handlers.model import chat_model, set_model_override
from langchain_core.language_models.fake_chat_models import FakeListChatModel
from langgraph.checkpoint.memory import MemorySaver


class TestSearchAgent(TestCase):

@patch('agent.search_agent.create_checkpoint_saver', return_value=MemorySaver())
def test_search_agent_init(self, mock_create_saver):
set_model_override(FakeListChatModel(responses=["fake response"]))
search_agent = SearchAgent(model=chat_model("test"), streaming=True)
self.assertIsNotNone(search_agent)

@patch('agent.search_agent.create_checkpoint_saver', return_value=MemorySaver())
def test_search_agent_invoke_simple(self, mock_create_saver):
expected_response = "This is a mocked LLM response."
set_model_override(FakeListChatModel(responses=[expected_response]))

search_agent = SearchAgent(model=chat_model("test"), streaming=True)
result = search_agent.invoke(question="What is the capital of France?", ref="test_ref")

self.assertIn("messages", result)
self.assertGreater(len(result["messages"]), 0)
self.assertEqual(result["messages"][-1].content, expected_response)

@patch('agent.search_agent.create_checkpoint_saver')
def test_search_agent_invocation(self, mock_create_saver):
# Create a memory saver instance with a Mock for delete_checkpoints
memory_saver = MemorySaver()
from unittest.mock import Mock
memory_saver.delete_checkpoints = Mock()
mock_create_saver.return_value = memory_saver

# Test that the SearchAgent invokes the model with the correct messages
set_model_override(FakeListChatModel(responses=["first response", "second response"]))
search_agent = SearchAgent(model=chat_model("test"), streaming=True)

# First invocation with some question
result_1 = search_agent.invoke(question="First question?", ref="test_ref")
self.assertIn("messages", result_1)
self.assertEqual(result_1["messages"][-1].content, "first response")

# Second invocation, same ref, should retain memory
result_2 = search_agent.invoke(question="Second question?", ref="test_ref")
self.assertEqual(result_2["messages"][-1].content, "second response")

# Verify delete_checkpoints was not called
memory_saver.delete_checkpoints.assert_not_called()


@patch('agent.search_agent.create_checkpoint_saver')
def test_search_agent_invoke_forget(self, mock_create_saver):
# Create a memory saver instance with a Mock for delete_checkpoints
memory_saver = MemorySaver()
from unittest.mock import Mock
memory_saver.delete_checkpoints = Mock()
mock_create_saver.return_value = memory_saver

# Test `forget=True` to ensure that state is reset and doesn't carry over between invocations
set_model_override(FakeListChatModel(responses=["first response", "second response"]))
search_agent = SearchAgent(model=chat_model("test"), streaming=True)

# First invocation with some question
result_1 = search_agent.invoke(question="First question?", ref="test_ref")
self.assertIn("messages", result_1)
self.assertEqual(result_1["messages"][-1].content, "first response")

# Second invocation, same ref, should retain memory if we don't forget
result_2 = search_agent.invoke(question="Second question?", ref="test_ref")
self.assertEqual(result_2["messages"][-1].content, "second response")

# Now invoke with forget=True, resetting the state
set_model_override(FakeListChatModel(responses=["fresh response"]))
search_agent = SearchAgent(model=chat_model("test"), streaming=True)

# Forget the state for "test_ref"
result_3 = search_agent.invoke(question="Third question after forgetting?", ref="test_ref", forget=True)
# With a fresh FakeListChatModel, the response should now be "fresh response"
self.assertEqual(result_3["messages"][-1].content, "fresh response")

# Verify delete_checkpoints was called
memory_saver.delete_checkpoints.assert_called_once_with("test_ref")
83 changes: 29 additions & 54 deletions chat/test/handlers/test_chat.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
# ruff: noqa: E402

from unittest import TestCase
from unittest.mock import patch
import json
import os
import sys

sys.path.append('./src')

from unittest import mock, TestCase
from unittest.mock import patch
from handlers.chat import handler
from handlers.model import set_model_override
from helpers.apitoken import ApiToken
from langchain_core.language_models.fake_chat_models import FakeListChatModel
from langgraph.checkpoint.memory import MemorySaver
from websocket import Websocket

from langchain_core.language_models.fake_chat_models import FakeListChatModel
from handlers.model import set_model_override

class MockClient:
def __init__(self):
Expand All @@ -24,55 +24,30 @@ def post_to_connection(self, Data, ConnectionId):
return Data

class MockContext:
def __init__(self):
self.log_stream_name = 'test'

# TODO: Find a way to build a better mock response (maybe using helpers.metrics.debug_response)
def mock_response(**kwargs):
result = {
'answer': 'Answer.',
'attributes': ['accession_number', 'alternate_title', 'api_link', 'canonical_link', 'caption', 'collection', 'contributor', 'date_created', 'date_created_edtf', 'description', 'genre', 'id', 'identifier', 'keywords', 'language', 'notes', 'physical_description_material', 'physical_description_size', 'provenance', 'publisher', 'rights_statement', 'subject', 'table_of_contents', 'thumbnail', 'title', 'visibility', 'work_type'],
'azure_endpoint': 'https://nul-ai-east.openai.azure.com/',
'deployment_name': 'gpt-4o',
'is_dev_team': False,
'is_superuser': False,
'k': 10,
'openai_api_version': '2024-02-01',
'prompt': "Prompt",
'question': 'Question?',
'ref': 'ref123',
'size': 20,
'source_documents': [],
'temperature': 0.2,
'text_key': 'id',
'token_counts': {'question': 19, 'answer': 348, 'prompt': 329, 'source_documents': 10428,'total': 11124}
}
result.update(kwargs)
return result

@mock.patch.dict(
os.environ,
{
"AZURE_OPENAI_RESOURCE_NAME": "test",
},
)

def __init__(self):
self.log_stream_name = 'test'

class TestHandler(TestCase):
def test_handler_unauthorized(self):

@patch.object(ApiToken, 'is_logged_in', return_value=False)
@patch('agent.search_agent.create_checkpoint_saver', return_value=MemorySaver())
def test_handler_unauthorized(self, mock_create_saver, mock_is_logged_in):
event = {"socket": Websocket(client=MockClient(), endpoint_url="test", connection_id="test", ref="test")}
self.assertEqual(handler(event, MockContext()), {'body': 'Unauthorized', 'statusCode': 401})

@patch.object(ApiToken, 'is_logged_in')
def test_handler_success(self, mock_is_logged_in):
mock_is_logged_in.return_value = True
set_model_override(FakeListChatModel(responses=["one", "two", "three"]))
event = {"socket": Websocket(client=MockClient(), endpoint_url="test", connection_id="test", ref="test"), "body": '{"question": "Question?"}' }
self.assertEqual(handler(event, MockContext()), {'statusCode': 200})
self.assertEqual(handler(event, MockContext()), {'statusCode': 401, 'body': 'Unauthorized'})

@patch.object(ApiToken, 'is_logged_in', return_value=True)
@patch('agent.search_agent.create_checkpoint_saver', return_value=MemorySaver())
def test_handler_success(self, mock_create_saver, mock_is_logged_in):
set_model_override(FakeListChatModel(responses=["fake response"]))
event = {
"socket": Websocket(client=MockClient(), endpoint_url="test", connection_id="test", ref="test"),
"body": '{"question": "Question?"}'
}
self.assertEqual(handler(event, MockContext()), {'statusCode': 200})

@patch.object(ApiToken, 'is_logged_in')
def test_handler_question_missing(self, mock_is_logged_in):
mock_is_logged_in.return_value = True
@patch.object(ApiToken, 'is_logged_in', return_value=True)
@patch('agent.search_agent.create_checkpoint_saver', return_value=MemorySaver())
def test_handler_question_missing(self, mock_create_saver, mock_is_logged_in):
mock_client = MockClient()
mock_websocket = Websocket(client=mock_client, endpoint_url="test", connection_id="test", ref="test")
event = {"socket": mock_websocket}
Expand All @@ -81,13 +56,13 @@ def test_handler_question_missing(self, mock_is_logged_in):
self.assertEqual(response["type"], "error")
self.assertEqual(response["message"], "Question cannot be blank")

@patch.object(ApiToken, 'is_logged_in')
def test_handler_question_blank(self, mock_is_logged_in):
mock_is_logged_in.return_value = True
@patch.object(ApiToken, 'is_logged_in', return_value=True)
@patch('agent.search_agent.create_checkpoint_saver', return_value=MemorySaver())
def test_handler_question_typo(self, mock_create_saver, mock_is_logged_in):
mock_client = MockClient()
mock_websocket = Websocket(client=mock_client, endpoint_url="test", connection_id="test", ref="test")
event = {"socket": mock_websocket, "body": '{"quesion": ""}'}
handler(event, MockContext())
response = json.loads(mock_client.received_data)
self.assertEqual(response["type"], "error")
self.assertEqual(response["message"], "Question cannot be blank")
self.assertEqual(response["message"], "Question cannot be blank")

0 comments on commit f5445ef

Please sign in to comment.