diff --git a/chat-playground/playground.ipynb b/chat-playground/playground.ipynb index 37ce9989..102edb30 100644 --- a/chat-playground/playground.ipynb +++ b/chat-playground/playground.ipynb @@ -19,7 +19,7 @@ " pass\n", "\n", "sys.path.insert(0, os.path.join(os.curdir, \"../chat/src\"))\n", - "import secrets # noqa" + "import helpers.secrets # noqa" ] }, { @@ -61,7 +61,7 @@ "source": [ "import agent.search_agent\n", "from agent.search_agent import SearchAgent\n", - "from handlers.model import chat_model\n", + "from core.setup import chat_model\n", "\n", "model = chat_model(model=\"us.anthropic.claude-3-5-sonnet-20241022-v2:0\", streaming=False)\n", "agent = SearchAgent(model=model)\n", diff --git a/chat/src/handlers/__init__.py b/chat/src/agent/callbacks/__init__.py similarity index 100% rename from chat/src/handlers/__init__.py rename to chat/src/agent/callbacks/__init__.py diff --git a/chat/src/agent/metrics_handler.py b/chat/src/agent/callbacks/metrics.py similarity index 71% rename from chat/src/agent/metrics_handler.py rename to chat/src/agent/callbacks/metrics.py index fa8b06c2..e4d292f7 100644 --- a/chat/src/agent/metrics_handler.py +++ b/chat/src/agent/callbacks/metrics.py @@ -4,12 +4,13 @@ from langchain_core.messages.tool import ToolMessage import json -class MetricsHandler(BaseCallbackHandler): +class MetricsCallbackHandler(BaseCallbackHandler): def __init__(self, *args, **kwargs): self.accumulator = {} self.answers = [] self.artifacts = [] super().__init__(*args, **kwargs) + def on_llm_end(self, response: LLMResult, **kwargs: Dict[str, Any]): if response is None: return @@ -20,18 +21,16 @@ def on_llm_end(self, response: LLMResult, **kwargs: Dict[str, Any]): for generation in response.generations[0]: if generation.text != "": self.answers.append(generation.text) - for k, v in generation.message.usage_metadata.items(): - self.accumulator[k] = self.accumulator.get(k, 0) + v - if not hasattr(generation, 'message') or generation.message is None: - continue - - metadata = getattr(generation.message, 'usage_metadata', None) - if metadata is None: - continue - - for k, v in metadata.items(): - self.accumulator[k] = self.accumulator.get(k, 0) + v + if not hasattr(generation, 'message') or generation.message is None: + continue + + metadata = getattr(generation.message, 'usage_metadata', None) + if metadata is None: + continue + + for k, v in metadata.items(): + self.accumulator[k] = self.accumulator.get(k, 0) + v def on_tool_end(self, output: ToolMessage, **kwargs: Dict[str, Any]): match output.name: diff --git a/chat/src/agent/agent_handler.py b/chat/src/agent/callbacks/socket.py similarity index 97% rename from chat/src/agent/agent_handler.py rename to chat/src/agent/callbacks/socket.py index 9e1a583f..7843c25f 100644 --- a/chat/src/agent/agent_handler.py +++ b/chat/src/agent/callbacks/socket.py @@ -1,6 +1,6 @@ from typing import Any, Dict, List, Optional -from websocket import Websocket +from core.websocket import Websocket from json.decoder import JSONDecodeError from langchain_core.callbacks import BaseCallbackHandler @@ -19,7 +19,7 @@ def deserialize_input(input_str): except JSONDecodeError: return input_str -class AgentHandler(BaseCallbackHandler): +class SocketCallbackHandler(BaseCallbackHandler): def __init__(self, socket: Websocket, ref: str, *args: List[Any], **kwargs: Dict[str, Any]): if socket is None: raise ValueError("Socket not provided to agent callback handler") diff --git a/chat/src/agent/checkpoints.py b/chat/src/agent/checkpoints.py deleted file mode 100644 index 53b50e6c..00000000 --- a/chat/src/agent/checkpoints.py +++ /dev/null @@ -1,12 +0,0 @@ -import os -import logging - -from agent.s3_saver import S3Saver -from langgraph.checkpoint.base import BaseCheckpointSaver - -logger = logging.getLogger(__name__) - -def checkpoint_saver(**kwargs) -> BaseCheckpointSaver: - checkpoint_bucket: str = os.getenv("CHECKPOINT_BUCKET_NAME") - - return S3Saver(bucket_name=checkpoint_bucket, **kwargs) \ No newline at end of file diff --git a/chat/src/agent/search_agent.py b/chat/src/agent/search_agent.py index 2e893f41..9f5ef703 100644 --- a/chat/src/agent/search_agent.py +++ b/chat/src/agent/search_agent.py @@ -1,6 +1,5 @@ from typing import Literal, List -from agent.checkpoints import checkpoint_saver from agent.tools import aggregate, discover_fields, search from langchain_core.messages import HumanMessage from langchain_core.messages.base import BaseMessage @@ -9,6 +8,7 @@ from langchain_core.messages.system import SystemMessage from langgraph.graph import END, START, StateGraph, MessagesState from langgraph.prebuilt import ToolNode +from core.setup import checkpoint_saver DEFAULT_SYSTEM_MESSAGE = """ Please provide a brief answer to the question using the tools provided. Include specific details from multiple documents that diff --git a/chat/src/agent/tools.py b/chat/src/agent/tools.py index 55688e68..4f78d66c 100644 --- a/chat/src/agent/tools.py +++ b/chat/src/agent/tools.py @@ -1,7 +1,7 @@ import json from langchain_core.tools import tool -from setup import opensearch_vector_store +from core.setup import opensearch_vector_store def get_keyword_fields(properties, prefix=''): """ diff --git a/chat/src/content_handler.py b/chat/src/content_handler.py deleted file mode 100644 index b75f98b9..00000000 --- a/chat/src/content_handler.py +++ /dev/null @@ -1,36 +0,0 @@ -import json -from typing import Dict, List -from langchain_community.embeddings.sagemaker_endpoint import EmbeddingsContentHandler - -class ContentHandler(EmbeddingsContentHandler): - content_type = "application/json" - accepts = "application/json" - - def transform_input(self, inputs: list[str], model_kwargs: Dict) -> bytes: - """ - Transforms the input into bytes that can be consumed by SageMaker endpoint. - Args: - inputs: List of input strings. - model_kwargs: Additional keyword arguments to be passed to the endpoint. - Returns: - The transformed bytes input. - """ - # Example: inference.py expects a JSON string with a "inputs" key: - input_str = json.dumps({"inputs": inputs, **model_kwargs}) - return input_str.encode("utf-8") - - def transform_output(self, output: bytes) -> List[List[float]]: - """ - Transforms the bytes output from the endpoint into a list of embeddings. - Args: - output: The bytes output from SageMaker endpoint. - Returns: - The transformed output - list of embeddings - Note: - The length of the outer list is the number of input strings. - The length of the inner lists is the embedding dimension. - """ - # Example: inference.py returns a JSON string with the list of - # embeddings in a "vectors" key: - response_json = json.loads(output.read().decode("utf-8")) - return [response_json["embedding"]] \ No newline at end of file diff --git a/chat/src/helpers/__init__.py b/chat/src/core/__init__.py similarity index 100% rename from chat/src/helpers/__init__.py rename to chat/src/core/__init__.py diff --git a/chat/src/helpers/apitoken.py b/chat/src/core/apitoken.py similarity index 100% rename from chat/src/helpers/apitoken.py rename to chat/src/core/apitoken.py diff --git a/chat/src/event_config.py b/chat/src/core/event_config.py similarity index 97% rename from chat/src/event_config.py rename to chat/src/core/event_config.py index bd9fc26a..c3bc45d4 100644 --- a/chat/src/event_config.py +++ b/chat/src/core/event_config.py @@ -4,9 +4,9 @@ from langchain_core.prompts import ChatPromptTemplate -from helpers.apitoken import ApiToken -from helpers.prompts import prompt_template -from websocket import Websocket +from core.apitoken import ApiToken +from core.prompts import prompt_template +from core.websocket import Websocket from uuid import uuid4 CHAIN_TYPE = "stuff" diff --git a/chat/src/helpers/prompts.py b/chat/src/core/prompts.py similarity index 100% rename from chat/src/helpers/prompts.py rename to chat/src/core/prompts.py diff --git a/chat/src/secrets.py b/chat/src/core/secrets.py similarity index 100% rename from chat/src/secrets.py rename to chat/src/core/secrets.py diff --git a/chat/src/setup.py b/chat/src/core/setup.py similarity index 77% rename from chat/src/setup.py rename to chat/src/core/setup.py index 8245be66..1aaef612 100644 --- a/chat/src/setup.py +++ b/chat/src/core/setup.py @@ -1,10 +1,20 @@ -from handlers.opensearch_neural_search import OpenSearchNeuralSearch +from persistence.s3_checkpointer import S3Checkpointer +from search.opensearch_neural_search import OpenSearchNeuralSearch +from langchain_aws import ChatBedrock +from langchain_core.language_models.base import BaseModel +from langgraph.checkpoint.base import BaseCheckpointSaver from opensearchpy import OpenSearch, RequestsHttpConnection from requests_aws4auth import AWS4Auth from urllib.parse import urlparse import os import boto3 +def chat_model(**kwargs) -> BaseModel: + return ChatBedrock(**kwargs) + +def checkpoint_saver(**kwargs) -> BaseCheckpointSaver: + checkpoint_bucket: str = os.getenv("CHECKPOINT_BUCKET_NAME") + return S3Checkpointer(bucket_name=checkpoint_bucket, **kwargs) def prefix(value): env_prefix = os.getenv("ENV_PREFIX") diff --git a/chat/src/websocket.py b/chat/src/core/websocket.py similarity index 95% rename from chat/src/websocket.py rename to chat/src/core/websocket.py index 0607420a..3864d2f8 100644 --- a/chat/src/websocket.py +++ b/chat/src/core/websocket.py @@ -1,5 +1,5 @@ import json -from setup import websocket_client +from core.setup import websocket_client class Websocket: def __init__(self, client=None, endpoint_url=None, connection_id=None, ref=None): diff --git a/chat/src/handlers/chat.py b/chat/src/handlers.py similarity index 60% rename from chat/src/handlers/chat.py rename to chat/src/handlers.py index ac2051da..c8c2c775 100644 --- a/chat/src/handlers/chat.py +++ b/chat/src/handlers.py @@ -1,20 +1,54 @@ -import secrets # noqa +import core.secrets # noqa import boto3 import json +import logging import os -import traceback from datetime import datetime -from event_config import EventConfig -# from honeybadger import honeybadger +from core.event_config import EventConfig +from honeybadger import honeybadger from agent.search_agent import SearchAgent -from agent.agent_handler import AgentHandler -from agent.metrics_handler import MetricsHandler -from handlers.model import chat_model +from agent.callbacks.socket import SocketCallbackHandler +from agent.callbacks.metrics import MetricsCallbackHandler +from core.setup import chat_model -# honeybadger.configure() -# logging.getLogger("honeybadger").addHandler(logging.StreamHandler()) +honeybadger.configure() +logging.getLogger("honeybadger").addHandler(logging.StreamHandler()) -def handler(event, context): +def chat_sync(event, context): + config = EventConfig(event) + + if not config.is_logged_in: + return {"statusCode": 401, "body": "Unauthorized"} + + if config.question is None or config.question == "": + return {"statusCode": 400, "body": "Question cannot be blank"} + + model = chat_model(model=config.model, streaming=False) + search_agent = SearchAgent(model=model) + result = MetricsCallbackHandler() + search_agent.invoke( + config.question, config.ref, forget=config.forget, callbacks=[result] + ) + + return { + "statusCode": 200, + "headers": {"Content-Type": "application/json"}, + "body": json.dumps( + { + "answer": result.answers, + "is_dev_team": config.api_token.is_dev_team(), + "is_superuser": config.api_token.is_superuser(), + "k": config.k, + "model": config.model, + "question": config.question, + "ref": config.ref, + "artifacts": result.artifacts, + "token_counts": result.accumulator, + } + ), + } + +def chat(event, context): config = EventConfig(event) socket = event.get("socket", None) config.setup_websocket(socket) @@ -27,20 +61,19 @@ def handler(event, context): config.socket.send({"type": "error", "message": "Question cannot be blank"}) return {"statusCode": 400, "body": "Question cannot be blank"} - metrics = MetricsHandler() - callbacks = [AgentHandler(config.socket, config.ref), metrics] + metrics = MetricsCallbackHandler() + callbacks = [SocketCallbackHandler(config.socket, config.ref), metrics] model = chat_model(model=config.model, streaming=config.stream_response) search_agent = SearchAgent(model=model) + try: search_agent.invoke(config.question, config.ref, forget=config.forget, callbacks=callbacks) log_metrics(context, metrics, config) except Exception as e: - print(f"Error: {e}") - print(traceback.format_exc()) error_response = {"type": "error", "message": "An unexpected error occurred. Please try again later."} if config.socket: config.socket.send(error_response) - return {"statusCode": 500, "body": json.dumps(error_response)} + raise e return {"statusCode": 200} diff --git a/chat/src/handlers/chat_sync.py b/chat/src/handlers/chat_sync.py deleted file mode 100644 index 3c593f65..00000000 --- a/chat/src/handlers/chat_sync.py +++ /dev/null @@ -1,43 +0,0 @@ -import secrets # noqa -import json -import logging -from agent.metrics_handler import MetricsHandler -from agent.search_agent import SearchAgent -from handlers.model import chat_model -from event_config import EventConfig -from honeybadger import honeybadger - -honeybadger.configure() -logging.getLogger('honeybadger').addHandler(logging.StreamHandler()) - -def handler(event, context): - config = EventConfig(event) - - if not config.is_logged_in: - return {"statusCode": 401, "body": "Unauthorized"} - - if config.question is None or config.question == "": - return {"statusCode": 400, "body": "Question cannot be blank"} - - model = chat_model(model=config.model, streaming=False) - search_agent = SearchAgent(model=model) - result = MetricsHandler() - search_agent.invoke(config.question, config.ref, forget=config.forget, callbacks=[result]) - - return { - "statusCode": 200, - "headers": { - "Content-Type": "application/json" - }, - "body": json.dumps({ - "answer": result.answers, - "is_dev_team": config.api_token.is_dev_team(), - "is_superuser": config.api_token.is_superuser(), - "k": config.k, - "model": config.model, - "question": config.question, - "ref": config.ref, - "artifacts": result.artifacts, - "token_counts": result.accumulator, - }) - } diff --git a/chat/src/handlers/model.py b/chat/src/handlers/model.py deleted file mode 100644 index 3cf8393d..00000000 --- a/chat/src/handlers/model.py +++ /dev/null @@ -1,5 +0,0 @@ -from langchain_aws import ChatBedrock -from langchain_core.language_models.base import BaseModel - -def chat_model(**kwargs) -> BaseModel: - return ChatBedrock(**kwargs) \ No newline at end of file diff --git a/chat/src/handlers/streaming_socket_callback_handler.py b/chat/src/handlers/streaming_socket_callback_handler.py deleted file mode 100644 index 8fe32272..00000000 --- a/chat/src/handlers/streaming_socket_callback_handler.py +++ /dev/null @@ -1,22 +0,0 @@ -from langchain.callbacks.base import BaseCallbackHandler -from websocket import Websocket -from typing import Any -from langchain_core.outputs.llm_result import LLMResult - -class StreamingSocketCallbackHandler(BaseCallbackHandler): - def __init__(self, socket: Websocket, stream: bool = True): - self.socket = socket - self.stream = stream - - def on_llm_new_token(self, token: str, **kwargs): - if len(token) > 0 and self.socket and self.stream: - return self.socket.send({"token": token}) - - def on_llm_end(self, response: LLMResult, **kwargs: Any): - try: - finish_reason = response.generations[0][0].generation_info["finish_reason"] - if self.socket: - return self.socket.send({"end": {"reason": finish_reason}}) - except Exception as err: - finish_reason = f'Unknown ({str(err)})' - print(f"Stream ended: {finish_reason}") diff --git a/chat/src/helpers/http_response.py b/chat/src/helpers/http_response.py deleted file mode 100644 index 11ba2962..00000000 --- a/chat/src/helpers/http_response.py +++ /dev/null @@ -1,63 +0,0 @@ -from helpers.metrics import debug_response -from langchain_core.output_parsers import StrOutputParser -from langchain_core.runnables import RunnableLambda, RunnablePassthrough - -def extract_prompt_value(v): - if isinstance(v, list): - return [extract_prompt_value(item) for item in v] - elif isinstance(v, dict) and 'label' in v: - return [v.get('label')] - else: - return v - -class HTTPResponse: - def __init__(self, config): - self.config = config - self.store = {} - - def debug_response_passthrough(self): - return RunnableLambda(lambda x: debug_response(self.config, x, self.original_question)) - - def original_question_passthrough(self): - def get_and_send_original_question(docs): - source_documents = [] - for doc in docs["context"]: - doc.metadata = {key: extract_prompt_value(doc.metadata.get(key)) for key in self.config.attributes if key in doc.metadata} - source_document = doc.metadata.copy() - source_document["content"] = doc.page_content - source_documents.append(source_document) - - self.context = source_documents - - original_question = { - "question": self.config.question, - "source_documents": source_documents, - } - - self.original_question = original_question - return docs - - return RunnablePassthrough(get_and_send_original_question) - - def prepare_response(self): - try: - retriever = self.config.opensearch.as_retriever(search_type="similarity", search_kwargs={"k": self.config.k, "size": self.config.size, "_source": {"excludes": ["embedding"]}}) - chain = ( - {"context": retriever, "question": RunnablePassthrough()} - | self.original_question_passthrough() - | self.config.prompt - | self.config.client - | StrOutputParser() - | self.debug_response_passthrough() - ) - response = chain.invoke(self.config.question) - response["context"] = self.context - except Exception as err: - response = { - "question": self.config.question, - "error": str(err), - "source_documents": [], - } - return response - - \ No newline at end of file diff --git a/chat/src/helpers/utils.py b/chat/src/helpers/utils.py deleted file mode 100644 index d0d243d4..00000000 --- a/chat/src/helpers/utils.py +++ /dev/null @@ -1,7 +0,0 @@ -def to_bool(val): - """Converts a value to boolean. If the value is a string, it considers - "", "no", "false", "0" as False. Otherwise, it returns the boolean of the value. - """ - if isinstance(val, str): - return val.lower() not in ["", "no", "false", "0"] - return bool(val) diff --git a/chat/src/http_event_config.py b/chat/src/http_event_config.py deleted file mode 100644 index 47f479aa..00000000 --- a/chat/src/http_event_config.py +++ /dev/null @@ -1,188 +0,0 @@ -import os -import json - -from dataclasses import dataclass, field - -from langchain_core.prompts import ChatPromptTemplate -from setup import ( - opensearch_client, - opensearch_vector_store, - openai_chat_client, -) -from typing import List -from helpers.apitoken import ApiToken -from helpers.prompts import document_template, prompt_template - -CHAIN_TYPE = "stuff" -DOCUMENT_VARIABLE_NAME = "context" -K_VALUE = 40 -MAX_K = 100 -MAX_TOKENS = 1000 -SIZE = 5 -TEMPERATURE = 0.2 -TEXT_KEY = "id" -VERSION = "2024-02-01" - -@dataclass -class HTTPEventConfig: - """ - The EventConfig class represents the configuration for an event. - Default values are set for the following properties which can be overridden in the payload message. - """ - - DEFAULT_ATTRIBUTES = ["accession_number", "alternate_title", "api_link", "canonical_link", "caption", "collection", - "contributor", "date_created", "date_created_edtf", "description", "genre", "id", "identifier", - "keywords", "language", "notes", "physical_description_material", "physical_description_size", - "provenance", "publisher", "rights_statement", "subject", "table_of_contents", "thumbnail", - "title", "visibility", "work_type"] - - api_token: ApiToken = field(init=False) - attributes: List[str] = field(init=False) - azure_endpoint: str = field(init=False) - azure_resource_name: str = field(init=False) - debug_mode: bool = field(init=False) - deployment_name: str = field(init=False) - document_prompt: ChatPromptTemplate = field(init=False) - event: dict = field(default_factory=dict) - is_logged_in: bool = field(init=False) - k: int = field(init=False) - max_tokens: int = field(init=False) - openai_api_version: str = field(init=False) - payload: dict = field(default_factory=dict) - prompt_text: str = field(init=False) - prompt: ChatPromptTemplate = field(init=False) - question: str = field(init=False) - ref: str = field(init=False) - request_context: dict = field(init=False) - temperature: float = field(init=False) - size: int = field(init=False) - stream_response: bool = field(init=False) - text_key: str = field(init=False) - - def __post_init__(self): - self.payload = json.loads(self.event.get("body", "{}")) - self.api_token = ApiToken(signed_token=self.payload.get("auth")) - self.attributes = self._get_attributes() - self.azure_endpoint = self._get_azure_endpoint() - self.azure_resource_name = self._get_azure_resource_name() - self.debug_mode = self._is_debug_mode_enabled() - self.deployment_name = self._get_deployment_name() - self.is_logged_in = self.api_token.is_logged_in() - self.k = self._get_k() - self.max_tokens = min(self.payload.get("max_tokens", MAX_TOKENS), MAX_TOKENS) - self.openai_api_version = self._get_openai_api_version() - self.prompt_text = self._get_prompt_text() - self.request_context = self.event.get("requestContext", {}) - self.question = self.payload.get("question") - self.ref = self.payload.get("ref") - self.size = self._get_size() - self.stream_response = self.payload.get("stream_response", not self.debug_mode) - self.temperature = self._get_temperature() - self.text_key = self._get_text_key() - self.document_prompt = self._get_document_prompt() - self.prompt = ChatPromptTemplate.from_template(self.prompt_text) - - def _get_payload_value_with_superuser_check(self, key, default): - if self.api_token.is_superuser(): - return self.payload.get(key, default) - else: - return default - - def _get_attributes_function(self): - try: - opensearch = opensearch_client() - mapping = opensearch.indices.get_mapping(index="dc-v2-work") - return list(next(iter(mapping.values()))['mappings']['properties'].keys()) - except StopIteration: - return [] - - def _get_attributes(self): - return self._get_payload_value_with_superuser_check("attributes", self.DEFAULT_ATTRIBUTES) - - def _get_azure_endpoint(self): - default = f"https://{self._get_azure_resource_name()}.openai.azure.com/" - return self._get_payload_value_with_superuser_check("azure_endpoint", default) - - def _get_azure_resource_name(self): - azure_resource_name = self._get_payload_value_with_superuser_check( - "azure_resource_name", os.environ.get("AZURE_OPENAI_RESOURCE_NAME") - ) - if not azure_resource_name: - raise EnvironmentError( - "Either payload must contain 'azure_resource_name' or environment variable 'AZURE_OPENAI_RESOURCE_NAME' must be set" - ) - return azure_resource_name - - def _get_deployment_name(self): - return self._get_payload_value_with_superuser_check( - "deployment_name", os.getenv("AZURE_OPENAI_LLM_DEPLOYMENT_ID") - ) - - def _get_k(self): - value = self._get_payload_value_with_superuser_check("k", K_VALUE) - return min(value, MAX_K) - - def _get_openai_api_version(self): - return self._get_payload_value_with_superuser_check( - "openai_api_version", VERSION - ) - - def _get_prompt_text(self): - return self._get_payload_value_with_superuser_check("prompt", prompt_template()) - - def _get_size(self): - return self._get_payload_value_with_superuser_check("size", SIZE) - - def _get_temperature(self): - return self._get_payload_value_with_superuser_check("temperature", TEMPERATURE) - - def _get_text_key(self): - return self._get_payload_value_with_superuser_check("text_key", TEXT_KEY) - - def _get_document_prompt(self): - return ChatPromptTemplate.from_template(document_template(self.attributes)) - - def debug_message(self): - return { - "type": "debug", - "message": { - "attributes": self.attributes, - "azure_endpoint": self.azure_endpoint, - "deployment_name": self.deployment_name, - "k": self.k, - "openai_api_version": self.openai_api_version, - "prompt": self.prompt_text, - "question": self.question, - "ref": self.ref, - "size": self.ref, - "temperature": self.temperature, - "text_key": self.text_key, - }, - } - - def setup_llm_request(self): - self._setup_vector_store() - self._setup_chat_client() - - def _setup_vector_store(self): - self.opensearch = opensearch_vector_store() - - def _setup_chat_client(self): - self.client = openai_chat_client( - azure_deployment=self.deployment_name, - azure_endpoint=self.azure_endpoint, - openai_api_version=self.openai_api_version, - max_tokens=self.max_tokens - ) - - def _is_debug_mode_enabled(self): - debug = self.payload.get("debug", False) - return debug and self.api_token.is_superuser() - - def _to_bool(self, val): - """Converts a value to boolean. If the value is a string, it considers - "", "no", "false", "0" as False. Otherwise, it returns the boolean of the value. - """ - if isinstance(val, str): - return val.lower() not in ["", "no", "false", "0"] - return bool(val) diff --git a/chat/test/helpers/__init__.py b/chat/src/persistence/__init__.py similarity index 100% rename from chat/test/helpers/__init__.py rename to chat/src/persistence/__init__.py diff --git a/chat/src/persistence/compressible_json_serializer.py b/chat/src/persistence/compressible_json_serializer.py new file mode 100644 index 00000000..46965dca --- /dev/null +++ b/chat/src/persistence/compressible_json_serializer.py @@ -0,0 +1,67 @@ +from typing import Any, Optional, Tuple + +import base64 +import bz2 +import gzip +import json +import langchain_core.messages as langchain_messages +from langchain_core.messages import BaseMessage +from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer + +class CompressibleJsonSerializer(JsonPlusSerializer): + def __init__(self, compression: Optional[str] = None): + self.compression = compression + + def dumps_typed(self, obj: Any) -> Tuple[str, Any]: + def default(o): + if isinstance(o, BaseMessage): + return { + "__type__": o.__class__.__name__, + "data": o.model_dump(), + } + raise TypeError( + f"Object of type {o.__class__.__name__} is not JSON serializable" + ) + + json_str = json.dumps(obj, default=default) + + if self.compression is None: + return "json", json_str + elif self.compression == "bz2": + compressed_str = base64.b64encode( + bz2.compress(json_str.encode("utf-8")) + ).decode("utf-8") + return "bz2_json", compressed_str + elif self.compression == "gzip": + compressed_str = base64.b64encode( + gzip.compress(json_str.encode("utf-8")) + ).decode("utf-8") + return "gzip_json", compressed_str + else: + raise ValueError(f"Unsupported compression type: {self.compression}") + + def loads_typed(self, data: Tuple[str, Any]) -> Any: + type_, payload = data + + if type_ == "json": + json_str = payload + elif type_ == "bz2_json": + json_str = bz2.decompress(base64.b64decode(payload)).decode("utf-8") + elif type_ == "gzip_json": + json_str = gzip.decompress(base64.b64decode(payload)).decode("utf-8") + else: + raise ValueError(f"Unknown data type: {type_}") + + def object_hook(dct): + if "__type__" in dct: + type_name = dct["__type__"] + data = dct["data"] + cls = getattr(langchain_messages, type_name, None) + if cls and issubclass(cls, BaseMessage): + return cls.model_construct(**data) + else: + raise ValueError(f"Unknown type: {type_name}") + return dct + + obj = json.loads(json_str, object_hook=object_hook) + return obj diff --git a/chat/src/agent/s3_saver.py b/chat/src/persistence/s3_checkpointer.py similarity index 84% rename from chat/src/agent/s3_saver.py rename to chat/src/persistence/s3_checkpointer.py index 745df719..b7435bba 100644 --- a/chat/src/agent/s3_saver.py +++ b/chat/src/persistence/s3_checkpointer.py @@ -1,14 +1,10 @@ import boto3 import json -import base64 -import bz2 -import gzip import os import time +from persistence.compressible_json_serializer import CompressibleJsonSerializer from typing import Any, Dict, Iterator, Optional, Sequence, Tuple, List from langchain_core.runnables import RunnableConfig -from langchain_core.messages import BaseMessage -import langchain_core.messages as langchain_messages from langgraph.checkpoint.base import ( BaseCheckpointSaver, @@ -19,61 +15,8 @@ PendingWrite, get_checkpoint_id, ) -from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer as BaseJsonPlusSerializer -class JsonPlusSerializer(BaseJsonPlusSerializer): - def __init__(self, compression: Optional[str] = None): - self.compression = compression - - def dumps_typed(self, obj: Any) -> Tuple[str, Any]: - def default(o): - if isinstance(o, BaseMessage): - return { - '__type__': o.__class__.__name__, - 'data': o.model_dump(), - } - raise TypeError(f'Object of type {o.__class__.__name__} is not JSON serializable') - - json_str = json.dumps(obj, default=default) - - if self.compression is None: - return 'json', json_str - elif self.compression == 'bz2': - compressed_str = base64.b64encode(bz2.compress(json_str.encode("utf-8"))).decode("utf-8") - return 'bz2_json', compressed_str - elif self.compression == 'gzip': - compressed_str = base64.b64encode(gzip.compress(json_str.encode("utf-8"))).decode("utf-8") - return 'gzip_json', compressed_str - else: - raise ValueError(f"Unsupported compression type: {self.compression}") - - def loads_typed(self, data: Tuple[str, Any]) -> Any: - type_, payload = data - - if type_ == 'json': - json_str = payload - elif type_ == 'bz2_json': - json_str = bz2.decompress(base64.b64decode(payload)).decode("utf-8") - elif type_ == 'gzip_json': - json_str = gzip.decompress(base64.b64decode(payload)).decode("utf-8") - else: - raise ValueError(f'Unknown data type: {type_}') - - def object_hook(dct): - if '__type__' in dct: - type_name = dct['__type__'] - data = dct['data'] - cls = getattr(langchain_messages, type_name, None) - if cls and issubclass(cls, BaseMessage): - return cls.model_construct(**data) - else: - raise ValueError(f'Unknown type: {type_name}') - return dct - - obj = json.loads(json_str, object_hook=object_hook) - return obj - def _namespace(val): return "__default__" if val == "" else val @@ -113,7 +56,7 @@ def _parse_s3_checkpoint_key(key: str) -> Dict[str, str]: } -class S3Saver(BaseCheckpointSaver): +class S3Checkpointer(BaseCheckpointSaver): """S3-based checkpoint saver implementation.""" def __init__( @@ -124,7 +67,7 @@ def __init__( compression: Optional[str] = None, ) -> None: super().__init__() - self.serde = JsonPlusSerializer(compression=compression) + self.serde = CompressibleJsonSerializer(compression=compression) self.s3 = boto3.client('s3', region_name=region_name, endpoint_url=endpoint_url) self.bucket_name = bucket_name diff --git a/chat/src/requirements-dev.txt b/chat/src/requirements-dev.txt index 73059330..528c4ac7 100644 --- a/chat/src/requirements-dev.txt +++ b/chat/src/requirements-dev.txt @@ -1,3 +1,4 @@ # Dev/Test Dependencies -ruff~=0.2.0 -coverage~=7.3.2 +moto~=5.0 +ruff~=0.2 +coverage~=7.3 diff --git a/chat/src/search/__init__.py b/chat/src/search/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/chat/src/helpers/hybrid_query.py b/chat/src/search/hybrid_query.py similarity index 100% rename from chat/src/helpers/hybrid_query.py rename to chat/src/search/hybrid_query.py diff --git a/chat/src/handlers/opensearch_neural_search.py b/chat/src/search/opensearch_neural_search.py similarity index 98% rename from chat/src/handlers/opensearch_neural_search.py rename to chat/src/search/opensearch_neural_search.py index 83856ae6..9ff6e24d 100644 --- a/chat/src/handlers/opensearch_neural_search.py +++ b/chat/src/search/opensearch_neural_search.py @@ -2,7 +2,7 @@ from langchain_core.vectorstores import VectorStore from opensearchpy import OpenSearch from typing import Any, List, Tuple -from helpers.hybrid_query import hybrid_query +from search.hybrid_query import hybrid_query class OpenSearchNeuralSearch(VectorStore): """Read-only OpenSearch vectorstore with neural search.""" diff --git a/chat/template.yaml b/chat/template.yaml index ab8465fc..f4a02ce2 100644 --- a/chat/template.yaml +++ b/chat/template.yaml @@ -205,7 +205,7 @@ Resources: #* Layers: #* - !Ref ChatDependencies MemorySize: 1024 - Handler: handlers/chat.handler + Handler: handlers.chat Timeout: 300 Environment: Variables: @@ -267,7 +267,7 @@ Resources: #* Layers: #* - !Ref ChatDependencies MemorySize: 1024 - Handler: handlers/chat_sync.handler + Handler: handlers.chat_sync Timeout: 300 Environment: Variables: diff --git a/chat/test/agent/callbacks/__init__.py b/chat/test/agent/callbacks/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/chat/test/agent/callbacks/test_metrics.py b/chat/test/agent/callbacks/test_metrics.py new file mode 100644 index 00000000..d421f7eb --- /dev/null +++ b/chat/test/agent/callbacks/test_metrics.py @@ -0,0 +1,94 @@ +from unittest import TestCase +from unittest.mock import patch +import sys + +sys.path.append("./src") + +from agent.callbacks.metrics import MetricsCallbackHandler + +class TestSocketCallbackHandler(TestCase): + def setUp(self): + self.ref = "test_ref" + self.handler = MetricsCallbackHandler() + + def test_on_llm_end_with_content(self): + # Mocking LLMResult and Generations + class MockMessage: + def __init__(self, text, response_metadata={}, usage_metadata={}): + self.text = text + self.message = self # For simplicity, reuse same object for .message + self.response_metadata = response_metadata + self.usage_metadata = usage_metadata + + class MockLLMResult: + def __init__(self, text, stop_reason="end_turn"): + response_metadata = {"stop_reason": stop_reason} + usage_metadata = {"input_tokens": 10, "output_tokens": 20, "total_tokens": 30} + message = MockMessage(text, response_metadata, usage_metadata) + self.generations = [[message]] + + # When response has content and end_turn stop reason + response = MockLLMResult("Here is the answer", stop_reason="end_turn") + with patch.object(self.handler, "on_llm_end", wraps=self.handler.on_llm_end) as mock: + self.handler.on_llm_end(response) + mock.assert_called_once_with(response) + self.assertEqual(self.handler.answers, ["Here is the answer"]) + self.assertEqual(self.handler.accumulator, {"input_tokens": 10, "output_tokens": 20, "total_tokens": 30}) + + def test_on_tool_end_search(self): + # Mock tool output + class MockDoc: + def __init__(self, metadata): + self.metadata = metadata + + class MockToolMessage: + def __init__(self, name, artifact): + self.name = name + self.artifact = artifact + + artifact = [ + MockDoc( + { + "id": 1, + "api_link": "https://example.edu/item/1", + "title": "Result 1", + "visibility": "public", + "work_type": "article", + "thumbnail": "img1", + } + ), + MockDoc( + { + "id": 2, + "api_link": "https://example.edu/item/2", + "title": "Result 2", + "visibility": "private", + "work_type": "document", + "thumbnail": "img2", + } + ), + ] + + output = MockToolMessage("search", artifact) + self.handler.on_tool_end(output) + self.assertEqual(self.handler.artifacts, [{"type": "source_urls", "artifact": ["https://example.edu/item/1", "https://example.edu/item/2"]}]) + + def test_on_tool_end_aggregate(self): + class MockToolMessage: + def __init__(self, name, artifact): + self.name = name + self.artifact = artifact + + output = MockToolMessage("aggregate", {"aggregation_result": {"count": 10}}) + self.handler.on_tool_end(output) + self.assertEqual(self.handler.artifacts, [{"type": "aggregation", "artifact": {"count": 10}}]) + + def test_on_tool_end_discover_fields(self): + class MockToolMessage: + def __init__(self, name, artifact): + self.name = name + self.artifact = artifact + + output = MockToolMessage("discover_fields", {}) + self.handler.on_tool_end(output) + self.assertEqual(self.handler.artifacts, []) diff --git a/chat/test/agent/test_agent_handler.py b/chat/test/agent/callbacks/test_socket.py similarity index 95% rename from chat/test/agent/test_agent_handler.py rename to chat/test/agent/callbacks/test_socket.py index d9d3f963..e90aaa6b 100644 --- a/chat/test/agent/test_agent_handler.py +++ b/chat/test/agent/callbacks/test_socket.py @@ -5,7 +5,7 @@ sys.path.append("./src") -from agent.agent_handler import AgentHandler +from agent.callbacks.socket import SocketCallbackHandler class MockClient: def __init__(self): @@ -14,11 +14,11 @@ def post_to_connection(self, Data, ConnectionId): self.received.append(Data) return Data -class TestAgentHandler(TestCase): +class TestSocketCallbackHandler(TestCase): def setUp(self): self.mock_socket = MagicMock() self.ref = "test_ref" - self.handler = AgentHandler(socket=self.mock_socket, ref=self.ref) + self.handler = SocketCallbackHandler(socket=self.mock_socket, ref=self.ref) def test_on_llm_start(self): # Given metadata that includes model name @@ -160,10 +160,10 @@ def test_on_agent_finish(self): "message": "Finished" }) -class TestAgentHandlerErrors(TestCase): +class TestSocketCallbackHandlerErrors(TestCase): def test_missing_socket(self): with self.assertRaises(ValueError) as context: - AgentHandler(socket=None, ref="abc123") + SocketCallbackHandler(socket=None, ref="abc123") self.assertIn("Socket not provided to agent callback handler", str(context.exception)) diff --git a/chat/test/core/__init__.py b/chat/test/core/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/chat/test/helpers/test_apitoken.py b/chat/test/core/test_apitoken.py similarity index 97% rename from chat/test/helpers/test_apitoken.py rename to chat/test/core/test_apitoken.py index e23a1646..ecd6970e 100644 --- a/chat/test/helpers/test_apitoken.py +++ b/chat/test/core/test_apitoken.py @@ -4,7 +4,7 @@ sys.path.append('./src') -from helpers.apitoken import ApiToken +from core.apitoken import ApiToken from test.fixtures.apitoken import DEV_TEAM_TOKEN, SUPER_TOKEN, TEST_SECRET, TEST_TOKEN from unittest import mock, TestCase diff --git a/chat/test/test_event_config.py b/chat/test/core/test_event_config.py similarity index 96% rename from chat/test/test_event_config.py rename to chat/test/core/test_event_config.py index 401c8417..3d05a064 100644 --- a/chat/test/test_event_config.py +++ b/chat/test/core/test_event_config.py @@ -3,7 +3,7 @@ import sys sys.path.append('./src') -from event_config import EventConfig +from core.event_config import EventConfig from unittest import TestCase class TestEventConfig(TestCase): diff --git a/chat/test/helpers/test_prompts.py b/chat/test/core/test_prompts.py similarity index 93% rename from chat/test/helpers/test_prompts.py rename to chat/test/core/test_prompts.py index b9a7d950..a49caee2 100644 --- a/chat/test/helpers/test_prompts.py +++ b/chat/test/core/test_prompts.py @@ -2,7 +2,7 @@ import sys sys.path.append('./src') -from helpers.prompts import prompt_template, document_template +from core.prompts import prompt_template, document_template from unittest import TestCase diff --git a/chat/test/test_websocket.py b/chat/test/core/test_websocket.py similarity index 93% rename from chat/test/test_websocket.py rename to chat/test/core/test_websocket.py index 4d4d8b76..90103ed9 100644 --- a/chat/test/test_websocket.py +++ b/chat/test/core/test_websocket.py @@ -3,7 +3,7 @@ sys.path.append('./src') from unittest import TestCase -from websocket import Websocket +from core.websocket import Websocket class MockClient: diff --git a/chat/test/handlers/test_chat.py b/chat/test/handlers/test_chat.py index aa0f4acd..91fd5f41 100644 --- a/chat/test/handlers/test_chat.py +++ b/chat/test/handlers/test_chat.py @@ -7,11 +7,11 @@ sys.path.append('./src') -from handlers.chat import handler -from helpers.apitoken import ApiToken +from handlers import chat +from core.apitoken import ApiToken +from core.websocket import Websocket from langchain_core.language_models.fake_chat_models import FakeListChatModel from langgraph.checkpoint.memory import MemorySaver -from websocket import Websocket class MockClient: @@ -32,17 +32,17 @@ class TestHandler(TestCase): @patch('agent.search_agent.checkpoint_saver', return_value=MemorySaver()) def test_handler_unauthorized(self, mock_create_saver, mock_is_logged_in): event = {"socket": Websocket(client=MockClient(), endpoint_url="test", connection_id="test", ref="test")} - self.assertEqual(handler(event, MockContext()), {'statusCode': 401, 'body': 'Unauthorized'}) + self.assertEqual(chat(event, MockContext()), {'statusCode': 401, 'body': 'Unauthorized'}) @patch.object(ApiToken, 'is_logged_in', return_value=True) @patch('agent.search_agent.checkpoint_saver', return_value=MemorySaver()) - @patch('handlers.chat.chat_model', return_value=FakeListChatModel(responses=["fake response"])) + @patch('handlers.chat_model', return_value=FakeListChatModel(responses=["fake response"])) def test_handler_success(self, mock_chat_model, mock_create_saver, mock_is_logged_in): event = { "socket": Websocket(client=MockClient(), endpoint_url="test", connection_id="test", ref="test"), "body": '{"question": "Question?"}' } - self.assertEqual(handler(event, MockContext()), {'statusCode': 200}) + self.assertEqual(chat(event, MockContext()), {'statusCode': 200}) @patch.object(ApiToken, 'is_logged_in', return_value=True) @patch('agent.search_agent.checkpoint_saver', return_value=MemorySaver()) @@ -50,7 +50,7 @@ def test_handler_question_missing(self, mock_create_saver, mock_is_logged_in): mock_client = MockClient() mock_websocket = Websocket(client=mock_client, endpoint_url="test", connection_id="test", ref="test") event = {"socket": mock_websocket} - handler(event, MockContext()) + chat(event, MockContext()) response = json.loads(mock_client.received_data) self.assertEqual(response["type"], "error") self.assertEqual(response["message"], "Question cannot be blank") @@ -61,7 +61,7 @@ def test_handler_question_typo(self, mock_create_saver, mock_is_logged_in): mock_client = MockClient() mock_websocket = Websocket(client=mock_client, endpoint_url="test", connection_id="test", ref="test") event = {"socket": mock_websocket, "body": '{"quesion": ""}'} - handler(event, MockContext()) + chat(event, MockContext()) response = json.loads(mock_client.received_data) self.assertEqual(response["type"], "error") self.assertEqual(response["message"], "Question cannot be blank") \ No newline at end of file diff --git a/chat/test/handlers/test_chat_sync.py b/chat/test/handlers/test_chat_sync.py index 18b1f89e..59d2a26a 100644 --- a/chat/test/handlers/test_chat_sync.py +++ b/chat/test/handlers/test_chat_sync.py @@ -1,7 +1,6 @@ # ruff: noqa: E402 import json -import os import sys from langchain_core.language_models.fake_chat_models import FakeListChatModel from langgraph.checkpoint.memory import MemorySaver @@ -10,8 +9,8 @@ from unittest import TestCase from unittest.mock import patch -from handlers.chat_sync import handler -from helpers.apitoken import ApiToken +from handlers import chat_sync +from core.apitoken import ApiToken class MockContext: def __init__(self): @@ -19,15 +18,15 @@ def __init__(self): class TestHandler(TestCase): def test_handler_unauthorized(self): - self.assertEqual(handler({"body": '{ "question": "Question?"}'}, MockContext()), {'body': 'Unauthorized', 'statusCode': 401}) + self.assertEqual(chat_sync({"body": '{ "question": "Question?"}'}, MockContext()), {'body': 'Unauthorized', 'statusCode': 401}) @patch.object(ApiToken, 'is_logged_in', return_value = True) def test_no_question(self, mock_is_logged_in): - self.assertEqual(handler({"body": '{ "question": ""}'}, MockContext()), {'statusCode': 400, 'body': 'Question cannot be blank'}) + self.assertEqual(chat_sync({"body": '{ "question": ""}'}, MockContext()), {'statusCode': 400, 'body': 'Question cannot be blank'}) @patch.object(ApiToken, 'is_logged_in', return_value = True) @patch("agent.search_agent.checkpoint_saver", return_value=MemorySaver()) - @patch('handlers.chat_sync.chat_model', return_value=FakeListChatModel(responses=["fake response"])) + @patch('handlers.chat_model', return_value=FakeListChatModel(responses=["fake response"])) def test_handler_success(self, mock_chat_model, mock_create_saver, mock_is_logged_in): expected_body = { "answer": ["fake response"], @@ -40,7 +39,7 @@ def test_handler_success(self, mock_chat_model, mock_create_saver, mock_is_logge "artifacts": [], "token_counts": {} } - response = handler({"body": '{"question": "Question?", "ref": "test_ref"}'}, MockContext()) + response = chat_sync({"body": '{"question": "Question?", "ref": "test_ref"}'}, MockContext()) self.assertEqual(json.loads(response.get("body")), expected_body) self.assertEqual(response.get("statusCode"), 200) diff --git a/chat/test/persistence/__init__.py b/chat/test/persistence/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/chat/test/persistence/test_compressible_json_serializer.py b/chat/test/persistence/test_compressible_json_serializer.py new file mode 100644 index 00000000..3fd09e12 --- /dev/null +++ b/chat/test/persistence/test_compressible_json_serializer.py @@ -0,0 +1,135 @@ +# ruff: noqa: E402 +import sys +sys.path.append("./src") + +from unittest import TestCase + +from langchain_core.messages import HumanMessage +from persistence.compressible_json_serializer import CompressibleJsonSerializer +import warnings + +warnings.simplefilter("ignore", DeprecationWarning) +class TestCompressibleJsonSerializer(TestCase): + def test_dumps_typed(self): + serializer = CompressibleJsonSerializer() + obj = {"key": "value"} + data = serializer.dumps_typed(obj) + self.assertEqual(data, ("json", '{"key": "value"}')) + + def test_loads_typed(self): + serializer = CompressibleJsonSerializer() + data = ("json", '{"key": "value"}') + obj = serializer.loads_typed(data) + self.assertEqual(obj, {"key": "value"}) + + def test_dumps_typed_with_bz2_compression(self): + serializer = CompressibleJsonSerializer(compression="bz2") + obj = {"key": "value"} + data = serializer.dumps_typed(obj) + self.assertEqual(data[0], "bz2_json") + + def test_loads_typed_with_bz2_compression(self): + serializer = CompressibleJsonSerializer(compression="bz2") + data = ( + "bz2_json", + "QlpoOTFBWSZTWYByjU0AAAcZgFAAABAiDAMqIAAim0BkEDQNAFPUpFyhWjhdyRThQkIByjU0", + ) + obj = serializer.loads_typed(data) + self.assertEqual(obj, {"key": "value"}) + + def test_dumps_typed_with_gzip_compression(self): + serializer = CompressibleJsonSerializer(compression="gzip") + obj = {"key": "value"} + data = serializer.dumps_typed(obj) + self.assertEqual(data[0], "gzip_json") + + def test_loads_typed_with_gzip_compression(self): + serializer = CompressibleJsonSerializer(compression="gzip") + data = ("gzip_json", "H4sIALfEW2cC/6tWyk6tVLJSUCpLzClNVaoFABtINTMQAAAA") + obj = serializer.loads_typed(data) + self.assertEqual(obj, {"key": "value"}) + + def test_nested_complex_object(self): + serializer = CompressibleJsonSerializer(compression="gzip") + data = ( + "gzip_json", + "H4sIAGwoW2cC/2WQMW/CMBCF/0rktU1lpyGELAxdunTrViHrjC8Q4Zyj2IEilP/enAtSpUoezv7e3Tu/mziLJlPPmYhhKUQhizJXxXI+Vd2oqlHFy6belFI+SdlIKRZlZ1mpsDW1WR\ +W5WRuTVxWa3LRtm5dQlXVRrlZr+8rq/RGI0OkzuAnZ4ya0DhHGqHW69RgCHBL6YhavAyYk3qce6OMX8ygLEVLL3lNEiqx5A8qufsoCwrg/Zq0fs4sfTyED46e43H004NyW+8HaLnaewOnTBcZD2mZewIhh8BRQ9xjh4c\ +KAd2GXI2/CIwh6fqDJuUcS9xq/oR8cwxZcwHnezfPf7+MYFuv/AShWPagOiHSXdDRM94zSpAHJdnRYJGRTWLv5B3RajDe+AQAA", + ) + obj = serializer.loads_typed(data) + self.assertEqual(obj, { + "v": 1, + "ts": "2024-12-12T18:16:12.989400+00:00", + "id": "1efb8b52-b7bb-66eb-bfff-4a64824557d3", + "channel_values": { + "__start__": { + "messages": [ + HumanMessage( + content="Can you search for works about football?", + additional_kwargs={}, + response_metadata={}, + ) + ] + } + }, + "channel_versions": {"__start__": 1}, + "versions_seen": {"__input__": {}}, + "pending_sends": [], + }) + + def test_dumps_typed_unsupported_compression(self): + serializer = CompressibleJsonSerializer(compression="unsupported") + with self.assertRaises(ValueError) as context: + serializer.dumps_typed({"key": "value"}) + + self.assertIn("Unsupported compression type", str(context.exception)) + + def test_loads_typed_unknown_type(self): + serializer = CompressibleJsonSerializer() + data = ("unknown_type", "payload") + with self.assertRaises(ValueError) as context: + serializer.loads_typed(data) + + self.assertIn("Unknown data type", str(context.exception)) + + def test_object_hook_unknown_type(self): + serializer = CompressibleJsonSerializer() + invalid_json = '{"__type__": "UnknownType", "data": {}}' + data = ("json", invalid_json) + with self.assertRaises(ValueError) as context: + serializer.loads_typed(data) + + self.assertIn("Unknown type", str(context.exception)) + + def test_loads_typed_empty_payload(self): + from json.decoder import JSONDecodeError + serializer = CompressibleJsonSerializer() + data = ("json", "") + with self.assertRaises(JSONDecodeError): + serializer.loads_typed(data) + + def test_dumps_typed_base_message(self): + serializer = CompressibleJsonSerializer() + # Create a BaseMessage instance (HumanMessage is one) + message = HumanMessage(content="Hello") + data_type, data_str = serializer.dumps_typed(message) + # Verify it returns a JSON string with the correct type and data + self.assertEqual(data_type, "json") + # We know that it returns {"__type__": "HumanMessage", "data": {...}} + # Check if the resulting JSON contains the expected keys + self.assertIn('"__type__": "HumanMessage"', data_str) + self.assertIn('"data":', data_str) + self.assertIn('"content": "Hello"', data_str) + + def test_dumps_typed_type_error(self): + serializer = CompressibleJsonSerializer() + + # Define a class that is not a BaseMessage + class NotSerializable: + pass + + with self.assertRaises(TypeError) as context: + serializer.dumps_typed(NotSerializable()) + + self.assertIn("is not JSON serializable", str(context.exception)) diff --git a/chat/test/persistence/test_s3_checkpointer.py b/chat/test/persistence/test_s3_checkpointer.py new file mode 100644 index 00000000..3ca76949 --- /dev/null +++ b/chat/test/persistence/test_s3_checkpointer.py @@ -0,0 +1,632 @@ +# ruff: noqa: E402 +import sys +sys.path.append("./src") + +from unittest import TestCase + +import boto3 +import json +import time +from moto import mock_aws +from langchain_core.runnables import RunnableConfig +from langgraph.checkpoint.base import ( + Checkpoint, + CheckpointMetadata, +) +from typing import Optional +from persistence.s3_checkpointer import S3Checkpointer + +import bz2 +import base64 +import gzip + +BUCKET_NAME = "mybucket" +REGION = "us-east-1" +THREAD_ID = "thread1" +CHECKPOINT_NAMESPACE = "" +CHECKPOINT_ID_1 = "checkpoint1" +CHECKPOINT_ID_2 = "checkpoint2" + +CHECKPOINTS = [ + { + "id": CHECKPOINT_ID_1, + "key": "checkpoints/thread1/__default__/checkpoint1/checkpoint.json", + "body": json.dumps( + { + "checkpoint_type": "json", + "checkpoint_data": "{}", + "metadata_data": "{}", + "parent_checkpoint_id": None, + "timestamp": int(time.time() * 1000), + } + ), + }, + { + "id": CHECKPOINT_ID_2, + "key": "checkpoints/thread1/__default__/checkpoint2/checkpoint.json", + "body": json.dumps( + { + "checkpoint_type": "json", + "checkpoint_data": "{}", + "metadata_data": "{}", + "parent_checkpoint_id": CHECKPOINT_ID_1, + "timestamp": int(time.time() * 1000), + } + ), + }, +] + + +@mock_aws +class TestS3Checkpointer(TestCase): + def setUp(self): + """Initialize the mock S3 bucket and S3Checkpointer instance before each test.""" + self.s3 = boto3.client("s3", region_name=REGION) + self.s3.create_bucket(Bucket=BUCKET_NAME) + self.checkpointer = S3Checkpointer(bucket_name=BUCKET_NAME, region_name=REGION) + + def tearDown(self): + """Clean up after each test.""" + self.checkpointer.delete_checkpoints(THREAD_ID) + + def setup_s3_bucket(self): + """Upload sample checkpoints to the mock S3 bucket.""" + for checkpoint in CHECKPOINTS: + self.s3.put_object( + Bucket=BUCKET_NAME, + Key=checkpoint["key"], + Body=checkpoint["body"], + ) + + def create_config(self, checkpoint_id: Optional[str] = None) -> RunnableConfig: + """Helper method to create RunnableConfig.""" + config_data = { + "configurable": { + "thread_id": THREAD_ID, + "checkpoint_ns": CHECKPOINT_NAMESPACE, + } + } + if checkpoint_id: + config_data["configurable"]["checkpoint_id"] = checkpoint_id + return RunnableConfig(config_data) + + # + # Basic Put and Get Checkpoints + # + + def test_put_checkpoint(self): + """Test that S3Checkpointer.put correctly saves a checkpoint to S3.""" + new_checkpoint = Checkpoint(id="checkpoint3") + metadata = CheckpointMetadata() + config = self.create_config() + + returned_config = self.checkpointer.put(config, new_checkpoint, metadata, {}) + + self.assertEqual(returned_config["configurable"]["checkpoint_id"], "checkpoint3") + expected_key = ( + f"checkpoints/{THREAD_ID}/__default__/checkpoint3/checkpoint.json" + ) + response = self.s3.get_object(Bucket=BUCKET_NAME, Key=expected_key) + body = json.loads(response["Body"].read().decode("utf-8")) + + self.assertEqual(body["checkpoint_type"], "json") + checkpoint_data = json.loads(body["checkpoint_data"]) + self.assertEqual(checkpoint_data["id"], "checkpoint3") + self.assertEqual(body["metadata_data"], "{}") + assert body["parent_checkpoint_id"] is None + assert "timestamp" in body + + def test_put_overwrite_checkpoint(self): + """Test that putting a checkpoint with an existing ID overwrites it.""" + initial_checkpoint = Checkpoint(id="checkpoint6") + initial_metadata = CheckpointMetadata() + config = self.create_config() + self.checkpointer.put(config, initial_checkpoint, initial_metadata, {}) + + updated_checkpoint = Checkpoint(id="checkpoint6") + updated_metadata = CheckpointMetadata() + self.checkpointer.put(config, updated_checkpoint, updated_metadata, {}) + + checkpoint_tuple = self.checkpointer.get_tuple(config) + assert checkpoint_tuple is not None + self.assertEqual(checkpoint_tuple.config["configurable"]["checkpoint_id"], "checkpoint6") + + def test_put_invalid_checkpoint(self): + """Test putting an invalid checkpoint raises appropriate errors.""" + with self.assertRaises(KeyError): + invalid_checkpoint = {} + config = self.create_config() + self.checkpointer.put(config, invalid_checkpoint, CheckpointMetadata(), {}) + + def test_get_tuple(self): + """Test that S3Checkpointer.get_tuple correctly retrieves a checkpoint tuple.""" + self.setup_s3_bucket() + config = self.create_config(checkpoint_id=CHECKPOINT_ID_2) + checkpoint_tuple = self.checkpointer.get_tuple(config) + + assert checkpoint_tuple is not None + assert ( + checkpoint_tuple.config["configurable"]["checkpoint_id"] == CHECKPOINT_ID_2 + ) + self.assertEqual(checkpoint_tuple.checkpoint, {}) + self.assertEqual(checkpoint_tuple.metadata, {}) + assert checkpoint_tuple.parent_config is not None + assert ( + checkpoint_tuple.parent_config["configurable"]["checkpoint_id"] + == CHECKPOINT_ID_1 + ) + self.assertEqual(checkpoint_tuple.pending_writes, []) + + def test_get_tuple_nonexistent_checkpoint(self): + """Test retrieving a checkpoint tuple that does not exist.""" + config = self.create_config(checkpoint_id="nonexistent") + checkpoint_tuple = self.checkpointer.get_tuple(config) + assert checkpoint_tuple is None + + def test_get_tuple_no_checkpoint_id_no_existing_checkpoints(self): + """Test get_tuple with no checkpoint_id and no existing checkpoints.""" + config = self.create_config() + result = self.checkpointer.get_tuple(config) + assert result is None + + def test_get_tuple_missing_metadata(self): + """Test get_tuple when metadata is missing.""" + key = f"checkpoints/{THREAD_ID}/__default__/missing_meta/checkpoint.json" + checkpoint_body = json.dumps( + { + "checkpoint_type": "json", + "checkpoint_data": "{}", + # "metadata_data": "{}" is intentionally omitted + "parent_checkpoint_id": None, + "timestamp": int(time.time() * 1000), + } + ) + self.s3.put_object(Bucket=BUCKET_NAME, Key=key, Body=checkpoint_body) + + config = self.create_config(checkpoint_id="missing_meta") + with self.assertRaises(ValueError) as context: + self.checkpointer.get_tuple(config) + + self.assertIn("Metadata is missing", str(context.exception)) + # + # Writes (Pending Writes) Tests + # + + def test_put_writes(self): + """Test that S3Checkpointer.put_writes correctly saves writes to S3.""" + checkpoint = Checkpoint(id="checkpoint4") + metadata = CheckpointMetadata() + config = self.create_config() + returned_config = self.checkpointer.put(config, checkpoint, metadata, {}) + + writes = [ + ("channel1", {"data": "value1"}), + ("channel2", {"data": "value2"}), + ] + task_id = "task123" + self.checkpointer.put_writes(returned_config, writes, task_id) + + for idx, (channel, value) in enumerate(writes): + write_key = f"checkpoints/{THREAD_ID}/__default__/checkpoint4/writes/{task_id}/{idx}.json" + response = self.s3.get_object(Bucket=BUCKET_NAME, Key=write_key) + body = json.loads(response["Body"].read().decode("utf-8")) + self.assertEqual(body["channel"], channel) + self.assertEqual(body["type"], "json") + self.assertEqual(body["value"], json.dumps(value)) + assert "timestamp" in body + + def test_put_writes_empty(self): + """Test putting an empty list of writes.""" + checkpoint = Checkpoint(id="checkpoint_empty_writes") + metadata = CheckpointMetadata() + config = self.create_config() + returned_config = self.checkpointer.put(config, checkpoint, metadata, {}) + self.checkpointer.put_writes(returned_config, [], "task_empty") + + checkpoint_tuple = self.checkpointer.get_tuple(returned_config) + assert checkpoint_tuple is not None + self.assertEqual(checkpoint_tuple.pending_writes, []) + + def test_put_writes_multiple_tasks(self): + """Test putting writes from multiple tasks.""" + checkpoint = Checkpoint(id="checkpoint_multi_tasks") + metadata = CheckpointMetadata() + config = self.create_config() + returned_config = self.checkpointer.put(config, checkpoint, metadata, {}) + + writes_task1 = [ + ("channel1", {"data": "task1_value1"}), + ("channel2", {"data": "task1_value2"}), + ] + writes_task2 = [ + ("channel1", {"data": "task2_value1"}), + ] + + self.checkpointer.put_writes(returned_config, writes_task1, "task1") + self.checkpointer.put_writes(returned_config, writes_task2, "task2") + + checkpoint_tuple = self.checkpointer.get_tuple(returned_config) + assert checkpoint_tuple is not None + self.assertEqual(len(checkpoint_tuple.pending_writes), 3) + + task1_writes = [w for w in checkpoint_tuple.pending_writes if w[0] == "task1"] + self.assertEqual(len(task1_writes), 2) + self.assertEqual(task1_writes[0][1], "channel1") + self.assertEqual(task1_writes[0][2], {"data": "task1_value1"}) + self.assertEqual(task1_writes[1][1], "channel2") + self.assertEqual(task1_writes[1][2], {"data": "task1_value2"}) + + task2_writes = [w for w in checkpoint_tuple.pending_writes if w[0] == "task2"] + self.assertEqual(len(task2_writes), 1) + self.assertEqual(task2_writes[0][1], "channel1") + self.assertEqual(task2_writes[0][2], {"data": "task2_value1"}) + + def test_get_tuple_with_writes(self): + """Test retrieving a checkpoint tuple that includes pending writes.""" + checkpoint = Checkpoint(id="checkpoint5") + metadata = CheckpointMetadata() + config = self.create_config() + returned_config = self.checkpointer.put(config, checkpoint, metadata, {}) + + writes = [ + ("channelA", {"info": "dataA"}), + ("channelB", {"info": "dataB"}), + ] + task_id = "task456" + self.checkpointer.put_writes(returned_config, writes, task_id) + + checkpoint_tuple = self.checkpointer.get_tuple(returned_config) + assert checkpoint_tuple is not None + self.assertEqual(checkpoint_tuple.config["configurable"]["checkpoint_id"], "checkpoint5") + self.assertEqual(checkpoint_tuple.checkpoint["id"], "checkpoint5") + self.assertEqual(checkpoint_tuple.metadata, {}) + assert checkpoint_tuple.parent_config is None + self.assertEqual(len(checkpoint_tuple.pending_writes), 2) + for i, (task, channel, value) in enumerate(checkpoint_tuple.pending_writes): + self.assertEqual(task, task_id) + self.assertEqual(channel, writes[i][0]) + self.assertEqual(value, writes[i][1]) + + # + # Listing Checkpoints and Filters + # + + def test_list_checkpoints_with_filters(self): + """Test listing checkpoints with filters like 'before' and 'limit'.""" + self.setup_s3_bucket() + saver = self.checkpointer + config = self.create_config() + + all_checkpoints = list(saver.list(config)) + self.assertEqual(len(all_checkpoints), len(CHECKPOINTS)) + + limited_checkpoints = list(saver.list(config, limit=1)) + self.assertEqual(len(limited_checkpoints), 1) + assert ( + limited_checkpoints[0].config["configurable"]["checkpoint_id"] + == CHECKPOINT_ID_2 + ) + + before_config = self.create_config(checkpoint_id=CHECKPOINT_ID_2) + before_checkpoints = list(saver.list(config, before=before_config)) + self.assertEqual(len(before_checkpoints), 1) + assert ( + before_checkpoints[0].config["configurable"]["checkpoint_id"] + == CHECKPOINT_ID_1 + ) + + def test_list_no_checkpoints(self): + """Test listing checkpoints when none exist.""" + config = self.create_config() + retrieved_checkpoints = list(self.checkpointer.list(config)) + self.assertEqual(len(retrieved_checkpoints), 0) + + def test_list_with_limit(self): + """Test listing with a limit.""" + self.setup_s3_bucket() + config = self.create_config() + results = list(self.checkpointer.list(config, limit=1)) + self.assertEqual(len(results), 1) + + def test_list_no_config(self): + """Test listing when no config is provided.""" + with self.assertRaises(ValueError) as context: + list(self.checkpointer.list(None)) + + self.assertIn("config must be provided", str(context.exception)) + + def test_list_before_removes_all(self): + """Test listing with a 'before' config that removes all results.""" + self.setup_s3_bucket() + config = self.create_config() + before_config = self.create_config(checkpoint_id="checkpoint0") + results = list(self.checkpointer.list(config, before=before_config)) + self.assertEqual(len(results), 0) + + # + # Parent-Child Checkpoint Relationship + # + + def test_put_and_get_with_parent_checkpoint(self): + """Test putting a checkpoint with a parent and retrieving the parent config.""" + parent_checkpoint = Checkpoint(id="parent_checkpoint") + parent_metadata = CheckpointMetadata() + parent_config = self.create_config() + self.checkpointer.put(parent_config, parent_checkpoint, parent_metadata, {}) + + child_checkpoint = Checkpoint(id="child_checkpoint") + child_metadata = CheckpointMetadata() + child_config = RunnableConfig( + { + "configurable": { + "thread_id": THREAD_ID, + "checkpoint_ns": CHECKPOINT_NAMESPACE, + "checkpoint_id": "parent_checkpoint", + } + } + ) + self.checkpointer.put(child_config, child_checkpoint, child_metadata, {}) + + child_tuple = self.checkpointer.get_tuple( + RunnableConfig( + { + "configurable": { + "thread_id": THREAD_ID, + "checkpoint_ns": CHECKPOINT_NAMESPACE, + "checkpoint_id": "child_checkpoint", + } + } + ) + ) + assert child_tuple is not None + self.assertEqual(child_tuple.config["configurable"]["checkpoint_id"], "child_checkpoint") + assert child_tuple.parent_config is not None + assert ( + child_tuple.parent_config["configurable"]["checkpoint_id"] + == "parent_checkpoint" + ) + + # + # Namespaces + # + + def test_put_with_namespace(self): + """Test putting and retrieving a checkpoint within a specific namespace.""" + namespace = "custom_ns" + config = RunnableConfig( + { + "configurable": { + "thread_id": THREAD_ID, + "checkpoint_ns": namespace, + } + } + ) + checkpoint = Checkpoint(id="checkpoint_ns1") + metadata = CheckpointMetadata() + returned_config = self.checkpointer.put(config, checkpoint, metadata, {}) + + retrieved_tuple = self.checkpointer.get_tuple(returned_config) + assert retrieved_tuple is not None + self.assertEqual(retrieved_tuple.config["configurable"]["checkpoint_ns"], namespace) + assert ( + retrieved_tuple.config["configurable"]["checkpoint_id"] == "checkpoint_ns1" + ) + + retrieved_checkpoints = list(self.checkpointer.list(config)) + self.assertEqual(len(retrieved_checkpoints), 1) + assert ( + retrieved_checkpoints[0].config["configurable"]["checkpoint_id"] + == "checkpoint_ns1" + ) + + def test_list_with_non_default_namespace(self): + """Test listing checkpoints in a non-default namespace.""" + namespace = "ns1" + config = RunnableConfig( + { + "configurable": { + "thread_id": THREAD_ID, + "checkpoint_ns": namespace, + } + } + ) + + checkpoint_ns1 = Checkpoint(id="ns1_ckpt1") + checkpoint_ns2 = Checkpoint(id="ns2_ckpt1") + metadata = CheckpointMetadata() + + self.checkpointer.put(config, checkpoint_ns1, metadata, {}) + + config_ns2 = RunnableConfig( + { + "configurable": { + "thread_id": THREAD_ID, + "checkpoint_ns": "ns2", + } + } + ) + self.checkpointer.put(config_ns2, checkpoint_ns2, metadata, {}) + + retrieved_ns1 = list(self.checkpointer.list(config)) + self.assertEqual(len(retrieved_ns1), 1) + self.assertEqual(retrieved_ns1[0].config["configurable"]["checkpoint_id"], "ns1_ckpt1") + + retrieved_ns2 = list(self.checkpointer.list(config_ns2)) + self.assertEqual(len(retrieved_ns2), 1) + self.assertEqual(retrieved_ns2[0].config["configurable"]["checkpoint_id"], "ns2_ckpt1") + + # + # Compression + # + + def test_put_with_compression(self): + """Test putting a checkpoint with compression enabled.""" + import base64 + + saver_compressed = S3Checkpointer( + bucket_name=BUCKET_NAME, region_name=REGION, compression="gzip" + ) + + checkpoint = Checkpoint(id="checkpoint_compressed") + metadata = CheckpointMetadata() + config = self.create_config() + saver_compressed.put(config, checkpoint, metadata, {}) + + expected_key = ( + f"checkpoints/{THREAD_ID}/__default__/checkpoint_compressed/checkpoint.json" + ) + response = self.s3.get_object(Bucket=BUCKET_NAME, Key=expected_key) + body = json.loads(response["Body"].read().decode("utf-8")) + + checkpoint_data_encoded = body["checkpoint_data"] + checkpoint_data = base64.b64decode(checkpoint_data_encoded) + assert checkpoint_data.startswith(b"\x1f\x8b") # Gzip magic number + + def test_list_bz2_checkpoints(self): + """Test listing a checkpoint where checkpoint_type starts with 'bz2'.""" + compressed_data = base64.b64encode(bz2.compress(b"{}")).decode("utf-8") + key = f"checkpoints/{THREAD_ID}/__default__/bz2_ckpt/checkpoint.json" + data = { + "checkpoint_type": "bz2_json", + "checkpoint_data": compressed_data, + "metadata_data": compressed_data, + "parent_checkpoint_id": None, + "timestamp": int(time.time() * 1000), + } + self.s3.put_object(Bucket=BUCKET_NAME, Key=key, Body=json.dumps(data)) + + config = self.create_config(checkpoint_id="bz2_ckpt") + retrieved_checkpoints = list(self.checkpointer.list(config)) + self.assertEqual(len(retrieved_checkpoints), 1) + self.assertEqual(retrieved_checkpoints[0].checkpoint, {}) + self.assertEqual(retrieved_checkpoints[0].metadata, {}) + + def test_list_gzip_checkpoints(self): + """Test listing a checkpoint where checkpoint_type starts with 'gzip'.""" + compressed_data = base64.b64encode(gzip.compress(b"{}")).decode("utf-8") + key = f"checkpoints/{THREAD_ID}/__default__/gzip_ckpt/checkpoint.json" + data = { + "checkpoint_type": "gzip_json", + "checkpoint_data": compressed_data, + "metadata_data": compressed_data, + "parent_checkpoint_id": None, + "timestamp": int(time.time() * 1000), + } + self.s3.put_object(Bucket=BUCKET_NAME, Key=key, Body=json.dumps(data)) + + config = self.create_config(checkpoint_id="gzip_ckpt") + retrieved_checkpoints = list(self.checkpointer.list(config)) + self.assertEqual(len(retrieved_checkpoints), 1) + self.assertEqual(retrieved_checkpoints[0].checkpoint, {}) + self.assertEqual(retrieved_checkpoints[0].metadata, {}) + + # + # Concurrency + # + + def test_concurrent_puts(self): + """Test concurrent puts to ensure thread safety (basic simulation).""" + import threading + + def put_checkpoint(id_suffix): + checkpoint = Checkpoint(id=f"checkpoint_concurrent_{id_suffix}") + metadata = CheckpointMetadata() + config = self.create_config() + self.checkpointer.put(config, checkpoint, metadata, {}) + + threads = [] + for i in range(5): + t = threading.Thread(target=put_checkpoint, args=(i,)) + threads.append(t) + t.start() + + for t in threads: + t.join() + + config = self.create_config() + retrieved_checkpoints = list(self.checkpointer.list(config)) + expected_ids = {f"checkpoint_concurrent_{i}" for i in range(5)} + retrieved_ids = { + ck.config["configurable"]["checkpoint_id"] for ck in retrieved_checkpoints + } + assert expected_ids.issubset(retrieved_ids) + + # + # Latest Checkpoint ID + # + + def test_get_latest_checkpoint_id(self): + """Test the internal method to get the latest checkpoint ID.""" + self.setup_s3_bucket() + latest_id = self.checkpointer._get_latest_checkpoint_id( + THREAD_ID, CHECKPOINT_NAMESPACE + ) + self.assertEqual(latest_id, CHECKPOINT_ID_2) + + def test_get_latest_checkpoint_id_no_keys(self): + """Test getting the latest checkpoint ID when none exist.""" + latest_id = self.checkpointer._get_latest_checkpoint_id( + THREAD_ID, CHECKPOINT_NAMESPACE + ) + assert latest_id is None + + # + # Deleting Checkpoints + # + + def test_delete_checkpoints(self): + """Test that delete_checkpoints correctly removes all checkpoints for a thread.""" + self.setup_s3_bucket() + config = self.create_config() + retrieved_checkpoints = list(self.checkpointer.list(config)) + self.assertEqual(len(retrieved_checkpoints), len(CHECKPOINTS)) + + self.checkpointer.delete_checkpoints(THREAD_ID) + retrieved_after_delete = list(self.checkpointer.list(config)) + self.assertEqual(len(retrieved_after_delete), 0) + + def test_delete_checkpoints_many(self): + """Test deleting multiple checkpoints in batches.""" + for i in range(3): + ckpt = Checkpoint(id=f"ckpt_del_{i}") + metadata = CheckpointMetadata() + config = self.create_config() + self.checkpointer.put(config, ckpt, metadata, {}) + + self.checkpointer.delete_checkpoints(THREAD_ID) + retrieved_after_delete = list(self.checkpointer.list(self.create_config())) + self.assertEqual(len(retrieved_after_delete), 0) + + # + # Invalid Key Formats and Other Edge Cases + # + + def test_load_pending_writes_invalid_key_format(self): + """Test _load_pending_writes handling invalid write key formats.""" + checkpoint = Checkpoint(id="ckpt_invalid_write") + metadata = CheckpointMetadata() + config = self.create_config() + returned_config = self.checkpointer.put(config, checkpoint, metadata, {}) + + invalid_write_key = f"checkpoints/{THREAD_ID}/__default__/ckpt_invalid_write/writes/invalid.json" + self.s3.put_object(Bucket=BUCKET_NAME, Key=invalid_write_key, Body="{}") + + tuple_result = self.checkpointer.get_tuple(returned_config) + assert tuple_result is not None + # No valid writes parsed due to invalid format + self.assertEqual(tuple_result.pending_writes, []) + + def test_invalid_checkpoint_key_format(self): + """Test handling of invalid checkpoint key formats.""" + invalid_key = "checkpoints/thread1/__default__/invalid_format.json" + self.s3.put_object( + Bucket=BUCKET_NAME, + Key=invalid_key, + Body='{"invalid": "data"}', + ) + + config = self.create_config() + with self.assertRaises(ValueError) as context: + list(self.checkpointer.list(config)) + + self.assertIn("Invalid checkpoint key format", str(context.exception)) diff --git a/chat/test/search/__init__.py b/chat/test/search/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/chat/test/helpers/test_hybrid_query.py b/chat/test/search/test_hybrid_query.py similarity index 94% rename from chat/test/helpers/test_hybrid_query.py rename to chat/test/search/test_hybrid_query.py index 4e38861e..c230f5d2 100644 --- a/chat/test/helpers/test_hybrid_query.py +++ b/chat/test/search/test_hybrid_query.py @@ -1,5 +1,5 @@ import sys -from helpers.hybrid_query import hybrid_query +from search.hybrid_query import hybrid_query from unittest import TestCase sys.path.append('./src') diff --git a/chat/test/handlers/test_opensearch_neural_search.py b/chat/test/search/test_opensearch_neural_search.py similarity index 95% rename from chat/test/handlers/test_opensearch_neural_search.py rename to chat/test/search/test_opensearch_neural_search.py index d7448679..b80ffb32 100644 --- a/chat/test/handlers/test_opensearch_neural_search.py +++ b/chat/test/search/test_opensearch_neural_search.py @@ -3,7 +3,7 @@ sys.path.append('./src') from unittest import TestCase -from handlers.opensearch_neural_search import OpenSearchNeuralSearch +from search.opensearch_neural_search import OpenSearchNeuralSearch from langchain_core.documents import Document class MockClient():