Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Added support for Claude 3+ Chat API in Bedrock #2870

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 17 additions & 10 deletions lib/instrumentation/aws-sdk/v3/bedrock.js
Original file line number Diff line number Diff line change
Expand Up @@ -114,25 +114,32 @@ function recordChatCompletionMessages({
isError: err !== null
})

const msg = new LlmChatCompletionMessage({
agent,
segment,
bedrockCommand,
bedrockResponse,
index: 0,
completionId: summary.id
// Record context message(s)
const promptContextMessages = bedrockCommand.prompt
promptContextMessages.forEach((contextMessage, promptIndex) => {
const msg = new LlmChatCompletionMessage({
agent,
segment,
bedrockCommand,
content: contextMessage.content,
role: contextMessage.role,
bedrockResponse,
index: promptIndex,
completionId: summary.id
})
recordEvent({ agent, type: 'LlmChatCompletionMessage', msg })
})
recordEvent({ agent, type: 'LlmChatCompletionMessage', msg })

bedrockResponse.completions.forEach((content, index) => {
bedrockResponse.completions.forEach((content, completionIndex) => {
const chatCompletionMessage = new LlmChatCompletionMessage({
agent,
segment,
bedrockCommand,
bedrockResponse,
isResponse: true,
index: index + 1,
index: promptContextMessages.length + completionIndex,
content,
role: 'assistant',
completionId: summary.id
})
recordEvent({ agent, type: 'LlmChatCompletionMessage', msg: chatCompletionMessage })
Expand Down
76 changes: 55 additions & 21 deletions lib/llm-events/aws-bedrock/bedrock-command.js
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

'use strict'

const { stringifyClaudeChunkedMessage } = require('./utils')

/**
* Parses an AWS invoke command instance into a re-usable entity.
*/
Expand Down Expand Up @@ -68,37 +70,34 @@ class BedrockCommand {
/**
* The question posed to the LLM.
*
* @returns {string|string[]|undefined}
* @returns {object[]} The array of context messages passed to the LLM (or a single user prompt for legacy "non-chat" models)
*/
get prompt() {
let result
if (this.isTitan() === true || this.isTitanEmbed() === true) {
result = this.#body.inputText
return [
{
role: 'user',
content: this.#body.inputText
}
]
} else if (this.isCohereEmbed() === true) {
result = this.#body.texts.join(' ')
return [
{
role: 'user',
content: this.#body.texts.join(' ')
}
]
} else if (
this.isClaude() === true ||
this.isClaudeTextCompletionApi() === true ||
this.isAi21() === true ||
this.isCohere() === true ||
this.isLlama() === true
) {
result = this.#body.prompt
} else if (this.isClaude3() === true) {
const collected = []
for (const message of this.#body?.messages) {
if (message?.role === 'assistant') {
continue
}
if (typeof message?.content === 'string') {
collected.push(message?.content)
continue
}
const mappedMsgObj = message?.content.map((msgContent) => msgContent.text)
collected.push(mappedMsgObj)
}
result = collected.join(' ')
return [{ role: 'user', content: this.#body.prompt }]
} else if (this.isClaudeMessagesApi() === true) {
return normalizeClaude3Messages(this.#body?.messages ?? [])
}
return result
return []
}

/**
Expand Down Expand Up @@ -151,6 +150,41 @@ class BedrockCommand {
isTitanEmbed() {
return this.#modelId.startsWith('amazon.titan-embed')
}

isClaudeMessagesApi() {
return (this.isClaude3() === true || this.isClaude() === true) && 'messages' in this.#body
}

isClaudeTextCompletionApi() {
return this.isClaude() === true && 'prompt' in this.#body
}
}

/**
* Claude v3 requests in Bedrock can have two different "chat" flavors. This function normalizes them into a consistent
* format per the AIM agent spec
*
* @param messages - The raw array of messages passed to the invoke API
* @returns {number|undefined} - The normalized messages
*/
function normalizeClaude3Messages(messages) {
const result = []
for (const message of messages ?? []) {
if (message == null) {
continue
}
if (typeof message.content === 'string') {
// Messages can be specified with plain string content
result.push({ role: message.role, content: message.content })
} else if (Array.isArray(message.content)) {
// Or in a "chunked" format for multi-modal support
result.push({
role: message.role,
content: stringifyClaudeChunkedMessage(message.content)
})
}
}
return result
}

module.exports = BedrockCommand
4 changes: 3 additions & 1 deletion lib/llm-events/aws-bedrock/bedrock-response.js
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

'use strict'

const { stringifyClaudeChunkedMessage } = require('./utils')

/**
* @typedef {object} AwsBedrockMiddlewareResponse
* @property {object} response Has a `body` property that is an IncomingMessage,
Expand Down Expand Up @@ -63,7 +65,7 @@ class BedrockResponse {
// Streamed response
this.#completions = body.completions
} else {
this.#completions = body?.content?.map((c) => c.text)
this.#completions = [stringifyClaudeChunkedMessage(body?.content)]
}
this.#id = body.id
} else if (cmd.isCohere() === true) {
Expand Down
17 changes: 4 additions & 13 deletions lib/llm-events/aws-bedrock/chat-completion-message.js
Original file line number Diff line number Diff line change
Expand Up @@ -39,28 +39,19 @@ class LlmChatCompletionMessage extends LlmEvent {
params = Object.assign({}, defaultParams, params)
super(params)

const { agent, content, isResponse, index, completionId } = params
const { agent, content, isResponse, index, completionId, role } = params
const recordContent = agent.config?.ai_monitoring?.record_content?.enabled
const tokenCB = agent?.llm?.tokenCountCallback

this.is_response = isResponse
this.completion_id = completionId
this.sequence = index
this.content = recordContent === true ? content : undefined
this.role = ''
this.role = role

this.#setId(index)
if (this.is_response === true) {
this.role = 'assistant'
if (typeof tokenCB === 'function') {
this.token_count = tokenCB(this.bedrockCommand.modelId, content)
}
} else {
this.role = 'user'
this.content = recordContent === true ? this.bedrockCommand.prompt : undefined
if (typeof tokenCB === 'function') {
this.token_count = tokenCB(this.bedrockCommand.modelId, this.bedrockCommand.prompt)
}
if (typeof tokenCB === 'function') {
this.token_count = tokenCB(this.bedrockCommand.modelId, content)
}
}

Expand Down
2 changes: 1 addition & 1 deletion lib/llm-events/aws-bedrock/chat-completion-summary.js
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class LlmChatCompletionSummary extends LlmEvent {
const cmd = this.bedrockCommand
this[cfr] = this.bedrockResponse.finishReason
this[rt] = cmd.temperature
this[nm] = 1 + this.bedrockResponse.completions.length
this[nm] = (this.bedrockCommand.prompt?.length ?? 0) + this.bedrockResponse.completions.length
}
}

Expand Down
4 changes: 2 additions & 2 deletions lib/llm-events/aws-bedrock/embedding.js
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ const LlmEvent = require('./event')
/**
* @typedef {object} LlmEmbeddingParams
* @augments LlmEventParams
* @property
* @property {string} input - The input message for the embedding call
*/
/**
* @type {LlmEmbeddingParams}
Expand All @@ -24,7 +24,7 @@ class LlmEmbedding extends LlmEvent {
const tokenCb = agent?.llm?.tokenCountCallback

this.input = agent.config?.ai_monitoring?.record_content?.enabled
? this.bedrockCommand.prompt
? this.bedrockCommand.prompt?.[0]?.content
: undefined
this.error = params.isError
this.duration = params.segment.getDurationInMillis()
Expand Down
36 changes: 36 additions & 0 deletions lib/llm-events/aws-bedrock/utils.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/*
* Copyright 2024 New Relic Corporation. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*/

'use strict'

/**
*
* @param {object[]} chunks - The "chunks" that make up a single conceptual message. In a multi-modal scenario, a single message
* might have a number of different-typed chunks interspersed
* @returns {string} - A stringified version of the message. We make a best-effort effort attempt to represent non-text chunks. In the future
* we may want to extend the agent to support these non-text chunks in a richer way. Placeholders are represented in an XML-like format but
* are NOT intended to be parsed as valid XML
*/
function stringifyClaudeChunkedMessage(chunks) {
const stringifiedChunks = chunks.map((msgContent) => {
switch (msgContent.type) {
case 'text':
return msgContent.text
case 'image':
return '<image>'
case 'tool_use':
return `<tool_use>${msgContent.name}</tool_use>`
case 'tool_result':
return `<tool_result>${msgContent.content}</tool_result>`
default:
return '<unknown_chunk>'
}
})
return stringifiedChunks.join('\n\n')
}

module.exports = {
stringifyClaudeChunkedMessage
}
6 changes: 6 additions & 0 deletions test/lib/aws-server-stubs/ai-server/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,12 @@ function handler(req, res) {
break
}

// Chunked claude model
case 'anthropic.claude-3-5-sonnet-20240620-v1:0': {
response = responses.claude3.get(payload?.messages?.[0]?.content?.[0].text)
break
}

case 'cohere.command-text-v14':
case 'cohere.command-light-text-v14': {
response = responses.cohere.get(payload.prompt)
Expand Down
34 changes: 34 additions & 0 deletions test/lib/aws-server-stubs/ai-server/responses/claude3.js
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,40 @@ responses.set('text claude3 ultimate question', {
}
})

responses.set('text claude3 ultimate question chunked', {
headers: {
'content-type': contentType,
'x-amzn-requestid': reqId,
'x-amzn-bedrock-invocation-latency': '926',
'x-amzn-bedrock-output-token-count': '36',
'x-amzn-bedrock-input-token-count': '14'
},
statusCode: 200,
body: {
id: 'msg_bdrk_019V7ABaw8ZZZYuRDSTWK7VE',
type: 'message',
role: 'assistant',
model: 'claude-3-haiku-20240307',
stop_sequence: null,
usage: { input_tokens: 30, output_tokens: 265 },
content: [
{
type: 'text',
text: "Here's a nice picture of a 42"
},
{
type: 'image',
source: {
type: 'base64',
media_type: 'image/jpeg',
data: 'U2hoLiBUaGlzIGlzbid0IHJlYWxseSBhbiBpbWFnZQ=='
}
}
],
stop_reason: 'endoftext'
}
})

responses.set('text claude3 ultimate question streamed', {
headers: {
'content-type': 'application/vnd.amazon.eventstream',
Expand Down
Loading
Loading