Skip to content

Commit

Permalink
Cleanup model info from event config
Browse files Browse the repository at this point in the history
  • Loading branch information
charlesLoder committed Dec 4, 2024
1 parent 6df7611 commit 3d7701d
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 78 deletions.
4 changes: 2 additions & 2 deletions chat/src/agent/agent.py → chat/src/agent/search_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def should_continue(state: MessagesState) -> Literal["tools", END]:
# Define the function that calls the model
def call_model(state: MessagesState):
messages = state["messages"]
response = model.invoke(messages, model=os.getenv("AZURE_DEPLOYMENT_NAME"))
response = model.invoke(messages, model=os.getenv("AZURE_OPENAI_LLM_DEPLOYMENT_ID"))
# We return a list, because this will get added to the existing list
return {"messages": [response]}

Expand Down Expand Up @@ -56,4 +56,4 @@ def call_model(state: MessagesState):
checkpointer = MemorySaver()

# Compile the graph
app = workflow.compile(checkpointer=checkpointer, debug=True)
search_agent = workflow.compile(checkpointer=checkpointer, debug=True)
46 changes: 0 additions & 46 deletions chat/src/event_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,18 +40,14 @@ class EventConfig:

api_token: ApiToken = field(init=False)
attributes: List[str] = field(init=False)
azure_endpoint: str = field(init=False)
azure_resource_name: str = field(init=False)
debug_mode: bool = field(init=False)
deployment_name: str = field(init=False)
document_prompt: ChatPromptTemplate = field(init=False)
event: dict = field(default_factory=dict)
is_dev_team: bool = field(init=False)
is_logged_in: bool = field(init=False)
is_superuser: bool = field(init=False)
k: int = field(init=False)
max_tokens: int = field(init=False)
openai_api_version: str = field(init=False)
payload: dict = field(default_factory=dict)
prompt_text: str = field(init=False)
prompt: ChatPromptTemplate = field(init=False)
Expand All @@ -68,16 +64,12 @@ def __post_init__(self):
self.payload = json.loads(self.event.get("body", "{}"))
self.api_token = ApiToken(signed_token=self.payload.get("auth"))
self.attributes = self._get_attributes()
self.azure_endpoint = self._get_azure_endpoint()
self.azure_resource_name = self._get_azure_resource_name()
self.debug_mode = self._is_debug_mode_enabled()
self.deployment_name = self._get_deployment_name()
self.is_dev_team = self.api_token.is_dev_team()
self.is_logged_in = self.api_token.is_logged_in()
self.is_superuser = self.api_token.is_superuser()
self.k = self._get_k()
self.max_tokens = min(self.payload.get("max_tokens", MAX_TOKENS), MAX_TOKENS)
self.openai_api_version = self._get_openai_api_version()
self.prompt_text = self._get_prompt_text()
self.request_context = self.event.get("requestContext", {})
self.question = self.payload.get("question")
Expand Down Expand Up @@ -105,22 +97,6 @@ def _get_attributes_function(self):

def _get_attributes(self):
return self._get_payload_value_with_superuser_check("attributes", self.DEFAULT_ATTRIBUTES)

def _get_azure_endpoint(self):
default = f"https://{self._get_azure_resource_name()}.openai.azure.com/"
return self._get_payload_value_with_superuser_check("azure_endpoint", default)

def _get_azure_resource_name(self):
azure_resource_name = self._get_payload_value_with_superuser_check(
"azure_resource_name", os.environ.get("AZURE_OPENAI_RESOURCE_NAME")
)
if not azure_resource_name:
raise EnvironmentError(
"Either payload must contain 'azure_resource_name' or environment variable 'AZURE_OPENAI_RESOURCE_NAME' must be set"
)
return azure_resource_name

def _get_deployment_name(self):
return self._get_payload_value_with_superuser_check(
"deployment_name", os.getenv("AZURE_OPENAI_LLM_DEPLOYMENT_ID")
)
Expand All @@ -129,11 +105,6 @@ def _get_k(self):
value = self._get_payload_value_with_superuser_check("k", K_VALUE)
return min(value, MAX_K)

def _get_openai_api_version(self):
return self._get_payload_value_with_superuser_check(
"openai_api_version", VERSION
)

def _get_prompt_text(self):
return self._get_payload_value_with_superuser_check("prompt", prompt_template())

Expand Down Expand Up @@ -178,23 +149,6 @@ def setup_websocket(self, socket=None):
self.socket = socket
return self.socket

def setup_llm_request(self):
self._setup_vector_store()
self._setup_chat_client()

def _setup_vector_store(self):
self.opensearch = opensearch_vector_store()

def _setup_chat_client(self):
self.client = openai_chat_client(
azure_deployment=self.deployment_name,
azure_endpoint=self.azure_endpoint,
openai_api_version=self.openai_api_version,
callbacks=[StreamingSocketCallbackHandler(self.socket, stream=self.stream_response)],
streaming=True,
max_tokens=self.max_tokens
)

def _is_debug_mode_enabled(self):
debug = self.payload.get("debug", False)
return debug and self.api_token.is_superuser()
Expand Down
32 changes: 2 additions & 30 deletions chat/src/handlers/chat.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import secrets # noqa
import boto3
import json
import logging
import os
from datetime import datetime
from event_config import EventConfig
from helpers.response import Response
from honeybadger import honeybadger
from chat.src.agent.search_agent import search_agent

honeybadger.configure()
logging.getLogger('honeybadger').addHandler(logging.StreamHandler())
Expand All @@ -30,36 +29,9 @@ def handler(event, context):
if config.question is None or config.question == "":
config.socket.send({"type": "error", "message": "Question cannot be blank"})
return {"statusCode": 400, "body": "Question cannot be blank"}

debug_message = config.debug_message()
if config.debug_mode:
config.socket.send(debug_message)

if not os.getenv("SKIP_LLM_REQUEST"):
config.setup_llm_request()

response = Response(config)
final_response = response.prepare_response()
if "error" in final_response:
logging.error(f'Error: {final_response["error"]}')
config.socket.send({"type": "error", "message": "Internal Server Error"})
return {"statusCode": 500, "body": "Internal Server Error"}
else:
config.socket.send(reshape_response(final_response, 'debug' if config.debug_mode else 'base'))

log_group = os.getenv('METRICS_LOG_GROUP')
log_stream = context.log_stream_name
if log_group and ensure_log_stream_exists(log_group, log_stream):
log_client = boto3.client('logs')
log_message = reshape_response(final_response, 'log')
log_events = [
{
'timestamp': timestamp(),
'message': json.dumps(log_message)
}
]
log_client.put_log_events(logGroupName=log_group, logStreamName=log_stream, logEvents=log_events)
return {"statusCode": 200}


def reshape_response(response, type):
return {k: response[k] for k in RESPONSE_TYPES[type]}
Expand Down

0 comments on commit 3d7701d

Please sign in to comment.