diff --git a/conversation/openai/openai.go b/conversation/openai/openai.go index df6b76eabe..56304ce5b0 100644 --- a/conversation/openai/openai.go +++ b/conversation/openai/openai.go @@ -23,12 +23,12 @@ import ( "github.com/dapr/kit/logger" kmeta "github.com/dapr/kit/metadata" - "github.com/sashabaranov/go-openai" + openai "github.com/sashabaranov/go-openai" ) type OpenAI struct { - key string - model string + cilent *openai.Client + model string logger logger.Logger } @@ -48,7 +48,7 @@ func (o *OpenAI) Init(ctx context.Context, meta conversation.Metadata) error { return err } - o.key = r.Key + o.cilent = openai.NewClient(r.Key) o.model = r.Model return nil @@ -62,7 +62,6 @@ func (o *OpenAI) GetComponentMetadata() (metadataInfo metadata.MetadataMap) { func (o *OpenAI) Converse(ctx context.Context, r *conversation.ConversationRequest) (res *conversation.ConversationResponse, err error) { // Note: OPENAI does not support load balance - client := openai.NewClient(o.key) messages := make([]openai.ChatCompletionMessage, 0, len(r.Inputs)) @@ -80,7 +79,7 @@ func (o *OpenAI) Converse(ctx context.Context, r *conversation.ConversationReque // TODO: support ConversationContext - resp, err := client.CreateChatCompletion(ctx, req) + resp, err := o.cilent.CreateChatCompletion(ctx, req) if err != nil { o.logger.Error(err) return nil, err