-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
193 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
package embedding | ||
|
||
import ( | ||
"context" | ||
"encoding/json" | ||
|
||
"github.com/aws/aws-sdk-go-v2/aws" | ||
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime" | ||
"github.com/hupe1980/golc/schema" | ||
) | ||
|
||
// Compile time check to ensure Bedrock satisfies the Embedder interface. | ||
var _ schema.Embedder = (*Bedrock)(nil) | ||
|
||
// amazonOutput represents the expected JSON output structure from the Bedrock model. | ||
type amazonOutput struct { | ||
Embedding []float64 `json:"embedding"` | ||
} | ||
|
||
// BedrockRuntimeClient is an interface for the Bedrock model runtime client. | ||
type BedrockRuntimeClient interface { | ||
InvokeModel(ctx context.Context, params *bedrockruntime.InvokeModelInput, optFns ...func(*bedrockruntime.Options)) (*bedrockruntime.InvokeModelOutput, error) | ||
} | ||
|
||
// BedrockOptions contains options for configuring the Bedrock model. | ||
type BedrockOptions struct { | ||
*schema.CallbackOptions `map:"-"` | ||
schema.Tokenizer `map:"-"` | ||
|
||
// Model id to use. | ||
ModelID string `map:"model_id,omitempty"` | ||
|
||
// Model params to use. | ||
ModelParams map[string]any `map:"model_params,omitempty"` | ||
} | ||
|
||
// Bedrock is a struct representing the Bedrock model embedding functionality. | ||
type Bedrock struct { | ||
client BedrockRuntimeClient | ||
opts BedrockOptions | ||
} | ||
|
||
// NewBedrock creates a new instance of Bedrock with the provided BedrockRuntimeClient and optional configuration. | ||
func NewBedrock(client BedrockRuntimeClient, optFns ...func(o *BedrockOptions)) *Bedrock { | ||
opts := BedrockOptions{ | ||
ModelID: "amazon.titan-embed-text-v1", | ||
} | ||
|
||
for _, fn := range optFns { | ||
fn(&opts) | ||
} | ||
|
||
return &Bedrock{ | ||
client: client, | ||
opts: opts, | ||
} | ||
} | ||
|
||
// EmbedDocuments embeds a list of documents and returns their embeddings. | ||
func (e *Bedrock) EmbedDocuments(ctx context.Context, texts []string) ([][]float64, error) { | ||
embeddings := make([][]float64, len(texts)) | ||
|
||
for i, text := range texts { | ||
embedding, err := e.EmbedQuery(ctx, text) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
embeddings[i] = embedding | ||
} | ||
|
||
return embeddings, nil | ||
} | ||
|
||
// EmbedQuery embeds a single query and returns its embedding. | ||
func (e *Bedrock) EmbedQuery(ctx context.Context, text string) ([]float64, error) { | ||
jsonBody := map[string]string{ | ||
"inputText": text, | ||
} | ||
|
||
body, err := json.Marshal(jsonBody) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
res, err := e.client.InvokeModel(ctx, &bedrockruntime.InvokeModelInput{ | ||
ModelId: aws.String(e.opts.ModelID), | ||
Body: body, | ||
Accept: aws.String("application/json"), | ||
ContentType: aws.String("application/json"), | ||
}) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
output := &amazonOutput{} | ||
if err := json.Unmarshal(res.Body, output); err != nil { | ||
return nil, err | ||
} | ||
|
||
return output.Embedding, nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
package embedding | ||
|
||
import ( | ||
"context" | ||
"errors" | ||
"testing" | ||
|
||
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime" | ||
"github.com/stretchr/testify/assert" | ||
) | ||
|
||
func TestBedrock(t *testing.T) { | ||
t.Run("TestEmbedDocuments", func(t *testing.T) { | ||
t.Run("Successful embedding of documents", func(t *testing.T) { | ||
// Create an instance of the Bedrock struct with a mock client. | ||
client := &mockBedrockRuntimeClient{ | ||
response: &bedrockruntime.InvokeModelOutput{ | ||
Body: []byte(`{"embedding": [1.0, 2.0, 3.0]}`), | ||
}, | ||
} | ||
embedder := NewBedrock(client) | ||
|
||
// Define a list of texts to embed. | ||
texts := []string{"text1", "text2"} | ||
|
||
// Embed the documents. | ||
embeddings, err := embedder.EmbedDocuments(context.Background(), texts) | ||
|
||
// Add your assertions using testify | ||
assert.NoError(t, err, "Expected no error") | ||
assert.NotNil(t, embeddings, "Expected non-nil embeddings") | ||
assert.Len(t, embeddings, 2, "Expected 2 embeddings") | ||
assert.Len(t, embeddings[0], 3, "Expected 3 values in the embedding") | ||
}) | ||
}) | ||
|
||
t.Run("TestEmbedQuery", func(t *testing.T) { | ||
t.Run("Successful embedding of a single query", func(t *testing.T) { | ||
// Create an instance of the Bedrock struct with a mock client. | ||
client := &mockBedrockRuntimeClient{ | ||
response: &bedrockruntime.InvokeModelOutput{ | ||
Body: []byte(`{"embedding": [1.0, 2.0, 3.0]}`), | ||
}, | ||
} | ||
embedder := NewBedrock(client) | ||
|
||
// Define a query text. | ||
query := "query text" | ||
|
||
// Embed the query. | ||
embedding, err := embedder.EmbedQuery(context.Background(), query) | ||
|
||
// Add your assertions using testify | ||
assert.NoError(t, err, "Expected no error") | ||
assert.NotNil(t, embedding, "Expected non-nil embedding") | ||
assert.Len(t, embedding, 3, "Expected 3 values in the embedding") | ||
}) | ||
|
||
t.Run("Embedding error", func(t *testing.T) { | ||
// Create an instance of the Bedrock struct with a mock client. | ||
client := &mockBedrockRuntimeClient{ | ||
err: errors.New("Embedding error"), | ||
} | ||
embedder := NewBedrock(client) | ||
|
||
// Define a query text. | ||
query := "query text" | ||
|
||
// Embed the query. | ||
embedding, err := embedder.EmbedQuery(context.Background(), query) | ||
|
||
// Add your assertions using testify | ||
assert.Error(t, err, "Expected an error") | ||
assert.Nil(t, embedding, "Expected nil embedding") | ||
}) | ||
}) | ||
} | ||
|
||
// mockBedrockRuntimeClient is a mock implementation of BedrockRuntimeClient for testing. | ||
type mockBedrockRuntimeClient struct { | ||
response *bedrockruntime.InvokeModelOutput | ||
err error | ||
} | ||
|
||
func (m *mockBedrockRuntimeClient) InvokeModel(ctx context.Context, params *bedrockruntime.InvokeModelInput, optFns ...func(*bedrockruntime.Options)) (*bedrockruntime.InvokeModelOutput, error) { | ||
if m.err != nil { | ||
return nil, m.err | ||
} | ||
|
||
return m.response, nil | ||
} |