Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Push latest chat memory feature branch to prototype API #278

Merged
merged 11 commits into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .github/workflows/test-python.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jobs:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
with:
python-version: '3.9'
python-version: '3.12'
cache-dependency-path: chat/src/requirements.txt
- run: pip install -r requirements.txt && pip install -r requirements-dev.txt
working-directory: ./chat/src
Expand All @@ -28,3 +28,6 @@ jobs:
run: |
coverage run --include='src/**/*' -m unittest
coverage report
env:
__SKIP_SECRETS__: true
AWS_REGION: us-east-1
266 changes: 10 additions & 256 deletions chat-playground/playground.ipynb

Large diffs are not rendered by default.

File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,33 @@
from langchain_core.messages.tool import ToolMessage
import json

class MetricsHandler(BaseCallbackHandler):
class MetricsCallbackHandler(BaseCallbackHandler):
def __init__(self, *args, **kwargs):
self.accumulator = {}
self.answers = []
self.artifacts = []
super().__init__(*args, **kwargs)

def on_llm_end(self, response: LLMResult, **kwargs: Dict[str, Any]):
if response is None:
return

if not response.generations or not response.generations[0]:
return

for generation in response.generations[0]:
self.answers.append(generation.text)
for k, v in generation.message.usage_metadata.items():
if k not in self.accumulator:
self.accumulator[k] = v
else:
self.accumulator[k] += v
if generation.text != "":
self.answers.append(generation.text)

if not hasattr(generation, 'message') or generation.message is None:
continue

metadata = getattr(generation.message, 'usage_metadata', None)
if metadata is None:
continue

for k, v in metadata.items():
self.accumulator[k] = self.accumulator.get(k, 0) + v

def on_tool_end(self, output: ToolMessage, **kwargs: Dict[str, Any]):
match output.name:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any, Dict, List, Optional

from websocket import Websocket
from core.websocket import Websocket

from json.decoder import JSONDecodeError
from langchain_core.callbacks import BaseCallbackHandler
Expand All @@ -19,7 +19,7 @@ def deserialize_input(input_str):
except JSONDecodeError:
return input_str

class AgentHandler(BaseCallbackHandler):
class SocketCallbackHandler(BaseCallbackHandler):
def __init__(self, socket: Websocket, ref: str, *args: List[Any], **kwargs: Dict[str, Any]):
if socket is None:
raise ValueError("Socket not provided to agent callback handler")
Expand Down Expand Up @@ -56,12 +56,9 @@ def on_tool_end(self, output: ToolMessage, **kwargs: Dict[str, Any]):
case "discover_fields":
pass
case "search":
try:
result_fields = ("id", "title", "visibility", "work_type", "thumbnail")
docs: List[Dict[str, Any]] = [{k: doc.metadata.get(k) for k in result_fields} for doc in output.artifact]
self.socket.send({"type": "search_result", "ref": self.ref, "message": docs})
except json.decoder.JSONDecodeError as e:
print(f"Invalid json ({e}) returned from {output.name} tool: {output.content}")
result_fields = ("id", "title", "visibility", "work_type", "thumbnail")
docs: List[Dict[str, Any]] = [{k: doc.metadata.get(k) for k in result_fields} for doc in output.artifact]
self.socket.send({"type": "search_result", "ref": self.ref, "message": docs})
case _:
print(f"Unhandled tool_end message: {output}")

Expand Down
29 changes: 15 additions & 14 deletions chat/src/agent/search_agent.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
import os

from typing import Literal, List

from agent.s3_saver import S3Saver, delete_checkpoints
from agent.tools import aggregate, discover_fields, search
from langchain_aws import ChatBedrock
from langchain_core.messages import HumanMessage
from langchain_core.messages.base import BaseMessage
from langchain_core.language_models.chat_models import BaseModel
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.messages.system import SystemMessage
from langgraph.graph import END, START, StateGraph, MessagesState
from langgraph.prebuilt import ToolNode
from core.setup import checkpoint_saver

