Skip to content

Commit

Permalink
Refactor chat package/module layout
Browse files Browse the repository at this point in the history
Get all tests passing
Add tests for the S3 Checkpointer
  • Loading branch information
mbklein committed Dec 17, 2024
1 parent f3e646d commit 5119178
Show file tree
Hide file tree
Showing 47 changed files with 1,045 additions and 506 deletions.
4 changes: 3 additions & 1 deletion .github/workflows/test-python.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jobs:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
with:
python-version: '3.9'
python-version: '3.12'
cache-dependency-path: chat/src/requirements.txt
- run: pip install -r requirements.txt && pip install -r requirements-dev.txt
working-directory: ./chat/src
Expand All @@ -28,3 +28,5 @@ jobs:
run: |
coverage run --include='src/**/*' -m unittest
coverage report
env:
AWS_REGION: us-east-1
4 changes: 2 additions & 2 deletions chat-playground/playground.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
" pass\n",
"\n",
"sys.path.insert(0, os.path.join(os.curdir, \"../chat/src\"))\n",
"import secrets # noqa"
"import helpers.secrets # noqa"
]
},
{
Expand Down Expand Up @@ -61,7 +61,7 @@
"source": [
"import agent.search_agent\n",
"from agent.search_agent import SearchAgent\n",
"from handlers.model import chat_model\n",
"from core.setup import chat_model\n",
"\n",
"model = chat_model(model=\"us.anthropic.claude-3-5-sonnet-20241022-v2:0\", streaming=False)\n",
"agent = SearchAgent(model=model)\n",
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
from langchain_core.messages.tool import ToolMessage
import json

class MetricsHandler(BaseCallbackHandler):
class MetricsCallbackHandler(BaseCallbackHandler):
def __init__(self, *args, **kwargs):
self.accumulator = {}
self.answers = []
self.artifacts = []
super().__init__(*args, **kwargs)

def on_llm_end(self, response: LLMResult, **kwargs: Dict[str, Any]):
if response is None:
return
Expand All @@ -20,18 +21,16 @@ def on_llm_end(self, response: LLMResult, **kwargs: Dict[str, Any]):
for generation in response.generations[0]:
if generation.text != "":
self.answers.append(generation.text)
for k, v in generation.message.usage_metadata.items():
self.accumulator[k] = self.accumulator.get(k, 0) + v

if not hasattr(generation, 'message') or generation.message is None:
continue
metadata = getattr(generation.message, 'usage_metadata', None)
if metadata is None:
continue
for k, v in metadata.items():
self.accumulator[k] = self.accumulator.get(k, 0) + v
if not hasattr(generation, 'message') or generation.message is None:
continue

metadata = getattr(generation.message, 'usage_metadata', None)
if metadata is None:
continue

for k, v in metadata.items():
self.accumulator[k] = self.accumulator.get(k, 0) + v

def on_tool_end(self, output: ToolMessage, **kwargs: Dict[str, Any]):
match output.name:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any, Dict, List, Optional

from websocket import Websocket
from core.websocket import Websocket

from json.decoder import JSONDecodeError
from langchain_core.callbacks import BaseCallbackHandler
Expand All @@ -19,7 +19,7 @@ def deserialize_input(input_str):
except JSONDecodeError:
return input_str

class AgentHandler(BaseCallbackHandler):
class SocketCallbackHandler(BaseCallbackHandler):
def __init__(self, socket: Websocket, ref: str, *args: List[Any], **kwargs: Dict[str, Any]):
if socket is None:
raise ValueError("Socket not provided to agent callback handler")
Expand Down
12 changes: 0 additions & 12 deletions chat/src/agent/checkpoints.py

This file was deleted.

2 changes: 1 addition & 1 deletion chat/src/agent/search_agent.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import Literal, List

from agent.checkpoints import checkpoint_saver
from agent.tools import aggregate, discover_fields, search
from langchain_core.messages import HumanMessage
from langchain_core.messages.base import BaseMessage
Expand All @@ -9,6 +8,7 @@
from langchain_core.messages.system import SystemMessage
from langgraph.graph import END, START, StateGraph, MessagesState
from langgraph.prebuilt import ToolNode
from core.setup import checkpoint_saver

DEFAULT_SYSTEM_MESSAGE = """
Please provide a brief answer to the question using the tools provided. Include specific details from multiple documents that
Expand Down
2 changes: 1 addition & 1 deletion chat/src/agent/tools.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json

from langchain_core.tools import tool
from setup import opensearch_vector_store
from core.setup import opensearch_vector_store

def get_keyword_fields(properties, prefix=''):
"""
Expand Down
36 changes: 0 additions & 36 deletions chat/src/content_handler.py

This file was deleted.

File renamed without changes.
File renamed without changes.
6 changes: 3 additions & 3 deletions chat/src/event_config.py → chat/src/core/event_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@

from langchain_core.prompts import ChatPromptTemplate

from helpers.apitoken import ApiToken
from helpers.prompts import prompt_template
from websocket import Websocket
from core.apitoken import ApiToken
from core.prompts import prompt_template
from core.websocket import Websocket
from uuid import uuid4

CHAIN_TYPE = "stuff"
Expand Down
File renamed without changes.
File renamed without changes.
12 changes: 11 additions & 1 deletion chat/src/setup.py → chat/src/core/setup.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,20 @@
from handlers.opensearch_neural_search import OpenSearchNeuralSearch
from persistence.s3_checkpointer import S3Checkpointer
from search.opensearch_neural_search import OpenSearchNeuralSearch
from langchain_aws import ChatBedrock
from langchain_core.language_models.base import BaseModel
from langgraph.checkpoint.base import BaseCheckpointSaver
from opensearchpy import OpenSearch, RequestsHttpConnection
from requests_aws4auth import AWS4Auth
from urllib.parse import urlparse
import os
import boto3

def chat_model(**kwargs) -> BaseModel:
return ChatBedrock(**kwargs)

def checkpoint_saver(**kwargs) -> BaseCheckpointSaver:
checkpoint_bucket: str = os.getenv("CHECKPOINT_BUCKET_NAME")
return S3Checkpointer(bucket_name=checkpoint_bucket, **kwargs)

def prefix(value):
env_prefix = os.getenv("ENV_PREFIX")
Expand Down
2 changes: 1 addition & 1 deletion chat/src/websocket.py → chat/src/core/websocket.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from setup import websocket_client
from core.setup import websocket_client

class Websocket:
def __init__(self, client=None, endpoint_url=None, connection_id=None, ref=None):
Expand Down
63 changes: 48 additions & 15 deletions chat/src/handlers/chat.py → chat/src/handlers.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,54 @@
import secrets # noqa
import core.secrets # noqa
import boto3
import json
import logging
import os
import traceback
from datetime import datetime
from event_config import EventConfig
# from honeybadger import honeybadger
from core.event_config import EventConfig
from honeybadger import honeybadger
from agent.search_agent import SearchAgent
from agent.agent_handler import AgentHandler
from agent.metrics_handler import MetricsHandler
from handlers.model import chat_model
from agent.callbacks.socket import SocketCallbackHandler
from agent.callbacks.metrics import MetricsCallbackHandler
from core.setup import chat_model

# honeybadger.configure()
# logging.getLogger("honeybadger").addHandler(logging.StreamHandler())
honeybadger.configure()
logging.getLogger("honeybadger").addHandler(logging.StreamHandler())

def handler(event, context):
def chat_sync(event, context):
config = EventConfig(event)

if not config.is_logged_in:
return {"statusCode": 401, "body": "Unauthorized"}

if config.question is None or config.question == "":
return {"statusCode": 400, "body": "Question cannot be blank"}

model = chat_model(model=config.model, streaming=False)
search_agent = SearchAgent(model=model)
result = MetricsCallbackHandler()
search_agent.invoke(
config.question, config.ref, forget=config.forget, callbacks=[result]
)

return {
"statusCode": 200,
"headers": {"Content-Type": "application/json"},
"body": json.dumps(
{
"answer": result.answers,
"is_dev_team": config.api_token.is_dev_team(),
"is_superuser": config.api_token.is_superuser(),
"k": config.k,
"model": config.model,
"question": config.question,
"ref": config.ref,
"artifacts": result.artifacts,
"token_counts": result.accumulator,
}
),
}

def chat(event, context):
config = EventConfig(event)
socket = event.get("socket", None)
config.setup_websocket(socket)
Expand All @@ -27,20 +61,19 @@ def handler(event, context):
config.socket.send({"type": "error", "message": "Question cannot be blank"})
return {"statusCode": 400, "body": "Question cannot be blank"}

metrics = MetricsHandler()
callbacks = [AgentHandler(config.socket, config.ref), metrics]
metrics = MetricsCallbackHandler()
callbacks = [SocketCallbackHandler(config.socket, config.ref), metrics]
model = chat_model(model=config.model, streaming=config.stream_response)
search_agent = SearchAgent(model=model)

try:
search_agent.invoke(config.question, config.ref, forget=config.forget, callbacks=callbacks)
log_metrics(context, metrics, config)
except Exception as e:
print(f"Error: {e}")
print(traceback.format_exc())
error_response = {"type": "error", "message": "An unexpected error occurred. Please try again later."}
if config.socket:
config.socket.send(error_response)
return {"statusCode": 500, "body": json.dumps(error_response)}
raise e

return {"statusCode": 200}

Expand Down
43 changes: 0 additions & 43 deletions chat/src/handlers/chat_sync.py

This file was deleted.

5 changes: 0 additions & 5 deletions chat/src/handlers/model.py

This file was deleted.

22 changes: 0 additions & 22 deletions chat/src/handlers/streaming_socket_callback_handler.py

This file was deleted.

Loading

0 comments on commit 5119178

Please sign in to comment.