Skip to content

Commit

Permalink
llms/googleai: Support more option.ClientOption and update google a…
Browse files Browse the repository at this point in the history
…i sdk (tmc#841)

* llms/googleai: support more `option.ClientOption` and update google ai sdk

`option.ClientOption` list:

- WithCredentialsJSON
- WithCredentialsFile
- WithHttpClient

* llms/googleai: fix palmclient and add test case for WithHTTPClient

* test: update test case for `VERTEX_CREDENTIALS`
  • Loading branch information
wangjiancn authored Jun 15, 2024
1 parent 587d3ce commit 9817356
Show file tree
Hide file tree
Showing 11 changed files with 250 additions and 139 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ cover.cov
# dev
.env
vendor/*
service-account.json

embeddings/cybertron/models/*
examples/cybertron-embedding-example/models/*
61 changes: 32 additions & 29 deletions examples/vertex-completion-example/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -7,44 +7,47 @@ toolchain go1.22.1
require github.com/tmc/langchaingo v0.1.10

require (
cloud.google.com/go v0.112.1 // indirect
cloud.google.com/go/ai v0.3.5-0.20240409161017-ce55ad694f21 // indirect
cloud.google.com/go/aiplatform v1.60.0 // indirect
cloud.google.com/go/compute v1.24.0 // indirect
cloud.google.com/go/compute/metadata v0.2.3 // indirect
cloud.google.com/go/iam v1.1.6 // indirect
cloud.google.com/go/longrunning v0.5.6 // indirect
cloud.google.com/go/vertexai v0.7.1 // indirect
github.com/dlclark/regexp2 v1.10.0 // indirect
cloud.google.com/go v0.113.0 // indirect
cloud.google.com/go/ai v0.5.0 // indirect
cloud.google.com/go/aiplatform v1.67.0 // indirect
cloud.google.com/go/auth v0.4.1 // indirect
cloud.google.com/go/auth/oauth2adapt v0.2.2 // indirect
cloud.google.com/go/compute/metadata v0.3.0 // indirect
cloud.google.com/go/iam v1.1.7 // indirect
cloud.google.com/go/longrunning v0.5.7 // indirect
cloud.google.com/go/vertexai v0.9.0 // indirect
github.com/dlclark/regexp2 v1.11.0 // indirect
github.com/felixge/httpsnoop v1.0.4 // indirect
github.com/go-logr/logr v1.4.1 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
github.com/golang/protobuf v1.5.4 // indirect
github.com/google/generative-ai-go v0.11.0 // indirect
github.com/google/generative-ai-go v0.12.0 // indirect
github.com/google/s2a-go v0.1.7 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect
github.com/googleapis/gax-go/v2 v2.12.3 // indirect
github.com/googleapis/gax-go/v2 v2.12.4 // indirect
github.com/pkoukk/tiktoken-go v0.1.6 // indirect
go.opencensus.io v0.24.0 // indirect
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.49.0 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0 // indirect
go.opentelemetry.io/otel v1.24.0 // indirect
go.opentelemetry.io/otel/metric v1.24.0 // indirect
go.opentelemetry.io/otel/trace v1.24.0 // indirect
golang.org/x/crypto v0.21.0 // indirect
golang.org/x/net v0.22.0 // indirect
golang.org/x/oauth2 v0.18.0 // indirect
golang.org/x/sync v0.6.0 // indirect
golang.org/x/sys v0.18.0 // indirect
golang.org/x/text v0.14.0 // indirect
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.51.0 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0 // indirect
go.opentelemetry.io/otel v1.26.0 // indirect
go.opentelemetry.io/otel/metric v1.26.0 // indirect
go.opentelemetry.io/otel/trace v1.26.0 // indirect
golang.org/x/crypto v0.23.0 // indirect
golang.org/x/net v0.25.0 // indirect
golang.org/x/oauth2 v0.20.0 // indirect
golang.org/x/sync v0.7.0 // indirect
golang.org/x/sys v0.20.0 // indirect
golang.org/x/text v0.15.0 // indirect
golang.org/x/time v0.5.0 // indirect
google.golang.org/api v0.172.0 // indirect
google.golang.org/appengine v1.6.8 // indirect
google.golang.org/genproto v0.0.0-20240221002015-b0ce06bbee7c // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20240401170217-c3f982113cda // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20240325203815-454cdb8f5daa // indirect
google.golang.org/grpc v1.62.1 // indirect
google.golang.org/protobuf v1.33.0 // indirect
google.golang.org/api v0.180.0 // indirect
google.golang.org/genproto v0.0.0-20240401170217-c3f982113cda // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20240515191416-fc5f0ca64291 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20240515191416-fc5f0ca64291 // indirect
google.golang.org/grpc v1.64.0 // indirect
google.golang.org/protobuf v1.34.1 // indirect
)

// test for new pkg version
replace github.com/tmc/langchaingo => ../..
141 changes: 58 additions & 83 deletions examples/vertex-completion-example/go.sum

Large diffs are not rendered by default.

10 changes: 9 additions & 1 deletion examples/vertex-completion-example/vertex-completion-example.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
// Set the VERTEX_PROJECT env var to your GCP project with Vertex AI APIs
// enabled. Set VERTEX_LOCATION to a GCP location (region); if you're not sure
// about the location, set us-central1
// Set the VERTEX_CREDENTIALS env var to the path of your GCP service account
// credentials JSON file.
package main

import (
Expand All @@ -18,7 +20,13 @@ func main() {
ctx := context.Background()
project := os.Getenv("VERTEX_PROJECT")
location := os.Getenv("VERTEX_LOCATION")
llm, err := vertex.New(ctx, googleai.WithCloudProject(project), googleai.WithCloudLocation(location))
credentialsJSONFile := os.Getenv("VERTEX_CREDENTIALS")
llm, err := vertex.New(
ctx,
googleai.WithCloudProject(project),
googleai.WithCloudLocation(location),
googleai.WithCredentialsFile(credentialsJSONFile),
)
if err != nil {
log.Fatal(err)
}
Expand Down
17 changes: 6 additions & 11 deletions llms/googleai/internal/palmclient/palmclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,6 @@ import (
"google.golang.org/protobuf/types/known/structpb"
)

const (
defaultAPIEndpoint = "us-central1-aiplatform.googleapis.com:443"
defaultLocation = "us-central1"
defaultPublisher = "google"
)

var (
// ErrMissingValue is returned when a value is missing.
ErrMissingValue = errors.New("missing value")
Expand Down Expand Up @@ -48,19 +42,20 @@ type PaLMClient struct {
}

// New returns a new Vertex AI based PaLM API client.
func New(projectID string, opts ...option.ClientOption) (*PaLMClient, error) {
func New(ctx context.Context, projectID, location string, opts ...option.ClientOption) (*PaLMClient, error) {
numConns := runtime.GOMAXPROCS(0)
if numConns > defaultMaxConns {
numConns = defaultMaxConns
}
o := []option.ClientOption{
option.WithGRPCConnectionPool(numConns),
option.WithEndpoint(defaultAPIEndpoint),
option.WithEndpoint(fmt.Sprintf("%s-aiplatform.googleapis.com:443", location)),
}
o = append(o, opts...)
opts = append(o, opts...)
// PredictionClient only support GRPC.
opts = append(opts, option.WithHTTPClient(nil))

ctx := context.Background()
client, err := aiplatform.NewPredictionClient(ctx, o...)
client, err := aiplatform.NewPredictionClient(ctx, opts...)
if err != nil {
return nil, err
}
Expand Down
3 changes: 1 addition & 2 deletions llms/googleai/new.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"github.com/google/generative-ai-go/genai"
"github.com/tmc/langchaingo/callbacks"
"github.com/tmc/langchaingo/llms"
"google.golang.org/api/option"
)

// GoogleAI is a type that represents a Google AI API client.
Expand All @@ -31,7 +30,7 @@ func New(ctx context.Context, opts ...Option) (*GoogleAI, error) {
opts: clientOptions,
}

client, err := genai.NewClient(ctx, option.WithAPIKey(clientOptions.APIKey))
client, err := genai.NewClient(ctx, clientOptions.ClientOptions...)
if err != nil {
return gi, err
}
Expand Down
53 changes: 50 additions & 3 deletions llms/googleai/option.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
package googleai

import (
"net/http"

"cloud.google.com/go/vertexai/genai"
"google.golang.org/api/option"
)

// Options is a set of options for GoogleAI and Vertex clients.
type Options struct {
APIKey string
CloudProject string
CloudLocation string
DefaultModel string
Expand All @@ -13,11 +19,12 @@ type Options struct {
DefaultTopK int
DefaultTopP float64
HarmThreshold HarmBlockThreshold

ClientOptions []option.ClientOption
}

func DefaultOptions() Options {
return Options{
APIKey: "",
CloudProject: "",
CloudLocation: "",
DefaultModel: "gemini-pro",
Expand All @@ -37,7 +44,47 @@ type Option func(*Options)
// googleai clients.
func WithAPIKey(apiKey string) Option {
return func(opts *Options) {
opts.APIKey = apiKey
opts.ClientOptions = append(opts.ClientOptions, option.WithAPIKey(apiKey))
}
}

// WithCredentialsJSON append a ClientOption that authenticates
// API calls with the given service account or refresh token JSON
// credentials.
func WithCredentialsJSON(credentialsJSON []byte) Option {
return func(opts *Options) {
if len(credentialsJSON) == 0 {
return
}
opts.ClientOptions = append(opts.ClientOptions, option.WithCredentialsJSON(credentialsJSON))
}
}

// WithCredentialsFile append a ClientOption that authenticates
// API calls with the given service account or refresh token JSON
// credentials file.
func WithCredentialsFile(credentialsFile string) Option {
return func(opts *Options) {
if credentialsFile == "" {
return
}
opts.ClientOptions = append(opts.ClientOptions, option.WithCredentialsFile(credentialsFile))
}
}

// WithRest configures the client to use the REST API.
func WithRest() Option {
return func(opts *Options) {
opts.ClientOptions = append(opts.ClientOptions, genai.WithREST())
}
}

// WithHTTPClient append a ClientOption that uses the provided HTTP client to
// make requests.
// This is useful for vertex clients.
func WithHTTPClient(httpClient *http.Client) Option {
return func(opts *Options) {
opts.ClientOptions = append(opts.ClientOptions, option.WithHTTPClient(httpClient))
}
}

Expand Down
6 changes: 5 additions & 1 deletion llms/googleai/palm/palm_llm.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
var (
ErrEmptyResponse = errors.New("no response")
ErrMissingProjectID = errors.New("missing the GCP Project ID, set it in the GOOGLE_CLOUD_PROJECT environment variable") //nolint:lll
ErrMissingLocation = errors.New("missing the GCP Location, set it in the GOOGLE_CLOUD_LOCATION environment variable") //nolint:lll
ErrUnexpectedResponseLength = errors.New("unexpected length of response")
ErrNotImplemented = errors.New("not implemented")
)
Expand Down Expand Up @@ -111,6 +112,9 @@ func newClient(opts ...Option) (*palmclient.PaLMClient, error) {
if len(options.projectID) == 0 {
return nil, ErrMissingProjectID
}
if len(options.location) == 0 {
return nil, ErrMissingLocation
}

return palmclient.New(options.projectID, options.clientOptions...)
return palmclient.New(context.TODO(), options.projectID, options.location, options.clientOptions...)
}
8 changes: 8 additions & 0 deletions llms/googleai/palm/palm_llm_option.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ var (

type options struct {
projectID string
location string
clientOptions []option.ClientOption
}

Expand All @@ -44,6 +45,13 @@ func WithProjectID(projectID string) Option {
}
}

// WithLocation passes the Google Cloud location to the client.
func WithLocation(location string) Option {
return func(opts *options) {
opts.location = location
}
}

// WithAPIKey returns a ClientOption that specifies an API key to be used
// as the basis for authentication.
func WithAPIKey(apiKey string) Option {
Expand Down
Loading

0 comments on commit 9817356

Please sign in to comment.