Skip to content
This repository has been archived by the owner on Jan 9, 2025. It is now read-only.

Commit

Permalink
fix(redis): fix message retrieval and improve system message support (#…
Browse files Browse the repository at this point in the history
…83)

Because

- [Bug 🐛] Previously, the retrieval process would fetch the earliest
messages instead of the latest ones, which was a critical bug.

- [Improvement 🧹] Furthermore, the meaning of `latest_k` in the commit
has been updated. It now refers to the latest K conversation turns. In a
chat history, each conversation turn consists of one participant
speaking or sending a message, followed by responses from other
participants. For example:

```
User: Question 1
Assistant: Response 1
User: Question 2
Assistant: Response 2
```

With the previous implementation, if `latest_k` was set to 3, the
retrieved conversation would be incomplete:

```
Assistant: Response 1
User: Question 2
Assistant: Response 2
```

This issue prevented complete conversation turns from being retrieved.

- [Improvement 🧹] Additionally, this commit introduces support for
system messages, which are treated as a special type of message for the
LLM. Regardless of the chat history's length, the option to include the
system message in the retrieved messages is now available.

This commit:

- fix the critical bug that caused the retrieval of early messages
instead of the latest ones
- update `latest_k` to refer to conversation turns
- add support for system messages to ensure they can always be included
when needed
  • Loading branch information
xiaofei-du authored Dec 14, 2023
1 parent 578663e commit 0c19492
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 28 deletions.
119 changes: 93 additions & 26 deletions pkg/redis/chat_history.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ import (
)

var (
// DefaultLatestK is the default number of latest messages to retrieve
DefaultLatestK = 10
// DefaultLatestK is the default number of latest conversation turns to retrieve
DefaultLatestK = 5
)