DEFAULT_SYSTEM_MESSAGE = """
Please provide a brief answer to the question using the tools provided. Include specific details from multiple documents that
Expand All @@ -21,16 +19,19 @@
class SearchAgent:
def __init__(
self,
model: BaseModel,
*,
checkpoint_bucket: str = os.getenv("CHECKPOINT_BUCKET_NAME"),
system_message: str = DEFAULT_SYSTEM_MESSAGE,
**kwargs):

self.checkpoint_bucket = checkpoint_bucket

**kwargs
):
tools = [discover_fields, search, aggregate]
tool_node = ToolNode(tools)
model = ChatBedrock(**kwargs).bind_tools(tools)

try:
model = model.bind_tools(tools)
except NotImplementedError:
pass


# Define the function that determines whether to continue or not
def should_continue(state: MessagesState) -> Literal["tools", END]:
Expand Down Expand Up @@ -67,13 +68,13 @@ def call_model(state: MessagesState):
# Add a normal edge from `tools` to `agent`
workflow.add_edge("tools", "agent")

checkpointer = S3Saver(bucket_name=checkpoint_bucket, compression="gzip")
self.search_agent = workflow.compile(checkpointer=checkpointer)
self.checkpointer = checkpoint_saver()
self.search_agent = workflow.compile(checkpointer=self.checkpointer)

def invoke(self, question: str, ref: str, *, callbacks: List[BaseCallbackHandler] = [], forget: bool = False, **kwargs):
if forget:
delete_checkpoints(self.checkpoint_bucket, ref)

self.checkpointer.delete_checkpoints(ref)
return self.search_agent.invoke(
{"messages": [HumanMessage(content=question)]},
config={"configurable": {"thread_id": ref}, "callbacks": callbacks},
Expand Down
2 changes: 1 addition & 1 deletion chat/src/agent/tools.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json

from langchain_core.tools import tool
from setup import opensearch_vector_store
from core.setup import opensearch_vector_store

def get_keyword_fields(properties, prefix=''):
"""
Expand Down
36 changes: 0 additions & 36 deletions chat/src/content_handler.py

This file was deleted.

File renamed without changes.
File renamed without changes.
28 changes: 3 additions & 25 deletions chat/src/event_config.py → chat/src/core/event_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@

from langchain_core.prompts import ChatPromptTemplate

from helpers.apitoken import ApiToken
from helpers.prompts import prompt_template
from websocket import Websocket
from core.apitoken import ApiToken
from core.prompts import prompt_template
from core.websocket import Websocket
from uuid import uuid4

CHAIN_TYPE = "stuff"
Expand Down Expand Up @@ -92,20 +92,6 @@ def _get_temperature(self):
def _get_text_key(self):
return self._get_payload_value_with_superuser_check("text_key", TEXT_KEY)

def debug_message(self):
return {
"type": "debug",
"message": {
"k": self.k,
"prompt": self.prompt_text,
"question": self.question,
"ref": self.ref,
"size": self.ref,
"temperature": self.temperature,
"text_key": self.text_key,
},
}

def setup_websocket(self, socket=None):
if socket is None:
connection_id = self.request_context.get("connectionId")
Expand All @@ -120,11 +106,3 @@ def setup_websocket(self, socket=None):
def _is_debug_mode_enabled(self):
debug = self.payload.get("debug", False)
return debug and self.api_token.is_superuser()

def _to_bool(self, val):
"""Converts a value to boolean. If the value is a string, it considers
"", "no", "false", "0" as False. Otherwise, it returns the boolean of the value.
"""
if isinstance(val, str):
return val.lower() not in ["", "no", "false", "0"]
return bool(val)
File renamed without changes.
File renamed without changes.
12 changes: 11 additions & 1 deletion chat/src/setup.py → chat/src/core/setup.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,20 @@
from handlers.opensearch_neural_search import OpenSearchNeuralSearch
from persistence.s3_checkpointer import S3Checkpointer
from search.opensearch_neural_search import OpenSearchNeuralSearch
from langchain_aws import ChatBedrock
from langchain_core.language_models.base import BaseModel
from langgraph.checkpoint.base import BaseCheckpointSaver
from opensearchpy import OpenSearch, RequestsHttpConnection
from requests_aws4auth import AWS4Auth
from urllib.parse import urlparse
import os
import boto3

