Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: [vertexai] add custom headers support in VertexAI #11085

Merged
merged 1 commit into from
Aug 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,12 @@
import com.google.common.base.Supplier;
import com.google.common.base.Suppliers;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.errorprone.annotations.CanIgnoreReturnValue;
import java.io.IOException;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.logging.Level;
import java.util.logging.Logger;
Expand All @@ -63,6 +66,7 @@ public class VertexAI implements AutoCloseable {
private final String location;
private final String apiEndpoint;
private final Transport transport;
private final HeaderProvider headerProvider;
private final CredentialsProvider credentialsProvider;

private final transient Supplier<PredictionServiceClient> predictionClientSupplier;
Expand All @@ -85,6 +89,7 @@ public VertexAI(String projectId, String location) {
location,
Transport.GRPC,
ImmutableList.of(),
/* customHeaders= */ ImmutableMap.of(),
/* credentials= */ Optional.empty(),
/* apiEndpoint= */ Optional.empty(),
/* predictionClientSupplierOpt= */ Optional.empty(),
Expand All @@ -108,6 +113,7 @@ public VertexAI() {
null,
Transport.GRPC,
ImmutableList.of(),
/* customHeaders= */ ImmutableMap.of(),
/* credentials= */ Optional.empty(),
/* apiEndpoint= */ Optional.empty(),
/* predictionClientSupplierOpt= */ Optional.empty(),
Expand All @@ -119,6 +125,7 @@ private VertexAI(
String location,
Transport transport,
List<String> scopes,
Map<String, String> customHeaders,
Optional<Credentials> credentials,
Optional<String> apiEndpoint,
Optional<Supplier<PredictionServiceClient>> predictionClientSupplierOpt,
Expand All @@ -131,6 +138,15 @@ private VertexAI(
this.location = Strings.isNullOrEmpty(location) ? inferLocation() : location;
this.transport = transport;

String sdkHeader =
String.format(
"%s/%s",
Constants.USER_AGENT_HEADER,
GaxProperties.getLibraryVersion(PredictionServiceSettings.class));
Map<String, String> headers = new HashMap<>(customHeaders);
headers.compute("user-agent", (k, v) -> v == null ? sdkHeader : sdkHeader + " " + v);
this.headerProvider = FixedHeaderProvider.create(headers);

if (credentials.isPresent()) {
this.credentialsProvider = FixedCredentialsProvider.create(credentials.get());
} else {
Expand Down Expand Up @@ -160,6 +176,7 @@ public static class Builder {
private String location;
private Transport transport = Transport.GRPC;
private ImmutableList<String> scopes = ImmutableList.of();
private ImmutableMap<String, String> customHeaders = ImmutableMap.of();
private Optional<Credentials> credentials = Optional.empty();
private Optional<String> apiEndpoint = Optional.empty();

Expand All @@ -174,6 +191,7 @@ public VertexAI build() {
location,
transport,
scopes,
customHeaders,
credentials,
apiEndpoint,
Optional.ofNullable(predictionClientSupplier),
Expand Down Expand Up @@ -240,6 +258,14 @@ public Builder setScopes(List<String> scopes) {
this.scopes = ImmutableList.copyOf(scopes);
return this;
}

@CanIgnoreReturnValue
public Builder setCustomHeaders(Map<String, String> customHeaders) {
checkNotNull(customHeaders, "customHeaders can't be null");

this.customHeaders = ImmutableMap.copyOf(customHeaders);
return this;
}
}

/**
Expand Down Expand Up @@ -278,6 +304,15 @@ public String getApiEndpoint() {
return apiEndpoint;
}

/**
* Returns the headers to use when making API calls.
*
* @return a map of headers to use when making API calls.
*/
public Map<String, String> getHeaders() {
return headerProvider.getHeaders();
}

/**
* Returns the default credentials to use when making API calls.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,15 @@
import static org.mockito.Mockito.mockStatic;
import static org.mockito.Mockito.when;

import com.google.api.gax.core.GaxProperties;
import com.google.api.gax.core.GoogleCredentialsProvider;
import com.google.auth.oauth2.GoogleCredentials;
import com.google.cloud.vertexai.api.PredictionServiceClient;
import com.google.cloud.vertexai.api.PredictionServiceSettings;
import com.google.common.collect.ImmutableList;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import org.junit.Rule;
import org.junit.Test;
Expand Down Expand Up @@ -397,4 +400,59 @@ public void testInstantiateVertexAI_builderWithTransport_shouldContainRightField
assertThat(vertexAi.getTransport()).isEqualTo(Transport.REST);
assertThat(vertexAi.getApiEndpoint()).isEqualTo(TEST_DEFAULT_ENDPOINT);
}

@Test
public void testInstantiateVertexAI_builderWithCustomHeaders_shouldContainRightFields()
throws IOException {
Map<String, String> customHeaders = new HashMap<>();
customHeaders.put("test_key", "test_value");

vertexAi =
new VertexAI.Builder()
.setProjectId(TEST_PROJECT)
.setLocation(TEST_LOCATION)
.setCustomHeaders(customHeaders)
.build();

assertThat(vertexAi.getProjectId()).isEqualTo(TEST_PROJECT);
assertThat(vertexAi.getLocation()).isEqualTo(TEST_LOCATION);
// headers should include both the sdk header and the custom headers
Map<String, String> expectedHeaders = new HashMap<>(customHeaders);
expectedHeaders.put(
"user-agent",
String.format(
"%s/%s",
Constants.USER_AGENT_HEADER,
GaxProperties.getLibraryVersion(PredictionServiceSettings.class)));
assertThat(vertexAi.getHeaders()).isEqualTo(expectedHeaders);
}

@Test
public void
testInstantiateVertexAI_builderWithCustomHeadersWithSdkReservedKey_shouldContainRightFields()
throws IOException {
Map<String, String> customHeadersWithSdkReservedKey = new HashMap<>();
customHeadersWithSdkReservedKey.put("user-agent", "test_value");

vertexAi =
new VertexAI.Builder()
.setProjectId(TEST_PROJECT)
.setLocation(TEST_LOCATION)
.setCustomHeaders(customHeadersWithSdkReservedKey)
.build();

assertThat(vertexAi.getProjectId()).isEqualTo(TEST_PROJECT);
assertThat(vertexAi.getLocation()).isEqualTo(TEST_LOCATION);
// headers should include sdk reserved key with value of both the sdk header and the custom
// headers
Map<String, String> expectedHeaders = new HashMap<>();
expectedHeaders.put(
"user-agent",
String.format(
"%s/%s %s",
Constants.USER_AGENT_HEADER,
GaxProperties.getLibraryVersion(PredictionServiceSettings.class),
"test_value"));
assertThat(vertexAi.getHeaders()).isEqualTo(expectedHeaders);
}
}
Loading