diff --git a/agent-bridge/src/main/java/com/newrelic/agent/bridge/NoOpAgent.java b/agent-bridge/src/main/java/com/newrelic/agent/bridge/NoOpAgent.java index 547c76d887..904f3fdcfb 100644 --- a/agent-bridge/src/main/java/com/newrelic/agent/bridge/NoOpAgent.java +++ b/agent-bridge/src/main/java/com/newrelic/agent/bridge/NoOpAgent.java @@ -7,6 +7,7 @@ package com.newrelic.agent.bridge; +import com.newrelic.api.agent.AiMonitoring; import com.newrelic.api.agent.Config; import com.newrelic.api.agent.ErrorApi; import com.newrelic.api.agent.Insights; @@ -18,6 +19,8 @@ import java.util.Collections; import java.util.Map; +import static com.newrelic.agent.bridge.NoOpAiMonitoring.INSTANCE; + class NoOpAgent implements Agent { static final Agent INSTANCE = new NoOpAgent(); @@ -65,6 +68,11 @@ public Insights getInsights() { return NoOpInsights.INSTANCE; } + @Override + public AiMonitoring getAiMonitoring() { + return NoOpAiMonitoring.INSTANCE; + } + @Override public ErrorApi getErrorApi() { return NoOpErrorApi.INSTANCE; diff --git a/agent-bridge/src/main/java/com/newrelic/agent/bridge/NoOpAiMonitoring.java b/agent-bridge/src/main/java/com/newrelic/agent/bridge/NoOpAiMonitoring.java new file mode 100644 index 0000000000..db26d3ac71 --- /dev/null +++ b/agent-bridge/src/main/java/com/newrelic/agent/bridge/NoOpAiMonitoring.java @@ -0,0 +1,21 @@ +package com.newrelic.agent.bridge; + +import com.newrelic.api.agent.AiMonitoring; +import com.newrelic.api.agent.LlmTokenCountCallback; + +import java.util.Map; + +public class NoOpAiMonitoring implements AiMonitoring { + + static final AiMonitoring INSTANCE = new NoOpAiMonitoring(); + + private NoOpAiMonitoring() {} + + @Override + public void recordLlmFeedbackEvent(Map llmFeedbackEventAttributes) { + } + + @Override + public void setLlmTokenCountCallback(LlmTokenCountCallback llmTokenCountCallback) { + } +} diff --git a/agent-bridge/src/main/java/com/newrelic/agent/bridge/NoOpTransaction.java b/agent-bridge/src/main/java/com/newrelic/agent/bridge/NoOpTransaction.java index 4449df5d57..9c79f6964b 100644 --- a/agent-bridge/src/main/java/com/newrelic/agent/bridge/NoOpTransaction.java +++ b/agent-bridge/src/main/java/com/newrelic/agent/bridge/NoOpTransaction.java @@ -27,6 +27,7 @@ public class NoOpTransaction implements Transaction { public static final Transaction INSTANCE = new NoOpTransaction(); public static final NoOpMap AGENT_ATTRIBUTES = new NoOpMap<>(); + public static final NoOpMap USER_ATTRIBUTES = new NoOpMap<>(); @Override public void beforeSendResponseHeaders() { @@ -153,6 +154,11 @@ public Map getAgentAttributes() { return AGENT_ATTRIBUTES; } + @Override + public Map getUserAttributes() { + return USER_ATTRIBUTES; + } + @Override public boolean registerAsyncActivity(Object activityContext) { return false; diff --git a/agent-bridge/src/main/java/com/newrelic/agent/bridge/Transaction.java b/agent-bridge/src/main/java/com/newrelic/agent/bridge/Transaction.java index 9f89f2064d..fc4061973e 100644 --- a/agent-bridge/src/main/java/com/newrelic/agent/bridge/Transaction.java +++ b/agent-bridge/src/main/java/com/newrelic/agent/bridge/Transaction.java @@ -23,6 +23,7 @@ public interface Transaction extends com.newrelic.api.agent.Transaction { Map getAgentAttributes(); + Map getUserAttributes(); /** * Sets the current transaction's name. diff --git a/agent-bridge/src/main/java/com/newrelic/agent/bridge/aimonitoring/AiMonitoringUtils.java b/agent-bridge/src/main/java/com/newrelic/agent/bridge/aimonitoring/AiMonitoringUtils.java new file mode 100644 index 0000000000..5332df36c8 --- /dev/null +++ b/agent-bridge/src/main/java/com/newrelic/agent/bridge/aimonitoring/AiMonitoringUtils.java @@ -0,0 +1,78 @@ +/* + * + * * Copyright 2024 New Relic Corporation. All rights reserved. + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package com.newrelic.agent.bridge.aimonitoring; + +import com.newrelic.api.agent.Config; +import com.newrelic.api.agent.NewRelic; + +import java.util.logging.Level; + +public class AiMonitoringUtils { + // Enabled defaults + private static final boolean AI_MONITORING_ENABLED_DEFAULT = false; + private static final boolean AI_MONITORING_STREAMING_ENABLED_DEFAULT = true; + private static final boolean AI_MONITORING_RECORD_CONTENT_ENABLED_DEFAULT = true; + private static final boolean HIGH_SECURITY_ENABLED_DEFAULT = false; + + /** + * Check if ai_monitoring features are enabled. + * Indicates whether LLM instrumentation will be registered. If this is set to False, no metrics, events, or spans are to be sent. + * + * @return true if AI monitoring is enabled, else false + */ + public static boolean isAiMonitoringEnabled() { + Config config = NewRelic.getAgent().getConfig(); + Boolean aimEnabled = config.getValue("ai_monitoring.enabled", AI_MONITORING_ENABLED_DEFAULT); + Boolean highSecurity = config.getValue("high_security", HIGH_SECURITY_ENABLED_DEFAULT); + + if (highSecurity || !aimEnabled) { + aimEnabled = false; + String disabledReason = highSecurity ? "High Security Mode." : "agent config."; + NewRelic.getAgent().getLogger().log(Level.FINE, "AIM: AI Monitoring is disabled due to " + disabledReason); + NewRelic.incrementCounter("Supportability/Java/ML/Disabled"); + } else { + NewRelic.incrementCounter("Supportability/Java/ML/Enabled"); + } + + return aimEnabled; + } + + /** + * Check if ai_monitoring.streaming features are enabled. + * + * @return true if streaming is enabled, else false + */ + public static boolean isAiMonitoringStreamingEnabled() { + Boolean enabled = NewRelic.getAgent().getConfig().getValue("ai_monitoring.streaming.enabled", AI_MONITORING_STREAMING_ENABLED_DEFAULT); + + if (enabled) { + NewRelic.incrementCounter("Supportability/Java/ML/Streaming/Enabled"); + } else { + NewRelic.incrementCounter("Supportability/Java/ML/Streaming/Disabled"); + } + + return enabled; + } + + /** + * Check if the input and output content should be added to LLM events. + * + * @return true if adding content is enabled, else false + */ + public static boolean isAiMonitoringRecordContentEnabled() { + Boolean enabled = NewRelic.getAgent().getConfig().getValue("ai_monitoring.record_content.enabled", AI_MONITORING_RECORD_CONTENT_ENABLED_DEFAULT); + + if (enabled) { + NewRelic.incrementCounter("Supportability/Java/ML/RecordContent/Enabled"); + } else { + NewRelic.incrementCounter("Supportability/Java/ML/RecordContent/Disabled"); + } + + return enabled; + } +} diff --git a/agent-bridge/src/main/java/com/newrelic/agent/bridge/aimonitoring/LlmTokenCountCallbackHolder.java b/agent-bridge/src/main/java/com/newrelic/agent/bridge/aimonitoring/LlmTokenCountCallbackHolder.java new file mode 100644 index 0000000000..6c89efd7e0 --- /dev/null +++ b/agent-bridge/src/main/java/com/newrelic/agent/bridge/aimonitoring/LlmTokenCountCallbackHolder.java @@ -0,0 +1,30 @@ +package com.newrelic.agent.bridge.aimonitoring; + +import com.newrelic.api.agent.LlmTokenCountCallback; + +/** + * A thread-safe holder for an instance of {@link LlmTokenCountCallback}. + * This class provides methods for setting and retrieving the callback instance. + */ +public class LlmTokenCountCallbackHolder { + + private static volatile LlmTokenCountCallback llmTokenCountCallback = null; + + /** + * Sets the {@link LlmTokenCountCallback} instance to be stored. + * + * @param newLlmTokenCountCallback the callback instance + */ + public static void setLlmTokenCountCallback(LlmTokenCountCallback newLlmTokenCountCallback) { + llmTokenCountCallback = newLlmTokenCountCallback; + } + + /** + * Retrieves the stored {@link LlmTokenCountCallback} instance. + * + * @return stored callback instance + */ + public static LlmTokenCountCallback getLlmTokenCountCallback() { + return llmTokenCountCallback; + } +} \ No newline at end of file diff --git a/agent-model/src/main/java/com/newrelic/agent/model/LlmCustomInsightsEvent.java b/agent-model/src/main/java/com/newrelic/agent/model/LlmCustomInsightsEvent.java new file mode 100644 index 0000000000..44798d0f76 --- /dev/null +++ b/agent-model/src/main/java/com/newrelic/agent/model/LlmCustomInsightsEvent.java @@ -0,0 +1,30 @@ +/* + * + * * Copyright 2024 New Relic Corporation. All rights reserved. + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package com.newrelic.agent.model; + +/** + * Represents an internal subtype of a CustomInsightsEvent that is sent to the + * custom_event_data collector endpoint but potentially subject to different + * validation rules and agent configuration. + */ +public class LlmCustomInsightsEvent { + // LLM event types + private static final String LLM_EMBEDDING = "LlmEmbedding"; + private static final String LLM_CHAT_COMPLETION_SUMMARY = "LlmChatCompletionSummary"; + private static final String LLM_CHAT_COMPLETION_MESSAGE = "LlmChatCompletionMessage"; + + /** + * Determines if a CustomInsightsEvent should be treated as a LlmEvent + * + * @param eventType type of the current event + * @return true if eventType is an LlmEvent, else false + */ + public static boolean isLlmEvent(String eventType) { + return eventType.equals(LLM_EMBEDDING) || eventType.equals(LLM_CHAT_COMPLETION_MESSAGE) || eventType.equals(LLM_CHAT_COMPLETION_SUMMARY); + } +} diff --git a/functional_test/src/test/java/com/newrelic/agent/extension/FakeExtensionAgent.java b/functional_test/src/test/java/com/newrelic/agent/extension/FakeExtensionAgent.java index 660f6ab86c..e6bb7559d2 100644 --- a/functional_test/src/test/java/com/newrelic/agent/extension/FakeExtensionAgent.java +++ b/functional_test/src/test/java/com/newrelic/agent/extension/FakeExtensionAgent.java @@ -10,6 +10,7 @@ import com.newrelic.agent.bridge.Agent; import com.newrelic.agent.bridge.TracedMethod; import com.newrelic.agent.bridge.Transaction; +import com.newrelic.api.agent.AiMonitoring; import com.newrelic.api.agent.Config; import com.newrelic.api.agent.ErrorApi; import com.newrelic.api.agent.Insights; @@ -38,6 +39,11 @@ public Logger getLogger() { @Override public Insights getInsights() { throw new RuntimeException(); } + @Override + public AiMonitoring getAiMonitoring() { + return null; + } + @Override public ErrorApi getErrorApi() { throw new RuntimeException(); } diff --git a/instrumentation/aws-bedrock-runtime-2.20/README.md b/instrumentation/aws-bedrock-runtime-2.20/README.md new file mode 100644 index 0000000000..37d019c1c9 --- /dev/null +++ b/instrumentation/aws-bedrock-runtime-2.20/README.md @@ -0,0 +1,191 @@ +# AWS Bedrock Runtime Instrumentation + +## About + +Instruments invocations of LLMs made by the AWS Bedrock Runtime SDK. + +## Support + +### Supported Clients/APIs + +The following AWS Bedrock Runtime clients and APIs are supported: + +* `BedrockRuntimeClient` + * `invokeModel` +* `BedrockRuntimeAsyncClient` + * `invokeModel` + +Note: Currently, `invokeModelWithResponseStream` is not supported. + +### Supported Models + +At the time of the instrumentation being published, the following text-based foundation models have been tested and confirmed as supported. As long as the model ID for an invoked LLM model contains one of the prefixes defined in `SupportedModels`, the instrumentation should attempt to process the request/response. However, if the request/response structure significantly changes the processing may fail. See the `README` for each model in `llm.models.*` for more details on each. + +* AI21 Labs + * Jurassic-2 Ultra (`ai21.j2-ultra-v1`) + * Jurassic-2 Mid (`ai21.j2-mid-v1`) +* Amazon + * Titan Embeddings G1 - Text (`amazon.titan-embed-text-v1`) + * Titan Text G1 - Lite (`amazon.titan-text-lite-v1`) + * Titan Text G1 - Express (`amazon.titan-text-express-v1`) + * Titan Multimodal Embeddings G1 (`amazon.titan-embed-image-v1`) +* Anthropic + * Claude (`anthropic.claude-v2`, `anthropic.claude-v2:1`) + * Claude Instant (`anthropic.claude-instant-v1`) +* Cohere + * Command (`cohere.command-text-v14`) + * Command Light (`cohere.command-light-text-v14`) + * Embed English (`cohere.embed-english-v3`) + * Embed Multilingual (`cohere.embed-multilingual-v3`) +* Meta + * Llama 2 Chat 13B (`meta.llama2-13b-chat-v1`) + * Llama 2 Chat 70B (`meta.llama2-70b-chat-v1`) + +## Involved Pieces + +### LLM Events + +The main goal of this instrumentation is to generate the following LLM events to drive the UI. + +* `LlmEmbedding`: An event that captures data specific to the creation of an embedding. +* `LlmChatCompletionSummary`: An event that captures high-level data about the creation of a chat completion including request, response, and call information. +* `LlmChatCompletionMessage`: An event that corresponds to each message (sent and received) from a chat completion call including those created by the user, assistant, and the system. + +These events are custom events sent via the public `recordCustomEvent` API. Currently, they contribute towards the following Custom Insights Events limits (this will likely change in the future). + +```yaml + custom_insights_events: + max_samples_stored: 100000 +``` + +Because of this, it is recommended to increase `custom_insights_events.max_samples_stored` to the maximum value of 100,000 to best avoid sampling issue. LLM events are sent to the `custom_event_data` collector endpoint but the backend will assign them a unique namespace to distinguish them from other custom events. + +### Attributes + +#### Agent Attributes + +An `llm: true` agent attribute will be set on all Transaction events where one of the supported Bedrock methods is invoked within an active transaction. + +#### LLM Event Attributes + +Attributes on LLM events use the same configuration and size limits as `custom_insights_events` with two notable exceptions being that the following two LLM event attributes will not be truncated at all: +* `content` +* `input` + +This is done so that token usage can be calculated on the backend based on the full input and output content. + +#### Custom LLM Attributes + +Any custom attributes added by customers using the `addCustomParameters` API that are prefixed with `llm.` will automatically be copied to `LlmEvent`s. For custom attributes added by the `addCustomParameters` API to be added to `LlmEvent`s the API calls must occur before invoking the Bedrock SDK. + +One potential custom attribute with special meaning that customers are encouraged to add is `llm.conversation_id`, which has implications in the UI and can be used to group LLM messages into specific conversations. + +### Metrics + +When in an active transaction a named span/segment for each LLM embedding and chat completion call is created using the following format: + +`Llm/{operation_type}/{vendor_name}/{function_name}` + +* `operation_type`: `completion` or `embedding` +* `vendor_name`: Name of LLM vendor (ex: `OpenAI`, `Bedrock`) +* `function_name`: Name of instrumented function (ex: `invokeModel`, `create`) + +A supportability metric is reported each time an instrumented framework method is invoked. These metrics are detected and parsed by APM Services to support entity tagging in the UI, if a metric isn't reported within the past day the LLM UI will not display in APM. The metric uses the following format: + +`Supportability/{language}/ML/{vendor_name}/{vendor_version}` + +* `language`: Name of language agent (ex: `Java`) +* `vendor_name`: Name of LLM vendor (ex: `Bedrock`) +* `vendor_version`: Version of instrumented LLM library (ex: `2.20`) + +Note: The vendor version isn't obtainable from the AWS Bedrock SDK for Java so the instrumentation version is used instead. + +Additionally, the following supportability metrics are recorded to indicate the agent config state. + +``` +Supportability/Java/ML/Enabled +Supportability/Java/ML/Disabled + +Supportability/Java/ML/Streaming/Enabled +Supportability/Java/ML/Streaming/Disabled + +Supportability/Java/ML/RecordContent/Enabled +Supportability/Java/ML/RecordContent/Disabled +``` + +## Config + +### Yaml + +`ai_monitoring.enabled`: Provides control over all AI Monitoring functionality. Set as true to enable all AI Monitoring features. +`ai_monitoring.record_content.enabled`: Provides control over whether attributes for the input and output content should be added to LLM events. Set as false to disable attributes for the input and output content. +`ai_monitoring.streaming.enabled`: NOT SUPPORTED + +### Environment Variable + +``` +NEW_RELIC_AI_MONITORING_ENABLED +NEW_RELIC_AI_MONITORING_RECORD_CONTENT_ENABLED +NEW_RELIC_AI_MONITORING_STREAMING_ENABLED +``` + +### System Property + +``` +-Dnewrelic.config.ai_monitoring.enabled +-Dnewrelic.config.ai_monitoring.record_content.enabled +-Dnewrelic.config.ai_monitoring.streaming.enabled +``` + +## Related Agent APIs + +AI monitoring can be enhanced by using the following agent APIs: +* `recordLlmFeedbackEvent` - Can be used to record an LlmFeedback event to associate user feedback with a specific distributed trace. +* `setLlmTokenCountCallback` - Can be used to register a Callback that provides a token count. +* `addCustomParameter` - Used to add custom attributed to LLM events. See [Custom LLM Attributes](#custom-llm-attributes) + +## Known Issues + +When using the `BedrockRuntimeAsyncClient`, which returns the response as a `CompletableFuture`, the external call to AWS isn't being captured. This is likely deeper instrumentation of the awssdk core classes, perhaps the `software.amazon.awssdk.core.internal.http.AmazonAsyncHttpClient` or `software.amazon.awssdk.http.nio.netty.NettyNioAsyncHttpClient`. The external call is actually made by `NettyRequestExecutor(ctx)).execute()` + +```java +"http-nio-8081-exec-9@16674" tid=0x56 nid=NA runnable + java.lang.Thread.State: RUNNABLE + at software.amazon.awssdk.http.nio.netty.internal.NettyRequestExecutor.execute(NettyRequestExecutor.java:92) + at software.amazon.awssdk.http.nio.netty.NettyNioAsyncHttpClient.execute(NettyNioAsyncHttpClient.java:123) + at software.amazon.awssdk.core.internal.http.pipeline.stages.MakeAsyncHttpRequestStage.doExecuteHttpRequest(MakeAsyncHttpRequestStage.java:189) + at software.amazon.awssdk.core.internal.http.pipeline.stages.MakeAsyncHttpRequestStage.executeHttpRequest(MakeAsyncHttpRequestStage.java:147) + at software.amazon.awssdk.core.internal.http.pipeline.stages.MakeAsyncHttpRequestStage.lambda$execute$1(MakeAsyncHttpRequestStage.java:99) + at software.amazon.awssdk.core.internal.http.pipeline.stages.MakeAsyncHttpRequestStage$$Lambda/0x0000000800aefa78.accept(Unknown Source:-1) + at java.util.concurrent.CompletableFuture.uniAcceptNow(CompletableFuture.java:757) + at java.util.concurrent.CompletableFuture.uniAcceptStage(CompletableFuture.java:735) + at java.util.concurrent.CompletableFuture.thenAccept(CompletableFuture.java:2214) + at software.amazon.awssdk.core.internal.http.pipeline.stages.MakeAsyncHttpRequestStage.execute(MakeAsyncHttpRequestStage.java:95) + at software.amazon.awssdk.core.internal.http.pipeline.stages.MakeAsyncHttpRequestStage.execute(MakeAsyncHttpRequestStage.java:60) + at software.amazon.awssdk.core.internal.http.pipeline.RequestPipelineBuilder$ComposingRequestPipelineStage.execute(RequestPipelineBuilder.java:206) + at software.amazon.awssdk.core.internal.http.pipeline.stages.AsyncApiCallAttemptMetricCollectionStage.execute(AsyncApiCallAttemptMetricCollectionStage.java:56) + at software.amazon.awssdk.core.internal.http.pipeline.stages.AsyncApiCallAttemptMetricCollectionStage.execute(AsyncApiCallAttemptMetricCollectionStage.java:38) + at software.amazon.awssdk.core.internal.http.pipeline.stages.AsyncRetryableStage$RetryingExecutor.attemptExecute(AsyncRetryableStage.java:144) + at software.amazon.awssdk.core.internal.http.pipeline.stages.AsyncRetryableStage$RetryingExecutor.maybeAttemptExecute(AsyncRetryableStage.java:136) + at software.amazon.awssdk.core.internal.http.pipeline.stages.AsyncRetryableStage$RetryingExecutor.execute(AsyncRetryableStage.java:95) + at software.amazon.awssdk.core.internal.http.pipeline.stages.AsyncRetryableStage.execute(AsyncRetryableStage.java:79) + at software.amazon.awssdk.core.internal.http.pipeline.stages.AsyncRetryableStage.execute(AsyncRetryableStage.java:44) + at software.amazon.awssdk.core.internal.http.pipeline.RequestPipelineBuilder$ComposingRequestPipelineStage.execute(RequestPipelineBuilder.java:206) + at software.amazon.awssdk.core.internal.http.pipeline.RequestPipelineBuilder$ComposingRequestPipelineStage.execute(RequestPipelineBuilder.java:206) + at software.amazon.awssdk.core.internal.http.pipeline.stages.AsyncExecutionFailureExceptionReportingStage.execute(AsyncExecutionFailureExceptionReportingStage.java:41) + at software.amazon.awssdk.core.internal.http.pipeline.stages.AsyncExecutionFailureExceptionReportingStage.execute(AsyncExecutionFailureExceptionReportingStage.java:29) + at software.amazon.awssdk.core.internal.http.pipeline.stages.AsyncApiCallTimeoutTrackingStage.execute(AsyncApiCallTimeoutTrackingStage.java:64) + at software.amazon.awssdk.core.internal.http.pipeline.stages.AsyncApiCallTimeoutTrackingStage.execute(AsyncApiCallTimeoutTrackingStage.java:36) + at software.amazon.awssdk.core.internal.http.pipeline.stages.AsyncApiCallMetricCollectionStage.execute(AsyncApiCallMetricCollectionStage.java:49) + at software.amazon.awssdk.core.internal.http.pipeline.stages.AsyncApiCallMetricCollectionStage.execute(AsyncApiCallMetricCollectionStage.java:32) + at software.amazon.awssdk.core.internal.http.pipeline.RequestPipelineBuilder$ComposingRequestPipelineStage.execute(RequestPipelineBuilder.java:206) + at software.amazon.awssdk.core.internal.http.AmazonAsyncHttpClient$RequestExecutionBuilderImpl.execute(AmazonAsyncHttpClient.java:190) + at software.amazon.awssdk.core.internal.handler.BaseAsyncClientHandler.invoke(BaseAsyncClientHandler.java:285) + at software.amazon.awssdk.core.internal.handler.BaseAsyncClientHandler.doExecute(BaseAsyncClientHandler.java:227) + at software.amazon.awssdk.core.internal.handler.BaseAsyncClientHandler.lambda$execute$1(BaseAsyncClientHandler.java:82) + at software.amazon.awssdk.core.internal.handler.BaseAsyncClientHandler$$Lambda/0x0000000800ab3088.get(Unknown Source:-1) + at software.amazon.awssdk.core.internal.handler.BaseAsyncClientHandler.measureApiCallSuccess(BaseAsyncClientHandler.java:291) + at software.amazon.awssdk.core.internal.handler.BaseAsyncClientHandler.execute(BaseAsyncClientHandler.java:75) + at software.amazon.awssdk.awscore.client.handler.AwsAsyncClientHandler.execute(AwsAsyncClientHandler.java:52) + at software.amazon.awssdk.services.bedrockruntime.DefaultBedrockRuntimeAsyncClient.invokeModel(DefaultBedrockRuntimeAsyncClient.java:161) +``` diff --git a/instrumentation/aws-bedrock-runtime-2.20/build.gradle b/instrumentation/aws-bedrock-runtime-2.20/build.gradle new file mode 100644 index 0000000000..01da205a22 --- /dev/null +++ b/instrumentation/aws-bedrock-runtime-2.20/build.gradle @@ -0,0 +1,21 @@ +jar { + manifest { attributes 'Implementation-Title': 'com.newrelic.instrumentation.aws-bedrock-runtime-2.20' } +} + +dependencies { + implementation(project(":agent-bridge")) + implementation 'software.amazon.awssdk:bedrockruntime:2.20.157' + + testImplementation 'software.amazon.awssdk:bedrockruntime:2.20.157' + testImplementation 'org.mockito:mockito-inline:4.11.0' + testImplementation 'org.json:json:20240303' +} + +verifyInstrumentation { + passes 'software.amazon.awssdk:bedrockruntime:[2.20.157,)' +} + +site { + title 'AWS Bedrock' + type 'Other' +} diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/events/LlmEvent.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/events/LlmEvent.java new file mode 100644 index 0000000000..221da47de2 --- /dev/null +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/events/LlmEvent.java @@ -0,0 +1,370 @@ +/* + * + * * Copyright 2024 New Relic Corporation. All rights reserved. + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package llm.events; + +import com.newrelic.api.agent.NewRelic; +import llm.models.ModelInvocation; +import llm.models.ModelRequest; +import llm.models.ModelResponse; +import llm.vendor.Vendor; + +import java.util.HashMap; +import java.util.Map; + +import static com.newrelic.agent.bridge.aimonitoring.AiMonitoringUtils.isAiMonitoringRecordContentEnabled; + +/** + * Class for building an LlmEvent + */ +public class LlmEvent { + private final Map eventAttributes; + private final Map userLlmAttributes; + + // LLM event types + public static final String LLM_EMBEDDING = "LlmEmbedding"; + public static final String LLM_CHAT_COMPLETION_SUMMARY = "LlmChatCompletionSummary"; + public static final String LLM_CHAT_COMPLETION_MESSAGE = "LlmChatCompletionMessage"; + + // Optional LlmEvent attributes + private final String spanId; + private final String traceId; + private final String vendor; + private final String ingestSource; + private final String id; + private final String content; + private final String role; + private final Boolean isResponse; + private final String requestId; + private final String responseModel; + private final Integer sequence; + private final String completionId; + private final Integer responseNumberOfMessages; + private final Float duration; + private final Boolean error; + private final String input; + private final Float requestTemperature; + private final Integer requestMaxTokens; + private final String requestModel; + private final Integer tokenCount; + private final String responseChoicesFinishReason; + + public static class Builder { + // Required builder parameters + private final Map userAttributes; + private final Map linkingMetadata; + private final ModelRequest modelRequest; + private final ModelResponse modelResponse; + + /* + * All optional builder attributes are defaulted to null so that they won't be added + * to the eventAttributes map unless they are explicitly set via the builder + * methods when constructing an LlmEvent. This allows the builder to create + * any type of LlmEvent with any combination of attributes solely determined + * by the builder methods that are called while omitting all unused attributes. + */ + + // Optional builder parameters + private String spanId = null; + private String traceId = null; + private String vendor = null; + private String ingestSource = null; + private String id = null; + private String content = null; + private String role = null; + private Boolean isResponse = null; + private String requestId = null; + private String responseModel = null; + private Integer sequence = null; + private String completionId = null; + private Integer responseNumberOfMessages = null; + private Float duration = null; + private Boolean error = null; + private String input = null; + private Float requestTemperature = null; + private Integer requestMaxTokens = null; + private String requestModel = null; + private Integer tokenCount = null; + private String responseChoicesFinishReason = null; + + public Builder(ModelInvocation modelInvocation) { + userAttributes = modelInvocation.getUserAttributes(); + linkingMetadata = modelInvocation.getLinkingMetadata(); + modelRequest = modelInvocation.getModelRequest(); + modelResponse = modelInvocation.getModelResponse(); + } + + public Builder spanId() { + spanId = ModelInvocation.getSpanId(linkingMetadata); + return this; + } + + public Builder traceId() { + traceId = ModelInvocation.getTraceId(linkingMetadata); + return this; + } + + public Builder vendor() { + vendor = Vendor.VENDOR; + return this; + } + + public Builder ingestSource() { + ingestSource = Vendor.INGEST_SOURCE; + return this; + } + + public Builder id(String modelId) { + id = modelId; + return this; + } + + public Builder content(String message) { + content = message; + return this; + } + + public Builder role(boolean isUser) { + if (isUser) { + role = "user"; + } else { + role = "assistant"; + } + return this; + } + + public Builder isResponse(boolean isUser) { + isResponse = !isUser; + return this; + } + + public Builder requestId() { + requestId = modelResponse.getAmznRequestId(); + return this; + } + + public Builder responseModel() { + responseModel = modelRequest.getModelId(); + return this; + } + + public Builder sequence(int eventSequence) { + sequence = eventSequence; + return this; + } + + public Builder completionId() { + completionId = modelResponse.getLlmChatCompletionSummaryId(); + return this; + } + + public Builder responseNumberOfMessages(int numberOfMessages) { + responseNumberOfMessages = numberOfMessages; + return this; + } + + public Builder duration(float callDuration) { + duration = callDuration; + return this; + } + + public Builder error() { + error = modelResponse.isErrorResponse(); + return this; + } + + public Builder input(int index) { + input = modelRequest.getInputText(index); + return this; + } + + public Builder requestTemperature() { + requestTemperature = modelRequest.getTemperature(); + return this; + } + + public Builder requestMaxTokens() { + requestMaxTokens = modelRequest.getMaxTokensToSample(); + return this; + } + + public Builder requestModel() { + requestModel = modelRequest.getModelId(); + return this; + } + + public Builder tokenCount(Integer count) { + tokenCount = count; + return this; + } + + public Builder responseChoicesFinishReason() { + responseChoicesFinishReason = modelResponse.getStopReason(); + return this; + } + + public LlmEvent build() { + return new LlmEvent(this); + } + } + + // This populates the LlmEvent attributes map with only the attributes that were explicitly set on the builder. + private LlmEvent(Builder builder) { + // Map of custom user attributes with the llm prefix + userLlmAttributes = getUserLlmAttributes(builder.userAttributes); + + // Map of all LLM event attributes + eventAttributes = new HashMap<>(userLlmAttributes); + + spanId = builder.spanId; + if (spanId != null && !spanId.isEmpty()) { + eventAttributes.put("span_id", spanId); + } + + traceId = builder.traceId; + if (traceId != null && !traceId.isEmpty()) { + eventAttributes.put("trace_id", traceId); + } + + vendor = builder.vendor; + if (vendor != null && !vendor.isEmpty()) { + eventAttributes.put("vendor", vendor); + } + + ingestSource = builder.ingestSource; + if (ingestSource != null && !ingestSource.isEmpty()) { + eventAttributes.put("ingest_source", ingestSource); + } + + id = builder.id; + if (id != null && !id.isEmpty()) { + eventAttributes.put("id", id); + } + + content = builder.content; + if (isAiMonitoringRecordContentEnabled() && content != null && !content.isEmpty()) { + eventAttributes.put("content", content); + } + + role = builder.role; + if (role != null && !role.isEmpty()) { + eventAttributes.put("role", role); + } + + isResponse = builder.isResponse; + if (isResponse != null) { + eventAttributes.put("is_response", isResponse); + } + + requestId = builder.requestId; + if (requestId != null && !requestId.isEmpty()) { + eventAttributes.put("request_id", requestId); + } + + responseModel = builder.responseModel; + if (responseModel != null && !responseModel.isEmpty()) { + eventAttributes.put("response.model", responseModel); + } + + sequence = builder.sequence; + if (sequence != null && sequence >= 0) { + eventAttributes.put("sequence", sequence); + } + + completionId = builder.completionId; + if (completionId != null && !completionId.isEmpty()) { + eventAttributes.put("completion_id", completionId); + } + + responseNumberOfMessages = builder.responseNumberOfMessages; + if (responseNumberOfMessages != null && responseNumberOfMessages >= 0) { + eventAttributes.put("response.number_of_messages", responseNumberOfMessages); + } + + duration = builder.duration; + if (duration != null && duration >= 0) { + eventAttributes.put("duration", duration); + } + + error = builder.error; + if (error != null && error) { + eventAttributes.put("error", true); + } + + input = builder.input; + if (isAiMonitoringRecordContentEnabled() && input != null && !input.isEmpty()) { + eventAttributes.put("input", input); + } + + requestTemperature = builder.requestTemperature; + if (requestTemperature != null && requestTemperature >= 0) { + eventAttributes.put("request.temperature", requestTemperature); + } + + requestMaxTokens = builder.requestMaxTokens; + if (requestMaxTokens != null && requestMaxTokens > 0) { + eventAttributes.put("request.max_tokens", requestMaxTokens); + } + + requestModel = builder.requestModel; + if (requestModel != null && !requestModel.isEmpty()) { + eventAttributes.put("request.model", requestModel); + } + + tokenCount = builder.tokenCount; + if (tokenCount != null && tokenCount > 0) { + eventAttributes.put("token_count", tokenCount); + } + + responseChoicesFinishReason = builder.responseChoicesFinishReason; + if (responseChoicesFinishReason != null && !responseChoicesFinishReason.isEmpty()) { + eventAttributes.put("response.choices.finish_reason", responseChoicesFinishReason); + } + } + + /** + * Takes a map of all attributes added by the customer via the addCustomParameter API and returns a map + * containing only custom attributes with a llm. prefix to be added to LlmEvents. + * + * @param userAttributes Map of all custom user attributes + * @return Map of user attributes prefixed with llm. + */ + private Map getUserLlmAttributes(Map userAttributes) { + Map userLlmAttributes = new HashMap<>(); + + if (userAttributes != null && !userAttributes.isEmpty()) { + for (Map.Entry entry : userAttributes.entrySet()) { + String key = entry.getKey(); + if (key.startsWith("llm.")) { + userLlmAttributes.put(key, entry.getValue()); + } + } + } + return userLlmAttributes; + } + + /** + * Record a LlmChatCompletionMessage custom event + */ + public void recordLlmChatCompletionMessageEvent() { + NewRelic.getAgent().getInsights().recordCustomEvent(LLM_CHAT_COMPLETION_MESSAGE, eventAttributes); + } + + /** + * Record a LlmChatCompletionSummary custom event + */ + public void recordLlmChatCompletionSummaryEvent() { + NewRelic.getAgent().getInsights().recordCustomEvent(LLM_CHAT_COMPLETION_SUMMARY, eventAttributes); + } + + /** + * Record a LlmEmbedding custom event + */ + public void recordLlmEmbeddingEvent() { + NewRelic.getAgent().getInsights().recordCustomEvent(LLM_EMBEDDING, eventAttributes); + } +} diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelInvocation.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelInvocation.java new file mode 100644 index 0000000000..0bb2fc4d8a --- /dev/null +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelInvocation.java @@ -0,0 +1,196 @@ +/* + * + * * Copyright 2024 New Relic Corporation. All rights reserved. + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package llm.models; + +import com.newrelic.agent.bridge.Token; +import com.newrelic.agent.bridge.Transaction; +import com.newrelic.agent.bridge.aimonitoring.LlmTokenCountCallbackHolder; +import com.newrelic.api.agent.NewRelic; +import com.newrelic.api.agent.Segment; + +import java.util.Map; +import java.util.Objects; +import java.util.UUID; + +import static llm.vendor.Vendor.BEDROCK; + +public interface ModelInvocation { + /** + * Set name of the traced method for each LLM embedding and chat completion call + * Llm/{operation_type}/{vendor_name}/{function_name} + *

+ * Used with the sync client + * + * @param txn current transaction + * @param functionName name of SDK function being invoked + */ + void setTracedMethodName(Transaction txn, String functionName); + + /** + * Set name of the async segment for each LLM embedding and chat completion call + * Llm/{operation_type}/{vendor_name}/{function_name} + *

+ * Used with the async client + * + * @param segment active segment for async timing + * @param functionName name of SDK function being invoked + */ + void setSegmentName(Segment segment, String functionName); + + /** + * Record an LlmEmbedding event that captures data specific to the creation of an embedding. + * + * @param startTime start time of SDK invoke method + * @param index of the input message in an array + */ + void recordLlmEmbeddingEvent(long startTime, int index); + + /** + * Record an LlmChatCompletionSummary event that captures high-level data about + * the creation of a chat completion including request, response, and call information. + * + * @param startTime start time of SDK invoke method + * @param numberOfMessages total number of LlmChatCompletionMessage events associated with the summary + */ + void recordLlmChatCompletionSummaryEvent(long startTime, int numberOfMessages); + + /** + * Record an LlmChatCompletionMessage event that corresponds to each message (sent and received) + * from a chat completion call including those created by the user, assistant, and the system. + * + * @param sequence index starting at 0 associated with each message + * @param message String representing the input/output message + * @param isUser boolean representing if the current message event is from a user input prompt or an assistant response message + */ + void recordLlmChatCompletionMessageEvent(int sequence, String message, boolean isUser); + + /** + * Record all LLM events when using the sync client. + * + * @param startTime start time of SDK invoke method + */ + void recordLlmEvents(long startTime); + + /** + * Record all LLM events when using the async client. + *

+ * This causes the txn to be active on the thread where the LlmEvents are created so + * that they properly added to the event reservoir on the txn. This is used when the + * model response is returned asynchronously via CompleteableFuture. + * + * @param startTime start time of SDK invoke method + * @param token Token used to link the transaction to the thread that produces the response + */ + void recordLlmEventsAsync(long startTime, Token token); + + /** + * Report an LLM error. + */ + void reportLlmError(); + + /** + * Get a map of linking metadata. + * + * @return Map of linking metadata + */ + Map getLinkingMetadata(); + + /** + * Get a map of user custom attributes. + * + * @return Map of user custom attributes + */ + Map getUserAttributes(); + + /** + * Get a ModelRequest wrapper class for the SDK Request object. + * + * @return ModelRequest + */ + ModelRequest getModelRequest(); + + /** + * Get a ModelResponse wrapper class for the SDK Response object. + * + * @return ModelResponse + */ + ModelResponse getModelResponse(); + + /** + * Increment a Supportability metric indicating that the SDK was instrumented. + *

+ * This needs to be incremented for every invocation of the SDK. + * Supportability/{language}/ML/{vendor_name}/{vendor_version} + * + * @param vendorVersion version of vendor + */ + static void incrementInstrumentedSupportabilityMetric(String vendorVersion) { + NewRelic.incrementCounter("Supportability/Java/ML/" + BEDROCK + "/" + vendorVersion); + } + + /** + * Set the llm:true attribute on the active transaction. + * + * @param txn current transaction + */ + static void setLlmTrueAgentAttribute(Transaction txn) { + // If in a txn with LLM-related spans + txn.getAgentAttributes().put("llm", true); + } + + /** + * Get the span.id attribute from the map of linking metadata. + * + * @param linkingMetadata Map of linking metadata + * @return String representing the span.id + */ + static String getSpanId(Map linkingMetadata) { + if (linkingMetadata != null && !linkingMetadata.isEmpty()) { + return linkingMetadata.get("span.id"); + } + return ""; + } + + /** + * Get the trace.id attribute from the map of linking metadata. + * + * @param linkingMetadata Map of linking metadata + * @return String representing the trace.id + */ + static String getTraceId(Map linkingMetadata) { + if (linkingMetadata != null && !linkingMetadata.isEmpty()) { + return linkingMetadata.get("trace.id"); + } + return ""; + } + + /** + * Generate a string representation of a random GUID + * + * @return String representation of a GUID + */ + static String getRandomGuid() { + return UUID.randomUUID().toString(); + } + + /** + * Calculates the tokenCount based on a user provided callback + * + * @param model String representation of the LLM model + * @param content String representation of the message content or prompt + * @return int representing the tokenCount + */ + static int getTokenCount(String model, String content) { + if (LlmTokenCountCallbackHolder.getLlmTokenCountCallback() == null || Objects.equals(content, "")) { + return 0; + } + return LlmTokenCountCallbackHolder + .getLlmTokenCountCallback() + .calculateLlmTokenCount(model, content); + } +} diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelRequest.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelRequest.java new file mode 100644 index 0000000000..9a2b491828 --- /dev/null +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelRequest.java @@ -0,0 +1,80 @@ +/* + * + * * Copyright 2024 New Relic Corporation. All rights reserved. + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package llm.models; + +import com.newrelic.api.agent.NewRelic; + +import java.util.logging.Level; + +public interface ModelRequest { + /** + * Get the max tokens allowed for the request. + * + * @return int representing the max tokens allowed for the request + */ + int getMaxTokensToSample(); + + /** + * Get the temperature of the request. + * + * @return float representing the temperature of the request + */ + float getTemperature(); + + /** + * Get the content of the request message, potentially from a specific array index + * if multiple messages are returned. + * + * @param index int indicating the index of a message in an array. May be ignored for request structures that always return a single message. + * @return String representing the content of the request message + */ + String getRequestMessage(int index); + + /** + * Get the number of request messages returned + * + * @return int representing the number of request messages returned + */ + int getNumberOfRequestMessages(); + + /** + * Get the input to the embedding creation call. + * + * @param index int indicating the index of a message in an array. May be ignored for request structures that always return a single message. + * @return String representing the input to the embedding creation call + */ + String getInputText(int index); + + /** + * Get the number of input text messages from the embedding request. + * + * @return int representing the number of request messages returned + */ + int getNumberOfInputTextMessages(); + + /** + * Get the LLM model ID. + * + * @return String representing the LLM model ID + */ + String getModelId(); + + /** + * Log when a parsing error occurs. + * + * @param e Exception encountered when parsing the request + * @param fieldBeingParsed field that was being parsed + */ + static void logParsingFailure(Exception e, String fieldBeingParsed) { + if (e != null) { + NewRelic.getAgent().getLogger().log(Level.FINEST, e, "AIM: Error parsing " + fieldBeingParsed + " from ModelRequest"); + } else { + NewRelic.getAgent().getLogger().log(Level.FINEST, "AIM: Unable to parse empty/null " + fieldBeingParsed + " from ModelRequest"); + } + } +} diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelResponse.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelResponse.java new file mode 100644 index 0000000000..e51b172a87 --- /dev/null +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ModelResponse.java @@ -0,0 +1,104 @@ +/* + * + * * Copyright 2024 New Relic Corporation. All rights reserved. + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package llm.models; + +import com.newrelic.api.agent.NewRelic; + +import java.util.logging.Level; + +public interface ModelResponse { + // Operation types + String COMPLETION = "completion"; + String EMBEDDING = "embedding"; + + /** + * Get the response message, potentially from a specific array index + * if multiple messages are returned. + * + * @param index int indicating the index of a message in an array. May be ignored for response structures that always return a single message. + * @return String representing the response message + */ + String getResponseMessage(int index); + + /** + * Get the number of response messages returned + * + * @return int representing the number of response messages returned + */ + int getNumberOfResponseMessages(); + + /** + * Get the stop reason. + * + * @return String representing the stop reason + */ + String getStopReason(); + + /** + * Get the Amazon Request ID. + * + * @return String representing the Amazon Request ID + */ + String getAmznRequestId(); + + /** + * Get the operation type. + * + * @return String representing the operation type + */ + String getOperationType(); + + /** + * Get the ID for the associated LlmChatCompletionSummary event. + * + * @return String representing the ID for the associated LlmChatCompletionSummary event + */ + String getLlmChatCompletionSummaryId(); + + /** + * Get the ID for the associated LlmEmbedding event. + * + * @return String representing the ID for the associated LlmEmbedding event + */ + String getLlmEmbeddingId(); + + /** + * Determine whether the response resulted in an error or not. + * + * @return boolean true when the LLM response is an error, false when the response was successful + */ + boolean isErrorResponse(); + + /** + * Get the response status code. + * + * @return int representing the response status code + */ + int getStatusCode(); + + /** + * Get the response status text. + * + * @return String representing the response status text + */ + String getStatusText(); + + /** + * Log when a parsing error occurs. + * + * @param e Exception encountered when parsing the response + * @param fieldBeingParsed field that was being parsed + */ + static void logParsingFailure(Exception e, String fieldBeingParsed) { + if (e != null) { + NewRelic.getAgent().getLogger().log(Level.FINEST, e, "AIM: Error parsing " + fieldBeingParsed + " from ModelResponse"); + } else { + NewRelic.getAgent().getLogger().log(Level.FINEST, "AIM: Unable to parse empty/null " + fieldBeingParsed + " from ModelResponse"); + } + } +} diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/SupportedModels.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/SupportedModels.java new file mode 100644 index 0000000000..70300901f1 --- /dev/null +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/SupportedModels.java @@ -0,0 +1,23 @@ +/* + * + * * Copyright 2024 New Relic Corporation. All rights reserved. + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package llm.models; + +/** + * Prefixes for supported models. As long as the model ID for an invoked LLM model contains + * one of these prefixes the instrumentation should attempt to process the request/response. + *

+ * See the README for each model in llm.models.* for more details on supported models. + */ +public class SupportedModels { + public static final String ANTHROPIC_CLAUDE = "anthropic.claude"; + public static final String AMAZON_TITAN = "amazon.titan"; + public static final String META_LLAMA_2 = "meta.llama2"; + public static final String COHERE_COMMAND = "cohere.command"; + public static final String COHERE_EMBED = "cohere.embed"; + public static final String AI_21_LABS_JURASSIC = "ai21.j2"; +} diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ai21labs/jurassic/JurassicModelInvocation.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ai21labs/jurassic/JurassicModelInvocation.java new file mode 100644 index 0000000000..ea41563314 --- /dev/null +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ai21labs/jurassic/JurassicModelInvocation.java @@ -0,0 +1,224 @@ +/* + * + * * Copyright 2024 New Relic Corporation. All rights reserved. + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package llm.models.ai21labs.jurassic; + +import com.newrelic.agent.bridge.Token; +import com.newrelic.agent.bridge.Transaction; +import com.newrelic.api.agent.NewRelic; +import com.newrelic.api.agent.Segment; +import com.newrelic.api.agent.Trace; +import llm.events.LlmEvent; +import llm.models.ModelInvocation; +import llm.models.ModelRequest; +import llm.models.ModelResponse; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelRequest; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse; + +import java.util.HashMap; +import java.util.Map; +import java.util.logging.Level; + +import static llm.models.ModelResponse.COMPLETION; +import static llm.models.ModelResponse.EMBEDDING; +import static llm.vendor.Vendor.BEDROCK; + +public class JurassicModelInvocation implements ModelInvocation { + Map linkingMetadata; + Map userAttributes; + ModelRequest modelRequest; + ModelResponse modelResponse; + + public JurassicModelInvocation(Map linkingMetadata, Map userCustomAttributes, InvokeModelRequest invokeModelRequest, + InvokeModelResponse invokeModelResponse) { + this.linkingMetadata = linkingMetadata; + this.userAttributes = userCustomAttributes; + this.modelRequest = new JurassicModelRequest(invokeModelRequest); + this.modelResponse = new JurassicModelResponse(invokeModelResponse); + } + + @Override + public void setTracedMethodName(Transaction txn, String functionName) { + txn.getTracedMethod().setMetricName("Llm", modelResponse.getOperationType(), BEDROCK, functionName); + } + + @Override + public void setSegmentName(Segment segment, String functionName) { + segment.setMetricName("Llm", modelResponse.getOperationType(), BEDROCK, functionName); + } + + @Override + public void recordLlmEmbeddingEvent(long startTime, int index) { + if (modelResponse.isErrorResponse()) { + reportLlmError(); + } + + LlmEvent.Builder builder = new LlmEvent.Builder(this); + + LlmEvent llmEmbeddingEvent = builder + .spanId() + .traceId() + .vendor() + .ingestSource() + .id(modelResponse.getLlmEmbeddingId()) + .requestId() + .input(index) + .requestModel() + .responseModel() + .tokenCount(ModelInvocation.getTokenCount(modelRequest.getModelId(), modelRequest.getInputText(0))) + .error() + .duration(System.currentTimeMillis() - startTime) + .build(); + + llmEmbeddingEvent.recordLlmEmbeddingEvent(); + } + + @Override + public void recordLlmChatCompletionSummaryEvent(long startTime, int numberOfMessages) { + if (modelResponse.isErrorResponse()) { + reportLlmError(); + } + + LlmEvent.Builder builder = new LlmEvent.Builder(this); + + LlmEvent llmChatCompletionSummaryEvent = builder + .spanId() + .traceId() + .vendor() + .ingestSource() + .id(modelResponse.getLlmChatCompletionSummaryId()) + .requestId() + .requestTemperature() + .requestMaxTokens() + .requestModel() + .responseModel() + .responseNumberOfMessages(numberOfMessages) + .responseChoicesFinishReason() + .error() + .duration(System.currentTimeMillis() - startTime) + .build(); + + llmChatCompletionSummaryEvent.recordLlmChatCompletionSummaryEvent(); + } + + @Override + public void recordLlmChatCompletionMessageEvent(int sequence, String message, boolean isUser) { + LlmEvent.Builder builder = new LlmEvent.Builder(this); + + LlmEvent llmChatCompletionMessageEvent = builder + .spanId() + .traceId() + .vendor() + .ingestSource() + .id(ModelInvocation.getRandomGuid()) + .content(message) + .role(isUser) + .isResponse(isUser) + .requestId() + .responseModel() + .sequence(sequence) + .completionId() + .tokenCount(ModelInvocation.getTokenCount(modelRequest.getModelId(), message)) + .build(); + + llmChatCompletionMessageEvent.recordLlmChatCompletionMessageEvent(); + } + + @Override + public void recordLlmEvents(long startTime) { + String operationType = modelResponse.getOperationType(); + if (operationType.equals(COMPLETION)) { + recordLlmChatCompletionEvents(startTime); + } else if (operationType.equals(EMBEDDING)) { + recordLlmEmbeddingEvents(startTime); + } else { + NewRelic.getAgent().getLogger().log(Level.INFO, "AIM: Unexpected operation type encountered when trying to record LLM events"); + } + } + + @Trace(async = true) + @Override + public void recordLlmEventsAsync(long startTime, Token token) { + if (token != null && token.isActive()) { + token.linkAndExpire(); + } + recordLlmEvents(startTime); + } + + @Override + public void reportLlmError() { + Map errorParams = new HashMap<>(); + errorParams.put("http.statusCode", modelResponse.getStatusCode()); + errorParams.put("error.code", modelResponse.getStatusCode()); + if (!modelResponse.getLlmChatCompletionSummaryId().isEmpty()) { + errorParams.put("completion_id", modelResponse.getLlmChatCompletionSummaryId()); + } + if (!modelResponse.getLlmEmbeddingId().isEmpty()) { + errorParams.put("embedding_id", modelResponse.getLlmEmbeddingId()); + } + NewRelic.noticeError("LlmError: " + modelResponse.getStatusText(), errorParams); + } + + /** + * Records multiple LlmChatCompletionMessage events and a single LlmChatCompletionSummary event. + * The number of LlmChatCompletionMessage events produced can differ based on vendor. + */ + private void recordLlmChatCompletionEvents(long startTime) { + int numberOfRequestMessages = modelRequest.getNumberOfRequestMessages(); + int numberOfResponseMessages = modelResponse.getNumberOfResponseMessages(); + int totalNumberOfMessages = numberOfRequestMessages + numberOfResponseMessages; + + int sequence = 0; + + // First, record all LlmChatCompletionMessage events representing the user input prompt + for (int i = 0; i < numberOfRequestMessages; i++) { + recordLlmChatCompletionMessageEvent(sequence, modelRequest.getRequestMessage(i), true); + sequence++; + } + + // Second, record all LlmChatCompletionMessage events representing the completion message from the LLM response + for (int i = 0; i < numberOfResponseMessages; i++) { + recordLlmChatCompletionMessageEvent(sequence, modelResponse.getResponseMessage(i), false); + sequence++; + } + + // Finally, record a summary event representing all LlmChatCompletionMessage events + recordLlmChatCompletionSummaryEvent(startTime, totalNumberOfMessages); + } + + /** + * Records one, and potentially more, LlmEmbedding events based on the number of input messages in the request. + * The number of LlmEmbedding events produced can differ based on vendor. + */ + private void recordLlmEmbeddingEvents(long startTime) { + int numberOfRequestMessages = modelRequest.getNumberOfInputTextMessages(); + // Record an LlmEmbedding event for each input message in the request + for (int i = 0; i < numberOfRequestMessages; i++) { + recordLlmEmbeddingEvent(startTime, i); + } + } + + @Override + public Map getLinkingMetadata() { + return linkingMetadata; + } + + @Override + public Map getUserAttributes() { + return userAttributes; + } + + @Override + public ModelRequest getModelRequest() { + return modelRequest; + } + + @Override + public ModelResponse getModelResponse() { + return modelResponse; + } +} diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ai21labs/jurassic/JurassicModelRequest.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ai21labs/jurassic/JurassicModelRequest.java new file mode 100644 index 0000000000..610ae4ea40 --- /dev/null +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ai21labs/jurassic/JurassicModelRequest.java @@ -0,0 +1,165 @@ +/* + * + * * Copyright 2024 New Relic Corporation. All rights reserved. + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package llm.models.ai21labs.jurassic; + +import com.newrelic.api.agent.NewRelic; +import llm.models.ModelRequest; +import software.amazon.awssdk.protocols.jsoncore.JsonNode; +import software.amazon.awssdk.protocols.jsoncore.JsonNodeParser; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelRequest; + +import java.util.Collections; +import java.util.Map; +import java.util.logging.Level; + +import static llm.models.ModelRequest.logParsingFailure; + +/** + * Stores the required info from the Bedrock InvokeModelRequest without holding + * a reference to the actual request object to avoid potential memory issues. + */ +public class JurassicModelRequest implements ModelRequest { + private static final String MAX_TOKENS = "maxTokens"; + private static final String TEMPERATURE = "temperature"; + private static final String PROMPT = "prompt"; + + private String invokeModelRequestBody = ""; + private String modelId = ""; + private Map requestBodyJsonMap = null; + + public JurassicModelRequest(InvokeModelRequest invokeModelRequest) { + if (invokeModelRequest != null) { + invokeModelRequestBody = invokeModelRequest.body().asUtf8String(); + modelId = invokeModelRequest.modelId(); + } else { + NewRelic.getAgent().getLogger().log(Level.FINEST, "AIM: Received null InvokeModelRequest"); + } + } + + /** + * Get a map of the Request body contents. + *

+ * Use this method to obtain the Request body contents so that the map is lazily initialized and only parsed once. + * + * @return map of String to JsonNode + */ + private Map getRequestBodyJsonMap() { + if (requestBodyJsonMap == null) { + requestBodyJsonMap = parseInvokeModelRequestBodyMap(); + } + return requestBodyJsonMap; + } + + /** + * Convert JSON Request body string into a map. + * + * @return map of String to JsonNode + */ + private Map parseInvokeModelRequestBodyMap() { + // Use AWS SDK JSON parsing to parse request body + JsonNodeParser jsonNodeParser = JsonNodeParser.create(); + JsonNode requestBodyJsonNode = jsonNodeParser.parse(invokeModelRequestBody); + + Map requestBodyJsonMap = null; + try { + if (requestBodyJsonNode != null && requestBodyJsonNode.isObject()) { + requestBodyJsonMap = requestBodyJsonNode.asObject(); + } else { + logParsingFailure(null, "request body"); + } + } catch (Exception e) { + logParsingFailure(e, "request body"); + } + return requestBodyJsonMap != null ? requestBodyJsonMap : Collections.emptyMap(); + } + + @Override + public int getMaxTokensToSample() { + int maxTokensToSample = 0; + try { + if (!getRequestBodyJsonMap().isEmpty()) { + JsonNode jsonNode = getRequestBodyJsonMap().get(MAX_TOKENS); + if (jsonNode.isNumber()) { + String maxTokensToSampleString = jsonNode.asNumber(); + maxTokensToSample = Integer.parseInt(maxTokensToSampleString); + } + } + } catch (Exception e) { + logParsingFailure(e, MAX_TOKENS); + } + if (maxTokensToSample == 0) { + logParsingFailure(null, MAX_TOKENS); + } + return maxTokensToSample; + } + + @Override + public float getTemperature() { + float temperature = 0f; + try { + if (!getRequestBodyJsonMap().isEmpty()) { + JsonNode jsonNode = getRequestBodyJsonMap().get(TEMPERATURE); + if (jsonNode.isNumber()) { + String temperatureString = jsonNode.asNumber(); + temperature = Float.parseFloat(temperatureString); + } + } else { + logParsingFailure(null, TEMPERATURE); + } + } catch (Exception e) { + logParsingFailure(e, TEMPERATURE); + } + return temperature; + } + + @Override + public int getNumberOfRequestMessages() { + // The Jurassic request only ever contains a single prompt message + return 1; + } + + @Override + public String getRequestMessage(int index) { + return parseStringValue(PROMPT); + } + + @Override + public String getInputText(int index) { + // This is a NoOp for Jurassic as it doesn't support embeddings + return ""; + } + + @Override + public int getNumberOfInputTextMessages() { + // This is a NoOp for Jurassic as it doesn't support embeddings + return 0; + } + + private String parseStringValue(String fieldToParse) { + String parsedStringValue = ""; + try { + if (!getRequestBodyJsonMap().isEmpty()) { + JsonNode jsonNode = getRequestBodyJsonMap().get(fieldToParse); + if (jsonNode.isString()) { + parsedStringValue = jsonNode.asString(); + } + } + } catch (Exception e) { + logParsingFailure(e, fieldToParse); + } + if (parsedStringValue.isEmpty()) { + logParsingFailure(null, fieldToParse); + } + return parsedStringValue; + } + + @Override + public String getModelId() { + return modelId; + } +} diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ai21labs/jurassic/JurassicModelResponse.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ai21labs/jurassic/JurassicModelResponse.java new file mode 100644 index 0000000000..8d78cacf4e --- /dev/null +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ai21labs/jurassic/JurassicModelResponse.java @@ -0,0 +1,255 @@ +/* + * + * * Copyright 2024 New Relic Corporation. All rights reserved. + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package llm.models.ai21labs.jurassic; + +import com.newrelic.api.agent.NewRelic; +import llm.models.ModelResponse; +import software.amazon.awssdk.protocols.jsoncore.JsonNode; +import software.amazon.awssdk.protocols.jsoncore.JsonNodeParser; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse; + +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.logging.Level; + +import static llm.models.ModelInvocation.getRandomGuid; +import static llm.models.ModelResponse.logParsingFailure; + +/** + * Stores the required info from the Bedrock InvokeModelResponse without holding + * a reference to the actual request object to avoid potential memory issues. + */ +public class JurassicModelResponse implements ModelResponse { + private static final String FINISH_REASON = "finishReason"; + private static final String REASON = "reason"; + private static final String COMPLETIONS = "completions"; + private static final String DATA = "data"; + private static final String TEXT = "text"; + + private String amznRequestId = ""; + + // LLM operation type + private String operationType = ""; + + // HTTP response + private boolean isSuccessfulResponse = false; + private int statusCode = 0; + private String statusText = ""; + + private String llmChatCompletionSummaryId = ""; + private String llmEmbeddingId = ""; + + private String invokeModelResponseBody = ""; + private Map responseBodyJsonMap = null; + + public JurassicModelResponse(InvokeModelResponse invokeModelResponse) { + if (invokeModelResponse != null) { + invokeModelResponseBody = invokeModelResponse.body().asUtf8String(); + isSuccessfulResponse = invokeModelResponse.sdkHttpResponse().isSuccessful(); + statusCode = invokeModelResponse.sdkHttpResponse().statusCode(); + Optional statusTextOptional = invokeModelResponse.sdkHttpResponse().statusText(); + statusTextOptional.ifPresent(s -> statusText = s); + setOperationType(invokeModelResponseBody); + amznRequestId = invokeModelResponse.responseMetadata().requestId(); + llmChatCompletionSummaryId = getRandomGuid(); + llmEmbeddingId = getRandomGuid(); + } else { + NewRelic.getAgent().getLogger().log(Level.INFO, "AIM: Received null InvokeModelResponse"); + } + } + + /** + * Get a map of the Response body contents. + *

+ * Use this method to obtain the Response body contents so that the map is lazily initialized and only parsed once. + * + * @return map of String to JsonNode + */ + private Map getResponseBodyJsonMap() { + if (responseBodyJsonMap == null) { + responseBodyJsonMap = parseInvokeModelResponseBodyMap(); + } + return responseBodyJsonMap; + } + + /** + * Convert JSON Response body string into a map. + * + * @return map of String to JsonNode + */ + private Map parseInvokeModelResponseBodyMap() { + Map responseBodyJsonMap = null; + try { + // Use AWS SDK JSON parsing to parse response body + JsonNodeParser jsonNodeParser = JsonNodeParser.create(); + JsonNode responseBodyJsonNode = jsonNodeParser.parse(invokeModelResponseBody); + + if (responseBodyJsonNode != null && responseBodyJsonNode.isObject()) { + responseBodyJsonMap = responseBodyJsonNode.asObject(); + } else { + logParsingFailure(null, "response body"); + } + } catch (Exception e) { + logParsingFailure(e, "response body"); + } + return responseBodyJsonMap != null ? responseBodyJsonMap : Collections.emptyMap(); + } + + /** + * Parses the operation type from the response body and assigns it to a field. + * + * @param invokeModelResponseBody response body String + */ + private void setOperationType(String invokeModelResponseBody) { + try { + if (!invokeModelResponseBody.isEmpty()) { + // Jurassic for Bedrock doesn't support embedding operations + if (invokeModelResponseBody.contains(COMPLETION)) { + operationType = COMPLETION; + } else { + logParsingFailure(null, "operation type"); + } + } + } catch (Exception e) { + logParsingFailure(e, "operation type"); + } + } + + @Override + public String getResponseMessage(int index) { + String parsedResponseMessage = ""; + try { + if (!getResponseBodyJsonMap().isEmpty()) { + JsonNode completionsJsonNode = getResponseBodyJsonMap().get(COMPLETIONS); + if (completionsJsonNode.isArray()) { + List completionsJsonNodeArray = completionsJsonNode.asArray(); + if (!completionsJsonNodeArray.isEmpty()) { + JsonNode jsonNode = completionsJsonNodeArray.get(index); + if (jsonNode.isObject()) { + Map jsonNodeObject = jsonNode.asObject(); + if (!jsonNodeObject.isEmpty()) { + JsonNode dataJsonNode = jsonNodeObject.get(DATA); + if (dataJsonNode.isObject()) { + Map dataJsonNodeObject = dataJsonNode.asObject(); + if (!dataJsonNodeObject.isEmpty()) { + JsonNode textJsonNode = dataJsonNodeObject.get(TEXT); + if (textJsonNode.isString()) { + parsedResponseMessage = textJsonNode.asString(); + } + } + } + } + } + } + } + } + } catch (Exception e) { + logParsingFailure(e, TEXT); + } + if (parsedResponseMessage.isEmpty()) { + logParsingFailure(null, TEXT); + } + return parsedResponseMessage; + } + + @Override + public int getNumberOfResponseMessages() { + int numberOfResponseMessages = 0; + try { + if (!getResponseBodyJsonMap().isEmpty()) { + JsonNode completionsJsonNode = getResponseBodyJsonMap().get(COMPLETIONS); + if (completionsJsonNode.isArray()) { + List completionsJsonNodeArray = completionsJsonNode.asArray(); + if (!completionsJsonNodeArray.isEmpty()) { + numberOfResponseMessages = completionsJsonNodeArray.size(); + } + } + } + } catch (Exception e) { + logParsingFailure(e, COMPLETIONS); + } + if (numberOfResponseMessages == 0) { + logParsingFailure(null, COMPLETIONS); + } + return numberOfResponseMessages; + } + + @Override + public String getStopReason() { + String parsedStopReason = ""; + try { + if (!getResponseBodyJsonMap().isEmpty()) { + JsonNode completionsJsonNode = getResponseBodyJsonMap().get(COMPLETIONS); + if (completionsJsonNode.isArray()) { + List jsonNodeArray = completionsJsonNode.asArray(); + if (!jsonNodeArray.isEmpty()) { + JsonNode jsonNode = jsonNodeArray.get(0); + if (jsonNode.isObject()) { + Map jsonNodeObject = jsonNode.asObject(); + if (!jsonNodeObject.isEmpty()) { + JsonNode dataJsonNode = jsonNodeObject.get(FINISH_REASON); + if (dataJsonNode.isObject()) { + Map dataJsonNodeObject = dataJsonNode.asObject(); + if (!dataJsonNodeObject.isEmpty()) { + JsonNode textJsonNode = dataJsonNodeObject.get(REASON); + if (textJsonNode.isString()) { + parsedStopReason = textJsonNode.asString(); + } + } + } + } + } + } + } + } + } catch (Exception e) { + logParsingFailure(e, FINISH_REASON); + } + if (parsedStopReason.isEmpty()) { + logParsingFailure(null, FINISH_REASON); + } + return parsedStopReason; + } + + @Override + public String getAmznRequestId() { + return amznRequestId; + } + + @Override + public String getOperationType() { + return operationType; + } + + @Override + public String getLlmChatCompletionSummaryId() { + return llmChatCompletionSummaryId; + } + + @Override + public String getLlmEmbeddingId() { + return llmEmbeddingId; + } + + @Override + public boolean isErrorResponse() { + return !isSuccessfulResponse; + } + + @Override + public int getStatusCode() { + return statusCode; + } + + @Override + public String getStatusText() { + return statusText; + } +} diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ai21labs/jurassic/README.md b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ai21labs/jurassic/README.md new file mode 100644 index 0000000000..6d2d1d71aa --- /dev/null +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/ai21labs/jurassic/README.md @@ -0,0 +1,713 @@ +# AI21 Labs + +Examples of the request/response bodies for models that have been tested and verified to work. The instrumentation should continue to correctly process new +models as long as they match the model naming prefixes in `llm.models.SupportedModels` and the request/response structure stays the same as the examples listed +here. + +## Jurassic Models + +#### Text Completion Models + +The following models have been tested: + +* Jurassic-2 Mid (`ai21.j2-mid-v1`) +* Jurassic-2 Ultra (`ai21.j2-ultra-v1`) + +#### Sample Request + +```json +{ + "temperature": 0.5, + "maxTokens": 1000, + "prompt": "What is the color of the sky?" +} +``` + +#### Sample Response + +```json +{ + "id": 1234, + "prompt": { + "text": "What is the color of the sky?", + "tokens": [ + { + "generatedToken": { + "token": "▁What▁is▁the", + "logprob": -8.316551208496094, + "raw_logprob": -8.316551208496094 + }, + "topTokens": null, + "textRange": { + "start": 0, + "end": 11 + } + }, + { + "generatedToken": { + "token": "▁color", + "logprob": -7.189708709716797, + "raw_logprob": -7.189708709716797 + }, + "topTokens": null, + "textRange": { + "start": 11, + "end": 17 + } + }, + { + "generatedToken": { + "token": "▁of▁the▁sky", + "logprob": -5.750617027282715, + "raw_logprob": -5.750617027282715 + }, + "topTokens": null, + "textRange": { + "start": 17, + "end": 28 + } + }, + { + "generatedToken": { + "token": "?", + "logprob": -5.858178615570068, + "raw_logprob": -5.858178615570068 + }, + "topTokens": null, + "textRange": { + "start": 28, + "end": 29 + } + } + ] + }, + "completions": [ + { + "data": { + "text": "\nThe color of the sky on Earth is blue. This is because Earth's atmosphere scatters short-wavelength light more efficiently than long-wavelength light. When sunlight enters Earth's atmosphere, most of the blue light is scattered, leaving mostly red light to illuminate the sky. The scattering of blue light is more efficient because it travels as shorter, smaller waves.", + "tokens": [ + { + "generatedToken": { + "token": "<|newline|>", + "logprob": 0.0, + "raw_logprob": -6.305972783593461E-5 + }, + "topTokens": null, + "textRange": { + "start": 0, + "end": 1 + } + }, + { + "generatedToken": { + "token": "▁The▁color", + "logprob": -0.007753042038530111, + "raw_logprob": -0.18397575616836548 + }, + "topTokens": null, + "textRange": { + "start": 1, + "end": 10 + } + }, + { + "generatedToken": { + "token": "▁of▁the▁sky", + "logprob": -6.770858453819528E-5, + "raw_logprob": -0.0130088459700346 + }, + "topTokens": null, + "textRange": { + "start": 10, + "end": 21 + } + }, + { + "generatedToken": { + "token": "▁on▁Earth", + "logprob": -6.189814303070307E-4, + "raw_logprob": -0.06852064281702042 + }, + "topTokens": null, + "textRange": { + "start": 21, + "end": 30 + } + }, + { + "generatedToken": { + "token": "▁is", + "logprob": -0.5599813461303711, + "raw_logprob": -1.532042145729065 + }, + "topTokens": null, + "textRange": { + "start": 30, + "end": 33 + } + }, + { + "generatedToken": { + "token": "▁blue", + "logprob": -0.0358763113617897, + "raw_logprob": -0.2531339228153229 + }, + "topTokens": null, + "textRange": { + "start": 33, + "end": 38 + } + }, + { + "generatedToken": { + "token": ".", + "logprob": -0.0022088908590376377, + "raw_logprob": -0.11807831376791 + }, + "topTokens": null, + "textRange": { + "start": 38, + "end": 39 + } + }, + { + "generatedToken": { + "token": "▁This▁is▁because", + "logprob": -0.7582850456237793, + "raw_logprob": -1.6503678560256958 + }, + "topTokens": null, + "textRange": { + "start": 39, + "end": 55 + } + }, + { + "generatedToken": { + "token": "▁Earth's▁atmosphere", + "logprob": -0.37150290608406067, + "raw_logprob": -1.086639404296875 + }, + "topTokens": null, + "textRange": { + "start": 55, + "end": 74 + } + }, + { + "generatedToken": { + "token": "▁scatter", + "logprob": -1.4662635294371285E-5, + "raw_logprob": -0.011443688534200191 + }, + "topTokens": null, + "textRange": { + "start": 74, + "end": 82 + } + }, + { + "generatedToken": { + "token": "s", + "logprob": -9.929640509653836E-5, + "raw_logprob": -0.01099079567939043 + }, + "topTokens": null, + "textRange": { + "start": 82, + "end": 83 + } + }, + { + "generatedToken": { + "token": "▁short", + "logprob": -2.97943115234375, + "raw_logprob": -1.8346563577651978 + }, + "topTokens": null, + "textRange": { + "start": 83, + "end": 89 + } + }, + { + "generatedToken": { + "token": "-wavelength", + "logprob": -1.5722469834145159E-4, + "raw_logprob": -0.020076051354408264 + }, + "topTokens": null, + "textRange": { + "start": 89, + "end": 100 + } + }, + { + "generatedToken": { + "token": "▁light", + "logprob": -1.8000440832111053E-5, + "raw_logprob": -0.008328350260853767 + }, + "topTokens": null, + "textRange": { + "start": 100, + "end": 106 + } + }, + { + "generatedToken": { + "token": "▁more▁efficiently", + "logprob": -0.11763446033000946, + "raw_logprob": -0.6382070779800415 + }, + "topTokens": null, + "textRange": { + "start": 106, + "end": 123 + } + }, + { + "generatedToken": { + "token": "▁than", + "logprob": -0.0850396677851677, + "raw_logprob": -0.4660969078540802 + }, + "topTokens": null, + "textRange": { + "start": 123, + "end": 128 + } + }, + { + "generatedToken": { + "token": "▁long", + "logprob": -0.21488533914089203, + "raw_logprob": -0.43275904655456543 + }, + "topTokens": null, + "textRange": { + "start": 128, + "end": 133 + } + }, + { + "generatedToken": { + "token": "-wavelength", + "logprob": -3.576272320060525E-6, + "raw_logprob": -0.0032024311367422342 + }, + "topTokens": null, + "textRange": { + "start": 133, + "end": 144 + } + }, + { + "generatedToken": { + "token": "▁light", + "logprob": -6.603976362384856E-5, + "raw_logprob": -0.021542951464653015 + }, + "topTokens": null, + "textRange": { + "start": 144, + "end": 150 + } + }, + { + "generatedToken": { + "token": ".", + "logprob": -0.03969373181462288, + "raw_logprob": -0.24834078550338745 + }, + "topTokens": null, + "textRange": { + "start": 150, + "end": 151 + } + }, + { + "generatedToken": { + "token": "▁When", + "logprob": -0.8459960222244263, + "raw_logprob": -1.758193016052246 + }, + "topTokens": null, + "textRange": { + "start": 151, + "end": 156 + } + }, + { + "generatedToken": { + "token": "▁sunlight", + "logprob": -0.043000709265470505, + "raw_logprob": -0.413555383682251 + }, + "topTokens": null, + "textRange": { + "start": 156, + "end": 165 + } + }, + { + "generatedToken": { + "token": "▁enters", + "logprob": -2.2813825607299805, + "raw_logprob": -1.975184440612793 + }, + "topTokens": null, + "textRange": { + "start": 165, + "end": 172 + } + }, + { + "generatedToken": { + "token": "▁Earth's▁atmosphere", + "logprob": -0.04206264019012451, + "raw_logprob": -0.22090668976306915 + }, + "topTokens": null, + "textRange": { + "start": 172, + "end": 191 + } + }, + { + "generatedToken": { + "token": ",", + "logprob": -2.1300431399140507E-4, + "raw_logprob": -0.04065611585974693 + }, + "topTokens": null, + "textRange": { + "start": 191, + "end": 192 + } + }, + { + "generatedToken": { + "token": "▁most▁of▁the", + "logprob": -1.0895559787750244, + "raw_logprob": -1.4258980751037598 + }, + "topTokens": null, + "textRange": { + "start": 192, + "end": 204 + } + }, + { + "generatedToken": { + "token": "▁blue▁light", + "logprob": -2.7195115089416504, + "raw_logprob": -2.069707155227661 + }, + "topTokens": null, + "textRange": { + "start": 204, + "end": 215 + } + }, + { + "generatedToken": { + "token": "▁is", + "logprob": -3.036991402041167E-4, + "raw_logprob": -0.036258988082408905 + }, + "topTokens": null, + "textRange": { + "start": 215, + "end": 218 + } + }, + { + "generatedToken": { + "token": "▁scattered", + "logprob": -1.1086402082582936E-5, + "raw_logprob": -0.007142604328691959 + }, + "topTokens": null, + "textRange": { + "start": 218, + "end": 228 + } + }, + { + "generatedToken": { + "token": ",", + "logprob": -0.8132423162460327, + "raw_logprob": -1.204469919204712 + }, + "topTokens": null, + "textRange": { + "start": 228, + "end": 229 + } + }, + { + "generatedToken": { + "token": "▁leaving", + "logprob": -0.028648898005485535, + "raw_logprob": -0.24427929520606995 + }, + "topTokens": null, + "textRange": { + "start": 229, + "end": 237 + } + }, + { + "generatedToken": { + "token": "▁mostly", + "logprob": -0.012762418016791344, + "raw_logprob": -0.18833962082862854 + }, + "topTokens": null, + "textRange": { + "start": 237, + "end": 244 + } + }, + { + "generatedToken": { + "token": "▁red▁light", + "logprob": -0.3875422477722168, + "raw_logprob": -0.9608176350593567 + }, + "topTokens": null, + "textRange": { + "start": 244, + "end": 254 + } + }, + { + "generatedToken": { + "token": "▁to▁illuminate", + "logprob": -1.2177848815917969, + "raw_logprob": -1.6379175186157227 + }, + "topTokens": null, + "textRange": { + "start": 254, + "end": 268 + } + }, + { + "generatedToken": { + "token": "▁the▁sky", + "logprob": -0.004821578972041607, + "raw_logprob": -0.1349806934595108 + }, + "topTokens": null, + "textRange": { + "start": 268, + "end": 276 + } + }, + { + "generatedToken": { + "token": ".", + "logprob": -2.7894584491150454E-5, + "raw_logprob": -0.01649152860045433 + }, + "topTokens": null, + "textRange": { + "start": 276, + "end": 277 + } + }, + { + "generatedToken": { + "token": "▁The", + "logprob": -4.816740989685059, + "raw_logprob": -3.04256534576416 + }, + "topTokens": null, + "textRange": { + "start": 277, + "end": 281 + } + }, + { + "generatedToken": { + "token": "▁scattering", + "logprob": -0.07598043233156204, + "raw_logprob": -0.4935254752635956 + }, + "topTokens": null, + "textRange": { + "start": 281, + "end": 292 + } + }, + { + "generatedToken": { + "token": "▁of", + "logprob": -2.1653952598571777, + "raw_logprob": -2.153515338897705 + }, + "topTokens": null, + "textRange": { + "start": 292, + "end": 295 + } + }, + { + "generatedToken": { + "token": "▁blue▁light", + "logprob": -0.0025517542380839586, + "raw_logprob": -0.0987434908747673 + }, + "topTokens": null, + "textRange": { + "start": 295, + "end": 306 + } + }, + { + "generatedToken": { + "token": "▁is", + "logprob": -0.04848421365022659, + "raw_logprob": -0.5477231740951538 + }, + "topTokens": null, + "textRange": { + "start": 306, + "end": 309 + } + }, + { + "generatedToken": { + "token": "▁more▁efficient", + "logprob": -1.145136833190918, + "raw_logprob": -1.6279737949371338 + }, + "topTokens": null, + "textRange": { + "start": 309, + "end": 324 + } + }, + { + "generatedToken": { + "token": "▁because▁it", + "logprob": -0.7712448835372925, + "raw_logprob": -1.402230143547058 + }, + "topTokens": null, + "textRange": { + "start": 324, + "end": 335 + } + }, + { + "generatedToken": { + "token": "▁travels", + "logprob": -1.0001159535022452E-4, + "raw_logprob": -0.03441037982702255 + }, + "topTokens": null, + "textRange": { + "start": 335, + "end": 343 + } + }, + { + "generatedToken": { + "token": "▁as", + "logprob": -2.169585604860913E-5, + "raw_logprob": -0.008925186470150948 + }, + "topTokens": null, + "textRange": { + "start": 343, + "end": 346 + } + }, + { + "generatedToken": { + "token": "▁shorter", + "logprob": -0.0026372435968369246, + "raw_logprob": -0.054399896413087845 + }, + "topTokens": null, + "textRange": { + "start": 346, + "end": 354 + } + }, + { + "generatedToken": { + "token": ",", + "logprob": -3.576214658096433E-5, + "raw_logprob": -0.011654269881546497 + }, + "topTokens": null, + "textRange": { + "start": 354, + "end": 355 + } + }, + { + "generatedToken": { + "token": "▁smaller", + "logprob": -1.0609570381348021E-5, + "raw_logprob": -0.007282733917236328 + }, + "topTokens": null, + "textRange": { + "start": 355, + "end": 363 + } + }, + { + "generatedToken": { + "token": "▁waves", + "logprob": -2.7418097943154862E-6, + "raw_logprob": -0.0030873988289386034 + }, + "topTokens": null, + "textRange": { + "start": 363, + "end": 369 + } + }, + { + "generatedToken": { + "token": ".", + "logprob": -0.19333261251449585, + "raw_logprob": -0.535153865814209 + }, + "topTokens": null, + "textRange": { + "start": 369, + "end": 370 + } + }, + { + "generatedToken": { + "token": "<|endoftext|>", + "logprob": -0.03163028880953789, + "raw_logprob": -0.6691970229148865 + }, + "topTokens": null, + "textRange": { + "start": 370, + "end": 370 + } + } + ] + }, + "finishReason": { + "reason": "endoftext" + } + } + ] +} +``` + +### Embedding Models + +Not supported by Jurassic. \ No newline at end of file diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/amazon/titan/README.md b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/amazon/titan/README.md new file mode 100644 index 0000000000..8947646586 --- /dev/null +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/amazon/titan/README.md @@ -0,0 +1,74 @@ +# Amazon + +Examples of the request/response bodies for models that have been tested and verified to work. The instrumentation should continue to correctly process new +models as long as they match the model naming prefixes in `llm.models.SupportedModels` and the request/response structure stays the same as the examples listed +here. + +## Titan Models + +### Text Completion Models + +The following models have been tested: + +* Titan Text G1-Lite (`amazon.titan-text-lite-v1`) +* Titan Text G1-Express (`amazon.titan-text-express-v1`) + +#### Sample Request + +```json +{ + "inputText": "What is the color of the sky?", + "textGenerationConfig": { + "maxTokenCount": 1000, + "stopSequences": [ + "User:" + ], + "temperature": 0.5, + "topP": 0.9 + } +} +``` + +#### Sample Response + +```json +{ + "inputTextTokenCount": 8, + "results": [ + { + "tokenCount": 39, + "outputText": "\nThe color of the sky depends on the time of day, weather conditions, and location. It can range from blue to gray, depending on the presence of clouds and pollutants in the air.", + "completionReason": "FINISH" + } + ] +} +``` + +### Embedding Models + +The following models have been tested: + +* Titan Embeddings G1-Text (`amazon.titan-embed-text-v1`) +* Titan Multimodal Embeddings G1 (`amazon.titan-embed-image-v1`) + +#### Sample Request + +```json +{ + "inputText": "What is the color of the sky?" +} +``` + +#### Sample Response + +```json +{ + "embedding": [ + 0.328125, + ..., + 0.44335938 + ], + "inputTextTokenCount": 8 +} +``` + diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/amazon/titan/TitanModelInvocation.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/amazon/titan/TitanModelInvocation.java new file mode 100644 index 0000000000..b919ae1e94 --- /dev/null +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/amazon/titan/TitanModelInvocation.java @@ -0,0 +1,224 @@ +/* + * + * * Copyright 2024 New Relic Corporation. All rights reserved. + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package llm.models.amazon.titan; + +import com.newrelic.agent.bridge.Token; +import com.newrelic.agent.bridge.Transaction; +import com.newrelic.api.agent.NewRelic; +import com.newrelic.api.agent.Segment; +import com.newrelic.api.agent.Trace; +import llm.events.LlmEvent; +import llm.models.ModelInvocation; +import llm.models.ModelRequest; +import llm.models.ModelResponse; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelRequest; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse; + +import java.util.HashMap; +import java.util.Map; +import java.util.logging.Level; + +import static llm.models.ModelResponse.COMPLETION; +import static llm.models.ModelResponse.EMBEDDING; +import static llm.vendor.Vendor.BEDROCK; + +public class TitanModelInvocation implements ModelInvocation { + Map linkingMetadata; + Map userAttributes; + ModelRequest modelRequest; + ModelResponse modelResponse; + + public TitanModelInvocation(Map linkingMetadata, Map userCustomAttributes, InvokeModelRequest invokeModelRequest, + InvokeModelResponse invokeModelResponse) { + this.linkingMetadata = linkingMetadata; + this.userAttributes = userCustomAttributes; + this.modelRequest = new TitanModelRequest(invokeModelRequest); + this.modelResponse = new TitanModelResponse(invokeModelResponse); + } + + @Override + public void setTracedMethodName(Transaction txn, String functionName) { + txn.getTracedMethod().setMetricName("Llm", modelResponse.getOperationType(), BEDROCK, functionName); + } + + @Override + public void setSegmentName(Segment segment, String functionName) { + segment.setMetricName("Llm", modelResponse.getOperationType(), BEDROCK, functionName); + } + + @Override + public void recordLlmEmbeddingEvent(long startTime, int index) { + if (modelResponse.isErrorResponse()) { + reportLlmError(); + } + + LlmEvent.Builder builder = new LlmEvent.Builder(this); + + LlmEvent llmEmbeddingEvent = builder + .spanId() + .traceId() + .vendor() + .ingestSource() + .id(modelResponse.getLlmEmbeddingId()) + .requestId() + .input(index) + .requestModel() + .responseModel() + .tokenCount(ModelInvocation.getTokenCount(modelRequest.getModelId(), modelRequest.getInputText(0))) + .error() + .duration(System.currentTimeMillis() - startTime) + .build(); + + llmEmbeddingEvent.recordLlmEmbeddingEvent(); + } + + @Override + public void recordLlmChatCompletionSummaryEvent(long startTime, int numberOfMessages) { + if (modelResponse.isErrorResponse()) { + reportLlmError(); + } + + LlmEvent.Builder builder = new LlmEvent.Builder(this); + + LlmEvent llmChatCompletionSummaryEvent = builder + .spanId() + .traceId() + .vendor() + .ingestSource() + .id(modelResponse.getLlmChatCompletionSummaryId()) + .requestId() + .requestTemperature() + .requestMaxTokens() + .requestModel() + .responseModel() + .responseNumberOfMessages(numberOfMessages) + .responseChoicesFinishReason() + .error() + .duration(System.currentTimeMillis() - startTime) + .build(); + + llmChatCompletionSummaryEvent.recordLlmChatCompletionSummaryEvent(); + } + + @Override + public void recordLlmChatCompletionMessageEvent(int sequence, String message, boolean isUser) { + LlmEvent.Builder builder = new LlmEvent.Builder(this); + + LlmEvent llmChatCompletionMessageEvent = builder + .spanId() + .traceId() + .vendor() + .ingestSource() + .id(ModelInvocation.getRandomGuid()) + .content(message) + .role(isUser) + .isResponse(isUser) + .requestId() + .responseModel() + .sequence(sequence) + .completionId() + .tokenCount(ModelInvocation.getTokenCount(modelRequest.getModelId(), message)) + .build(); + + llmChatCompletionMessageEvent.recordLlmChatCompletionMessageEvent(); + } + + @Override + public void recordLlmEvents(long startTime) { + String operationType = modelResponse.getOperationType(); + if (operationType.equals(COMPLETION)) { + recordLlmChatCompletionEvents(startTime); + } else if (operationType.equals(EMBEDDING)) { + recordLlmEmbeddingEvents(startTime); + } else { + NewRelic.getAgent().getLogger().log(Level.INFO, "AIM: Unexpected operation type encountered when trying to record LLM events"); + } + } + + @Trace(async = true) + @Override + public void recordLlmEventsAsync(long startTime, Token token) { + if (token != null && token.isActive()) { + token.linkAndExpire(); + } + recordLlmEvents(startTime); + } + + @Override + public void reportLlmError() { + Map errorParams = new HashMap<>(); + errorParams.put("http.statusCode", modelResponse.getStatusCode()); + errorParams.put("error.code", modelResponse.getStatusCode()); + if (!modelResponse.getLlmChatCompletionSummaryId().isEmpty()) { + errorParams.put("completion_id", modelResponse.getLlmChatCompletionSummaryId()); + } + if (!modelResponse.getLlmEmbeddingId().isEmpty()) { + errorParams.put("embedding_id", modelResponse.getLlmEmbeddingId()); + } + NewRelic.noticeError("LlmError: " + modelResponse.getStatusText(), errorParams); + } + + /** + * Records multiple LlmChatCompletionMessage events and a single LlmChatCompletionSummary event. + * The number of LlmChatCompletionMessage events produced can differ based on vendor. + */ + private void recordLlmChatCompletionEvents(long startTime) { + int numberOfRequestMessages = modelRequest.getNumberOfRequestMessages(); + int numberOfResponseMessages = modelResponse.getNumberOfResponseMessages(); + int totalNumberOfMessages = numberOfRequestMessages + numberOfResponseMessages; + + int sequence = 0; + + // First, record all LlmChatCompletionMessage events representing the user input prompt + for (int i = 0; i < numberOfRequestMessages; i++) { + recordLlmChatCompletionMessageEvent(sequence, modelRequest.getRequestMessage(i), true); + sequence++; + } + + // Second, record all LlmChatCompletionMessage events representing the completion message from the LLM response + for (int i = 0; i < numberOfResponseMessages; i++) { + recordLlmChatCompletionMessageEvent(sequence, modelResponse.getResponseMessage(i), false); + sequence++; + } + + // Finally, record a summary event representing all LlmChatCompletionMessage events + recordLlmChatCompletionSummaryEvent(startTime, totalNumberOfMessages); + } + + /** + * Records one, and potentially more, LlmEmbedding events based on the number of input messages in the request. + * The number of LlmEmbedding events produced can differ based on vendor. + */ + private void recordLlmEmbeddingEvents(long startTime) { + int numberOfRequestMessages = modelRequest.getNumberOfInputTextMessages(); + // Record an LlmEmbedding event for each input message in the request + for (int i = 0; i < numberOfRequestMessages; i++) { + recordLlmEmbeddingEvent(startTime, i); + } + } + + @Override + public Map getLinkingMetadata() { + return linkingMetadata; + } + + @Override + public Map getUserAttributes() { + return userAttributes; + } + + @Override + public ModelRequest getModelRequest() { + return modelRequest; + } + + @Override + public ModelResponse getModelResponse() { + return modelResponse; + } +} diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/amazon/titan/TitanModelRequest.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/amazon/titan/TitanModelRequest.java new file mode 100644 index 0000000000..0221b3fee8 --- /dev/null +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/amazon/titan/TitanModelRequest.java @@ -0,0 +1,177 @@ +/* + * + * * Copyright 2024 New Relic Corporation. All rights reserved. + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package llm.models.amazon.titan; + +import com.newrelic.api.agent.NewRelic; +import llm.models.ModelRequest; +import software.amazon.awssdk.protocols.jsoncore.JsonNode; +import software.amazon.awssdk.protocols.jsoncore.JsonNodeParser; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelRequest; + +import java.util.Collections; +import java.util.Map; +import java.util.logging.Level; + +import static llm.models.ModelRequest.logParsingFailure; + +/** + * Stores the required info from the Bedrock InvokeModelRequest without holding + * a reference to the actual request object to avoid potential memory issues. + */ +public class TitanModelRequest implements ModelRequest { + private static final String MAX_TOKEN_COUNT = "maxTokenCount"; + private static final String TEMPERATURE = "temperature"; + private static final String TEXT_GENERATION_CONFIG = "textGenerationConfig"; + private static final String INPUT_TEXT = "inputText"; + + private String invokeModelRequestBody = ""; + private String modelId = ""; + private Map requestBodyJsonMap = null; + + public TitanModelRequest(InvokeModelRequest invokeModelRequest) { + if (invokeModelRequest != null) { + invokeModelRequestBody = invokeModelRequest.body().asUtf8String(); + modelId = invokeModelRequest.modelId(); + } else { + NewRelic.getAgent().getLogger().log(Level.FINEST, "AIM: Received null InvokeModelRequest"); + } + } + + /** + * Get a map of the Request body contents. + *

+ * Use this method to obtain the Request body contents so that the map is lazily initialized and only parsed once. + * + * @return map of String to JsonNode + */ + private Map getRequestBodyJsonMap() { + if (requestBodyJsonMap == null) { + requestBodyJsonMap = parseInvokeModelRequestBodyMap(); + } + return requestBodyJsonMap; + } + + /** + * Convert JSON Request body string into a map. + * + * @return map of String to JsonNode + */ + private Map parseInvokeModelRequestBodyMap() { + // Use AWS SDK JSON parsing to parse request body + JsonNodeParser jsonNodeParser = JsonNodeParser.create(); + JsonNode requestBodyJsonNode = jsonNodeParser.parse(invokeModelRequestBody); + + Map requestBodyJsonMap = null; + try { + if (requestBodyJsonNode != null && requestBodyJsonNode.isObject()) { + requestBodyJsonMap = requestBodyJsonNode.asObject(); + } else { + logParsingFailure(null, "request body"); + } + } catch (Exception e) { + logParsingFailure(e, "request body"); + } + return requestBodyJsonMap != null ? requestBodyJsonMap : Collections.emptyMap(); + } + + @Override + public int getMaxTokensToSample() { + int maxTokensToSample = 0; + try { + if (!getRequestBodyJsonMap().isEmpty()) { + JsonNode textGenConfigJsonNode = getRequestBodyJsonMap().get(TEXT_GENERATION_CONFIG); + if (textGenConfigJsonNode.isObject()) { + Map textGenConfigJsonNodeObject = textGenConfigJsonNode.asObject(); + if (!textGenConfigJsonNodeObject.isEmpty()) { + JsonNode maxTokenCountJsonNode = textGenConfigJsonNodeObject.get(MAX_TOKEN_COUNT); + if (maxTokenCountJsonNode.isNumber()) { + String maxTokenCountString = maxTokenCountJsonNode.asNumber(); + maxTokensToSample = Integer.parseInt(maxTokenCountString); + } + } + } + } + } catch (Exception e) { + logParsingFailure(e, MAX_TOKEN_COUNT); + } + if (maxTokensToSample == 0) { + logParsingFailure(null, MAX_TOKEN_COUNT); + } + return maxTokensToSample; + } + + @Override + public float getTemperature() { + float temperature = 0f; + try { + if (!getRequestBodyJsonMap().isEmpty()) { + JsonNode textGenConfigJsonNode = getRequestBodyJsonMap().get(TEXT_GENERATION_CONFIG); + if (textGenConfigJsonNode.isObject()) { + Map textGenConfigJsonNodeObject = textGenConfigJsonNode.asObject(); + if (!textGenConfigJsonNodeObject.isEmpty()) { + JsonNode temperatureJsonNode = textGenConfigJsonNodeObject.get(TEMPERATURE); + if (temperatureJsonNode.isNumber()) { + String temperatureString = temperatureJsonNode.asNumber(); + temperature = Float.parseFloat(temperatureString); + } + } + } + } else { + logParsingFailure(null, TEMPERATURE); + } + } catch (Exception e) { + logParsingFailure(e, TEMPERATURE); + } + return temperature; + } + + @Override + public int getNumberOfRequestMessages() { + // The Titan request only ever contains a single inputText message + return 1; + } + + @Override + public String getRequestMessage(int index) { + return parseStringValue(INPUT_TEXT); + } + + @Override + public String getInputText(int index) { + return parseStringValue(INPUT_TEXT); + } + + @Override + public int getNumberOfInputTextMessages() { + // There is only ever a single inputText message for Titan embeddings + return 1; + } + + private String parseStringValue(String fieldToParse) { + String parsedStringValue = ""; + try { + if (!getRequestBodyJsonMap().isEmpty()) { + JsonNode jsonNode = getRequestBodyJsonMap().get(fieldToParse); + if (jsonNode.isString()) { + parsedStringValue = jsonNode.asString(); + } + } + } catch (Exception e) { + logParsingFailure(e, fieldToParse); + } + if (parsedStringValue.isEmpty()) { + logParsingFailure(null, fieldToParse); + } + return parsedStringValue; + } + + @Override + public String getModelId() { + return modelId; + } +} diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/amazon/titan/TitanModelResponse.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/amazon/titan/TitanModelResponse.java new file mode 100644 index 0000000000..ea1e741bdf --- /dev/null +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/amazon/titan/TitanModelResponse.java @@ -0,0 +1,242 @@ +/* + * + * * Copyright 2024 New Relic Corporation. All rights reserved. + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package llm.models.amazon.titan; + +import com.newrelic.api.agent.NewRelic; +import llm.models.ModelResponse; +import software.amazon.awssdk.protocols.jsoncore.JsonNode; +import software.amazon.awssdk.protocols.jsoncore.JsonNodeParser; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse; + +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.logging.Level; + +import static llm.models.ModelInvocation.getRandomGuid; +import static llm.models.ModelResponse.logParsingFailure; + +/** + * Stores the required info from the Bedrock InvokeModelResponse without holding + * a reference to the actual request object to avoid potential memory issues. + */ +public class TitanModelResponse implements ModelResponse { + private static final String COMPLETION_REASON = "completionReason"; + private static final String RESULTS = "results"; + private static final String OUTPUT_TEXT = "outputText"; + + private String amznRequestId = ""; + + // LLM operation type + private String operationType = ""; + + // HTTP response + private boolean isSuccessfulResponse = false; + private int statusCode = 0; + private String statusText = ""; + + private String llmChatCompletionSummaryId = ""; + private String llmEmbeddingId = ""; + + private String invokeModelResponseBody = ""; + private Map responseBodyJsonMap = null; + + public TitanModelResponse(InvokeModelResponse invokeModelResponse) { + if (invokeModelResponse != null) { + invokeModelResponseBody = invokeModelResponse.body().asUtf8String(); + isSuccessfulResponse = invokeModelResponse.sdkHttpResponse().isSuccessful(); + statusCode = invokeModelResponse.sdkHttpResponse().statusCode(); + Optional statusTextOptional = invokeModelResponse.sdkHttpResponse().statusText(); + statusTextOptional.ifPresent(s -> statusText = s); + setOperationType(invokeModelResponseBody); + amznRequestId = invokeModelResponse.responseMetadata().requestId(); + llmChatCompletionSummaryId = getRandomGuid(); + llmEmbeddingId = getRandomGuid(); + } else { + NewRelic.getAgent().getLogger().log(Level.INFO, "AIM: Received null InvokeModelResponse"); + } + } + + /** + * Get a map of the Response body contents. + *

+ * Use this method to obtain the Response body contents so that the map is lazily initialized and only parsed once. + * + * @return map of String to JsonNode + */ + private Map getResponseBodyJsonMap() { + if (responseBodyJsonMap == null) { + responseBodyJsonMap = parseInvokeModelResponseBodyMap(); + } + return responseBodyJsonMap; + } + + /** + * Convert JSON Response body string into a map. + * + * @return map of String to JsonNode + */ + private Map parseInvokeModelResponseBodyMap() { + Map responseBodyJsonMap = null; + try { + // Use AWS SDK JSON parsing to parse response body + JsonNodeParser jsonNodeParser = JsonNodeParser.create(); + JsonNode responseBodyJsonNode = jsonNodeParser.parse(invokeModelResponseBody); + + if (responseBodyJsonNode != null && responseBodyJsonNode.isObject()) { + responseBodyJsonMap = responseBodyJsonNode.asObject(); + } else { + logParsingFailure(null, "response body"); + } + } catch (Exception e) { + logParsingFailure(e, "response body"); + } + return responseBodyJsonMap != null ? responseBodyJsonMap : Collections.emptyMap(); + } + + /** + * Parses the operation type from the response body and assigns it to a field. + * + * @param invokeModelResponseBody response body String + */ + private void setOperationType(String invokeModelResponseBody) { + try { + if (!invokeModelResponseBody.isEmpty()) { + if (invokeModelResponseBody.contains(COMPLETION_REASON)) { + operationType = COMPLETION; + } else if (invokeModelResponseBody.contains(EMBEDDING)) { + operationType = EMBEDDING; + } else { + logParsingFailure(null, "operation type"); + } + } + } catch (Exception e) { + logParsingFailure(e, "operation type"); + } + } + + @Override + public String getResponseMessage(int index) { + String parsedResponseMessage = ""; + try { + if (!getResponseBodyJsonMap().isEmpty()) { + JsonNode jsonNode = getResponseBodyJsonMap().get(RESULTS); + if (jsonNode.isArray()) { + List resultsJsonNodeArray = jsonNode.asArray(); + if (!resultsJsonNodeArray.isEmpty()) { + JsonNode resultsJsonNode = resultsJsonNodeArray.get(index); + if (resultsJsonNode.isObject()) { + Map resultsJsonNodeObject = resultsJsonNode.asObject(); + if (!resultsJsonNodeObject.isEmpty()) { + JsonNode outputTextJsonNode = resultsJsonNodeObject.get(OUTPUT_TEXT); + if (outputTextJsonNode.isString()) { + parsedResponseMessage = outputTextJsonNode.asString(); + } + } + } + } + } + } + } catch (Exception e) { + logParsingFailure(e, OUTPUT_TEXT); + } + if (parsedResponseMessage.isEmpty()) { + logParsingFailure(null, OUTPUT_TEXT); + } + return parsedResponseMessage; + } + + @Override + public int getNumberOfResponseMessages() { + int numberOfResponseMessages = 0; + try { + if (!getResponseBodyJsonMap().isEmpty()) { + JsonNode jsonNode = getResponseBodyJsonMap().get(RESULTS); + if (jsonNode.isArray()) { + List resultsJsonNodeArray = jsonNode.asArray(); + if (!resultsJsonNodeArray.isEmpty()) { + numberOfResponseMessages = resultsJsonNodeArray.size(); + } + } + } + } catch (Exception e) { + logParsingFailure(e, RESULTS); + } + if (numberOfResponseMessages == 0) { + logParsingFailure(null, RESULTS); + } + return numberOfResponseMessages; + } + + @Override + public String getStopReason() { + String parsedStopReason = ""; + try { + if (!getResponseBodyJsonMap().isEmpty()) { + JsonNode jsonNode = getResponseBodyJsonMap().get(RESULTS); + if (jsonNode.isArray()) { + List resultsJsonNodeArray = jsonNode.asArray(); + if (!resultsJsonNodeArray.isEmpty()) { + JsonNode resultsJsonNode = resultsJsonNodeArray.get(0); + if (resultsJsonNode.isObject()) { + Map resultsJsonNodeObject = resultsJsonNode.asObject(); + if (!resultsJsonNodeObject.isEmpty()) { + JsonNode outputTextJsonNode = resultsJsonNodeObject.get(COMPLETION_REASON); + if (outputTextJsonNode.isString()) { + parsedStopReason = outputTextJsonNode.asString(); + } + } + } + } + } + } + } catch (Exception e) { + logParsingFailure(e, COMPLETION_REASON); + } + if (parsedStopReason.isEmpty()) { + logParsingFailure(null, COMPLETION_REASON); + } + return parsedStopReason; + } + + @Override + public String getAmznRequestId() { + return amznRequestId; + } + + @Override + public String getOperationType() { + return operationType; + } + + @Override + public String getLlmChatCompletionSummaryId() { + return llmChatCompletionSummaryId; + } + + @Override + public String getLlmEmbeddingId() { + return llmEmbeddingId; + } + + @Override + public boolean isErrorResponse() { + return !isSuccessfulResponse; + } + + @Override + public int getStatusCode() { + return statusCode; + } + + @Override + public String getStatusText() { + return statusText; + } +} diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/ClaudeModelInvocation.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/ClaudeModelInvocation.java new file mode 100644 index 0000000000..f1861fba2c --- /dev/null +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/ClaudeModelInvocation.java @@ -0,0 +1,224 @@ +/* + * + * * Copyright 2024 New Relic Corporation. All rights reserved. + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package llm.models.anthropic.claude; + +import com.newrelic.agent.bridge.Token; +import com.newrelic.agent.bridge.Transaction; +import com.newrelic.api.agent.NewRelic; +import com.newrelic.api.agent.Segment; +import com.newrelic.api.agent.Trace; +import llm.events.LlmEvent; +import llm.models.ModelInvocation; +import llm.models.ModelRequest; +import llm.models.ModelResponse; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelRequest; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse; + +import java.util.HashMap; +import java.util.Map; +import java.util.logging.Level; + +import static llm.models.ModelResponse.COMPLETION; +import static llm.models.ModelResponse.EMBEDDING; +import static llm.vendor.Vendor.BEDROCK; + +public class ClaudeModelInvocation implements ModelInvocation { + Map linkingMetadata; + Map userAttributes; + ModelRequest modelRequest; + ModelResponse modelResponse; + + public ClaudeModelInvocation(Map linkingMetadata, Map userCustomAttributes, InvokeModelRequest invokeModelRequest, + InvokeModelResponse invokeModelResponse) { + this.linkingMetadata = linkingMetadata; + this.userAttributes = userCustomAttributes; + this.modelRequest = new ClaudeModelRequest(invokeModelRequest); + this.modelResponse = new ClaudeModelResponse(invokeModelResponse); + } + + @Override + public void setTracedMethodName(Transaction txn, String functionName) { + txn.getTracedMethod().setMetricName("Llm", modelResponse.getOperationType(), BEDROCK, functionName); + } + + @Override + public void setSegmentName(Segment segment, String functionName) { + segment.setMetricName("Llm", modelResponse.getOperationType(), BEDROCK, functionName); + } + + @Override + public void recordLlmEmbeddingEvent(long startTime, int index) { + if (modelResponse.isErrorResponse()) { + reportLlmError(); + } + + LlmEvent.Builder builder = new LlmEvent.Builder(this); + + LlmEvent llmEmbeddingEvent = builder + .spanId() + .traceId() + .vendor() + .ingestSource() + .id(modelResponse.getLlmEmbeddingId()) + .requestId() + .input(index) + .requestModel() + .responseModel() + .tokenCount(ModelInvocation.getTokenCount(modelRequest.getModelId(), modelRequest.getInputText(0))) + .error() + .duration(System.currentTimeMillis() - startTime) + .build(); + + llmEmbeddingEvent.recordLlmEmbeddingEvent(); + } + + @Override + public void recordLlmChatCompletionSummaryEvent(long startTime, int numberOfMessages) { + if (modelResponse.isErrorResponse()) { + reportLlmError(); + } + + LlmEvent.Builder builder = new LlmEvent.Builder(this); + + LlmEvent llmChatCompletionSummaryEvent = builder + .spanId() + .traceId() + .vendor() + .ingestSource() + .id(modelResponse.getLlmChatCompletionSummaryId()) + .requestId() + .requestTemperature() + .requestMaxTokens() + .requestModel() + .responseModel() + .responseNumberOfMessages(numberOfMessages) + .responseChoicesFinishReason() + .error() + .duration(System.currentTimeMillis() - startTime) + .build(); + + llmChatCompletionSummaryEvent.recordLlmChatCompletionSummaryEvent(); + } + + @Override + public void recordLlmChatCompletionMessageEvent(int sequence, String message, boolean isUser) { + LlmEvent.Builder builder = new LlmEvent.Builder(this); + + LlmEvent llmChatCompletionMessageEvent = builder + .spanId() + .traceId() + .vendor() + .ingestSource() + .id(ModelInvocation.getRandomGuid()) + .content(message) + .role(isUser) + .isResponse(isUser) + .requestId() + .responseModel() + .sequence(sequence) + .completionId() + .tokenCount(ModelInvocation.getTokenCount(modelRequest.getModelId(), message)) + .build(); + + llmChatCompletionMessageEvent.recordLlmChatCompletionMessageEvent(); + } + + @Override + public void recordLlmEvents(long startTime) { + String operationType = modelResponse.getOperationType(); + if (operationType.equals(COMPLETION)) { + recordLlmChatCompletionEvents(startTime); + } else if (operationType.equals(EMBEDDING)) { + recordLlmEmbeddingEvents(startTime); + } else { + NewRelic.getAgent().getLogger().log(Level.INFO, "AIM: Unexpected operation type encountered when trying to record LLM events"); + } + } + + @Trace(async = true) + @Override + public void recordLlmEventsAsync(long startTime, Token token) { + if (token != null && token.isActive()) { + token.linkAndExpire(); + } + recordLlmEvents(startTime); + } + + @Override + public void reportLlmError() { + Map errorParams = new HashMap<>(); + errorParams.put("http.statusCode", modelResponse.getStatusCode()); + errorParams.put("error.code", modelResponse.getStatusCode()); + if (!modelResponse.getLlmChatCompletionSummaryId().isEmpty()) { + errorParams.put("completion_id", modelResponse.getLlmChatCompletionSummaryId()); + } + if (!modelResponse.getLlmEmbeddingId().isEmpty()) { + errorParams.put("embedding_id", modelResponse.getLlmEmbeddingId()); + } + NewRelic.noticeError("LlmError: " + modelResponse.getStatusText(), errorParams); + } + + /** + * Records multiple LlmChatCompletionMessage events and a single LlmChatCompletionSummary event. + * The number of LlmChatCompletionMessage events produced can differ based on vendor. + */ + private void recordLlmChatCompletionEvents(long startTime) { + int numberOfRequestMessages = modelRequest.getNumberOfRequestMessages(); + int numberOfResponseMessages = modelResponse.getNumberOfResponseMessages(); + int totalNumberOfMessages = numberOfRequestMessages + numberOfResponseMessages; + + int sequence = 0; + + // First, record all LlmChatCompletionMessage events representing the user input prompt + for (int i = 0; i < numberOfRequestMessages; i++) { + recordLlmChatCompletionMessageEvent(sequence, modelRequest.getRequestMessage(i), true); + sequence++; + } + + // Second, record all LlmChatCompletionMessage events representing the completion message from the LLM response + for (int i = 0; i < numberOfResponseMessages; i++) { + recordLlmChatCompletionMessageEvent(sequence, modelResponse.getResponseMessage(i), false); + sequence++; + } + + // Finally, record a summary event representing all LlmChatCompletionMessage events + recordLlmChatCompletionSummaryEvent(startTime, totalNumberOfMessages); + } + + /** + * Records one, and potentially more, LlmEmbedding events based on the number of input messages in the request. + * The number of LlmEmbedding events produced can differ based on vendor. + */ + private void recordLlmEmbeddingEvents(long startTime) { + int numberOfRequestMessages = modelRequest.getNumberOfInputTextMessages(); + // Record an LlmEmbedding event for each input message in the request + for (int i = 0; i < numberOfRequestMessages; i++) { + recordLlmEmbeddingEvent(startTime, i); + } + } + + @Override + public Map getLinkingMetadata() { + return linkingMetadata; + } + + @Override + public Map getUserAttributes() { + return userAttributes; + } + + @Override + public ModelRequest getModelRequest() { + return modelRequest; + } + + @Override + public ModelResponse getModelResponse() { + return modelResponse; + } +} diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/ClaudeModelRequest.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/ClaudeModelRequest.java new file mode 100644 index 0000000000..0355983642 --- /dev/null +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/ClaudeModelRequest.java @@ -0,0 +1,165 @@ +/* + * + * * Copyright 2024 New Relic Corporation. All rights reserved. + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package llm.models.anthropic.claude; + +import com.newrelic.api.agent.NewRelic; +import llm.models.ModelRequest; +import software.amazon.awssdk.protocols.jsoncore.JsonNode; +import software.amazon.awssdk.protocols.jsoncore.JsonNodeParser; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelRequest; + +import java.util.Collections; +import java.util.Map; +import java.util.logging.Level; + +import static llm.models.ModelRequest.logParsingFailure; + +/** + * Stores the required info from the Bedrock InvokeModelRequest without holding + * a reference to the actual request object to avoid potential memory issues. + */ +public class ClaudeModelRequest implements ModelRequest { + private static final String MAX_TOKENS_TO_SAMPLE = "max_tokens_to_sample"; + private static final String TEMPERATURE = "temperature"; + private static final String PROMPT = "prompt"; + + private String invokeModelRequestBody = ""; + private String modelId = ""; + private Map requestBodyJsonMap = null; + + public ClaudeModelRequest(InvokeModelRequest invokeModelRequest) { + if (invokeModelRequest != null) { + invokeModelRequestBody = invokeModelRequest.body().asUtf8String(); + modelId = invokeModelRequest.modelId(); + } else { + NewRelic.getAgent().getLogger().log(Level.FINEST, "AIM: Received null InvokeModelRequest"); + } + } + + /** + * Get a map of the Request body contents. + *

+ * Use this method to obtain the Request body contents so that the map is lazily initialized and only parsed once. + * + * @return map of String to JsonNode + */ + private Map getRequestBodyJsonMap() { + if (requestBodyJsonMap == null) { + requestBodyJsonMap = parseInvokeModelRequestBodyMap(); + } + return requestBodyJsonMap; + } + + /** + * Convert JSON Request body string into a map. + * + * @return map of String to JsonNode + */ + private Map parseInvokeModelRequestBodyMap() { + // Use AWS SDK JSON parsing to parse request body + JsonNodeParser jsonNodeParser = JsonNodeParser.create(); + JsonNode requestBodyJsonNode = jsonNodeParser.parse(invokeModelRequestBody); + + Map requestBodyJsonMap = null; + try { + if (requestBodyJsonNode != null && requestBodyJsonNode.isObject()) { + requestBodyJsonMap = requestBodyJsonNode.asObject(); + } else { + logParsingFailure(null, "request body"); + } + } catch (Exception e) { + logParsingFailure(e, "request body"); + } + return requestBodyJsonMap != null ? requestBodyJsonMap : Collections.emptyMap(); + } + + @Override + public int getMaxTokensToSample() { + int maxTokensToSample = 0; + try { + if (!getRequestBodyJsonMap().isEmpty()) { + JsonNode jsonNode = getRequestBodyJsonMap().get(MAX_TOKENS_TO_SAMPLE); + if (jsonNode.isNumber()) { + String maxTokensToSampleString = jsonNode.asNumber(); + maxTokensToSample = Integer.parseInt(maxTokensToSampleString); + } + } + } catch (Exception e) { + logParsingFailure(e, MAX_TOKENS_TO_SAMPLE); + } + if (maxTokensToSample == 0) { + logParsingFailure(null, MAX_TOKENS_TO_SAMPLE); + } + return maxTokensToSample; + } + + @Override + public float getTemperature() { + float temperature = 0f; + try { + if (!getRequestBodyJsonMap().isEmpty()) { + JsonNode jsonNode = getRequestBodyJsonMap().get(TEMPERATURE); + if (jsonNode.isNumber()) { + String temperatureString = jsonNode.asNumber(); + temperature = Float.parseFloat(temperatureString); + } + } else { + logParsingFailure(null, TEMPERATURE); + } + } catch (Exception e) { + logParsingFailure(e, TEMPERATURE); + } + return temperature; + } + + @Override + public int getNumberOfRequestMessages() { + // The Claude request only ever contains a single prompt message + return 1; + } + + @Override + public String getRequestMessage(int index) { + return parseStringValue(PROMPT); + } + + @Override + public String getInputText(int index) { + // This is a NoOp for Claude as it doesn't support embeddings + return ""; + } + + @Override + public int getNumberOfInputTextMessages() { + // This is a NoOp for Llama as it doesn't support embeddings + return 0; + } + + private String parseStringValue(String fieldToParse) { + String parsedStringValue = ""; + try { + if (!getRequestBodyJsonMap().isEmpty()) { + JsonNode jsonNode = getRequestBodyJsonMap().get(fieldToParse); + if (jsonNode.isString()) { + parsedStringValue = jsonNode.asString(); + } + } + } catch (Exception e) { + logParsingFailure(e, fieldToParse); + } + if (parsedStringValue.isEmpty()) { + logParsingFailure(null, fieldToParse); + } + return parsedStringValue; + } + + @Override + public String getModelId() { + return modelId; + } +} diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/ClaudeModelResponse.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/ClaudeModelResponse.java new file mode 100644 index 0000000000..d34c3dd3ac --- /dev/null +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/ClaudeModelResponse.java @@ -0,0 +1,188 @@ +/* + * + * * Copyright 2024 New Relic Corporation. All rights reserved. + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package llm.models.anthropic.claude; + +import com.newrelic.api.agent.NewRelic; +import llm.models.ModelResponse; +import software.amazon.awssdk.protocols.jsoncore.JsonNode; +import software.amazon.awssdk.protocols.jsoncore.JsonNodeParser; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse; + +import java.util.Collections; +import java.util.Map; +import java.util.Optional; +import java.util.logging.Level; + +import static llm.models.ModelInvocation.getRandomGuid; +import static llm.models.ModelResponse.logParsingFailure; + +/** + * Stores the required info from the Bedrock InvokeModelResponse without holding + * a reference to the actual request object to avoid potential memory issues. + */ +public class ClaudeModelResponse implements ModelResponse { + private static final String STOP_REASON = "stop_reason"; + + private String amznRequestId = ""; + + // LLM operation type + private String operationType = ""; + + // HTTP response + private boolean isSuccessfulResponse = false; + private int statusCode = 0; + private String statusText = ""; + + private String llmChatCompletionSummaryId = ""; + private String llmEmbeddingId = ""; + + private String invokeModelResponseBody = ""; + private Map responseBodyJsonMap = null; + + public ClaudeModelResponse(InvokeModelResponse invokeModelResponse) { + if (invokeModelResponse != null) { + invokeModelResponseBody = invokeModelResponse.body().asUtf8String(); + isSuccessfulResponse = invokeModelResponse.sdkHttpResponse().isSuccessful(); + statusCode = invokeModelResponse.sdkHttpResponse().statusCode(); + Optional statusTextOptional = invokeModelResponse.sdkHttpResponse().statusText(); + statusTextOptional.ifPresent(s -> statusText = s); + setOperationType(invokeModelResponseBody); + amznRequestId = invokeModelResponse.responseMetadata().requestId(); + llmChatCompletionSummaryId = getRandomGuid(); + llmEmbeddingId = getRandomGuid(); + } else { + NewRelic.getAgent().getLogger().log(Level.INFO, "AIM: Received null InvokeModelResponse"); + } + } + + /** + * Get a map of the Response body contents. + *

+ * Use this method to obtain the Response body contents so that the map is lazily initialized and only parsed once. + * + * @return map of String to JsonNode + */ + private Map getResponseBodyJsonMap() { + if (responseBodyJsonMap == null) { + responseBodyJsonMap = parseInvokeModelResponseBodyMap(); + } + return responseBodyJsonMap; + } + + /** + * Convert JSON Response body string into a map. + * + * @return map of String to JsonNode + */ + private Map parseInvokeModelResponseBodyMap() { + Map responseBodyJsonMap = null; + try { + // Use AWS SDK JSON parsing to parse response body + JsonNodeParser jsonNodeParser = JsonNodeParser.create(); + JsonNode responseBodyJsonNode = jsonNodeParser.parse(invokeModelResponseBody); + + if (responseBodyJsonNode != null && responseBodyJsonNode.isObject()) { + responseBodyJsonMap = responseBodyJsonNode.asObject(); + } else { + logParsingFailure(null, "response body"); + } + } catch (Exception e) { + logParsingFailure(e, "response body"); + } + return responseBodyJsonMap != null ? responseBodyJsonMap : Collections.emptyMap(); + } + + /** + * Parses the operation type from the response body and assigns it to a field. + * + * @param invokeModelResponseBody response body String + */ + private void setOperationType(String invokeModelResponseBody) { + try { + if (!invokeModelResponseBody.isEmpty()) { + // Claude for Bedrock doesn't support embedding operations + if (invokeModelResponseBody.contains(COMPLETION)) { + operationType = COMPLETION; + } else { + logParsingFailure(null, "operation type"); + } + } + } catch (Exception e) { + logParsingFailure(e, "operation type"); + } + } + + @Override + public String getResponseMessage(int index) { + return parseStringValue(COMPLETION); + } + + @Override + public int getNumberOfResponseMessages() { + // There is only ever a single response message + return 1; + } + + @Override + public String getStopReason() { + return parseStringValue(STOP_REASON); + } + + private String parseStringValue(String fieldToParse) { + String parsedStringValue = ""; + try { + if (!getResponseBodyJsonMap().isEmpty()) { + JsonNode jsonNode = getResponseBodyJsonMap().get(fieldToParse); + if (jsonNode.isString()) { + parsedStringValue = jsonNode.asString(); + } + } + } catch (Exception e) { + logParsingFailure(e, fieldToParse); + } + if (parsedStringValue.isEmpty()) { + logParsingFailure(null, fieldToParse); + } + return parsedStringValue; + } + + @Override + public String getAmznRequestId() { + return amznRequestId; + } + + @Override + public String getOperationType() { + return operationType; + } + + @Override + public String getLlmChatCompletionSummaryId() { + return llmChatCompletionSummaryId; + } + + @Override + public String getLlmEmbeddingId() { + return llmEmbeddingId; + } + + @Override + public boolean isErrorResponse() { + return !isSuccessfulResponse; + } + + @Override + public int getStatusCode() { + return statusCode; + } + + @Override + public String getStatusText() { + return statusText; + } +} diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/README.md b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/README.md new file mode 100644 index 0000000000..345695d03c --- /dev/null +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/anthropic/claude/README.md @@ -0,0 +1,41 @@ +# Anthropic + +Examples of the request/response bodies for models that have been tested and verified to work. The instrumentation should continue to correctly process new +models as long as they match the model naming prefixes in `llm.models.SupportedModels` and the request/response structure stays the same as the examples listed +here. + +## Claude Models + +### Text Completion Models + +The following models have been tested: + +* Claude(`anthropic.claude-v2`, `anthropic.claude-v2:1`) +* Claude Instant(`anthropic.claude-instant-v1`) + +#### Sample Request + +```json +{ + "stop_sequences": [ + "\n\nHuman:" + ], + "max_tokens_to_sample": 1000, + "temperature": 0.5, + "prompt": "Human: What is the color of the sky?\n\nAssistant:" +} +``` + +#### Sample Response + +```json +{ + "completion": " The sky appears blue during the day because molecules in the air scatter blue light from the sun more than they scatter red light. The actual color of the sky varies some based on atmospheric conditions, but the primary color we perceive is blue.", + "stop_reason": "stop_sequence", + "stop": "\n\nHuman:" +} +``` + +### Embedding Models + +Not supported by Claude. diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/cohere/command/CommandModelInvocation.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/cohere/command/CommandModelInvocation.java new file mode 100644 index 0000000000..729900e38b --- /dev/null +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/cohere/command/CommandModelInvocation.java @@ -0,0 +1,224 @@ +/* + * + * * Copyright 2024 New Relic Corporation. All rights reserved. + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package llm.models.cohere.command; + +import com.newrelic.agent.bridge.Token; +import com.newrelic.agent.bridge.Transaction; +import com.newrelic.api.agent.NewRelic; +import com.newrelic.api.agent.Segment; +import com.newrelic.api.agent.Trace; +import llm.events.LlmEvent; +import llm.models.ModelInvocation; +import llm.models.ModelRequest; +import llm.models.ModelResponse; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelRequest; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse; + +import java.util.HashMap; +import java.util.Map; +import java.util.logging.Level; + +import static llm.models.ModelResponse.COMPLETION; +import static llm.models.ModelResponse.EMBEDDING; +import static llm.vendor.Vendor.BEDROCK; + +public class CommandModelInvocation implements ModelInvocation { + Map linkingMetadata; + Map userAttributes; + ModelRequest modelRequest; + ModelResponse modelResponse; + + public CommandModelInvocation(Map linkingMetadata, Map userCustomAttributes, InvokeModelRequest invokeModelRequest, + InvokeModelResponse invokeModelResponse) { + this.linkingMetadata = linkingMetadata; + this.userAttributes = userCustomAttributes; + this.modelRequest = new CommandModelRequest(invokeModelRequest); + this.modelResponse = new CommandModelResponse(invokeModelResponse); + } + + @Override + public void setTracedMethodName(Transaction txn, String functionName) { + txn.getTracedMethod().setMetricName("Llm", modelResponse.getOperationType(), BEDROCK, functionName); + } + + @Override + public void setSegmentName(Segment segment, String functionName) { + segment.setMetricName("Llm", modelResponse.getOperationType(), BEDROCK, functionName); + } + + @Override + public void recordLlmEmbeddingEvent(long startTime, int index) { + if (modelResponse.isErrorResponse()) { + reportLlmError(); + } + + LlmEvent.Builder builder = new LlmEvent.Builder(this); + + LlmEvent llmEmbeddingEvent = builder + .spanId() + .traceId() + .vendor() + .ingestSource() + .id(modelResponse.getLlmEmbeddingId()) + .requestId() + .input(index) + .requestModel() + .responseModel() + .tokenCount(ModelInvocation.getTokenCount(modelRequest.getModelId(), modelRequest.getInputText(0))) + .error() + .duration(System.currentTimeMillis() - startTime) + .build(); + + llmEmbeddingEvent.recordLlmEmbeddingEvent(); + } + + @Override + public void recordLlmChatCompletionSummaryEvent(long startTime, int numberOfMessages) { + if (modelResponse.isErrorResponse()) { + reportLlmError(); + } + + LlmEvent.Builder builder = new LlmEvent.Builder(this); + + LlmEvent llmChatCompletionSummaryEvent = builder + .spanId() + .traceId() + .vendor() + .ingestSource() + .id(modelResponse.getLlmChatCompletionSummaryId()) + .requestId() + .requestTemperature() + .requestMaxTokens() + .requestModel() + .responseModel() + .responseNumberOfMessages(numberOfMessages) + .responseChoicesFinishReason() + .error() + .duration(System.currentTimeMillis() - startTime) + .build(); + + llmChatCompletionSummaryEvent.recordLlmChatCompletionSummaryEvent(); + } + + @Override + public void recordLlmChatCompletionMessageEvent(int sequence, String message, boolean isUser) { + LlmEvent.Builder builder = new LlmEvent.Builder(this); + + LlmEvent llmChatCompletionMessageEvent = builder + .spanId() + .traceId() + .vendor() + .ingestSource() + .id(ModelInvocation.getRandomGuid()) + .content(message) + .role(isUser) + .isResponse(isUser) + .requestId() + .responseModel() + .sequence(sequence) + .completionId() + .tokenCount(ModelInvocation.getTokenCount(modelRequest.getModelId(), message)) + .build(); + + llmChatCompletionMessageEvent.recordLlmChatCompletionMessageEvent(); + } + + @Override + public void recordLlmEvents(long startTime) { + String operationType = modelResponse.getOperationType(); + if (operationType.equals(COMPLETION)) { + recordLlmChatCompletionEvents(startTime); + } else if (operationType.equals(EMBEDDING)) { + recordLlmEmbeddingEvents(startTime); + } else { + NewRelic.getAgent().getLogger().log(Level.INFO, "AIM: Unexpected operation type encountered when trying to record LLM events"); + } + } + + @Trace(async = true) + @Override + public void recordLlmEventsAsync(long startTime, Token token) { + if (token != null && token.isActive()) { + token.linkAndExpire(); + } + recordLlmEvents(startTime); + } + + @Override + public void reportLlmError() { + Map errorParams = new HashMap<>(); + errorParams.put("http.statusCode", modelResponse.getStatusCode()); + errorParams.put("error.code", modelResponse.getStatusCode()); + if (!modelResponse.getLlmChatCompletionSummaryId().isEmpty()) { + errorParams.put("completion_id", modelResponse.getLlmChatCompletionSummaryId()); + } + if (!modelResponse.getLlmEmbeddingId().isEmpty()) { + errorParams.put("embedding_id", modelResponse.getLlmEmbeddingId()); + } + NewRelic.noticeError("LlmError: " + modelResponse.getStatusText(), errorParams); + } + + /** + * Records multiple LlmChatCompletionMessage events and a single LlmChatCompletionSummary event. + * The number of LlmChatCompletionMessage events produced can differ based on vendor. + */ + private void recordLlmChatCompletionEvents(long startTime) { + int numberOfRequestMessages = modelRequest.getNumberOfRequestMessages(); + int numberOfResponseMessages = modelResponse.getNumberOfResponseMessages(); + int totalNumberOfMessages = numberOfRequestMessages + numberOfResponseMessages; + + int sequence = 0; + + // First, record all LlmChatCompletionMessage events representing the user input prompt + for (int i = 0; i < numberOfRequestMessages; i++) { + recordLlmChatCompletionMessageEvent(sequence, modelRequest.getRequestMessage(i), true); + sequence++; + } + + // Second, record all LlmChatCompletionMessage events representing the completion message from the LLM response + for (int i = 0; i < numberOfResponseMessages; i++) { + recordLlmChatCompletionMessageEvent(sequence, modelResponse.getResponseMessage(i), false); + sequence++; + } + + // Finally, record a summary event representing all LlmChatCompletionMessage events + recordLlmChatCompletionSummaryEvent(startTime, totalNumberOfMessages); + } + + /** + * Records one, and potentially more, LlmEmbedding events based on the number of input messages in the request. + * The number of LlmEmbedding events produced can differ based on vendor. + */ + private void recordLlmEmbeddingEvents(long startTime) { + int numberOfRequestMessages = modelRequest.getNumberOfInputTextMessages(); + // Record an LlmEmbedding event for each input message in the request + for (int i = 0; i < numberOfRequestMessages; i++) { + recordLlmEmbeddingEvent(startTime, i); + } + } + + @Override + public Map getLinkingMetadata() { + return linkingMetadata; + } + + @Override + public Map getUserAttributes() { + return userAttributes; + } + + @Override + public ModelRequest getModelRequest() { + return modelRequest; + } + + @Override + public ModelResponse getModelResponse() { + return modelResponse; + } +} diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/cohere/command/CommandModelRequest.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/cohere/command/CommandModelRequest.java new file mode 100644 index 0000000000..12c4218ce9 --- /dev/null +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/cohere/command/CommandModelRequest.java @@ -0,0 +1,202 @@ +/* + * + * * Copyright 2024 New Relic Corporation. All rights reserved. + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package llm.models.cohere.command; + +import com.newrelic.api.agent.NewRelic; +import llm.models.ModelRequest; +import software.amazon.awssdk.protocols.jsoncore.JsonNode; +import software.amazon.awssdk.protocols.jsoncore.JsonNodeParser; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelRequest; + +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.logging.Level; + +import static llm.models.ModelRequest.logParsingFailure; + +/** + * Stores the required info from the Bedrock InvokeModelRequest without holding + * a reference to the actual request object to avoid potential memory issues. + */ +public class CommandModelRequest implements ModelRequest { + private static final String MAX_TOKENS = "max_tokens"; + private static final String TEMPERATURE = "temperature"; + private static final String PROMPT = "prompt"; + private static final String TEXTS = "texts"; + + private String invokeModelRequestBody = ""; + private String modelId = ""; + private Map requestBodyJsonMap = null; + + public CommandModelRequest(InvokeModelRequest invokeModelRequest) { + if (invokeModelRequest != null) { + invokeModelRequestBody = invokeModelRequest.body().asUtf8String(); + modelId = invokeModelRequest.modelId(); + } else { + NewRelic.getAgent().getLogger().log(Level.FINEST, "AIM: Received null InvokeModelRequest"); + } + } + + /** + * Get a map of the Request body contents. + *

+ * Use this method to obtain the Request body contents so that the map is lazily initialized and only parsed once. + * + * @return map of String to JsonNode + */ + private Map getRequestBodyJsonMap() { + if (requestBodyJsonMap == null) { + requestBodyJsonMap = parseInvokeModelRequestBodyMap(); + } + return requestBodyJsonMap; + } + + /** + * Convert JSON Request body string into a map. + * + * @return map of String to JsonNode + */ + private Map parseInvokeModelRequestBodyMap() { + // Use AWS SDK JSON parsing to parse request body + JsonNodeParser jsonNodeParser = JsonNodeParser.create(); + JsonNode requestBodyJsonNode = jsonNodeParser.parse(invokeModelRequestBody); + + Map requestBodyJsonMap = null; + try { + if (requestBodyJsonNode != null && requestBodyJsonNode.isObject()) { + requestBodyJsonMap = requestBodyJsonNode.asObject(); + } else { + logParsingFailure(null, "request body"); + } + } catch (Exception e) { + logParsingFailure(e, "request body"); + } + return requestBodyJsonMap != null ? requestBodyJsonMap : Collections.emptyMap(); + } + + @Override + public int getMaxTokensToSample() { + int maxTokensToSample = 0; + try { + if (!getRequestBodyJsonMap().isEmpty()) { + JsonNode jsonNode = getRequestBodyJsonMap().get(MAX_TOKENS); + if (jsonNode.isNumber()) { + String maxTokensToSampleString = jsonNode.asNumber(); + maxTokensToSample = Integer.parseInt(maxTokensToSampleString); + } + } + } catch (Exception e) { + logParsingFailure(e, MAX_TOKENS); + } + if (maxTokensToSample == 0) { + logParsingFailure(null, MAX_TOKENS); + } + return maxTokensToSample; + } + + @Override + public float getTemperature() { + float temperature = 0f; + try { + if (!getRequestBodyJsonMap().isEmpty()) { + JsonNode jsonNode = getRequestBodyJsonMap().get(TEMPERATURE); + if (jsonNode.isNumber()) { + String temperatureString = jsonNode.asNumber(); + temperature = Float.parseFloat(temperatureString); + } + } else { + logParsingFailure(null, TEMPERATURE); + } + } catch (Exception e) { + logParsingFailure(e, TEMPERATURE); + } + return temperature; + } + + @Override + public int getNumberOfRequestMessages() { + // The Command request only ever contains a single prompt message + return 1; + } + + @Override + public String getRequestMessage(int index) { + return parseStringValue(PROMPT); + } + + @Override + public String getInputText(int index) { + String parsedInputText = ""; + try { + if (!getRequestBodyJsonMap().isEmpty()) { + JsonNode textsJsonNode = getRequestBodyJsonMap().get(TEXTS); + if (textsJsonNode.isArray()) { + List textsJsonNodeArray = textsJsonNode.asArray(); + if (!textsJsonNodeArray.isEmpty()) { + JsonNode jsonNode = textsJsonNodeArray.get(index); + if (jsonNode.isString()) { + parsedInputText = jsonNode.asString(); + } + } + } + } + } catch (Exception e) { + logParsingFailure(e, TEXTS); + } + if (parsedInputText.isEmpty()) { + logParsingFailure(null, TEXTS); + } + return parsedInputText; + } + + @Override + public int getNumberOfInputTextMessages() { + int numberOfInputTextMessages = 0; + try { + if (!getRequestBodyJsonMap().isEmpty()) { + JsonNode textsJsonNode = getRequestBodyJsonMap().get(TEXTS); + if (textsJsonNode.isArray()) { + List textsJsonNodeArray = textsJsonNode.asArray(); + if (!textsJsonNodeArray.isEmpty()) { + numberOfInputTextMessages = textsJsonNodeArray.size(); + } + } + } + } catch (Exception e) { + logParsingFailure(e, TEXTS); + } + if (numberOfInputTextMessages == 0) { + logParsingFailure(null, TEXTS); + } + return numberOfInputTextMessages; + } + + private String parseStringValue(String fieldToParse) { + String parsedStringValue = ""; + try { + if (!getRequestBodyJsonMap().isEmpty()) { + JsonNode jsonNode = getRequestBodyJsonMap().get(fieldToParse); + if (jsonNode.isString()) { + parsedStringValue = jsonNode.asString(); + } + } + } catch (Exception e) { + logParsingFailure(e, fieldToParse); + } + if (parsedStringValue.isEmpty()) { + logParsingFailure(null, fieldToParse); + } + return parsedStringValue; + } + + @Override + public String getModelId() { + return modelId; + } +} diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/cohere/command/CommandModelResponse.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/cohere/command/CommandModelResponse.java new file mode 100644 index 0000000000..1308759a70 --- /dev/null +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/cohere/command/CommandModelResponse.java @@ -0,0 +1,243 @@ +/* + * + * * Copyright 2024 New Relic Corporation. All rights reserved. + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package llm.models.cohere.command; + +import com.newrelic.api.agent.NewRelic; +import llm.models.ModelResponse; +import software.amazon.awssdk.protocols.jsoncore.JsonNode; +import software.amazon.awssdk.protocols.jsoncore.JsonNodeParser; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse; + +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.logging.Level; + +import static llm.models.ModelInvocation.getRandomGuid; +import static llm.models.ModelResponse.logParsingFailure; + +/** + * Stores the required info from the Bedrock InvokeModelResponse without holding + * a reference to the actual request object to avoid potential memory issues. + */ +public class CommandModelResponse implements ModelResponse { + private static final String FINISH_REASON = "finish_reason"; + private static final String GENERATIONS = "generations"; + private static final String EMBEDDINGS = "embeddings"; + private static final String TEXT = "text"; + + private String amznRequestId = ""; + + // LLM operation type + private String operationType = ""; + + // HTTP response + private boolean isSuccessfulResponse = false; + private int statusCode = 0; + private String statusText = ""; + + private String llmChatCompletionSummaryId = ""; + private String llmEmbeddingId = ""; + + private String invokeModelResponseBody = ""; + private Map responseBodyJsonMap = null; + + public CommandModelResponse(InvokeModelResponse invokeModelResponse) { + if (invokeModelResponse != null) { + invokeModelResponseBody = invokeModelResponse.body().asUtf8String(); + isSuccessfulResponse = invokeModelResponse.sdkHttpResponse().isSuccessful(); + statusCode = invokeModelResponse.sdkHttpResponse().statusCode(); + Optional statusTextOptional = invokeModelResponse.sdkHttpResponse().statusText(); + statusTextOptional.ifPresent(s -> statusText = s); + setOperationType(invokeModelResponseBody); + amznRequestId = invokeModelResponse.responseMetadata().requestId(); + llmChatCompletionSummaryId = getRandomGuid(); + llmEmbeddingId = getRandomGuid(); + } else { + NewRelic.getAgent().getLogger().log(Level.INFO, "AIM: Received null InvokeModelResponse"); + } + } + + /** + * Get a map of the Response body contents. + *

+ * Use this method to obtain the Response body contents so that the map is lazily initialized and only parsed once. + * + * @return map of String to JsonNode + */ + private Map getResponseBodyJsonMap() { + if (responseBodyJsonMap == null) { + responseBodyJsonMap = parseInvokeModelResponseBodyMap(); + } + return responseBodyJsonMap; + } + + /** + * Convert JSON Response body string into a map. + * + * @return map of String to JsonNode + */ + private Map parseInvokeModelResponseBodyMap() { + Map responseBodyJsonMap = null; + try { + // Use AWS SDK JSON parsing to parse response body + JsonNodeParser jsonNodeParser = JsonNodeParser.create(); + JsonNode responseBodyJsonNode = jsonNodeParser.parse(invokeModelResponseBody); + + if (responseBodyJsonNode != null && responseBodyJsonNode.isObject()) { + responseBodyJsonMap = responseBodyJsonNode.asObject(); + } else { + logParsingFailure(null, "response body"); + } + } catch (Exception e) { + logParsingFailure(e, "response body"); + } + return responseBodyJsonMap != null ? responseBodyJsonMap : Collections.emptyMap(); + } + + /** + * Parses the operation type from the response body and assigns it to a field. + * + * @param invokeModelResponseBody response body String + */ + private void setOperationType(String invokeModelResponseBody) { + try { + if (!invokeModelResponseBody.isEmpty()) { + if (invokeModelResponseBody.contains(GENERATIONS)) { + operationType = COMPLETION; + } else if (invokeModelResponseBody.contains(EMBEDDINGS)) { + operationType = EMBEDDING; + } else { + logParsingFailure(null, "operation type"); + } + } + } catch (Exception e) { + logParsingFailure(e, "operation type"); + } + } + + @Override + public String getResponseMessage(int index) { + String parsedResponseMessage = ""; + try { + if (!getResponseBodyJsonMap().isEmpty()) { + JsonNode generationsJsonNode = getResponseBodyJsonMap().get(GENERATIONS); + if (generationsJsonNode.isArray()) { + List generationsJsonNodeArray = generationsJsonNode.asArray(); + if (!generationsJsonNodeArray.isEmpty()) { + JsonNode jsonNode = generationsJsonNodeArray.get(index); + if (jsonNode.isObject()) { + Map jsonNodeObject = jsonNode.asObject(); + if (!jsonNodeObject.isEmpty()) { + JsonNode textJsonNode = jsonNodeObject.get(TEXT); + if (textJsonNode.isString()) { + parsedResponseMessage = textJsonNode.asString(); + } + } + } + } + } + } + } catch (Exception e) { + logParsingFailure(e, TEXT); + } + if (parsedResponseMessage.isEmpty()) { + logParsingFailure(null, TEXT); + } + return parsedResponseMessage; + } + + @Override + public int getNumberOfResponseMessages() { + int numberOfResponseMessages = 0; + try { + if (!getResponseBodyJsonMap().isEmpty()) { + JsonNode generationsJsonNode = getResponseBodyJsonMap().get(GENERATIONS); + if (generationsJsonNode.isArray()) { + List generationsJsonNodeArray = generationsJsonNode.asArray(); + if (!generationsJsonNodeArray.isEmpty()) { + numberOfResponseMessages = generationsJsonNodeArray.size(); + } + } + } + } catch (Exception e) { + logParsingFailure(e, GENERATIONS); + } + if (numberOfResponseMessages == 0) { + logParsingFailure(null, GENERATIONS); + } + return numberOfResponseMessages; + } + + @Override + public String getStopReason() { + String parsedStopReason = ""; + try { + if (!getResponseBodyJsonMap().isEmpty()) { + JsonNode generationsJsonNode = getResponseBodyJsonMap().get(GENERATIONS); + if (generationsJsonNode.isArray()) { + List generationsJsonNodeArray = generationsJsonNode.asArray(); + if (!generationsJsonNodeArray.isEmpty()) { + JsonNode jsonNode = generationsJsonNodeArray.get(0); + if (jsonNode.isObject()) { + Map jsonNodeObject = jsonNode.asObject(); + if (!jsonNodeObject.isEmpty()) { + JsonNode finishReasonJsonNode = jsonNodeObject.get(FINISH_REASON); + if (finishReasonJsonNode.isString()) { + parsedStopReason = finishReasonJsonNode.asString(); + } + } + } + } + } + } + } catch (Exception e) { + logParsingFailure(e, FINISH_REASON); + } + if (parsedStopReason.isEmpty()) { + logParsingFailure(null, FINISH_REASON); + } + return parsedStopReason; + } + + @Override + public String getAmznRequestId() { + return amznRequestId; + } + + @Override + public String getOperationType() { + return operationType; + } + + @Override + public String getLlmChatCompletionSummaryId() { + return llmChatCompletionSummaryId; + } + + @Override + public String getLlmEmbeddingId() { + return llmEmbeddingId; + } + + @Override + public boolean isErrorResponse() { + return !isSuccessfulResponse; + } + + @Override + public int getStatusCode() { + return statusCode; + } + + @Override + public String getStatusText() { + return statusText; + } +} diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/cohere/command/README.md b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/cohere/command/README.md new file mode 100644 index 0000000000..c50a56a1cf --- /dev/null +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/cohere/command/README.md @@ -0,0 +1,82 @@ +# Cohere + +Examples of the request/response bodies for models that have been tested and verified to work. The instrumentation should continue to correctly process new models as long as they match the model naming prefixes in `llm.models.SupportedModels` and the request/response structure stays the same as the examples listed here. + +## Command Models + +### Text Completion Models + +The following models have been tested: +* Command(`cohere.command-text-v14`) +* Command Light(`cohere.command-light-text-v14`) + +#### Sample Request + +```json +{ + "p": 0.9, + "stop_sequences": [ + "User:" + ], + "truncate": "END", + "max_tokens": 1000, + "stream": false, + "temperature": 0.5, + "k": 0, + "return_likelihoods": "NONE", + "prompt": "What is the color of the sky?" +} +``` + +#### Sample Response + +```json +{ + "generations": [ + { + "finish_reason": "COMPLETE", + "id": "f5700a48-0730-49f1-9756-227a993963aa", + "text": " The color of the sky can vary depending on the time of day, weather conditions, and location. In general, the color of the sky is a pale blue. During the day, the sky can appear to be a lighter shade of blue, while at night, it may appear to be a darker shade of blue or even black. The color of the sky can also be affected by the presence of clouds, which can appear as white, grey, or even pink or red in the morning or evening light. \n\nIt is important to note that the color of the sky is not a static or fixed color, but rather a dynamic and ever-changing one, which can be influenced by a variety of factors." + } + ], + "id": "c548295f-9064-49c5-a05f-c754e4c5c9f8", + "prompt": "What is the color of the sky?" +} +``` + +### Embedding Models + +The following models have been tested: +* Embed English(`cohere.embed-english-v3`) +* Embed Multilingual(`cohere.embed-multilingual-v3`) + +#### Sample Request + +```json +{ + "texts": [ + "What is the color of the sky?" + ], + "truncate": "NONE", + "input_type": "search_document" +} +``` + +#### Sample Response + +```json +{ + "embeddings": [ + [ + -0.002828598, + ..., + 0.00541687 + ] + ], + "id": "e1e969ba-d526-4c76-aa92-a8a705288f6d", + "response_type": "embeddings_floats", + "texts": [ + "what is the color of the sky?" + ] +} +``` diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/meta/llama2/Llama2ModelInvocation.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/meta/llama2/Llama2ModelInvocation.java new file mode 100644 index 0000000000..99e2820dde --- /dev/null +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/meta/llama2/Llama2ModelInvocation.java @@ -0,0 +1,224 @@ +/* + * + * * Copyright 2024 New Relic Corporation. All rights reserved. + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package llm.models.meta.llama2; + +import com.newrelic.agent.bridge.Token; +import com.newrelic.agent.bridge.Transaction; +import com.newrelic.api.agent.NewRelic; +import com.newrelic.api.agent.Segment; +import com.newrelic.api.agent.Trace; +import llm.events.LlmEvent; +import llm.models.ModelInvocation; +import llm.models.ModelRequest; +import llm.models.ModelResponse; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelRequest; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse; + +import java.util.HashMap; +import java.util.Map; +import java.util.logging.Level; + +import static llm.models.ModelResponse.COMPLETION; +import static llm.models.ModelResponse.EMBEDDING; +import static llm.vendor.Vendor.BEDROCK; + +public class Llama2ModelInvocation implements ModelInvocation { + Map linkingMetadata; + Map userAttributes; + ModelRequest modelRequest; + ModelResponse modelResponse; + + public Llama2ModelInvocation(Map linkingMetadata, Map userCustomAttributes, InvokeModelRequest invokeModelRequest, + InvokeModelResponse invokeModelResponse) { + this.linkingMetadata = linkingMetadata; + this.userAttributes = userCustomAttributes; + this.modelRequest = new Llama2ModelRequest(invokeModelRequest); + this.modelResponse = new Llama2ModelResponse(invokeModelResponse); + } + + @Override + public void setTracedMethodName(Transaction txn, String functionName) { + txn.getTracedMethod().setMetricName("Llm", modelResponse.getOperationType(), BEDROCK, functionName); + } + + @Override + public void setSegmentName(Segment segment, String functionName) { + segment.setMetricName("Llm", modelResponse.getOperationType(), BEDROCK, functionName); + } + + @Override + public void recordLlmEmbeddingEvent(long startTime, int index) { + if (modelResponse.isErrorResponse()) { + reportLlmError(); + } + + LlmEvent.Builder builder = new LlmEvent.Builder(this); + + LlmEvent llmEmbeddingEvent = builder + .spanId() + .traceId() + .vendor() + .ingestSource() + .id(modelResponse.getLlmEmbeddingId()) + .requestId() + .input(index) + .requestModel() + .responseModel() + .tokenCount(ModelInvocation.getTokenCount(modelRequest.getModelId(), modelRequest.getInputText(0))) + .error() + .duration(System.currentTimeMillis() - startTime) + .build(); + + llmEmbeddingEvent.recordLlmEmbeddingEvent(); + } + + @Override + public void recordLlmChatCompletionSummaryEvent(long startTime, int numberOfMessages) { + if (modelResponse.isErrorResponse()) { + reportLlmError(); + } + + LlmEvent.Builder builder = new LlmEvent.Builder(this); + + LlmEvent llmChatCompletionSummaryEvent = builder + .spanId() + .traceId() + .vendor() + .ingestSource() + .id(modelResponse.getLlmChatCompletionSummaryId()) + .requestId() + .requestTemperature() + .requestMaxTokens() + .requestModel() + .responseModel() + .responseNumberOfMessages(numberOfMessages) + .responseChoicesFinishReason() + .error() + .duration(System.currentTimeMillis() - startTime) + .build(); + + llmChatCompletionSummaryEvent.recordLlmChatCompletionSummaryEvent(); + } + + @Override + public void recordLlmChatCompletionMessageEvent(int sequence, String message, boolean isUser) { + LlmEvent.Builder builder = new LlmEvent.Builder(this); + + LlmEvent llmChatCompletionMessageEvent = builder + .spanId() + .traceId() + .vendor() + .ingestSource() + .id(ModelInvocation.getRandomGuid()) + .content(message) + .role(isUser) + .isResponse(isUser) + .requestId() + .responseModel() + .sequence(sequence) + .completionId() + .tokenCount(ModelInvocation.getTokenCount(modelRequest.getModelId(), message)) + .build(); + + llmChatCompletionMessageEvent.recordLlmChatCompletionMessageEvent(); + } + + @Override + public void recordLlmEvents(long startTime) { + String operationType = modelResponse.getOperationType(); + if (operationType.equals(COMPLETION)) { + recordLlmChatCompletionEvents(startTime); + } else if (operationType.equals(EMBEDDING)) { + recordLlmEmbeddingEvents(startTime); + } else { + NewRelic.getAgent().getLogger().log(Level.INFO, "AIM: Unexpected operation type encountered when trying to record LLM events"); + } + } + + @Trace(async = true) + @Override + public void recordLlmEventsAsync(long startTime, Token token) { + if (token != null && token.isActive()) { + token.linkAndExpire(); + } + recordLlmEvents(startTime); + } + + @Override + public void reportLlmError() { + Map errorParams = new HashMap<>(); + errorParams.put("http.statusCode", modelResponse.getStatusCode()); + errorParams.put("error.code", modelResponse.getStatusCode()); + if (!modelResponse.getLlmChatCompletionSummaryId().isEmpty()) { + errorParams.put("completion_id", modelResponse.getLlmChatCompletionSummaryId()); + } + if (!modelResponse.getLlmEmbeddingId().isEmpty()) { + errorParams.put("embedding_id", modelResponse.getLlmEmbeddingId()); + } + NewRelic.noticeError("LlmError: " + modelResponse.getStatusText(), errorParams); + } + + /** + * Records multiple LlmChatCompletionMessage events and a single LlmChatCompletionSummary event. + * The number of LlmChatCompletionMessage events produced can differ based on vendor. + */ + private void recordLlmChatCompletionEvents(long startTime) { + int numberOfRequestMessages = modelRequest.getNumberOfRequestMessages(); + int numberOfResponseMessages = modelResponse.getNumberOfResponseMessages(); + int totalNumberOfMessages = numberOfRequestMessages + numberOfResponseMessages; + + int sequence = 0; + + // First, record all LlmChatCompletionMessage events representing the user input prompt + for (int i = 0; i < numberOfRequestMessages; i++) { + recordLlmChatCompletionMessageEvent(sequence, modelRequest.getRequestMessage(i), true); + sequence++; + } + + // Second, record all LlmChatCompletionMessage events representing the completion message from the LLM response + for (int i = 0; i < numberOfResponseMessages; i++) { + recordLlmChatCompletionMessageEvent(sequence, modelResponse.getResponseMessage(i), false); + sequence++; + } + + // Finally, record a summary event representing all LlmChatCompletionMessage events + recordLlmChatCompletionSummaryEvent(startTime, totalNumberOfMessages); + } + + /** + * Records one, and potentially more, LlmEmbedding events based on the number of input messages in the request. + * The number of LlmEmbedding events produced can differ based on vendor. + */ + private void recordLlmEmbeddingEvents(long startTime) { + int numberOfRequestMessages = modelRequest.getNumberOfInputTextMessages(); + // Record an LlmEmbedding event for each input message in the request + for (int i = 0; i < numberOfRequestMessages; i++) { + recordLlmEmbeddingEvent(startTime, i); + } + } + + @Override + public Map getLinkingMetadata() { + return linkingMetadata; + } + + @Override + public Map getUserAttributes() { + return userAttributes; + } + + @Override + public ModelRequest getModelRequest() { + return modelRequest; + } + + @Override + public ModelResponse getModelResponse() { + return modelResponse; + } +} diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/meta/llama2/Llama2ModelRequest.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/meta/llama2/Llama2ModelRequest.java new file mode 100644 index 0000000000..adf248b3eb --- /dev/null +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/meta/llama2/Llama2ModelRequest.java @@ -0,0 +1,165 @@ +/* + * + * * Copyright 2024 New Relic Corporation. All rights reserved. + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package llm.models.meta.llama2; + +import com.newrelic.api.agent.NewRelic; +import llm.models.ModelRequest; +import software.amazon.awssdk.protocols.jsoncore.JsonNode; +import software.amazon.awssdk.protocols.jsoncore.JsonNodeParser; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelRequest; + +import java.util.Collections; +import java.util.Map; +import java.util.logging.Level; + +import static llm.models.ModelRequest.logParsingFailure; + +/** + * Stores the required info from the Bedrock InvokeModelRequest without holding + * a reference to the actual request object to avoid potential memory issues. + */ +public class Llama2ModelRequest implements ModelRequest { + private static final String MAX_GEN_LEN = "max_gen_len"; + private static final String TEMPERATURE = "temperature"; + private static final String PROMPT = "prompt"; + + private String invokeModelRequestBody = ""; + private String modelId = ""; + private Map requestBodyJsonMap = null; + + public Llama2ModelRequest(InvokeModelRequest invokeModelRequest) { + if (invokeModelRequest != null) { + invokeModelRequestBody = invokeModelRequest.body().asUtf8String(); + modelId = invokeModelRequest.modelId(); + } else { + NewRelic.getAgent().getLogger().log(Level.FINEST, "AIM: Received null InvokeModelRequest"); + } + } + + /** + * Get a map of the Request body contents. + *

+ * Use this method to obtain the Request body contents so that the map is lazily initialized and only parsed once. + * + * @return map of String to JsonNode + */ + private Map getRequestBodyJsonMap() { + if (requestBodyJsonMap == null) { + requestBodyJsonMap = parseInvokeModelRequestBodyMap(); + } + return requestBodyJsonMap; + } + + /** + * Convert JSON Request body string into a map. + * + * @return map of String to JsonNode + */ + private Map parseInvokeModelRequestBodyMap() { + // Use AWS SDK JSON parsing to parse request body + JsonNodeParser jsonNodeParser = JsonNodeParser.create(); + JsonNode requestBodyJsonNode = jsonNodeParser.parse(invokeModelRequestBody); + + Map requestBodyJsonMap = null; + try { + if (requestBodyJsonNode != null && requestBodyJsonNode.isObject()) { + requestBodyJsonMap = requestBodyJsonNode.asObject(); + } else { + logParsingFailure(null, "request body"); + } + } catch (Exception e) { + logParsingFailure(e, "request body"); + } + return requestBodyJsonMap != null ? requestBodyJsonMap : Collections.emptyMap(); + } + + @Override + public int getMaxTokensToSample() { + int maxTokensToSample = 0; + try { + if (!getRequestBodyJsonMap().isEmpty()) { + JsonNode jsonNode = getRequestBodyJsonMap().get(MAX_GEN_LEN); + if (jsonNode.isNumber()) { + String maxTokensToSampleString = jsonNode.asNumber(); + maxTokensToSample = Integer.parseInt(maxTokensToSampleString); + } + } + } catch (Exception e) { + logParsingFailure(e, MAX_GEN_LEN); + } + if (maxTokensToSample == 0) { + logParsingFailure(null, MAX_GEN_LEN); + } + return maxTokensToSample; + } + + @Override + public float getTemperature() { + float temperature = 0f; + try { + if (!getRequestBodyJsonMap().isEmpty()) { + JsonNode jsonNode = getRequestBodyJsonMap().get(TEMPERATURE); + if (jsonNode.isNumber()) { + String temperatureString = jsonNode.asNumber(); + temperature = Float.parseFloat(temperatureString); + } + } else { + logParsingFailure(null, TEMPERATURE); + } + } catch (Exception e) { + logParsingFailure(e, TEMPERATURE); + } + return temperature; + } + + @Override + public int getNumberOfRequestMessages() { + // The Llama request only ever contains a single prompt message + return 1; + } + + @Override + public String getRequestMessage(int index) { + return parseStringValue(PROMPT); + } + + @Override + public String getInputText(int index) { + // This is a NoOp for Llama as it doesn't support embeddings + return ""; + } + + @Override + public int getNumberOfInputTextMessages() { + // This is a NoOp for Llama as it doesn't support embeddings + return 0; + } + + private String parseStringValue(String fieldToParse) { + String parsedStringValue = ""; + try { + if (!getRequestBodyJsonMap().isEmpty()) { + JsonNode jsonNode = getRequestBodyJsonMap().get(fieldToParse); + if (jsonNode.isString()) { + parsedStringValue = jsonNode.asString(); + } + } + } catch (Exception e) { + logParsingFailure(e, fieldToParse); + } + if (parsedStringValue.isEmpty()) { + logParsingFailure(null, fieldToParse); + } + return parsedStringValue; + } + + @Override + public String getModelId() { + return modelId; + } +} diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/meta/llama2/Llama2ModelResponse.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/meta/llama2/Llama2ModelResponse.java new file mode 100644 index 0000000000..d15eef7136 --- /dev/null +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/meta/llama2/Llama2ModelResponse.java @@ -0,0 +1,189 @@ +/* + * + * * Copyright 2024 New Relic Corporation. All rights reserved. + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package llm.models.meta.llama2; + +import com.newrelic.api.agent.NewRelic; +import llm.models.ModelResponse; +import software.amazon.awssdk.protocols.jsoncore.JsonNode; +import software.amazon.awssdk.protocols.jsoncore.JsonNodeParser; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse; + +import java.util.Collections; +import java.util.Map; +import java.util.Optional; +import java.util.logging.Level; + +import static llm.models.ModelInvocation.getRandomGuid; +import static llm.models.ModelResponse.logParsingFailure; + +/** + * Stores the required info from the Bedrock InvokeModelResponse without holding + * a reference to the actual request object to avoid potential memory issues. + */ +public class Llama2ModelResponse implements ModelResponse { + private static final String STOP_REASON = "stop_reason"; + private static final String GENERATION = "generation"; + + private String amznRequestId = ""; + + // LLM operation type + private String operationType = ""; + + // HTTP response + private boolean isSuccessfulResponse = false; + private int statusCode = 0; + private String statusText = ""; + + private String llmChatCompletionSummaryId = ""; + private String llmEmbeddingId = ""; + + private String invokeModelResponseBody = ""; + private Map responseBodyJsonMap = null; + + public Llama2ModelResponse(InvokeModelResponse invokeModelResponse) { + if (invokeModelResponse != null) { + invokeModelResponseBody = invokeModelResponse.body().asUtf8String(); + isSuccessfulResponse = invokeModelResponse.sdkHttpResponse().isSuccessful(); + statusCode = invokeModelResponse.sdkHttpResponse().statusCode(); + Optional statusTextOptional = invokeModelResponse.sdkHttpResponse().statusText(); + statusTextOptional.ifPresent(s -> statusText = s); + setOperationType(invokeModelResponseBody); + amznRequestId = invokeModelResponse.responseMetadata().requestId(); + llmChatCompletionSummaryId = getRandomGuid(); + llmEmbeddingId = getRandomGuid(); + } else { + NewRelic.getAgent().getLogger().log(Level.INFO, "AIM: Received null InvokeModelResponse"); + } + } + + /** + * Get a map of the Response body contents. + *

+ * Use this method to obtain the Response body contents so that the map is lazily initialized and only parsed once. + * + * @return map of String to JsonNode + */ + private Map getResponseBodyJsonMap() { + if (responseBodyJsonMap == null) { + responseBodyJsonMap = parseInvokeModelResponseBodyMap(); + } + return responseBodyJsonMap; + } + + /** + * Convert JSON Response body string into a map. + * + * @return map of String to JsonNode + */ + private Map parseInvokeModelResponseBodyMap() { + Map responseBodyJsonMap = null; + try { + // Use AWS SDK JSON parsing to parse response body + JsonNodeParser jsonNodeParser = JsonNodeParser.create(); + JsonNode responseBodyJsonNode = jsonNodeParser.parse(invokeModelResponseBody); + + if (responseBodyJsonNode != null && responseBodyJsonNode.isObject()) { + responseBodyJsonMap = responseBodyJsonNode.asObject(); + } else { + logParsingFailure(null, "response body"); + } + } catch (Exception e) { + logParsingFailure(e, "response body"); + } + return responseBodyJsonMap != null ? responseBodyJsonMap : Collections.emptyMap(); + } + + /** + * Parses the operation type from the response body and assigns it to a field. + * + * @param invokeModelResponseBody response body String + */ + private void setOperationType(String invokeModelResponseBody) { + try { + if (!invokeModelResponseBody.isEmpty()) { + // Meta Llama 2 for Bedrock doesn't support embedding operations + if (invokeModelResponseBody.contains(GENERATION)) { + operationType = COMPLETION; + } else { + logParsingFailure(null, "operation type"); + } + } + } catch (Exception e) { + logParsingFailure(e, "operation type"); + } + } + + @Override + public String getResponseMessage(int index) { + return parseStringValue(GENERATION); + } + + @Override + public int getNumberOfResponseMessages() { + // There is only ever a single response message + return 1; + } + + @Override + public String getStopReason() { + return parseStringValue(STOP_REASON); + } + + private String parseStringValue(String fieldToParse) { + String parsedStringValue = ""; + try { + if (!getResponseBodyJsonMap().isEmpty()) { + JsonNode jsonNode = getResponseBodyJsonMap().get(fieldToParse); + if (jsonNode.isString()) { + parsedStringValue = jsonNode.asString(); + } + } + } catch (Exception e) { + logParsingFailure(e, fieldToParse); + } + if (parsedStringValue.isEmpty()) { + logParsingFailure(null, fieldToParse); + } + return parsedStringValue; + } + + @Override + public String getAmznRequestId() { + return amznRequestId; + } + + @Override + public String getOperationType() { + return operationType; + } + + @Override + public String getLlmChatCompletionSummaryId() { + return llmChatCompletionSummaryId; + } + + @Override + public String getLlmEmbeddingId() { + return llmEmbeddingId; + } + + @Override + public boolean isErrorResponse() { + return !isSuccessfulResponse; + } + + @Override + public int getStatusCode() { + return statusCode; + } + + @Override + public String getStatusText() { + return statusText; + } +} diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/meta/llama2/README.md b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/meta/llama2/README.md new file mode 100644 index 0000000000..6bfba79a5a --- /dev/null +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/models/meta/llama2/README.md @@ -0,0 +1,40 @@ +# Meta + +Examples of the request/response bodies for models that have been tested and verified to work. The instrumentation should continue to correctly process new +models as long as they match the model naming prefixes in `llm.models.SupportedModels` and the request/response structure stays the same as the examples listed +here. + +## Llama 2 Models + +### Text Completion Models + +The following models have been tested: + +* Llama 2Chat 13B (`meta.llama2-13b-chat-v1`) +* Llama 2Chat 70B (`meta.llama2-70b-chat-v1`) + +#### Sample Request + +```json +{ + "top_p": 0.9, + "max_gen_len": 1000, + "temperature": 0.5, + "prompt": "What is the color of the sky?" +} +``` + +#### Sample Response + +```json +{ + "generation": "\n\nThe color of the sky can vary depending on the time of day and atmospheric conditions. During the daytime, the sky typically appears blue, which is caused by a phenomenon called Rayleigh scattering, in which shorter (blue) wavelengths of light are scattered more than longer (red) wavelengths by the tiny molecules of gases in the atmosphere.\n\nIn the evening, as the sun sets, the sky can take on a range of colors, including orange, pink, and purple, due to the scattering of light by atmospheric particles. During sunrise and sunset, the sky can also appear more red or orange due to the longer wavelengths of light being scattered.\n\nAt night, the sky can appear dark, but it can also be illuminated by the moon, stars, and artificial light sources such as city lights. In areas with minimal light pollution, the night sky can be a deep indigo or black, with the stars and constellations visible as points of light.\n\nOverall, the color of the sky can vary greatly depending on the time of day, atmospheric conditions, and the observer's location.", + "prompt_token_count": 9, + "generation_token_count": 256, + "stop_reason": "stop" +} +``` + +### Embedding Models + +Not supported by Llama 2. diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/vendor/Vendor.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/vendor/Vendor.java new file mode 100644 index 0000000000..0381ffe818 --- /dev/null +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/llm/vendor/Vendor.java @@ -0,0 +1,16 @@ +/* + * + * * Copyright 2024 New Relic Corporation. All rights reserved. + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package llm.vendor; + +public class Vendor { + public static final String VENDOR = "bedrock"; + // Bedrock vendor_version isn't obtainable, so set it to instrumentation version instead + public static final String VENDOR_VERSION = "2.20"; + public static final String BEDROCK = "Bedrock"; + public static final String INGEST_SOURCE = "Java"; +} diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeAsyncClient_Instrumentation.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeAsyncClient_Instrumentation.java new file mode 100644 index 0000000000..f99ab6e635 --- /dev/null +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeAsyncClient_Instrumentation.java @@ -0,0 +1,144 @@ +/* + * + * * Copyright 2024 New Relic Corporation. All rights reserved. + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package software.amazon.awssdk.services.bedrockruntime; + +import com.newrelic.agent.bridge.AgentBridge; +import com.newrelic.agent.bridge.NoOpTransaction; +import com.newrelic.agent.bridge.Token; +import com.newrelic.agent.bridge.Transaction; +import com.newrelic.api.agent.NewRelic; +import com.newrelic.api.agent.Segment; +import com.newrelic.api.agent.Trace; +import com.newrelic.api.agent.weaver.MatchType; +import com.newrelic.api.agent.weaver.Weave; +import com.newrelic.api.agent.weaver.Weaver; +import llm.models.ModelInvocation; +import llm.models.ai21labs.jurassic.JurassicModelInvocation; +import llm.models.amazon.titan.TitanModelInvocation; +import llm.models.anthropic.claude.ClaudeModelInvocation; +import llm.models.cohere.command.CommandModelInvocation; +import llm.models.meta.llama2.Llama2ModelInvocation; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelRequest; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelWithResponseStreamRequest; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelWithResponseStreamResponseHandler; + +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.function.BiConsumer; +import java.util.logging.Level; + +import static com.newrelic.agent.bridge.aimonitoring.AiMonitoringUtils.isAiMonitoringEnabled; +import static com.newrelic.agent.bridge.aimonitoring.AiMonitoringUtils.isAiMonitoringStreamingEnabled; +import static llm.models.SupportedModels.AI_21_LABS_JURASSIC; +import static llm.models.SupportedModels.AMAZON_TITAN; +import static llm.models.SupportedModels.ANTHROPIC_CLAUDE; +import static llm.models.SupportedModels.COHERE_COMMAND; +import static llm.models.SupportedModels.COHERE_EMBED; +import static llm.models.SupportedModels.META_LLAMA_2; +import static llm.vendor.Vendor.VENDOR_VERSION; + +/** + * Service client for accessing Amazon Bedrock Runtime asynchronously. + */ +@Weave(type = MatchType.Interface, originalName = "software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeAsyncClient") +public abstract class BedrockRuntimeAsyncClient_Instrumentation { + + @Trace + public CompletableFuture invokeModel(InvokeModelRequest invokeModelRequest) { + long startTime = System.currentTimeMillis(); + CompletableFuture invokeModelResponseFuture = Weaver.callOriginal(); + + if (isAiMonitoringEnabled()) { + Transaction txn = AgentBridge.getAgent().getTransaction(); + ModelInvocation.incrementInstrumentedSupportabilityMetric(VENDOR_VERSION); + + if (!(txn instanceof NoOpTransaction)) { + // Segment will be renamed later when the response is available + Segment segment = txn.startSegment(""); + // Set llm = true agent attribute, this is required on transaction events + ModelInvocation.setLlmTrueAgentAttribute(txn); + + // This should never happen, but protecting against bad implementations + if (invokeModelResponseFuture == null) { + segment.end(); + } else { + Map userAttributes = txn.getUserAttributes(); + Map linkingMetadata = NewRelic.getAgent().getLinkingMetadata(); + String modelId = invokeModelRequest.modelId(); + + Token token = txn.getToken(); + + // Instrumentation fails if the BiConsumer is replaced with a lambda + invokeModelResponseFuture.whenComplete(new BiConsumer() { + @Override + public void accept(InvokeModelResponse invokeModelResponse, Throwable throwable) { + + try { + if (modelId.toLowerCase().contains(ANTHROPIC_CLAUDE)) { + ModelInvocation claudeModelInvocation = new ClaudeModelInvocation(linkingMetadata, userAttributes, invokeModelRequest, + invokeModelResponse); + // Set segment name based on LLM operation from response + claudeModelInvocation.setSegmentName(segment, "invokeModel"); + claudeModelInvocation.recordLlmEventsAsync(startTime, token); + } else if (modelId.toLowerCase().contains(AMAZON_TITAN)) { + ModelInvocation titanModelInvocation = new TitanModelInvocation(linkingMetadata, userAttributes, invokeModelRequest, + invokeModelResponse); + // Set traced method name based on LLM operation from response + titanModelInvocation.setTracedMethodName(txn, "invokeModel"); + titanModelInvocation.recordLlmEventsAsync(startTime, token); + } else if (modelId.toLowerCase().contains(META_LLAMA_2)) { + ModelInvocation llama2ModelInvocation = new Llama2ModelInvocation(linkingMetadata, userAttributes, invokeModelRequest, + invokeModelResponse); + // Set traced method name based on LLM operation from response + llama2ModelInvocation.setTracedMethodName(txn, "invokeModel"); + llama2ModelInvocation.recordLlmEventsAsync(startTime, token); + } else if (modelId.toLowerCase().contains(COHERE_COMMAND) || modelId.toLowerCase().contains(COHERE_EMBED)) { + ModelInvocation commandModelInvocation = new CommandModelInvocation(linkingMetadata, userAttributes, invokeModelRequest, + invokeModelResponse); + // Set traced method name based on LLM operation from response + commandModelInvocation.setTracedMethodName(txn, "invokeModel"); + commandModelInvocation.recordLlmEventsAsync(startTime, token); + } else if (modelId.toLowerCase().contains(AI_21_LABS_JURASSIC)) { + ModelInvocation jurassicModelInvocation = new JurassicModelInvocation(linkingMetadata, userAttributes, invokeModelRequest, + invokeModelResponse); + // Set traced method name based on LLM operation from response + jurassicModelInvocation.setTracedMethodName(txn, "invokeModel"); + jurassicModelInvocation.recordLlmEventsAsync(startTime, token); + } + if (segment != null) { + segment.endAsync(); + } + } catch (Throwable t) { + if (segment != null) { + segment.endAsync(); + } + AgentBridge.instrumentation.noticeInstrumentationError(t, Weaver.getImplementationTitle()); + } + } + }); + } + } + } + return invokeModelResponseFuture; + } + + public CompletableFuture invokeModelWithResponseStream( + InvokeModelWithResponseStreamRequest invokeModelWithResponseStreamRequest, + InvokeModelWithResponseStreamResponseHandler asyncResponseHandler) { + if (isAiMonitoringEnabled()) { + if (isAiMonitoringStreamingEnabled()) { + NewRelic.getAgent() + .getLogger() + .log(Level.FINER, + "aws-bedrock-runtime-2.20 instrumentation does not currently support response streaming. Enabling ai_monitoring.streaming will have no effect."); + } + } + return Weaver.callOriginal(); + } +} diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeClient_Instrumentation.java b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeClient_Instrumentation.java new file mode 100644 index 0000000000..c1411bbec1 --- /dev/null +++ b/instrumentation/aws-bedrock-runtime-2.20/src/main/java/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeClient_Instrumentation.java @@ -0,0 +1,93 @@ +/* + * + * * Copyright 2024 New Relic Corporation. All rights reserved. + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package software.amazon.awssdk.services.bedrockruntime; + +import com.newrelic.agent.bridge.AgentBridge; +import com.newrelic.agent.bridge.NoOpTransaction; +import com.newrelic.agent.bridge.Transaction; +import com.newrelic.api.agent.NewRelic; +import com.newrelic.api.agent.Trace; +import com.newrelic.api.agent.weaver.MatchType; +import com.newrelic.api.agent.weaver.Weave; +import com.newrelic.api.agent.weaver.Weaver; +import llm.models.ModelInvocation; +import llm.models.ai21labs.jurassic.JurassicModelInvocation; +import llm.models.amazon.titan.TitanModelInvocation; +import llm.models.anthropic.claude.ClaudeModelInvocation; +import llm.models.cohere.command.CommandModelInvocation; +import llm.models.meta.llama2.Llama2ModelInvocation; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelRequest; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse; + +import java.util.Map; + +import static com.newrelic.agent.bridge.aimonitoring.AiMonitoringUtils.isAiMonitoringEnabled; +import static llm.models.SupportedModels.AI_21_LABS_JURASSIC; +import static llm.models.SupportedModels.AMAZON_TITAN; +import static llm.models.SupportedModels.ANTHROPIC_CLAUDE; +import static llm.models.SupportedModels.COHERE_COMMAND; +import static llm.models.SupportedModels.COHERE_EMBED; +import static llm.models.SupportedModels.META_LLAMA_2; +import static llm.vendor.Vendor.VENDOR_VERSION; + +/** + * Service client for accessing Amazon Bedrock Runtime synchronously. + */ +@Weave(type = MatchType.Interface, originalName = "software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClient") +public abstract class BedrockRuntimeClient_Instrumentation { + + @Trace + public InvokeModelResponse invokeModel(InvokeModelRequest invokeModelRequest) { + long startTime = System.currentTimeMillis(); + InvokeModelResponse invokeModelResponse = Weaver.callOriginal(); + + if (isAiMonitoringEnabled()) { + Transaction txn = AgentBridge.getAgent().getTransaction(); + ModelInvocation.incrementInstrumentedSupportabilityMetric(VENDOR_VERSION); + + if (!(txn instanceof NoOpTransaction)) { + // Set llm = true agent attribute, this is required on transaction events + ModelInvocation.setLlmTrueAgentAttribute(txn); + + Map userAttributes = txn.getUserAttributes(); + Map linkingMetadata = NewRelic.getAgent().getLinkingMetadata(); + String modelId = invokeModelRequest.modelId(); + + if (modelId.toLowerCase().contains(ANTHROPIC_CLAUDE)) { + ModelInvocation claudeModelInvocation = new ClaudeModelInvocation(linkingMetadata, userAttributes, invokeModelRequest, invokeModelResponse); + // Set traced method name based on LLM operation from response + claudeModelInvocation.setTracedMethodName(txn, "invokeModel"); + claudeModelInvocation.recordLlmEvents(startTime); + } else if (modelId.toLowerCase().contains(AMAZON_TITAN)) { + ModelInvocation titanModelInvocation = new TitanModelInvocation(linkingMetadata, userAttributes, invokeModelRequest, invokeModelResponse); + // Set traced method name based on LLM operation from response + titanModelInvocation.setTracedMethodName(txn, "invokeModel"); + titanModelInvocation.recordLlmEvents(startTime); + } else if (modelId.toLowerCase().contains(META_LLAMA_2)) { + ModelInvocation llama2ModelInvocation = new Llama2ModelInvocation(linkingMetadata, userAttributes, invokeModelRequest, invokeModelResponse); + // Set traced method name based on LLM operation from response + llama2ModelInvocation.setTracedMethodName(txn, "invokeModel"); + llama2ModelInvocation.recordLlmEvents(startTime); + } else if (modelId.toLowerCase().contains(COHERE_COMMAND) || modelId.toLowerCase().contains(COHERE_EMBED)) { + ModelInvocation commandModelInvocation = new CommandModelInvocation(linkingMetadata, userAttributes, invokeModelRequest, + invokeModelResponse); + // Set traced method name based on LLM operation from response + commandModelInvocation.setTracedMethodName(txn, "invokeModel"); + commandModelInvocation.recordLlmEvents(startTime); + } else if (modelId.toLowerCase().contains(AI_21_LABS_JURASSIC)) { + ModelInvocation jurassicModelInvocation = new JurassicModelInvocation(linkingMetadata, userAttributes, invokeModelRequest, + invokeModelResponse); + // Set traced method name based on LLM operation from response + jurassicModelInvocation.setTracedMethodName(txn, "invokeModel"); + jurassicModelInvocation.recordLlmEvents(startTime); + } + } + } + return invokeModelResponse; + } +} diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/events/LlmEventTest.java b/instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/events/LlmEventTest.java new file mode 100644 index 0000000000..f932918bd1 --- /dev/null +++ b/instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/events/LlmEventTest.java @@ -0,0 +1,325 @@ +/* + * + * * Copyright 2024 New Relic Corporation. All rights reserved. + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package llm.events; + +import com.newrelic.agent.introspec.Event; +import com.newrelic.agent.introspec.InstrumentationTestConfig; +import com.newrelic.agent.introspec.InstrumentationTestRunner; +import com.newrelic.agent.introspec.Introspector; +import com.newrelic.api.agent.LlmTokenCountCallback; +import com.newrelic.agent.bridge.aimonitoring.LlmTokenCountCallbackHolder; +import llm.models.ModelInvocation; +import llm.models.amazon.titan.TitanModelInvocation; +import llm.models.anthropic.claude.ClaudeModelInvocation; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import software.amazon.awssdk.core.SdkBytes; +import software.amazon.awssdk.http.SdkHttpResponse; +import software.amazon.awssdk.services.bedrockruntime.model.BedrockRuntimeResponseMetadata; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelRequest; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse; + +import java.util.Collection; +import java.util.HashMap; +import java.util.Map; +import java.util.Optional; + +import static llm.events.LlmEvent.Builder; +import static llm.events.LlmEvent.LLM_CHAT_COMPLETION_MESSAGE; +import static llm.events.LlmEvent.LLM_CHAT_COMPLETION_SUMMARY; +import static llm.events.LlmEvent.LLM_EMBEDDING; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +@RunWith(InstrumentationTestRunner.class) +@InstrumentationTestConfig(includePrefixes = { "software.amazon.awssdk.services.bedrockruntime" }, configName = "llm_enabled.yml") +public class LlmEventTest { + private final Introspector introspector = InstrumentationTestRunner.getIntrospector(); + + @Before + public void before() { + introspector.clear(); + setUp(); + } + + public void setUp() { + LlmTokenCountCallback llmTokenCountCallback = (model, content) -> 13; + LlmTokenCountCallbackHolder.setLlmTokenCountCallback(llmTokenCountCallback); + } + @Test + public void testRecordLlmEmbeddingEvent() { + // Given + Map linkingMetadata = new HashMap<>(); + linkingMetadata.put("span.id", "span-id-123"); + linkingMetadata.put("trace.id", "trace-id-xyz"); + + Map userAttributes = new HashMap<>(); + userAttributes.put("llm.conversation_id", "conversation-id-890"); + userAttributes.put("llm.testPrefix", "testPrefix"); + userAttributes.put("test", "test"); + + // Mock out ModelRequest + InvokeModelRequest mockInvokeModelRequest = mock(InvokeModelRequest.class); + SdkBytes mockRequestSdkBytes = mock(SdkBytes.class); + when(mockInvokeModelRequest.body()).thenReturn(mockRequestSdkBytes); + when(mockRequestSdkBytes.asUtf8String()).thenReturn("{\"inputText\":\"What is the color of the sky?\"}"); + when(mockInvokeModelRequest.modelId()).thenReturn("amazon.titan-embed-text-v1"); + + // Mock out ModelResponse + InvokeModelResponse mockInvokeModelResponse = mock(InvokeModelResponse.class); + SdkBytes mockResponseSdkBytes = mock(SdkBytes.class); + when(mockInvokeModelResponse.body()).thenReturn(mockResponseSdkBytes); + when(mockResponseSdkBytes.asUtf8String()).thenReturn("{\"embedding\":[0.328125,0.44335938],\"inputTextTokenCount\":8}"); + + SdkHttpResponse mockSdkHttpResponse = mock(SdkHttpResponse.class); + when(mockInvokeModelResponse.sdkHttpResponse()).thenReturn(mockSdkHttpResponse); + when(mockSdkHttpResponse.isSuccessful()).thenReturn(true); + when(mockSdkHttpResponse.statusCode()).thenReturn(200); + when(mockSdkHttpResponse.statusText()).thenReturn(Optional.of("OK")); + + BedrockRuntimeResponseMetadata mockBedrockRuntimeResponseMetadata = mock(BedrockRuntimeResponseMetadata.class); + when(mockInvokeModelResponse.responseMetadata()).thenReturn(mockBedrockRuntimeResponseMetadata); + when(mockBedrockRuntimeResponseMetadata.requestId()).thenReturn("90a22e92-db1d-4474-97a9-28b143846301"); + + // Instantiate ModelInvocation + TitanModelInvocation titanModelInvocation = new TitanModelInvocation(linkingMetadata, userAttributes, mockInvokeModelRequest, + mockInvokeModelResponse); + + // When + // Build LlmEmbedding event + Builder builder = new Builder(titanModelInvocation); + LlmEvent llmEmbeddingEvent = builder + .spanId() // attribute 1 + .traceId() // attribute 2 + .vendor() // attribute 3 + .ingestSource() // attribute 4 + .id(titanModelInvocation.getModelResponse().getLlmEmbeddingId()) // attribute 5 + .requestId() // attribute 6 + .input(0) // attribute 7 + .requestModel() // attribute 8 + .responseModel() // attribute 9 + .tokenCount(123) // attribute 10 + .error() // not added + .duration(9000f) // attribute 11 + .build(); + + // attributes 12 & 13 should be the two llm.* prefixed userAttributes + + // Record LlmEmbedding event + llmEmbeddingEvent.recordLlmEmbeddingEvent(); + + // Then + Collection customEvents = introspector.getCustomEvents(LLM_EMBEDDING); + assertEquals(1, customEvents.size()); + + Event event = customEvents.iterator().next(); + assertEquals(LLM_EMBEDDING, event.getType()); + + Map attributes = event.getAttributes(); + assertEquals(13, attributes.size()); + assertEquals("span-id-123", attributes.get("span_id")); + assertEquals("trace-id-xyz", attributes.get("trace_id")); + assertEquals("bedrock", attributes.get("vendor")); + assertEquals("Java", attributes.get("ingest_source")); + assertFalse(((String) attributes.get("id")).isEmpty()); + assertEquals("90a22e92-db1d-4474-97a9-28b143846301", attributes.get("request_id")); + assertEquals("What is the color of the sky?", attributes.get("input")); + assertEquals("amazon.titan-embed-text-v1", attributes.get("request.model")); + assertEquals("amazon.titan-embed-text-v1", attributes.get("response.model")); + assertEquals(123, attributes.get("token_count")); + assertEquals(9000f, attributes.get("duration")); + assertEquals("conversation-id-890", attributes.get("llm.conversation_id")); + assertEquals("testPrefix", attributes.get("llm.testPrefix")); + } + + @Test + public void testRecordLlmChatCompletionMessageEvent() { + // Given + Map linkingMetadata = new HashMap<>(); + linkingMetadata.put("span.id", "span-id-123"); + linkingMetadata.put("trace.id", "trace-id-xyz"); + + Map userAttributes = new HashMap<>(); + userAttributes.put("llm.conversation_id", "conversation-id-890"); + userAttributes.put("llm.testPrefix", "testPrefix"); + userAttributes.put("test", "test"); + + String expectedUserPrompt = "Human: What is the color of the sky?\n\nAssistant:"; + + // Mock out ModelRequest + InvokeModelRequest mockInvokeModelRequest = mock(InvokeModelRequest.class); + SdkBytes mockRequestSdkBytes = mock(SdkBytes.class); + when(mockInvokeModelRequest.body()).thenReturn(mockRequestSdkBytes); + when(mockRequestSdkBytes.asUtf8String()) + .thenReturn( + "{\"stop_sequences\":[\"\\n\\nHuman:\"],\"max_tokens_to_sample\":1000,\"temperature\":0.5,\"prompt\":\"Human: What is the color of the sky?\\n\\nAssistant:\"}"); + when(mockInvokeModelRequest.modelId()).thenReturn("anthropic.claude-v2"); + + // Mock out ModelResponse + InvokeModelResponse mockInvokeModelResponse = mock(InvokeModelResponse.class); + SdkBytes mockResponseSdkBytes = mock(SdkBytes.class); + when(mockInvokeModelResponse.body()).thenReturn(mockResponseSdkBytes); + when(mockResponseSdkBytes.asUtf8String()) + .thenReturn( + "{\"completion\":\" The sky appears blue during the day because of how sunlight interacts with the gases in Earth's atmosphere.\",\"stop_reason\":\"stop_sequence\",\"stop\":\"\\n\\nHuman:\"}"); + + SdkHttpResponse mockSdkHttpResponse = mock(SdkHttpResponse.class); + when(mockInvokeModelResponse.sdkHttpResponse()).thenReturn(mockSdkHttpResponse); + when(mockSdkHttpResponse.isSuccessful()).thenReturn(true); + when(mockSdkHttpResponse.statusCode()).thenReturn(200); + when(mockSdkHttpResponse.statusText()).thenReturn(Optional.of("OK")); + + BedrockRuntimeResponseMetadata mockBedrockRuntimeResponseMetadata = mock(BedrockRuntimeResponseMetadata.class); + when(mockInvokeModelResponse.responseMetadata()).thenReturn(mockBedrockRuntimeResponseMetadata); + when(mockBedrockRuntimeResponseMetadata.requestId()).thenReturn("90a22e92-db1d-4474-97a9-28b143846301"); + + ClaudeModelInvocation claudeModelInvocation = new ClaudeModelInvocation(linkingMetadata, userAttributes, mockInvokeModelRequest, + mockInvokeModelResponse); + + LlmEvent.Builder builder = new LlmEvent.Builder(claudeModelInvocation); + LlmEvent llmChatCompletionMessageEvent = builder + .spanId() // attribute 1 + .traceId() // attribute 2 + .vendor() // attribute 3 + .ingestSource() // attribute 4 + .id(ModelInvocation.getRandomGuid()) // attribute 5 + .content(expectedUserPrompt) // attribute 6 + .role(true) // attribute 7 + .isResponse(true) // attribute 8 + .requestId() // attribute 9 + .responseModel() // attribute 10 + .sequence(0) // attribute 11 + .completionId() // attribute 12 + .tokenCount(LlmTokenCountCallbackHolder.getLlmTokenCountCallback().calculateLlmTokenCount("model", "content")) // attribute 13 + .build(); + + // attributes 14 & 15 should be the two llm.* prefixed userAttributes + + // Record LlmChatCompletionMessage event + llmChatCompletionMessageEvent.recordLlmChatCompletionMessageEvent(); + + // Then + Collection customEvents = introspector.getCustomEvents(LLM_CHAT_COMPLETION_MESSAGE); + assertEquals(1, customEvents.size()); + + Event event = customEvents.iterator().next(); + assertEquals(LLM_CHAT_COMPLETION_MESSAGE, event.getType()); + + Map attributes = event.getAttributes(); + assertEquals(15, attributes.size()); + assertEquals("span-id-123", attributes.get("span_id")); + assertEquals("trace-id-xyz", attributes.get("trace_id")); + assertEquals("bedrock", attributes.get("vendor")); + assertEquals("Java", attributes.get("ingest_source")); + assertFalse(((String) attributes.get("id")).isEmpty()); + assertEquals(expectedUserPrompt, attributes.get("content")); + assertEquals("user", attributes.get("role")); + assertEquals(false, attributes.get("is_response")); + assertEquals("90a22e92-db1d-4474-97a9-28b143846301", attributes.get("request_id")); + assertEquals("anthropic.claude-v2", attributes.get("response.model")); + assertEquals(0, attributes.get("sequence")); + assertFalse(((String) attributes.get("completion_id")).isEmpty()); + assertEquals(13, attributes.get("token_count")); + assertEquals("conversation-id-890", attributes.get("llm.conversation_id")); + assertEquals("testPrefix", attributes.get("llm.testPrefix")); + } + + @Test + public void testRecordLlmChatCompletionSummaryEvent() { + // Given + Map linkingMetadata = new HashMap<>(); + linkingMetadata.put("span.id", "span-id-123"); + linkingMetadata.put("trace.id", "trace-id-xyz"); + + Map userAttributes = new HashMap<>(); + userAttributes.put("llm.conversation_id", "conversation-id-890"); + userAttributes.put("llm.testPrefix", "testPrefix"); + userAttributes.put("test", "test"); + + // Mock out ModelRequest + InvokeModelRequest mockInvokeModelRequest = mock(InvokeModelRequest.class); + SdkBytes mockRequestSdkBytes = mock(SdkBytes.class); + when(mockInvokeModelRequest.body()).thenReturn(mockRequestSdkBytes); + when(mockRequestSdkBytes.asUtf8String()) + .thenReturn( + "{\"stop_sequences\":[\"\\n\\nHuman:\"],\"max_tokens_to_sample\":1000,\"temperature\":0.5,\"prompt\":\"Human: What is the color of the sky?\\n\\nAssistant:\"}"); + when(mockInvokeModelRequest.modelId()).thenReturn("anthropic.claude-v2"); + + // Mock out ModelResponse + InvokeModelResponse mockInvokeModelResponse = mock(InvokeModelResponse.class); + SdkBytes mockResponseSdkBytes = mock(SdkBytes.class); + when(mockInvokeModelResponse.body()).thenReturn(mockResponseSdkBytes); + when(mockResponseSdkBytes.asUtf8String()) + .thenReturn( + "{\"completion\":\" The sky appears blue during the day because of how sunlight interacts with the gases in Earth's atmosphere.\",\"stop_reason\":\"stop_sequence\",\"stop\":\"\\n\\nHuman:\"}"); + + SdkHttpResponse mockSdkHttpResponse = mock(SdkHttpResponse.class); + when(mockInvokeModelResponse.sdkHttpResponse()).thenReturn(mockSdkHttpResponse); + when(mockSdkHttpResponse.isSuccessful()).thenReturn(true); + when(mockSdkHttpResponse.statusCode()).thenReturn(200); + when(mockSdkHttpResponse.statusText()).thenReturn(Optional.of("OK")); + + BedrockRuntimeResponseMetadata mockBedrockRuntimeResponseMetadata = mock(BedrockRuntimeResponseMetadata.class); + when(mockInvokeModelResponse.responseMetadata()).thenReturn(mockBedrockRuntimeResponseMetadata); + when(mockBedrockRuntimeResponseMetadata.requestId()).thenReturn("90a22e92-db1d-4474-97a9-28b143846301"); + + ClaudeModelInvocation claudeModelInvocation = new ClaudeModelInvocation(linkingMetadata, userAttributes, mockInvokeModelRequest, + mockInvokeModelResponse); + + LlmEvent.Builder builder = new LlmEvent.Builder(claudeModelInvocation); + LlmEvent llmChatCompletionSummaryEvent = builder + .spanId() // attribute 1 + .traceId() // attribute 2 + .vendor() // attribute 3 + .ingestSource() // attribute 4 + .id(claudeModelInvocation.getModelResponse().getLlmChatCompletionSummaryId()) // attribute 5 + .requestId() // attribute 6 + .requestTemperature() // attribute 7 + .requestMaxTokens() // attribute 8 + .requestModel() // attribute 9 + .responseModel() // attribute 10 + .responseNumberOfMessages(2) // attribute 11 + .responseChoicesFinishReason() // attribute 12 + .error() // not added + .duration(9000f) // attribute 13 + .build(); + + // attributes 14 & 15 should be the two llm.* prefixed userAttributes + + // Record LlmChatCompletionSummary event + llmChatCompletionSummaryEvent.recordLlmChatCompletionSummaryEvent(); + + // Then + Collection customEvents = introspector.getCustomEvents(LLM_CHAT_COMPLETION_SUMMARY); + assertEquals(1, customEvents.size()); + + Event event = customEvents.iterator().next(); + assertEquals(LLM_CHAT_COMPLETION_SUMMARY, event.getType()); + + Map attributes = event.getAttributes(); + assertEquals(15, attributes.size()); + assertEquals("span-id-123", attributes.get("span_id")); + assertEquals("trace-id-xyz", attributes.get("trace_id")); + assertEquals("bedrock", attributes.get("vendor")); + assertEquals("Java", attributes.get("ingest_source")); + assertFalse(((String) attributes.get("id")).isEmpty()); + assertEquals("90a22e92-db1d-4474-97a9-28b143846301", attributes.get("request_id")); + assertEquals(0.5f, attributes.get("request.temperature")); + assertEquals(1000, attributes.get("request.max_tokens")); + assertEquals("anthropic.claude-v2", attributes.get("request.model")); + assertEquals("anthropic.claude-v2", attributes.get("response.model")); + assertEquals(2, attributes.get("response.number_of_messages")); + assertEquals("stop_sequence", attributes.get("response.choices.finish_reason")); + assertEquals(9000f, attributes.get("duration")); + assertEquals("conversation-id-890", attributes.get("llm.conversation_id")); + assertEquals("testPrefix", attributes.get("llm.testPrefix")); + } +} diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/models/TestUtil.java b/instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/models/TestUtil.java new file mode 100644 index 0000000000..4b3b86da9b --- /dev/null +++ b/instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/models/TestUtil.java @@ -0,0 +1,105 @@ +/* + * + * * Copyright 2024 New Relic Corporation. All rights reserved. + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package llm.models; + +import com.newrelic.agent.introspec.ErrorEvent; +import com.newrelic.agent.introspec.Event; + +import java.util.Collection; +import java.util.Iterator; +import java.util.Map; + +import static llm.events.LlmEvent.LLM_CHAT_COMPLETION_MESSAGE; +import static llm.events.LlmEvent.LLM_CHAT_COMPLETION_SUMMARY; +import static llm.events.LlmEvent.LLM_EMBEDDING; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +public class TestUtil { + public static void assertLlmChatCompletionMessageAttributes(Event event, String modelId, String requestInput, String responseContent, boolean isResponse) { + assertEquals(LLM_CHAT_COMPLETION_MESSAGE, event.getType()); + + Map attributes = event.getAttributes(); + assertEquals("Java", attributes.get("ingest_source")); + assertFalse(((String) attributes.get("completion_id")).isEmpty()); + assertFalse(((String) attributes.get("id")).isEmpty()); + assertFalse(((String) attributes.get("request_id")).isEmpty()); + assertEquals("bedrock", attributes.get("vendor")); + assertEquals(modelId, attributes.get("response.model")); + assertEquals("testPrefix", attributes.get("llm.testPrefix")); + assertEquals("conversation-id-value", attributes.get("llm.conversation_id")); + assertEquals(13, attributes.get("token_count")); + + if (isResponse) { + assertEquals("assistant", attributes.get("role")); + assertEquals(responseContent, attributes.get("content")); + assertEquals(true, attributes.get("is_response")); + assertEquals(1, attributes.get("sequence")); + } else { + assertEquals("user", attributes.get("role")); + assertEquals(requestInput, attributes.get("content")); + assertEquals(false, attributes.get("is_response")); + assertEquals(0, attributes.get("sequence")); + } + } + + public static void assertLlmChatCompletionSummaryAttributes(Event event, String modelId, String finishReason) { + assertEquals(LLM_CHAT_COMPLETION_SUMMARY, event.getType()); + + Map attributes = event.getAttributes(); + assertEquals("Java", attributes.get("ingest_source")); + assertEquals(0.5f, attributes.get("request.temperature")); + assertTrue(((Float) attributes.get("duration")) >= 0); + assertEquals(finishReason, attributes.get("response.choices.finish_reason")); + assertEquals(modelId, attributes.get("request.model")); + assertEquals("bedrock", attributes.get("vendor")); + assertEquals(modelId, attributes.get("response.model")); + assertFalse(((String) attributes.get("id")).isEmpty()); + assertFalse(((String) attributes.get("request_id")).isEmpty()); + assertEquals(2, attributes.get("response.number_of_messages")); + assertEquals(1000, attributes.get("request.max_tokens")); + assertEquals("testPrefix", attributes.get("llm.testPrefix")); + assertEquals("conversation-id-value", attributes.get("llm.conversation_id")); + } + + public static void assertLlmEmbeddingAttributes(Event event, String modelId, String requestInput) { + assertEquals(LLM_EMBEDDING, event.getType()); + + Map attributes = event.getAttributes(); + assertEquals("Java", attributes.get("ingest_source")); + assertTrue(((Float) attributes.get("duration")) >= 0); + assertEquals(requestInput, attributes.get("input")); + assertEquals(modelId, attributes.get("request.model")); + assertEquals(modelId, attributes.get("response.model")); + assertEquals("bedrock", attributes.get("vendor")); + assertFalse(((String) attributes.get("id")).isEmpty()); + assertFalse(((String) attributes.get("request_id")).isEmpty()); + assertEquals("testPrefix", attributes.get("llm.testPrefix")); + assertEquals("conversation-id-value", attributes.get("llm.conversation_id")); + assertEquals(13, attributes.get("token_count")); + } + + public static void assertErrorEvent(boolean isError, Collection errorEvents) { + if (isError) { + assertEquals(1, errorEvents.size()); + Iterator errorEventIterator = errorEvents.iterator(); + ErrorEvent errorEvent = errorEventIterator.next(); + + assertEquals("LlmError: BAD_REQUEST", errorEvent.getErrorClass()); + assertEquals("LlmError: BAD_REQUEST", errorEvent.getErrorMessage()); + + Map errorEventAttributes = errorEvent.getAttributes(); + assertFalse(errorEventAttributes.isEmpty()); + assertEquals(400, errorEventAttributes.get("error.code")); + assertEquals(400, errorEventAttributes.get("http.statusCode")); + } else { + assertTrue(errorEvents.isEmpty()); + } + } +} diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/models/ai21labs/jurassic/JurassicModelInvocationTest.java b/instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/models/ai21labs/jurassic/JurassicModelInvocationTest.java new file mode 100644 index 0000000000..320b54ac57 --- /dev/null +++ b/instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/models/ai21labs/jurassic/JurassicModelInvocationTest.java @@ -0,0 +1,166 @@ +/* + * + * * Copyright 2024 New Relic Corporation. All rights reserved. + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package llm.models.ai21labs.jurassic; + +import com.newrelic.agent.introspec.Event; +import com.newrelic.agent.introspec.InstrumentationTestConfig; +import com.newrelic.agent.introspec.InstrumentationTestRunner; +import com.newrelic.agent.introspec.Introspector; +import com.newrelic.api.agent.LlmTokenCountCallback; +import com.newrelic.agent.bridge.aimonitoring.LlmTokenCountCallbackHolder; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import software.amazon.awssdk.core.SdkBytes; +import software.amazon.awssdk.http.SdkHttpResponse; +import software.amazon.awssdk.services.bedrockruntime.model.BedrockRuntimeResponseMetadata; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelRequest; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse; + +import java.util.Collection; +import java.util.HashMap; +import java.util.Iterator; +import java.util.Map; +import java.util.Optional; + +import static llm.events.LlmEvent.LLM_CHAT_COMPLETION_MESSAGE; +import static llm.events.LlmEvent.LLM_CHAT_COMPLETION_SUMMARY; +import static llm.models.TestUtil.assertErrorEvent; +import static llm.models.TestUtil.assertLlmChatCompletionMessageAttributes; +import static llm.models.TestUtil.assertLlmChatCompletionSummaryAttributes; +import static org.junit.Assert.assertEquals; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +@RunWith(InstrumentationTestRunner.class) +@InstrumentationTestConfig(includePrefixes = { "software.amazon.awssdk.services.bedrockruntime" }, configName = "llm_enabled.yml") +public class JurassicModelInvocationTest { + + private final Introspector introspector = InstrumentationTestRunner.getIntrospector(); + + // Completion + private final String completionModelId = "ai21.j2-mid-v1"; + private final String completionRequestBody = "{\"temperature\":0.5,\"maxTokens\":1000,\"prompt\":\"What is the color of the sky?\"}"; + private final String completionResponseBody = + "{\"id\":1234,\"prompt\":{\"text\":\"What is the color of the sky?\",\"tokens\":[{\"generatedToken\":{\"token\":\"▁What▁is▁the\",\"logprob\":-9.992481231689453,\"raw_logprob\":-9.992481231689453}\n" + + ",\"topTokens\":null,\"textRange\":{\"start\":0,\"end\":11}}]},\"completions\":[{\"data\":{\"text\":\"\\nThe color of the sky is blue.\",\"tokens\":[{\"generatedToken\":{\"token\":\"<|newline|>\",\"logprob\":0.0,\"raw_logprob\":-1.389883691444993E-4},\"topTokens\":null,\"textRange\":{\"start\":0,\"end\":1}}]},\"finishReason\":{\"reason\":\"endoftext\"}}]}"; + private final String completionRequestInput = "What is the color of the sky?"; + private final String completionResponseContent = "\nThe color of the sky is blue."; + private final String finishReason = "endoftext"; + + @Before + public void before() { + introspector.clear(); + LlmTokenCountCallback llmTokenCountCallback = (model, content) -> 13; + LlmTokenCountCallbackHolder.setLlmTokenCountCallback(llmTokenCountCallback); + } + + @Test + public void testCompletion() { + boolean isError = false; + + JurassicModelInvocation jurassicModelInvocation = mockJurassicModelInvocation(completionModelId, completionRequestBody, completionResponseBody, + isError); + jurassicModelInvocation.recordLlmEvents(System.currentTimeMillis()); + + Collection llmChatCompletionMessageEvents = introspector.getCustomEvents(LLM_CHAT_COMPLETION_MESSAGE); + assertEquals(2, llmChatCompletionMessageEvents.size()); + Iterator llmChatCompletionMessageEventIterator = llmChatCompletionMessageEvents.iterator(); + Event llmChatCompletionMessageEventOne = llmChatCompletionMessageEventIterator.next(); + + assertLlmChatCompletionMessageAttributes(llmChatCompletionMessageEventOne, completionModelId, completionRequestInput, completionResponseContent, false); + + Event llmChatCompletionMessageEventTwo = llmChatCompletionMessageEventIterator.next(); + + assertLlmChatCompletionMessageAttributes(llmChatCompletionMessageEventTwo, completionModelId, completionRequestInput, completionResponseContent, true); + + Collection llmChatCompletionSummaryEvents = introspector.getCustomEvents(LLM_CHAT_COMPLETION_SUMMARY); + assertEquals(1, llmChatCompletionSummaryEvents.size()); + Iterator llmChatCompletionSummaryEventIterator = llmChatCompletionSummaryEvents.iterator(); + Event llmChatCompletionSummaryEvent = llmChatCompletionSummaryEventIterator.next(); + + assertLlmChatCompletionSummaryAttributes(llmChatCompletionSummaryEvent, completionModelId, finishReason); + + assertErrorEvent(isError, introspector.getErrorEvents()); + } + + @Test + public void testCompletionError() { + boolean isError = true; + + JurassicModelInvocation jurassicModelInvocation = mockJurassicModelInvocation(completionModelId, completionRequestBody, completionResponseBody, + isError); + jurassicModelInvocation.recordLlmEvents(System.currentTimeMillis()); + + Collection llmChatCompletionMessageEvents = introspector.getCustomEvents(LLM_CHAT_COMPLETION_MESSAGE); + assertEquals(2, llmChatCompletionMessageEvents.size()); + Iterator llmChatCompletionMessageEventIterator = llmChatCompletionMessageEvents.iterator(); + Event llmChatCompletionMessageEventOne = llmChatCompletionMessageEventIterator.next(); + + assertLlmChatCompletionMessageAttributes(llmChatCompletionMessageEventOne, completionModelId, completionRequestInput, completionResponseContent, false); + + Event llmChatCompletionMessageEventTwo = llmChatCompletionMessageEventIterator.next(); + + assertLlmChatCompletionMessageAttributes(llmChatCompletionMessageEventTwo, completionModelId, completionRequestInput, completionResponseContent, true); + + Collection llmChatCompletionSummaryEvents = introspector.getCustomEvents(LLM_CHAT_COMPLETION_SUMMARY); + assertEquals(1, llmChatCompletionSummaryEvents.size()); + Iterator llmChatCompletionSummaryEventIterator = llmChatCompletionSummaryEvents.iterator(); + Event llmChatCompletionSummaryEvent = llmChatCompletionSummaryEventIterator.next(); + + assertLlmChatCompletionSummaryAttributes(llmChatCompletionSummaryEvent, completionModelId, finishReason); + + assertErrorEvent(isError, introspector.getErrorEvents()); + } + + private JurassicModelInvocation mockJurassicModelInvocation(String modelId, String requestBody, String responseBody, boolean isError) { + // Given + Map linkingMetadata = new HashMap<>(); + linkingMetadata.put("span.id", "span-id-123"); + linkingMetadata.put("trace.id", "trace-id-xyz"); + + Map userAttributes = new HashMap<>(); + userAttributes.put("llm.conversation_id", "conversation-id-value"); + userAttributes.put("llm.testPrefix", "testPrefix"); + userAttributes.put("test", "test"); + + // Mock out ModelRequest + InvokeModelRequest mockInvokeModelRequest = mock(InvokeModelRequest.class); + SdkBytes mockRequestSdkBytes = mock(SdkBytes.class); + when(mockInvokeModelRequest.body()).thenReturn(mockRequestSdkBytes); + when(mockRequestSdkBytes.asUtf8String()).thenReturn(requestBody); + when(mockInvokeModelRequest.modelId()).thenReturn(modelId); + + // Mock out ModelResponse + InvokeModelResponse mockInvokeModelResponse = mock(InvokeModelResponse.class); + SdkBytes mockResponseSdkBytes = mock(SdkBytes.class); + when(mockInvokeModelResponse.body()).thenReturn(mockResponseSdkBytes); + when(mockResponseSdkBytes.asUtf8String()).thenReturn(responseBody); + + SdkHttpResponse mockSdkHttpResponse = mock(SdkHttpResponse.class); + when(mockInvokeModelResponse.sdkHttpResponse()).thenReturn(mockSdkHttpResponse); + + if (isError) { + when(mockSdkHttpResponse.statusCode()).thenReturn(400); + when(mockSdkHttpResponse.statusText()).thenReturn(Optional.of("BAD_REQUEST")); + when(mockSdkHttpResponse.isSuccessful()).thenReturn(false); + } else { + when(mockSdkHttpResponse.statusCode()).thenReturn(200); + when(mockSdkHttpResponse.statusText()).thenReturn(Optional.of("OK")); + when(mockSdkHttpResponse.isSuccessful()).thenReturn(true); + } + + BedrockRuntimeResponseMetadata mockBedrockRuntimeResponseMetadata = mock(BedrockRuntimeResponseMetadata.class); + when(mockInvokeModelResponse.responseMetadata()).thenReturn(mockBedrockRuntimeResponseMetadata); + when(mockBedrockRuntimeResponseMetadata.requestId()).thenReturn("90a22e92-db1d-4474-97a9-28b143846301"); + + // Instantiate ModelInvocation + return new JurassicModelInvocation(linkingMetadata, userAttributes, mockInvokeModelRequest, + mockInvokeModelResponse); + } +} diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/models/amazon/titan/TitanModelInvocationTest.java b/instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/models/amazon/titan/TitanModelInvocationTest.java new file mode 100644 index 0000000000..521472afb5 --- /dev/null +++ b/instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/models/amazon/titan/TitanModelInvocationTest.java @@ -0,0 +1,204 @@ +/* + * + * * Copyright 2024 New Relic Corporation. All rights reserved. + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package llm.models.amazon.titan; + +import com.newrelic.agent.introspec.Event; +import com.newrelic.agent.introspec.InstrumentationTestConfig; +import com.newrelic.agent.introspec.InstrumentationTestRunner; +import com.newrelic.agent.introspec.Introspector; +import com.newrelic.api.agent.LlmTokenCountCallback; +import com.newrelic.agent.bridge.aimonitoring.LlmTokenCountCallbackHolder; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import software.amazon.awssdk.core.SdkBytes; +import software.amazon.awssdk.http.SdkHttpResponse; +import software.amazon.awssdk.services.bedrockruntime.model.BedrockRuntimeResponseMetadata; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelRequest; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse; + +import java.util.Collection; +import java.util.HashMap; +import java.util.Iterator; +import java.util.Map; +import java.util.Optional; + +import static llm.events.LlmEvent.LLM_CHAT_COMPLETION_MESSAGE; +import static llm.events.LlmEvent.LLM_CHAT_COMPLETION_SUMMARY; +import static llm.events.LlmEvent.LLM_EMBEDDING; +import static llm.models.TestUtil.assertErrorEvent; +import static llm.models.TestUtil.assertLlmChatCompletionMessageAttributes; +import static llm.models.TestUtil.assertLlmChatCompletionSummaryAttributes; +import static llm.models.TestUtil.assertLlmEmbeddingAttributes; +import static org.junit.Assert.assertEquals; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +@RunWith(InstrumentationTestRunner.class) +@InstrumentationTestConfig(includePrefixes = { "software.amazon.awssdk.services.bedrockruntime" }, configName = "llm_enabled.yml") +public class TitanModelInvocationTest { + + private final Introspector introspector = InstrumentationTestRunner.getIntrospector(); + + // Embedding + private final String embeddingModelId = "amazon.titan-embed-text-v1"; + private final String embeddingRequestBody = "{\"inputText\":\"What is the color of the sky?\"}"; + private final String embeddingResponseBody = "{\"embedding\":[0.328125,0.44335938],\"inputTextTokenCount\":8}"; + private final String embeddingRequestInput = "What is the color of the sky?"; + + // Completion + private final String completionModelId = "amazon.titan-text-lite-v1"; + private final String completionRequestBody = "{\"inputText\":\"What is the color of the sky?\",\"textGenerationConfig\":{\"maxTokenCount\":1000,\"stopSequences\":[\"User:\"],\"temperature\":0.5,\"topP\":0.9}}"; + private final String completionResponseBody = "{\"inputTextTokenCount\":8,\"results\":[{\"tokenCount\":9,\"outputText\":\"\\nThe color of the sky is blue.\",\"completionReason\":\"FINISH\"}]}"; + private final String completionRequestInput = "What is the color of the sky?"; + private final String completionResponseContent = "\nThe color of the sky is blue."; + private final String finishReason = "FINISH"; + + @Before + public void before() { + introspector.clear(); + LlmTokenCountCallback llmTokenCountCallback = (model, content) -> 13; + LlmTokenCountCallbackHolder.setLlmTokenCountCallback(llmTokenCountCallback); + } + + @Test + public void testEmbedding() { + boolean isError = false; + + TitanModelInvocation titanModelInvocation = mockTitanModelInvocation(embeddingModelId, embeddingRequestBody, embeddingResponseBody, isError); + titanModelInvocation.recordLlmEvents(System.currentTimeMillis()); + + Collection llmEmbeddingEvents = introspector.getCustomEvents(LLM_EMBEDDING); + assertEquals(1, llmEmbeddingEvents.size()); + Iterator llmEmbeddingEventIterator = llmEmbeddingEvents.iterator(); + Event llmEmbeddingEvent = llmEmbeddingEventIterator.next(); + + assertLlmEmbeddingAttributes(llmEmbeddingEvent, embeddingModelId, embeddingRequestInput); + + assertErrorEvent(isError, introspector.getErrorEvents()); + } + + @Test + public void testCompletion() { + boolean isError = false; + + TitanModelInvocation titanModelInvocation = mockTitanModelInvocation(completionModelId, completionRequestBody, completionResponseBody, isError); + titanModelInvocation.recordLlmEvents(System.currentTimeMillis()); + + Collection llmChatCompletionMessageEvents = introspector.getCustomEvents(LLM_CHAT_COMPLETION_MESSAGE); + assertEquals(2, llmChatCompletionMessageEvents.size()); + Iterator llmChatCompletionMessageEventIterator = llmChatCompletionMessageEvents.iterator(); + Event llmChatCompletionMessageEventOne = llmChatCompletionMessageEventIterator.next(); + + assertLlmChatCompletionMessageAttributes(llmChatCompletionMessageEventOne, completionModelId, completionRequestInput, completionResponseContent, false); + + Event llmChatCompletionMessageEventTwo = llmChatCompletionMessageEventIterator.next(); + + assertLlmChatCompletionMessageAttributes(llmChatCompletionMessageEventTwo, completionModelId, completionRequestInput, completionResponseContent, true); + + Collection llmChatCompletionSummaryEvents = introspector.getCustomEvents(LLM_CHAT_COMPLETION_SUMMARY); + assertEquals(1, llmChatCompletionSummaryEvents.size()); + Iterator llmChatCompletionSummaryEventIterator = llmChatCompletionSummaryEvents.iterator(); + Event llmChatCompletionSummaryEvent = llmChatCompletionSummaryEventIterator.next(); + + assertLlmChatCompletionSummaryAttributes(llmChatCompletionSummaryEvent, completionModelId, finishReason); + + assertErrorEvent(isError, introspector.getErrorEvents()); + } + + @Test + public void testEmbeddingError() { + boolean isError = true; + + TitanModelInvocation titanModelInvocation = mockTitanModelInvocation(embeddingModelId, embeddingRequestBody, embeddingResponseBody, isError); + titanModelInvocation.recordLlmEvents(System.currentTimeMillis()); + + Collection llmEmbeddingEvents = introspector.getCustomEvents(LLM_EMBEDDING); + assertEquals(1, llmEmbeddingEvents.size()); + Iterator llmEmbeddingEventIterator = llmEmbeddingEvents.iterator(); + Event llmEmbeddingEvent = llmEmbeddingEventIterator.next(); + + assertLlmEmbeddingAttributes(llmEmbeddingEvent, embeddingModelId, embeddingRequestInput); + + assertErrorEvent(isError, introspector.getErrorEvents()); + } + + @Test + public void testCompletionError() { + boolean isError = true; + + TitanModelInvocation titanModelInvocation = mockTitanModelInvocation(completionModelId, completionRequestBody, completionResponseBody, isError); + titanModelInvocation.recordLlmEvents(System.currentTimeMillis()); + + Collection llmChatCompletionMessageEvents = introspector.getCustomEvents(LLM_CHAT_COMPLETION_MESSAGE); + assertEquals(2, llmChatCompletionMessageEvents.size()); + Iterator llmChatCompletionMessageEventIterator = llmChatCompletionMessageEvents.iterator(); + Event llmChatCompletionMessageEventOne = llmChatCompletionMessageEventIterator.next(); + + assertLlmChatCompletionMessageAttributes(llmChatCompletionMessageEventOne, completionModelId, completionRequestInput, completionResponseContent, false); + + Event llmChatCompletionMessageEventTwo = llmChatCompletionMessageEventIterator.next(); + + assertLlmChatCompletionMessageAttributes(llmChatCompletionMessageEventTwo, completionModelId, completionRequestInput, completionResponseContent, true); + + Collection llmChatCompletionSummaryEvents = introspector.getCustomEvents(LLM_CHAT_COMPLETION_SUMMARY); + assertEquals(1, llmChatCompletionSummaryEvents.size()); + Iterator llmChatCompletionSummaryEventIterator = llmChatCompletionSummaryEvents.iterator(); + Event llmChatCompletionSummaryEvent = llmChatCompletionSummaryEventIterator.next(); + + assertLlmChatCompletionSummaryAttributes(llmChatCompletionSummaryEvent, completionModelId, finishReason); + + assertErrorEvent(isError, introspector.getErrorEvents()); + } + + private TitanModelInvocation mockTitanModelInvocation(String modelId, String requestBody, String responseBody, boolean isError) { + // Given + Map linkingMetadata = new HashMap<>(); + linkingMetadata.put("span.id", "span-id-123"); + linkingMetadata.put("trace.id", "trace-id-xyz"); + + Map userAttributes = new HashMap<>(); + userAttributes.put("llm.conversation_id", "conversation-id-value"); + userAttributes.put("llm.testPrefix", "testPrefix"); + userAttributes.put("test", "test"); + + // Mock out ModelRequest + InvokeModelRequest mockInvokeModelRequest = mock(InvokeModelRequest.class); + SdkBytes mockRequestSdkBytes = mock(SdkBytes.class); + when(mockInvokeModelRequest.body()).thenReturn(mockRequestSdkBytes); + when(mockRequestSdkBytes.asUtf8String()).thenReturn(requestBody); + when(mockInvokeModelRequest.modelId()).thenReturn(modelId); + + // Mock out ModelResponse + InvokeModelResponse mockInvokeModelResponse = mock(InvokeModelResponse.class); + SdkBytes mockResponseSdkBytes = mock(SdkBytes.class); + when(mockInvokeModelResponse.body()).thenReturn(mockResponseSdkBytes); + when(mockResponseSdkBytes.asUtf8String()).thenReturn(responseBody); + + SdkHttpResponse mockSdkHttpResponse = mock(SdkHttpResponse.class); + when(mockInvokeModelResponse.sdkHttpResponse()).thenReturn(mockSdkHttpResponse); + + if (isError) { + when(mockSdkHttpResponse.statusCode()).thenReturn(400); + when(mockSdkHttpResponse.statusText()).thenReturn(Optional.of("BAD_REQUEST")); + when(mockSdkHttpResponse.isSuccessful()).thenReturn(false); + } else { + when(mockSdkHttpResponse.statusCode()).thenReturn(200); + when(mockSdkHttpResponse.statusText()).thenReturn(Optional.of("OK")); + when(mockSdkHttpResponse.isSuccessful()).thenReturn(true); + } + + BedrockRuntimeResponseMetadata mockBedrockRuntimeResponseMetadata = mock(BedrockRuntimeResponseMetadata.class); + when(mockInvokeModelResponse.responseMetadata()).thenReturn(mockBedrockRuntimeResponseMetadata); + when(mockBedrockRuntimeResponseMetadata.requestId()).thenReturn("90a22e92-db1d-4474-97a9-28b143846301"); + + // Instantiate ModelInvocation + return new TitanModelInvocation(linkingMetadata, userAttributes, mockInvokeModelRequest, + mockInvokeModelResponse); + } +} diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/models/anthropic/claude/ClaudeModelInvocationTest.java b/instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/models/anthropic/claude/ClaudeModelInvocationTest.java new file mode 100644 index 0000000000..36d53a92c3 --- /dev/null +++ b/instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/models/anthropic/claude/ClaudeModelInvocationTest.java @@ -0,0 +1,164 @@ +/* + * + * * Copyright 2024 New Relic Corporation. All rights reserved. + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package llm.models.anthropic.claude; + +import com.newrelic.agent.introspec.Event; +import com.newrelic.agent.introspec.InstrumentationTestConfig; +import com.newrelic.agent.introspec.InstrumentationTestRunner; +import com.newrelic.agent.introspec.Introspector; +import com.newrelic.api.agent.LlmTokenCountCallback; +import com.newrelic.agent.bridge.aimonitoring.LlmTokenCountCallbackHolder; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import software.amazon.awssdk.core.SdkBytes; +import software.amazon.awssdk.http.SdkHttpResponse; +import software.amazon.awssdk.services.bedrockruntime.model.BedrockRuntimeResponseMetadata; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelRequest; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse; + +import java.util.Collection; +import java.util.HashMap; +import java.util.Iterator; +import java.util.Map; +import java.util.Optional; + +import static llm.events.LlmEvent.LLM_CHAT_COMPLETION_MESSAGE; +import static llm.events.LlmEvent.LLM_CHAT_COMPLETION_SUMMARY; +import static llm.models.TestUtil.assertErrorEvent; +import static llm.models.TestUtil.assertLlmChatCompletionMessageAttributes; +import static llm.models.TestUtil.assertLlmChatCompletionSummaryAttributes; +import static org.junit.Assert.assertEquals; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +@RunWith(InstrumentationTestRunner.class) +@InstrumentationTestConfig(includePrefixes = { "software.amazon.awssdk.services.bedrockruntime" }, configName = "llm_enabled.yml") +public class ClaudeModelInvocationTest { + + private final Introspector introspector = InstrumentationTestRunner.getIntrospector(); + + // Completion + private final String completionModelId = "anthropic.claude-v2"; + private final String completionRequestBody = "{\"stop_sequences\":[\"\\n\\nHuman:\"],\"max_tokens_to_sample\":1000,\"temperature\":0.5,\"prompt\":\"Human: What is the color of the sky?\\n\\nAssistant:\"}"; + private final String completionResponseBody = "{\"completion\":\" The color of the sky is blue.\",\"stop_reason\":\"stop_sequence\",\"stop\":\"\\n\\nHuman:\"}"; + private final String completionRequestInput = "Human: What is the color of the sky?\n\nAssistant:"; + private final String completionResponseContent = " The color of the sky is blue."; + private final String finishReason = "stop_sequence"; + + @Before + public void before() { + introspector.clear(); + LlmTokenCountCallback llmTokenCountCallback = (model, content) -> 13; + LlmTokenCountCallbackHolder.setLlmTokenCountCallback(llmTokenCountCallback); + } + + @Test + public void testCompletion() { + boolean isError = false; + + ClaudeModelInvocation claudeModelInvocation = mockClaudeModelInvocation(completionModelId, completionRequestBody, completionResponseBody, + isError); + claudeModelInvocation.recordLlmEvents(System.currentTimeMillis()); + + Collection llmChatCompletionMessageEvents = introspector.getCustomEvents(LLM_CHAT_COMPLETION_MESSAGE); + assertEquals(2, llmChatCompletionMessageEvents.size()); + Iterator llmChatCompletionMessageEventIterator = llmChatCompletionMessageEvents.iterator(); + Event llmChatCompletionMessageEventOne = llmChatCompletionMessageEventIterator.next(); + + assertLlmChatCompletionMessageAttributes(llmChatCompletionMessageEventOne, completionModelId, completionRequestInput, completionResponseContent, false); + + Event llmChatCompletionMessageEventTwo = llmChatCompletionMessageEventIterator.next(); + + assertLlmChatCompletionMessageAttributes(llmChatCompletionMessageEventTwo, completionModelId, completionRequestInput, completionResponseContent, true); + + Collection llmChatCompletionSummaryEvents = introspector.getCustomEvents(LLM_CHAT_COMPLETION_SUMMARY); + assertEquals(1, llmChatCompletionSummaryEvents.size()); + Iterator llmChatCompletionSummaryEventIterator = llmChatCompletionSummaryEvents.iterator(); + Event llmChatCompletionSummaryEvent = llmChatCompletionSummaryEventIterator.next(); + + assertLlmChatCompletionSummaryAttributes(llmChatCompletionSummaryEvent, completionModelId, finishReason); + + assertErrorEvent(isError, introspector.getErrorEvents()); + } + + @Test + public void testCompletionError() { + boolean isError = true; + + ClaudeModelInvocation claudeModelInvocation = mockClaudeModelInvocation(completionModelId, completionRequestBody, completionResponseBody, + isError); + claudeModelInvocation.recordLlmEvents(System.currentTimeMillis()); + + Collection llmChatCompletionMessageEvents = introspector.getCustomEvents(LLM_CHAT_COMPLETION_MESSAGE); + assertEquals(2, llmChatCompletionMessageEvents.size()); + Iterator llmChatCompletionMessageEventIterator = llmChatCompletionMessageEvents.iterator(); + Event llmChatCompletionMessageEventOne = llmChatCompletionMessageEventIterator.next(); + + assertLlmChatCompletionMessageAttributes(llmChatCompletionMessageEventOne, completionModelId, completionRequestInput, completionResponseContent, false); + + Event llmChatCompletionMessageEventTwo = llmChatCompletionMessageEventIterator.next(); + + assertLlmChatCompletionMessageAttributes(llmChatCompletionMessageEventTwo, completionModelId, completionRequestInput, completionResponseContent, true); + + Collection llmChatCompletionSummaryEvents = introspector.getCustomEvents(LLM_CHAT_COMPLETION_SUMMARY); + assertEquals(1, llmChatCompletionSummaryEvents.size()); + Iterator llmChatCompletionSummaryEventIterator = llmChatCompletionSummaryEvents.iterator(); + Event llmChatCompletionSummaryEvent = llmChatCompletionSummaryEventIterator.next(); + + assertLlmChatCompletionSummaryAttributes(llmChatCompletionSummaryEvent, completionModelId, finishReason); + + assertErrorEvent(isError, introspector.getErrorEvents()); + } + + private ClaudeModelInvocation mockClaudeModelInvocation(String modelId, String requestBody, String responseBody, boolean isError) { + // Given + Map linkingMetadata = new HashMap<>(); + linkingMetadata.put("span.id", "span-id-123"); + linkingMetadata.put("trace.id", "trace-id-xyz"); + + Map userAttributes = new HashMap<>(); + userAttributes.put("llm.conversation_id", "conversation-id-value"); + userAttributes.put("llm.testPrefix", "testPrefix"); + userAttributes.put("test", "test"); + + // Mock out ModelRequest + InvokeModelRequest mockInvokeModelRequest = mock(InvokeModelRequest.class); + SdkBytes mockRequestSdkBytes = mock(SdkBytes.class); + when(mockInvokeModelRequest.body()).thenReturn(mockRequestSdkBytes); + when(mockRequestSdkBytes.asUtf8String()).thenReturn(requestBody); + when(mockInvokeModelRequest.modelId()).thenReturn(modelId); + + // Mock out ModelResponse + InvokeModelResponse mockInvokeModelResponse = mock(InvokeModelResponse.class); + SdkBytes mockResponseSdkBytes = mock(SdkBytes.class); + when(mockInvokeModelResponse.body()).thenReturn(mockResponseSdkBytes); + when(mockResponseSdkBytes.asUtf8String()).thenReturn(responseBody); + + SdkHttpResponse mockSdkHttpResponse = mock(SdkHttpResponse.class); + when(mockInvokeModelResponse.sdkHttpResponse()).thenReturn(mockSdkHttpResponse); + + if (isError) { + when(mockSdkHttpResponse.statusCode()).thenReturn(400); + when(mockSdkHttpResponse.statusText()).thenReturn(Optional.of("BAD_REQUEST")); + when(mockSdkHttpResponse.isSuccessful()).thenReturn(false); + } else { + when(mockSdkHttpResponse.statusCode()).thenReturn(200); + when(mockSdkHttpResponse.statusText()).thenReturn(Optional.of("OK")); + when(mockSdkHttpResponse.isSuccessful()).thenReturn(true); + } + + BedrockRuntimeResponseMetadata mockBedrockRuntimeResponseMetadata = mock(BedrockRuntimeResponseMetadata.class); + when(mockInvokeModelResponse.responseMetadata()).thenReturn(mockBedrockRuntimeResponseMetadata); + when(mockBedrockRuntimeResponseMetadata.requestId()).thenReturn("90a22e92-db1d-4474-97a9-28b143846301"); + + // Instantiate ModelInvocation + return new ClaudeModelInvocation(linkingMetadata, userAttributes, mockInvokeModelRequest, + mockInvokeModelResponse); + } +} diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/models/cohere/command/CommandModelInvocationTest.java b/instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/models/cohere/command/CommandModelInvocationTest.java new file mode 100644 index 0000000000..cc115dfa29 --- /dev/null +++ b/instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/models/cohere/command/CommandModelInvocationTest.java @@ -0,0 +1,204 @@ +/* + * + * * Copyright 2024 New Relic Corporation. All rights reserved. + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package llm.models.cohere.command; + +import com.newrelic.agent.introspec.Event; +import com.newrelic.agent.introspec.InstrumentationTestConfig; +import com.newrelic.agent.introspec.InstrumentationTestRunner; +import com.newrelic.agent.introspec.Introspector; +import com.newrelic.api.agent.LlmTokenCountCallback; +import com.newrelic.agent.bridge.aimonitoring.LlmTokenCountCallbackHolder; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import software.amazon.awssdk.core.SdkBytes; +import software.amazon.awssdk.http.SdkHttpResponse; +import software.amazon.awssdk.services.bedrockruntime.model.BedrockRuntimeResponseMetadata; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelRequest; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse; + +import java.util.Collection; +import java.util.HashMap; +import java.util.Iterator; +import java.util.Map; +import java.util.Optional; + +import static llm.events.LlmEvent.LLM_CHAT_COMPLETION_MESSAGE; +import static llm.events.LlmEvent.LLM_CHAT_COMPLETION_SUMMARY; +import static llm.events.LlmEvent.LLM_EMBEDDING; +import static llm.models.TestUtil.assertErrorEvent; +import static llm.models.TestUtil.assertLlmChatCompletionMessageAttributes; +import static llm.models.TestUtil.assertLlmChatCompletionSummaryAttributes; +import static llm.models.TestUtil.assertLlmEmbeddingAttributes; +import static org.junit.Assert.assertEquals; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +@RunWith(InstrumentationTestRunner.class) +@InstrumentationTestConfig(includePrefixes = { "software.amazon.awssdk.services.bedrockruntime" }, configName = "llm_enabled.yml") +public class CommandModelInvocationTest { + + private final Introspector introspector = InstrumentationTestRunner.getIntrospector(); + + // Embedding + private final String embeddingModelId = "cohere.embed-english-v3"; + private final String embeddingRequestBody = "{\"texts\":[\"What is the color of the sky?\"],\"truncate\":\"NONE\",\"input_type\":\"search_document\"}"; + private final String embeddingResponseBody = "{\"embeddings\":[[-0.002828598,0.012145996]],\"id\":\"c2c5c119-1268-4155-8c98-50ae199ffa16\",\"response_type\":\"embeddings_floats\",\"texts\":[\"what is the color of the sky?\"]}"; + private final String embeddingRequestInput = "What is the color of the sky?"; + + // Completion + private final String completionModelId = "cohere.command-light-text-v14"; + private final String completionRequestBody = "{\"p\":0.9,\"stop_sequences\":[\"User:\"],\"truncate\":\"END\",\"max_tokens\":1000,\"stream\":false,\"temperature\":0.5,\"k\":0,\"return_likelihoods\":\"NONE\",\"prompt\":\"What is the color of the sky?\"}"; + private final String completionResponseBody = "{\"generations\":[{\"finish_reason\":\"COMPLETE\",\"id\":\"314ba8cf-778d-49ed-a2cb-cf260008a2cc\",\"text\":\" The color of the sky is blue.\"}],\"id\":\"3070a2a7-b5a3-44cf-9908-554fa25473a6\",\"prompt\":\"What is the color of the sky?\"}"; + private final String completionRequestInput = "What is the color of the sky?"; + private final String completionResponseContent = " The color of the sky is blue."; + private final String finishReason = "COMPLETE"; + + @Before + public void before() { + introspector.clear(); + LlmTokenCountCallback llmTokenCountCallback = (model, content) -> 13; + LlmTokenCountCallbackHolder.setLlmTokenCountCallback(llmTokenCountCallback); + } + + @Test + public void testEmbedding() { + boolean isError = false; + + CommandModelInvocation commandModelInvocation = mockCommandModelInvocation(embeddingModelId, embeddingRequestBody, embeddingResponseBody, isError); + commandModelInvocation.recordLlmEvents(System.currentTimeMillis()); + + Collection llmEmbeddingEvents = introspector.getCustomEvents(LLM_EMBEDDING); + assertEquals(1, llmEmbeddingEvents.size()); + Iterator llmEmbeddingEventIterator = llmEmbeddingEvents.iterator(); + Event llmEmbeddingEvent = llmEmbeddingEventIterator.next(); + + assertLlmEmbeddingAttributes(llmEmbeddingEvent, embeddingModelId, embeddingRequestInput); + + assertErrorEvent(isError, introspector.getErrorEvents()); + } + + @Test + public void testCompletion() { + boolean isError = false; + + CommandModelInvocation commandModelInvocation = mockCommandModelInvocation(completionModelId, completionRequestBody, completionResponseBody, isError); + commandModelInvocation.recordLlmEvents(System.currentTimeMillis()); + + Collection llmChatCompletionMessageEvents = introspector.getCustomEvents(LLM_CHAT_COMPLETION_MESSAGE); + assertEquals(2, llmChatCompletionMessageEvents.size()); + Iterator llmChatCompletionMessageEventIterator = llmChatCompletionMessageEvents.iterator(); + Event llmChatCompletionMessageEventOne = llmChatCompletionMessageEventIterator.next(); + + assertLlmChatCompletionMessageAttributes(llmChatCompletionMessageEventOne, completionModelId, completionRequestInput, completionResponseContent, false); + + Event llmChatCompletionMessageEventTwo = llmChatCompletionMessageEventIterator.next(); + + assertLlmChatCompletionMessageAttributes(llmChatCompletionMessageEventTwo, completionModelId, completionRequestInput, completionResponseContent, true); + + Collection llmChatCompletionSummaryEvents = introspector.getCustomEvents(LLM_CHAT_COMPLETION_SUMMARY); + assertEquals(1, llmChatCompletionSummaryEvents.size()); + Iterator llmChatCompletionSummaryEventIterator = llmChatCompletionSummaryEvents.iterator(); + Event llmChatCompletionSummaryEvent = llmChatCompletionSummaryEventIterator.next(); + + assertLlmChatCompletionSummaryAttributes(llmChatCompletionSummaryEvent, completionModelId, finishReason); + + assertErrorEvent(isError, introspector.getErrorEvents()); + } + + @Test + public void testEmbeddingError() { + boolean isError = true; + + CommandModelInvocation commandModelInvocation = mockCommandModelInvocation(embeddingModelId, embeddingRequestBody, embeddingResponseBody, isError); + commandModelInvocation.recordLlmEvents(System.currentTimeMillis()); + + Collection llmEmbeddingEvents = introspector.getCustomEvents(LLM_EMBEDDING); + assertEquals(1, llmEmbeddingEvents.size()); + Iterator llmEmbeddingEventIterator = llmEmbeddingEvents.iterator(); + Event llmEmbeddingEvent = llmEmbeddingEventIterator.next(); + + assertLlmEmbeddingAttributes(llmEmbeddingEvent, embeddingModelId, embeddingRequestInput); + + assertErrorEvent(isError, introspector.getErrorEvents()); + } + + @Test + public void testCompletionError() { + boolean isError = true; + + CommandModelInvocation commandModelInvocation = mockCommandModelInvocation(completionModelId, completionRequestBody, completionResponseBody, isError); + commandModelInvocation.recordLlmEvents(System.currentTimeMillis()); + + Collection llmChatCompletionMessageEvents = introspector.getCustomEvents(LLM_CHAT_COMPLETION_MESSAGE); + assertEquals(2, llmChatCompletionMessageEvents.size()); + Iterator llmChatCompletionMessageEventIterator = llmChatCompletionMessageEvents.iterator(); + Event llmChatCompletionMessageEventOne = llmChatCompletionMessageEventIterator.next(); + + assertLlmChatCompletionMessageAttributes(llmChatCompletionMessageEventOne, completionModelId, completionRequestInput, completionResponseContent, false); + + Event llmChatCompletionMessageEventTwo = llmChatCompletionMessageEventIterator.next(); + + assertLlmChatCompletionMessageAttributes(llmChatCompletionMessageEventTwo, completionModelId, completionRequestInput, completionResponseContent, true); + + Collection llmChatCompletionSummaryEvents = introspector.getCustomEvents(LLM_CHAT_COMPLETION_SUMMARY); + assertEquals(1, llmChatCompletionSummaryEvents.size()); + Iterator llmChatCompletionSummaryEventIterator = llmChatCompletionSummaryEvents.iterator(); + Event llmChatCompletionSummaryEvent = llmChatCompletionSummaryEventIterator.next(); + + assertLlmChatCompletionSummaryAttributes(llmChatCompletionSummaryEvent, completionModelId, finishReason); + + assertErrorEvent(isError, introspector.getErrorEvents()); + } + + private CommandModelInvocation mockCommandModelInvocation(String modelId, String requestBody, String responseBody, boolean isError) { + // Given + Map linkingMetadata = new HashMap<>(); + linkingMetadata.put("span.id", "span-id-123"); + linkingMetadata.put("trace.id", "trace-id-xyz"); + + Map userAttributes = new HashMap<>(); + userAttributes.put("llm.conversation_id", "conversation-id-value"); + userAttributes.put("llm.testPrefix", "testPrefix"); + userAttributes.put("test", "test"); + + // Mock out ModelRequest + InvokeModelRequest mockInvokeModelRequest = mock(InvokeModelRequest.class); + SdkBytes mockRequestSdkBytes = mock(SdkBytes.class); + when(mockInvokeModelRequest.body()).thenReturn(mockRequestSdkBytes); + when(mockRequestSdkBytes.asUtf8String()).thenReturn(requestBody); + when(mockInvokeModelRequest.modelId()).thenReturn(modelId); + + // Mock out ModelResponse + InvokeModelResponse mockInvokeModelResponse = mock(InvokeModelResponse.class); + SdkBytes mockResponseSdkBytes = mock(SdkBytes.class); + when(mockInvokeModelResponse.body()).thenReturn(mockResponseSdkBytes); + when(mockResponseSdkBytes.asUtf8String()).thenReturn(responseBody); + + SdkHttpResponse mockSdkHttpResponse = mock(SdkHttpResponse.class); + when(mockInvokeModelResponse.sdkHttpResponse()).thenReturn(mockSdkHttpResponse); + + if (isError) { + when(mockSdkHttpResponse.statusCode()).thenReturn(400); + when(mockSdkHttpResponse.statusText()).thenReturn(Optional.of("BAD_REQUEST")); + when(mockSdkHttpResponse.isSuccessful()).thenReturn(false); + } else { + when(mockSdkHttpResponse.statusCode()).thenReturn(200); + when(mockSdkHttpResponse.statusText()).thenReturn(Optional.of("OK")); + when(mockSdkHttpResponse.isSuccessful()).thenReturn(true); + } + + BedrockRuntimeResponseMetadata mockBedrockRuntimeResponseMetadata = mock(BedrockRuntimeResponseMetadata.class); + when(mockInvokeModelResponse.responseMetadata()).thenReturn(mockBedrockRuntimeResponseMetadata); + when(mockBedrockRuntimeResponseMetadata.requestId()).thenReturn("90a22e92-db1d-4474-97a9-28b143846301"); + + // Instantiate ModelInvocation + return new CommandModelInvocation(linkingMetadata, userAttributes, mockInvokeModelRequest, + mockInvokeModelResponse); + } +} diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/models/meta/llama2/Llama2ModelInvocationTest.java b/instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/models/meta/llama2/Llama2ModelInvocationTest.java new file mode 100644 index 0000000000..da1399310a --- /dev/null +++ b/instrumentation/aws-bedrock-runtime-2.20/src/test/java/llm/models/meta/llama2/Llama2ModelInvocationTest.java @@ -0,0 +1,164 @@ +/* + * + * * Copyright 2024 New Relic Corporation. All rights reserved. + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package llm.models.meta.llama2; + +import com.newrelic.agent.introspec.Event; +import com.newrelic.agent.introspec.InstrumentationTestConfig; +import com.newrelic.agent.introspec.InstrumentationTestRunner; +import com.newrelic.agent.introspec.Introspector; +import com.newrelic.api.agent.LlmTokenCountCallback; +import com.newrelic.agent.bridge.aimonitoring.LlmTokenCountCallbackHolder; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import software.amazon.awssdk.core.SdkBytes; +import software.amazon.awssdk.http.SdkHttpResponse; +import software.amazon.awssdk.services.bedrockruntime.model.BedrockRuntimeResponseMetadata; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelRequest; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse; + +import java.util.Collection; +import java.util.HashMap; +import java.util.Iterator; +import java.util.Map; +import java.util.Optional; + +import static llm.events.LlmEvent.LLM_CHAT_COMPLETION_MESSAGE; +import static llm.events.LlmEvent.LLM_CHAT_COMPLETION_SUMMARY; +import static llm.models.TestUtil.assertErrorEvent; +import static llm.models.TestUtil.assertLlmChatCompletionMessageAttributes; +import static llm.models.TestUtil.assertLlmChatCompletionSummaryAttributes; +import static org.junit.Assert.assertEquals; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +@RunWith(InstrumentationTestRunner.class) +@InstrumentationTestConfig(includePrefixes = { "software.amazon.awssdk.services.bedrockruntime" }, configName = "llm_enabled.yml") +public class Llama2ModelInvocationTest { + + private final Introspector introspector = InstrumentationTestRunner.getIntrospector(); + + // Completion + private final String completionModelId = "meta.llama2-13b-chat-v1"; + private final String completionRequestBody = "{\"top_p\":0.9,\"max_gen_len\":1000,\"temperature\":0.5,\"prompt\":\"What is the color of the sky?\"}"; + private final String completionResponseBody = "{\"generation\":\"\\n\\nThe color of the sky is blue.\",\"prompt_token_count\":9,\"generation_token_count\":306,\"stop_reason\":\"stop\"}"; + private final String completionRequestInput = "What is the color of the sky?"; + private final String completionResponseContent = "\n\nThe color of the sky is blue."; + private final String finishReason = "stop"; + + @Before + public void before() { + introspector.clear(); + LlmTokenCountCallback llmTokenCountCallback = (model, content) -> 13; + LlmTokenCountCallbackHolder.setLlmTokenCountCallback(llmTokenCountCallback); + } + + @Test + public void testCompletion() { + boolean isError = false; + + Llama2ModelInvocation llama2ModelInvocation = mockLlama2ModelInvocation(completionModelId, completionRequestBody, completionResponseBody, + isError); + llama2ModelInvocation.recordLlmEvents(System.currentTimeMillis()); + + Collection llmChatCompletionMessageEvents = introspector.getCustomEvents(LLM_CHAT_COMPLETION_MESSAGE); + assertEquals(2, llmChatCompletionMessageEvents.size()); + Iterator llmChatCompletionMessageEventIterator = llmChatCompletionMessageEvents.iterator(); + Event llmChatCompletionMessageEventOne = llmChatCompletionMessageEventIterator.next(); + + assertLlmChatCompletionMessageAttributes(llmChatCompletionMessageEventOne, completionModelId, completionRequestInput, completionResponseContent, false); + + Event llmChatCompletionMessageEventTwo = llmChatCompletionMessageEventIterator.next(); + + assertLlmChatCompletionMessageAttributes(llmChatCompletionMessageEventTwo, completionModelId, completionRequestInput, completionResponseContent, true); + + Collection llmChatCompletionSummaryEvents = introspector.getCustomEvents(LLM_CHAT_COMPLETION_SUMMARY); + assertEquals(1, llmChatCompletionSummaryEvents.size()); + Iterator llmChatCompletionSummaryEventIterator = llmChatCompletionSummaryEvents.iterator(); + Event llmChatCompletionSummaryEvent = llmChatCompletionSummaryEventIterator.next(); + + assertLlmChatCompletionSummaryAttributes(llmChatCompletionSummaryEvent, completionModelId, finishReason); + + assertErrorEvent(isError, introspector.getErrorEvents()); + } + + @Test + public void testCompletionError() { + boolean isError = true; + + Llama2ModelInvocation llama2ModelInvocation = mockLlama2ModelInvocation(completionModelId, completionRequestBody, completionResponseBody, + isError); + llama2ModelInvocation.recordLlmEvents(System.currentTimeMillis()); + + Collection llmChatCompletionMessageEvents = introspector.getCustomEvents(LLM_CHAT_COMPLETION_MESSAGE); + assertEquals(2, llmChatCompletionMessageEvents.size()); + Iterator llmChatCompletionMessageEventIterator = llmChatCompletionMessageEvents.iterator(); + Event llmChatCompletionMessageEventOne = llmChatCompletionMessageEventIterator.next(); + + assertLlmChatCompletionMessageAttributes(llmChatCompletionMessageEventOne, completionModelId, completionRequestInput, completionResponseContent, false); + + Event llmChatCompletionMessageEventTwo = llmChatCompletionMessageEventIterator.next(); + + assertLlmChatCompletionMessageAttributes(llmChatCompletionMessageEventTwo, completionModelId, completionRequestInput, completionResponseContent, true); + + Collection llmChatCompletionSummaryEvents = introspector.getCustomEvents(LLM_CHAT_COMPLETION_SUMMARY); + assertEquals(1, llmChatCompletionSummaryEvents.size()); + Iterator llmChatCompletionSummaryEventIterator = llmChatCompletionSummaryEvents.iterator(); + Event llmChatCompletionSummaryEvent = llmChatCompletionSummaryEventIterator.next(); + + assertLlmChatCompletionSummaryAttributes(llmChatCompletionSummaryEvent, completionModelId, finishReason); + + assertErrorEvent(isError, introspector.getErrorEvents()); + } + + private Llama2ModelInvocation mockLlama2ModelInvocation(String modelId, String requestBody, String responseBody, boolean isError) { + // Given + Map linkingMetadata = new HashMap<>(); + linkingMetadata.put("span.id", "span-id-123"); + linkingMetadata.put("trace.id", "trace-id-xyz"); + + Map userAttributes = new HashMap<>(); + userAttributes.put("llm.conversation_id", "conversation-id-value"); + userAttributes.put("llm.testPrefix", "testPrefix"); + userAttributes.put("test", "test"); + + // Mock out ModelRequest + InvokeModelRequest mockInvokeModelRequest = mock(InvokeModelRequest.class); + SdkBytes mockRequestSdkBytes = mock(SdkBytes.class); + when(mockInvokeModelRequest.body()).thenReturn(mockRequestSdkBytes); + when(mockRequestSdkBytes.asUtf8String()).thenReturn(requestBody); + when(mockInvokeModelRequest.modelId()).thenReturn(modelId); + + // Mock out ModelResponse + InvokeModelResponse mockInvokeModelResponse = mock(InvokeModelResponse.class); + SdkBytes mockResponseSdkBytes = mock(SdkBytes.class); + when(mockInvokeModelResponse.body()).thenReturn(mockResponseSdkBytes); + when(mockResponseSdkBytes.asUtf8String()).thenReturn(responseBody); + + SdkHttpResponse mockSdkHttpResponse = mock(SdkHttpResponse.class); + when(mockInvokeModelResponse.sdkHttpResponse()).thenReturn(mockSdkHttpResponse); + + if (isError) { + when(mockSdkHttpResponse.statusCode()).thenReturn(400); + when(mockSdkHttpResponse.statusText()).thenReturn(Optional.of("BAD_REQUEST")); + when(mockSdkHttpResponse.isSuccessful()).thenReturn(false); + } else { + when(mockSdkHttpResponse.statusCode()).thenReturn(200); + when(mockSdkHttpResponse.statusText()).thenReturn(Optional.of("OK")); + when(mockSdkHttpResponse.isSuccessful()).thenReturn(true); + } + + BedrockRuntimeResponseMetadata mockBedrockRuntimeResponseMetadata = mock(BedrockRuntimeResponseMetadata.class); + when(mockInvokeModelResponse.responseMetadata()).thenReturn(mockBedrockRuntimeResponseMetadata); + when(mockBedrockRuntimeResponseMetadata.requestId()).thenReturn("90a22e92-db1d-4474-97a9-28b143846301"); + + // Instantiate ModelInvocation + return new Llama2ModelInvocation(linkingMetadata, userAttributes, mockInvokeModelRequest, + mockInvokeModelResponse); + } +} diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/test/java/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeAsyncClientMock.java b/instrumentation/aws-bedrock-runtime-2.20/src/test/java/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeAsyncClientMock.java new file mode 100644 index 0000000000..6558c5957c --- /dev/null +++ b/instrumentation/aws-bedrock-runtime-2.20/src/test/java/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeAsyncClientMock.java @@ -0,0 +1,108 @@ +/* + * + * * Copyright 2024 New Relic Corporation. All rights reserved. + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package software.amazon.awssdk.services.bedrockruntime; + +import software.amazon.awssdk.awscore.AwsResponseMetadata; +import software.amazon.awssdk.core.SdkBytes; +import software.amazon.awssdk.core.SdkResponse; +import software.amazon.awssdk.http.SdkHttpFullResponse; +import software.amazon.awssdk.http.SdkHttpResponse; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelRequest; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelWithResponseStreamRequest; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelWithResponseStreamResponseHandler; + +import java.util.HashMap; +import java.util.concurrent.CompletableFuture; +import java.util.function.Consumer; + +public class BedrockRuntimeAsyncClientMock implements BedrockRuntimeAsyncClient { + + // Embedding + public static final String embeddingModelId = "amazon.titan-embed-text-v1"; + public static final String embeddingResponseBody = "{\"embedding\":[0.328125,0.44335938],\"inputTextTokenCount\":8}"; + public static final String embeddingRequestInput = "What is the color of the sky?"; + + // Completion + public static final String completionModelId = "amazon.titan-text-lite-v1"; + public static final String completionResponseBody = "{\"inputTextTokenCount\":8,\"results\":[{\"tokenCount\":9,\"outputText\":\"\\nThe color of the sky is blue.\",\"completionReason\":\"FINISH\"}]}"; + public static final String completionRequestInput = "What is the color of the sky?"; + public static final String completionResponseContent = "\nThe color of the sky is blue."; + public static final String finishReason = "FINISH"; + + @Override + public String serviceName() { + return null; + } + + @Override + public void close() { + + } + + @Override + public CompletableFuture invokeModel(InvokeModelRequest invokeModelRequest) { + HashMap metadata = new HashMap<>(); + metadata.put("AWS_REQUEST_ID", "9d32a71a-e285-4b14-a23d-4f7d67b50ac3"); + AwsResponseMetadata awsResponseMetadata = new BedrockRuntimeResponseMetadataMock(metadata); + SdkHttpFullResponse sdkHttpFullResponse; + SdkResponse sdkResponse = null; + + boolean isError = invokeModelRequest.body().asUtf8String().contains("\"errorTest\":true"); + + if (invokeModelRequest.modelId().equals(completionModelId)) { + // This case will mock out a chat completion request/response + if (isError) { + sdkHttpFullResponse = SdkHttpResponse.builder().statusCode(400).statusText("BAD_REQUEST").build(); + } else { + sdkHttpFullResponse = SdkHttpResponse.builder().statusCode(200).statusText("OK").build(); + } + + sdkResponse = InvokeModelResponse.builder() + .body(SdkBytes.fromUtf8String(completionResponseBody)) + .contentType("application/json") + .responseMetadata(awsResponseMetadata) + .sdkHttpResponse(sdkHttpFullResponse) + .build(); + } else if (invokeModelRequest.modelId().equals(embeddingModelId)) { + // This case will mock out an embedding request/response + if (isError) { + sdkHttpFullResponse = SdkHttpResponse.builder().statusCode(400).statusText("BAD_REQUEST").build(); + } else { + sdkHttpFullResponse = SdkHttpResponse.builder().statusCode(200).statusText("OK").build(); + } + + sdkResponse = InvokeModelResponse.builder() + .body(SdkBytes.fromUtf8String(embeddingResponseBody)) + .contentType("application/json") + .responseMetadata(awsResponseMetadata) + .sdkHttpResponse(sdkHttpFullResponse) + .build(); + } + return CompletableFuture.completedFuture((InvokeModelResponse) sdkResponse); + } + + @Override + public CompletableFuture invokeModelWithResponseStream(InvokeModelWithResponseStreamRequest invokeModelWithResponseStreamRequest, + InvokeModelWithResponseStreamResponseHandler asyncResponseHandler) { + return BedrockRuntimeAsyncClient.super.invokeModelWithResponseStream(invokeModelWithResponseStreamRequest, asyncResponseHandler); + // Streaming not currently supported + } + + @Override + public CompletableFuture invokeModelWithResponseStream(Consumer invokeModelWithResponseStreamRequest, + InvokeModelWithResponseStreamResponseHandler asyncResponseHandler) { + return BedrockRuntimeAsyncClient.super.invokeModelWithResponseStream(invokeModelWithResponseStreamRequest, asyncResponseHandler); + // Streaming not currently supported + } + + @Override + public BedrockRuntimeServiceClientConfiguration serviceClientConfiguration() { + return null; + } +} diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/test/java/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeAsyncClient_InstrumentationTest.java b/instrumentation/aws-bedrock-runtime-2.20/src/test/java/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeAsyncClient_InstrumentationTest.java new file mode 100644 index 0000000000..928ecc87b3 --- /dev/null +++ b/instrumentation/aws-bedrock-runtime-2.20/src/test/java/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeAsyncClient_InstrumentationTest.java @@ -0,0 +1,209 @@ +/* + * + * * Copyright 2024 New Relic Corporation. All rights reserved. + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package software.amazon.awssdk.services.bedrockruntime; + +import com.newrelic.agent.introspec.Event; +import com.newrelic.agent.introspec.InstrumentationTestConfig; +import com.newrelic.agent.introspec.InstrumentationTestRunner; +import com.newrelic.agent.introspec.Introspector; +import com.newrelic.agent.introspec.TracedMetricData; +import com.newrelic.api.agent.NewRelic; +import com.newrelic.api.agent.Trace; +import llm.models.ModelResponse; +import org.json.JSONObject; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import software.amazon.awssdk.core.SdkBytes; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelRequest; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse; + +import java.util.Collection; +import java.util.Collections; +import java.util.Iterator; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; + +import static llm.events.LlmEvent.LLM_CHAT_COMPLETION_MESSAGE; +import static llm.events.LlmEvent.LLM_CHAT_COMPLETION_SUMMARY; +import static llm.events.LlmEvent.LLM_EMBEDDING; +import static llm.models.TestUtil.assertErrorEvent; +import static llm.models.TestUtil.assertLlmChatCompletionMessageAttributes; +import static llm.models.TestUtil.assertLlmChatCompletionSummaryAttributes; +import static llm.models.TestUtil.assertLlmEmbeddingAttributes; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; +import static software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeAsyncClientMock.completionModelId; +import static software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeAsyncClientMock.completionRequestInput; +import static software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeAsyncClientMock.completionResponseContent; +import static software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeAsyncClientMock.embeddingModelId; +import static software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeAsyncClientMock.embeddingRequestInput; +import static software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeAsyncClientMock.finishReason; + +@RunWith(InstrumentationTestRunner.class) +@InstrumentationTestConfig(includePrefixes = { "software.amazon.awssdk.services.bedrockruntime" }, configName = "llm_enabled.yml") + +public class BedrockRuntimeAsyncClient_InstrumentationTest { + private static final BedrockRuntimeAsyncClientMock mockBedrockRuntimeAsyncClient = new BedrockRuntimeAsyncClientMock(); + private final Introspector introspector = InstrumentationTestRunner.getIntrospector(); + + @Before + public void before() { + introspector.clear(); + } + + @Test + public void testInvokeModelCompletion() throws ExecutionException, InterruptedException { + boolean isError = false; + InvokeModelRequest invokeModelRequest = buildAmazonTitanCompletionRequest(isError); + InvokeModelResponse invokeModelResponse = invokeModelInTransaction(invokeModelRequest); + + assertNotNull(invokeModelResponse); + assertTransaction(ModelResponse.COMPLETION); + assertSupportabilityMetrics(); + assertLlmEvents(ModelResponse.COMPLETION); + assertErrorEvent(isError, introspector.getErrorEvents()); + } + + @Test + public void testInvokeModelEmbedding() throws ExecutionException, InterruptedException { + boolean isError = false; + InvokeModelRequest invokeModelRequest = buildAmazonTitanEmbeddingRequest(isError); + InvokeModelResponse invokeModelResponse = invokeModelInTransaction(invokeModelRequest); + + assertNotNull(invokeModelResponse); + assertTransaction(ModelResponse.EMBEDDING); + assertSupportabilityMetrics(); + assertLlmEvents(ModelResponse.EMBEDDING); + assertErrorEvent(isError, introspector.getErrorEvents()); + } + + @Test + public void testInvokeModelCompletionError() throws ExecutionException, InterruptedException { + boolean isError = true; + InvokeModelRequest invokeModelRequest = buildAmazonTitanCompletionRequest(isError); + InvokeModelResponse invokeModelResponse = invokeModelInTransaction(invokeModelRequest); + + assertNotNull(invokeModelResponse); + assertTransaction(ModelResponse.COMPLETION); + assertSupportabilityMetrics(); + assertLlmEvents(ModelResponse.COMPLETION); + assertErrorEvent(isError, introspector.getErrorEvents()); + } + + @Test + public void testInvokeModelEmbeddingError() throws ExecutionException, InterruptedException { + boolean isError = true; + InvokeModelRequest invokeModelRequest = buildAmazonTitanEmbeddingRequest(isError); + InvokeModelResponse invokeModelResponse = invokeModelInTransaction(invokeModelRequest); + + assertNotNull(invokeModelResponse); + assertTransaction(ModelResponse.EMBEDDING); + assertSupportabilityMetrics(); + assertLlmEvents(ModelResponse.EMBEDDING); + assertErrorEvent(isError, introspector.getErrorEvents()); + } + + private static InvokeModelRequest buildAmazonTitanCompletionRequest(boolean isError) { + JSONObject textGenerationConfig = new JSONObject() + .put("maxTokenCount", 1000) + .put("stopSequences", Collections.singletonList("User:")) + .put("temperature", 0.5) + .put("topP", 0.9); + + String payload = new JSONObject() + .put("inputText", completionRequestInput) + .put("textGenerationConfig", textGenerationConfig) + .put("errorTest", isError) // this is not a real model attribute, just adding for testing + .toString(); + + return InvokeModelRequest.builder() + .body(SdkBytes.fromUtf8String(payload)) + .modelId(completionModelId) + .contentType("application/json") + .accept("application/json") + .build(); + } + + private static InvokeModelRequest buildAmazonTitanEmbeddingRequest(boolean isError) { + String payload = new JSONObject() + .put("inputText", embeddingRequestInput) + .put("errorTest", isError) // this is not a real model attribute, just adding for testing + .toString(); + + return InvokeModelRequest.builder() + .body(SdkBytes.fromUtf8String(payload)) + .modelId(embeddingModelId) + .contentType("application/json") + .accept("application/json") + .build(); + } + + @Trace(dispatcher = true) + private InvokeModelResponse invokeModelInTransaction(InvokeModelRequest invokeModelRequest) throws ExecutionException, InterruptedException { + NewRelic.addCustomParameter("llm.conversation_id", "conversation-id-value"); // Will be added to LLM events + NewRelic.addCustomParameter("llm.testPrefix", "testPrefix"); // Will be added to LLM events + NewRelic.addCustomParameter("test", "test"); // Will NOT be added to LLM events + CompletableFuture invokeModelResponseCompletableFuture = mockBedrockRuntimeAsyncClient.invokeModel(invokeModelRequest); + return invokeModelResponseCompletableFuture.get(); + } + + private void assertTransaction(String operationType) { + assertEquals(1, introspector.getFinishedTransactionCount(TimeUnit.SECONDS.toMillis(2))); + Collection transactionNames = introspector.getTransactionNames(); + String transactionName = transactionNames.iterator().next(); + Map metrics = introspector.getMetricsForTransaction(transactionName); + assertTrue(metrics.containsKey("Llm/" + operationType + "/Bedrock/invokeModel")); + assertEquals(1, metrics.get("Llm/" + operationType + "/Bedrock/invokeModel").getCallCount()); + } + + private void assertSupportabilityMetrics() { + Map unscopedMetrics = introspector.getUnscopedMetrics(); + assertTrue(unscopedMetrics.containsKey("Supportability/Java/ML/Bedrock/2.20")); + } + + private void assertLlmEvents(String operationType) { + if (ModelResponse.COMPLETION.equals(operationType)) { + // LlmChatCompletionMessage events + Collection llmChatCompletionMessageEvents = introspector.getCustomEvents(LLM_CHAT_COMPLETION_MESSAGE); + assertEquals(2, llmChatCompletionMessageEvents.size()); + + Iterator llmChatCompletionMessageEventIterator = llmChatCompletionMessageEvents.iterator(); + // LlmChatCompletionMessage event for user request message + Event llmChatCompletionMessageEventOne = llmChatCompletionMessageEventIterator.next(); + assertLlmChatCompletionMessageAttributes(llmChatCompletionMessageEventOne, completionModelId, completionRequestInput, completionResponseContent, + false); + + // LlmChatCompletionMessage event for assistant response message + Event llmChatCompletionMessageEventTwo = llmChatCompletionMessageEventIterator.next(); + assertLlmChatCompletionMessageAttributes(llmChatCompletionMessageEventTwo, completionModelId, completionRequestInput, completionResponseContent, + true); + + // LlmCompletionSummary events + Collection llmCompletionSummaryEvents = introspector.getCustomEvents(LLM_CHAT_COMPLETION_SUMMARY); + assertEquals(1, llmCompletionSummaryEvents.size()); + + Iterator llmCompletionSummaryEventIterator = llmCompletionSummaryEvents.iterator(); + // Summary event for both LlmChatCompletionMessage events + Event llmCompletionSummaryEvent = llmCompletionSummaryEventIterator.next(); + assertLlmChatCompletionSummaryAttributes(llmCompletionSummaryEvent, completionModelId, finishReason); + } else if (ModelResponse.EMBEDDING.equals(operationType)) { + // LlmEmbedding events + Collection llmEmbeddingEvents = introspector.getCustomEvents(LLM_EMBEDDING); + assertEquals(1, llmEmbeddingEvents.size()); + + Iterator llmEmbeddingEventIterator = llmEmbeddingEvents.iterator(); + // LlmEmbedding event + Event llmEmbeddingEvent = llmEmbeddingEventIterator.next(); + assertLlmEmbeddingAttributes(llmEmbeddingEvent, embeddingModelId, embeddingRequestInput); + } + } +} diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/test/java/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeClientMock.java b/instrumentation/aws-bedrock-runtime-2.20/src/test/java/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeClientMock.java new file mode 100644 index 0000000000..81b164861e --- /dev/null +++ b/instrumentation/aws-bedrock-runtime-2.20/src/test/java/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeClientMock.java @@ -0,0 +1,113 @@ +/* + * + * * Copyright 2024 New Relic Corporation. All rights reserved. + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package software.amazon.awssdk.services.bedrockruntime; + +import software.amazon.awssdk.awscore.AwsResponseMetadata; +import software.amazon.awssdk.awscore.exception.AwsServiceException; +import software.amazon.awssdk.core.SdkBytes; +import software.amazon.awssdk.core.SdkResponse; +import software.amazon.awssdk.core.exception.SdkClientException; +import software.amazon.awssdk.http.SdkHttpFullResponse; +import software.amazon.awssdk.http.SdkHttpResponse; +import software.amazon.awssdk.services.bedrockruntime.model.AccessDeniedException; +import software.amazon.awssdk.services.bedrockruntime.model.BedrockRuntimeException; +import software.amazon.awssdk.services.bedrockruntime.model.InternalServerException; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelRequest; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse; +import software.amazon.awssdk.services.bedrockruntime.model.ModelErrorException; +import software.amazon.awssdk.services.bedrockruntime.model.ModelNotReadyException; +import software.amazon.awssdk.services.bedrockruntime.model.ModelTimeoutException; +import software.amazon.awssdk.services.bedrockruntime.model.ResourceNotFoundException; +import software.amazon.awssdk.services.bedrockruntime.model.ServiceQuotaExceededException; +import software.amazon.awssdk.services.bedrockruntime.model.ThrottlingException; +import software.amazon.awssdk.services.bedrockruntime.model.ValidationException; + +import java.util.HashMap; +import java.util.function.Consumer; + +public class BedrockRuntimeClientMock implements BedrockRuntimeClient { + + // Embedding + public static final String embeddingModelId = "amazon.titan-embed-text-v1"; + public static final String embeddingResponseBody = "{\"embedding\":[0.328125,0.44335938],\"inputTextTokenCount\":8}"; + public static final String embeddingRequestInput = "What is the color of the sky?"; + + // Completion + public static final String completionModelId = "amazon.titan-text-lite-v1"; + public static final String completionResponseBody = "{\"inputTextTokenCount\":8,\"results\":[{\"tokenCount\":9,\"outputText\":\"\\nThe color of the sky is blue.\",\"completionReason\":\"FINISH\"}]}"; + public static final String completionRequestInput = "What is the color of the sky?"; + public static final String completionResponseContent = "\nThe color of the sky is blue."; + public static final String finishReason = "FINISH"; + + @Override + public String serviceName() { + return null; + } + + @Override + public void close() { + + } + + @Override + public InvokeModelResponse invokeModel(InvokeModelRequest invokeModelRequest) + throws AccessDeniedException, ResourceNotFoundException, ThrottlingException, ModelTimeoutException, InternalServerException, ValidationException, + ModelNotReadyException, ServiceQuotaExceededException, ModelErrorException, AwsServiceException, SdkClientException, BedrockRuntimeException { + + HashMap metadata = new HashMap<>(); + metadata.put("AWS_REQUEST_ID", "9d32a71a-e285-4b14-a23d-4f7d67b50ac3"); + AwsResponseMetadata awsResponseMetadata = new BedrockRuntimeResponseMetadataMock(metadata); + SdkHttpFullResponse sdkHttpFullResponse; + SdkResponse sdkResponse = null; + + boolean isError = invokeModelRequest.body().asUtf8String().contains("\"errorTest\":true"); + + if (invokeModelRequest.modelId().equals(completionModelId)) { + // This case will mock out a chat completion request/response + if (isError) { + sdkHttpFullResponse = SdkHttpResponse.builder().statusCode(400).statusText("BAD_REQUEST").build(); + } else { + sdkHttpFullResponse = SdkHttpResponse.builder().statusCode(200).statusText("OK").build(); + } + + sdkResponse = InvokeModelResponse.builder() + .body(SdkBytes.fromUtf8String(completionResponseBody)) + .contentType("application/json") + .responseMetadata(awsResponseMetadata) + .sdkHttpResponse(sdkHttpFullResponse) + .build(); + } else if (invokeModelRequest.modelId().equals(embeddingModelId)) { + // This case will mock out an embedding request/response + if (isError) { + sdkHttpFullResponse = SdkHttpResponse.builder().statusCode(400).statusText("BAD_REQUEST").build(); + } else { + sdkHttpFullResponse = SdkHttpResponse.builder().statusCode(200).statusText("OK").build(); + } + + sdkResponse = InvokeModelResponse.builder() + .body(SdkBytes.fromUtf8String(embeddingResponseBody)) + .contentType("application/json") + .responseMetadata(awsResponseMetadata) + .sdkHttpResponse(sdkHttpFullResponse) + .build(); + } + return (InvokeModelResponse) sdkResponse; + } + + @Override + public InvokeModelResponse invokeModel(Consumer invokeModelRequest) + throws AccessDeniedException, ResourceNotFoundException, ThrottlingException, ModelTimeoutException, InternalServerException, ValidationException, + ModelNotReadyException, ServiceQuotaExceededException, ModelErrorException, AwsServiceException, SdkClientException, BedrockRuntimeException { + return null; + } + + @Override + public BedrockRuntimeServiceClientConfiguration serviceClientConfiguration() { + return null; + } +} diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/test/java/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeClient_InstrumentationTest.java b/instrumentation/aws-bedrock-runtime-2.20/src/test/java/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeClient_InstrumentationTest.java new file mode 100644 index 0000000000..a7dcc1b96f --- /dev/null +++ b/instrumentation/aws-bedrock-runtime-2.20/src/test/java/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeClient_InstrumentationTest.java @@ -0,0 +1,206 @@ +/* + * + * * Copyright 2024 New Relic Corporation. All rights reserved. + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package software.amazon.awssdk.services.bedrockruntime; + +import com.newrelic.agent.introspec.Event; +import com.newrelic.agent.introspec.InstrumentationTestConfig; +import com.newrelic.agent.introspec.InstrumentationTestRunner; +import com.newrelic.agent.introspec.Introspector; +import com.newrelic.agent.introspec.TracedMetricData; +import com.newrelic.api.agent.NewRelic; +import com.newrelic.api.agent.Trace; +import llm.models.ModelResponse; +import org.json.JSONObject; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import software.amazon.awssdk.core.SdkBytes; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelRequest; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse; + +import java.util.Collection; +import java.util.Collections; +import java.util.Iterator; +import java.util.Map; +import java.util.concurrent.TimeUnit; + +import static llm.events.LlmEvent.LLM_CHAT_COMPLETION_MESSAGE; +import static llm.events.LlmEvent.LLM_CHAT_COMPLETION_SUMMARY; +import static llm.events.LlmEvent.LLM_EMBEDDING; +import static llm.models.TestUtil.assertErrorEvent; +import static llm.models.TestUtil.assertLlmChatCompletionMessageAttributes; +import static llm.models.TestUtil.assertLlmChatCompletionSummaryAttributes; +import static llm.models.TestUtil.assertLlmEmbeddingAttributes; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; +import static software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClientMock.completionModelId; +import static software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClientMock.completionRequestInput; +import static software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClientMock.completionResponseContent; +import static software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClientMock.embeddingModelId; +import static software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClientMock.embeddingRequestInput; +import static software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClientMock.finishReason; + +@RunWith(InstrumentationTestRunner.class) +@InstrumentationTestConfig(includePrefixes = { "software.amazon.awssdk.services.bedrockruntime" }, configName = "llm_enabled.yml") + +public class BedrockRuntimeClient_InstrumentationTest { + private static final BedrockRuntimeClientMock mockBedrockRuntimeClient = new BedrockRuntimeClientMock(); + private final Introspector introspector = InstrumentationTestRunner.getIntrospector(); + + @Before + public void before() { + introspector.clear(); + } + + @Test + public void testInvokeModelCompletion() { + boolean isError = false; + InvokeModelRequest invokeModelRequest = buildAmazonTitanCompletionRequest(isError); + InvokeModelResponse invokeModelResponse = invokeModelInTransaction(invokeModelRequest); + + assertNotNull(invokeModelResponse); + assertTransaction(ModelResponse.COMPLETION); + assertSupportabilityMetrics(); + assertLlmEvents(ModelResponse.COMPLETION); + assertErrorEvent(isError, introspector.getErrorEvents()); + } + + @Test + public void testInvokeModelEmbedding() { + boolean isError = false; + InvokeModelRequest invokeModelRequest = buildAmazonTitanEmbeddingRequest(isError); + InvokeModelResponse invokeModelResponse = invokeModelInTransaction(invokeModelRequest); + + assertNotNull(invokeModelResponse); + assertTransaction(ModelResponse.EMBEDDING); + assertSupportabilityMetrics(); + assertLlmEvents(ModelResponse.EMBEDDING); + assertErrorEvent(isError, introspector.getErrorEvents()); + } + + @Test + public void testInvokeModelCompletionError() { + boolean isError = true; + InvokeModelRequest invokeModelRequest = buildAmazonTitanCompletionRequest(isError); + InvokeModelResponse invokeModelResponse = invokeModelInTransaction(invokeModelRequest); + + assertNotNull(invokeModelResponse); + assertTransaction(ModelResponse.COMPLETION); + assertSupportabilityMetrics(); + assertLlmEvents(ModelResponse.COMPLETION); + assertErrorEvent(isError, introspector.getErrorEvents()); + } + + @Test + public void testInvokeModelEmbeddingError() { + boolean isError = true; + InvokeModelRequest invokeModelRequest = buildAmazonTitanEmbeddingRequest(isError); + InvokeModelResponse invokeModelResponse = invokeModelInTransaction(invokeModelRequest); + + assertNotNull(invokeModelResponse); + assertTransaction(ModelResponse.EMBEDDING); + assertSupportabilityMetrics(); + assertLlmEvents(ModelResponse.EMBEDDING); + assertErrorEvent(isError, introspector.getErrorEvents()); + } + + private static InvokeModelRequest buildAmazonTitanCompletionRequest(boolean isError) { + JSONObject textGenerationConfig = new JSONObject() + .put("maxTokenCount", 1000) + .put("stopSequences", Collections.singletonList("User:")) + .put("temperature", 0.5) + .put("topP", 0.9); + + String payload = new JSONObject() + .put("inputText", completionRequestInput) + .put("textGenerationConfig", textGenerationConfig) + .put("errorTest", isError) // this is not a real model attribute, just adding for testing + .toString(); + + return InvokeModelRequest.builder() + .body(SdkBytes.fromUtf8String(payload)) + .modelId(completionModelId) + .contentType("application/json") + .accept("application/json") + .build(); + } + + private static InvokeModelRequest buildAmazonTitanEmbeddingRequest(boolean isError) { + String payload = new JSONObject() + .put("inputText", embeddingRequestInput) + .put("errorTest", isError) // this is not a real model attribute, just adding for testing + .toString(); + + return InvokeModelRequest.builder() + .body(SdkBytes.fromUtf8String(payload)) + .modelId(embeddingModelId) + .contentType("application/json") + .accept("application/json") + .build(); + } + + @Trace(dispatcher = true) + private InvokeModelResponse invokeModelInTransaction(InvokeModelRequest invokeModelRequest) { + NewRelic.addCustomParameter("llm.conversation_id", "conversation-id-value"); // Will be added to LLM events + NewRelic.addCustomParameter("llm.testPrefix", "testPrefix"); // Will be added to LLM events + NewRelic.addCustomParameter("test", "test"); // Will NOT be added to LLM events + return mockBedrockRuntimeClient.invokeModel(invokeModelRequest); + } + + private void assertTransaction(String operationType) { + assertEquals(1, introspector.getFinishedTransactionCount(TimeUnit.SECONDS.toMillis(2))); + Collection transactionNames = introspector.getTransactionNames(); + String transactionName = transactionNames.iterator().next(); + Map metrics = introspector.getMetricsForTransaction(transactionName); + assertTrue(metrics.containsKey("Llm/" + operationType + "/Bedrock/invokeModel")); + assertEquals(1, metrics.get("Llm/" + operationType + "/Bedrock/invokeModel").getCallCount()); + } + + private void assertSupportabilityMetrics() { + Map unscopedMetrics = introspector.getUnscopedMetrics(); + assertTrue(unscopedMetrics.containsKey("Supportability/Java/ML/Bedrock/2.20")); + } + + private void assertLlmEvents(String operationType) { + if (ModelResponse.COMPLETION.equals(operationType)) { + // LlmChatCompletionMessage events + Collection llmChatCompletionMessageEvents = introspector.getCustomEvents(LLM_CHAT_COMPLETION_MESSAGE); + assertEquals(2, llmChatCompletionMessageEvents.size()); + + Iterator llmChatCompletionMessageEventIterator = llmChatCompletionMessageEvents.iterator(); + // LlmChatCompletionMessage event for user request message + Event llmChatCompletionMessageEventOne = llmChatCompletionMessageEventIterator.next(); + assertLlmChatCompletionMessageAttributes(llmChatCompletionMessageEventOne, completionModelId, completionRequestInput, completionResponseContent, + false); + + // LlmChatCompletionMessage event for assistant response message + Event llmChatCompletionMessageEventTwo = llmChatCompletionMessageEventIterator.next(); + assertLlmChatCompletionMessageAttributes(llmChatCompletionMessageEventTwo, completionModelId, completionRequestInput, completionResponseContent, + true); + + // LlmCompletionSummary events + Collection llmCompletionSummaryEvents = introspector.getCustomEvents(LLM_CHAT_COMPLETION_SUMMARY); + assertEquals(1, llmCompletionSummaryEvents.size()); + + Iterator llmCompletionSummaryEventIterator = llmCompletionSummaryEvents.iterator(); + // Summary event for both LlmChatCompletionMessage events + Event llmCompletionSummaryEvent = llmCompletionSummaryEventIterator.next(); + assertLlmChatCompletionSummaryAttributes(llmCompletionSummaryEvent, completionModelId, finishReason); + } else if (ModelResponse.EMBEDDING.equals(operationType)) { + // LlmEmbedding events + Collection llmEmbeddingEvents = introspector.getCustomEvents(LLM_EMBEDDING); + assertEquals(1, llmEmbeddingEvents.size()); + + Iterator llmEmbeddingEventIterator = llmEmbeddingEvents.iterator(); + // LlmEmbedding event + Event llmEmbeddingEvent = llmEmbeddingEventIterator.next(); + assertLlmEmbeddingAttributes(llmEmbeddingEvent, embeddingModelId, embeddingRequestInput); + } + } +} diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/test/java/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeResponseMetadataMock.java b/instrumentation/aws-bedrock-runtime-2.20/src/test/java/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeResponseMetadataMock.java new file mode 100644 index 0000000000..b7967e4381 --- /dev/null +++ b/instrumentation/aws-bedrock-runtime-2.20/src/test/java/software/amazon/awssdk/services/bedrockruntime/BedrockRuntimeResponseMetadataMock.java @@ -0,0 +1,18 @@ +/* + * + * * Copyright 2024 New Relic Corporation. All rights reserved. + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package software.amazon.awssdk.services.bedrockruntime; + +import software.amazon.awssdk.awscore.AwsResponseMetadata; + +import java.util.Map; + +public class BedrockRuntimeResponseMetadataMock extends AwsResponseMetadata { + protected BedrockRuntimeResponseMetadataMock(Map metadata) { + super(metadata); + } +} diff --git a/instrumentation/aws-bedrock-runtime-2.20/src/test/resources/llm_enabled.yml b/instrumentation/aws-bedrock-runtime-2.20/src/test/resources/llm_enabled.yml new file mode 100644 index 0000000000..dbcd129940 --- /dev/null +++ b/instrumentation/aws-bedrock-runtime-2.20/src/test/resources/llm_enabled.yml @@ -0,0 +1,11 @@ +common: &default_settings + ai_monitoring: + enabled: true + record_content: + enabled: true + streaming: + enabled: true + + custom_insights_events: + max_samples_stored: 30000 + max_attribute_value: 255 diff --git a/newrelic-agent/src/main/java/com/newrelic/agent/AgentImpl.java b/newrelic-agent/src/main/java/com/newrelic/agent/AgentImpl.java index b690795eca..2acd5efd7a 100644 --- a/newrelic-agent/src/main/java/com/newrelic/agent/AgentImpl.java +++ b/newrelic-agent/src/main/java/com/newrelic/agent/AgentImpl.java @@ -7,6 +7,7 @@ package com.newrelic.agent; +import com.newrelic.agent.aimonitoring.AiMonitoringImpl; import com.newrelic.agent.bridge.AgentBridge; import com.newrelic.agent.bridge.NoOpMetricAggregator; import com.newrelic.agent.bridge.NoOpTracedMethod; @@ -15,6 +16,7 @@ import com.newrelic.agent.bridge.Transaction; import com.newrelic.agent.service.ServiceFactory; import com.newrelic.agent.tracers.Tracer; +import com.newrelic.api.agent.AiMonitoring; import com.newrelic.api.agent.ErrorApi; import com.newrelic.api.agent.Insights; import com.newrelic.api.agent.Logger; @@ -137,6 +139,11 @@ public Insights getInsights() { return ServiceFactory.getServiceManager().getInsights(); } + @Override + public AiMonitoring getAiMonitoring() { + return new AiMonitoringImpl(); + } + @Override public Logs getLogSender() { return ServiceFactory.getServiceManager().getLogSenderService(); diff --git a/newrelic-agent/src/main/java/com/newrelic/agent/MetricNames.java b/newrelic-agent/src/main/java/com/newrelic/agent/MetricNames.java index 948e7ede9c..0a426c81b7 100644 --- a/newrelic-agent/src/main/java/com/newrelic/agent/MetricNames.java +++ b/newrelic-agent/src/main/java/com/newrelic/agent/MetricNames.java @@ -498,6 +498,9 @@ public class MetricNames { public static final String SUPPORTABILITY_SLOW_TXN_DETECTION_ENABLED = "Supportability/SlowTransactionDetection/enabled"; public static final String SUPPORTABILITY_SLOW_TXN_DETECTION_DISABLED = "Supportability/SlowTransactionDetection/disabled"; + // AiMonitoring Callback Set + public static final String SUPPORTABILITY_AI_MONITORING_TOKEN_COUNT_CALLBACK_SET = "Supportability/AiMonitoringTokenCountCallback/Set"; + /** * Utility method for adding supportability metrics to APIs * diff --git a/newrelic-agent/src/main/java/com/newrelic/agent/TransactionApiImpl.java b/newrelic-agent/src/main/java/com/newrelic/agent/TransactionApiImpl.java index 292091654c..fbc7839994 100644 --- a/newrelic-agent/src/main/java/com/newrelic/agent/TransactionApiImpl.java +++ b/newrelic-agent/src/main/java/com/newrelic/agent/TransactionApiImpl.java @@ -313,6 +313,12 @@ public Map getAgentAttributes() { return (tx != null) ? tx.getAgentAttributes() : NoOpTransaction.INSTANCE.getAgentAttributes(); } + @Override + public Map getUserAttributes() { + Transaction tx = getTransactionIfExists(); + return (tx != null) ? tx.getUserAttributes() : NoOpTransaction.INSTANCE.getUserAttributes(); + } + @Override public void provideHeaders(InboundHeaders headers) { Transaction tx = getTransactionIfExists(); diff --git a/newrelic-agent/src/main/java/com/newrelic/agent/aimonitoring/AiMonitoringImpl.java b/newrelic-agent/src/main/java/com/newrelic/agent/aimonitoring/AiMonitoringImpl.java new file mode 100644 index 0000000000..bb6f7c79c8 --- /dev/null +++ b/newrelic-agent/src/main/java/com/newrelic/agent/aimonitoring/AiMonitoringImpl.java @@ -0,0 +1,66 @@ +package com.newrelic.agent.aimonitoring; + +import com.newrelic.agent.MetricNames; +import com.newrelic.agent.bridge.aimonitoring.LlmTokenCountCallbackHolder; +import com.newrelic.api.agent.AiMonitoring; +import com.newrelic.api.agent.LlmFeedbackEventAttributes; +import com.newrelic.api.agent.LlmTokenCountCallback; +import com.newrelic.api.agent.NewRelic; + +import java.util.Map; + +/** + * A utility class for interacting with the AI Monitoring API to record LlmFeedbackMessage events. + * This class implements the {@link AiMonitoring} interface and provides methods for feedback event recording + * and setting callbacks for token calculation. + */ + +public class AiMonitoringImpl implements AiMonitoring { + private static final String SUPPORTABILITY_AI_MONITORING_TOKEN_COUNT_CALLBACK_SET = "Supportability/AiMonitoringTokenCountCallback/set"; + + /** + * Records an LlmFeedbackMessage event. + * + * @param llmFeedbackEventAttributes A map containing the attributes of an LlmFeedbackMessage event. To construct + * the llmFeedbackEventAttributes map, use + * {@link LlmFeedbackEventAttributes.Builder} + *

Required Attributes:

+ *
    + *
  • "traceId" (String): Trace ID where the chat completion related to the + * feedback event occurred
  • + *
  • "rating" (Integer/String): Rating provided by an end user
  • + *
+ * Optional Attributes: + *
    + *
  • "category" (String): Category of the feedback as provided by the end user
  • + *
  • "message" (String): Freeform text feedback from an end user.
  • + *
  • "metadata" (Map<String, String>): Set of key-value pairs to store + * additional data to submit with the feedback event
  • + *
+ */ + + @Override + public void recordLlmFeedbackEvent(Map llmFeedbackEventAttributes) { + if (llmFeedbackEventAttributes == null) { + throw new IllegalArgumentException("llmFeedbackEventAttributes cannot be null"); + } + // Delegate to Insights API for event recording + NewRelic.getAgent().getInsights().recordCustomEvent("LlmFeedbackMessage", llmFeedbackEventAttributes); + } + + /** + * Sets the callback for token calculation and reports a supportability metric. + * + * @param llmTokenCountCallback The callback instance implementing {@link LlmTokenCountCallback} interface. + * This callback will be used for token calculation. + * @see LlmTokenCountCallback + */ + @Override + public void setLlmTokenCountCallback(LlmTokenCountCallback llmTokenCountCallback) { + if (llmTokenCountCallback == null) { + throw new IllegalArgumentException("llmTokenCountCallback cannot be null"); + } + LlmTokenCountCallbackHolder.setLlmTokenCountCallback(llmTokenCountCallback); + NewRelic.getAgent().getMetricAggregator().incrementCounter(MetricNames.SUPPORTABILITY_AI_MONITORING_TOKEN_COUNT_CALLBACK_SET); + } +} diff --git a/newrelic-agent/src/main/java/com/newrelic/agent/attributes/CustomEventAttributeValidator.java b/newrelic-agent/src/main/java/com/newrelic/agent/attributes/CustomEventAttributeValidator.java index 650beeefb9..70e0ac3429 100644 --- a/newrelic-agent/src/main/java/com/newrelic/agent/attributes/CustomEventAttributeValidator.java +++ b/newrelic-agent/src/main/java/com/newrelic/agent/attributes/CustomEventAttributeValidator.java @@ -12,8 +12,11 @@ /** * Attribute validator with truncation rules specific to custom events. */ -public class CustomEventAttributeValidator extends AttributeValidator{ - private static final int MAX_CUSTOM_EVENT_ATTRIBUTE_SIZE = ServiceFactory.getConfigService().getDefaultAgentConfig().getInsightsConfig().getMaxAttributeValue(); +public class CustomEventAttributeValidator extends AttributeValidator { + private static final int MAX_CUSTOM_EVENT_ATTRIBUTE_SIZE = ServiceFactory.getConfigService() + .getDefaultAgentConfig() + .getInsightsConfig() + .getMaxAttributeValue(); public CustomEventAttributeValidator(String attributeType) { super(attributeType); diff --git a/newrelic-agent/src/main/java/com/newrelic/agent/attributes/LlmEventAttributeValidator.java b/newrelic-agent/src/main/java/com/newrelic/agent/attributes/LlmEventAttributeValidator.java new file mode 100644 index 0000000000..fc1a2dcf0b --- /dev/null +++ b/newrelic-agent/src/main/java/com/newrelic/agent/attributes/LlmEventAttributeValidator.java @@ -0,0 +1,40 @@ +/* + * + * * Copyright 2024 New Relic Corporation. All rights reserved. + * * SPDX-License-Identifier: Apache-2.0 + * + */ + +package com.newrelic.agent.attributes; + +import com.newrelic.agent.service.ServiceFactory; + +/** + * Attribute validator with truncation rules specific to LLM events. + */ +public class LlmEventAttributeValidator extends AttributeValidator { + private static final int MAX_CUSTOM_EVENT_ATTRIBUTE_SIZE = ServiceFactory.getConfigService() + .getDefaultAgentConfig() + .getInsightsConfig() + .getMaxAttributeValue(); + + public LlmEventAttributeValidator(String attributeType) { + super(attributeType); + } + + @Override + protected String truncateValue(String key, String value, String methodCalled) { + /* + * The 'input' and output 'content' attributes should be added to LLM events + * without being truncated as per the LLMs agent spec. This is because the + * backend will use these attributes to calculate LLM token usage in cases + * where token counts aren't available on LLM events. + */ + if (key.equals("content") || key.equals("input")) { + return value; + } + String truncatedVal = truncateString(value, MAX_CUSTOM_EVENT_ATTRIBUTE_SIZE); + logTruncatedValue(key, value, truncatedVal, methodCalled, MAX_CUSTOM_EVENT_ATTRIBUTE_SIZE); + return truncatedVal; + } +} diff --git a/newrelic-agent/src/main/java/com/newrelic/agent/service/analytics/InsightsServiceImpl.java b/newrelic-agent/src/main/java/com/newrelic/agent/service/analytics/InsightsServiceImpl.java index 5ed82abb4a..39679a3a14 100644 --- a/newrelic-agent/src/main/java/com/newrelic/agent/service/analytics/InsightsServiceImpl.java +++ b/newrelic-agent/src/main/java/com/newrelic/agent/service/analytics/InsightsServiceImpl.java @@ -18,10 +18,12 @@ import com.newrelic.agent.TransactionData; import com.newrelic.agent.attributes.AttributeSender; import com.newrelic.agent.attributes.CustomEventAttributeValidator; +import com.newrelic.agent.attributes.LlmEventAttributeValidator; import com.newrelic.agent.config.AgentConfig; import com.newrelic.agent.config.AgentConfigListener; import com.newrelic.agent.model.AnalyticsEvent; import com.newrelic.agent.model.CustomInsightsEvent; +import com.newrelic.agent.model.LlmCustomInsightsEvent; import com.newrelic.agent.service.AbstractService; import com.newrelic.agent.service.ServiceFactory; import com.newrelic.agent.stats.StatsEngine; @@ -332,7 +334,7 @@ public String getEventHarvestLimitMetric() { } private void recordSupportabilityMetrics(StatsEngine statsEngine, long durationInNanoseconds, - DistributedSamplingPriorityQueue reservoir) { + DistributedSamplingPriorityQueue reservoir) { statsEngine.getStats(MetricNames.SUPPORTABILITY_INSIGHTS_SERVICE_CUSTOMER_SENT) .incrementCallCount(reservoir.size()); statsEngine.getStats(MetricNames.SUPPORTABILITY_INSIGHTS_SERVICE_CUSTOMER_SEEN) @@ -366,16 +368,25 @@ private static String mapInternString(String value) { private static CustomInsightsEvent createValidatedEvent(String eventType, Map attributes) { Map userAttributes = new HashMap<>(attributes.size()); - CustomInsightsEvent event = new CustomInsightsEvent(mapInternString(eventType), System.currentTimeMillis(), userAttributes, DistributedTraceServiceImpl.nextTruncatedFloat()); + CustomInsightsEvent event = new CustomInsightsEvent(mapInternString(eventType), System.currentTimeMillis(), userAttributes, + DistributedTraceServiceImpl.nextTruncatedFloat()); // Now add the attributes from the argument map to the event using an AttributeSender. // An AttributeSender is the way to reuse all the existing attribute validations. We // also locally "intern" Strings because we anticipate a lot of reuse of the keys and, // possibly, the values. But there's an interaction: if the key or value is chopped // within the attribute sender, the modified value won't be "interned" in our map. + AttributeSender sender; + final String method; - AttributeSender sender = new CustomEventAttributeSender(userAttributes); - final String method = "add custom event attribute"; + // CustomInsightsEvents are being overloaded to support some internal event types being sent to the same agent endpoint + if (LlmCustomInsightsEvent.isLlmEvent(eventType)) { + sender = new LlmEventAttributeSender(userAttributes); + method = "add llm event attribute"; + } else { + sender = new CustomEventAttributeSender(userAttributes); + method = "add custom event attribute"; + } for (Map.Entry entry : attributes.entrySet()) { String key = entry.getKey(); @@ -384,7 +395,7 @@ private static CustomInsightsEvent createValidatedEvent(String eventType, Map getAttributeMap() { } } + /** + * LlmEvent attribute validation rules differ from those of a standard CustomInsightsEvent + */ + private static class LlmEventAttributeSender extends AttributeSender { + private static final String ATTRIBUTE_TYPE = "llm"; + + private final Map userAttributes; + + public LlmEventAttributeSender(Map userAttributes) { + super(new LlmEventAttributeValidator(ATTRIBUTE_TYPE)); + this.userAttributes = userAttributes; + // This will have the effect of only copying attributes onto LlmEvents if there is an active transaction + setTransactional(true); + } + + @Override + protected String getAttributeType() { + return ATTRIBUTE_TYPE; + } + + @Override + protected Map getAttributeMap() { + if (ServiceFactory.getConfigService().getDefaultAgentConfig().isCustomParametersAllowed()) { + return userAttributes; + } + return null; + } + } + @Override public Insights getTransactionInsights(AgentConfig config) { return new TransactionInsights(config); diff --git a/newrelic-agent/src/main/resources/newrelic.yml b/newrelic-agent/src/main/resources/newrelic.yml index 1924ab8687..259456e73f 100644 --- a/newrelic-agent/src/main/resources/newrelic.yml +++ b/newrelic-agent/src/main/resources/newrelic.yml @@ -83,6 +83,20 @@ common: &default_settings # Default is the logs directory in the newrelic.jar parent directory. #log_file_path: + # AI Monitoring captures insights on the performance, quality, and cost of interactions with LLM models made with instrumented SDKs. + ai_monitoring: + + # Provides control over all AI Monitoring functionality. Set as true to enable all AI Monitoring features. + # Default is false. + enabled: false + + # Provides control over whether attributes for the input and output content should be added to LLM events. + record_content: + + # Set as false to disable attributes for the input and output content. + # Default is true. + enabled: true + # Provides the ability to forward application logs to New Relic, generate log usage metrics, # and decorate local application log files with agent metadata for use with third party log forwarders. # The application_logging.forwarding and application_logging.local_decorating should not be used together. diff --git a/newrelic-agent/src/test/java/com/newrelic/agent/aimonitoring/AiMonitoringImplTest.java b/newrelic-agent/src/test/java/com/newrelic/agent/aimonitoring/AiMonitoringImplTest.java new file mode 100644 index 0000000000..46254fa6e8 --- /dev/null +++ b/newrelic-agent/src/test/java/com/newrelic/agent/aimonitoring/AiMonitoringImplTest.java @@ -0,0 +1,91 @@ +package com.newrelic.agent.aimonitoring; + +import com.newrelic.agent.bridge.aimonitoring.LlmTokenCountCallbackHolder; +import com.newrelic.api.agent.LlmFeedbackEventAttributes; +import com.newrelic.api.agent.LlmTokenCountCallback; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnitRunner; + +import java.util.Map; + +import static org.junit.Assert.assertEquals; + +@RunWith(MockitoJUnitRunner.class) +public class AiMonitoringImplTest { + + @Mock + private AiMonitoringImpl aiMonitoringImpl; + + private LlmTokenCountCallback callback; + Map llmFeedbackEventAttributes; + + @Before + public void setup() { + String traceId = "123456"; + Integer rating = 5; + LlmFeedbackEventAttributes.Builder llmFeedbackEventBuilder = new LlmFeedbackEventAttributes.Builder(traceId, rating); + llmFeedbackEventAttributes = llmFeedbackEventBuilder + .category("General") + .message("Great experience") + .build(); + callback = getCallback(); + aiMonitoringImpl = new AiMonitoringImpl(); + } + + @Test + public void testRecordLlmFeedbackEventSent() { + try { + aiMonitoringImpl.recordLlmFeedbackEvent(llmFeedbackEventAttributes); + } catch (IllegalArgumentException e) { + // test should not catch an exception + } + + } + + @Test + public void testRecordLlmFeedbackEventWithNullAttributes() { + + try { + aiMonitoringImpl.recordLlmFeedbackEvent(null); + Assert.fail("Expected IllegalArgumentException to be thrown"); + } catch (IllegalArgumentException e) { + // Expected exception thrown, test passed + System.out.println("IllegalArgumentException successfully thrown!"); + } + } + + @Test + public void testCallbackSetSuccessfully() { + aiMonitoringImpl.setLlmTokenCountCallback(callback); + assertEquals(callback, LlmTokenCountCallbackHolder.getLlmTokenCountCallback()); + } + + @Test + public void testSetLlmTokenCountCallbackWithNull() { + + try { + aiMonitoringImpl.setLlmTokenCountCallback(null); + Assert.fail("Expected IllegalArgumentException to be thrown"); + } catch (IllegalArgumentException e) { + // Expected exception thrown, test passes + System.out.println("IllegalArgumentException successfully thrown!"); + } + + } + + public LlmTokenCountCallback getCallback() { + class TestCallback implements LlmTokenCountCallback { + + @Override + public int calculateLlmTokenCount(String model, String content) { + return 13; + } + } + return new TestCallback(); + } + +} diff --git a/newrelic-api/src/main/java/com/newrelic/api/agent/Agent.java b/newrelic-api/src/main/java/com/newrelic/api/agent/Agent.java index ef289ca4ad..023c34d278 100644 --- a/newrelic-api/src/main/java/com/newrelic/api/agent/Agent.java +++ b/newrelic-api/src/main/java/com/newrelic/api/agent/Agent.java @@ -63,6 +63,13 @@ public interface Agent { */ Insights getInsights(); + /** + * Provides access to the AI Monitoring custom events API. + * + * @return Object for recording custom events. + */ + AiMonitoring getAiMonitoring(); + ErrorApi getErrorApi(); /** diff --git a/newrelic-api/src/main/java/com/newrelic/api/agent/AiMonitoring.java b/newrelic-api/src/main/java/com/newrelic/api/agent/AiMonitoring.java new file mode 100644 index 0000000000..42494b1aab --- /dev/null +++ b/newrelic-api/src/main/java/com/newrelic/api/agent/AiMonitoring.java @@ -0,0 +1,53 @@ +package com.newrelic.api.agent; + +import java.util.Map; + +/** + * This interface defines methods for recording LlmFeedbackMessage events and setting a callback for token calculation. + */ +public interface AiMonitoring { + /** + * Records an LlmFeedbackMessage event. + * + * @param llmFeedbackEventAttributes A map containing the attributes of an LlmFeedbackMessage event. To construct + * the llmFeedbackEventAttributes map, use + * {@link LlmFeedbackEventAttributes.Builder} + *

The map must include:

+ *
    + *
  • "traceId" (String): Trace ID where the chat completion related to the + * feedback event occurred
  • + *
  • "rating" (Integer/String): Rating provided by an end user
  • + *
+ *

Optional attributes: + *

    + *
  • "category" (String): Category of the feedback as provided by the end user
  • + *
  • "message" (String): Freeform text feedback from an end user
  • + *
  • "metadata" (Map<String, String>): Set of key-value pairs to store + * additional data to submit with the feedback event
  • + *
+ * + * + */ + void recordLlmFeedbackEvent(Map llmFeedbackEventAttributes); + + /** + * Sets the callback function for calculating LLM tokens. + * + * @param llmTokenCountCallback The callback function to be invoked for counting LLM tokens. + * Example usage: + *
{@code
+     *                              LlmTokenCountCallback llmTokenCountCallback = new LlmTokenCountCallback() {
+     *                                  {@literal @}Override
+     *                                  public Integer calculateLlmTokenCount(String model, String content) {
+     *                                      // Token calculation based on model and content goes here
+     *                                      // Return the calculated token count
+     *                                  }
+     *                               };
+     *
+     *                               // Set the created callback instance
+     *                               NewRelic.getAgent().getAiMonitoring().setLlmTokenCountCallback(llmTokenCountCallback);
+     *                               }
+ */ + void setLlmTokenCountCallback(LlmTokenCountCallback llmTokenCountCallback); + +} diff --git a/newrelic-api/src/main/java/com/newrelic/api/agent/LlmFeedbackEventAttributes.java b/newrelic-api/src/main/java/com/newrelic/api/agent/LlmFeedbackEventAttributes.java new file mode 100644 index 0000000000..9211f5c1c8 --- /dev/null +++ b/newrelic-api/src/main/java/com/newrelic/api/agent/LlmFeedbackEventAttributes.java @@ -0,0 +1,107 @@ +package com.newrelic.api.agent; + +import java.util.HashMap; +import java.util.Map; +import java.util.UUID; + +public class LlmFeedbackEventAttributes { + private final String traceId; + private final Object rating; + private final String category; + private final String message; + private final Map metadata; + private final UUID id; + private final String ingestSource; + private static final String INGEST_SOURCE = "Java"; + + protected LlmFeedbackEventAttributes(String traceId, Object rating, String category, String message, Map metadata, UUID id, String ingestSource) { + this.traceId = traceId; + this.rating = rating; + this.category = category; + this.message = message; + this.metadata = metadata; + this.id = id; + this.ingestSource = ingestSource; + } + + public String getTraceId() { + return traceId; + } + + public Object getRating() { + return rating; + } + + + public String getCategory() { + return category; + } + + public String getMessage() { + return message; + } + + public Map getMetadata() { + return metadata; + } + + public UUID getId() { + return id; + } + + public String getIngestSource() { + return ingestSource; + } + + public Map toMap() { + Map feedbackAttributesMap = new HashMap<>(); + feedbackAttributesMap.put("trace_id", getTraceId()); + feedbackAttributesMap.put("rating", getRating()); + feedbackAttributesMap.put("id", getId()); + feedbackAttributesMap.put("ingest_source", getIngestSource()); + if (category != null) { + feedbackAttributesMap.put("category", getCategory()); + } + if (message != null) { + feedbackAttributesMap.put("message", getMessage()); + } + if (metadata != null) { + feedbackAttributesMap.put("metadata", getMetadata()); + } + return feedbackAttributesMap; + } + + public static class Builder { + private final String traceId; + private final Object rating; + private String category = null; + private String message = null; + private Map metadata = null; + private final UUID id = UUID.randomUUID(); + + public Builder(String traceId, Object rating) { + this.traceId = traceId; + this.rating = rating; + } + + public Builder category(String category) { + this.category = category; + return this; + } + + public Builder message(String message) { + this.message = message; + return this; + } + + public Builder metadata(Map metadata) { + this.metadata = metadata; + return this; + } + + public Map build() { + return new LlmFeedbackEventAttributes(traceId, rating, category, message, metadata, id, INGEST_SOURCE).toMap(); + + } + } +} diff --git a/newrelic-api/src/main/java/com/newrelic/api/agent/LlmTokenCountCallback.java b/newrelic-api/src/main/java/com/newrelic/api/agent/LlmTokenCountCallback.java new file mode 100644 index 0000000000..28f1d3864f --- /dev/null +++ b/newrelic-api/src/main/java/com/newrelic/api/agent/LlmTokenCountCallback.java @@ -0,0 +1,40 @@ +package com.newrelic.api.agent; + +/** + * An interface for calculating the number of tokens used for a given LLM (Large Language Model) and content. + *

+ * Implement this interface to define custom logic for token calculation based on your application's requirements. + *

+ *

+ * Example usage: + *

{@code
+ * class MyTokenCountCallback implements LlmTokenCountCallback {
+ *
+ *     @Override
+ *     public Integer calculateLlmTokenCount(String model, String content) {
+ *         // Implement your custom token calculating logic here
+ *         // This example calculates the number of tokens based on the length of the content
+ *         return content.length();
+ *     }
+ * }
+ *
+ * LlmTokenCountCallback myCallback = new MyTokenCountCallback();
+ * // After creating the myCallback instance, it should be passed as an argument to the setLlmTokenCountCallback
+ * // method of the AI Monitoring API.
+ * NewRelic.getAgent().getAiMonitoring.setLlmTokenCountCallback(myCallback);
+ * }
+ *

+ */ +public interface LlmTokenCountCallback { + + + /** + * Calculates the number of tokens used for a given LLM model and content. + * + * @param model The name of the LLM model. + * @param content The message content or prompt. + * @return An integer representing the number of tokens used for the given model and content. + * If the count cannot be determined or is less than or equal to 0, 0 is returned. + */ + public int calculateLlmTokenCount(String model, String content); +} diff --git a/newrelic-api/src/main/java/com/newrelic/api/agent/NoOpAgent.java b/newrelic-api/src/main/java/com/newrelic/api/agent/NoOpAgent.java index b207ad799f..dc61aee4f2 100644 --- a/newrelic-api/src/main/java/com/newrelic/api/agent/NoOpAgent.java +++ b/newrelic-api/src/main/java/com/newrelic/api/agent/NoOpAgent.java @@ -360,6 +360,14 @@ public void recordCustomEvent(String eventType, Map attributes) { } }; + private static final AiMonitoring AI_MONITORING = new AiMonitoring() { + @Override + public void recordLlmFeedbackEvent(Map llmFeedbackEventAttributes) {} + + @Override + public void setLlmTokenCountCallback(LlmTokenCountCallback llmTokenCountCallback) {} + }; + private static final Segment SEGMENT = new Segment() { @Override public void setMetricName(String... metricNameParts) { @@ -458,6 +466,11 @@ public Insights getInsights() { return INSIGHTS; } + @Override + public AiMonitoring getAiMonitoring() { + return AI_MONITORING; + } + @Override public ErrorApi getErrorApi() { return ERROR_API; diff --git a/newrelic-api/src/test/java/com/newrelic/api/agent/LlmFeedbackEventAttributesTest.java b/newrelic-api/src/test/java/com/newrelic/api/agent/LlmFeedbackEventAttributesTest.java new file mode 100644 index 0000000000..fb3a7c9898 --- /dev/null +++ b/newrelic-api/src/test/java/com/newrelic/api/agent/LlmFeedbackEventAttributesTest.java @@ -0,0 +1,91 @@ +package com.newrelic.api.agent; + +import org.junit.Before; +import org.junit.Test; + +import java.util.HashMap; +import java.util.Map; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; + +public class LlmFeedbackEventAttributesTest { + + LlmFeedbackEventAttributes.Builder llmFeedbackEventBuilder; + Map llmFeedbackEventAttributes; + + @Before + public void setup() { + String traceId = "123456"; + Object rating = 3; + llmFeedbackEventBuilder = new LlmFeedbackEventAttributes.Builder(traceId, rating); + } + + @Test + public void testBuilderWithRequiredParamsOnly() { + llmFeedbackEventAttributes = llmFeedbackEventBuilder.build(); + + assertNotNull(llmFeedbackEventAttributes); + assertEquals("123456", llmFeedbackEventAttributes.get("trace_id")); + assertEquals(3, llmFeedbackEventAttributes.get("rating")); + assertNotNull(llmFeedbackEventAttributes.get("id")); + assertEquals("Java", llmFeedbackEventAttributes.get("ingest_source")); + assertFalse(llmFeedbackEventAttributes.containsKey("category")); + assertFalse(llmFeedbackEventAttributes.containsKey("message")); + assertFalse(llmFeedbackEventAttributes.containsKey("metadata")); + } + + @Test + public void testBuilderWithRequiredAndOptionalParams() { + llmFeedbackEventAttributes = llmFeedbackEventBuilder + .category("exampleCategory") + .message("exampleMessage") + .metadata(createMetadataMap()) + .build(); + + assertNotNull(llmFeedbackEventAttributes); + assertEquals("123456", llmFeedbackEventAttributes.get("trace_id")); + assertEquals(3, llmFeedbackEventAttributes.get("rating")); + assertEquals("exampleCategory", llmFeedbackEventAttributes.get("category")); + assertEquals("exampleMessage", llmFeedbackEventAttributes.get("message")); + } + + @Test + public void testBuilderWithOptionalParamsSetToNull() { + llmFeedbackEventAttributes = llmFeedbackEventBuilder + .category(null) + .message(null) + .metadata(null) + .build(); + + assertNotNull(llmFeedbackEventAttributes); + assertEquals("123456", llmFeedbackEventAttributes.get("trace_id")); + assertEquals(3, llmFeedbackEventAttributes.get("rating")); + assertNull(llmFeedbackEventAttributes.get("category")); + assertNull(llmFeedbackEventAttributes.get("message")); + assertNull(llmFeedbackEventAttributes.get("metadata")); + assertNotNull(llmFeedbackEventAttributes.get("id")); + assertEquals("Java", llmFeedbackEventAttributes.get("ingest_source")); + } + + @Test + public void testBuilderWithRatingParamAsStringType() { + String traceId2 = "123456"; + Object rating2 = "3"; + llmFeedbackEventBuilder = new LlmFeedbackEventAttributes.Builder(traceId2, rating2); + llmFeedbackEventAttributes = llmFeedbackEventBuilder.build(); + + assertNotNull(llmFeedbackEventAttributes); + assertEquals("123456", llmFeedbackEventAttributes.get("trace_id")); + assertEquals("3", llmFeedbackEventAttributes.get("rating")); + } + + public Map createMetadataMap() { + Map map = new HashMap<>(); + map.put("key1", "val1"); + map.put("key2", "val2"); + return map; + } +} \ No newline at end of file diff --git a/newrelic-opentelemetry-agent-extension/src/main/java/com/newrelic/opentelemetry/OpenTelemetryAgent.java b/newrelic-opentelemetry-agent-extension/src/main/java/com/newrelic/opentelemetry/OpenTelemetryAgent.java index 8fd5de5bbf..ce1fdeb23a 100644 --- a/newrelic-opentelemetry-agent-extension/src/main/java/com/newrelic/opentelemetry/OpenTelemetryAgent.java +++ b/newrelic-opentelemetry-agent-extension/src/main/java/com/newrelic/opentelemetry/OpenTelemetryAgent.java @@ -8,6 +8,7 @@ package com.newrelic.opentelemetry; import com.newrelic.api.agent.Agent; +import com.newrelic.api.agent.AiMonitoring; import com.newrelic.api.agent.Config; import com.newrelic.api.agent.Insights; import com.newrelic.api.agent.Logger; @@ -73,6 +74,11 @@ public Insights getInsights() { return openTelemetryInsights; } + @Override + public AiMonitoring getAiMonitoring() { + return null; + } + @Override public TraceMetadata getTraceMetadata() { OpenTelemetryNewRelic.logUnsupportedMethod("Agent", "getTraceMetadata"); diff --git a/settings.gradle b/settings.gradle index 3eb6445904..e439ffec06 100644 --- a/settings.gradle +++ b/settings.gradle @@ -68,6 +68,7 @@ if (JavaVersion.current().isJava11Compatible()) { // Weaver Instrumentation include 'instrumentation:anorm-2.3' include 'instrumentation:anorm-2.4' +include 'instrumentation:aws-bedrock-runtime-2.20' include 'instrumentation:aws-java-sdk-sqs-1.10.44' include 'instrumentation:aws-java-sdk-s3-1.2.13' include 'instrumentation:aws-java-sdk-s3-2.0'