diff --git a/api.js b/api.js index 5c0f05b6dd..49c210f54e 100644 --- a/api.js +++ b/api.js @@ -32,6 +32,7 @@ const obfuscate = require('./lib/util/sql/obfuscate') const { DESTINATIONS } = require('./lib/config/attribute-filter') const parse = require('module-details-from-path') const { isSimpleObject } = require('./lib/util/objects') +const { AsyncLocalStorage } = require('async_hooks') /* * @@ -1902,4 +1903,52 @@ API.prototype.ignoreApdex = function ignoreApdex() { transaction.ignoreApdex = true } +/** + * Run a function with the passed in LLM context as the active context and return its return value. + * + * An example of setting a custom attribute: + * + * newrelic.withLlmCustomAttributes({'llm.someAttribute': 'someValue'}, () => { + * return; + * }) + * @param {Object} context LLM custom attributes context + * @param {Function} callback The function to execute in context. + */ +API.prototype.withLlmCustomAttributes = function withLlmCustomAttributes(context, callback) { + context = context || {} + const metric = this.agent.metrics.getOrCreateMetric( + NAMES.SUPPORTABILITY.API + '/withLlmCustomAttributes' + ) + metric.incrementCallCount() + + const transaction = this.agent.tracer.getTransaction() + + if (!callback || typeof callback !== 'function') { + logger.warn('withLlmCustomAttributes must be used with a valid callback') + return + } + + if (!transaction) { + logger.warn('withLlmCustomAttributes must be called within the scope of a transaction.') + return callback() + } + + for (const [key, value] of Object.entries(context)) { + if (typeof value === 'object' || typeof value === 'function') { + logger.warn(`Invalid attribute type for ${key}. Skipped.`) + delete context[key] + } else if (key.indexOf('llm.') !== 0) { + logger.warn(`Invalid attribute name ${key}. Renamed to "llm.${key}".`) + delete context[key] + context[`llm.${key}`] = value + } + } + + transaction._llmContextManager = transaction._llmContextManager || new AsyncLocalStorage() + const parentContext = transaction._llmContextManager.getStore() || {} + + const fullContext = Object.assign({}, parentContext, context) + return transaction._llmContextManager.run(fullContext, callback) +} + module.exports = API diff --git a/lib/instrumentation/aws-sdk/v3/bedrock.js b/lib/instrumentation/aws-sdk/v3/bedrock.js index 055c572c63..7fe1b4d277 100644 --- a/lib/instrumentation/aws-sdk/v3/bedrock.js +++ b/lib/instrumentation/aws-sdk/v3/bedrock.js @@ -18,6 +18,7 @@ const { DESTINATIONS } = require('../../../config/attribute-filter') const { AI } = require('../../../metrics/names') const { RecorderSpec } = require('../../../shim/specs') const InstrumentationDescriptor = require('../../../instrumentation-descriptor') +const { extractLlmContext } = require('../../../util/llm-utils') let TRACKING_METRIC @@ -55,7 +56,12 @@ function isStreamingEnabled({ commandName, config }) { */ function recordEvent({ agent, type, msg }) { msg.serialize() - agent.customEventAggregator.add([{ type, timestamp: Date.now() }, msg]) + const llmContext = extractLlmContext(agent) + + agent.customEventAggregator.add([ + { type, timestamp: Date.now() }, + Object.assign({}, msg, llmContext) + ]) } /** diff --git a/lib/instrumentation/langchain/common.js b/lib/instrumentation/langchain/common.js index b8e5de272f..34e3d84c8e 100644 --- a/lib/instrumentation/langchain/common.js +++ b/lib/instrumentation/langchain/common.js @@ -7,6 +7,7 @@ const { AI: { LANGCHAIN } } = require('../../metrics/names') +const { extractLlmContext } = require('../../util/llm-utils') const common = module.exports @@ -49,7 +50,12 @@ common.mergeMetadata = function mergeMetadata(localMeta = {}, paramsMeta = {}) { */ common.recordEvent = function recordEvent({ agent, type, msg, pkgVersion }) { agent.metrics.getOrCreateMetric(`${LANGCHAIN.TRACKING_PREFIX}/${pkgVersion}`).incrementCallCount() - agent.customEventAggregator.add([{ type, timestamp: Date.now() }, msg]) + const llmContext = extractLlmContext(agent) + + agent.customEventAggregator.add([ + { type, timestamp: Date.now() }, + Object.assign({}, msg, llmContext) + ]) } /** diff --git a/lib/instrumentation/openai.js b/lib/instrumentation/openai.js index da0b4a22f4..36e4487057 100644 --- a/lib/instrumentation/openai.js +++ b/lib/instrumentation/openai.js @@ -12,6 +12,7 @@ const { LlmErrorMessage } = require('../../lib/llm-events/openai') const { RecorderSpec } = require('../../lib/shim/specs') +const { extractLlmContext } = require('../util/llm-utils') const MIN_VERSION = '4.0.0' const MIN_STREAM_VERSION = '4.12.2' @@ -75,7 +76,12 @@ function decorateSegment({ shim, result, apiKey }) { * @param {object} params.msg LLM event */ function recordEvent({ agent, type, msg }) { - agent.customEventAggregator.add([{ type, timestamp: Date.now() }, msg]) + const llmContext = extractLlmContext(agent) + + agent.customEventAggregator.add([ + { type, timestamp: Date.now() }, + Object.assign({}, msg, llmContext) + ]) } /** diff --git a/lib/util/llm-utils.js b/lib/util/llm-utils.js new file mode 100644 index 0000000000..7d26e692c7 --- /dev/null +++ b/lib/util/llm-utils.js @@ -0,0 +1,34 @@ +/* + * Copyright 2020 New Relic Corporation. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +'use strict' + +exports = module.exports = { extractLlmContext, extractLlmAttributes } + +/** + * Extract LLM attributes from the LLM context + * + * @param {Object} context LLM context object + * @returns {Object} LLM custom attributes + */ +function extractLlmAttributes(context) { + return Object.keys(context).reduce((result, key) => { + if (key.indexOf('llm.') === 0) { + result[key] = context[key] + } + return result + }, {}) +} + +/** + * Extract LLM context from the active transaction + * + * @param {Agent} agent NR agent instance + * @returns {Object} LLM context object + */ +function extractLlmContext(agent) { + const context = agent.tracer.getTransaction()?._llmContextManager?.getStore() || {} + return extractLlmAttributes(context) +} diff --git a/test/unit/api/api-llm.test.js b/test/unit/api/api-llm.test.js index a2fb477d7c..26eab92c81 100644 --- a/test/unit/api/api-llm.test.js +++ b/test/unit/api/api-llm.test.js @@ -121,6 +121,116 @@ tap.test('Agent API LLM methods', (t) => { }) }) + t.test('withLlmCustomAttributes should handle no active transaction', (t) => { + const { api } = t.context + t.equal( + api.withLlmCustomAttributes({ test: 1 }, () => { + t.equal(loggerMock.warn.callCount, 1) + return 1 + }), + 1 + ) + t.end() + }) + + t.test('withLlmCustomAttributes should handle an empty store', (t) => { + const { api } = t.context + const agent = api.agent + + helper.runInTransaction(api.agent, (tx) => { + agent.tracer.getTransaction = () => { + return tx + } + t.equal( + api.withLlmCustomAttributes(null, () => { + return 1 + }), + 1 + ) + t.end() + }) + }) + + t.test('withLlmCustomAttributes should handle no callback', (t) => { + const { api } = t.context + const agent = api.agent + helper.runInTransaction(api.agent, (tx) => { + agent.tracer.getTransaction = () => { + return tx + } + api.withLlmCustomAttributes({ test: 1 }, null) + t.equal(loggerMock.warn.callCount, 1) + t.end() + }) + }) + + t.test('withLlmCustomAttributes should normalize attributes', (t) => { + const { api } = t.context + const agent = api.agent + helper.runInTransaction(api.agent, (tx) => { + agent.tracer.getTransaction = () => { + return tx + } + api.withLlmCustomAttributes( + { + 'toRename': 'value1', + 'llm.number': 1, + 'llm.boolean': true, + 'toDelete': () => {}, + 'toDelete2': {}, + 'toDelete3': [] + }, + () => { + const contextManager = tx._llmContextManager + const parentContext = contextManager.getStore() + t.equal(parentContext['llm.toRename'], 'value1') + t.notOk(parentContext.toDelete) + t.notOk(parentContext.toDelete2) + t.notOk(parentContext.toDelete3) + t.equal(parentContext['llm.number'], 1) + t.equal(parentContext['llm.boolean'], true) + t.end() + } + ) + }) + }) + + t.test('withLlmCustomAttributes should support branching', (t) => { + const { api } = t.context + const agent = api.agent + t.autoend() + helper.runInTransaction(api.agent, (tx) => { + agent.tracer.getTransaction = () => { + return tx + } + api.withLlmCustomAttributes( + { 'llm.step': '1', 'llm.path': 'root', 'llm.name': 'root' }, + () => { + const contextManager = tx._llmContextManager + const context = contextManager.getStore() + t.equal(context[`llm.step`], '1') + t.equal(context['llm.path'], 'root') + t.equal(context['llm.name'], 'root') + api.withLlmCustomAttributes({ 'llm.step': '1.1', 'llm.path': 'root/1' }, () => { + const contextManager = tx._llmContextManager + const context = contextManager.getStore() + t.equal(context[`llm.step`], '1.1') + t.equal(context['llm.path'], 'root/1') + t.equal(context['llm.name'], 'root') + }) + api.withLlmCustomAttributes({ 'llm.step': '1.2', 'llm.path': 'root/2' }, () => { + const contextManager = tx._llmContextManager + const context = contextManager.getStore() + t.equal(context[`llm.step`], '1.2') + t.equal(context['llm.path'], 'root/2') + t.equal(context['llm.name'], 'root') + t.end() + }) + } + ) + }) + }) + t.test('setLlmTokenCount should register callback to calculate token counts', async (t) => { const { api, agent } = t.context function callback(model, content) { diff --git a/test/unit/api/stub.test.js b/test/unit/api/stub.test.js index 68401c64a1..f31f0e9d07 100644 --- a/test/unit/api/stub.test.js +++ b/test/unit/api/stub.test.js @@ -8,7 +8,7 @@ const tap = require('tap') const API = require('../../../stub_api') -const EXPECTED_API_COUNT = 36 +const EXPECTED_API_COUNT = 37 tap.test('Agent API - Stubbed Agent API', (t) => { t.autoend() diff --git a/test/unit/instrumentation/openai.test.js b/test/unit/instrumentation/openai.test.js index 95be982840..49a8c2aace 100644 --- a/test/unit/instrumentation/openai.test.js +++ b/test/unit/instrumentation/openai.test.js @@ -119,5 +119,6 @@ test('openai unit tests', (t) => { t.equal(isWrapped, false, 'should not wrap chat completions create') t.end() }) + t.end() }) diff --git a/test/unit/util/llm-utils.test.js b/test/unit/util/llm-utils.test.js new file mode 100644 index 0000000000..666a42e3f2 --- /dev/null +++ b/test/unit/util/llm-utils.test.js @@ -0,0 +1,74 @@ +/* + * Copyright 2023 New Relic Corporation. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +'use strict' + +const tap = require('tap') +const { extractLlmAttributes, extractLlmContext } = require('../../../lib/util/llm-utils') +const { AsyncLocalStorage } = require('async_hooks') + +tap.test('extractLlmAttributes', (t) => { + const context = { + 'skip': 1, + 'llm.get': 2, + 'fllm.skip': 3 + } + + const llmContext = extractLlmAttributes(context) + t.notOk(llmContext.skip) + t.notOk(llmContext['fllm.skip']) + t.equal(llmContext['llm.get'], 2) + t.end() +}) + +tap.test('extractLlmContext', (t) => { + t.beforeEach((t) => { + const tx = { + _llmContextManager: new AsyncLocalStorage() + } + t.context.agent = { + tracer: { + getTransaction: () => { + return tx + } + } + } + t.context.tx = tx + }) + + t.test('handle empty context', (t) => { + const { tx, agent } = t.context + tx._llmContextManager.run(null, () => { + const llmContext = extractLlmContext(agent) + t.equal(typeof llmContext, 'object') + t.equal(Object.entries(llmContext).length, 0) + t.end() + }) + }) + + t.test('extract LLM context', (t) => { + const { tx, agent } = t.context + tx._llmContextManager.run({ 'llm.test': 1, 'skip': 2 }, () => { + const llmContext = extractLlmContext(agent) + t.equal(llmContext['llm.test'], 1) + t.notOk(llmContext.skip) + t.end() + }) + }) + + t.test('no transaction', (t) => { + const { tx, agent } = t.context + agent.tracer.getTransaction = () => { + return null + } + tx._llmContextManager.run(null, () => { + const llmContext = extractLlmContext(agent) + t.equal(typeof llmContext, 'object') + t.equal(Object.entries(llmContext).length, 0) + t.end() + }) + }) + t.end() +}) diff --git a/test/versioned/aws-sdk-v3/bedrock-chat-completions.tap.js b/test/versioned/aws-sdk-v3/bedrock-chat-completions.tap.js index 61b2bc9b42..cda36a41cc 100644 --- a/test/versioned/aws-sdk-v3/bedrock-chat-completions.tap.js +++ b/test/versioned/aws-sdk-v3/bedrock-chat-completions.tap.js @@ -158,6 +158,32 @@ tap.afterEach(async (t) => { } ) + tap.test( + `${modelId}: supports custom attributes on LlmChatCompletionMessage(s) and LlmChatCompletionSummary events`, + (t) => { + const { bedrock, client, agent } = t.context + const prompt = `text ${resKey} ultimate question` + const input = requests[resKey](prompt, modelId) + const command = new bedrock.InvokeModelCommand(input) + + const api = helper.getAgentApi() + helper.runInTransaction(agent, async (tx) => { + api.addCustomAttribute('llm.conversation_id', 'convo-id') + api.withLlmCustomAttributes({ 'llm.contextAttribute': 'someValue' }, async () => { + await client.send(command) + const events = agent.customEventAggregator.events.toArray() + + const chatSummary = events.filter(([{ type }]) => type === 'LlmChatCompletionSummary')[0] + const [, message] = chatSummary + t.equal(message['llm.contextAttribute'], 'someValue') + + tx.end() + t.end() + }) + }) + } + ) + tap.test(`${modelId}: text answer (streamed)`, (t) => { if (modelId.includes('ai21')) { t.skip('model does not support streaming') diff --git a/test/versioned/langchain/runnables.tap.js b/test/versioned/langchain/runnables.tap.js index e6ced1cc4c..a2d3aacc7f 100644 --- a/test/versioned/langchain/runnables.tap.js +++ b/test/versioned/langchain/runnables.tap.js @@ -52,7 +52,6 @@ tap.test('Langchain instrumentation - runnable sequence', (t) => { t.test('should create langchain events for every invoke call', (t) => { const { agent, prompt, outputParser, model } = t.context - helper.runInTransaction(agent, async (tx) => { const input = { topic: 'scientist' } const options = { metadata: { key: 'value', hello: 'world' }, tags: ['tag1', 'tag2'] } @@ -95,11 +94,31 @@ tap.test('Langchain instrumentation - runnable sequence', (t) => { }) }) + t.test('should support custom attributes on the LLM events', (t) => { + const { agent, prompt, outputParser, model } = t.context + const api = helper.getAgentApi() + helper.runInTransaction(agent, async (tx) => { + api.withLlmCustomAttributes({ 'llm.contextAttribute': 'someValue' }, async () => { + const input = { topic: 'scientist' } + const options = { metadata: { key: 'value', hello: 'world' }, tags: ['tag1', 'tag2'] } + + const chain = prompt.pipe(model).pipe(outputParser) + await chain.invoke(input, options) + const events = agent.customEventAggregator.events.toArray() + + const [[, message]] = events + t.equal(message['llm.contextAttribute'], 'someValue') + + tx.end() + t.end() + }) + }) + }) + t.test( 'should create langchain events for every invoke call on chat prompt + model + parser', (t) => { const { agent, prompt, outputParser, model } = t.context - helper.runInTransaction(agent, async (tx) => { const input = { topic: 'scientist' } const options = { metadata: { key: 'value', hello: 'world' }, tags: ['tag1', 'tag2'] } diff --git a/test/versioned/openai/chat-completions.tap.js b/test/versioned/openai/chat-completions.tap.js index 67334d7fd9..322691bebb 100644 --- a/test/versioned/openai/chat-completions.tap.js +++ b/test/versioned/openai/chat-completions.tap.js @@ -448,4 +448,64 @@ tap.test('OpenAI instrumentation - chat completions', (t) => { test.end() }) }) + + t.test('should record LLM custom events with attributes', (test) => { + const { client, agent } = t.context + const api = helper.getAgentApi() + + helper.runInTransaction(agent, () => { + api.withLlmCustomAttributes({ 'llm.shared': true, 'llm.path': 'root/' }, async () => { + await api.withLlmCustomAttributes( + { 'llm.path': 'root/branch1', 'llm.attr1': true }, + async () => { + agent.config.ai_monitoring.streaming.enabled = true + const model = 'gpt-3.5-turbo-0613' + const content = 'You are a mathematician.' + await client.chat.completions.create({ + max_tokens: 100, + temperature: 0.5, + model, + messages: [ + { role: 'user', content }, + { role: 'user', content: 'What does 1 plus 1 equal?' } + ] + }) + } + ) + + await api.withLlmCustomAttributes( + { 'llm.path': 'root/branch2', 'llm.attr2': true }, + async () => { + agent.config.ai_monitoring.streaming.enabled = true + const model = 'gpt-3.5-turbo-0613' + const content = 'You are a mathematician.' + await client.chat.completions.create({ + max_tokens: 100, + temperature: 0.5, + model, + messages: [ + { role: 'user', content }, + { role: 'user', content: 'What does 1 plus 2 equal?' } + ] + }) + } + ) + + const events = agent.customEventAggregator.events.toArray().map((event) => event[1]) + + events.forEach((event) => { + t.ok(event['llm.shared']) + if (event['llm.path'] === 'root/branch1') { + t.ok(event['llm.attr1']) + t.notOk(event['llm.attr2']) + } else { + t.ok(event['llm.attr2']) + t.notOk(event['llm.attr1']) + } + }) + + test.end() + }) + }) + }) })