forked from elastic/elasticsearch
-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[8.x] [ML] Stream Bedrock Completion (elastic#114732) (elastic#114781)
* [ML] Stream Bedrock Completion (elastic#114732) Notes: - Adds a new API to the chatCompletionRequest to invoke the Bedrock Stream API - Create a StreamingChatProcessor that subscribes to streaming results from bedrock and handles the parsing on another thread. - There was no good way (that I could see) to extend the Provider-based CompletionRequestEntity, so they have been flattened into one RequestEntity that can be shared between ConverseRequest and ConverseStreamRequest. * Use jdk17 API
- Loading branch information
Showing
33 changed files
with
702 additions
and
920 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
pr: 114732 | ||
summary: Stream Bedrock Completion | ||
area: Machine Learning | ||
type: enhancement | ||
issues: [] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
156 changes: 156 additions & 0 deletions
156
...ticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockStreamingChatProcessor.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,156 @@ | ||
/* | ||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
* or more contributor license agreements. Licensed under the Elastic License | ||
* 2.0; you may not use this file except in compliance with the Elastic License | ||
* 2.0. | ||
*/ | ||
|
||
package org.elasticsearch.xpack.inference.external.amazonbedrock; | ||
|
||
import software.amazon.awssdk.services.bedrockruntime.model.ContentBlockDeltaEvent; | ||
import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamOutput; | ||
import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamResponseHandler; | ||
|
||
import org.elasticsearch.ElasticsearchException; | ||
import org.elasticsearch.common.util.concurrent.EsExecutors; | ||
import org.elasticsearch.core.Strings; | ||
import org.elasticsearch.threadpool.ThreadPool; | ||
import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults; | ||
|
||
import java.util.ArrayDeque; | ||
import java.util.concurrent.CompletableFuture; | ||
import java.util.concurrent.Flow; | ||
import java.util.concurrent.atomic.AtomicBoolean; | ||
import java.util.concurrent.atomic.AtomicLong; | ||
import java.util.concurrent.atomic.AtomicReference; | ||
|
||
import static org.elasticsearch.xpack.inference.InferencePlugin.UTILITY_THREAD_POOL_NAME; | ||
|
||
class AmazonBedrockStreamingChatProcessor implements Flow.Processor<ConverseStreamOutput, StreamingChatCompletionResults.Results> { | ||
private final AtomicReference<Throwable> error = new AtomicReference<>(null); | ||
private final AtomicLong demand = new AtomicLong(0); | ||
private final AtomicBoolean isDone = new AtomicBoolean(false); | ||
private final AtomicBoolean onCompleteCalled = new AtomicBoolean(false); | ||
private final AtomicBoolean onErrorCalled = new AtomicBoolean(false); | ||
private final ThreadPool threadPool; | ||
private volatile Flow.Subscriber<? super StreamingChatCompletionResults.Results> downstream; | ||
private volatile Flow.Subscription upstream; | ||
|
||
AmazonBedrockStreamingChatProcessor(ThreadPool threadPool) { | ||
this.threadPool = threadPool; | ||
} | ||
|
||
@Override | ||
public void subscribe(Flow.Subscriber<? super StreamingChatCompletionResults.Results> subscriber) { | ||
if (downstream == null) { | ||
downstream = subscriber; | ||
downstream.onSubscribe(new StreamSubscription()); | ||
} else { | ||
subscriber.onError(new IllegalStateException("Subscriber already set.")); | ||
} | ||
} | ||
|
||
@Override | ||
public void onSubscribe(Flow.Subscription subscription) { | ||
if (upstream == null) { | ||
upstream = subscription; | ||
var currentRequestCount = demand.getAndUpdate(i -> 0); | ||
if (currentRequestCount > 0) { | ||
upstream.request(currentRequestCount); | ||
} | ||
} else { | ||
subscription.cancel(); | ||
} | ||
} | ||
|
||
@Override | ||
public void onNext(ConverseStreamOutput item) { | ||
if (item.sdkEventType() == ConverseStreamOutput.EventType.CONTENT_BLOCK_DELTA) { | ||
demand.set(0); // reset demand before we fork to another thread | ||
item.accept(ConverseStreamResponseHandler.Visitor.builder().onContentBlockDelta(this::sendDownstreamOnAnotherThread).build()); | ||
} else { | ||
upstream.request(1); | ||
} | ||
} | ||
|
||
// this is always called from a netty thread maintained by the AWS SDK, we'll move it to our thread to process the response | ||
private void sendDownstreamOnAnotherThread(ContentBlockDeltaEvent event) { | ||
CompletableFuture.runAsync(() -> { | ||
var text = event.delta().text(); | ||
var result = new ArrayDeque<StreamingChatCompletionResults.Result>(1); | ||
result.offer(new StreamingChatCompletionResults.Result(text)); | ||
var results = new StreamingChatCompletionResults.Results(result); | ||
downstream.onNext(results); | ||
}, threadPool.executor(UTILITY_THREAD_POOL_NAME)); | ||
} | ||
|
||
@Override | ||
public void onError(Throwable amazonBedrockRuntimeException) { | ||
error.set( | ||
new ElasticsearchException( | ||
Strings.format("AmazonBedrock StreamingChatProcessor failure: [%s]", amazonBedrockRuntimeException.getMessage()), | ||
amazonBedrockRuntimeException | ||
) | ||
); | ||
if (isDone.compareAndSet(false, true) && checkAndResetDemand() && onErrorCalled.compareAndSet(false, true)) { | ||
downstream.onError(error.get()); | ||
} | ||
} | ||
|
||
private boolean checkAndResetDemand() { | ||
return demand.getAndUpdate(i -> 0L) > 0L; | ||
} | ||
|
||
@Override | ||
public void onComplete() { | ||
if (isDone.compareAndSet(false, true) && checkAndResetDemand() && onCompleteCalled.compareAndSet(false, true)) { | ||
downstream.onComplete(); | ||
} | ||
} | ||
|
||
private class StreamSubscription implements Flow.Subscription { | ||
@Override | ||
public void request(long n) { | ||
if (n > 0L) { | ||
demand.updateAndGet(i -> { | ||
var sum = i + n; | ||
return sum >= 0 ? sum : Long.MAX_VALUE; | ||
}); | ||
if (upstream == null) { | ||
// wait for upstream to subscribe before forwarding request | ||
return; | ||
} | ||
if (upstreamIsRunning()) { | ||
requestOnMlThread(n); | ||
} else if (error.get() != null && onErrorCalled.compareAndSet(false, true)) { | ||
downstream.onError(error.get()); | ||
} else if (onCompleteCalled.compareAndSet(false, true)) { | ||
downstream.onComplete(); | ||
} | ||
} else { | ||
cancel(); | ||
downstream.onError(new IllegalStateException("Cannot request a negative number.")); | ||
} | ||
} | ||
|
||
private boolean upstreamIsRunning() { | ||
return isDone.get() == false && error.get() == null; | ||
} | ||
|
||
private void requestOnMlThread(long n) { | ||
var currentThreadPool = EsExecutors.executorName(Thread.currentThread().getName()); | ||
if (UTILITY_THREAD_POOL_NAME.equalsIgnoreCase(currentThreadPool)) { | ||
upstream.request(n); | ||
} else { | ||
CompletableFuture.runAsync(() -> upstream.request(n), threadPool.executor(UTILITY_THREAD_POOL_NAME)); | ||
} | ||
} | ||
|
||
@Override | ||
public void cancel() { | ||
if (upstream != null && upstreamIsRunning()) { | ||
upstream.cancel(); | ||
} | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.