Skip to content

Commit

Permalink
Merge pull request #276 from nulib/5279-chat-sync
Browse files Browse the repository at this point in the history
Get sync chat function working
  • Loading branch information
mbklein authored Dec 17, 2024
2 parents eea458f + bacc315 commit cbbcd75
Show file tree
Hide file tree
Showing 7 changed files with 99 additions and 328 deletions.
264 changes: 9 additions & 255 deletions chat-playground/playground.ipynb

Large diffs are not rendered by default.

8 changes: 3 additions & 5 deletions chat/src/agent/metrics_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,10 @@ def __init__(self, *args, **kwargs):

def on_llm_end(self, response: LLMResult, **kwargs: Dict[str, Any]):
for generation in response.generations[0]:
self.answers.append(generation.text)
if generation.text != "":
self.answers.append(generation.text)
for k, v in generation.message.usage_metadata.items():
if k not in self.accumulator:
self.accumulator[k] = v
else:
self.accumulator[k] += v
self.accumulator[k] = self.accumulator.get(k, 0) + v

def on_tool_end(self, output: ToolMessage, **kwargs: Dict[str, Any]):
match output.name:
Expand Down
3 changes: 2 additions & 1 deletion chat/src/handlers/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ def handler(event, context):

metrics = MetricsHandler()
callbacks = [AgentHandler(config.socket, config.ref), metrics]
search_agent = SearchAgent(model=chat_model(config), streaming=True)
model = chat_model(model=config.model, streaming=config.stream_response)
search_agent = SearchAgent(model=model)
try:
search_agent.invoke(config.question, config.ref, forget=config.forget, callbacks=callbacks)
log_metrics(context, metrics, config)
Expand Down
50 changes: 26 additions & 24 deletions chat/src/handlers/chat_sync.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,43 @@
import secrets # noqa
import json
import logging
import os
from http_event_config import HTTPEventConfig
from helpers.http_response import HTTPResponse
from agent.metrics_handler import MetricsHandler
from agent.search_agent import SearchAgent
from handlers.model import chat_model
from event_config import EventConfig
from honeybadger import honeybadger

honeybadger.configure()
logging.getLogger('honeybadger').addHandler(logging.StreamHandler())

RESPONSE_TYPES = {
"base": ["answer", "ref", "context"],
"debug": ["answer", "attributes", "azure_endpoint", "deployment_name", "is_superuser", "k", "openai_api_version", "prompt", "question", "ref", "temperature", "text_key", "token_counts", "context"],
"log": ["answer", "deployment_name", "is_superuser", "k", "openai_api_version", "prompt", "question", "ref", "size", "source_documents", "temperature", "token_counts"],
"error": ["question", "error", "source_documents"]
}

def handler(event, context):
config = HTTPEventConfig(event)
config = EventConfig(event)

if not config.is_logged_in:
return {"statusCode": 401, "body": "Unauthorized"}

if config.question is None or config.question == "":
return {"statusCode": 400, "body": "Question cannot be blank"}

if not os.getenv("SKIP_LLM_REQUEST"):
config.setup_llm_request()
response = HTTPResponse(config)
final_response = response.prepare_response()
if "error" in final_response:
logging.error(f'Error: {final_response["error"]}')
return {"statusCode": 500, "body": "Internal Server Error"}
else:
return {"statusCode": 200, "body": json.dumps(reshape_response(final_response, 'debug' if config.debug_mode else 'base'))}

return {"statusCode": 200}
model = chat_model(model=config.model, streaming=False)
search_agent = SearchAgent(model=model)
result = MetricsHandler()
search_agent.invoke(config.question, config.ref, forget=config.forget, callbacks=[result])

def reshape_response(response, type):
return {k: response[k] for k in RESPONSE_TYPES[type]}
return {
"statusCode": 200,
"headers": {
"Content-Type": "application/json"
},
"body": json.dumps({
"answer": result.answers,
"is_dev_team": config.api_token.is_dev_team(),
"is_superuser": config.api_token.is_superuser(),
"k": config.k,
"model": config.model,
"question": config.question,
"ref": config.ref,
"artifacts": result.artifacts,
"token_counts": result.accumulator,
})
}
5 changes: 2 additions & 3 deletions chat/src/handlers/model.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from event_config import EventConfig
from langchain_aws import ChatBedrock
from langchain_core.language_models.base import BaseModel

