forked from tmc/langchaingo
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
memory: Add mongodb memory implementation (tmc#810)
* feat(internal): add mongodb client * feat(memory): add mongodb memory
- Loading branch information
Showing
7 changed files
with
347 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
package mongodb | ||
|
||
import ( | ||
"context" | ||
|
||
"go.mongodb.org/mongo-driver/mongo" | ||
"go.mongodb.org/mongo-driver/mongo/options" | ||
"go.mongodb.org/mongo-driver/mongo/readpref" | ||
) | ||
|
||
func NewClient(ctx context.Context, url string) (*mongo.Client, error) { | ||
client, err := mongo.Connect(ctx, options.Client().ApplyURI(url)) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
err = client.Ping(ctx, readpref.Primary()) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
return client, nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,136 @@ | ||
package memory | ||
|
||
import ( | ||
"context" | ||
"encoding/json" | ||
|
||
"github.com/tmc/langchaingo/internal/mongodb" | ||
"github.com/tmc/langchaingo/llms" | ||
"github.com/tmc/langchaingo/schema" | ||
"go.mongodb.org/mongo-driver/bson" | ||
"go.mongodb.org/mongo-driver/mongo" | ||
) | ||
|
||
const ( | ||
// mongoSessionIDKey a unique identifier of the session, like user name, email, chat id etc. | ||
// same as langchain. | ||
mongoSessionIDKey = "SessionId" | ||
) | ||
|
||
type MongoDBChatMessageHistory struct { | ||
url string | ||
sessionID string | ||
databaseName string | ||
collectionName string | ||
client *mongo.Client | ||
collection *mongo.Collection | ||
} | ||
|
||
type chatMessageModel struct { | ||
SessionID string `bson:"SessionId" json:"SessionId"` | ||
History string `bson:"History" json:"History"` | ||
} | ||
|
||
// Statically assert that MongoDBChatMessageHistory implement the chat message history interface. | ||
var _ schema.ChatMessageHistory = &MongoDBChatMessageHistory{} | ||
|
||
// NewMongoDBChatMessageHistory creates a new MongoDBChatMessageHistory using chat message options. | ||
func NewMongoDBChatMessageHistory(ctx context.Context, options ...MongoDBChatMessageHistoryOption) (*MongoDBChatMessageHistory, error) { | ||
h, err := applyMongoDBChatOptions(options...) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
client, err := mongodb.NewClient(ctx, h.url) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
h.client = client | ||
|
||
h.collection = client.Database(h.databaseName).Collection(h.collectionName) | ||
// create session id index | ||
if _, err := h.collection.Indexes().CreateOne(ctx, mongo.IndexModel{Keys: bson.D{{Key: mongoSessionIDKey, Value: 1}}}); err != nil { | ||
return nil, err | ||
} | ||
|
||
return h, nil | ||
} | ||
|
||
// Messages returns all messages stored. | ||
func (h *MongoDBChatMessageHistory) Messages(ctx context.Context) ([]llms.ChatMessage, error) { | ||
messages := []llms.ChatMessage{} | ||
filter := bson.M{mongoSessionIDKey: h.sessionID} | ||
cursor, err := h.collection.Find(ctx, filter) | ||
if err != nil { | ||
return messages, err | ||
} | ||
|
||
_messages := []chatMessageModel{} | ||
if err := cursor.All(ctx, &_messages); err != nil { | ||
return messages, err | ||
} | ||
for _, message := range _messages { | ||
m := llms.ChatMessageModel{} | ||
if err := json.Unmarshal([]byte(message.History), &m); err != nil { | ||
return messages, err | ||
} | ||
messages = append(messages, m.ToChatMessage()) | ||
} | ||
|
||
return messages, nil | ||
} | ||
|
||
// AddAIMessage adds an AIMessage to the chat message history. | ||
func (h *MongoDBChatMessageHistory) AddAIMessage(ctx context.Context, text string) error { | ||
return h.AddMessage(ctx, llms.AIChatMessage{Content: text}) | ||
} | ||
|
||
// AddUserMessage adds a user to the chat message history. | ||
func (h *MongoDBChatMessageHistory) AddUserMessage(ctx context.Context, text string) error { | ||
return h.AddMessage(ctx, llms.HumanChatMessage{Content: text}) | ||
} | ||
|
||
// Clear clear session memory from MongoDB. | ||
func (h *MongoDBChatMessageHistory) Clear(ctx context.Context) error { | ||
filter := bson.M{mongoSessionIDKey: h.sessionID} | ||
_, err := h.collection.DeleteMany(ctx, filter) | ||
return err | ||
} | ||
|
||
// AddMessage adds a message to the store. | ||
func (h *MongoDBChatMessageHistory) AddMessage(ctx context.Context, message llms.ChatMessage) error { | ||
_message, err := json.Marshal(llms.ConvertChatMessageToModel(message)) | ||
if err != nil { | ||
return err | ||
} | ||
|
||
_, err = h.collection.InsertOne(ctx, chatMessageModel{ | ||
SessionID: h.sessionID, | ||
History: string(_message), | ||
}) | ||
|
||
return err | ||
} | ||
|
||
// SetMessages replaces existing messages in the store. | ||
func (h *MongoDBChatMessageHistory) SetMessages(ctx context.Context, messages []llms.ChatMessage) error { | ||
_messages := []interface{}{} | ||
for _, message := range messages { | ||
_message, err := json.Marshal(llms.ConvertChatMessageToModel(message)) | ||
if err != nil { | ||
return err | ||
} | ||
_messages = append(_messages, chatMessageModel{ | ||
SessionID: h.sessionID, | ||
History: string(_message), | ||
}) | ||
} | ||
|
||
if err := h.Clear(ctx); err != nil { | ||
return err | ||
} | ||
|
||
_, err := h.collection.InsertMany(ctx, _messages) | ||
return err | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
package memory | ||
|
||
import ( | ||
"errors" | ||
) | ||
|
||
const ( | ||
mongoDefaultDBName = "chat_history" | ||
mongoDefaultCollectionName = "message_store" | ||
) | ||
|
||
var ( | ||
errMongoInvalidURL = errors.New("invalid mongo url option") | ||
errMongoInvalidSessionID = errors.New("invalid mongo session id option") | ||
) | ||
|
||
type MongoDBChatMessageHistoryOption func(m *MongoDBChatMessageHistory) | ||
|
||
func applyMongoDBChatOptions(options ...MongoDBChatMessageHistoryOption) (*MongoDBChatMessageHistory, error) { | ||
h := &MongoDBChatMessageHistory{ | ||
databaseName: mongoDefaultDBName, | ||
collectionName: mongoDefaultCollectionName, | ||
} | ||
|
||
for _, option := range options { | ||
option(h) | ||
} | ||
|
||
if h.url == "" { | ||
return nil, errMongoInvalidURL | ||
} | ||
if h.sessionID == "" { | ||
return nil, errMongoInvalidSessionID | ||
} | ||
|
||
return h, nil | ||
} | ||
|
||
// WithConnectionURL is an option for specifying the MongoDB connection URL. Must be set. | ||
func WithConnectionURL(connectionURL string) MongoDBChatMessageHistoryOption { | ||
return func(p *MongoDBChatMessageHistory) { | ||
p.url = connectionURL | ||
} | ||
} | ||
|
||
// WithSessionID is an arbitrary key that is used to store the messages of a single chat session, | ||
// like user name, email, chat id etc. Must be set. | ||
func WithSessionID(sessionID string) MongoDBChatMessageHistoryOption { | ||
return func(p *MongoDBChatMessageHistory) { | ||
p.sessionID = sessionID | ||
} | ||
} | ||
|
||
// WithCollectionName is an option for specifying the collection name. | ||
func WithCollectionName(name string) MongoDBChatMessageHistoryOption { | ||
return func(p *MongoDBChatMessageHistory) { | ||
p.collectionName = name | ||
} | ||
} | ||
|
||
// WithDataBaseName is an option for specifying the database name. | ||
func WithDataBaseName(name string) MongoDBChatMessageHistoryOption { | ||
return func(p *MongoDBChatMessageHistory) { | ||
p.databaseName = name | ||
} | ||
} |
Oops, something went wrong.