Skip to content

Commit

Permalink
Get agents working printing to logs
Browse files Browse the repository at this point in the history
  • Loading branch information
kdid committed Dec 4, 2024
1 parent 3d7701d commit 1e614e3
Show file tree
Hide file tree
Showing 8 changed files with 111 additions and 62 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/deploy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ jobs:
path: ".tfvars"
- uses: actions/setup-python@v2
with:
python-version: '3.10'
python-version: '3.12'
- uses: aws-actions/setup-sam@v1
- uses: aws-actions/configure-aws-credentials@master
with:
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ jobs:
- uses: actions/checkout@v2
- uses: actions/setup-python@v2
with:
python-version: '3.10'
python-version: '3.12'
- name: Install dependencies
run: pip install -r requirements.txt
working-directory: ./docs
Expand Down
1 change: 1 addition & 0 deletions chat/dependencies/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ honeybadger
langchain~=0.2
langchain-aws~=0.1
langchain-openai~=0.1
langgraph~=0.2
openai~=1.35
opensearch-py
pyjwt~=2.6.0
Expand Down
2 changes: 1 addition & 1 deletion chat/src/agent/tools.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json

from langchain_core.tools import tool
from opensearch_client import opensearch_vector_store
from setup import opensearch_vector_store

@tool(response_format="content_and_artifact")
def search(query: str):
Expand Down
66 changes: 31 additions & 35 deletions chat/src/event_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,7 @@
from dataclasses import dataclass, field

from langchain_core.prompts import ChatPromptTemplate
from setup import (
opensearch_client,
opensearch_vector_store,
openai_chat_client,
)

from typing import List
from handlers.streaming_socket_callback_handler import StreamingSocketCallbackHandler
from helpers.apitoken import ApiToken
Expand All @@ -25,23 +21,46 @@
TEXT_KEY = "id"
VERSION = "2024-02-01"


@dataclass
class EventConfig:
"""
The EventConfig class represents the configuration for an event.
Default values are set for the following properties which can be overridden in the payload message.
"""

DEFAULT_ATTRIBUTES = ["accession_number", "alternate_title", "api_link", "canonical_link", "caption", "collection",
"contributor", "date_created", "date_created_edtf", "description", "genre", "id", "identifier",
"keywords", "language", "notes", "physical_description_material", "physical_description_size",
"provenance", "publisher", "rights_statement", "subject", "table_of_contents", "thumbnail",
"title", "visibility", "work_type"]
DEFAULT_ATTRIBUTES = [
"accession_number",
"alternate_title",
"api_link",
"canonical_link",
"caption",
"collection",
"contributor",
"date_created",
"date_created_edtf",
"description",
"genre",
"id",
"identifier",
"keywords",
"language",
"notes",
"physical_description_material",
"physical_description_size",
"provenance",
"publisher",
"rights_statement",
"subject",
"table_of_contents",
"thumbnail",
"title",
"visibility",
"work_type",
]

api_token: ApiToken = field(init=False)
attributes: List[str] = field(init=False)
debug_mode: bool = field(init=False)
document_prompt: ChatPromptTemplate = field(init=False)
event: dict = field(default_factory=dict)
is_dev_team: bool = field(init=False)
is_logged_in: bool = field(init=False)
Expand All @@ -63,7 +82,6 @@ class EventConfig:
def __post_init__(self):
self.payload = json.loads(self.event.get("body", "{}"))
self.api_token = ApiToken(signed_token=self.payload.get("auth"))
self.attributes = self._get_attributes()
self.debug_mode = self._is_debug_mode_enabled()
self.is_dev_team = self.api_token.is_dev_team()
self.is_logged_in = self.api_token.is_logged_in()
Expand All @@ -78,7 +96,6 @@ def __post_init__(self):
self.stream_response = self.payload.get("stream_response", not self.debug_mode)
self.temperature = self._get_temperature()
self.text_key = self._get_text_key()
self.document_prompt = self._get_document_prompt()
self.prompt = ChatPromptTemplate.from_template(self.prompt_text)