MODEL_OVERRIDE: BaseModel = None

def chat_model(event: EventConfig):
return MODEL_OVERRIDE or ChatBedrock(model=event.model)
def chat_model(**kwargs):
return MODEL_OVERRIDE or ChatBedrock(**kwargs)

def set_model_override(model: BaseModel):
global MODEL_OVERRIDE
Expand Down
2 changes: 1 addition & 1 deletion chat/src/requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Dev/Test Dependencies
ruff~=0.1.0
ruff~=0.2.0
coverage~=7.3.2
95 changes: 56 additions & 39 deletions chat/template.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,6 @@ Resources:
CheckpointBucket:
Type: 'AWS::S3::Bucket'
Properties:
VersioningConfiguration:
Status: Enabled
PublicAccessBlockConfiguration:
BlockPublicAcls: true
BlockPublicPolicy: true
Expand Down Expand Up @@ -259,43 +257,62 @@ Resources:
Resource: "*"
#* Metadata:
#* BuildMethod: nodejs20.x
# ChatSyncFunction:
# Type: AWS::Serverless::Function
# Properties:
# CodeUri: ./src
# Runtime: python3.12
# Architectures:
# - x86_64
# #* Layers:
# #* - !Ref ChatDependencies
# MemorySize: 1024
# Handler: handlers/chat_sync.handler
# Timeout: 300
# Environment:
# Variables:
# API_TOKEN_NAME: !Ref ApiTokenName
# ENV_PREFIX: !Ref EnvironmentPrefix
# HONEYBADGER_API_KEY: !Ref HoneybadgerApiKey
# HONEYBADGER_ENVIRONMENT: !Ref HoneybadgerEnv
# HONEYBADGER_REVISION: !Ref HoneybadgerRevision
# METRICS_LOG_GROUP: !Ref ChatMetricsLog
# SECRETS_PATH: !Ref SecretsPath
# FunctionUrlConfig:
# AuthType: NONE
# Policies:
# - !Ref SecretsPolicy
# - Statement:
# - Effect: Allow
# Action:
# - 'es:ESHttpGet'
# - 'es:ESHttpPost'
# Resource: '*'
# - Statement:
# - Effect: Allow
# Action:
# - logs:CreateLogStream
# - logs:PutLogEvents
# Resource: !Sub "${ChatMetricsLog.Arn}:*"
ChatSyncFunction:
Type: AWS::Serverless::Function
Properties:
CodeUri: ./src
Runtime: python3.12
Architectures:
- x86_64
#* Layers:
#* - !Ref ChatDependencies
MemorySize: 1024
Handler: handlers/chat_sync.handler
Timeout: 300
Environment:
Variables:
API_TOKEN_NAME: !Ref ApiTokenName
CHECKPOINT_BUCKET_NAME: !Ref CheckpointBucket
ENV_PREFIX: !Ref EnvironmentPrefix
HONEYBADGER_API_KEY: !Ref HoneybadgerApiKey
HONEYBADGER_ENVIRONMENT: !Ref HoneybadgerEnv
HONEYBADGER_REVISION: !Ref HoneybadgerRevision
METRICS_LOG_GROUP: !Ref ChatMetricsLog
SECRETS_PATH: !Ref SecretsPath
NO_COLOR: 1
FunctionUrlConfig:
AuthType: NONE
Policies:
- !Ref SecretsPolicy
- Statement:
- Effect: Allow
Action:
- 'es:ESHttpGet'
- 'es:ESHttpPost'
Resource: '*'
- Statement:
- Effect: Allow
Action:
- logs:CreateLogStream
- logs:PutLogEvents
Resource: !Sub "${ChatMetricsLog.Arn}:*"
- Statement:
- Effect: Allow
Action:
- s3:GetObject
- s3:ListBucket
- s3:PutObject
- s3:DeleteObject
- s3:ListObjectsV2
Resource:
- !Sub "arn:aws:s3:::${CheckpointBucket}"
- !Sub "arn:aws:s3:::${CheckpointBucket}/*"
- Statement:
- Effect: Allow
Action:
- bedrock:InvokeModel
- bedrock:InvokeModelWithResponseStream
Resource: "*"
#* Metadata:
#* BuildMethod: nodejs20.x
ChatMetricsLog:
Expand Down

0 comments on commit cbbcd75

Please sign in to comment.