Skip to content

Commit

Permalink
Add chat completion with retry
Browse files Browse the repository at this point in the history
  • Loading branch information
hupe1980 committed Jul 27, 2023
1 parent 7719d62 commit 4ef09c4
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 22 deletions.
83 changes: 62 additions & 21 deletions model/chatmodel/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@ package chatmodel

import (
"context"
"errors"
"fmt"

"github.com/avast/retry-go"
"github.com/hupe1980/golc"
"github.com/hupe1980/golc/callback"
"github.com/hupe1980/golc/schema"
Expand Down Expand Up @@ -44,6 +46,8 @@ type OpenAIOptions struct {
BaseURL string
// OrgID is the organization ID for accessing the OpenAI service.
OrgID string
// MaxRetries represents the maximum number of retries to make when generating.
MaxRetries uint `map:"max_retries,omitempty"`
}

var DefaultOpenAIOptions = OpenAIOptions{
Expand All @@ -55,6 +59,7 @@ var DefaultOpenAIOptions = OpenAIOptions{
TopP: 1,
PresencePenalty: 0,
FrequencyPenalty: 0,
MaxRetries: 3,
}

// OpenAI represents the OpenAI chat model.
Expand Down Expand Up @@ -153,7 +158,7 @@ func (cm *OpenAI) Generate(ctx context.Context, messages schema.ChatMessages, op
})
}

res, err := cm.client.CreateChatCompletion(ctx, openai.ChatCompletionRequest{
res, err := cm.createChatCompletionWithRetry(ctx, openai.ChatCompletionRequest{
Model: cm.opts.ModelName,
Temperature: cm.opts.Temperature,
Messages: openAIMessages,
Expand All @@ -180,6 +185,62 @@ func (cm *OpenAI) Generate(ctx context.Context, messages schema.ChatMessages, op
}, nil
}

func (cm *OpenAI) createChatCompletionWithRetry(ctx context.Context, request openai.ChatCompletionRequest) (openai.ChatCompletionResponse, error) {
retryOpts := []retry.Option{
retry.Attempts(cm.opts.MaxRetries),
retry.DelayType(retry.FixedDelay),
retry.RetryIf(func(err error) bool {
e := &openai.APIError{}
if errors.As(err, &e) {
switch e.HTTPStatusCode {
case 429, 500:
return true
default:
return false
}
}

return false
}),
}

var res openai.ChatCompletionResponse

err := retry.Do(
func() error {
r, cErr := cm.client.CreateChatCompletion(ctx, request)
if cErr != nil {
return cErr
}
res = r
return nil
},
retryOpts...,
)

return res, err
}

// Type returns the type of the model.
func (cm *OpenAI) Type() string {
return "chatmodel.OpenAI"
}

// Verbose returns the verbosity setting of the model.
func (cm *OpenAI) Verbose() bool {
return cm.opts.CallbackOptions.Verbose
}

// Callbacks returns the registered callbacks of the model.
func (cm *OpenAI) Callbacks() []schema.Callback {
return cm.opts.CallbackOptions.Callbacks
}

// InvocationParams returns the parameters used in the model invocation.
func (cm *OpenAI) InvocationParams() map[string]any {
return nil
}

// messageTypeToOpenAIRole converts a schema.ChatMessageType to the corresponding OpenAI role string.
func messageTypeToOpenAIRole(mType schema.ChatMessageType) (string, error) {
switch mType { // nolint exhaustive
Expand Down Expand Up @@ -220,23 +281,3 @@ func openAIResponseToChatMessage(msg openai.ChatCompletionMessage) schema.ChatMe

return schema.NewGenericChatMessage(msg.Content, "unknown")
}

// Type returns the type of the model.
func (cm *OpenAI) Type() string {
return "chatmodel.OpenAI"
}

// Verbose returns the verbosity setting of the model.
func (cm *OpenAI) Verbose() bool {
return cm.opts.CallbackOptions.Verbose
}

// Callbacks returns the registered callbacks of the model.
func (cm *OpenAI) Callbacks() []schema.Callback {
return cm.opts.CallbackOptions.Callbacks
}

// InvocationParams returns the parameters used in the model invocation.
func (cm *OpenAI) InvocationParams() map[string]any {
return nil
}
2 changes: 1 addition & 1 deletion model/chatmodel/openai_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ func TestOpenAI_Generate(t *testing.T) {

result, err := openAI.Generate(ctx, messages)
assert.Error(t, err)
assert.EqualError(t, err, mockError.Error())
assert.EqualError(t, errors.New("All attempts fail:\n#1: generation error"), err.Error())
assert.Nil(t, result)
})
// Test case for Type method
Expand Down

0 comments on commit 4ef09c4

Please sign in to comment.