From e61928d8fd68911c8d52938016a82836e430c13e Mon Sep 17 00:00:00 2001 From: "Michael B. Klein" Date: Thu, 12 Dec 2024 18:08:35 +0000 Subject: [PATCH] Add delete_checkpoints to the s3 checkpointer Slightly change key pattern of checkpoints in s3 --- chat/src/agent/s3_saver.py | 70 +++++++++++++++++++++++++++++++------- chat/src/handlers/chat.py | 7 ++-- 2 files changed, 62 insertions(+), 15 deletions(-) diff --git a/chat/src/agent/s3_saver.py b/chat/src/agent/s3_saver.py index 4cc4293..0874808 100644 --- a/chat/src/agent/s3_saver.py +++ b/chat/src/agent/s3_saver.py @@ -73,21 +73,38 @@ def object_hook(dct): obj = json.loads(json_str, object_hook=object_hook) return obj +def _namespace(val): + return "__default__" if val == "" else val + +def _namespace_val(namespace): + return "" if namespace == "__default__" else namespace + +def _make_s3_thread_prefix(thread_id: str) -> str: + return f"checkpoints/{thread_id}" + +def _make_s3_namespace_prefix(thread_id: str, checkpoint_ns: str) -> str: + prefix = _make_s3_thread_prefix(thread_id) + return f"{prefix}/{_namespace(checkpoint_ns)}" + +def _make_s3_checkpoint_prefix(thread_id: str, checkpoint_ns: str, checkpoint_id: str) -> str: + prefix = _make_s3_namespace_prefix(thread_id, checkpoint_ns) + return f"{prefix}/{checkpoint_id}" def _make_s3_checkpoint_key(thread_id: str, checkpoint_ns: str, checkpoint_id: str) -> str: - return f"checkpoints/{thread_id}/{checkpoint_ns}/{checkpoint_id}.json" + prefix = _make_s3_checkpoint_prefix(thread_id, checkpoint_ns, checkpoint_id) + return f"{prefix}/checkpoint.json" def _make_s3_write_key(thread_id: str, checkpoint_ns: str, checkpoint_id: str, task_id: str, idx: int) -> str: - return f"checkpoints/{thread_id}/{checkpoint_ns}/{checkpoint_id}/writes/{task_id}/{idx}.json" + prefix = _make_s3_checkpoint_prefix(thread_id, checkpoint_ns, checkpoint_id) + return f"{prefix}/writes/{task_id}/{idx}.json" def _parse_s3_checkpoint_key(key: str) -> Dict[str, str]: parts = key.split("/") - if len(parts) < 4: + if len(parts) < 5 or parts[4] != "checkpoint.json": raise ValueError("Invalid checkpoint key format") thread_id = parts[1] - checkpoint_ns = parts[2] - filename = parts[3] - checkpoint_id = filename[:-5] # remove ".json" + checkpoint_ns = _namespace_val(parts[2]) + checkpoint_id = parts[3] return { "thread_id": thread_id, "checkpoint_ns": checkpoint_ns, @@ -239,10 +256,10 @@ def list( thread_id = config["configurable"]["thread_id"] checkpoint_ns = config["configurable"].get("checkpoint_ns", "") - prefix = f"checkpoints/{thread_id}/{checkpoint_ns}/" + prefix = _make_s3_namespace_prefix(thread_id, checkpoint_ns) paginator = self.s3.get_paginator("list_objects_v2") - pages = paginator.paginate(Bucket=self.bucket_name, Prefix=prefix) + pages = paginator.paginate(Bucket=self.bucket_name, Prefix=f"{prefix}/") keys = [] for page in pages: @@ -312,9 +329,9 @@ def list( ) def _get_latest_checkpoint_id(self, thread_id: str, checkpoint_ns: str) -> Optional[str]: - prefix = f"checkpoints/{thread_id}/{checkpoint_ns}/" + prefix = _make_s3_namespace_prefix(thread_id, checkpoint_ns) paginator = self.s3.get_paginator("list_objects_v2") - pages = paginator.paginate(Bucket=self.bucket_name, Prefix=prefix) + pages = paginator.paginate(Bucket=self.bucket_name, Prefix=f"{prefix}/") keys = [] for page in pages: for c in page.get("Contents", []): @@ -331,7 +348,7 @@ def _get_latest_checkpoint_id(self, thread_id: str, checkpoint_ns: str) -> Optio return latest_id def _load_pending_writes(self, thread_id: str, checkpoint_ns: str, checkpoint_id: str) -> List[PendingWrite]: - prefix = f"checkpoints/{thread_id}/{checkpoint_ns}/{checkpoint_id}/writes/" + prefix = _make_s3_checkpoint_prefix(thread_id, checkpoint_ns, checkpoint_id) + "/writes/" paginator = self.s3.get_paginator("list_objects_v2") pages = paginator.paginate(Bucket=self.bucket_name, Prefix=prefix) @@ -351,4 +368,33 @@ def _load_pending_writes(self, thread_id: str, checkpoint_ns: str, checkpoint_id value = self.serde.loads_typed((value_type, value_data)) writes.append((task_id, channel, value)) - return writes \ No newline at end of file + return writes + +def delete_checkpoints(bucket_name, thread_id, region_name="us-east-1"): + """ + Deletes all items with the specified thread_id from the checkpoint + bucket. + + :param bucket_name: The name of the S3 checkpoint bucket + :param thread_id: The thread_id value to delete. + :param region_name: The S3 region the bucket is in + """ + session = boto3.Session(region_name=region_name) + client = session.client("s3") + + def delete_objects(objects): + if objects['Objects']: + client.delete_objects(Bucket=bucket_name, Delete=objects) + + paginator = client.get_paginator("list_objects_v2") + pages = paginator.paginate(Bucket=bucket_name, Prefix=f"checkpoints/{thread_id}/") + + to_delete = dict(Objects=[]) + for item in pages.search('Contents'): + if item is not None: + to_delete['Objects'].append(dict(Key=item['Key'])) + + if len(to_delete['Objects']) >= 1000: + delete_objects(to_delete) + + delete_objects(to_delete) diff --git a/chat/src/handlers/chat.py b/chat/src/handlers/chat.py index 85553c3..b34593a 100644 --- a/chat/src/handlers/chat.py +++ b/chat/src/handlers/chat.py @@ -6,6 +6,7 @@ from datetime import datetime from event_config import EventConfig # from honeybadger import honeybadger +from agent.s3_saver import delete_checkpoints from agent.search_agent import search_agent from langchain_core.messages import HumanMessage from agent.agent_handler import AgentHandler @@ -55,8 +56,8 @@ def handler(event, context): config.socket.send({"type": "error", "message": "Unauthorized"}) return {"statusCode": 401, "body": "Unauthorized"} - # if config.forget: - # delete_checkpoint(config.ref) + if config.forget: + delete_checkpoints(os.getenv("CHECKPOINT_BUCKET_NAME"), config.ref) if config.question is None or config.question == "": config.socket.send({"type": "error", "message": "Question cannot be blank"}) @@ -76,7 +77,7 @@ def handler(event, context): search_agent.invoke( {"messages": [HumanMessage(content=config.question)]}, config={"configurable": {"thread_id": config.ref}, "callbacks": callbacks}, - debug=True + debug=False ) except Exception as e: print(f"Error: {e}")