Skip to content

Commit

Permalink
Add bedrock embedding
Browse files Browse the repository at this point in the history
  • Loading branch information
hupe1980 committed Oct 26, 2023
1 parent 4975d8d commit 15ea619
Show file tree
Hide file tree
Showing 2 changed files with 193 additions and 0 deletions.
102 changes: 102 additions & 0 deletions embedding/bedrock.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
package embedding

import (
"context"
"encoding/json"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
"github.com/hupe1980/golc/schema"
)

// Compile time check to ensure Bedrock satisfies the Embedder interface.
var _ schema.Embedder = (*Bedrock)(nil)

// amazonOutput represents the expected JSON output structure from the Bedrock model.
type amazonOutput struct {
Embedding []float64 `json:"embedding"`
}

// BedrockRuntimeClient is an interface for the Bedrock model runtime client.
type BedrockRuntimeClient interface {
InvokeModel(ctx context.Context, params *bedrockruntime.InvokeModelInput, optFns ...func(*bedrockruntime.Options)) (*bedrockruntime.InvokeModelOutput, error)
}

// BedrockOptions contains options for configuring the Bedrock model.
type BedrockOptions struct {
*schema.CallbackOptions `map:"-"`
schema.Tokenizer `map:"-"`

// Model id to use.
ModelID string `map:"model_id,omitempty"`

// Model params to use.
ModelParams map[string]any `map:"model_params,omitempty"`
}

// Bedrock is a struct representing the Bedrock model embedding functionality.
type Bedrock struct {
client BedrockRuntimeClient
opts BedrockOptions
}

// NewBedrock creates a new instance of Bedrock with the provided BedrockRuntimeClient and optional configuration.
func NewBedrock(client BedrockRuntimeClient, optFns ...func(o *BedrockOptions)) *Bedrock {
opts := BedrockOptions{
ModelID: "amazon.titan-embed-text-v1",
}

for _, fn := range optFns {
fn(&opts)
}

return &Bedrock{
client: client,
opts: opts,
}
}

// EmbedDocuments embeds a list of documents and returns their embeddings.
func (e *Bedrock) EmbedDocuments(ctx context.Context, texts []string) ([][]float64, error) {
embeddings := make([][]float64, len(texts))

for i, text := range texts {
embedding, err := e.EmbedQuery(ctx, text)
if err != nil {
return nil, err
}

embeddings[i] = embedding
}

return embeddings, nil
}

// EmbedQuery embeds a single query and returns its embedding.
func (e *Bedrock) EmbedQuery(ctx context.Context, text string) ([]float64, error) {
jsonBody := map[string]string{
"inputText": text,
}

body, err := json.Marshal(jsonBody)
if err != nil {
return nil, err
}

res, err := e.client.InvokeModel(ctx, &bedrockruntime.InvokeModelInput{
ModelId: aws.String(e.opts.ModelID),
Body: body,
Accept: aws.String("application/json"),
ContentType: aws.String("application/json"),
})
if err != nil {
return nil, err
}

output := &amazonOutput{}
if err := json.Unmarshal(res.Body, output); err != nil {
return nil, err
}

return output.Embedding, nil
}
91 changes: 91 additions & 0 deletions embedding/bedrock_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
package embedding

import (
"context"
"errors"
"testing"

"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
"github.com/stretchr/testify/assert"
)

func TestBedrock(t *testing.T) {
t.Run("TestEmbedDocuments", func(t *testing.T) {
t.Run("Successful embedding of documents", func(t *testing.T) {
// Create an instance of the Bedrock struct with a mock client.
client := &mockBedrockRuntimeClient{
response: &bedrockruntime.InvokeModelOutput{
Body: []byte(`{"embedding": [1.0, 2.0, 3.0]}`),
},
}
embedder := NewBedrock(client)

// Define a list of texts to embed.
texts := []string{"text1", "text2"}

// Embed the documents.
embeddings, err := embedder.EmbedDocuments(context.Background(), texts)

// Add your assertions using testify
assert.NoError(t, err, "Expected no error")
assert.NotNil(t, embeddings, "Expected non-nil embeddings")
assert.Len(t, embeddings, 2, "Expected 2 embeddings")
assert.Len(t, embeddings[0], 3, "Expected 3 values in the embedding")
})
})

t.Run("TestEmbedQuery", func(t *testing.T) {
t.Run("Successful embedding of a single query", func(t *testing.T) {
// Create an instance of the Bedrock struct with a mock client.
client := &mockBedrockRuntimeClient{
response: &bedrockruntime.InvokeModelOutput{
Body: []byte(`{"embedding": [1.0, 2.0, 3.0]}`),
},
}
embedder := NewBedrock(client)

// Define a query text.
query := "query text"

// Embed the query.
embedding, err := embedder.EmbedQuery(context.Background(), query)

// Add your assertions using testify
assert.NoError(t, err, "Expected no error")
assert.NotNil(t, embedding, "Expected non-nil embedding")
assert.Len(t, embedding, 3, "Expected 3 values in the embedding")
})

t.Run("Embedding error", func(t *testing.T) {
// Create an instance of the Bedrock struct with a mock client.
client := &mockBedrockRuntimeClient{
err: errors.New("Embedding error"),
}
embedder := NewBedrock(client)

// Define a query text.
query := "query text"

// Embed the query.
embedding, err := embedder.EmbedQuery(context.Background(), query)

// Add your assertions using testify
assert.Error(t, err, "Expected an error")
assert.Nil(t, embedding, "Expected nil embedding")
})
})
}

// mockBedrockRuntimeClient is a mock implementation of BedrockRuntimeClient for testing.
type mockBedrockRuntimeClient struct {
response *bedrockruntime.InvokeModelOutput
err error
}

func (m *mockBedrockRuntimeClient) InvokeModel(ctx context.Context, params *bedrockruntime.InvokeModelInput, optFns ...func(*bedrockruntime.Options)) (*bedrockruntime.InvokeModelOutput, error) {
if m.err != nil {
return nil, m.err
}

return m.response, nil
}

0 comments on commit 15ea619

Please sign in to comment.