Skip to content

Commit

Permalink
Teach the chat handler how to forget on demand
Browse files Browse the repository at this point in the history
  • Loading branch information
mbklein committed Dec 11, 2024
1 parent 1d5490e commit 3b9cd33
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 0 deletions.
62 changes: 62 additions & 0 deletions chat/src/agent/dynamodb_cleaner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import os
import boto3
from boto3.dynamodb.conditions import Key
from botocore.exceptions import ClientError

def delete_checkpoint(thread_id, region_name=os.getenv("AWS_REGION")):
"""
Deletes all items with the specified thread_id from the checkpoint
DynamoDB tables.
:param thread_id: The thread_id value to delete.
:param region_name: AWS region where the table is hosted.
"""
for table_var in ["CHECKPOINT_TABLE", "CHECKPOINT_WRITES_TABLE"]:
delete_thread(os.getenv(table_var), thread_id, region_name)

def delete_thread(table_name, thread_id, region_name=os.getenv("AWS_REGION")):
"""
Deletes all items with the specified thread_id from the DynamoDB table.
:param table_name: Name of the DynamoDB table.
:param thread_id: The thread_id value to delete.
:param region_name: AWS region where the table is hosted.
"""
# Initialize a session using Amazon DynamoDB
session = boto3.Session(region_name=region_name)
dynamodb = session.resource('dynamodb')
table = dynamodb.Table(table_name)

try:
# Query the table for all items with the given thread_id
response = table.query(
KeyConditionExpression=Key('thread_id').eq(thread_id)
)

items = response.get('Items', [])

# Continue querying if there are more items (pagination)
while 'LastEvaluatedKey' in response:
response = table.query(
KeyConditionExpression=Key('thread_id').eq(thread_id),
ExclusiveStartKey=response['LastEvaluatedKey']
)
items.extend(response.get('Items', []))

if not items:
print(f"No items found with thread_id: {thread_id}")
return

# Prepare delete requests in batches of 25 (DynamoDB limit for BatchWriteItem)
with table.batch_writer() as batch:
for item in items:
key = {
'thread_id': item['thread_id'],
'sort_key': item['sort_key'] # Ensure you use the correct sort key name
}
batch.delete_item(Key=key)

print(f"Successfully deleted {len(items)} items with thread_id: {thread_id}")

except ClientError as e:
print(f"An error occurred: {e.response['Error']['Message']}")
2 changes: 2 additions & 0 deletions chat/src/event_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ class EventConfig:
api_token: ApiToken = field(init=False)
debug_mode: bool = field(init=False)
event: dict = field(default_factory=dict)
forget: bool = field(init=False)
is_dev_team: bool = field(init=False)
is_logged_in: bool = field(init=False)
is_superuser: bool = field(init=False)
Expand All @@ -81,6 +82,7 @@ def __post_init__(self):
self.payload = json.loads(self.event.get("body", "{}"))
self.api_token = ApiToken(signed_token=self.payload.get("auth"))
self.debug_mode = self._is_debug_mode_enabled()
self.forget = self.payload.get("forget", False)
self.is_dev_team = self.api_token.is_dev_team()
self.is_logged_in = self.api_token.is_logged_in()
self.is_superuser = self.api_token.is_superuser()
Expand Down
4 changes: 4 additions & 0 deletions chat/src/handlers/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from datetime import datetime
from event_config import EventConfig
# from honeybadger import honeybadger
from agent.dynamodb_cleaner import delete_checkpoint
from agent.search_agent import search_agent
from langchain_core.messages import HumanMessage
from agent.agent_handler import AgentHandler
Expand Down Expand Up @@ -70,6 +71,9 @@ def handler(event, context):
logGroupName=log_group, logStreamName=log_stream, logEvents=log_events
)

if config.forget:
delete_checkpoint(config.ref)

callbacks = [AgentHandler(config.socket, config.ref)]
try:
search_agent.invoke(
Expand Down

0 comments on commit 3b9cd33

Please sign in to comment.