Skip to content

Commit

Permalink
Add retry support to VertexAI embedding and chat models
Browse files Browse the repository at this point in the history
Resolves spring-projects#832

Introduces retry functionality to VertexAI embedding and
chat models, enhancing their resilience against transient failures.

It also corrects a typo in the VertexAiEmbeddingConnectionDetails
class name.

Key changes:

* Add RetryTemplate to VertexAiTextEmbeddingModel and VertexAiGeminiChatModel
* Introduce spring-ai-retry dependency
* Refactor code to support retry logic
* Update auto-configuration classes to incorporate retry functionality
* Fix typo in VertexAiEmbeddingConnectionDetails class name

remove extraneous commented out code

Add missing copyright headers, author etc.
  • Loading branch information
markpollack authored and sobychacko committed Oct 2, 2024
1 parent 6fc76b7 commit bb88e2f
Show file tree
Hide file tree
Showing 14 changed files with 624 additions and 115 deletions.
6 changes: 6 additions & 0 deletions models/spring-ai-vertex-ai-embedding/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,12 @@
<version>${project.parent.version}</version>
</dependency>

<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-retry</artifactId>
<version>${project.parent.version}</version>
</dependency>

<dependency>
<groupId>org.springframework</groupId>
<artifactId>spring-web</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,15 @@
import com.google.cloud.aiplatform.v1.PredictionServiceSettings;