def _get_payload_value_with_superuser_check(self, key, default):
Expand All @@ -87,20 +104,6 @@ def _get_payload_value_with_superuser_check(self, key, default):
else:
return default

def _get_attributes_function(self):
try:
opensearch = opensearch_client()
mapping = opensearch.indices.get_mapping(index="dc-v2-work")
return list(next(iter(mapping.values()))['mappings']['properties'].keys())
except StopIteration:
return []

def _get_attributes(self):
return self._get_payload_value_with_superuser_check("attributes", self.DEFAULT_ATTRIBUTES)
return self._get_payload_value_with_superuser_check(
"deployment_name", os.getenv("AZURE_OPENAI_LLM_DEPLOYMENT_ID")
)

def _get_k(self):
value = self._get_payload_value_with_superuser_check("k", K_VALUE)
return min(value, MAX_K)
Expand All @@ -117,18 +120,11 @@ def _get_temperature(self):
def _get_text_key(self):
return self._get_payload_value_with_superuser_check("text_key", TEXT_KEY)

def _get_document_prompt(self):
return ChatPromptTemplate.from_template(document_template(self.attributes))

def debug_message(self):
return {
"type": "debug",
"message": {
"attributes": self.attributes,
"azure_endpoint": self.azure_endpoint,
"deployment_name": self.deployment_name,
"k": self.k,
"openai_api_version": self.openai_api_version,
"prompt": self.prompt_text,
"question": self.question,
"ref": self.ref,
Expand Down
59 changes: 48 additions & 11 deletions chat/src/handlers/chat.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,55 @@
import secrets # noqa
import secrets # noqa
import boto3
import logging
import os
from datetime import datetime
from event_config import EventConfig
from honeybadger import honeybadger
from chat.src.agent.search_agent import search_agent
from agent.search_agent import search_agent
from langchain_core.messages import HumanMessage

honeybadger.configure()
logging.getLogger('honeybadger').addHandler(logging.StreamHandler())
logging.getLogger("honeybadger").addHandler(logging.StreamHandler())

RESPONSE_TYPES = {
"base": ["answer", "ref"],
"debug": ["answer", "attributes", "azure_endpoint", "deployment_name", "is_superuser", "k", "openai_api_version", "prompt", "question", "ref", "temperature", "text_key", "token_counts"],
"log": ["answer", "deployment_name", "is_superuser", "k", "openai_api_version", "prompt", "question", "ref", "size", "source_documents", "temperature", "token_counts", "is_dev_team"],
"error": ["question", "error", "source_documents"]
"debug": [
"answer",
"attributes",
"azure_endpoint",
"deployment_name",
"is_superuser",
"k",
"openai_api_version",
"prompt",
"question",
"ref",
"temperature",
"text_key",
"token_counts",
],
"log": [
"answer",
"deployment_name",
"is_superuser",
"k",
"openai_api_version",
"prompt",
"question",
"ref",
"size",
"source_documents",
"temperature",
"token_counts",
"is_dev_team",
],
"error": ["question", "error", "source_documents"],
}


def handler(event, context):
config = EventConfig(event)
socket = event.get('socket', None)
socket = event.get("socket", None)
config.setup_websocket(socket)

if not (config.is_logged_in or config.is_superuser):
Expand All @@ -30,22 +60,29 @@ def handler(event, context):
config.socket.send({"type": "error", "message": "Question cannot be blank"})
return {"statusCode": 400, "body": "Question cannot be blank"}

response = search_agent.invoke(
{"messages": [HumanMessage(content=config.question)]},
config={"configurable": {"thread_id": config.ref}},
)

print(response)



def reshape_response(response, type):
return {k: response[k] for k in RESPONSE_TYPES[type]}


def ensure_log_stream_exists(log_group, log_stream):
log_client = boto3.client('logs')
log_client = boto3.client("logs")
try:
log_client.create_log_stream(logGroupName=log_group, logStreamName=log_stream)
return True
except log_client.exceptions.ResourceAlreadyExistsException:
return True
except Exception:
print(f'Could not create log stream: {log_group}:{log_stream}')
print(f"Could not create log stream: {log_group}:{log_stream}")
return False


def timestamp():
return round(datetime.timestamp(datetime.now()) * 1000)
return round(datetime.timestamp(datetime.now()) * 1000)
33 changes: 24 additions & 9 deletions chat/src/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,48 +6,63 @@
import os
import boto3


def prefix(value):
env_prefix = os.getenv("ENV_PREFIX")
env_prefix = None if env_prefix == "" else env_prefix
return '-'.join(filter(None, [env_prefix, value]))
return "-".join(filter(None, [env_prefix, value]))


def openai_chat_client(**kwargs):
return AzureChatOpenAI(
openai_api_key=os.getenv("AZURE_OPENAI_API_KEY"),
api_version=os.getenv("AZURE_API_VERSION", "2024-08-01-preview"),
azure_endpoint=f"https://{os.getenv("AZURE_OPENAI_RESOURCE_NAME")}.openai.azure.com",
**kwargs,
)


def opensearch_endpoint():
endpoint = os.getenv("OPENSEARCH_ENDPOINT")
parsed = urlparse(endpoint)
if parsed.netloc != '':
if parsed.netloc != "":
return parsed.netloc
else:
return endpoint



def opensearch_client(region_name=os.getenv("AWS_REGION")):
session = boto3.Session(region_name=region_name)
awsauth = AWS4Auth(region=region_name, service="es", refreshable_credentials=session.get_credentials())
awsauth = AWS4Auth(
region=region_name,
service="es",
refreshable_credentials=session.get_credentials(),
)
endpoint = opensearch_endpoint()

return OpenSearch(
hosts=[{'host': endpoint, 'port': 443}],
use_ssl = True,
hosts=[{"host": endpoint, "port": 443}],
use_ssl=True,
connection_class=RequestsHttpConnection,
http_auth=awsauth,
)


def opensearch_vector_store(region_name=os.getenv("AWS_REGION")):
session = boto3.Session(region_name=region_name)
awsauth = AWS4Auth(region=region_name, service="es", refreshable_credentials=session.get_credentials())
awsauth = AWS4Auth(
region=region_name,
service="es",
refreshable_credentials=session.get_credentials(),
)

docsearch = OpenSearchNeuralSearch(
index=prefix("dc-v2-work"),
model_id=os.getenv("OPENSEARCH_MODEL_ID"),
endpoint=opensearch_endpoint(),
connection_class=RequestsHttpConnection,
http_auth=awsauth,
text_field= "id"
text_field="id",
)
return docsearch

Expand All @@ -58,4 +73,4 @@ def websocket_client(endpoint_url: str):
client = boto3.client("apigatewaymanagementapi", endpoint_url=endpoint_url)
return client
except Exception as e:
raise e
raise e
8 changes: 4 additions & 4 deletions chat/template.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -183,15 +183,15 @@ Resources:
Description: Dependencies for streaming chat function
ContentUri: ./dependencies
CompatibleRuntimes:
- python3.10
- python3.12
LicenseInfo: "Apache-2.0"
Metadata:
BuildMethod: python3.10
BuildMethod: python3.12
ChatFunction:
Type: AWS::Serverless::Function
Properties:
CodeUri: ./src
Runtime: python3.10
Runtime: python3.12
Architectures:
- x86_64
Layers:
Expand Down Expand Up @@ -234,7 +234,7 @@ Resources:
Type: AWS::Serverless::Function
Properties:
CodeUri: ./src
Runtime: python3.10
Runtime: python3.12
Architectures:
- x86_64
Layers:
Expand Down

0 comments on commit 1e614e3

Please sign in to comment.