Skip to content

Commit

Permalink
Switch from Azure OpenAI to Bedrock
Browse files Browse the repository at this point in the history
  • Loading branch information
mbklein committed Dec 12, 2024
1 parent a8effff commit c8980c7
Show file tree
Hide file tree
Showing 7 changed files with 19 additions and 23 deletions.
3 changes: 2 additions & 1 deletion chat/src/agent/agent_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ def on_tool_end(self, output: ToolMessage, **kwargs: Dict[str, Any]):
pass
case "search":
try:
docs: List[Dict[str, Any]] = [doc.metadata for doc in output.artifact]
result_fields = ("id", "title", "visibility", "work_type", "thumbnail")
docs: List[Dict[str, Any]] = [{k: doc.metadata.get(k) for k in result_fields} for doc in output.artifact]
self.socket.send({"type": "search_result", "ref": self.ref, "message": docs})
except json.decoder.JSONDecodeError as e:
print(f"Invalid json ({e}) returned from {output.name} tool: {output.content}")
Expand Down
8 changes: 4 additions & 4 deletions chat/src/agent/search_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,21 @@
from langchain_core.messages.system import SystemMessage
from langgraph.graph import END, START, StateGraph, MessagesState
from langgraph.prebuilt import ToolNode
from setup import openai_chat_client
from setup import chat_client

# Keep your answer concise and keep reading time under 45 seconds.

system_message = """
Please provide a brief answer to the question using the tools provided. Include specific details from multiple documents that
support your answer. Answer in raw markdown, but not within a code block. When citing source documents, construct Markdown
links using the document's canonical_link field.
links using the document's canonical_link field. Do not include intermediate messages explaining your process.
"""

tools = [discover_fields, search, aggregate]

tool_node = ToolNode(tools)

model = openai_chat_client(streaming=True).bind_tools(tools)
model = chat_client(streaming=True).bind_tools(tools)

# Define the function that determines whether to continue or not
def should_continue(state: MessagesState) -> Literal["tools", END]:
Expand All @@ -38,7 +38,7 @@ def should_continue(state: MessagesState) -> Literal["tools", END]:
# Define the function that calls the model
def call_model(state: MessagesState):
messages = [SystemMessage(content=system_message)] + state["messages"]
response: BaseMessage = model.invoke(messages, model=os.getenv("AZURE_OPENAI_LLM_DEPLOYMENT_ID"))
response: BaseMessage = model.invoke(messages) # , model=os.getenv("AZURE_OPENAI_LLM_DEPLOYMENT_ID")
# We return a list, because this will get added to the existing list
# if socket is not none and the response content is not an empty string
return {"messages": [response]}
Expand Down
7 changes: 2 additions & 5 deletions chat/src/handlers/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,9 @@
"debug": [
"answer",
"attributes",
"azure_endpoint",
"deployment_name",
"is_superuser",
"k",
"openai_api_version",
"prompt",
"question",
"ref",
Expand All @@ -36,7 +34,6 @@
"deployment_name",
"is_superuser",
"k",
"openai_api_version",
"prompt",
"question",
"ref",
Expand Down Expand Up @@ -79,8 +76,8 @@ def handler(event, context):
try:
search_agent.invoke(
{"messages": [HumanMessage(content=config.question)]},
config={"configurable": {"thread_id": config.ref}, "callbacks": callbacks, "metadata": {"model_deployment": os.getenv("AZURE_OPENAI_LLM_DEPLOYMENT_ID")}},
debug=False
config={"configurable": {"thread_id": config.ref}, "callbacks": callbacks},
debug=True
)
except Exception as e:
print(f"Error: {e}")
Expand Down
2 changes: 0 additions & 2 deletions chat/src/helpers/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,10 @@ def debug_response(config, response, original_question):
return {
"answer": response,
"attributes": config.attributes,
"azure_endpoint": config.azure_endpoint,
"deployment_name": config.deployment_name,
"is_dev_team": config.api_token.is_dev_team(),
"is_superuser": config.api_token.is_superuser(),
"k": config.k,
"openai_api_version": config.openai_api_version,
"prompt": config.prompt_text,
"question": config.question,
"ref": config.ref,
Expand Down
6 changes: 1 addition & 5 deletions chat/src/secrets.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,12 @@ def load_secrets():
EnvironmentMap = [
['API_TOKEN_SECRET', 'dcapi', 'api_token_secret'],
['OPENSEARCH_ENDPOINT', 'index', 'endpoint'],
['OPENSEARCH_MODEL_ID', 'index', 'embedding_model'],
['AZURE_OPENAI_API_KEY', 'azure_openai', 'api_key'],
['AZURE_OPENAI_LLM_DEPLOYMENT_ID', 'azure_openai', 'llm_deployment_id'],
['AZURE_OPENAI_RESOURCE_NAME', 'azure_openai', 'resource_name']
['OPENSEARCH_MODEL_ID', 'index', 'embedding_model']
]

client = boto3.client("secretsmanager")
response = client.batch_get_secret_value(SecretIdList=[
f'{SecretsPath}/infrastructure/index',
f'{SecretsPath}/infrastructure/azure_openai',
f'{SecretsPath}/config/dcapi'
])

Expand Down
10 changes: 4 additions & 6 deletions chat/src/setup.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from langchain_openai import AzureChatOpenAI
from langchain_aws import ChatBedrock
from handlers.opensearch_neural_search import OpenSearchNeuralSearch
from opensearchpy import OpenSearch, RequestsHttpConnection
from requests_aws4auth import AWS4Auth
Expand All @@ -13,11 +13,9 @@ def prefix(value):
return "-".join(filter(None, [env_prefix, value]))


def openai_chat_client(**kwargs):
return AzureChatOpenAI(
openai_api_key=os.getenv("AZURE_OPENAI_API_KEY"),
api_version=os.getenv("AZURE_API_VERSION", "2024-08-01-preview"),
azure_endpoint=f"https://{os.getenv("AZURE_OPENAI_RESOURCE_NAME")}.openai.azure.com",
def chat_client(**kwargs):
return ChatBedrock(
model="us.anthropic.claude-3-5-sonnet-20241022-v2:0",
**kwargs,
)

Expand Down
6 changes: 6 additions & 0 deletions chat/template.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,12 @@ Resources:
Resource:
- !Sub "arn:aws:s3:::${CheckpointBucket}"
- !Sub "arn:aws:s3:::${CheckpointBucket}/*"
- Statement:
- Effect: Allow
Action:
- bedrock:InvokeModel
- bedrock:InvokeModelWithResponseStream
Resource: "*"
#* Metadata:
#* BuildMethod: nodejs20.x
# ChatSyncFunction:
Expand Down

0 comments on commit c8980c7

Please sign in to comment.