/**
* VertexAiEmbeddigConnectionDetails represents the details of a connection to the Vertex
* VertexAiEmbeddingConnectionDetails represents the details of a connection to the Vertex
* AI embedding service. It provides methods to access the project ID, location,
* publisher, and PredictionServiceSettings.
*
* @author Christian Tzolov
* @author Mark Pollack
* @since 1.0.0
*/
public class VertexAiEmbeddigConnectionDetails {
public class VertexAiEmbeddingConnectionDetails {

private static final String DEFAULT_LOCATION = "us-central1";

Expand Down Expand Up @@ -55,7 +59,7 @@ public class VertexAiEmbeddigConnectionDetails {

private final String publisher;

public VertexAiEmbeddigConnectionDetails(String endpoint, String projectId, String location, String publisher) {
public VertexAiEmbeddingConnectionDetails(String endpoint, String projectId, String location, String publisher) {
this.projectId = projectId;
this.location = location;
this.publisher = publisher;
Expand Down Expand Up @@ -119,7 +123,7 @@ public Builder withPublisher(String publisher) {
return this;
}

public VertexAiEmbeddigConnectionDetails build() {
public VertexAiEmbeddingConnectionDetails build() {
if (!StringUtils.hasText(this.endpoint)) {
if (!StringUtils.hasText(this.location)) {
this.endpoint = DEFAULT_ENDPOINT;
Expand All @@ -134,7 +138,7 @@ public VertexAiEmbeddigConnectionDetails build() {
this.publisher = DEFAULT_PUBLISHER;
}

return new VertexAiEmbeddigConnectionDetails(this.endpoint, this.projectId, this.location, this.publisher);
return new VertexAiEmbeddingConnectionDetails(this.endpoint, this.projectId, this.location, this.publisher);
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
import org.springframework.ai.embedding.EmbeddingResultMetadata;
import org.springframework.ai.embedding.EmbeddingResultMetadata.ModalityType;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.vertexai.embedding.VertexAiEmbeddigConnectionDetails;
import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingConnectionDetails;
import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingUsage;
import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingUtils;
import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingUtils.ImageBuilder;
Expand All @@ -59,6 +59,7 @@
* is not yet fully functional and is subject to change.
*
* @author Christian Tzolov
* @author Mark Pollack
* @since 1.0.0
*/
public class VertexAiMultimodalEmbeddingModel implements DocumentEmbeddingModel {
Expand All @@ -76,9 +77,9 @@ public class VertexAiMultimodalEmbeddingModel implements DocumentEmbeddingModel
private static final List<MimeType> SUPPORTED_IMAGE_MIME_SUB_TYPES = List.of(MimeTypeUtils.IMAGE_JPEG,
MimeTypeUtils.IMAGE_GIF, MimeTypeUtils.IMAGE_PNG, MimeTypeUtils.parseMimeType("image/bmp"));

private final VertexAiEmbeddigConnectionDetails connectionDetails;
private final VertexAiEmbeddingConnectionDetails connectionDetails;

public VertexAiMultimodalEmbeddingModel(VertexAiEmbeddigConnectionDetails connectionDetails,
public VertexAiMultimodalEmbeddingModel(VertexAiEmbeddingConnectionDetails connectionDetails,
VertexAiMultimodalEmbeddingOptions defaultEmbeddingOptions) {

Assert.notNull(defaultEmbeddingOptions, "VertexAiMultimodalEmbeddingOptions must not be null");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,17 @@
import org.springframework.ai.embedding.EmbeddingResponse;
import org.springframework.ai.embedding.EmbeddingResponseMetadata;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.vertexai.embedding.VertexAiEmbeddigConnectionDetails;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingConnectionDetails;
import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingUsage;
import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingUtils;
import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingUtils.TextInstanceBuilder;
import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingUtils.TextParametersBuilder;
import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingUsage;
import org.springframework.retry.support.RetryTemplate;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
Expand All @@ -47,22 +50,29 @@
* A class representing a Vertex AI Text Embedding Model.
*
* @author Christian Tzolov
* @author Mark Pollack
* @since 1.0.0
*/
public class VertexAiTextEmbeddingModel extends AbstractEmbeddingModel {

public final VertexAiTextEmbeddingOptions defaultOptions;

private final VertexAiEmbeddigConnectionDetails connectionDetails;
private final VertexAiEmbeddingConnectionDetails connectionDetails;

private final RetryTemplate retryTemplate;

public VertexAiTextEmbeddingModel(VertexAiEmbeddigConnectionDetails connectionDetails,
public VertexAiTextEmbeddingModel(VertexAiEmbeddingConnectionDetails connectionDetails,
VertexAiTextEmbeddingOptions defaultEmbeddingOptions) {
this(connectionDetails, defaultEmbeddingOptions, RetryUtils.DEFAULT_RETRY_TEMPLATE);
}

public VertexAiTextEmbeddingModel(VertexAiEmbeddingConnectionDetails connectionDetails,
VertexAiTextEmbeddingOptions defaultEmbeddingOptions, RetryTemplate retryTemplate) {
Assert.notNull(defaultEmbeddingOptions, "VertexAiTextEmbeddingOptions must not be null");

Assert.notNull(retryTemplate, "retryTemplate must not be null");
this.defaultOptions = defaultEmbeddingOptions.initializeDefaults();

this.connectionDetails = connectionDetails;
this.retryTemplate = retryTemplate;
}

@Override
Expand All @@ -73,46 +83,23 @@ public float[] embed(Document document) {

@Override
public EmbeddingResponse call(EmbeddingRequest request) {
return retryTemplate.execute(context -> {
VertexAiTextEmbeddingOptions finalOptions = this.defaultOptions;

VertexAiTextEmbeddingOptions finalOptions = this.defaultOptions;

if (request.getOptions() != null && request.getOptions() != EmbeddingOptions.EMPTY) {
var defaultOptionsCopy = VertexAiTextEmbeddingOptions.builder().from(this.defaultOptions).build();
finalOptions = ModelOptionsUtils.merge(request.getOptions(), defaultOptionsCopy,
VertexAiTextEmbeddingOptions.class);
}

try (PredictionServiceClient client = PredictionServiceClient
.create(this.connectionDetails.getPredictionServiceSettings())) {

EndpointName endpointName = this.connectionDetails.getEndpointName(finalOptions.getModel());

PredictRequest.Builder predictRequestBuilder = PredictRequest.newBuilder()
.setEndpoint(endpointName.toString());

TextParametersBuilder parametersBuilder = TextParametersBuilder.of();

if (finalOptions.getAutoTruncate() != null) {
parametersBuilder.withAutoTruncate(finalOptions.getAutoTruncate());
}

if (finalOptions.getDimensions() != null) {
parametersBuilder.withOutputDimensionality(finalOptions.getDimensions());
if (request.getOptions() != null && request.getOptions() != EmbeddingOptions.EMPTY) {
var defaultOptionsCopy = VertexAiTextEmbeddingOptions.builder().from(this.defaultOptions).build();
finalOptions = ModelOptionsUtils.merge(request.getOptions(), defaultOptionsCopy,
VertexAiTextEmbeddingOptions.class);
}

predictRequestBuilder.setParameters(VertexAiEmbeddingUtils.valueOf(parametersBuilder.build()));
PredictionServiceClient client = createPredictionServiceClient();

for (int i = 0; i < request.getInstructions().size(); i++) {
EndpointName endpointName = this.connectionDetails.getEndpointName(finalOptions.getModel());

TextInstanceBuilder instanceBuilder = TextInstanceBuilder.of(request.getInstructions().get(i))
.withTaskType(finalOptions.getTaskType().name());
if (StringUtils.hasText(finalOptions.getTitle())) {
instanceBuilder.withTitle(finalOptions.getTitle());
}
predictRequestBuilder.addInstances(VertexAiEmbeddingUtils.valueOf(instanceBuilder.build()));
}
PredictRequest.Builder predictRequestBuilder = getPredictRequestBuilder(request, endpointName,
finalOptions);

PredictResponse embeddingResponse = client.predict(predictRequestBuilder.build());
PredictResponse embeddingResponse = getPredictResponse(client, predictRequestBuilder);

int index = 0;
int totalTokenCount = 0;
Expand All @@ -131,12 +118,53 @@ public EmbeddingResponse call(EmbeddingRequest request) {
}
return new EmbeddingResponse(embeddingList,
generateResponseMetadata(finalOptions.getModel(), totalTokenCount));
});
}

protected PredictRequest.Builder getPredictRequestBuilder(EmbeddingRequest request, EndpointName endpointName,
VertexAiTextEmbeddingOptions finalOptions) {
PredictRequest.Builder predictRequestBuilder = PredictRequest.newBuilder().setEndpoint(endpointName.toString());

TextParametersBuilder parametersBuilder = TextParametersBuilder.of();

if (finalOptions.getAutoTruncate() != null) {
parametersBuilder.withAutoTruncate(finalOptions.getAutoTruncate());
}
catch (Exception e) {

if (finalOptions.getDimensions() != null) {
parametersBuilder.withOutputDimensionality(finalOptions.getDimensions());
}

predictRequestBuilder.setParameters(VertexAiEmbeddingUtils.valueOf(parametersBuilder.build()));

for (int i = 0; i < request.getInstructions().size(); i++) {

TextInstanceBuilder instanceBuilder = TextInstanceBuilder.of(request.getInstructions().get(i))
.withTaskType(finalOptions.getTaskType().name());
if (StringUtils.hasText(finalOptions.getTitle())) {
instanceBuilder.withTitle(finalOptions.getTitle());
}
predictRequestBuilder.addInstances(VertexAiEmbeddingUtils.valueOf(instanceBuilder.build()));
}
return predictRequestBuilder;
}

// for testing
PredictionServiceClient createPredictionServiceClient() {
try {
return PredictionServiceClient.create(this.connectionDetails.getPredictionServiceSettings());
}
catch (IOException e) {
throw new RuntimeException(e);
}
}

// for testing
PredictResponse getPredictResponse(PredictionServiceClient client, PredictRequest.Builder predictRequestBuilder) {
PredictResponse embeddingResponse = client.predict(predictRequestBuilder.build());
return embeddingResponse;
}

private EmbeddingResponseMetadata generateResponseMetadata(String model, Integer totalTokens) {
EmbeddingResponseMetadata metadata = new EmbeddingResponseMetadata();
metadata.setModel(model);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import org.springframework.ai.embedding.DocumentEmbeddingRequest;
import org.springframework.ai.embedding.EmbeddingResponse;
import org.springframework.ai.embedding.EmbeddingResultMetadata;
import org.springframework.ai.vertexai.embedding.VertexAiEmbeddigConnectionDetails;
import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingConnectionDetails;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.SpringBootConfiguration;
import org.springframework.boot.test.context.SpringBootTest;
Expand Down Expand Up @@ -213,16 +213,16 @@ void textImageAndVideoEmbedding() {
static class Config {

@Bean
public VertexAiEmbeddigConnectionDetails connectionDetails() {
return VertexAiEmbeddigConnectionDetails.builder()
public VertexAiEmbeddingConnectionDetails connectionDetails() {
return VertexAiEmbeddingConnectionDetails.builder()
.withProjectId(System.getenv("VERTEX_AI_GEMINI_PROJECT_ID"))
.withLocation(System.getenv("VERTEX_AI_GEMINI_LOCATION"))
.build();
}

@Bean
public VertexAiMultimodalEmbeddingModel vertexAiEmbeddingModel(
VertexAiEmbeddigConnectionDetails connectionDetails) {
VertexAiEmbeddingConnectionDetails connectionDetails) {

VertexAiMultimodalEmbeddingOptions options = VertexAiMultimodalEmbeddingOptions.builder()
.withModel(VertexAiMultimodalEmbeddingModelName.MULTIMODAL_EMBEDDING_001)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
/*
* Copyright 2024-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.ai.vertexai.embedding.text;

import com.google.cloud.aiplatform.v1.EndpointName;
import com.google.cloud.aiplatform.v1.PredictRequest;
import com.google.cloud.aiplatform.v1.PredictResponse;
import com.google.cloud.aiplatform.v1.PredictionServiceClient;
import org.springframework.ai.embedding.EmbeddingRequest;
import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingConnectionDetails;
import org.springframework.retry.support.RetryTemplate;

import java.io.IOException;

public class TestVertexAiTextEmbeddingModel extends VertexAiTextEmbeddingModel {

private PredictionServiceClient mockPredictionServiceClient;

private PredictRequest.Builder mockPredictRequestBuilder;

public TestVertexAiTextEmbeddingModel(VertexAiEmbeddingConnectionDetails connectionDetails,
VertexAiTextEmbeddingOptions defaultEmbeddingOptions, RetryTemplate retryTemplate) {
super(connectionDetails, defaultEmbeddingOptions, retryTemplate);
}

public void setMockPredictionServiceClient(PredictionServiceClient mockPredictionServiceClient) {
this.mockPredictionServiceClient = mockPredictionServiceClient;
}

@Override
PredictionServiceClient createPredictionServiceClient() {
if (mockPredictionServiceClient != null) {
return mockPredictionServiceClient;
}
return super.createPredictionServiceClient();
}

@Override
PredictResponse getPredictResponse(PredictionServiceClient client, PredictRequest.Builder predictRequestBuilder) {
if (mockPredictionServiceClient != null) {
return mockPredictionServiceClient.predict(predictRequestBuilder.build());
}
return super.getPredictResponse(client, predictRequestBuilder);
}

public void setMockPredictRequestBuilder(PredictRequest.Builder mockPredictRequestBuilder) {
this.mockPredictRequestBuilder = mockPredictRequestBuilder;
}

@Override
protected PredictRequest.Builder getPredictRequestBuilder(EmbeddingRequest request, EndpointName endpointName,
VertexAiTextEmbeddingOptions finalOptions) {
if (mockPredictRequestBuilder != null) {
return mockPredictRequestBuilder;
}
return super.getPredictRequestBuilder(request, endpointName, finalOptions);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import org.junit.jupiter.params.provider.ValueSource;
import org.springframework.ai.embedding.EmbeddingRequest;
import org.springframework.ai.embedding.EmbeddingResponse;
import org.springframework.ai.vertexai.embedding.VertexAiEmbeddigConnectionDetails;
import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingConnectionDetails;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.SpringBootConfiguration;
import org.springframework.boot.test.context.SpringBootTest;
Expand Down Expand Up @@ -67,15 +67,15 @@ void defaultEmbedding(String modelName) {
static class Config {

@Bean
public VertexAiEmbeddigConnectionDetails connectionDetails() {
return VertexAiEmbeddigConnectionDetails.builder()
public VertexAiEmbeddingConnectionDetails connectionDetails() {
return VertexAiEmbeddingConnectionDetails.builder()
.withProjectId(System.getenv("VERTEX_AI_GEMINI_PROJECT_ID"))
.withLocation(System.getenv("VERTEX_AI_GEMINI_LOCATION"))
.build();
}

@Bean
public VertexAiTextEmbeddingModel vertexAiEmbeddingModel(VertexAiEmbeddigConnectionDetails connectionDetails) {
public VertexAiTextEmbeddingModel vertexAiEmbeddingModel(VertexAiEmbeddingConnectionDetails connectionDetails) {

VertexAiTextEmbeddingOptions options = VertexAiTextEmbeddingOptions.builder()
.withModel(VertexAiTextEmbeddingOptions.DEFAULT_MODEL_NAME)
Expand Down
Loading

0 comments on commit bb88e2f

Please sign in to comment.