type Message struct {
Expand All @@ -35,8 +35,9 @@ type ChatMessageWriteOutput struct {
}

type ChatHistoryRetrieveInput struct {
SessionID string `json:"session_id"`
LatestK *int `json:"latest_k,omitempty"`
SessionID string `json:"session_id"`
LatestK *int `json:"latest_k,omitempty"`
IncludeSystemMessage bool `json:"include_system_message"`
}

// ChatHistoryReadOutput is a wrapper struct for the messages associated with a session ID
Expand All @@ -45,10 +46,55 @@ type ChatHistoryRetrieveOutput struct {
Status bool `json:"status"`
}

// WriteSystemMessage writes system message for a given session ID
func WriteSystemMessage(client *goredis.Client, sessionID string, message MessageWithTime) error {
messageJSON, err := json.Marshal(message)
if err != nil {
return err
}

// Store in a hash with a unique SessionID
return client.HSet(context.Background(), "system_messages", sessionID, messageJSON).Err()
}

func WriteNonSystemMessage(client *goredis.Client, sessionID string, message MessageWithTime) error {
// Marshal the MessageWithTime struct to JSON
messageJSON, err := json.Marshal(message)
if err != nil {
return err
}

// Index by Timestamp: Add to the Sorted Set
return client.ZAdd(context.Background(), sessionID+":timestamps", goredis.Z{
Score: float64(message.Timestamp),
Member: string(messageJSON),
}).Err()
}

// RetrieveSystemMessage gets system message based on a given session ID
func RetrieveSystemMessage(client *goredis.Client, sessionID string) (bool, *MessageWithTime, error) {
serializedMessage, err := client.HGet(context.Background(), "system_messages", sessionID).Result()

// Check if the messageID does not exist
if err == goredis.Nil {
// Handle the case where the message does not exist
return false, nil, nil
} else if err != nil {
// Handle other types of errors
return false, nil, err
}

var message MessageWithTime
if err := json.Unmarshal([]byte(serializedMessage), &message); err != nil {
return false, nil, err
}

return true, &message, nil
}

func WriteMessage(client *goredis.Client, input ChatMessageWriteInput) ChatMessageWriteOutput {
// Current time
currTime := time.Now().Unix()
key := input.SessionID

// Create a MessageWithTime struct with the provided input and timestamp
messageWithTime := MessageWithTime{
Expand All @@ -60,50 +106,53 @@ func WriteMessage(client *goredis.Client, input ChatMessageWriteInput) ChatMessa
Timestamp: currTime,
}

// Marshal the MessageWithTime struct to JSON
messageJSON, err := json.Marshal(messageWithTime)
if err != nil {
return ChatMessageWriteOutput{Status: false}
// Treat system message differently
if input.Role == "system" {
err := WriteSystemMessage(client, input.SessionID, messageWithTime)
if err != nil {
return ChatMessageWriteOutput{Status: false}
} else {
return ChatMessageWriteOutput{Status: true}
}
}

// Append chat message to the Redis list
err = client.RPush(context.Background(), key, messageJSON).Err()
err := WriteNonSystemMessage(client, input.SessionID, messageWithTime)
if err != nil {
return ChatMessageWriteOutput{Status: false}
} else {
return ChatMessageWriteOutput{Status: true}
}

return ChatMessageWriteOutput{Status: true}
}

// RetrieveSessionMessages retrieves the latest K messages from the Redis list for the given session ID
// RetrieveSessionMessages retrieves the latest K conversation turns from the Redis list for the given session ID
func RetrieveSessionMessages(client *goredis.Client, input ChatHistoryRetrieveInput) ChatHistoryRetrieveOutput {
if input.LatestK == nil || *input.LatestK <= 0 {
input.LatestK = &DefaultLatestK
}
key := input.SessionID

messagesWithTime := []MessageWithTime{}
messages := []*Message{}
ctx := context.Background()

// Determine the start and stop indexes for retrieving the latest k messages
startIndex := int64(0)
stopIndex := int64(*input.LatestK - 1) // The stop index is k-1 to fetch the latest k messages

// Retrieve the latest k messages associated with the sessionID
messageWithTimeJSONs, err := client.LRange(context.Background(), input.SessionID, startIndex, stopIndex).Result()
// Retrieve the latest K conversation turns associated with the session ID by descending timestamp order
messagesNum := *input.LatestK * 2
timestampMessages, err := client.ZRevRange(ctx, key+":timestamps", 0, int64(messagesNum-1)).Result()
if err != nil {
// Handle the error, e.g., log it or return an error response
return ChatHistoryRetrieveOutput{
Messages: messages,
Status: false,
}
}

// Unmarshal retrieved JSON messages into MessageWithTime structs
for _, m := range messageWithTimeJSONs {
// Iterate through the members and deserialize them into MessageWithTime
for _, member := range timestampMessages {
var messageWithTime MessageWithTime
if err := json.Unmarshal([]byte(m), &messageWithTime); err != nil {
// Handle the error, e.g., log it or skip the invalid message
continue
if err := json.Unmarshal([]byte(member), &messageWithTime); err != nil {
return ChatHistoryRetrieveOutput{
Messages: messages,
Status: false,
}
}
messagesWithTime = append(messagesWithTime, messageWithTime)
}
Expand All @@ -113,6 +162,24 @@ func RetrieveSessionMessages(client *goredis.Client, input ChatHistoryRetrieveIn
return messagesWithTime[i].Timestamp < messagesWithTime[j].Timestamp
})

// Add System message if exist
if input.IncludeSystemMessage {
exist, sysMessage, err := RetrieveSystemMessage(client, input.SessionID)
if err != nil {
return ChatHistoryRetrieveOutput{
Messages: messages,
Status: false,
}
}
if exist {
messages = append(messages, &Message{
Role: sysMessage.Role,
Content: sysMessage.Content,
Metadata: sysMessage.Metadata,
})
}
}

// Convert the MessageWithTime structs to Message structs
for _, m := range messagesWithTime {
messages = append(messages, &Message{
Expand Down
18 changes: 16 additions & 2 deletions pkg/redis/config/tasks.json
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,23 @@
"input": {
"instillUIOrder": 0,
"properties": {
"include_system_message": {
"default": true,
"description": "Include system message in the retrieved conversation turns if exists",
"instillAcceptFormats": [
"boolean"
],
"instillUIOrder": 2,
"instillUpstreamTypes": [
"value",
"reference"
],
"title": "Include System Message If Exists",
"type": "boolean"
},
"latest_k": {
"default": 10,
"description": "The number of latest messages to retrieve",
"default": 5,
"description": "The number of latest conversation turns to retrieve. A conversation turn typically includes one participant speaking or sending a message, and the other participant(s) responding to it.",
"instillAcceptFormats": [
"integer"
],
Expand Down

0 comments on commit 0c19492

Please sign in to comment.