def chat_model(**kwargs) -> BaseModel:
return ChatBedrock(**kwargs)

def checkpoint_saver(**kwargs) -> BaseCheckpointSaver:
checkpoint_bucket: str = os.getenv("CHECKPOINT_BUCKET_NAME")
return S3Checkpointer(bucket_name=checkpoint_bucket, **kwargs)

def prefix(value):
env_prefix = os.getenv("ENV_PREFIX")
Expand Down
8 changes: 4 additions & 4 deletions chat/src/websocket.py → chat/src/core/websocket.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from setup import websocket_client
from core.setup import websocket_client

class Websocket:
def __init__(self, client=None, endpoint_url=None, connection_id=None, ref=None):
Expand All @@ -8,9 +8,9 @@ def __init__(self, client=None, endpoint_url=None, connection_id=None, ref=None)
self.ref = ref if ref else {}

def send(self, data):
# if isinstance(data, str):
# data = {"message": data}
# data["ref"] = self.ref
if isinstance(data, str):
data = {"message": data}
data["ref"] = self.ref
data_as_bytes = bytes(json.dumps(data), "utf-8")

if self.connection_id == "debug":
Expand Down
65 changes: 50 additions & 15 deletions chat/src/handlers/chat.py → chat/src/handlers.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,54 @@
import secrets # noqa
import core.secrets # noqa
import boto3
import json
import logging
import os
import traceback
from datetime import datetime
from event_config import EventConfig
# from honeybadger import honeybadger
from core.event_config import EventConfig
from honeybadger import honeybadger
from agent.search_agent import SearchAgent
from agent.agent_handler import AgentHandler
from agent.metrics_handler import MetricsHandler
from agent.callbacks.socket import SocketCallbackHandler
from agent.callbacks.metrics import MetricsCallbackHandler
from core.setup import chat_model

# honeybadger.configure()
# logging.getLogger("honeybadger").addHandler(logging.StreamHandler())
honeybadger.configure()
logging.getLogger("honeybadger").addHandler(logging.StreamHandler())

def handler(event, context):
def chat_sync(event, context):
config = EventConfig(event)

if not config.is_logged_in:
return {"statusCode": 401, "body": "Unauthorized"}

if config.question is None or config.question == "":
return {"statusCode": 400, "body": "Question cannot be blank"}

model = chat_model(model=config.model, streaming=False)
search_agent = SearchAgent(model=model)
result = MetricsCallbackHandler()
search_agent.invoke(
config.question, config.ref, forget=config.forget, callbacks=[result]
)

return {
"statusCode": 200,
"headers": {"Content-Type": "application/json"},
"body": json.dumps(
{
"answer": result.answers,
"is_dev_team": config.api_token.is_dev_team(),
"is_superuser": config.api_token.is_superuser(),
"k": config.k,
"model": config.model,
"question": config.question,
"ref": config.ref,
"artifacts": result.artifacts,
"token_counts": result.accumulator,
}
),
}

def chat(event, context):
config = EventConfig(event)
socket = event.get("socket", None)
config.setup_websocket(socket)
Expand All @@ -26,19 +61,19 @@ def handler(event, context):
config.socket.send({"type": "error", "message": "Question cannot be blank"})
return {"statusCode": 400, "body": "Question cannot be blank"}

metrics = MetricsHandler()
callbacks = [AgentHandler(config.socket, config.ref), metrics]
search_agent = SearchAgent(model=config.model, streaming=True)
metrics = MetricsCallbackHandler()
callbacks = [SocketCallbackHandler(config.socket, config.ref), metrics]
model = chat_model(model=config.model, streaming=config.stream_response)
search_agent = SearchAgent(model=model)

try:
search_agent.invoke(config.question, config.ref, forget=config.forget, callbacks=callbacks)
log_metrics(context, metrics, config)
except Exception as e:
print(f"Error: {e}")
print(traceback.format_exc())
error_response = {"type": "error", "message": "An unexpected error occurred. Please try again later."}
if config.socket:
config.socket.send(error_response)
return {"statusCode": 500, "body": json.dumps(error_response)}
raise e

return {"statusCode": 200}

Expand Down
Loading
Loading