Skip to content

Commit

Permalink
(chore)APM: Refactor Bedrock Integration (#5137)
Browse files Browse the repository at this point in the history
* refactor apm tracing

* Update packages/datadog-plugin-aws-sdk/src/services/bedrockruntime/index.js

Co-authored-by: Sam Brenner <[email protected]>

* Update packages/datadog-plugin-aws-sdk/src/services/bedrockruntime/tracing.js

Co-authored-by: Sam Brenner <[email protected]>

* Update packages/datadog-plugin-aws-sdk/src/services/bedrockruntime/tracing.js

Co-authored-by: Sam Brenner <[email protected]>

* Update packages/datadog-plugin-aws-sdk/src/services/bedrockruntime/utils.js

Co-authored-by: Sam Brenner <[email protected]>

* CODEOWNERS

* remove shouldSetChoiceId override

* remove shouldSetChoiceId override

* lint

* Update packages/datadog-instrumentations/src/aws-sdk.js

Co-authored-by: Thomas Hunter II <[email protected]>

---------

Co-authored-by: Sam Brenner <[email protected]>
Co-authored-by: Thomas Hunter II <[email protected]>
  • Loading branch information
3 people authored Jan 24, 2025
1 parent 30efc06 commit f41f5f7
Show file tree
Hide file tree
Showing 6 changed files with 151 additions and 78 deletions.
2 changes: 2 additions & 0 deletions CODEOWNERS
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@
/packages/datadog-plugin-langchain/ @DataDog/ml-observability
/packages/datadog-instrumentations/src/openai.js @DataDog/ml-observability
/packages/datadog-instrumentations/src/langchain.js @DataDog/ml-observability
/packages/datadog-plugin-aws-sdk/src/services/bedrockruntime @DataDog/ml-observability
/packages/datadog-plugin-aws-sdk/test/bedrockruntime.spec.js @DataDog/ml-observability

# CI
/.github/workflows/appsec.yml @DataDog/asm-js
Expand Down
4 changes: 3 additions & 1 deletion packages/datadog-instrumentations/src/aws-sdk.js
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,8 @@ function getMessage (request, error, result) {
}

function getChannelSuffix (name) {
// some resource identifiers have spaces between ex: bedrock runtime
name = name.replaceAll(' ', '')
return [
'cloudwatchlogs',
'dynamodb',
Expand All @@ -168,7 +170,7 @@ function getChannelSuffix (name) {
'sqs',
'states',
'stepfunctions',
'bedrock runtime'
'bedrockruntime'
].includes(name)
? name
: 'default'
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
const CompositePlugin = require('../../../../dd-trace/src/plugins/composite')
const BedrockRuntimeTracing = require('./tracing')
class BedrockRuntimePlugin extends CompositePlugin {
static get id () {
return 'bedrockruntime'
}

static get plugins () {
return {
tracing: BedrockRuntimeTracing
}
}
}
module.exports = BedrockRuntimePlugin
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
'use strict'

const BaseAwsSdkPlugin = require('../../base')
const { parseModelId, extractRequestParams, extractTextAndResponseReason } = require('./utils')

const enabledOperations = ['invokeModel']

class BedrockRuntime extends BaseAwsSdkPlugin {
static get id () { return 'bedrockruntime' }

isEnabled (request) {
const operation = request.operation
if (!enabledOperations.includes(operation)) {
return false
}

return super.isEnabled(request)
}

generateTags (params, operation, response) {
const { modelProvider, modelName } = parseModelId(params.modelId)

const requestParams = extractRequestParams(params, modelProvider)
const textAndResponseReason = extractTextAndResponseReason(response, modelProvider, modelName)

const tags = buildTagsFromParams(requestParams, textAndResponseReason, modelProvider, modelName, operation)

return tags
}
}

function buildTagsFromParams (requestParams, textAndResponseReason, modelProvider, modelName, operation) {
const tags = {}

// add request tags
tags['resource.name'] = operation
tags['aws.bedrock.request.model'] = modelName
tags['aws.bedrock.request.model_provider'] = modelProvider.toLowerCase()
tags['aws.bedrock.request.prompt'] = requestParams.prompt
tags['aws.bedrock.request.temperature'] = requestParams.temperature
tags['aws.bedrock.request.top_p'] = requestParams.topP
tags['aws.bedrock.request.top_k'] = requestParams.topK
tags['aws.bedrock.request.max_tokens'] = requestParams.maxTokens
tags['aws.bedrock.request.stop_sequences'] = requestParams.stopSequences
tags['aws.bedrock.request.input_type'] = requestParams.inputType
tags['aws.bedrock.request.truncate'] = requestParams.truncate
tags['aws.bedrock.request.stream'] = requestParams.stream
tags['aws.bedrock.request.n'] = requestParams.n

// add response tags
if (modelName.includes('embed')) {
tags['aws.bedrock.response.embedding_length'] = textAndResponseReason.message.length
}
if (textAndResponseReason.choiceId) {
tags['aws.bedrock.response.choices.id'] = textAndResponseReason.choiceId
}
tags['aws.bedrock.response.choices.text'] = textAndResponseReason.message
tags['aws.bedrock.response.choices.finish_reason'] = textAndResponseReason.finishReason

return tags
}

module.exports = BedrockRuntime
Original file line number Diff line number Diff line change
@@ -1,7 +1,17 @@
'use strict'

const BaseAwsSdkPlugin = require('../base')
const log = require('../../../dd-trace/src/log')
const log = require('../../../../dd-trace/src/log')

const MODEL_TYPE_IDENTIFIERS = [
'foundation-model/',
'custom-model/',
'provisioned-model/',
'imported-module/',
'prompt/',
'endpoint/',
'inference-profile/',
'default-prompt-router/'
]

const PROVIDER = {
AI21: 'AI21',
Expand All @@ -13,44 +23,6 @@ const PROVIDER = {
MISTRAL: 'MISTRAL'
}

const enabledOperations = ['invokeModel']

class BedrockRuntime extends BaseAwsSdkPlugin {
static get id () { return 'bedrock runtime' }

isEnabled (request) {
const operation = request.operation
if (!enabledOperations.includes(operation)) {
return false
}

return super.isEnabled(request)
}

generateTags (params, operation, response) {
let tags = {}
let modelName = ''
let modelProvider = ''
const modelMeta = params.modelId.split('.')
if (modelMeta.length === 2) {
[modelProvider, modelName] = modelMeta
modelProvider = modelProvider.toUpperCase()
} else {
[, modelProvider, modelName] = modelMeta
modelProvider = modelProvider.toUpperCase()
}

const shouldSetChoiceIds = modelProvider === PROVIDER.COHERE && !modelName.includes('embed')

const requestParams = extractRequestParams(params, modelProvider)
const textAndResponseReason = extractTextAndResponseReason(response, modelProvider, modelName, shouldSetChoiceIds)

tags = buildTagsFromParams(requestParams, textAndResponseReason, modelProvider, modelName, operation)

return tags
}
}

class Generation {
constructor ({ message = '', finishReason = '', choiceId = '' } = {}) {
// stringify message as it could be a single generated message as well as a list of embeddings
Expand All @@ -65,18 +37,19 @@ class RequestParams {
prompt = '',
temperature = undefined,
topP = undefined,
topK = undefined,
maxTokens = undefined,
stopSequences = [],
inputType = '',
truncate = '',
stream = '',
n = undefined
} = {}) {
// TODO: set a truncation limit to prompt
// stringify prompt as it could be a single prompt as well as a list of message objects
this.prompt = typeof prompt === 'string' ? prompt : JSON.stringify(prompt) || ''
this.temperature = temperature !== undefined ? temperature : undefined
this.topP = topP !== undefined ? topP : undefined
this.topK = topK !== undefined ? topK : undefined
this.maxTokens = maxTokens !== undefined ? maxTokens : undefined
this.stopSequences = stopSequences || []
this.inputType = inputType || ''
Expand All @@ -86,11 +59,53 @@ class RequestParams {
}
}

function parseModelId (modelId) {
// Best effort to extract the model provider and model name from the bedrock model ID.
// modelId can be a 1/2 period-separated string or a full AWS ARN, based on the following formats:
// 1. Base model: "{model_provider}.{model_name}"
// 2. Cross-region model: "{region}.{model_provider}.{model_name}"
// 3. Other: Prefixed by AWS ARN "arn:aws{+region?}:bedrock:{region}:{account-id}:"
// a. Foundation model: ARN prefix + "foundation-model/{region?}.{model_provider}.{model_name}"
// b. Custom model: ARN prefix + "custom-model/{model_provider}.{model_name}"
// c. Provisioned model: ARN prefix + "provisioned-model/{model-id}"
// d. Imported model: ARN prefix + "imported-module/{model-id}"
// e. Prompt management: ARN prefix + "prompt/{prompt-id}"
// f. Sagemaker: ARN prefix + "endpoint/{model-id}"
// g. Inference profile: ARN prefix + "{application-?}inference-profile/{model-id}"
// h. Default prompt router: ARN prefix + "default-prompt-router/{prompt-id}"
// If model provider cannot be inferred from the modelId formatting, then default to "custom"
modelId = modelId.toLowerCase()
if (!modelId.startsWith('arn:aws')) {
const modelMeta = modelId.split('.')
if (modelMeta.length < 2) {
return { modelProvider: 'custom', modelName: modelMeta[0] }
}
return { modelProvider: modelMeta[modelMeta.length - 2], modelName: modelMeta[modelMeta.length - 1] }
}

for (const identifier of MODEL_TYPE_IDENTIFIERS) {
if (!modelId.includes(identifier)) {
continue
}
modelId = modelId.split(identifier).pop()
if (['foundation-model/', 'custom-model/'].includes(identifier)) {
const modelMeta = modelId.split('.')
if (modelMeta.length < 2) {
return { modelProvider: 'custom', modelName: modelId }
}
return { modelProvider: modelMeta[modelMeta.length - 2], modelName: modelMeta[modelMeta.length - 1] }
}
return { modelProvider: 'custom', modelName: modelId }
}

return { modelProvider: 'custom', modelName: 'custom' }
}

function extractRequestParams (params, provider) {
const requestBody = JSON.parse(params.body)
const modelId = params.modelId

switch (provider) {
switch (provider.toUpperCase()) {
case PROVIDER.AI21: {
let userPrompt = requestBody.prompt
if (modelId.includes('jamba')) {
Expand Down Expand Up @@ -176,11 +191,11 @@ function extractRequestParams (params, provider) {
}
}

function extractTextAndResponseReason (response, provider, modelName, shouldSetChoiceIds) {
function extractTextAndResponseReason (response, provider, modelName) {
const body = JSON.parse(Buffer.from(response.body).toString('utf8'))

const shouldSetChoiceIds = provider.toUpperCase() === PROVIDER.COHERE && !modelName.includes('embed')
try {
switch (provider) {
switch (provider.toUpperCase()) {
case PROVIDER.AI21: {
if (modelName.includes('jamba')) {
const generations = body.choices || []
Expand Down Expand Up @@ -262,34 +277,11 @@ function extractTextAndResponseReason (response, provider, modelName, shouldSetC
return new Generation()
}

function buildTagsFromParams (requestParams, textAndResponseReason, modelProvider, modelName, operation) {
const tags = {}

// add request tags
tags['resource.name'] = operation
tags['aws.bedrock.request.model'] = modelName
tags['aws.bedrock.request.model_provider'] = modelProvider
tags['aws.bedrock.request.prompt'] = requestParams.prompt
tags['aws.bedrock.request.temperature'] = requestParams.temperature
tags['aws.bedrock.request.top_p'] = requestParams.topP
tags['aws.bedrock.request.max_tokens'] = requestParams.maxTokens
tags['aws.bedrock.request.stop_sequences'] = requestParams.stopSequences
tags['aws.bedrock.request.input_type'] = requestParams.inputType
tags['aws.bedrock.request.truncate'] = requestParams.truncate
tags['aws.bedrock.request.stream'] = requestParams.stream
tags['aws.bedrock.request.n'] = requestParams.n

// add response tags
if (modelName.includes('embed')) {
tags['aws.bedrock.response.embedding_length'] = textAndResponseReason.message.length
}
if (textAndResponseReason.choiceId) {
tags['aws.bedrock.response.choices.id'] = textAndResponseReason.choiceId
}
tags['aws.bedrock.response.choices.text'] = textAndResponseReason.message
tags['aws.bedrock.response.choices.finish_reason'] = textAndResponseReason.finishReason

return tags
module.exports = {
Generation,
RequestParams,
parseModelId,
extractRequestParams,
extractTextAndResponseReason,
PROVIDER
}

module.exports = BedrockRuntime
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ const PROVIDER = {
}

describe('Plugin', () => {
describe('aws-sdk (bedrock)', function () {
describe('aws-sdk (bedrockruntime)', function () {
setup()

withVersions('aws-sdk', ['@aws-sdk/smithy-client', 'aws-sdk'], '>=3', (version, moduleName) => {
Expand Down Expand Up @@ -217,7 +217,7 @@ describe('Plugin', () => {
expect(span.meta).to.include({
'aws.operation': 'invokeModel',
'aws.bedrock.request.model': model.modelId.split('.')[1],
'aws.bedrock.request.model_provider': model.provider,
'aws.bedrock.request.model_provider': model.provider.toLowerCase(),
'aws.bedrock.request.prompt': model.userPrompt
})
expect(span.metrics).to.include({
Expand Down

0 comments on commit f41f5f7

Please sign in to comment.