Skip to content

Commit

Permalink
Add missing options
Browse files Browse the repository at this point in the history
  • Loading branch information
hupe1980 committed Sep 16, 2023
1 parent afb918b commit 6475298
Showing 1 changed file with 14 additions and 5 deletions.
19 changes: 14 additions & 5 deletions embedding/hugging_face_hub.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@ import (
// Compile time check to ensure HuggingFaceHub satisfies the Embedder interface.
var _ schema.Embedder = (*HuggingFaceHub)(nil)

// HuggingFaceHubClient represents a client for interacting with Hugging Face Hub.
type HuggingFaceHubClient interface {
// FeatureExtractionWithAutomaticReduction performs feature extraction with automatic reduction.
// It returns the extraction response or an error if the operation fails.
FeatureExtractionWithAutomaticReduction(ctx context.Context, req *huggingface.FeatureExtractionRequest) (huggingface.FeatureExtractionWithAutomaticReductionResponse, error)
}

Expand All @@ -22,20 +25,24 @@ type HuggingFaceHubOptions struct {
Options huggingface.Options
}

// HuggingFaceHub represents an embedder for Hugging Face Hub models.
type HuggingFaceHub struct {
client HuggingFaceHubClient
opts HuggingFaceHubOptions
}

// NewHuggingFaceHub creates a new instance of the HuggingFaceHub embedder.
func NewHuggingFaceHub(token string, optFns ...func(o *HuggingFaceHubOptions)) *HuggingFaceHub {
client := huggingface.NewInferenceClient(token)

return NewHuggingFaceHubFromClient(client, optFns...)
}

// NewHuggingFaceHubFromClient creates a new instance of the HuggingFaceHub embedder from a custom client.
func NewHuggingFaceHubFromClient(client HuggingFaceHubClient, optFns ...func(o *HuggingFaceHubOptions)) *HuggingFaceHub {
opts := HuggingFaceHubOptions{
Model: "sentence-transformers/all-mpnet-base-v2",
Model: "sentence-transformers/all-mpnet-base-v2",
Options: huggingface.Options{},
}

for _, fn := range optFns {
Expand All @@ -51,8 +58,9 @@ func NewHuggingFaceHubFromClient(client HuggingFaceHubClient, optFns ...func(o *
// EmbedDocuments embeds a list of documents and returns their embeddings.
func (e *HuggingFaceHub) EmbedDocuments(ctx context.Context, texts []string) ([][]float64, error) {
res, err := e.client.FeatureExtractionWithAutomaticReduction(ctx, &huggingface.FeatureExtractionRequest{
Inputs: texts,
Model: e.opts.Model,
Inputs: texts,
Model: e.opts.Model,
Options: e.opts.Options,
})
if err != nil {
return nil, err
Expand All @@ -64,8 +72,9 @@ func (e *HuggingFaceHub) EmbedDocuments(ctx context.Context, texts []string) ([]
// EmbedQuery embeds a single query and returns its embedding.
func (e *HuggingFaceHub) EmbedQuery(ctx context.Context, text string) ([]float64, error) {
res, err := e.client.FeatureExtractionWithAutomaticReduction(ctx, &huggingface.FeatureExtractionRequest{
Inputs: []string{text},
Model: e.opts.Model,
Inputs: []string{text},
Model: e.opts.Model,
Options: e.opts.Options,
})
if err != nil {
return nil, err
Expand Down

0 comments on commit 6475298

Please sign in to comment.