diff --git a/chat/src/agent/dynamodb_cleaner.py b/chat/src/agent/dynamodb_cleaner.py new file mode 100644 index 0000000..d2c9346 --- /dev/null +++ b/chat/src/agent/dynamodb_cleaner.py @@ -0,0 +1,62 @@ +import os +import boto3 +from boto3.dynamodb.conditions import Key +from botocore.exceptions import ClientError + +def delete_checkpoint(thread_id, region_name=os.getenv("AWS_REGION")): + """ + Deletes all items with the specified thread_id from the checkpoint + DynamoDB tables. + + :param thread_id: The thread_id value to delete. + :param region_name: AWS region where the table is hosted. + """ + for table_var in ["CHECKPOINT_TABLE", "CHECKPOINT_WRITES_TABLE"]: + delete_thread(os.getenv(table_var), thread_id, region_name) + +def delete_thread(table_name, thread_id, region_name=os.getenv("AWS_REGION")): + """ + Deletes all items with the specified thread_id from the DynamoDB table. + + :param table_name: Name of the DynamoDB table. + :param thread_id: The thread_id value to delete. + :param region_name: AWS region where the table is hosted. + """ + # Initialize a session using Amazon DynamoDB + session = boto3.Session(region_name=region_name) + dynamodb = session.resource('dynamodb') + table = dynamodb.Table(table_name) + + try: + # Query the table for all items with the given thread_id + response = table.query( + KeyConditionExpression=Key('thread_id').eq(thread_id) + ) + + items = response.get('Items', []) + + # Continue querying if there are more items (pagination) + while 'LastEvaluatedKey' in response: + response = table.query( + KeyConditionExpression=Key('thread_id').eq(thread_id), + ExclusiveStartKey=response['LastEvaluatedKey'] + ) + items.extend(response.get('Items', [])) + + if not items: + print(f"No items found with thread_id: {thread_id}") + return + + # Prepare delete requests in batches of 25 (DynamoDB limit for BatchWriteItem) + with table.batch_writer() as batch: + for item in items: + key = { + 'thread_id': item['thread_id'], + 'sort_key': item['sort_key'] # Ensure you use the correct sort key name + } + batch.delete_item(Key=key) + + print(f"Successfully deleted {len(items)} items with thread_id: {thread_id}") + + except ClientError as e: + print(f"An error occurred: {e.response['Error']['Message']}") diff --git a/chat/src/event_config.py b/chat/src/event_config.py index 3b5fd28..1fe355c 100644 --- a/chat/src/event_config.py +++ b/chat/src/event_config.py @@ -60,6 +60,7 @@ class EventConfig: api_token: ApiToken = field(init=False) debug_mode: bool = field(init=False) event: dict = field(default_factory=dict) + forget: bool = field(init=False) is_dev_team: bool = field(init=False) is_logged_in: bool = field(init=False) is_superuser: bool = field(init=False) @@ -81,6 +82,7 @@ def __post_init__(self): self.payload = json.loads(self.event.get("body", "{}")) self.api_token = ApiToken(signed_token=self.payload.get("auth")) self.debug_mode = self._is_debug_mode_enabled() + self.forget = self.payload.get("forget", False) self.is_dev_team = self.api_token.is_dev_team() self.is_logged_in = self.api_token.is_logged_in() self.is_superuser = self.api_token.is_superuser() diff --git a/chat/src/handlers/chat.py b/chat/src/handlers/chat.py index 88cbc14..6e3c088 100644 --- a/chat/src/handlers/chat.py +++ b/chat/src/handlers/chat.py @@ -5,6 +5,7 @@ from datetime import datetime from event_config import EventConfig # from honeybadger import honeybadger +from agent.dynamodb_cleaner import delete_checkpoint from agent.search_agent import search_agent from langchain_core.messages import HumanMessage from agent.agent_handler import AgentHandler @@ -70,6 +71,9 @@ def handler(event, context): logGroupName=log_group, logStreamName=log_stream, logEvents=log_events ) + if config.forget: + delete_checkpoint(config.ref) + callbacks = [AgentHandler(config.socket, config.ref)] try: search_agent.invoke(