Skip to content

Commit

Permalink
[8.x] [ML] Stream Bedrock Completion (elastic#114732) (elastic#114781)
Browse files Browse the repository at this point in the history
* [ML] Stream Bedrock Completion (elastic#114732)

Notes:
- Adds a new API to the chatCompletionRequest to invoke the Bedrock
  Stream API
- Create a StreamingChatProcessor that subscribes to streaming results
  from bedrock and handles the parsing on another thread.
- There was no good way (that I could see) to extend the Provider-based
  CompletionRequestEntity, so they have been flattened into one
  RequestEntity that can be shared between ConverseRequest and
  ConverseStreamRequest.

* Use jdk17 API
  • Loading branch information
prwhelan authored Oct 15, 2024
1 parent ffcf87c commit f273fdc
Show file tree
Hide file tree
Showing 33 changed files with 702 additions and 920 deletions.
5 changes: 5 additions & 0 deletions docs/changelog/114732.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 114732
summary: Stream Bedrock Completion
area: Machine Learning
type: enhancement
issues: []
1 change: 1 addition & 0 deletions x-pack/plugin/inference/src/main/java/module-info.java
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
requires software.amazon.awssdk.profiles;
requires org.slf4j;
requires software.amazon.awssdk.retries.api;
requires org.reactivestreams;

exports org.elasticsearch.xpack.inference.action;
exports org.elasticsearch.xpack.inference.registry;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults;
import org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockChatCompletionRequest;
import org.elasticsearch.xpack.inference.external.response.amazonbedrock.AmazonBedrockResponseHandler;
import org.elasticsearch.xpack.inference.external.response.amazonbedrock.completion.AmazonBedrockChatCompletionResponseListener;
Expand All @@ -33,11 +34,16 @@ protected AmazonBedrockChatCompletionExecutor(

@Override
protected void executeClientRequest(AmazonBedrockBaseClient awsBedrockClient) {
var chatCompletionResponseListener = new AmazonBedrockChatCompletionResponseListener(
chatCompletionRequest,
responseHandler,
inferenceResultsListener
);
chatCompletionRequest.executeChatCompletionRequest(awsBedrockClient, chatCompletionResponseListener);
if (chatCompletionRequest.isStreaming()) {
var publisher = chatCompletionRequest.executeStreamChatCompletionRequest(awsBedrockClient);
inferenceResultsListener.onResponse(new StreamingChatCompletionResults(publisher));
} else {
var chatCompletionResponseListener = new AmazonBedrockChatCompletionResponseListener(
chatCompletionRequest,
responseHandler,
inferenceResultsListener
);
chatCompletionRequest.executeChatCompletionRequest(awsBedrockClient, chatCompletionResponseListener);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,22 @@

import software.amazon.awssdk.services.bedrockruntime.model.ConverseRequest;
import software.amazon.awssdk.services.bedrockruntime.model.ConverseResponse;
import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamRequest;
import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelRequest;
import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse;

import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.xcontent.ChunkedToXContent;

import java.time.Instant;
import java.util.concurrent.Flow;

public interface AmazonBedrockClient {
void converse(ConverseRequest converseRequest, ActionListener<ConverseResponse> responseListener) throws ElasticsearchException;

Flow.Publisher<? extends ChunkedToXContent> converseStream(ConverseStreamRequest converseStreamRequest) throws ElasticsearchException;

void invokeModel(InvokeModelRequest invokeModelRequest, ActionListener<InvokeModelResponse> responseListener)
throws ElasticsearchException;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,21 @@
import software.amazon.awssdk.services.bedrockruntime.model.BedrockRuntimeException;
import software.amazon.awssdk.services.bedrockruntime.model.ConverseRequest;
import software.amazon.awssdk.services.bedrockruntime.model.ConverseResponse;
import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamRequest;
import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamResponseHandler;
import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelRequest;
import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse;

import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.SpecialPermission;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.xcontent.ChunkedToXContent;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.Strings;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockModel;
import org.reactivestreams.FlowAdapters;
import org.slf4j.LoggerFactory;

import java.security.AccessController;
Expand All @@ -36,6 +41,7 @@
import java.util.Objects;
import java.util.concurrent.CompletionException;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Flow;

/**
* Not marking this as "final" so we can subclass it for mocking
Expand All @@ -53,19 +59,21 @@ public class AmazonBedrockInferenceClient extends AmazonBedrockBaseClient {
private static final Duration DEFAULT_CLIENT_TIMEOUT_MS = Duration.ofMillis(10000);

private final BedrockRuntimeAsyncClient internalClient;
private final ThreadPool threadPool;
private volatile Instant expiryTimestamp;

public static AmazonBedrockBaseClient create(AmazonBedrockModel model, @Nullable TimeValue timeout) {
public static AmazonBedrockBaseClient create(AmazonBedrockModel model, @Nullable TimeValue timeout, ThreadPool threadPool) {
try {
return new AmazonBedrockInferenceClient(model, timeout);
return new AmazonBedrockInferenceClient(model, timeout, threadPool);
} catch (Exception e) {
throw new ElasticsearchException("Failed to create Amazon Bedrock Client", e);
}
}

protected AmazonBedrockInferenceClient(AmazonBedrockModel model, @Nullable TimeValue timeout) {
protected AmazonBedrockInferenceClient(AmazonBedrockModel model, @Nullable TimeValue timeout, ThreadPool threadPool) {
super(model, timeout);
this.internalClient = createAmazonBedrockClient(model, timeout);
this.threadPool = Objects.requireNonNull(threadPool);
setExpiryTimestamp();
}

Expand All @@ -79,6 +87,16 @@ public void converse(ConverseRequest converseRequest, ActionListener<ConverseRes
}
}

@Override
public Flow.Publisher<? extends ChunkedToXContent> converseStream(ConverseStreamRequest request) throws ElasticsearchException {
var awsResponseProcessor = new AmazonBedrockStreamingChatProcessor(threadPool);
internalClient.converseStream(
request,
ConverseStreamResponseHandler.builder().subscriber(() -> FlowAdapters.toSubscriber(awsResponseProcessor)).build()
);
return awsResponseProcessor;
}

private void onFailure(ActionListener<?> listener, Throwable t, String method) {
var unwrappedException = t;
if (t instanceof CompletionException || t instanceof ExecutionException) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,9 @@ public final class AmazonBedrockInferenceClientCache implements AmazonBedrockCli
// not final for testing
private Clock clock;

public AmazonBedrockInferenceClientCache(
BiFunction<AmazonBedrockModel, TimeValue, AmazonBedrockBaseClient> creator,
@Nullable Clock clock
) {
public AmazonBedrockInferenceClientCache(BiFunction<AmazonBedrockModel, TimeValue, AmazonBedrockBaseClient> creator, Clock clock) {
this.creator = Objects.requireNonNull(creator);
this.clock = Objects.requireNonNullElse(clock, Clock.systemUTC());
this.clock = Objects.requireNonNull(clock);
}

public AmazonBedrockBaseClient getOrCreateClient(AmazonBedrockModel model, @Nullable TimeValue timeout) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.elasticsearch.xpack.inference.services.ServiceComponents;

import java.io.IOException;
import java.time.Clock;
import java.util.Objects;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
Expand All @@ -42,7 +43,10 @@ public Factory(ServiceComponents serviceComponents, ClusterService clusterServic
}

public Sender createSender() {
var clientCache = new AmazonBedrockInferenceClientCache(AmazonBedrockInferenceClient::create, null);
var clientCache = new AmazonBedrockInferenceClientCache(
(model, timeout) -> AmazonBedrockInferenceClient.create(model, timeout, serviceComponents.threadPool()),
Clock.systemUTC()
);
return createSender(new AmazonBedrockExecuteOnlyRequestSender(clientCache, serviceComponents.throttlerManager()));
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.amazonbedrock;

import software.amazon.awssdk.services.bedrockruntime.model.ContentBlockDeltaEvent;
import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamOutput;
import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamResponseHandler;

import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.core.Strings;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults;

import java.util.ArrayDeque;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Flow;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;

import static org.elasticsearch.xpack.inference.InferencePlugin.UTILITY_THREAD_POOL_NAME;

class AmazonBedrockStreamingChatProcessor implements Flow.Processor<ConverseStreamOutput, StreamingChatCompletionResults.Results> {
private final AtomicReference<Throwable> error = new AtomicReference<>(null);
private final AtomicLong demand = new AtomicLong(0);
private final AtomicBoolean isDone = new AtomicBoolean(false);
private final AtomicBoolean onCompleteCalled = new AtomicBoolean(false);
private final AtomicBoolean onErrorCalled = new AtomicBoolean(false);
private final ThreadPool threadPool;
private volatile Flow.Subscriber<? super StreamingChatCompletionResults.Results> downstream;
private volatile Flow.Subscription upstream;

AmazonBedrockStreamingChatProcessor(ThreadPool threadPool) {
this.threadPool = threadPool;
}

@Override
public void subscribe(Flow.Subscriber<? super StreamingChatCompletionResults.Results> subscriber) {
if (downstream == null) {
downstream = subscriber;
downstream.onSubscribe(new StreamSubscription());
} else {
subscriber.onError(new IllegalStateException("Subscriber already set."));
}
}

@Override
public void onSubscribe(Flow.Subscription subscription) {
if (upstream == null) {
upstream = subscription;
var currentRequestCount = demand.getAndUpdate(i -> 0);
if (currentRequestCount > 0) {
upstream.request(currentRequestCount);
}
} else {
subscription.cancel();
}
}

@Override
public void onNext(ConverseStreamOutput item) {
if (item.sdkEventType() == ConverseStreamOutput.EventType.CONTENT_BLOCK_DELTA) {
demand.set(0); // reset demand before we fork to another thread
item.accept(ConverseStreamResponseHandler.Visitor.builder().onContentBlockDelta(this::sendDownstreamOnAnotherThread).build());
} else {
upstream.request(1);
}
}

// this is always called from a netty thread maintained by the AWS SDK, we'll move it to our thread to process the response
private void sendDownstreamOnAnotherThread(ContentBlockDeltaEvent event) {
CompletableFuture.runAsync(() -> {
var text = event.delta().text();
var result = new ArrayDeque<StreamingChatCompletionResults.Result>(1);
result.offer(new StreamingChatCompletionResults.Result(text));
var results = new StreamingChatCompletionResults.Results(result);
downstream.onNext(results);
}, threadPool.executor(UTILITY_THREAD_POOL_NAME));
}

@Override
public void onError(Throwable amazonBedrockRuntimeException) {
error.set(
new ElasticsearchException(
Strings.format("AmazonBedrock StreamingChatProcessor failure: [%s]", amazonBedrockRuntimeException.getMessage()),
amazonBedrockRuntimeException
)
);
if (isDone.compareAndSet(false, true) && checkAndResetDemand() && onErrorCalled.compareAndSet(false, true)) {
downstream.onError(error.get());
}
}

private boolean checkAndResetDemand() {
return demand.getAndUpdate(i -> 0L) > 0L;
}

@Override
public void onComplete() {
if (isDone.compareAndSet(false, true) && checkAndResetDemand() && onCompleteCalled.compareAndSet(false, true)) {
downstream.onComplete();
}
}

private class StreamSubscription implements Flow.Subscription {
@Override
public void request(long n) {
if (n > 0L) {
demand.updateAndGet(i -> {
var sum = i + n;
return sum >= 0 ? sum : Long.MAX_VALUE;
});
if (upstream == null) {
// wait for upstream to subscribe before forwarding request
return;
}
if (upstreamIsRunning()) {
requestOnMlThread(n);
} else if (error.get() != null && onErrorCalled.compareAndSet(false, true)) {
downstream.onError(error.get());
} else if (onCompleteCalled.compareAndSet(false, true)) {
downstream.onComplete();
}
} else {
cancel();
downstream.onError(new IllegalStateException("Cannot request a negative number."));
}
}

private boolean upstreamIsRunning() {
return isDone.get() == false && error.get() == null;
}

private void requestOnMlThread(long n) {
var currentThreadPool = EsExecutors.executorName(Thread.currentThread().getName());
if (UTILITY_THREAD_POOL_NAME.equalsIgnoreCase(currentThreadPool)) {
upstream.request(n);
} else {
CompletableFuture.runAsync(() -> upstream.request(n), threadPool.executor(UTILITY_THREAD_POOL_NAME));
}
}

@Override
public void cancel() {
if (upstream != null && upstreamIsRunning()) {
upstream.cancel();
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import org.elasticsearch.xpack.inference.external.response.amazonbedrock.completion.AmazonBedrockChatCompletionResponseHandler;
import org.elasticsearch.xpack.inference.services.amazonbedrock.completion.AmazonBedrockChatCompletionModel;

import java.util.List;
import java.util.function.Supplier;

public class AmazonBedrockChatCompletionRequestManager extends AmazonBedrockRequestManager {
Expand All @@ -45,9 +44,11 @@ public void execute(
Supplier<Boolean> hasRequestCompletedFunction,
ActionListener<InferenceServiceResults> listener
) {
List<String> docsInput = DocumentsOnlyInput.of(inferenceInputs).getInputs();
var docsOnly = DocumentsOnlyInput.of(inferenceInputs);
var docsInput = docsOnly.getInputs();
var stream = docsOnly.stream();
var requestEntity = AmazonBedrockChatCompletionEntityFactory.createEntity(model, docsInput);
var request = new AmazonBedrockChatCompletionRequest(model, requestEntity, timeout);
var request = new AmazonBedrockChatCompletionRequest(model, requestEntity, timeout, stream);
var responseHandler = new AmazonBedrockChatCompletionResponseHandler();

try {
Expand Down
Loading

0 comments on commit f273fdc

Please sign in to comment.