Skip to content

Commit

Permalink
Improve anthropic prompt
Browse files Browse the repository at this point in the history
  • Loading branch information
hupe1980 committed Sep 7, 2023
1 parent ed25557 commit 9f3a3fd
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 1 deletion.
35 changes: 34 additions & 1 deletion model/chatmodel/anthropic.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package chatmodel

import (
"context"
"fmt"
"strings"

"github.com/hupe1980/golc"
"github.com/hupe1980/golc/callback"
Expand All @@ -11,6 +13,11 @@ import (
"github.com/hupe1980/golc/util"
)

const (
humanPromptPrefix = "\n\nHuman:"
aiPromptPrefix = "\n\nAssistant:"
)

// Compile time check to ensure Anthropic satisfies the ChatModel interface.
var _ schema.ChatModel = (*Anthropic)(nil)

Expand Down Expand Up @@ -82,7 +89,7 @@ func (cm *Anthropic) Generate(ctx context.Context, messages schema.ChatMessages,
fn(&opts)
}

prompt, err := messages.Format()
prompt, err := convertMessagesToAnthropicPrompt(messages)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -125,3 +132,29 @@ func (cm *Anthropic) Callbacks() []schema.Callback {
func (cm *Anthropic) InvocationParams() map[string]any {
return util.StructToMap(cm.opts)
}

func convertMessagesToAnthropicPrompt(messages schema.ChatMessages) (string, error) {
if len(messages) > 0 {
msg := messages[len(messages)-1]
if msg.Type() != schema.ChatMessageTypeAI {
messages = append(messages, schema.NewAIChatMessage(""))
}
}

prompt := ""

for _, message := range messages {
switch message.Type() {
case schema.ChatMessageTypeSystem:
prompt += fmt.Sprintf("%s <admin>%s</admin>", humanPromptPrefix, message.Content())
case schema.ChatMessageTypeAI:
prompt += fmt.Sprintf("%s %s", aiPromptPrefix, message.Content())
case schema.ChatMessageTypeHuman:
prompt += fmt.Sprintf("%s %s", humanPromptPrefix, message.Content())
default:
return "", fmt.Errorf("unsupported message type: %s", message.Type())
}
}

return strings.TrimRight(prompt, " "), nil
}
44 changes: 44 additions & 0 deletions model/chatmodel/anthropic_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package chatmodel

import (
"testing"

"github.com/hupe1980/golc/schema"
"github.com/stretchr/testify/assert"
)

func TestConvertMessagesToAnthropicPrompt(t *testing.T) {
t.Run("Empty input messages", func(t *testing.T) {
emptyMessages := schema.ChatMessages{}
emptyPrompt, emptyErr := convertMessagesToAnthropicPrompt(emptyMessages)
assert.Equal(t, "", emptyPrompt)
assert.Nil(t, emptyErr)
})

t.Run("Messages with a single system message", func(t *testing.T) {
systemMessage := schema.NewSystemChatMessage("System message")
messagesWithSystem := schema.ChatMessages{systemMessage}
systemPrompt, systemErr := convertMessagesToAnthropicPrompt(messagesWithSystem)
expectedSystemPrompt := "\n\nHuman: <admin>System message</admin>\n\nAssistant:"
assert.Equal(t, expectedSystemPrompt, systemPrompt)
assert.Nil(t, systemErr)
})

t.Run("Messages with a single AI message", func(t *testing.T) {
aiMessage := schema.NewAIChatMessage("AI message")
messagesWithAI := schema.ChatMessages{aiMessage}
aiPrompt, aiErr := convertMessagesToAnthropicPrompt(messagesWithAI)
expectedAIPrompt := "\n\nAssistant: AI message"
assert.Equal(t, expectedAIPrompt, aiPrompt)
assert.Nil(t, aiErr)
})

t.Run("Messages with a single human message", func(t *testing.T) {
humanMessage := schema.NewHumanChatMessage("Human message")
messagesWithHuman := schema.ChatMessages{humanMessage}
humanPrompt, humanErr := convertMessagesToAnthropicPrompt(messagesWithHuman)
expectedHumanPrompt := "\n\nHuman: Human message\n\nAssistant:"
assert.Equal(t, expectedHumanPrompt, humanPrompt)
assert.Nil(t, humanErr)
})
}

0 comments on commit 9f3a3fd

Please sign in to comment.