From f273fdc49d5f1494ca0034e01e2a6ef8b014abd7 Mon Sep 17 00:00:00 2001 From: Pat Whelan Date: Mon, 14 Oct 2024 21:37:00 -0400 Subject: [PATCH] [8.x] [ML] Stream Bedrock Completion (#114732) (#114781) * [ML] Stream Bedrock Completion (#114732) Notes: - Adds a new API to the chatCompletionRequest to invoke the Bedrock Stream API - Create a StreamingChatProcessor that subscribes to streaming results from bedrock and handles the parsing on another thread. - There was no good way (that I could see) to extend the Provider-based CompletionRequestEntity, so they have been flattened into one RequestEntity that can be shared between ConverseRequest and ConverseStreamRequest. * Use jdk17 API --- docs/changelog/114732.yaml | 5 + .../inference/src/main/java/module-info.java | 1 + .../AmazonBedrockChatCompletionExecutor.java | 18 +- .../amazonbedrock/AmazonBedrockClient.java | 5 + .../AmazonBedrockInferenceClient.java | 24 +- .../AmazonBedrockInferenceClientCache.java | 7 +- .../AmazonBedrockRequestSender.java | 6 +- .../AmazonBedrockStreamingChatProcessor.java | 156 +++++++++++ ...onBedrockChatCompletionRequestManager.java | 7 +- ...edrockAI21LabsCompletionRequestEntity.java | 60 ----- ...drockAnthropicCompletionRequestEntity.java | 67 ----- ...zonBedrockChatCompletionEntityFactory.java | 48 +--- .../AmazonBedrockChatCompletionRequest.java | 43 ++- ...nBedrockCohereCompletionRequestEntity.java | 67 ----- .../AmazonBedrockConverseRequestEntity.java | 23 +- .../AmazonBedrockConverseUtils.java | 33 +++ ...zonBedrockMetaCompletionRequestEntity.java | 60 ----- ...BedrockMistralCompletionRequestEntity.java | 67 ----- ...onBedrockTitanCompletionRequestEntity.java | 60 ----- .../amazonbedrock/AmazonBedrockService.java | 6 + .../AmazonBedrockExecutorTests.java | 10 +- ...mazonBedrockInferenceClientCacheTests.java | 2 +- .../AmazonBedrockMockInferenceClient.java | 12 +- ...zonBedrockStreamingChatProcessorTests.java | 251 ++++++++++++++++++ ...kAI21LabsCompletionRequestEntityTests.java | 70 ----- ...AnthropicCompletionRequestEntityTests.java | 82 ------ ...drockChatCompletionEntityFactoryTests.java | 102 +++++++ ...ockCohereCompletionRequestEntityTests.java | 82 ------ .../AmazonBedrockConverseRequestUtils.java | 12 +- ...drockMetaCompletionRequestEntityTests.java | 70 ----- ...ckMistralCompletionRequestEntityTests.java | 82 ------ ...rockTitanCompletionRequestEntityTests.java | 70 ----- .../AmazonBedrockServiceTests.java | 14 +- 33 files changed, 702 insertions(+), 920 deletions(-) create mode 100644 docs/changelog/114732.yaml create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockStreamingChatProcessor.java delete mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockAI21LabsCompletionRequestEntity.java delete mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockAnthropicCompletionRequestEntity.java delete mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockCohereCompletionRequestEntity.java delete mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockMetaCompletionRequestEntity.java delete mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockMistralCompletionRequestEntity.java delete mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockTitanCompletionRequestEntity.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockStreamingChatProcessorTests.java delete mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockAI21LabsCompletionRequestEntityTests.java delete mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockAnthropicCompletionRequestEntityTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockChatCompletionEntityFactoryTests.java delete mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockCohereCompletionRequestEntityTests.java delete mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockMetaCompletionRequestEntityTests.java delete mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockMistralCompletionRequestEntityTests.java delete mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockTitanCompletionRequestEntityTests.java diff --git a/docs/changelog/114732.yaml b/docs/changelog/114732.yaml new file mode 100644 index 0000000000000..42176cdbda443 --- /dev/null +++ b/docs/changelog/114732.yaml @@ -0,0 +1,5 @@ +pr: 114732 +summary: Stream Bedrock Completion +area: Machine Learning +type: enhancement +issues: [] diff --git a/x-pack/plugin/inference/src/main/java/module-info.java b/x-pack/plugin/inference/src/main/java/module-info.java index 53cb6ac154ced..60cb254e0afbe 100644 --- a/x-pack/plugin/inference/src/main/java/module-info.java +++ b/x-pack/plugin/inference/src/main/java/module-info.java @@ -32,6 +32,7 @@ requires software.amazon.awssdk.profiles; requires org.slf4j; requires software.amazon.awssdk.retries.api; + requires org.reactivestreams; exports org.elasticsearch.xpack.inference.action; exports org.elasticsearch.xpack.inference.registry; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockChatCompletionExecutor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockChatCompletionExecutor.java index a4e0c399517c1..2afa91d4dc776 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockChatCompletionExecutor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockChatCompletionExecutor.java @@ -10,6 +10,7 @@ import org.apache.logging.log4j.Logger; import org.elasticsearch.action.ActionListener; import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults; import org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockChatCompletionRequest; import org.elasticsearch.xpack.inference.external.response.amazonbedrock.AmazonBedrockResponseHandler; import org.elasticsearch.xpack.inference.external.response.amazonbedrock.completion.AmazonBedrockChatCompletionResponseListener; @@ -33,11 +34,16 @@ protected AmazonBedrockChatCompletionExecutor( @Override protected void executeClientRequest(AmazonBedrockBaseClient awsBedrockClient) { - var chatCompletionResponseListener = new AmazonBedrockChatCompletionResponseListener( - chatCompletionRequest, - responseHandler, - inferenceResultsListener - ); - chatCompletionRequest.executeChatCompletionRequest(awsBedrockClient, chatCompletionResponseListener); + if (chatCompletionRequest.isStreaming()) { + var publisher = chatCompletionRequest.executeStreamChatCompletionRequest(awsBedrockClient); + inferenceResultsListener.onResponse(new StreamingChatCompletionResults(publisher)); + } else { + var chatCompletionResponseListener = new AmazonBedrockChatCompletionResponseListener( + chatCompletionRequest, + responseHandler, + inferenceResultsListener + ); + chatCompletionRequest.executeChatCompletionRequest(awsBedrockClient, chatCompletionResponseListener); + } } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockClient.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockClient.java index 23b6884ddc33a..f1cfc84643b1c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockClient.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockClient.java @@ -9,17 +9,22 @@ import software.amazon.awssdk.services.bedrockruntime.model.ConverseRequest; import software.amazon.awssdk.services.bedrockruntime.model.ConverseResponse; +import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamRequest; import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelRequest; import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse; import org.elasticsearch.ElasticsearchException; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.xcontent.ChunkedToXContent; import java.time.Instant; +import java.util.concurrent.Flow; public interface AmazonBedrockClient { void converse(ConverseRequest converseRequest, ActionListener responseListener) throws ElasticsearchException; + Flow.Publisher converseStream(ConverseStreamRequest converseStreamRequest) throws ElasticsearchException; + void invokeModel(InvokeModelRequest invokeModelRequest, ActionListener responseListener) throws ElasticsearchException; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockInferenceClient.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockInferenceClient.java index b1486f4995b84..040aa99d81346 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockInferenceClient.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockInferenceClient.java @@ -17,16 +17,21 @@ import software.amazon.awssdk.services.bedrockruntime.model.BedrockRuntimeException; import software.amazon.awssdk.services.bedrockruntime.model.ConverseRequest; import software.amazon.awssdk.services.bedrockruntime.model.ConverseResponse; +import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamRequest; +import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamResponseHandler; import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelRequest; import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse; import org.elasticsearch.ElasticsearchException; import org.elasticsearch.SpecialPermission; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.xcontent.ChunkedToXContent; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.Strings; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockModel; +import org.reactivestreams.FlowAdapters; import org.slf4j.LoggerFactory; import java.security.AccessController; @@ -36,6 +41,7 @@ import java.util.Objects; import java.util.concurrent.CompletionException; import java.util.concurrent.ExecutionException; +import java.util.concurrent.Flow; /** * Not marking this as "final" so we can subclass it for mocking @@ -53,19 +59,21 @@ public class AmazonBedrockInferenceClient extends AmazonBedrockBaseClient { private static final Duration DEFAULT_CLIENT_TIMEOUT_MS = Duration.ofMillis(10000); private final BedrockRuntimeAsyncClient internalClient; + private final ThreadPool threadPool; private volatile Instant expiryTimestamp; - public static AmazonBedrockBaseClient create(AmazonBedrockModel model, @Nullable TimeValue timeout) { + public static AmazonBedrockBaseClient create(AmazonBedrockModel model, @Nullable TimeValue timeout, ThreadPool threadPool) { try { - return new AmazonBedrockInferenceClient(model, timeout); + return new AmazonBedrockInferenceClient(model, timeout, threadPool); } catch (Exception e) { throw new ElasticsearchException("Failed to create Amazon Bedrock Client", e); } } - protected AmazonBedrockInferenceClient(AmazonBedrockModel model, @Nullable TimeValue timeout) { + protected AmazonBedrockInferenceClient(AmazonBedrockModel model, @Nullable TimeValue timeout, ThreadPool threadPool) { super(model, timeout); this.internalClient = createAmazonBedrockClient(model, timeout); + this.threadPool = Objects.requireNonNull(threadPool); setExpiryTimestamp(); } @@ -79,6 +87,16 @@ public void converse(ConverseRequest converseRequest, ActionListener converseStream(ConverseStreamRequest request) throws ElasticsearchException { + var awsResponseProcessor = new AmazonBedrockStreamingChatProcessor(threadPool); + internalClient.converseStream( + request, + ConverseStreamResponseHandler.builder().subscriber(() -> FlowAdapters.toSubscriber(awsResponseProcessor)).build() + ); + return awsResponseProcessor; + } + private void onFailure(ActionListener listener, Throwable t, String method) { var unwrappedException = t; if (t instanceof CompletionException || t instanceof ExecutionException) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockInferenceClientCache.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockInferenceClientCache.java index 21e5cfaf211e5..339673e1302ac 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockInferenceClientCache.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockInferenceClientCache.java @@ -29,12 +29,9 @@ public final class AmazonBedrockInferenceClientCache implements AmazonBedrockCli // not final for testing private Clock clock; - public AmazonBedrockInferenceClientCache( - BiFunction creator, - @Nullable Clock clock - ) { + public AmazonBedrockInferenceClientCache(BiFunction creator, Clock clock) { this.creator = Objects.requireNonNull(creator); - this.clock = Objects.requireNonNullElse(clock, Clock.systemUTC()); + this.clock = Objects.requireNonNull(clock); } public AmazonBedrockBaseClient getOrCreateClient(AmazonBedrockModel model, @Nullable TimeValue timeout) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockRequestSender.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockRequestSender.java index e23b0274ede26..a8d85d896d684 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockRequestSender.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockRequestSender.java @@ -23,6 +23,7 @@ import org.elasticsearch.xpack.inference.services.ServiceComponents; import java.io.IOException; +import java.time.Clock; import java.util.Objects; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; @@ -42,7 +43,10 @@ public Factory(ServiceComponents serviceComponents, ClusterService clusterServic } public Sender createSender() { - var clientCache = new AmazonBedrockInferenceClientCache(AmazonBedrockInferenceClient::create, null); + var clientCache = new AmazonBedrockInferenceClientCache( + (model, timeout) -> AmazonBedrockInferenceClient.create(model, timeout, serviceComponents.threadPool()), + Clock.systemUTC() + ); return createSender(new AmazonBedrockExecuteOnlyRequestSender(clientCache, serviceComponents.throttlerManager())); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockStreamingChatProcessor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockStreamingChatProcessor.java new file mode 100644 index 0000000000000..439fc5b65efd5 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockStreamingChatProcessor.java @@ -0,0 +1,156 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.amazonbedrock; + +import software.amazon.awssdk.services.bedrockruntime.model.ContentBlockDeltaEvent; +import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamOutput; +import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamResponseHandler; + +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.common.util.concurrent.EsExecutors; +import org.elasticsearch.core.Strings; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults; + +import java.util.ArrayDeque; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.Flow; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicReference; + +import static org.elasticsearch.xpack.inference.InferencePlugin.UTILITY_THREAD_POOL_NAME; + +class AmazonBedrockStreamingChatProcessor implements Flow.Processor { + private final AtomicReference error = new AtomicReference<>(null); + private final AtomicLong demand = new AtomicLong(0); + private final AtomicBoolean isDone = new AtomicBoolean(false); + private final AtomicBoolean onCompleteCalled = new AtomicBoolean(false); + private final AtomicBoolean onErrorCalled = new AtomicBoolean(false); + private final ThreadPool threadPool; + private volatile Flow.Subscriber downstream; + private volatile Flow.Subscription upstream; + + AmazonBedrockStreamingChatProcessor(ThreadPool threadPool) { + this.threadPool = threadPool; + } + + @Override + public void subscribe(Flow.Subscriber subscriber) { + if (downstream == null) { + downstream = subscriber; + downstream.onSubscribe(new StreamSubscription()); + } else { + subscriber.onError(new IllegalStateException("Subscriber already set.")); + } + } + + @Override + public void onSubscribe(Flow.Subscription subscription) { + if (upstream == null) { + upstream = subscription; + var currentRequestCount = demand.getAndUpdate(i -> 0); + if (currentRequestCount > 0) { + upstream.request(currentRequestCount); + } + } else { + subscription.cancel(); + } + } + + @Override + public void onNext(ConverseStreamOutput item) { + if (item.sdkEventType() == ConverseStreamOutput.EventType.CONTENT_BLOCK_DELTA) { + demand.set(0); // reset demand before we fork to another thread + item.accept(ConverseStreamResponseHandler.Visitor.builder().onContentBlockDelta(this::sendDownstreamOnAnotherThread).build()); + } else { + upstream.request(1); + } + } + + // this is always called from a netty thread maintained by the AWS SDK, we'll move it to our thread to process the response + private void sendDownstreamOnAnotherThread(ContentBlockDeltaEvent event) { + CompletableFuture.runAsync(() -> { + var text = event.delta().text(); + var result = new ArrayDeque(1); + result.offer(new StreamingChatCompletionResults.Result(text)); + var results = new StreamingChatCompletionResults.Results(result); + downstream.onNext(results); + }, threadPool.executor(UTILITY_THREAD_POOL_NAME)); + } + + @Override + public void onError(Throwable amazonBedrockRuntimeException) { + error.set( + new ElasticsearchException( + Strings.format("AmazonBedrock StreamingChatProcessor failure: [%s]", amazonBedrockRuntimeException.getMessage()), + amazonBedrockRuntimeException + ) + ); + if (isDone.compareAndSet(false, true) && checkAndResetDemand() && onErrorCalled.compareAndSet(false, true)) { + downstream.onError(error.get()); + } + } + + private boolean checkAndResetDemand() { + return demand.getAndUpdate(i -> 0L) > 0L; + } + + @Override + public void onComplete() { + if (isDone.compareAndSet(false, true) && checkAndResetDemand() && onCompleteCalled.compareAndSet(false, true)) { + downstream.onComplete(); + } + } + + private class StreamSubscription implements Flow.Subscription { + @Override + public void request(long n) { + if (n > 0L) { + demand.updateAndGet(i -> { + var sum = i + n; + return sum >= 0 ? sum : Long.MAX_VALUE; + }); + if (upstream == null) { + // wait for upstream to subscribe before forwarding request + return; + } + if (upstreamIsRunning()) { + requestOnMlThread(n); + } else if (error.get() != null && onErrorCalled.compareAndSet(false, true)) { + downstream.onError(error.get()); + } else if (onCompleteCalled.compareAndSet(false, true)) { + downstream.onComplete(); + } + } else { + cancel(); + downstream.onError(new IllegalStateException("Cannot request a negative number.")); + } + } + + private boolean upstreamIsRunning() { + return isDone.get() == false && error.get() == null; + } + + private void requestOnMlThread(long n) { + var currentThreadPool = EsExecutors.executorName(Thread.currentThread().getName()); + if (UTILITY_THREAD_POOL_NAME.equalsIgnoreCase(currentThreadPool)) { + upstream.request(n); + } else { + CompletableFuture.runAsync(() -> upstream.request(n), threadPool.executor(UTILITY_THREAD_POOL_NAME)); + } + } + + @Override + public void cancel() { + if (upstream != null && upstreamIsRunning()) { + upstream.cancel(); + } + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AmazonBedrockChatCompletionRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AmazonBedrockChatCompletionRequestManager.java index 1c6bb58717942..69a5c665feb86 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AmazonBedrockChatCompletionRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AmazonBedrockChatCompletionRequestManager.java @@ -22,7 +22,6 @@ import org.elasticsearch.xpack.inference.external.response.amazonbedrock.completion.AmazonBedrockChatCompletionResponseHandler; import org.elasticsearch.xpack.inference.services.amazonbedrock.completion.AmazonBedrockChatCompletionModel; -import java.util.List; import java.util.function.Supplier; public class AmazonBedrockChatCompletionRequestManager extends AmazonBedrockRequestManager { @@ -45,9 +44,11 @@ public void execute( Supplier hasRequestCompletedFunction, ActionListener listener ) { - List docsInput = DocumentsOnlyInput.of(inferenceInputs).getInputs(); + var docsOnly = DocumentsOnlyInput.of(inferenceInputs); + var docsInput = docsOnly.getInputs(); + var stream = docsOnly.stream(); var requestEntity = AmazonBedrockChatCompletionEntityFactory.createEntity(model, docsInput); - var request = new AmazonBedrockChatCompletionRequest(model, requestEntity, timeout); + var request = new AmazonBedrockChatCompletionRequest(model, requestEntity, timeout, stream); var responseHandler = new AmazonBedrockChatCompletionResponseHandler(); try { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockAI21LabsCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockAI21LabsCompletionRequestEntity.java deleted file mode 100644 index aff01316838f8..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockAI21LabsCompletionRequestEntity.java +++ /dev/null @@ -1,60 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion; - -import software.amazon.awssdk.services.bedrockruntime.model.ConverseRequest; - -import org.elasticsearch.core.Nullable; - -import java.util.List; -import java.util.Objects; - -import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseUtils.getConverseMessageList; - -public record AmazonBedrockAI21LabsCompletionRequestEntity( - List messages, - @Nullable Double temperature, - @Nullable Double topP, - @Nullable Integer maxTokenCount -) implements AmazonBedrockConverseRequestEntity { - - public AmazonBedrockAI21LabsCompletionRequestEntity { - Objects.requireNonNull(messages); - } - - @Override - public ConverseRequest.Builder addMessages(ConverseRequest.Builder request) { - return request.messages(getConverseMessageList(messages)); - } - - @Override - public ConverseRequest.Builder addInferenceConfig(ConverseRequest.Builder request) { - if (temperature == null && topP == null && maxTokenCount == null) { - return request; - } - - return request.inferenceConfig(config -> { - if (temperature != null) { - config.temperature(temperature.floatValue()); - } - - if (topP != null) { - config.topP(topP.floatValue()); - } - - if (maxTokenCount != null) { - config.maxTokens(maxTokenCount); - } - }); - } - - @Override - public ConverseRequest.Builder addAdditionalModelFields(ConverseRequest.Builder request) { - return request; - } -} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockAnthropicCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockAnthropicCompletionRequestEntity.java deleted file mode 100644 index 540012c221192..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockAnthropicCompletionRequestEntity.java +++ /dev/null @@ -1,67 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion; - -import software.amazon.awssdk.services.bedrockruntime.model.ConverseRequest; - -import org.elasticsearch.core.Nullable; -import org.elasticsearch.core.Strings; - -import java.util.List; -import java.util.Objects; - -import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseUtils.getConverseMessageList; - -public record AmazonBedrockAnthropicCompletionRequestEntity( - List messages, - @Nullable Double temperature, - @Nullable Double topP, - @Nullable Double topK, - @Nullable Integer maxTokenCount -) implements AmazonBedrockConverseRequestEntity { - - public AmazonBedrockAnthropicCompletionRequestEntity { - Objects.requireNonNull(messages); - } - - @Override - public ConverseRequest.Builder addMessages(ConverseRequest.Builder request) { - return request.messages(getConverseMessageList(messages)); - } - - @Override - public ConverseRequest.Builder addInferenceConfig(ConverseRequest.Builder request) { - if (temperature == null && topP == null && maxTokenCount == null) { - return request; - } - - return request.inferenceConfig(config -> { - if (temperature != null) { - config.temperature(temperature.floatValue()); - } - - if (topP != null) { - config.topP(topP.floatValue()); - } - - if (maxTokenCount != null) { - config.maxTokens(maxTokenCount); - } - }); - } - - @Override - public ConverseRequest.Builder addAdditionalModelFields(ConverseRequest.Builder request) { - if (topK == null) { - return request; - } - - String topKField = Strings.format("{\"top_k\":%f}", topK.floatValue()); - return request.additionalModelResponseFieldPaths(topKField); - } -} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockChatCompletionEntityFactory.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockChatCompletionEntityFactory.java index f86d2229d42ad..db902290ba0be 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockChatCompletionEntityFactory.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockChatCompletionEntityFactory.java @@ -12,6 +12,8 @@ import java.util.List; import java.util.Objects; +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseUtils.additionalTopK; + public final class AmazonBedrockChatCompletionEntityFactory { public static AmazonBedrockConverseRequestEntity createEntity(AmazonBedrockChatCompletionModel model, List messages) { Objects.requireNonNull(model); @@ -19,55 +21,21 @@ public static AmazonBedrockConverseRequestEntity createEntity(AmazonBedrockChatC var serviceSettings = model.getServiceSettings(); var taskSettings = model.getTaskSettings(); switch (serviceSettings.provider()) { - case AI21LABS -> { - return new AmazonBedrockAI21LabsCompletionRequestEntity( - messages, - taskSettings.temperature(), - taskSettings.topP(), - taskSettings.maxNewTokens() - ); - } - case AMAZONTITAN -> { - return new AmazonBedrockTitanCompletionRequestEntity( - messages, - taskSettings.temperature(), - taskSettings.topP(), - taskSettings.maxNewTokens() - ); - } - case ANTHROPIC -> { - return new AmazonBedrockAnthropicCompletionRequestEntity( + case AI21LABS, AMAZONTITAN, META -> { + return new AmazonBedrockConverseRequestEntity( messages, taskSettings.temperature(), taskSettings.topP(), - taskSettings.topK(), taskSettings.maxNewTokens() ); } - case COHERE -> { - return new AmazonBedrockCohereCompletionRequestEntity( + case ANTHROPIC, COHERE, MISTRAL -> { + return new AmazonBedrockConverseRequestEntity( messages, taskSettings.temperature(), taskSettings.topP(), - taskSettings.topK(), - taskSettings.maxNewTokens() - ); - } - case META -> { - return new AmazonBedrockMetaCompletionRequestEntity( - messages, - taskSettings.temperature(), - taskSettings.topP(), - taskSettings.maxNewTokens() - ); - } - case MISTRAL -> { - return new AmazonBedrockMistralCompletionRequestEntity( - messages, - taskSettings.temperature(), - taskSettings.topP(), - taskSettings.topK(), - taskSettings.maxNewTokens() + taskSettings.maxNewTokens(), + additionalTopK(taskSettings.topK()) ); } default -> { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockChatCompletionRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockChatCompletionRequest.java index 61e0504732462..05d7d90873a71 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockChatCompletionRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockChatCompletionRequest.java @@ -8,7 +8,9 @@ package org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion; import software.amazon.awssdk.services.bedrockruntime.model.ConverseRequest; +import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamRequest; +import org.elasticsearch.common.xcontent.ChunkedToXContent; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.TaskType; @@ -20,19 +22,26 @@ import java.io.IOException; import java.util.Objects; +import java.util.concurrent.Flow; + +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseUtils.getConverseMessageList; +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseUtils.inferenceConfig; public class AmazonBedrockChatCompletionRequest extends AmazonBedrockRequest { public static final String USER_ROLE = "user"; private final AmazonBedrockConverseRequestEntity requestEntity; private AmazonBedrockChatCompletionResponseListener listener; + private final boolean stream; public AmazonBedrockChatCompletionRequest( AmazonBedrockChatCompletionModel model, AmazonBedrockConverseRequestEntity requestEntity, - @Nullable TimeValue timeout + @Nullable TimeValue timeout, + boolean stream ) { super(model, timeout); this.requestEntity = Objects.requireNonNull(requestEntity); + this.stream = stream; } @Override @@ -52,10 +61,16 @@ public TaskType taskType() { } private ConverseRequest getConverseRequest() { - var converseRequest = ConverseRequest.builder().modelId(amazonBedrockModel.model()); - converseRequest = requestEntity.addMessages(converseRequest); - converseRequest = requestEntity.addInferenceConfig(converseRequest); - converseRequest = requestEntity.addAdditionalModelFields(converseRequest); + var converseRequest = ConverseRequest.builder() + .modelId(amazonBedrockModel.model()) + .messages(getConverseMessageList(requestEntity.messages())) + .additionalModelResponseFieldPaths(requestEntity.additionalModelFields()); + + inferenceConfig(requestEntity).ifPresent(converseRequest::inferenceConfig); + + if (requestEntity.additionalModelFields() != null) { + converseRequest.additionalModelResponseFieldPaths(requestEntity.additionalModelFields()); + } return converseRequest.build(); } @@ -66,4 +81,22 @@ public void executeChatCompletionRequest( this.listener = chatCompletionResponseListener; this.executeRequest(awsBedrockClient); } + + public Flow.Publisher executeStreamChatCompletionRequest(AmazonBedrockBaseClient awsBedrockClient) { + var converseStreamRequest = ConverseStreamRequest.builder() + .modelId(amazonBedrockModel.model()) + .messages(getConverseMessageList(requestEntity.messages())); + + inferenceConfig(requestEntity).ifPresent(converseStreamRequest::inferenceConfig); + + if (requestEntity.additionalModelFields() != null) { + converseStreamRequest.additionalModelResponseFieldPaths(requestEntity.additionalModelFields()); + } + return awsBedrockClient.converseStream(converseStreamRequest.build()); + } + + @Override + public boolean isStreaming() { + return stream; + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockCohereCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockCohereCompletionRequestEntity.java deleted file mode 100644 index f1ae04ad39516..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockCohereCompletionRequestEntity.java +++ /dev/null @@ -1,67 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion; - -import software.amazon.awssdk.services.bedrockruntime.model.ConverseRequest; - -import org.elasticsearch.core.Nullable; -import org.elasticsearch.core.Strings; - -import java.util.List; -import java.util.Objects; - -import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseUtils.getConverseMessageList; - -public record AmazonBedrockCohereCompletionRequestEntity( - List messages, - @Nullable Double temperature, - @Nullable Double topP, - @Nullable Double topK, - @Nullable Integer maxTokenCount -) implements AmazonBedrockConverseRequestEntity { - - public AmazonBedrockCohereCompletionRequestEntity { - Objects.requireNonNull(messages); - } - - @Override - public ConverseRequest.Builder addMessages(ConverseRequest.Builder request) { - return request.messages(getConverseMessageList(messages)); - } - - @Override - public ConverseRequest.Builder addInferenceConfig(ConverseRequest.Builder request) { - if (temperature == null && topP == null && maxTokenCount == null) { - return request; - } - - return request.inferenceConfig(config -> { - if (temperature != null) { - config.temperature(temperature.floatValue()); - } - - if (topP != null) { - config.topP(topP.floatValue()); - } - - if (maxTokenCount != null) { - config.maxTokens(maxTokenCount); - } - }); - } - - @Override - public ConverseRequest.Builder addAdditionalModelFields(ConverseRequest.Builder request) { - if (topK == null) { - return request; - } - - String topKField = Strings.format("{\"top_k\":%f}", topK.floatValue()); - return request.additionalModelResponseFieldPaths(topKField); - } -} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockConverseRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockConverseRequestEntity.java index d8e9fa43797cd..203b2820ab16f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockConverseRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockConverseRequestEntity.java @@ -7,12 +7,23 @@ package org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion; -import software.amazon.awssdk.services.bedrockruntime.model.ConverseRequest; +import org.elasticsearch.core.Nullable; -public interface AmazonBedrockConverseRequestEntity { - ConverseRequest.Builder addMessages(ConverseRequest.Builder request); +import java.util.List; - ConverseRequest.Builder addInferenceConfig(ConverseRequest.Builder request); - - ConverseRequest.Builder addAdditionalModelFields(ConverseRequest.Builder request); +public record AmazonBedrockConverseRequestEntity( + List messages, + @Nullable Double temperature, + @Nullable Double topP, + @Nullable Integer maxTokenCount, + @Nullable List additionalModelFields +) { + public AmazonBedrockConverseRequestEntity( + List messages, + @Nullable Double temperature, + @Nullable Double topP, + @Nullable Integer maxTokenCount + ) { + this(messages, temperature, topP, maxTokenCount, null); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockConverseUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockConverseUtils.java index 22e0d26a315a7..eb1652ff7ff6d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockConverseUtils.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockConverseUtils.java @@ -8,9 +8,14 @@ package org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion; import software.amazon.awssdk.services.bedrockruntime.model.ContentBlock; +import software.amazon.awssdk.services.bedrockruntime.model.InferenceConfiguration; import software.amazon.awssdk.services.bedrockruntime.model.Message; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.Strings; + import java.util.List; +import java.util.Optional; import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockChatCompletionRequest.USER_ROLE; @@ -22,4 +27,32 @@ public static List getConverseMessageList(List texts) { .map(content -> Message.builder().role(USER_ROLE).content(content).build()) .toList(); } + + public static Optional inferenceConfig(AmazonBedrockConverseRequestEntity request) { + if (request.temperature() != null || request.topP() != null || request.maxTokenCount() != null) { + var builder = InferenceConfiguration.builder(); + if (request.temperature() != null) { + builder.temperature(request.temperature().floatValue()); + } + + if (request.topP() != null) { + builder.topP(request.topP().floatValue()); + } + + if (request.maxTokenCount() != null) { + builder.maxTokens(request.maxTokenCount()); + } + return Optional.of(builder.build()); + } + return Optional.empty(); + } + + @Nullable + public static List additionalTopK(@Nullable Double topK) { + if (topK == null) { + return null; + } + + return List.of(Strings.format("{\"top_k\":%f}", topK.floatValue())); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockMetaCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockMetaCompletionRequestEntity.java deleted file mode 100644 index c21791ced02cb..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockMetaCompletionRequestEntity.java +++ /dev/null @@ -1,60 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion; - -import software.amazon.awssdk.services.bedrockruntime.model.ConverseRequest; - -import org.elasticsearch.core.Nullable; - -import java.util.List; -import java.util.Objects; - -import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseUtils.getConverseMessageList; - -public record AmazonBedrockMetaCompletionRequestEntity( - List messages, - @Nullable Double temperature, - @Nullable Double topP, - @Nullable Integer maxTokenCount -) implements AmazonBedrockConverseRequestEntity { - - public AmazonBedrockMetaCompletionRequestEntity { - Objects.requireNonNull(messages); - } - - @Override - public ConverseRequest.Builder addMessages(ConverseRequest.Builder request) { - return request.messages(getConverseMessageList(messages)); - } - - @Override - public ConverseRequest.Builder addInferenceConfig(ConverseRequest.Builder request) { - if (temperature == null && topP == null && maxTokenCount == null) { - return request; - } - - return request.inferenceConfig(config -> { - if (temperature != null) { - config.temperature(temperature.floatValue()); - } - - if (topP != null) { - config.topP(topP.floatValue()); - } - - if (maxTokenCount != null) { - config.maxTokens(maxTokenCount); - } - }); - } - - @Override - public ConverseRequest.Builder addAdditionalModelFields(ConverseRequest.Builder request) { - return request; - } -} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockMistralCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockMistralCompletionRequestEntity.java deleted file mode 100644 index 15931674cbabb..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockMistralCompletionRequestEntity.java +++ /dev/null @@ -1,67 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion; - -import software.amazon.awssdk.services.bedrockruntime.model.ConverseRequest; - -import org.elasticsearch.core.Nullable; -import org.elasticsearch.core.Strings; - -import java.util.List; -import java.util.Objects; - -import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseUtils.getConverseMessageList; - -public record AmazonBedrockMistralCompletionRequestEntity( - List messages, - @Nullable Double temperature, - @Nullable Double topP, - @Nullable Double topK, - @Nullable Integer maxTokenCount -) implements AmazonBedrockConverseRequestEntity { - - public AmazonBedrockMistralCompletionRequestEntity { - Objects.requireNonNull(messages); - } - - @Override - public ConverseRequest.Builder addMessages(ConverseRequest.Builder request) { - return request.messages(getConverseMessageList(messages)); - } - - @Override - public ConverseRequest.Builder addInferenceConfig(ConverseRequest.Builder request) { - if (temperature == null && topP == null && maxTokenCount == null) { - return request; - } - - return request.inferenceConfig(config -> { - if (temperature != null) { - config.temperature(temperature.floatValue()); - } - - if (topP != null) { - config.topP(topP.floatValue()); - } - - if (maxTokenCount != null) { - config.maxTokens(maxTokenCount); - } - }); - } - - @Override - public ConverseRequest.Builder addAdditionalModelFields(ConverseRequest.Builder request) { - if (topK == null) { - return request; - } - - String topKField = Strings.format("{\"top_k\":%f}", topK.floatValue()); - return request.additionalModelResponseFieldPaths(topKField); - } -} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockTitanCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockTitanCompletionRequestEntity.java deleted file mode 100644 index e267592dfd0ba..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockTitanCompletionRequestEntity.java +++ /dev/null @@ -1,60 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion; - -import software.amazon.awssdk.services.bedrockruntime.model.ConverseRequest; - -import org.elasticsearch.core.Nullable; - -import java.util.List; -import java.util.Objects; - -import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseUtils.getConverseMessageList; - -public record AmazonBedrockTitanCompletionRequestEntity( - List messages, - @Nullable Double temperature, - @Nullable Double topP, - @Nullable Integer maxTokenCount -) implements AmazonBedrockConverseRequestEntity { - - public AmazonBedrockTitanCompletionRequestEntity { - Objects.requireNonNull(messages); - } - - @Override - public ConverseRequest.Builder addMessages(ConverseRequest.Builder request) { - return request.messages(getConverseMessageList(messages)); - } - - @Override - public ConverseRequest.Builder addInferenceConfig(ConverseRequest.Builder request) { - if (temperature == null && topP == null && maxTokenCount == null) { - return request; - } - - return request.inferenceConfig(config -> { - if (temperature != null) { - config.temperature(temperature.floatValue()); - } - - if (topP != null) { - config.topP(topP.floatValue()); - } - - if (maxTokenCount != null) { - config.maxTokens(maxTokenCount); - } - }); - } - - @Override - public ConverseRequest.Builder addAdditionalModelFields(ConverseRequest.Builder request) { - return request; - } -} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java index 0e70ce0655b1e..64b9635175c54 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java @@ -43,6 +43,7 @@ import java.io.IOException; import java.util.List; import java.util.Map; +import java.util.Set; import static org.elasticsearch.TransportVersions.ML_INFERENCE_AMAZON_BEDROCK_ADDED; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; @@ -260,6 +261,11 @@ public TransportVersion getMinimalSupportedVersion() { return ML_INFERENCE_AMAZON_BEDROCK_ADDED; } + @Override + public Set supportedStreamingTasks() { + return COMPLETION_ONLY; + } + /** * For text embedding models get the embedding size and * update the service settings. diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockExecutorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockExecutorTests.java index 8f09c53c99366..6d601b4b08c53 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockExecutorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockExecutorTests.java @@ -20,7 +20,7 @@ import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockChatCompletionRequest; -import org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockTitanCompletionRequestEntity; +import org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestEntity; import org.elasticsearch.xpack.inference.external.request.amazonbedrock.embeddings.AmazonBedrockEmbeddingsRequest; import org.elasticsearch.xpack.inference.external.request.amazonbedrock.embeddings.AmazonBedrockTitanEmbeddingsRequestEntity; import org.elasticsearch.xpack.inference.external.response.amazonbedrock.completion.AmazonBedrockChatCompletionResponseHandler; @@ -100,8 +100,8 @@ public void testExecute_ChatCompletionRequest() throws CharacterCodingException "secretkey" ); - var requestEntity = new AmazonBedrockTitanCompletionRequestEntity(List.of("abc"), null, null, 512); - var request = new AmazonBedrockChatCompletionRequest(model, requestEntity, null); + var requestEntity = new AmazonBedrockConverseRequestEntity(List.of("abc"), null, null, 512); + var request = new AmazonBedrockChatCompletionRequest(model, requestEntity, null, false); var responseHandler = new AmazonBedrockChatCompletionResponseHandler(); var clientCache = new AmazonBedrockMockClientCache(getTestConverseResult("converse result"), null, null); @@ -124,8 +124,8 @@ public void testExecute_FailsProperly_WithElasticsearchException() { "secretkey" ); - var requestEntity = new AmazonBedrockTitanCompletionRequestEntity(List.of("abc"), null, null, 512); - var request = new AmazonBedrockChatCompletionRequest(model, requestEntity, null); + var requestEntity = new AmazonBedrockConverseRequestEntity(List.of("abc"), null, null, 512); + var request = new AmazonBedrockChatCompletionRequest(model, requestEntity, null, false); var responseHandler = new AmazonBedrockChatCompletionResponseHandler(); var clientCache = new AmazonBedrockMockClientCache(null, null, new ElasticsearchException("test exception")); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockInferenceClientCacheTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockInferenceClientCacheTests.java index 873b2e22497c6..bb7c669cdf09b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockInferenceClientCacheTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockInferenceClientCacheTests.java @@ -25,7 +25,7 @@ public class AmazonBedrockInferenceClientCacheTests extends ESTestCase { public void testCache_ReturnsSameObject() throws IOException { AmazonBedrockInferenceClientCache cacheInstance; - try (var cache = new AmazonBedrockInferenceClientCache(AmazonBedrockMockInferenceClient::create, null)) { + try (var cache = new AmazonBedrockInferenceClientCache(AmazonBedrockMockInferenceClient::create, Clock.systemUTC())) { cacheInstance = cache; var model = AmazonBedrockEmbeddingsModelTests.createModel( "inferenceId", diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockMockInferenceClient.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockMockInferenceClient.java index 5584e90b3264d..e6cd667b824b3 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockMockInferenceClient.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockMockInferenceClient.java @@ -14,15 +14,19 @@ import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse; import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockModel; import java.util.concurrent.CompletableFuture; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; public class AmazonBedrockMockInferenceClient extends AmazonBedrockInferenceClient { private CompletableFuture converseResponseFuture = CompletableFuture.completedFuture(null); @@ -33,7 +37,13 @@ public static AmazonBedrockMockInferenceClient create(AmazonBedrockModel model, } protected AmazonBedrockMockInferenceClient(AmazonBedrockModel model, @Nullable TimeValue timeout) { - super(model, timeout); + super(model, timeout, mockThreadPool()); + } + + private static ThreadPool mockThreadPool() { + ThreadPool threadPool = mock(); + when(threadPool.executor(anyString())).thenReturn(EsExecutors.DIRECT_EXECUTOR_SERVICE); + return threadPool; } public void setExceptionToThrow(ElasticsearchException exceptionToThrow) { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockStreamingChatProcessorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockStreamingChatProcessorTests.java new file mode 100644 index 0000000000000..ba87bdfe16cdd --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockStreamingChatProcessorTests.java @@ -0,0 +1,251 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.amazonbedrock; + +import software.amazon.awssdk.services.bedrockruntime.model.BedrockRuntimeException; +import software.amazon.awssdk.services.bedrockruntime.model.ContentBlockDelta; +import software.amazon.awssdk.services.bedrockruntime.model.ContentBlockDeltaEvent; +import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamOutput; +import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamResponseHandler; + +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.common.util.concurrent.EsExecutors; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults; +import org.junit.Before; +import org.mockito.ArgumentCaptor; + +import java.util.Arrays; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Flow; +import java.util.concurrent.atomic.AtomicInteger; + +import static org.elasticsearch.xpack.inference.InferencePlugin.UTILITY_THREAD_POOL_NAME; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.isA; +import static org.hamcrest.Matchers.notNullValue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.ArgumentMatchers.assertArg; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; + +public class AmazonBedrockStreamingChatProcessorTests extends ESTestCase { + private AmazonBedrockStreamingChatProcessor processor; + + @Before + public void setUp() throws Exception { + super.setUp(); + ThreadPool threadPool = mock(); + when(threadPool.executor(UTILITY_THREAD_POOL_NAME)).thenReturn(EsExecutors.DIRECT_EXECUTOR_SERVICE); + processor = new AmazonBedrockStreamingChatProcessor(threadPool); + } + + /** + * We do not issue requests on subscribe because the downstream will control the pacing. + */ + public void testOnSubscribeBeforeDownstreamDoesNotRequest() { + var upstream = mock(Flow.Subscription.class); + processor.onSubscribe(upstream); + + verify(upstream, never()).request(anyLong()); + } + + /** + * If the downstream requests data before the upstream is set, when the upstream is set, we will forward the pending requests to it. + */ + public void testOnSubscribeAfterDownstreamRequests() { + var expectedRequestCount = randomLongBetween(1, 500); + Flow.Subscriber subscriber = mock(); + doAnswer(ans -> { + Flow.Subscription sub = ans.getArgument(0); + sub.request(expectedRequestCount); + return null; + }).when(subscriber).onSubscribe(any()); + processor.subscribe(subscriber); + + var upstream = mock(Flow.Subscription.class); + processor.onSubscribe(upstream); + + verify(upstream, times(1)).request(anyLong()); + } + + public void testCancelDuplicateSubscriptions() { + processor.onSubscribe(mock()); + + var upstream = mock(Flow.Subscription.class); + processor.onSubscribe(upstream); + + verify(upstream, times(1)).cancel(); + verifyNoMoreInteractions(upstream); + } + + public void testMultiplePublishesCallsOnError() { + processor.subscribe(mock()); + + Flow.Subscriber subscriber = mock(); + processor.subscribe(subscriber); + + verify(subscriber, times(1)).onError(assertArg(e -> { + assertThat(e, isA(IllegalStateException.class)); + assertThat(e.getMessage(), equalTo("Subscriber already set.")); + })); + } + + public void testNonDeltaBlocksAreSkipped() { + var upstream = mock(Flow.Subscription.class); + processor.onSubscribe(upstream); + var counter = new AtomicInteger(); + Arrays.stream(ConverseStreamOutput.EventType.values()) + .filter(type -> type != ConverseStreamOutput.EventType.CONTENT_BLOCK_DELTA) + .forEach(type -> { + ConverseStreamOutput output = mock(); + when(output.sdkEventType()).thenReturn(type); + processor.onNext(output); + verify(upstream, times(counter.incrementAndGet())).request(eq(1L)); + }); + } + + public void testDeltaBlockForwardsDownstream() { + var expectedText = "hello"; + + // mock executorservice so we can make sure we handle the response on another thread + ExecutorService executorService = mock(); + ThreadPool threadPool = mock(); + when(threadPool.executor(UTILITY_THREAD_POOL_NAME)).thenReturn(executorService); + processor = new AmazonBedrockStreamingChatProcessor(threadPool); + doAnswer(ans -> { + Runnable command = ans.getArgument(0); + command.run(); + return null; + }).when(executorService).execute(any()); + + Flow.Subscription upstream = mock(); + processor.onSubscribe(upstream); + Flow.Subscriber downstream = mock(); + processor.subscribe(downstream); + + ConverseStreamOutput output = output(expectedText); + + processor.onNext(output); + + verifyText(downstream, expectedText); + verify(executorService, times(1)).execute(any()); + verify(upstream, times(0)).request(anyLong()); + } + + private ConverseStreamOutput output(String text) { + ConverseStreamOutput output = mock(); + when(output.sdkEventType()).thenReturn(ConverseStreamOutput.EventType.CONTENT_BLOCK_DELTA); + doAnswer(ans -> { + ConverseStreamResponseHandler.Visitor visitor = ans.getArgument(0); + ContentBlockDelta delta = ContentBlockDelta.fromText(text); + ContentBlockDeltaEvent event = ContentBlockDeltaEvent.builder().delta(delta).build(); + visitor.visitContentBlockDelta(event); + return null; + }).when(output).accept(any()); + return output; + } + + private void verifyText(Flow.Subscriber downstream, String expectedText) { + verify(downstream, times(1)).onNext(assertArg(results -> { + assertThat(results, notNullValue()); + assertThat(results.results().size(), equalTo(1)); + assertThat(results.results().getFirst().delta(), equalTo(expectedText)); + })); + } + + public void verifyCompleteBeforeRequest() { + processor.onComplete(); + + Flow.Subscriber downstream = mock(); + var sub = ArgumentCaptor.forClass(Flow.Subscription.class); + processor.subscribe(downstream); + verify(downstream).onSubscribe(sub.capture()); + + sub.getValue().request(1); + verify(downstream, times(1)).onComplete(); + } + + public void verifyCompleteAfterRequest() { + + Flow.Subscriber downstream = mock(); + var sub = ArgumentCaptor.forClass(Flow.Subscription.class); + processor.subscribe(downstream); + verify(downstream).onSubscribe(sub.capture()); + + sub.getValue().request(1); + processor.onComplete(); + verify(downstream, times(1)).onComplete(); + } + + public void verifyOnErrorBeforeRequest() { + var expectedError = BedrockRuntimeException.builder().message("ahhhhhh").build(); + processor.onError(expectedError); + + Flow.Subscriber downstream = mock(); + var sub = ArgumentCaptor.forClass(Flow.Subscription.class); + processor.subscribe(downstream); + verify(downstream).onSubscribe(sub.capture()); + + sub.getValue().request(1); + verify(downstream, times(1)).onError(assertArg(e -> { + assertThat(e, isA(ElasticsearchException.class)); + assertThat(e.getCause(), is(expectedError)); + })); + } + + public void verifyOnErrorAfterRequest() { + var expectedError = BedrockRuntimeException.builder().message("ahhhhhh").build(); + + Flow.Subscriber downstream = mock(); + var sub = ArgumentCaptor.forClass(Flow.Subscription.class); + processor.subscribe(downstream); + verify(downstream).onSubscribe(sub.capture()); + + sub.getValue().request(1); + processor.onError(expectedError); + verify(downstream, times(1)).onError(assertArg(e -> { + assertThat(e, isA(ElasticsearchException.class)); + assertThat(e.getCause(), is(expectedError)); + })); + } + + public void verifyAsyncOnCompleteIsStillDeliveredSynchronously() { + mockUpstream(); + + Flow.Subscriber downstream = mock(); + var sub = ArgumentCaptor.forClass(Flow.Subscription.class); + processor.subscribe(downstream); + verify(downstream).onSubscribe(sub.capture()); + + sub.getValue().request(1); + verify(downstream, times(1)).onNext(any()); + processor.onComplete(); + verify(downstream, times(0)).onComplete(); + sub.getValue().request(1); + verify(downstream, times(1)).onComplete(); + } + + private void mockUpstream() { + Flow.Subscription upstream = mock(); + doAnswer(ans -> { + processor.onNext(output(randomIdentifier())); + return null; + }).when(upstream).request(anyLong()); + processor.onSubscribe(upstream); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockAI21LabsCompletionRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockAI21LabsCompletionRequestEntityTests.java deleted file mode 100644 index 10c8943c75f6c..0000000000000 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockAI21LabsCompletionRequestEntityTests.java +++ /dev/null @@ -1,70 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion; - -import org.elasticsearch.test.ESTestCase; - -import java.util.List; - -import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHasMessage; -import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveAnyMaxTokensInput; -import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveAnyTemperatureInput; -import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveAnyTopKInput; -import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveAnyTopPInput; -import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveMaxTokensInput; -import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveTemperatureInput; -import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveTopPInput; -import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.getConverseRequest; -import static org.hamcrest.Matchers.is; - -public class AmazonBedrockAI21LabsCompletionRequestEntityTests extends ESTestCase { - public void testRequestEntity_CreatesProperRequest() { - var request = new AmazonBedrockAI21LabsCompletionRequestEntity(List.of("test message"), null, null, null); - var builtRequest = getConverseRequest("testmodel", request); - assertThat(builtRequest.modelId(), is("testmodel")); - assertThat(doesConverseRequestHasMessage(builtRequest, "test message"), is(true)); - assertThat(builtRequest.modelId(), is("testmodel")); - assertFalse(doesConverseRequestHaveAnyTemperatureInput(builtRequest)); - assertFalse(doesConverseRequestHaveAnyTopPInput(builtRequest)); - assertFalse(doesConverseRequestHaveAnyTopKInput(builtRequest)); - assertFalse(doesConverseRequestHaveAnyMaxTokensInput(builtRequest)); - } - - public void testRequestEntity_CreatesProperRequest_WithTemperature() { - var request = new AmazonBedrockAI21LabsCompletionRequestEntity(List.of("test message"), 1.0, null, null); - var builtRequest = getConverseRequest("testmodel", request); - assertThat(builtRequest.modelId(), is("testmodel")); - assertThat(doesConverseRequestHasMessage(builtRequest, "test message"), is(true)); - assertTrue(doesConverseRequestHaveTemperatureInput(builtRequest, 1.0)); - assertFalse(doesConverseRequestHaveAnyTopPInput(builtRequest)); - assertFalse(doesConverseRequestHaveAnyTopKInput(builtRequest)); - assertFalse(doesConverseRequestHaveAnyMaxTokensInput(builtRequest)); - } - - public void testRequestEntity_CreatesProperRequest_WithTopP() { - var request = new AmazonBedrockAI21LabsCompletionRequestEntity(List.of("test message"), null, 1.0, null); - var builtRequest = getConverseRequest("testmodel", request); - assertThat(builtRequest.modelId(), is("testmodel")); - assertThat(doesConverseRequestHasMessage(builtRequest, "test message"), is(true)); - assertFalse(doesConverseRequestHaveAnyTemperatureInput(builtRequest)); - assertTrue(doesConverseRequestHaveTopPInput(builtRequest, 1.0)); - assertFalse(doesConverseRequestHaveAnyTopKInput(builtRequest)); - assertFalse(doesConverseRequestHaveAnyMaxTokensInput(builtRequest)); - } - - public void testRequestEntity_CreatesProperRequest_WithMaxTokens() { - var request = new AmazonBedrockAI21LabsCompletionRequestEntity(List.of("test message"), null, null, 128); - var builtRequest = getConverseRequest("testmodel", request); - assertThat(builtRequest.modelId(), is("testmodel")); - assertThat(doesConverseRequestHasMessage(builtRequest, "test message"), is(true)); - assertFalse(doesConverseRequestHaveAnyTemperatureInput(builtRequest)); - assertFalse(doesConverseRequestHaveAnyTopPInput(builtRequest)); - assertFalse(doesConverseRequestHaveAnyTopKInput(builtRequest)); - assertTrue(doesConverseRequestHaveMaxTokensInput(builtRequest, 128)); - } -} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockAnthropicCompletionRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockAnthropicCompletionRequestEntityTests.java deleted file mode 100644 index e8a3440a37294..0000000000000 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockAnthropicCompletionRequestEntityTests.java +++ /dev/null @@ -1,82 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion; - -import org.elasticsearch.test.ESTestCase; - -import java.util.List; - -import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHasMessage; -import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveAnyMaxTokensInput; -import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveAnyTemperatureInput; -import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveAnyTopKInput; -import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveAnyTopPInput; -import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveMaxTokensInput; -import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveTemperatureInput; -import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveTopKInput; -import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveTopPInput; -import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.getConverseRequest; -import static org.hamcrest.Matchers.is; - -public class AmazonBedrockAnthropicCompletionRequestEntityTests extends ESTestCase { - public void testRequestEntity_CreatesProperRequest() { - var request = new AmazonBedrockAnthropicCompletionRequestEntity(List.of("test message"), null, null, null, null); - var builtRequest = getConverseRequest("testmodel", request); - assertThat(builtRequest.modelId(), is("testmodel")); - assertThat(doesConverseRequestHasMessage(builtRequest, "test message"), is(true)); - assertThat(builtRequest.modelId(), is("testmodel")); - assertFalse(doesConverseRequestHaveAnyTemperatureInput(builtRequest)); - assertFalse(doesConverseRequestHaveAnyTopPInput(builtRequest)); - assertFalse(doesConverseRequestHaveAnyTopKInput(builtRequest)); - assertFalse(doesConverseRequestHaveAnyMaxTokensInput(builtRequest)); - } - - public void testRequestEntity_CreatesProperRequest_WithTemperature() { - var request = new AmazonBedrockAnthropicCompletionRequestEntity(List.of("test message"), 1.0, null, null, null); - var builtRequest = getConverseRequest("testmodel", request); - assertThat(builtRequest.modelId(), is("testmodel")); - assertThat(doesConverseRequestHasMessage(builtRequest, "test message"), is(true)); - assertTrue(doesConverseRequestHaveTemperatureInput(builtRequest, 1.0)); - assertFalse(doesConverseRequestHaveAnyTopPInput(builtRequest)); - assertFalse(doesConverseRequestHaveAnyTopKInput(builtRequest)); - assertFalse(doesConverseRequestHaveAnyMaxTokensInput(builtRequest)); - } - - public void testRequestEntity_CreatesProperRequest_WithTopP() { - var request = new AmazonBedrockAnthropicCompletionRequestEntity(List.of("test message"), null, 1.0, null, null); - var builtRequest = getConverseRequest("testmodel", request); - assertThat(builtRequest.modelId(), is("testmodel")); - assertThat(doesConverseRequestHasMessage(builtRequest, "test message"), is(true)); - assertFalse(doesConverseRequestHaveAnyTemperatureInput(builtRequest)); - assertTrue(doesConverseRequestHaveTopPInput(builtRequest, 1.0)); - assertFalse(doesConverseRequestHaveAnyTopKInput(builtRequest)); - assertFalse(doesConverseRequestHaveAnyMaxTokensInput(builtRequest)); - } - - public void testRequestEntity_CreatesProperRequest_WithMaxTokens() { - var request = new AmazonBedrockAnthropicCompletionRequestEntity(List.of("test message"), null, null, null, 128); - var builtRequest = getConverseRequest("testmodel", request); - assertThat(builtRequest.modelId(), is("testmodel")); - assertThat(doesConverseRequestHasMessage(builtRequest, "test message"), is(true)); - assertFalse(doesConverseRequestHaveAnyTemperatureInput(builtRequest)); - assertFalse(doesConverseRequestHaveAnyTopPInput(builtRequest)); - assertFalse(doesConverseRequestHaveAnyTopKInput(builtRequest)); - assertTrue(doesConverseRequestHaveMaxTokensInput(builtRequest, 128)); - } - - public void testRequestEntity_CreatesProperRequest_WithTopK() { - var request = new AmazonBedrockAnthropicCompletionRequestEntity(List.of("test message"), null, null, 1.0, null); - var builtRequest = getConverseRequest("testmodel", request); - assertThat(builtRequest.modelId(), is("testmodel")); - assertThat(doesConverseRequestHasMessage(builtRequest, "test message"), is(true)); - assertFalse(doesConverseRequestHaveAnyTemperatureInput(builtRequest)); - assertFalse(doesConverseRequestHaveAnyTopPInput(builtRequest)); - assertTrue(doesConverseRequestHaveTopKInput(builtRequest, 1.0)); - assertFalse(doesConverseRequestHaveAnyMaxTokensInput(builtRequest)); - } -} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockChatCompletionEntityFactoryTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockChatCompletionEntityFactoryTests.java new file mode 100644 index 0000000000000..29ed3e4d89273 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockChatCompletionEntityFactoryTests.java @@ -0,0 +1,102 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion; + +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockProvider; +import org.elasticsearch.xpack.inference.services.amazonbedrock.completion.AmazonBedrockChatCompletionModel; +import org.elasticsearch.xpack.inference.services.amazonbedrock.completion.AmazonBedrockChatCompletionServiceSettings; +import org.elasticsearch.xpack.inference.services.amazonbedrock.completion.AmazonBedrockChatCompletionTaskSettings; + +import java.io.IOException; +import java.util.List; + +import static org.elasticsearch.xcontent.XContentParserConfiguration.EMPTY; +import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockProvider.AI21LABS; +import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockProvider.AMAZONTITAN; +import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockProvider.ANTHROPIC; +import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockProvider.COHERE; +import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockProvider.META; +import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockProvider.MISTRAL; +import static org.hamcrest.Matchers.closeTo; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.notNullValue; +import static org.hamcrest.Matchers.nullValue; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class AmazonBedrockChatCompletionEntityFactoryTests extends ESTestCase { + public void testEntitiesWithoutAdditionalMessages() { + List.of(AI21LABS, AMAZONTITAN, META).forEach(provider -> { + var expectedTemp = randomDoubleBetween(1, 10, true); + var expectedTopP = randomDoubleBetween(1, 10, true); + + var expectedMaxToken = randomIntBetween(1, 10); + var expectedMessage = List.of(randomIdentifier()); + var model = model(provider, expectedTemp, expectedTopP, expectedMaxToken); + + var entity = AmazonBedrockChatCompletionEntityFactory.createEntity(model, expectedMessage); + + assertThat(entity, notNullValue()); + assertThat(entity.temperature(), equalTo(expectedTemp)); + assertThat(entity.topP(), equalTo(expectedTopP)); + assertThat(entity.maxTokenCount(), equalTo(expectedMaxToken)); + assertThat(entity.additionalModelFields(), nullValue()); + assertThat(entity.messages(), equalTo(expectedMessage)); + }); + } + + public void testWithAdditionalMessages() { + List.of(ANTHROPIC, COHERE, MISTRAL).forEach(provider -> { + var expectedTemp = randomDoubleBetween(1, 10, true); + var expectedTopP = randomDoubleBetween(1, 10, true); + var expectedMaxToken = randomIntBetween(1, 10); + var expectedMessage = List.of(randomIdentifier()); + var expectedTopK = randomDoubleBetween(1, 10, true); + var model = model(provider, expectedTemp, expectedTopP, expectedMaxToken, expectedTopK); + + var entity = AmazonBedrockChatCompletionEntityFactory.createEntity(model, expectedMessage); + + assertThat(entity, notNullValue()); + assertThat(entity.temperature(), equalTo(expectedTemp)); + assertThat(entity.topP(), equalTo(expectedTopP)); + assertThat(entity.maxTokenCount(), equalTo(expectedMaxToken)); + assertThat(entity.messages(), equalTo(expectedMessage)); + assertThat(entity.additionalModelFields(), notNullValue()); + assertThat(entity.additionalModelFields().size(), equalTo(1)); + try (var parser = XContentFactory.xContent(XContentType.JSON).createParser(EMPTY, entity.additionalModelFields().get(0))) { + var additionalModelFields = parser.map(); + assertThat((Double) additionalModelFields.get("top_k"), closeTo(expectedTopK, 0.1)); + } catch (IOException e) { + fail(e); + } + }); + } + + AmazonBedrockChatCompletionModel model(AmazonBedrockProvider provider, Double temperature, Double topP, Integer maxTokenCount) { + return model(provider, temperature, topP, maxTokenCount, null); + } + + AmazonBedrockChatCompletionModel model(AmazonBedrockProvider provider, Double temp, Double topP, Integer tokenCount, Double topK) { + var serviceSettings = mock(AmazonBedrockChatCompletionServiceSettings.class); + when(serviceSettings.provider()).thenReturn(provider); + + var taskSettings = mock(AmazonBedrockChatCompletionTaskSettings.class); + when(taskSettings.temperature()).thenReturn(temp); + when(taskSettings.topP()).thenReturn(topP); + when(taskSettings.maxNewTokens()).thenReturn(tokenCount); + when(taskSettings.topK()).thenReturn(topK); + + var model = mock(AmazonBedrockChatCompletionModel.class); + when(model.getServiceSettings()).thenReturn(serviceSettings); + when(model.getTaskSettings()).thenReturn(taskSettings); + return model; + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockCohereCompletionRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockCohereCompletionRequestEntityTests.java deleted file mode 100644 index c8e844d000240..0000000000000 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockCohereCompletionRequestEntityTests.java +++ /dev/null @@ -1,82 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion; - -import org.elasticsearch.test.ESTestCase; - -import java.util.List; - -import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHasMessage; -import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveAnyMaxTokensInput; -import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveAnyTemperatureInput; -import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveAnyTopKInput; -import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveAnyTopPInput; -import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveMaxTokensInput; -import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveTemperatureInput; -import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveTopKInput; -import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveTopPInput; -import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.getConverseRequest; -import static org.hamcrest.Matchers.is; - -public class AmazonBedrockCohereCompletionRequestEntityTests extends ESTestCase { - public void testRequestEntity_CreatesProperRequest() { - var request = new AmazonBedrockCohereCompletionRequestEntity(List.of("test message"), null, null, null, null); - var builtRequest = getConverseRequest("testmodel", request); - assertThat(builtRequest.modelId(), is("testmodel")); - assertThat(doesConverseRequestHasMessage(builtRequest, "test message"), is(true)); - assertThat(builtRequest.modelId(), is("testmodel")); - assertFalse(doesConverseRequestHaveAnyTemperatureInput(builtRequest)); - assertFalse(doesConverseRequestHaveAnyTopPInput(builtRequest)); - assertFalse(doesConverseRequestHaveAnyTopKInput(builtRequest)); - assertFalse(doesConverseRequestHaveAnyMaxTokensInput(builtRequest)); - } - - public void testRequestEntity_CreatesProperRequest_WithTemperature() { - var request = new AmazonBedrockCohereCompletionRequestEntity(List.of("test message"), 1.0, null, null, null); - var builtRequest = getConverseRequest("testmodel", request); - assertThat(builtRequest.modelId(), is("testmodel")); - assertThat(doesConverseRequestHasMessage(builtRequest, "test message"), is(true)); - assertTrue(doesConverseRequestHaveTemperatureInput(builtRequest, 1.0)); - assertFalse(doesConverseRequestHaveAnyTopPInput(builtRequest)); - assertFalse(doesConverseRequestHaveAnyTopKInput(builtRequest)); - assertFalse(doesConverseRequestHaveAnyMaxTokensInput(builtRequest)); - } - - public void testRequestEntity_CreatesProperRequest_WithTopP() { - var request = new AmazonBedrockCohereCompletionRequestEntity(List.of("test message"), null, 1.0, null, null); - var builtRequest = getConverseRequest("testmodel", request); - assertThat(builtRequest.modelId(), is("testmodel")); - assertThat(doesConverseRequestHasMessage(builtRequest, "test message"), is(true)); - assertFalse(doesConverseRequestHaveAnyTemperatureInput(builtRequest)); - assertTrue(doesConverseRequestHaveTopPInput(builtRequest, 1.0)); - assertFalse(doesConverseRequestHaveAnyTopKInput(builtRequest)); - assertFalse(doesConverseRequestHaveAnyMaxTokensInput(builtRequest)); - } - - public void testRequestEntity_CreatesProperRequest_WithMaxTokens() { - var request = new AmazonBedrockCohereCompletionRequestEntity(List.of("test message"), null, null, null, 128); - var builtRequest = getConverseRequest("testmodel", request); - assertThat(builtRequest.modelId(), is("testmodel")); - assertThat(doesConverseRequestHasMessage(builtRequest, "test message"), is(true)); - assertFalse(doesConverseRequestHaveAnyTemperatureInput(builtRequest)); - assertFalse(doesConverseRequestHaveAnyTopPInput(builtRequest)); - assertFalse(doesConverseRequestHaveAnyTopKInput(builtRequest)); - assertTrue(doesConverseRequestHaveMaxTokensInput(builtRequest, 128)); - } - - public void testRequestEntity_CreatesProperRequest_WithTopK() { - var request = new AmazonBedrockCohereCompletionRequestEntity(List.of("test message"), null, null, 1.0, null); - var builtRequest = getConverseRequest("testmodel", request); - assertThat(builtRequest.modelId(), is("testmodel")); - assertThat(doesConverseRequestHasMessage(builtRequest, "test message"), is(true)); - assertFalse(doesConverseRequestHaveAnyTemperatureInput(builtRequest)); - assertFalse(doesConverseRequestHaveAnyTopPInput(builtRequest)); - assertTrue(doesConverseRequestHaveTopKInput(builtRequest, 1.0)); - assertFalse(doesConverseRequestHaveAnyMaxTokensInput(builtRequest)); - } -} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockConverseRequestUtils.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockConverseRequestUtils.java index 17c3b4488bae4..0e7acd3337e0f 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockConverseRequestUtils.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockConverseRequestUtils.java @@ -15,12 +15,16 @@ import java.util.Collection; +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseUtils.getConverseMessageList; +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseUtils.inferenceConfig; + public final class AmazonBedrockConverseRequestUtils { public static ConverseRequest getConverseRequest(String modelId, AmazonBedrockConverseRequestEntity requestEntity) { - var converseRequest = ConverseRequest.builder().modelId(modelId); - converseRequest = requestEntity.addMessages(converseRequest); - converseRequest = requestEntity.addInferenceConfig(converseRequest); - converseRequest = requestEntity.addAdditionalModelFields(converseRequest); + var converseRequest = ConverseRequest.builder() + .modelId(modelId) + .messages(getConverseMessageList(requestEntity.messages())) + .additionalModelResponseFieldPaths(requestEntity.additionalModelFields()); + inferenceConfig(requestEntity).ifPresent(converseRequest::inferenceConfig); return converseRequest.build(); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockMetaCompletionRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockMetaCompletionRequestEntityTests.java deleted file mode 100644 index 25700f7c7aee1..0000000000000 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockMetaCompletionRequestEntityTests.java +++ /dev/null @@ -1,70 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion; - -import org.elasticsearch.test.ESTestCase; - -import java.util.List; - -import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHasMessage; -import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveAnyMaxTokensInput; -import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveAnyTemperatureInput; -import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveAnyTopKInput; -import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveAnyTopPInput; -import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveMaxTokensInput; -import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveTemperatureInput; -import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveTopPInput; -import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.getConverseRequest; -import static org.hamcrest.Matchers.is; - -public class AmazonBedrockMetaCompletionRequestEntityTests extends ESTestCase { - public void testRequestEntity_CreatesProperRequest() { - var request = new AmazonBedrockMetaCompletionRequestEntity(List.of("test message"), null, null, null); - var builtRequest = getConverseRequest("testmodel", request); - assertThat(builtRequest.modelId(), is("testmodel")); - assertThat(doesConverseRequestHasMessage(builtRequest, "test message"), is(true)); - assertThat(builtRequest.modelId(), is("testmodel")); - assertFalse(doesConverseRequestHaveAnyTemperatureInput(builtRequest)); - assertFalse(doesConverseRequestHaveAnyTopPInput(builtRequest)); - assertFalse(doesConverseRequestHaveAnyTopKInput(builtRequest)); - assertFalse(doesConverseRequestHaveAnyMaxTokensInput(builtRequest)); - } - - public void testRequestEntity_CreatesProperRequest_WithTemperature() { - var request = new AmazonBedrockMetaCompletionRequestEntity(List.of("test message"), 1.0, null, null); - var builtRequest = getConverseRequest("testmodel", request); - assertThat(builtRequest.modelId(), is("testmodel")); - assertThat(doesConverseRequestHasMessage(builtRequest, "test message"), is(true)); - assertTrue(doesConverseRequestHaveTemperatureInput(builtRequest, 1.0)); - assertFalse(doesConverseRequestHaveAnyTopPInput(builtRequest)); - assertFalse(doesConverseRequestHaveAnyTopKInput(builtRequest)); - assertFalse(doesConverseRequestHaveAnyMaxTokensInput(builtRequest)); - } - - public void testRequestEntity_CreatesProperRequest_WithTopP() { - var request = new AmazonBedrockMetaCompletionRequestEntity(List.of("test message"), null, 1.0, null); - var builtRequest = getConverseRequest("testmodel", request); - assertThat(builtRequest.modelId(), is("testmodel")); - assertThat(doesConverseRequestHasMessage(builtRequest, "test message"), is(true)); - assertFalse(doesConverseRequestHaveAnyTemperatureInput(builtRequest)); - assertTrue(doesConverseRequestHaveTopPInput(builtRequest, 1.0)); - assertFalse(doesConverseRequestHaveAnyTopKInput(builtRequest)); - assertFalse(doesConverseRequestHaveAnyMaxTokensInput(builtRequest)); - } - - public void testRequestEntity_CreatesProperRequest_WithMaxTokens() { - var request = new AmazonBedrockMetaCompletionRequestEntity(List.of("test message"), null, null, 128); - var builtRequest = getConverseRequest("testmodel", request); - assertThat(builtRequest.modelId(), is("testmodel")); - assertThat(doesConverseRequestHasMessage(builtRequest, "test message"), is(true)); - assertFalse(doesConverseRequestHaveAnyTemperatureInput(builtRequest)); - assertFalse(doesConverseRequestHaveAnyTopPInput(builtRequest)); - assertFalse(doesConverseRequestHaveAnyTopKInput(builtRequest)); - assertTrue(doesConverseRequestHaveMaxTokensInput(builtRequest, 128)); - } -} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockMistralCompletionRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockMistralCompletionRequestEntityTests.java deleted file mode 100644 index 8e321b0cb33a7..0000000000000 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockMistralCompletionRequestEntityTests.java +++ /dev/null @@ -1,82 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion; - -import org.elasticsearch.test.ESTestCase; - -import java.util.List; - -import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHasMessage; -import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveAnyMaxTokensInput; -import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveAnyTemperatureInput; -import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveAnyTopKInput; -import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveAnyTopPInput; -import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveMaxTokensInput; -import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveTemperatureInput; -import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveTopKInput; -import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveTopPInput; -import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.getConverseRequest; -import static org.hamcrest.Matchers.is; - -public class AmazonBedrockMistralCompletionRequestEntityTests extends ESTestCase { - public void testRequestEntity_CreatesProperRequest() { - var request = new AmazonBedrockMistralCompletionRequestEntity(List.of("test message"), null, null, null, null); - var builtRequest = getConverseRequest("testmodel", request); - assertThat(builtRequest.modelId(), is("testmodel")); - assertThat(doesConverseRequestHasMessage(builtRequest, "test message"), is(true)); - assertThat(builtRequest.modelId(), is("testmodel")); - assertFalse(doesConverseRequestHaveAnyTemperatureInput(builtRequest)); - assertFalse(doesConverseRequestHaveAnyTopPInput(builtRequest)); - assertFalse(doesConverseRequestHaveAnyTopKInput(builtRequest)); - assertFalse(doesConverseRequestHaveAnyMaxTokensInput(builtRequest)); - } - - public void testRequestEntity_CreatesProperRequest_WithTemperature() { - var request = new AmazonBedrockMistralCompletionRequestEntity(List.of("test message"), 1.0, null, null, null); - var builtRequest = getConverseRequest("testmodel", request); - assertThat(builtRequest.modelId(), is("testmodel")); - assertThat(doesConverseRequestHasMessage(builtRequest, "test message"), is(true)); - assertTrue(doesConverseRequestHaveTemperatureInput(builtRequest, 1.0)); - assertFalse(doesConverseRequestHaveAnyTopPInput(builtRequest)); - assertFalse(doesConverseRequestHaveAnyTopKInput(builtRequest)); - assertFalse(doesConverseRequestHaveAnyMaxTokensInput(builtRequest)); - } - - public void testRequestEntity_CreatesProperRequest_WithTopP() { - var request = new AmazonBedrockMistralCompletionRequestEntity(List.of("test message"), null, 1.0, null, null); - var builtRequest = getConverseRequest("testmodel", request); - assertThat(builtRequest.modelId(), is("testmodel")); - assertThat(doesConverseRequestHasMessage(builtRequest, "test message"), is(true)); - assertFalse(doesConverseRequestHaveAnyTemperatureInput(builtRequest)); - assertTrue(doesConverseRequestHaveTopPInput(builtRequest, 1.0)); - assertFalse(doesConverseRequestHaveAnyTopKInput(builtRequest)); - assertFalse(doesConverseRequestHaveAnyMaxTokensInput(builtRequest)); - } - - public void testRequestEntity_CreatesProperRequest_WithMaxTokens() { - var request = new AmazonBedrockMistralCompletionRequestEntity(List.of("test message"), null, null, null, 128); - var builtRequest = getConverseRequest("testmodel", request); - assertThat(builtRequest.modelId(), is("testmodel")); - assertThat(doesConverseRequestHasMessage(builtRequest, "test message"), is(true)); - assertFalse(doesConverseRequestHaveAnyTemperatureInput(builtRequest)); - assertFalse(doesConverseRequestHaveAnyTopPInput(builtRequest)); - assertFalse(doesConverseRequestHaveAnyTopKInput(builtRequest)); - assertTrue(doesConverseRequestHaveMaxTokensInput(builtRequest, 128)); - } - - public void testRequestEntity_CreatesProperRequest_WithTopK() { - var request = new AmazonBedrockMistralCompletionRequestEntity(List.of("test message"), null, null, 1.0, null); - var builtRequest = getConverseRequest("testmodel", request); - assertThat(builtRequest.modelId(), is("testmodel")); - assertThat(doesConverseRequestHasMessage(builtRequest, "test message"), is(true)); - assertFalse(doesConverseRequestHaveAnyTemperatureInput(builtRequest)); - assertFalse(doesConverseRequestHaveAnyTopPInput(builtRequest)); - assertTrue(doesConverseRequestHaveTopKInput(builtRequest, 1.0)); - assertFalse(doesConverseRequestHaveAnyMaxTokensInput(builtRequest)); - } -} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockTitanCompletionRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockTitanCompletionRequestEntityTests.java deleted file mode 100644 index 8d1c15499bfb6..0000000000000 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockTitanCompletionRequestEntityTests.java +++ /dev/null @@ -1,70 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion; - -import org.elasticsearch.test.ESTestCase; - -import java.util.List; - -import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHasMessage; -import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveAnyMaxTokensInput; -import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveAnyTemperatureInput; -import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveAnyTopKInput; -import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveAnyTopPInput; -import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveMaxTokensInput; -import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveTemperatureInput; -import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveTopPInput; -import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.getConverseRequest; -import static org.hamcrest.Matchers.is; - -public class AmazonBedrockTitanCompletionRequestEntityTests extends ESTestCase { - public void testRequestEntity_CreatesProperRequest() { - var request = new AmazonBedrockTitanCompletionRequestEntity(List.of("test message"), null, null, null); - var builtRequest = getConverseRequest("testmodel", request); - assertThat(builtRequest.modelId(), is("testmodel")); - assertThat(doesConverseRequestHasMessage(builtRequest, "test message"), is(true)); - assertThat(builtRequest.modelId(), is("testmodel")); - assertFalse(doesConverseRequestHaveAnyTemperatureInput(builtRequest)); - assertFalse(doesConverseRequestHaveAnyTopPInput(builtRequest)); - assertFalse(doesConverseRequestHaveAnyTopKInput(builtRequest)); - assertFalse(doesConverseRequestHaveAnyMaxTokensInput(builtRequest)); - } - - public void testRequestEntity_CreatesProperRequest_WithTemperature() { - var request = new AmazonBedrockTitanCompletionRequestEntity(List.of("test message"), 1.0, null, null); - var builtRequest = getConverseRequest("testmodel", request); - assertThat(builtRequest.modelId(), is("testmodel")); - assertThat(doesConverseRequestHasMessage(builtRequest, "test message"), is(true)); - assertTrue(doesConverseRequestHaveTemperatureInput(builtRequest, 1.0)); - assertFalse(doesConverseRequestHaveAnyTopPInput(builtRequest)); - assertFalse(doesConverseRequestHaveAnyTopKInput(builtRequest)); - assertFalse(doesConverseRequestHaveAnyMaxTokensInput(builtRequest)); - } - - public void testRequestEntity_CreatesProperRequest_WithTopP() { - var request = new AmazonBedrockTitanCompletionRequestEntity(List.of("test message"), null, 1.0, null); - var builtRequest = getConverseRequest("testmodel", request); - assertThat(builtRequest.modelId(), is("testmodel")); - assertThat(doesConverseRequestHasMessage(builtRequest, "test message"), is(true)); - assertFalse(doesConverseRequestHaveAnyTemperatureInput(builtRequest)); - assertTrue(doesConverseRequestHaveTopPInput(builtRequest, 1.0)); - assertFalse(doesConverseRequestHaveAnyTopKInput(builtRequest)); - assertFalse(doesConverseRequestHaveAnyMaxTokensInput(builtRequest)); - } - - public void testRequestEntity_CreatesProperRequest_WithMaxTokens() { - var request = new AmazonBedrockTitanCompletionRequestEntity(List.of("test message"), null, null, 128); - var builtRequest = getConverseRequest("testmodel", request); - assertThat(builtRequest.modelId(), is("testmodel")); - assertThat(doesConverseRequestHasMessage(builtRequest, "test message"), is(true)); - assertFalse(doesConverseRequestHaveAnyTemperatureInput(builtRequest)); - assertFalse(doesConverseRequestHaveAnyTopPInput(builtRequest)); - assertFalse(doesConverseRequestHaveAnyTopKInput(builtRequest)); - assertTrue(doesConverseRequestHaveMaxTokensInput(builtRequest, 128)); - } -} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java index 8c1af747c1b62..6c7eb613dd6cc 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java @@ -7,6 +7,8 @@ package org.elasticsearch.xpack.inference.services.amazonbedrock; +import software.amazon.awssdk.services.bedrockruntime.model.BedrockRuntimeException; + import org.elasticsearch.ElasticsearchException; import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; @@ -32,7 +34,6 @@ import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults; import org.elasticsearch.xpack.inference.Utils; import org.elasticsearch.xpack.inference.external.amazonbedrock.AmazonBedrockMockRequestSender; -import org.elasticsearch.xpack.inference.external.amazonbedrock.AmazonBedrockRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.Sender; import org.elasticsearch.xpack.inference.services.ServiceComponentsTests; @@ -1167,12 +1168,19 @@ public void testInfer_UnauthorizedResponse() throws IOException { var factory = mock(HttpRequestSender.Factory.class); when(factory.createSender()).thenReturn(sender); - var amazonBedrockFactory = new AmazonBedrockRequestSender.Factory( + var amazonBedrockFactory = new AmazonBedrockMockRequestSender.Factory( ServiceComponentsTests.createWithSettings(threadPool, Settings.EMPTY), mockClusterServiceEmpty() ); - try (var service = new AmazonBedrockService(factory, amazonBedrockFactory, createWithEmptySettings(threadPool))) { + try ( + var service = new AmazonBedrockService(factory, amazonBedrockFactory, createWithEmptySettings(threadPool)); + var requestSender = (AmazonBedrockMockRequestSender) amazonBedrockFactory.createSender() + ) { + requestSender.enqueue( + BedrockRuntimeException.builder().message("The security token included in the request is invalid").build() + ); + var model = AmazonBedrockEmbeddingsModelTests.createModel( "id", "us-east-1",