Skip to content

Commit

Permalink
feat: Added support for Claude 3+ Chat API in Bedrock (#2870)
Browse files Browse the repository at this point in the history
  • Loading branch information
RyanKadri authored Jan 17, 2025
1 parent 7dceae9 commit 6a83abf
Show file tree
Hide file tree
Showing 17 changed files with 413 additions and 101 deletions.
39 changes: 25 additions & 14 deletions lib/instrumentation/aws-sdk/v3/bedrock.js
Original file line number Diff line number Diff line change
Expand Up @@ -118,27 +118,34 @@ function recordChatCompletionMessages({
isError: err !== null
})

const msg = new LlmChatCompletionMessage({
agent,
segment,
bedrockCommand,
bedrockResponse,
transaction,
index: 0,
completionId: summary.id
// Record context message(s)
const promptContextMessages = bedrockCommand.prompt
promptContextMessages.forEach((contextMessage, promptIndex) => {
const msg = new LlmChatCompletionMessage({
agent,
segment,
transaction,
bedrockCommand,
content: contextMessage.content,
role: contextMessage.role,
bedrockResponse,
index: promptIndex,
completionId: summary.id
})
recordEvent({ agent, type: 'LlmChatCompletionMessage', msg })
})
recordEvent({ agent, type: 'LlmChatCompletionMessage', msg })

bedrockResponse.completions.forEach((content, index) => {
bedrockResponse.completions.forEach((content, completionIndex) => {
const chatCompletionMessage = new LlmChatCompletionMessage({
agent,
segment,
transaction,
bedrockCommand,
bedrockResponse,
isResponse: true,
index: index + 1,
index: promptContextMessages.length + completionIndex,
content,
role: 'assistant',
completionId: summary.id
})
recordEvent({ agent, type: 'LlmChatCompletionMessage', msg: chatCompletionMessage })
Expand Down Expand Up @@ -179,18 +186,22 @@ function recordEmbeddingMessage({
return
}

const embedding = new LlmEmbedding({
const embeddings = bedrockCommand.prompt.map(prompt => new LlmEmbedding({
agent,
segment,
transaction,
bedrockCommand,
input: prompt.content,
bedrockResponse,
isError: err !== null
}))

embeddings.forEach(embedding => {
recordEvent({ agent, type: 'LlmEmbedding', msg: embedding })
})

recordEvent({ agent, type: 'LlmEmbedding', msg: embedding })
if (err) {
const llmError = new LlmError({ bedrockResponse, err, embedding })
const llmError = new LlmError({ bedrockResponse, err, embedding: embeddings.length === 1 ? embeddings[0] : undefined })
agent.errors.add(transaction, err, llmError)
}
}
Expand Down
76 changes: 55 additions & 21 deletions lib/llm-events/aws-bedrock/bedrock-command.js
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

'use strict'

const { stringifyClaudeChunkedMessage } = require('./utils')

/**
* Parses an AWS invoke command instance into a re-usable entity.
*/
Expand Down Expand Up @@ -68,37 +70,34 @@ class BedrockCommand {
/**
* The question posed to the LLM.
*
* @returns {string|string[]|undefined}
* @returns {object[]} The array of context messages passed to the LLM (or a single user prompt for legacy "non-chat" models)
*/
get prompt() {
let result
if (this.isTitan() === true || this.isTitanEmbed() === true) {
result = this.#body.inputText
return [
{
role: 'user',
content: this.#body.inputText
}
]
} else if (this.isCohereEmbed() === true) {
result = this.#body.texts.join(' ')
return [
{
role: 'user',
content: this.#body.texts.join(' ')
}
]
} else if (
this.isClaude() === true ||
this.isClaudeTextCompletionApi() === true ||
this.isAi21() === true ||
this.isCohere() === true ||
this.isLlama() === true
) {
result = this.#body.prompt
} else if (this.isClaude3() === true) {
const collected = []
for (const message of this.#body?.messages) {
if (message?.role === 'assistant') {
continue
}
if (typeof message?.content === 'string') {
collected.push(message?.content)
continue
}
const mappedMsgObj = message?.content.map((msgContent) => msgContent.text)
collected.push(mappedMsgObj)
}
result = collected.join(' ')
return [{ role: 'user', content: this.#body.prompt }]
} else if (this.isClaudeMessagesApi() === true) {
return normalizeClaude3Messages(this.#body?.messages)
}
return result
return []
}

/**
Expand Down Expand Up @@ -151,6 +150,41 @@ class BedrockCommand {
isTitanEmbed() {
return this.#modelId.startsWith('amazon.titan-embed')
}

isClaudeMessagesApi() {
return (this.isClaude3() === true || this.isClaude() === true) && 'messages' in this.#body
}

isClaudeTextCompletionApi() {
return this.isClaude() === true && 'prompt' in this.#body
}
}

/**
* Claude v3 requests in Bedrock can have two different "chat" flavors. This function normalizes them into a consistent
* format per the AIM agent spec
*
* @param messages - The raw array of messages passed to the invoke API
* @returns {number|undefined} - The normalized messages
*/
function normalizeClaude3Messages(messages) {
const result = []
for (const message of messages ?? []) {
if (message == null) {
continue
}
if (typeof message.content === 'string') {
// Messages can be specified with plain string content
result.push({ role: message.role, content: message.content })
} else if (Array.isArray(message.content)) {
// Or in a "chunked" format for multi-modal support
result.push({
role: message.role,
content: stringifyClaudeChunkedMessage(message.content)
})
}
}
return result
}

module.exports = BedrockCommand
4 changes: 3 additions & 1 deletion lib/llm-events/aws-bedrock/bedrock-response.js
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

'use strict'

const { stringifyClaudeChunkedMessage } = require('./utils')

/**
* @typedef {object} AwsBedrockMiddlewareResponse
* @property {object} response Has a `body` property that is an IncomingMessage,
Expand Down Expand Up @@ -63,7 +65,7 @@ class BedrockResponse {
// Streamed response
this.#completions = body.completions
} else {
this.#completions = body?.content?.map((c) => c.text)
this.#completions = [stringifyClaudeChunkedMessage(body?.content)]
}
this.#id = body.id
} else if (cmd.isCohere() === true) {
Expand Down
17 changes: 4 additions & 13 deletions lib/llm-events/aws-bedrock/chat-completion-message.js
Original file line number Diff line number Diff line change
Expand Up @@ -39,28 +39,19 @@ class LlmChatCompletionMessage extends LlmEvent {
params = Object.assign({}, defaultParams, params)
super(params)

const { agent, content, isResponse, index, completionId } = params
const { agent, content, isResponse, index, completionId, role } = params
const recordContent = agent.config?.ai_monitoring?.record_content?.enabled
const tokenCB = agent?.llm?.tokenCountCallback

this.is_response = isResponse
this.completion_id = completionId
this.sequence = index
this.content = recordContent === true ? content : undefined
this.role = ''
this.role = role

this.#setId(index)
if (this.is_response === true) {
this.role = 'assistant'
if (typeof tokenCB === 'function') {
this.token_count = tokenCB(this.bedrockCommand.modelId, content)
}
} else {
this.role = 'user'
this.content = recordContent === true ? this.bedrockCommand.prompt : undefined
if (typeof tokenCB === 'function') {
this.token_count = tokenCB(this.bedrockCommand.modelId, this.bedrockCommand.prompt)
}
if (typeof tokenCB === 'function') {
this.token_count = tokenCB(this.bedrockCommand.modelId, content)
}
}

Expand Down
2 changes: 1 addition & 1 deletion lib/llm-events/aws-bedrock/chat-completion-summary.js
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class LlmChatCompletionSummary extends LlmEvent {
const cmd = this.bedrockCommand
this[cfr] = this.bedrockResponse.finishReason
this[rt] = cmd.temperature
this[nm] = 1 + this.bedrockResponse.completions.length
this[nm] = (this.bedrockCommand.prompt.length) + this.bedrockResponse.completions.length
}
}

Expand Down
10 changes: 6 additions & 4 deletions lib/llm-events/aws-bedrock/embedding.js
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ const LlmEvent = require('./event')
/**
* @typedef {object} LlmEmbeddingParams
* @augments LlmEventParams
* @property
* @property {string} input - The input message for the embedding call
*/
/**
* @type {LlmEmbeddingParams}
Expand All @@ -20,16 +20,18 @@ const defaultParams = {}
class LlmEmbedding extends LlmEvent {
constructor(params = defaultParams) {
super(params)
const { agent } = params
const { agent, input } = params
const tokenCb = agent?.llm?.tokenCountCallback

this.input = agent.config?.ai_monitoring?.record_content?.enabled
? this.bedrockCommand.prompt
? input
: undefined
this.error = params.isError
this.duration = params.segment.getDurationInMillis()

// Even if not recording content, we should use the local token counting callback to record token usage
if (typeof tokenCb === 'function') {
this.token_count = tokenCb(this.bedrockCommand.modelId, this.bedrockCommand.prompt)
this.token_count = tokenCb(this.bedrockCommand.modelId, input)
}
}
}
Expand Down
36 changes: 36 additions & 0 deletions lib/llm-events/aws-bedrock/utils.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/*
* Copyright 2024 New Relic Corporation. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*/

'use strict'

/**
*
* @param {object[]} chunks - The "chunks" that make up a single conceptual message. In a multi-modal scenario, a single message
* might have a number of different-typed chunks interspersed
* @returns {string} - A stringified version of the message. We make a best-effort effort attempt to represent non-text chunks. In the future
* we may want to extend the agent to support these non-text chunks in a richer way. Placeholders are represented in an XML-like format but
* are NOT intended to be parsed as valid XML
*/
function stringifyClaudeChunkedMessage(chunks) {
const stringifiedChunks = chunks.map((msgContent) => {
switch (msgContent.type) {
case 'text':
return msgContent.text
case 'image':
return '<image>'
case 'tool_use':
return `<tool_use>${msgContent.name}</tool_use>`
case 'tool_result':
return `<tool_result>${msgContent.content}</tool_result>`
default:
return '<unknown_chunk>'
}
})
return stringifiedChunks.join('\n\n')
}

module.exports = {
stringifyClaudeChunkedMessage
}
6 changes: 6 additions & 0 deletions test/lib/aws-server-stubs/ai-server/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,12 @@ function handler(req, res) {
break
}

// Chunked claude model
case 'anthropic.claude-3-5-sonnet-20240620-v1:0': {
response = responses.claude3.get(payload?.messages?.[0]?.content?.[0].text)
break
}

case 'cohere.command-text-v14':
case 'cohere.command-light-text-v14': {
response = responses.cohere.get(payload.prompt)
Expand Down
34 changes: 34 additions & 0 deletions test/lib/aws-server-stubs/ai-server/responses/claude3.js
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,40 @@ responses.set('text claude3 ultimate question', {
}
})

responses.set('text claude3 ultimate question chunked', {
headers: {
'content-type': contentType,
'x-amzn-requestid': reqId,
'x-amzn-bedrock-invocation-latency': '926',
'x-amzn-bedrock-output-token-count': '36',
'x-amzn-bedrock-input-token-count': '14'
},
statusCode: 200,
body: {
id: 'msg_bdrk_019V7ABaw8ZZZYuRDSTWK7VE',
type: 'message',
role: 'assistant',
model: 'claude-3-haiku-20240307',
stop_sequence: null,
usage: { input_tokens: 30, output_tokens: 265 },
content: [
{
type: 'text',
text: "Here's a nice picture of a 42"
},
{
type: 'image',
source: {
type: 'base64',
media_type: 'image/jpeg',
data: 'U2hoLiBUaGlzIGlzbid0IHJlYWxseSBhbiBpbWFnZQ=='
}
}
],
stop_reason: 'endoftext'
}
})

responses.set('text claude3 ultimate question streamed', {
headers: {
'content-type': 'application/vnd.amazon.eventstream',
Expand Down
Loading

0 comments on commit 6a83abf

Please sign in to comment.