diff --git a/lib/instrumentation/aws-sdk/v3/bedrock.js b/lib/instrumentation/aws-sdk/v3/bedrock.js index 54d5ef4616..13de6e6254 100644 --- a/lib/instrumentation/aws-sdk/v3/bedrock.js +++ b/lib/instrumentation/aws-sdk/v3/bedrock.js @@ -118,18 +118,24 @@ function recordChatCompletionMessages({ isError: err !== null }) - const msg = new LlmChatCompletionMessage({ - agent, - segment, - bedrockCommand, - bedrockResponse, - transaction, - index: 0, - completionId: summary.id + // Record context message(s) + const promptContextMessages = bedrockCommand.prompt + promptContextMessages.forEach((contextMessage, promptIndex) => { + const msg = new LlmChatCompletionMessage({ + agent, + segment, + transaction, + bedrockCommand, + content: contextMessage.content, + role: contextMessage.role, + bedrockResponse, + index: promptIndex, + completionId: summary.id + }) + recordEvent({ agent, type: 'LlmChatCompletionMessage', msg }) }) - recordEvent({ agent, type: 'LlmChatCompletionMessage', msg }) - bedrockResponse.completions.forEach((content, index) => { + bedrockResponse.completions.forEach((content, completionIndex) => { const chatCompletionMessage = new LlmChatCompletionMessage({ agent, segment, @@ -137,8 +143,9 @@ function recordChatCompletionMessages({ bedrockCommand, bedrockResponse, isResponse: true, - index: index + 1, + index: promptContextMessages.length + completionIndex, content, + role: 'assistant', completionId: summary.id }) recordEvent({ agent, type: 'LlmChatCompletionMessage', msg: chatCompletionMessage }) @@ -179,18 +186,22 @@ function recordEmbeddingMessage({ return } - const embedding = new LlmEmbedding({ + const embeddings = bedrockCommand.prompt.map(prompt => new LlmEmbedding({ agent, segment, transaction, bedrockCommand, + input: prompt.content, bedrockResponse, isError: err !== null + })) + + embeddings.forEach(embedding => { + recordEvent({ agent, type: 'LlmEmbedding', msg: embedding }) }) - recordEvent({ agent, type: 'LlmEmbedding', msg: embedding }) if (err) { - const llmError = new LlmError({ bedrockResponse, err, embedding }) + const llmError = new LlmError({ bedrockResponse, err, embedding: embeddings.length === 1 ? embeddings[0] : undefined }) agent.errors.add(transaction, err, llmError) } } diff --git a/lib/llm-events/aws-bedrock/bedrock-command.js b/lib/llm-events/aws-bedrock/bedrock-command.js index 471f0b0e37..9baaaf49a0 100644 --- a/lib/llm-events/aws-bedrock/bedrock-command.js +++ b/lib/llm-events/aws-bedrock/bedrock-command.js @@ -5,6 +5,8 @@ 'use strict' +const { stringifyClaudeChunkedMessage } = require('./utils') + /** * Parses an AWS invoke command instance into a re-usable entity. */ @@ -68,37 +70,34 @@ class BedrockCommand { /** * The question posed to the LLM. * - * @returns {string|string[]|undefined} + * @returns {object[]} The array of context messages passed to the LLM (or a single user prompt for legacy "non-chat" models) */ get prompt() { - let result if (this.isTitan() === true || this.isTitanEmbed() === true) { - result = this.#body.inputText + return [ + { + role: 'user', + content: this.#body.inputText + } + ] } else if (this.isCohereEmbed() === true) { - result = this.#body.texts.join(' ') + return [ + { + role: 'user', + content: this.#body.texts.join(' ') + } + ] } else if ( - this.isClaude() === true || + this.isClaudeTextCompletionApi() === true || this.isAi21() === true || this.isCohere() === true || this.isLlama() === true ) { - result = this.#body.prompt - } else if (this.isClaude3() === true) { - const collected = [] - for (const message of this.#body?.messages) { - if (message?.role === 'assistant') { - continue - } - if (typeof message?.content === 'string') { - collected.push(message?.content) - continue - } - const mappedMsgObj = message?.content.map((msgContent) => msgContent.text) - collected.push(mappedMsgObj) - } - result = collected.join(' ') + return [{ role: 'user', content: this.#body.prompt }] + } else if (this.isClaudeMessagesApi() === true) { + return normalizeClaude3Messages(this.#body?.messages) } - return result + return [] } /** @@ -151,6 +150,41 @@ class BedrockCommand { isTitanEmbed() { return this.#modelId.startsWith('amazon.titan-embed') } + + isClaudeMessagesApi() { + return (this.isClaude3() === true || this.isClaude() === true) && 'messages' in this.#body + } + + isClaudeTextCompletionApi() { + return this.isClaude() === true && 'prompt' in this.#body + } +} + +/** + * Claude v3 requests in Bedrock can have two different "chat" flavors. This function normalizes them into a consistent + * format per the AIM agent spec + * + * @param messages - The raw array of messages passed to the invoke API + * @returns {number|undefined} - The normalized messages + */ +function normalizeClaude3Messages(messages) { + const result = [] + for (const message of messages ?? []) { + if (message == null) { + continue + } + if (typeof message.content === 'string') { + // Messages can be specified with plain string content + result.push({ role: message.role, content: message.content }) + } else if (Array.isArray(message.content)) { + // Or in a "chunked" format for multi-modal support + result.push({ + role: message.role, + content: stringifyClaudeChunkedMessage(message.content) + }) + } + } + return result } module.exports = BedrockCommand diff --git a/lib/llm-events/aws-bedrock/bedrock-response.js b/lib/llm-events/aws-bedrock/bedrock-response.js index 7b645cfc06..d7fb47329e 100644 --- a/lib/llm-events/aws-bedrock/bedrock-response.js +++ b/lib/llm-events/aws-bedrock/bedrock-response.js @@ -5,6 +5,8 @@ 'use strict' +const { stringifyClaudeChunkedMessage } = require('./utils') + /** * @typedef {object} AwsBedrockMiddlewareResponse * @property {object} response Has a `body` property that is an IncomingMessage, @@ -63,7 +65,7 @@ class BedrockResponse { // Streamed response this.#completions = body.completions } else { - this.#completions = body?.content?.map((c) => c.text) + this.#completions = [stringifyClaudeChunkedMessage(body?.content)] } this.#id = body.id } else if (cmd.isCohere() === true) { diff --git a/lib/llm-events/aws-bedrock/chat-completion-message.js b/lib/llm-events/aws-bedrock/chat-completion-message.js index 6f0c564794..0609529717 100644 --- a/lib/llm-events/aws-bedrock/chat-completion-message.js +++ b/lib/llm-events/aws-bedrock/chat-completion-message.js @@ -39,7 +39,7 @@ class LlmChatCompletionMessage extends LlmEvent { params = Object.assign({}, defaultParams, params) super(params) - const { agent, content, isResponse, index, completionId } = params + const { agent, content, isResponse, index, completionId, role } = params const recordContent = agent.config?.ai_monitoring?.record_content?.enabled const tokenCB = agent?.llm?.tokenCountCallback @@ -47,20 +47,11 @@ class LlmChatCompletionMessage extends LlmEvent { this.completion_id = completionId this.sequence = index this.content = recordContent === true ? content : undefined - this.role = '' + this.role = role this.#setId(index) - if (this.is_response === true) { - this.role = 'assistant' - if (typeof tokenCB === 'function') { - this.token_count = tokenCB(this.bedrockCommand.modelId, content) - } - } else { - this.role = 'user' - this.content = recordContent === true ? this.bedrockCommand.prompt : undefined - if (typeof tokenCB === 'function') { - this.token_count = tokenCB(this.bedrockCommand.modelId, this.bedrockCommand.prompt) - } + if (typeof tokenCB === 'function') { + this.token_count = tokenCB(this.bedrockCommand.modelId, content) } } diff --git a/lib/llm-events/aws-bedrock/chat-completion-summary.js b/lib/llm-events/aws-bedrock/chat-completion-summary.js index 36492a0e73..b5e07c82eb 100644 --- a/lib/llm-events/aws-bedrock/chat-completion-summary.js +++ b/lib/llm-events/aws-bedrock/chat-completion-summary.js @@ -36,7 +36,7 @@ class LlmChatCompletionSummary extends LlmEvent { const cmd = this.bedrockCommand this[cfr] = this.bedrockResponse.finishReason this[rt] = cmd.temperature - this[nm] = 1 + this.bedrockResponse.completions.length + this[nm] = (this.bedrockCommand.prompt.length) + this.bedrockResponse.completions.length } } diff --git a/lib/llm-events/aws-bedrock/embedding.js b/lib/llm-events/aws-bedrock/embedding.js index a1eec6e1bf..f754737697 100644 --- a/lib/llm-events/aws-bedrock/embedding.js +++ b/lib/llm-events/aws-bedrock/embedding.js @@ -10,7 +10,7 @@ const LlmEvent = require('./event') /** * @typedef {object} LlmEmbeddingParams * @augments LlmEventParams - * @property + * @property {string} input - The input message for the embedding call */ /** * @type {LlmEmbeddingParams} @@ -20,16 +20,18 @@ const defaultParams = {} class LlmEmbedding extends LlmEvent { constructor(params = defaultParams) { super(params) - const { agent } = params + const { agent, input } = params const tokenCb = agent?.llm?.tokenCountCallback this.input = agent.config?.ai_monitoring?.record_content?.enabled - ? this.bedrockCommand.prompt + ? input : undefined this.error = params.isError this.duration = params.segment.getDurationInMillis() + + // Even if not recording content, we should use the local token counting callback to record token usage if (typeof tokenCb === 'function') { - this.token_count = tokenCb(this.bedrockCommand.modelId, this.bedrockCommand.prompt) + this.token_count = tokenCb(this.bedrockCommand.modelId, input) } } } diff --git a/lib/llm-events/aws-bedrock/utils.js b/lib/llm-events/aws-bedrock/utils.js new file mode 100644 index 0000000000..0c547f81c7 --- /dev/null +++ b/lib/llm-events/aws-bedrock/utils.js @@ -0,0 +1,36 @@ +/* + * Copyright 2024 New Relic Corporation. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +'use strict' + +/** + * + * @param {object[]} chunks - The "chunks" that make up a single conceptual message. In a multi-modal scenario, a single message + * might have a number of different-typed chunks interspersed + * @returns {string} - A stringified version of the message. We make a best-effort effort attempt to represent non-text chunks. In the future + * we may want to extend the agent to support these non-text chunks in a richer way. Placeholders are represented in an XML-like format but + * are NOT intended to be parsed as valid XML + */ +function stringifyClaudeChunkedMessage(chunks) { + const stringifiedChunks = chunks.map((msgContent) => { + switch (msgContent.type) { + case 'text': + return msgContent.text + case 'image': + return '' + case 'tool_use': + return `${msgContent.name}` + case 'tool_result': + return `${msgContent.content}` + default: + return '' + } + }) + return stringifiedChunks.join('\n\n') +} + +module.exports = { + stringifyClaudeChunkedMessage +} diff --git a/test/lib/aws-server-stubs/ai-server/index.js b/test/lib/aws-server-stubs/ai-server/index.js index 5ab12bdd04..ef713c0542 100644 --- a/test/lib/aws-server-stubs/ai-server/index.js +++ b/test/lib/aws-server-stubs/ai-server/index.js @@ -101,6 +101,12 @@ function handler(req, res) { break } + // Chunked claude model + case 'anthropic.claude-3-5-sonnet-20240620-v1:0': { + response = responses.claude3.get(payload?.messages?.[0]?.content?.[0].text) + break + } + case 'cohere.command-text-v14': case 'cohere.command-light-text-v14': { response = responses.cohere.get(payload.prompt) diff --git a/test/lib/aws-server-stubs/ai-server/responses/claude3.js b/test/lib/aws-server-stubs/ai-server/responses/claude3.js index 2e0fbf0505..3728867b1b 100644 --- a/test/lib/aws-server-stubs/ai-server/responses/claude3.js +++ b/test/lib/aws-server-stubs/ai-server/responses/claude3.js @@ -34,6 +34,40 @@ responses.set('text claude3 ultimate question', { } }) +responses.set('text claude3 ultimate question chunked', { + headers: { + 'content-type': contentType, + 'x-amzn-requestid': reqId, + 'x-amzn-bedrock-invocation-latency': '926', + 'x-amzn-bedrock-output-token-count': '36', + 'x-amzn-bedrock-input-token-count': '14' + }, + statusCode: 200, + body: { + id: 'msg_bdrk_019V7ABaw8ZZZYuRDSTWK7VE', + type: 'message', + role: 'assistant', + model: 'claude-3-haiku-20240307', + stop_sequence: null, + usage: { input_tokens: 30, output_tokens: 265 }, + content: [ + { + type: 'text', + text: "Here's a nice picture of a 42" + }, + { + type: 'image', + source: { + type: 'base64', + media_type: 'image/jpeg', + data: 'U2hoLiBUaGlzIGlzbid0IHJlYWxseSBhbiBpbWFnZQ==' + } + } + ], + stop_reason: 'endoftext' + } +}) + responses.set('text claude3 ultimate question streamed', { headers: { 'content-type': 'application/vnd.amazon.eventstream', diff --git a/test/unit/llm-events/aws-bedrock/bedrock-command.test.js b/test/unit/llm-events/aws-bedrock/bedrock-command.test.js index 9395a328c2..0a62e84117 100644 --- a/test/unit/llm-events/aws-bedrock/bedrock-command.test.js +++ b/test/unit/llm-events/aws-bedrock/bedrock-command.test.js @@ -38,7 +38,7 @@ const claude35 = { const claude3 = { modelId: 'anthropic.claude-3-haiku-20240307-v1:0', body: { - messages: [{ content: 'who are you' }] + messages: [{ role: 'user', content: 'who are you' }] } } @@ -114,7 +114,7 @@ test('non-conforming command is handled gracefully', async (t) => { assert.equal(cmd.maxTokens, undefined) assert.equal(cmd.modelId, '') assert.equal(cmd.modelType, 'completion') - assert.equal(cmd.prompt, undefined) + assert.deepEqual(cmd.prompt, []) assert.equal(cmd.temperature, undefined) }) @@ -125,7 +125,7 @@ test('ai21 minimal command works', async (t) => { assert.equal(cmd.maxTokens, undefined) assert.equal(cmd.modelId, ai21.modelId) assert.equal(cmd.modelType, 'completion') - assert.equal(cmd.prompt, ai21.body.prompt) + assert.deepEqual(cmd.prompt, [{ role: 'user', content: ai21.body.prompt }]) assert.equal(cmd.temperature, undefined) }) @@ -139,7 +139,7 @@ test('ai21 complete command works', async (t) => { assert.equal(cmd.maxTokens, 25) assert.equal(cmd.modelId, payload.modelId) assert.equal(cmd.modelType, 'completion') - assert.equal(cmd.prompt, payload.body.prompt) + assert.deepEqual(cmd.prompt, [{ role: 'user', content: payload.body.prompt }]) assert.equal(cmd.temperature, payload.body.temperature) }) @@ -150,7 +150,7 @@ test('claude minimal command works', async (t) => { assert.equal(cmd.maxTokens, undefined) assert.equal(cmd.modelId, claude.modelId) assert.equal(cmd.modelType, 'completion') - assert.equal(cmd.prompt, claude.body.prompt) + assert.deepEqual(cmd.prompt, [{ role: 'user', content: claude.body.prompt }]) assert.equal(cmd.temperature, undefined) }) @@ -164,7 +164,7 @@ test('claude complete command works', async (t) => { assert.equal(cmd.maxTokens, 25) assert.equal(cmd.modelId, payload.modelId) assert.equal(cmd.modelType, 'completion') - assert.equal(cmd.prompt, payload.body.prompt) + assert.deepEqual(cmd.prompt, [{ role: 'user', content: payload.body.prompt }]) assert.equal(cmd.temperature, payload.body.temperature) }) @@ -175,7 +175,7 @@ test('claude3 minimal command works', async (t) => { assert.equal(cmd.maxTokens, undefined) assert.equal(cmd.modelId, claude3.modelId) assert.equal(cmd.modelType, 'completion') - assert.equal(cmd.prompt, claude3.body.messages[0].content) + assert.deepEqual(cmd.prompt, claude3.body.messages) assert.equal(cmd.temperature, undefined) }) @@ -189,7 +189,7 @@ test('claude3 complete command works', async (t) => { assert.equal(cmd.maxTokens, 25) assert.equal(cmd.modelId, payload.modelId) assert.equal(cmd.modelType, 'completion') - assert.equal(cmd.prompt, payload.body.messages[0].content) + assert.deepEqual(cmd.prompt, payload.body.messages) assert.equal(cmd.temperature, payload.body.temperature) }) @@ -200,7 +200,20 @@ test('claude35 minimal command works with claude 3 api', async (t) => { assert.equal(cmd.maxTokens, undefined) assert.equal(cmd.modelId, claude3.modelId) assert.equal(cmd.modelType, 'completion') - assert.equal(cmd.prompt, claude3.body.messages[0].content) + assert.deepEqual(cmd.prompt, claude3.body.messages) + assert.equal(cmd.temperature, undefined) +}) + +test('claude35 malformed payload produces reasonable values', async (t) => { + const malformedPayload = structuredClone(claude35) + malformedPayload.body = {} + t.nr.updatePayload(malformedPayload) + const cmd = new BedrockCommand(t.nr.input) + assert.equal(cmd.isClaude3(), true) + assert.equal(cmd.maxTokens, undefined) + assert.equal(cmd.modelId, claude35.modelId) + assert.equal(cmd.modelType, 'completion') + assert.deepEqual(cmd.prompt, []) assert.equal(cmd.temperature, undefined) }) @@ -211,7 +224,7 @@ test('claude35 minimal command works', async (t) => { assert.equal(cmd.maxTokens, undefined) assert.equal(cmd.modelId, claude35.modelId) assert.equal(cmd.modelType, 'completion') - assert.equal(cmd.prompt, 'who are you') + assert.deepEqual(cmd.prompt, [{ role: 'user', content: 'who are' }, { role: 'assistant', content: 'researching' }, { role: 'user', content: 'you' }]) assert.equal(cmd.temperature, undefined) }) @@ -225,7 +238,7 @@ test('claude35 complete command works', async (t) => { assert.equal(cmd.maxTokens, 25) assert.equal(cmd.modelId, payload.modelId) assert.equal(cmd.modelType, 'completion') - assert.equal(cmd.prompt, 'who are you') + assert.deepEqual(cmd.prompt, [{ role: 'user', content: 'who are' }, { role: 'assistant', content: 'researching' }, { role: 'user', content: 'you' }]) assert.equal(cmd.temperature, payload.body.temperature) }) @@ -236,7 +249,7 @@ test('cohere minimal command works', async (t) => { assert.equal(cmd.maxTokens, undefined) assert.equal(cmd.modelId, cohere.modelId) assert.equal(cmd.modelType, 'completion') - assert.equal(cmd.prompt, cohere.body.prompt) + assert.deepEqual(cmd.prompt, [{ role: 'user', content: cohere.body.prompt }]) assert.equal(cmd.temperature, undefined) }) @@ -250,7 +263,7 @@ test('cohere complete command works', async (t) => { assert.equal(cmd.maxTokens, 25) assert.equal(cmd.modelId, payload.modelId) assert.equal(cmd.modelType, 'completion') - assert.equal(cmd.prompt, payload.body.prompt) + assert.deepEqual(cmd.prompt, [{ role: 'user', content: payload.body.prompt }]) assert.equal(cmd.temperature, payload.body.temperature) }) @@ -261,7 +274,7 @@ test('cohere embed minimal command works', async (t) => { assert.equal(cmd.maxTokens, undefined) assert.equal(cmd.modelId, cohereEmbed.modelId) assert.equal(cmd.modelType, 'embedding') - assert.deepStrictEqual(cmd.prompt, cohereEmbed.body.texts.join(' ')) + assert.deepStrictEqual(cmd.prompt, [{ role: 'user', content: cohereEmbed.body.texts.join(' ') }]) assert.equal(cmd.temperature, undefined) }) @@ -272,7 +285,7 @@ test('llama2 minimal command works', async (t) => { assert.equal(cmd.maxTokens, undefined) assert.equal(cmd.modelId, llama2.modelId) assert.equal(cmd.modelType, 'completion') - assert.equal(cmd.prompt, llama2.body.prompt) + assert.deepEqual(cmd.prompt, [{ role: 'user', content: llama2.body.prompt }]) assert.equal(cmd.temperature, undefined) }) @@ -286,7 +299,7 @@ test('llama2 complete command works', async (t) => { assert.equal(cmd.maxTokens, 25) assert.equal(cmd.modelId, payload.modelId) assert.equal(cmd.modelType, 'completion') - assert.equal(cmd.prompt, payload.body.prompt) + assert.deepEqual(cmd.prompt, [{ role: 'user', content: payload.body.prompt }]) assert.equal(cmd.temperature, payload.body.temperature) }) @@ -297,7 +310,7 @@ test('llama3 minimal command works', async (t) => { assert.equal(cmd.maxTokens, undefined) assert.equal(cmd.modelId, llama3.modelId) assert.equal(cmd.modelType, 'completion') - assert.equal(cmd.prompt, llama3.body.prompt) + assert.deepEqual(cmd.prompt, [{ role: 'user', content: llama3.body.prompt }]) assert.equal(cmd.temperature, undefined) }) @@ -311,7 +324,7 @@ test('llama3 complete command works', async (t) => { assert.equal(cmd.maxTokens, 25) assert.equal(cmd.modelId, payload.modelId) assert.equal(cmd.modelType, 'completion') - assert.equal(cmd.prompt, payload.body.prompt) + assert.deepEqual(cmd.prompt, [{ role: 'user', content: payload.body.prompt }]) assert.equal(cmd.temperature, payload.body.temperature) }) @@ -322,7 +335,7 @@ test('titan minimal command works', async (t) => { assert.equal(cmd.maxTokens, undefined) assert.equal(cmd.modelId, titan.modelId) assert.equal(cmd.modelType, 'completion') - assert.equal(cmd.prompt, titan.body.inputText) + assert.deepEqual(cmd.prompt, [{ role: 'user', content: titan.body.inputText }]) assert.equal(cmd.temperature, undefined) }) @@ -338,7 +351,7 @@ test('titan complete command works', async (t) => { assert.equal(cmd.maxTokens, 25) assert.equal(cmd.modelId, payload.modelId) assert.equal(cmd.modelType, 'completion') - assert.equal(cmd.prompt, payload.body.inputText) + assert.deepEqual(cmd.prompt, [{ role: 'user', content: payload.body.inputText }]) assert.equal(cmd.temperature, payload.body.textGenerationConfig.temperature) }) @@ -349,6 +362,6 @@ test('titan embed minimal command works', async (t) => { assert.equal(cmd.maxTokens, undefined) assert.equal(cmd.modelId, titanEmbed.modelId) assert.equal(cmd.modelType, 'embedding') - assert.equal(cmd.prompt, titanEmbed.body.inputText) + assert.deepEqual(cmd.prompt, [{ role: 'user', content: titanEmbed.body.inputText }]) assert.equal(cmd.temperature, undefined) }) diff --git a/test/unit/llm-events/aws-bedrock/bedrock-response.test.js b/test/unit/llm-events/aws-bedrock/bedrock-response.test.js index b4047c7324..3805df2cb5 100644 --- a/test/unit/llm-events/aws-bedrock/bedrock-response.test.js +++ b/test/unit/llm-events/aws-bedrock/bedrock-response.test.js @@ -29,6 +29,14 @@ const claude = { stop_reason: 'done' } +const claude35 = { + content: [ + { type: 'text', text: 'Hello' }, + { type: 'text', text: 'world' } + ], + stop_reason: 'done' +} + const cohere = { id: 'cohere-response-1', generations: [ @@ -151,6 +159,18 @@ test('claude complete responses work', async (t) => { assert.equal(res.statusCode, 200) }) +test('claude 3.5 complete responses work', async (t) => { + t.nr.bedrockCommand.isClaude3 = () => true + t.nr.updatePayload(structuredClone(claude35)) + const res = new BedrockResponse(t.nr) + assert.deepStrictEqual(res.completions, ['Hello\n\nworld']) + assert.equal(res.finishReason, 'done') + assert.deepStrictEqual(res.headers, t.nr.response.response.headers) + assert.equal(res.id, undefined) + assert.equal(res.requestId, 'aws-request-1') + assert.equal(res.statusCode, 200) +}) + test('cohere malformed responses work', async (t) => { t.nr.bedrockCommand.isCohere = () => true const res = new BedrockResponse(t.nr) diff --git a/test/unit/llm-events/aws-bedrock/chat-completion-message.test.js b/test/unit/llm-events/aws-bedrock/chat-completion-message.test.js index f0af45c441..0b19f056af 100644 --- a/test/unit/llm-events/aws-bedrock/chat-completion-message.test.js +++ b/test/unit/llm-events/aws-bedrock/chat-completion-message.test.js @@ -56,6 +56,7 @@ test.beforeEach((ctx) => { ctx.nr.segment = { id: 'segment-1' } + ctx.nr.role = 'assistant' ctx.nr.bedrockResponse = { headers: { @@ -71,7 +72,6 @@ test.beforeEach((ctx) => { ctx.nr.bedrockCommand = { id: 'cmd-1', - prompt: 'who are you', isAi21() { return false }, @@ -92,12 +92,13 @@ test.beforeEach((ctx) => { test('create creates a non-response instance', async (t) => { t.nr.agent.llm.tokenCountCallback = () => 3 + t.nr.role = 'user' const event = new LlmChatCompletionMessage(t.nr) assert.equal(event.is_response, false) assert.equal(event['llm.conversation_id'], 'conversation-1') assert.equal(event.completion_id, 'completion-1') assert.equal(event.sequence, 0) - assert.equal(event.content, 'who are you') + assert.equal(event.content, 'a prompt') assert.equal(event.role, 'user') assert.match(event.id, /[\w-]{36}/) assert.equal(event.token_count, 3) diff --git a/test/unit/llm-events/aws-bedrock/chat-completion-summary.test.js b/test/unit/llm-events/aws-bedrock/chat-completion-summary.test.js index cd37b0d020..57f56af28a 100644 --- a/test/unit/llm-events/aws-bedrock/chat-completion-summary.test.js +++ b/test/unit/llm-events/aws-bedrock/chat-completion-summary.test.js @@ -50,6 +50,9 @@ test.beforeEach((ctx) => { ctx.nr.bedrockCommand = { maxTokens: 25, temperature: 0.5, + prompt: [ + { role: 'user', content: 'Hello!' } + ], isAi21() { return false }, diff --git a/test/unit/llm-events/aws-bedrock/embedding.test.js b/test/unit/llm-events/aws-bedrock/embedding.test.js index b2908401a6..bc5bdcdc3c 100644 --- a/test/unit/llm-events/aws-bedrock/embedding.test.js +++ b/test/unit/llm-events/aws-bedrock/embedding.test.js @@ -46,9 +46,10 @@ test.beforeEach((ctx) => { } ctx.nr.bedrockCommand = { - prompt: 'who are you' } + ctx.nr.input = 'who are you' + ctx.nr.bedrockResponse = { headers: { 'x-amzn-requestid': 'request-1' diff --git a/test/unit/llm-events/aws-bedrock/utils.test.js b/test/unit/llm-events/aws-bedrock/utils.test.js new file mode 100644 index 0000000000..18b8106713 --- /dev/null +++ b/test/unit/llm-events/aws-bedrock/utils.test.js @@ -0,0 +1,48 @@ +/* + * Copyright 2024 New Relic Corporation. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +'use strict' + +const test = require('node:test') +const assert = require('node:assert') +const { stringifyClaudeChunkedMessage } = require('../../../../lib/llm-events/aws-bedrock/utils') + +test('interleaves text chunks with other types', async (t) => { + const out = stringifyClaudeChunkedMessage([ + { type: 'text', text: 'Hello' }, + { type: 'image', source: { type: 'base64', data: 'U29tZSByYW5kb20gaW1hZ2U=', media_type: 'image/jpeg' } }, + { type: 'text', text: 'world' } + ]) + + assert.equal(out, 'Hello\n\n\n\nworld') +}) + +test('adds a placeholder for unrecognized chunk types', async (t) => { + const out = stringifyClaudeChunkedMessage([ + { type: 'text', text: 'Hello' }, + { type: 'direct_neural_upload', data: 'V2hhdCBzaG91bGQgSSBtYWtlIGZvciBkaW5uZXI/' }, + { type: 'text', text: 'world' } + ]) + + assert.equal(out, 'Hello\n\n\n\nworld') +}) + +test('adds information about tool calls', async (t) => { + const out = stringifyClaudeChunkedMessage([ + { type: 'text', text: 'I will look up the weather in Philly' }, + { type: 'tool_use', name: 'lookup_weather', input: { city: 'Philly' }, id: 'abc123' }, + ]) + + assert.equal(out, 'I will look up the weather in Philly\n\nlookup_weather') +}) + +test('adds information about tool results', async (t) => { + const out = stringifyClaudeChunkedMessage([ + { type: 'text', text: 'Here is the weather in philly' }, + { type: 'tool_result', name: 'lookup_weather', content: 'Nice!', tool_use_id: 'abc123' }, + ]) + + assert.equal(out, 'Here is the weather in philly\n\nNice!') +}) diff --git a/test/versioned/aws-sdk-v3/bedrock-chat-completions.test.js b/test/versioned/aws-sdk-v3/bedrock-chat-completions.test.js index e7dc8dcce7..9932329425 100644 --- a/test/versioned/aws-sdk-v3/bedrock-chat-completions.test.js +++ b/test/versioned/aws-sdk-v3/bedrock-chat-completions.test.js @@ -6,7 +6,12 @@ 'use strict' const assert = require('node:assert') const test = require('node:test') -const { afterEach, assertChatCompletionMessages, assertChatCompletionSummary } = require('./common') +const { + afterEach, + assertChatCompletionMessages, + assertChatCompletionSummary, + assertChatCompletionMessage +} = require('./common') const helper = require('../../lib/agent_helper') const createAiResponseServer = require('../../lib/aws-server-stubs/ai-server') const { FAKE_CREDENTIALS } = require('../../lib/aws-server-stubs') @@ -50,6 +55,16 @@ const requests = { }), modelId }), + claude3Chunked: (chunks, modelId) => ({ + body: JSON.stringify({ + anthropic_version: 'bedrock-2023-05-31', + max_tokens: 100, + temperature: 0.5, + system: 'Please respond in the style of Christopher Walken', + messages: chunks + }), + modelId + }), cohere: (prompt, modelId) => ({ body: JSON.stringify({ prompt, temperature: 0.5, max_tokens: 100 }), modelId @@ -466,6 +481,82 @@ test('ai21: should properly create errors on create completion (streamed)', asyn }) }) +test('anthropic-claude-3: should properly create events for chunked messages', async (t) => { + const { bedrock, client, agent } = t.nr + const modelId = 'anthropic.claude-3-5-sonnet-20240620-v1:0' + const prompt = 'text claude3 ultimate question chunked' + const promptFollowUp = 'And please include an image in the response' + const input = requests.claude3Chunked( + [ + { + role: 'user', + content: [ + { + type: 'text', + text: prompt + } + ] + }, + { + role: 'user', + content: [ + { + type: 'text', + text: promptFollowUp + } + ] + } + ], + modelId + ) + + const command = new bedrock.InvokeModelCommand(input) + + const api = helper.getAgentApi() + await helper.runInTransaction(agent, async (tx) => { + api.addCustomAttribute('llm.conversation_id', 'convo-id') + await client.send(command) + + const events = agent.customEventAggregator.events.toArray() + assert.equal(events.length, 4) + const chatSummary = events.filter(([{ type }]) => type === 'LlmChatCompletionSummary')[0] + const chatMsgs = events + .filter(([{ type }]) => type === 'LlmChatCompletionMessage') + .sort(([, a], [, b]) => a.sequence - b.sequence) + + assertChatCompletionMessage({ + tx, + message: chatMsgs[0], + modelId, + expectedContent: prompt, + isResponse: false, + expectedRole: 'user' + }) + + assertChatCompletionMessage({ + tx, + message: chatMsgs[1], + modelId, + expectedContent: promptFollowUp, + isResponse: false, + expectedRole: 'user' + }) + + // Note the placeholder for the image chunk + assertChatCompletionMessage({ + tx, + message: chatMsgs[2], + modelId, + expectedContent: "Here's a nice picture of a 42\n\n", + isResponse: true, + expectedRole: 'assistant' + }) + + assertChatCompletionSummary({ tx, modelId, chatSummary, numMsgs: 3 }) + tx.end() + }) +}) + test('models that do not support streaming should be handled', async (t) => { const { bedrock, client, agent, expectedExternalPath } = t.nr const modelId = 'amazon.titan-embed-text-v1' diff --git a/test/versioned/aws-sdk-v3/common.js b/test/versioned/aws-sdk-v3/common.js index f1aa64eb08..b64b677eeb 100644 --- a/test/versioned/aws-sdk-v3/common.js +++ b/test/versioned/aws-sdk-v3/common.js @@ -77,6 +77,35 @@ function checkExternals({ service, operations, tx, end }) { } function assertChatCompletionMessages({ tx, chatMsgs, expectedId, modelId, prompt, resContent }) { + chatMsgs.forEach((msg) => { + if (msg[1].sequence > 1) { + // Streamed responses may have more than two messages. + // We only care about the start and end of the conversation. + return + } + + const isResponse = msg[1].sequence === 1 + assertChatCompletionMessage({ + tx, + message: msg, + expectedId, + modelId, + expectedContent: isResponse ? resContent : prompt, + isResponse, + expectedRole: isResponse ? 'assistant' : 'user' + }) + }) +} + +function assertChatCompletionMessage({ + tx, + message, + expectedId, + modelId, + expectedContent, + isResponse, + expectedRole +}) { const [segment] = tx.trace.getChildren(tx.trace.root.id) const baseMsg = { appName: 'New Relic for Node.js tests', @@ -91,30 +120,19 @@ function assertChatCompletionMessages({ tx, chatMsgs, expectedId, modelId, promp completion_id: /\w{8}-\w{4}-\w{4}-\w{4}-\w{12}/ } - chatMsgs.forEach((msg) => { - if (msg[1].sequence > 1) { - // Streamed responses may have more than two messages. - // We only care about the start and end of the conversation. - return - } + const [messageBase, messageData] = message - const expectedChatMsg = { ...baseMsg } - const id = expectedId ? `${expectedId}-${msg[1].sequence}` : msg[1].id - if (msg[1].sequence === 0) { - expectedChatMsg.sequence = 0 - expectedChatMsg.id = id - expectedChatMsg.content = prompt - } else if (msg[1].sequence === 1) { - expectedChatMsg.sequence = 1 - expectedChatMsg.role = 'assistant' - expectedChatMsg.id = id - expectedChatMsg.content = resContent - expectedChatMsg.is_response = true - } + const expectedChatMsg = { ...baseMsg } + const id = expectedId ? `${expectedId}-${messageData.sequence}` : messageData.id - assert.equal(msg[0].type, 'LlmChatCompletionMessage') - match(msg[1], expectedChatMsg) - }) + expectedChatMsg.sequence = messageData.sequence + expectedChatMsg.role = expectedRole + expectedChatMsg.id = id + expectedChatMsg.content = expectedContent + expectedChatMsg.is_response = isResponse + + assert.equal(messageBase.type, 'LlmChatCompletionMessage') + match(messageData, expectedChatMsg) } function assertChatCompletionSummary({ tx, modelId, chatSummary, error = false, numMsgs = 2 }) { @@ -162,6 +180,7 @@ module.exports = { afterEach, assertChatCompletionSummary, assertChatCompletionMessages, + assertChatCompletionMessage, DATASTORE_PATTERN, EXTERN_PATTERN, SNS_PATTERN,