diff --git a/go.mod b/go.mod index d000ab03a..c625e5c32 100644 --- a/go.mod +++ b/go.mod @@ -8,6 +8,7 @@ require ( github.com/testcontainers/testcontainers-go v0.29.1 github.com/testcontainers/testcontainers-go/modules/chroma v0.29.1 github.com/testcontainers/testcontainers-go/modules/milvus v0.29.1 + github.com/testcontainers/testcontainers-go/modules/mongodb v0.29.1 github.com/testcontainers/testcontainers-go/modules/mysql v0.29.1 github.com/testcontainers/testcontainers-go/modules/opensearch v0.29.1 github.com/testcontainers/testcontainers-go/modules/postgres v0.29.1 @@ -83,6 +84,7 @@ require ( github.com/gogo/protobuf v1.3.2 // indirect github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect github.com/golang/protobuf v1.5.4 // indirect + github.com/golang/snappy v0.0.4 // indirect github.com/google/flatbuffers v23.5.26+incompatible // indirect github.com/google/go-querystring v1.1.0 // indirect github.com/google/s2a-go v0.1.7 // indirect @@ -117,6 +119,7 @@ require ( github.com/moby/term v0.5.0 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect + github.com/montanaflynn/stats v0.0.0-20171201202039-1bf9dbcd8cbe // indirect github.com/morikuni/aec v1.0.0 // indirect github.com/nlpodyssey/gopickle v0.2.0 // indirect github.com/nlpodyssey/gotokenizers v0.2.0 // indirect @@ -143,13 +146,16 @@ require ( github.com/tidwall/pretty v1.2.0 // indirect github.com/tklauser/go-sysconf v0.3.12 // indirect github.com/tklauser/numcpus v0.6.1 // indirect + github.com/xdg-go/pbkdf2 v1.0.0 // indirect + github.com/xdg-go/scram v1.1.2 // indirect + github.com/xdg-go/stringprep v1.0.4 // indirect github.com/yargevad/filepathx v1.0.0 // indirect + github.com/youmark/pkcs8 v0.0.0-20181117223130-1be2e3e5546d // indirect github.com/yusufpapurcu/wmi v1.2.3 // indirect gitlab.com/golang-commonmark/html v0.0.0-20191124015941-a22733972181 // indirect gitlab.com/golang-commonmark/linkify v0.0.0-20191026162114-a0c2df6c8f82 // indirect gitlab.com/golang-commonmark/mdurl v0.0.0-20191124015652-932350d1cb84 // indirect gitlab.com/golang-commonmark/puny v0.0.0-20191124015043-9f83538fa04f // indirect - go.mongodb.org/mongo-driver v1.11.3 // indirect go.opencensus.io v0.24.0 // indirect go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.49.0 // indirect go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0 // indirect @@ -205,6 +211,7 @@ require ( github.com/weaviate/weaviate v1.23.9 github.com/weaviate/weaviate-go-client/v4 v4.12.1 gitlab.com/golang-commonmark/markdown v0.0.0-20211110145824-bf3e522c626a + go.mongodb.org/mongo-driver v1.13.1 go.starlark.net v0.0.0-20230302034142-4b1e35fe2254 golang.org/x/exp v0.0.0-20230713183714-613f0c0eb8a1 golang.org/x/tools v0.14.0 diff --git a/go.sum b/go.sum index ad922198e..cd59b64f3 100644 --- a/go.sum +++ b/go.sum @@ -329,6 +329,8 @@ github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiu github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= +github.com/golang/snappy v0.0.4 h1:yAGX7huGHXlcLOEtBnF4w7FQwA26wojNCwOYAEhLjQM= +github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/gomodule/redigo v1.7.1-0.20190724094224-574c33c3df38/go.mod h1:B4C85qUVwatsJoIUNIfCRsp7qO0iAmpGFZ4EELWSbC4= github.com/google/flatbuffers v23.5.26+incompatible h1:M9dgRyhJemaM4Sw8+66GHBu8ioaQmyPLg1b8VwK5WJg= github.com/google/flatbuffers v23.5.26+incompatible/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8= @@ -521,6 +523,7 @@ github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lN github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= +github.com/montanaflynn/stats v0.0.0-20171201202039-1bf9dbcd8cbe h1:iruDEfMl2E6fbMZ9s0scYfZQ84/6SPL6zC8ACM2oIL0= github.com/montanaflynn/stats v0.0.0-20171201202039-1bf9dbcd8cbe/go.mod h1:wL8QJuTMNUDYhXwkmfOly8iTdp5TEcJFWZD2D7SIkUc= github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A= github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc= @@ -654,6 +657,8 @@ github.com/testcontainers/testcontainers-go/modules/chroma v0.29.1 h1:xm2LnnrPcK github.com/testcontainers/testcontainers-go/modules/chroma v0.29.1/go.mod h1:R6duRa3bVpkDsTSMffrfRW6wyXtKK2jqRNtDjDLW59Y= github.com/testcontainers/testcontainers-go/modules/milvus v0.29.1 h1:KYlzcq8BF6cg+XoLUx6LgkwP4zkVpQ+33ZBLfm/9pmo= github.com/testcontainers/testcontainers-go/modules/milvus v0.29.1/go.mod h1:IQ6CpkAaf2bYmOnr44obiLjyoGQQhaAhq2QfQ9iBM7Q= +github.com/testcontainers/testcontainers-go/modules/mongodb v0.29.1 h1:UEU6STi5h1A0TcVyAI8MtAPxnLD6DrDogZpTQ6TZ4qs= +github.com/testcontainers/testcontainers-go/modules/mongodb v0.29.1/go.mod h1:OanSjytpk9EwgnJwoDC7vx9fIuCiOdTP8TsW1sIrjEY= github.com/testcontainers/testcontainers-go/modules/mysql v0.29.1 h1:SnJtZNcskgxOMyVAT7M+MQjpveP59nwKzlBw2ItX+C8= github.com/testcontainers/testcontainers-go/modules/mysql v0.29.1/go.mod h1:VhA5dV+O19sx3Y9u9bfO+fbJfP3E7RiMq0nDMEGjslw= github.com/testcontainers/testcontainers-go/modules/opensearch v0.29.1 h1:QoSRd5e+XAJo6sVv7pREf6cgHJ5I5+0aAT9IK2INVaM= @@ -712,11 +717,16 @@ github.com/weaviate/weaviate-go-client/v4 v4.12.1 h1:XFKL49BgSOcxrFs5IV+Q5pydLTs github.com/weaviate/weaviate-go-client/v4 v4.12.1/go.mod h1:r1PlU5sAZKFvAPgymEHQj0hjSAuEV9X77PJ/ffZ6cEo= github.com/x-cray/logrus-prefixed-formatter v0.5.2 h1:00txxvfBM9muc0jiLIEAkAcIMJzfthRT6usrui8uGmg= github.com/x-cray/logrus-prefixed-formatter v0.5.2/go.mod h1:2duySbKsL6M18s5GU7VPsoEPHyzalCE06qoARUCeBBE= +github.com/xdg-go/pbkdf2 v1.0.0 h1:Su7DPu48wXMwC3bs7MCNG+z4FhcyEuz5dlvchbq0B0c= github.com/xdg-go/pbkdf2 v1.0.0/go.mod h1:jrpuAogTd400dnrH08LKmI/xc1MbPOebTwRqcT5RDeI= github.com/xdg-go/scram v1.0.2/go.mod h1:1WAq6h33pAW+iRreB34OORO2Nf7qel3VV3fjBj+hCSs= github.com/xdg-go/scram v1.1.1/go.mod h1:RaEWvsqvNKKvBPvcKeFjrG2cJqOkHTiyTpzz23ni57g= +github.com/xdg-go/scram v1.1.2 h1:FHX5I5B4i4hKRVRBCFRxq1iQRej7WO3hhBuJf+UUySY= +github.com/xdg-go/scram v1.1.2/go.mod h1:RT/sEzTbU5y00aCK8UOx6R7YryM0iF1N2MOmC3kKLN4= github.com/xdg-go/stringprep v1.0.2/go.mod h1:8F9zXuvzgwmyT5DUm4GUfZGDdT3W+LCvS6+da4O5kxM= github.com/xdg-go/stringprep v1.0.3/go.mod h1:W3f5j4i+9rC0kuIEJL0ky1VpHXQU3ocBgklLGvcBnW8= +github.com/xdg-go/stringprep v1.0.4 h1:XLI/Ng3O1Atzq0oBs3TWm+5ZVgkq2aqdlvP9JtoZ6c8= +github.com/xdg-go/stringprep v1.0.4/go.mod h1:mPGuuIYwz7CmR2bT9j4GbQqutWS1zV24gijq1dTyGkM= github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f/go.mod h1:N2zxlSyiKSe5eX1tZViRH5QA0qijqEDrYZiPEAiq3wU= github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415/go.mod h1:GwrjFmJcFw6At/Gs6z4yjiIwzuJ1/+UwLxMQDVQXShQ= github.com/xeipuuv/gojsonschema v1.2.0/go.mod h1:anYRn/JVcOK2ZgGU+IjEV4nwlhoK5sQluxsYJ78Id3Y= @@ -724,6 +734,7 @@ github.com/xordataexchange/crypt v0.0.3-0.20170626215501-b2862e3d0a77/go.mod h1: github.com/yalp/jsonpath v0.0.0-20180802001716-5cc68e5049a0/go.mod h1:/LWChgwKmvncFJFHJ7Gvn9wZArjbV5/FppcK2fKk/tI= github.com/yargevad/filepathx v1.0.0 h1:SYcT+N3tYGi+NvazubCNlvgIPbzAk7i7y2dwg3I5FYc= github.com/yargevad/filepathx v1.0.0/go.mod h1:BprfX/gpYNJHJfc35GjRRpVcwWXS89gGulUIU5tK3tA= +github.com/youmark/pkcs8 v0.0.0-20181117223130-1be2e3e5546d h1:splanxYIlg+5LfHAM6xpdFEAYOk8iySO56hMFq6uLyA= github.com/youmark/pkcs8 v0.0.0-20181117223130-1be2e3e5546d/go.mod h1:rHwXgn7JulP+udvsHwJoVG1YGAP6VLg4y9I5dyZdqmA= github.com/yudai/gojsondiff v1.0.0/go.mod h1:AY32+k2cwILAkW1fbgxQ5mUmMiZFgLIV+FBNExI05xg= github.com/yudai/golcs v0.0.0-20170316035057-ecda9a501e82/go.mod h1:lgjkn3NuSvDfVJdfcVVdX+jpBxNmX4rDAzaS45IcYoM= @@ -749,8 +760,8 @@ gitlab.com/opennota/wd v0.0.0-20180912061657-c5d65f63c638/go.mod h1:EGRJaqe2eO9X go.mongodb.org/mongo-driver v1.7.3/go.mod h1:NqaYOwnXWr5Pm7AOpO5QFxKJ503nbMse/R79oO62zWg= go.mongodb.org/mongo-driver v1.7.5/go.mod h1:VXEWRZ6URJIkUq2SCAyapmhH0ZLRBP+FT4xhp5Zvxng= go.mongodb.org/mongo-driver v1.10.0/go.mod h1:wsihk0Kdgv8Kqu1Anit4sfK+22vSFbUrAVEYRhCXrA8= -go.mongodb.org/mongo-driver v1.11.3 h1:Ql6K6qYHEzB6xvu4+AU0BoRoqf9vFPcc4o7MUIdPW8Y= -go.mongodb.org/mongo-driver v1.11.3/go.mod h1:PTSz5yu21bkT/wXpkS7WR5f0ddqw5quethTUn9WM+2g= +go.mongodb.org/mongo-driver v1.13.1 h1:YIc7HTYsKndGK4RFzJ3covLz1byri52x0IoMB0Pt/vk= +go.mongodb.org/mongo-driver v1.13.1/go.mod h1:wcDf1JBCXy2mOW0bWHwO/IOYqdca1MPCwDtFu/Z9+eo= go.opencensus.io v0.24.0 h1:y73uSU6J157QMP2kn2r30vwW1A2W2WFwSCGnAVxeaD0= go.opencensus.io v0.24.0/go.mod h1:vNK8G9p7aAivkbmorf4v+7Hgx+Zs0yY+0fOtgBfjQKo= go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.49.0 h1:4Pp6oUg3+e/6M4C0A/3kJ2VYa++dsWVTtGgLVj5xtHg= diff --git a/internal/mongodb/client.go b/internal/mongodb/client.go new file mode 100644 index 000000000..c519b22f6 --- /dev/null +++ b/internal/mongodb/client.go @@ -0,0 +1,23 @@ +package mongodb + +import ( + "context" + + "go.mongodb.org/mongo-driver/mongo" + "go.mongodb.org/mongo-driver/mongo/options" + "go.mongodb.org/mongo-driver/mongo/readpref" +) + +func NewClient(ctx context.Context, url string) (*mongo.Client, error) { + client, err := mongo.Connect(ctx, options.Client().ApplyURI(url)) + if err != nil { + return nil, err + } + + err = client.Ping(ctx, readpref.Primary()) + if err != nil { + return nil, err + } + + return client, nil +} diff --git a/llms/chat_messages.go b/llms/chat_messages.go index 12946e9b9..9cdaf4faf 100644 --- a/llms/chat_messages.go +++ b/llms/chat_messages.go @@ -4,6 +4,7 @@ import ( "encoding/json" "errors" "fmt" + "log/slog" "strings" ) @@ -165,3 +166,36 @@ func getMessageRole(m ChatMessage, humanPrefix, aiPrefix string) (string, error) } return role, nil } + +type ChatMessageModelData struct { + Content string `bson:"content" json:"content"` + Type string `bson:"type" json:"type"` +} + +type ChatMessageModel struct { + Type string `bson:"type" json:"type"` + Data ChatMessageModelData `bson:"data" json:"data"` +} + +func (c ChatMessageModel) ToChatMessage() ChatMessage { + switch c.Type { + case string(ChatMessageTypeAI): + return AIChatMessage{Content: c.Data.Content} + case string(ChatMessageTypeHuman): + return HumanChatMessage{Content: c.Data.Content} + default: + slog.Warn("convert to chat message failed with invalid message type", "type", c.Type) + return nil + } +} + +// ConvertChatMessageToModel Convert a ChatMessage to a ChatMessageModel. +func ConvertChatMessageToModel(m ChatMessage) ChatMessageModel { + return ChatMessageModel{ + Type: string(m.GetType()), + Data: ChatMessageModelData{ + Type: string(m.GetType()), + Content: m.GetContent(), + }, + } +} diff --git a/memory/mongo_chat_history.go b/memory/mongo_chat_history.go new file mode 100644 index 000000000..92aca228a --- /dev/null +++ b/memory/mongo_chat_history.go @@ -0,0 +1,136 @@ +package memory + +import ( + "context" + "encoding/json" + + "github.com/tmc/langchaingo/internal/mongodb" + "github.com/tmc/langchaingo/llms" + "github.com/tmc/langchaingo/schema" + "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/mongo" +) + +const ( + // mongoSessionIDKey a unique identifier of the session, like user name, email, chat id etc. + // same as langchain. + mongoSessionIDKey = "SessionId" +) + +type MongoDBChatMessageHistory struct { + url string + sessionID string + databaseName string + collectionName string + client *mongo.Client + collection *mongo.Collection +} + +type chatMessageModel struct { + SessionID string `bson:"SessionId" json:"SessionId"` + History string `bson:"History" json:"History"` +} + +// Statically assert that MongoDBChatMessageHistory implement the chat message history interface. +var _ schema.ChatMessageHistory = &MongoDBChatMessageHistory{} + +// NewMongoDBChatMessageHistory creates a new MongoDBChatMessageHistory using chat message options. +func NewMongoDBChatMessageHistory(ctx context.Context, options ...MongoDBChatMessageHistoryOption) (*MongoDBChatMessageHistory, error) { + h, err := applyMongoDBChatOptions(options...) + if err != nil { + return nil, err + } + + client, err := mongodb.NewClient(ctx, h.url) + if err != nil { + return nil, err + } + + h.client = client + + h.collection = client.Database(h.databaseName).Collection(h.collectionName) + // create session id index + if _, err := h.collection.Indexes().CreateOne(ctx, mongo.IndexModel{Keys: bson.D{{Key: mongoSessionIDKey, Value: 1}}}); err != nil { + return nil, err + } + + return h, nil +} + +// Messages returns all messages stored. +func (h *MongoDBChatMessageHistory) Messages(ctx context.Context) ([]llms.ChatMessage, error) { + messages := []llms.ChatMessage{} + filter := bson.M{mongoSessionIDKey: h.sessionID} + cursor, err := h.collection.Find(ctx, filter) + if err != nil { + return messages, err + } + + _messages := []chatMessageModel{} + if err := cursor.All(ctx, &_messages); err != nil { + return messages, err + } + for _, message := range _messages { + m := llms.ChatMessageModel{} + if err := json.Unmarshal([]byte(message.History), &m); err != nil { + return messages, err + } + messages = append(messages, m.ToChatMessage()) + } + + return messages, nil +} + +// AddAIMessage adds an AIMessage to the chat message history. +func (h *MongoDBChatMessageHistory) AddAIMessage(ctx context.Context, text string) error { + return h.AddMessage(ctx, llms.AIChatMessage{Content: text}) +} + +// AddUserMessage adds a user to the chat message history. +func (h *MongoDBChatMessageHistory) AddUserMessage(ctx context.Context, text string) error { + return h.AddMessage(ctx, llms.HumanChatMessage{Content: text}) +} + +// Clear clear session memory from MongoDB. +func (h *MongoDBChatMessageHistory) Clear(ctx context.Context) error { + filter := bson.M{mongoSessionIDKey: h.sessionID} + _, err := h.collection.DeleteMany(ctx, filter) + return err +} + +// AddMessage adds a message to the store. +func (h *MongoDBChatMessageHistory) AddMessage(ctx context.Context, message llms.ChatMessage) error { + _message, err := json.Marshal(llms.ConvertChatMessageToModel(message)) + if err != nil { + return err + } + + _, err = h.collection.InsertOne(ctx, chatMessageModel{ + SessionID: h.sessionID, + History: string(_message), + }) + + return err +} + +// SetMessages replaces existing messages in the store. +func (h *MongoDBChatMessageHistory) SetMessages(ctx context.Context, messages []llms.ChatMessage) error { + _messages := []interface{}{} + for _, message := range messages { + _message, err := json.Marshal(llms.ConvertChatMessageToModel(message)) + if err != nil { + return err + } + _messages = append(_messages, chatMessageModel{ + SessionID: h.sessionID, + History: string(_message), + }) + } + + if err := h.Clear(ctx); err != nil { + return err + } + + _, err := h.collection.InsertMany(ctx, _messages) + return err +} diff --git a/memory/mongo_chat_history_options.go b/memory/mongo_chat_history_options.go new file mode 100644 index 000000000..89da8b988 --- /dev/null +++ b/memory/mongo_chat_history_options.go @@ -0,0 +1,66 @@ +package memory + +import ( + "errors" +) + +const ( + mongoDefaultDBName = "chat_history" + mongoDefaultCollectionName = "message_store" +) + +var ( + errMongoInvalidURL = errors.New("invalid mongo url option") + errMongoInvalidSessionID = errors.New("invalid mongo session id option") +) + +type MongoDBChatMessageHistoryOption func(m *MongoDBChatMessageHistory) + +func applyMongoDBChatOptions(options ...MongoDBChatMessageHistoryOption) (*MongoDBChatMessageHistory, error) { + h := &MongoDBChatMessageHistory{ + databaseName: mongoDefaultDBName, + collectionName: mongoDefaultCollectionName, + } + + for _, option := range options { + option(h) + } + + if h.url == "" { + return nil, errMongoInvalidURL + } + if h.sessionID == "" { + return nil, errMongoInvalidSessionID + } + + return h, nil +} + +// WithConnectionURL is an option for specifying the MongoDB connection URL. Must be set. +func WithConnectionURL(connectionURL string) MongoDBChatMessageHistoryOption { + return func(p *MongoDBChatMessageHistory) { + p.url = connectionURL + } +} + +// WithSessionID is an arbitrary key that is used to store the messages of a single chat session, +// like user name, email, chat id etc. Must be set. +func WithSessionID(sessionID string) MongoDBChatMessageHistoryOption { + return func(p *MongoDBChatMessageHistory) { + p.sessionID = sessionID + } +} + +// WithCollectionName is an option for specifying the collection name. +func WithCollectionName(name string) MongoDBChatMessageHistoryOption { + return func(p *MongoDBChatMessageHistory) { + p.collectionName = name + } +} + +// WithDataBaseName is an option for specifying the database name. +func WithDataBaseName(name string) MongoDBChatMessageHistoryOption { + return func(p *MongoDBChatMessageHistory) { + p.databaseName = name + } +} diff --git a/memory/mongo_chat_history_test.go b/memory/mongo_chat_history_test.go new file mode 100644 index 000000000..73826c16d --- /dev/null +++ b/memory/mongo_chat_history_test.go @@ -0,0 +1,67 @@ +package memory + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/testcontainers/testcontainers-go" + "github.com/testcontainers/testcontainers-go/modules/mongodb" + "github.com/tmc/langchaingo/llms" +) + +func runTestContainer() (string, error) { + ctx := context.Background() + + mongoContainer, err := mongodb.RunContainer( + ctx, + testcontainers.WithImage("mongo:7.0.8"), + mongodb.WithUsername("test"), + mongodb.WithPassword("test"), + ) + if err != nil { + return "", err + } + + url, err := mongoContainer.ConnectionString(ctx) + if err != nil { + return "", err + } + return url, nil +} + +func TestMongoDBChatMessageHistory(t *testing.T) { + t.Parallel() + + url, err := runTestContainer() + require.NoError(t, err) + + ctx := context.Background() + _, err = NewMongoDBChatMessageHistory(ctx, WithSessionID("test")) + assert.Equal(t, errMongoInvalidURL, err) + + _, err = NewMongoDBChatMessageHistory(ctx, WithConnectionURL(url)) + assert.Equal(t, errMongoInvalidSessionID, err) + + history, err := NewMongoDBChatMessageHistory(ctx, WithConnectionURL(url), WithSessionID("testSessionXX")) + require.NoError(t, err) + + err = history.AddAIMessage(ctx, "Hi") + require.NoError(t, err) + + err = history.AddUserMessage(ctx, "Hello") + require.NoError(t, err) + + messages, err := history.Messages(ctx) + require.NoError(t, err) + + assert.Len(t, messages, 2) + assert.Equal(t, llms.ChatMessageTypeAI, messages[0].GetType()) + assert.Equal(t, "Hi", messages[0].GetContent()) + assert.Equal(t, llms.ChatMessageTypeHuman, messages[1].GetType()) + assert.Equal(t, "Hello", messages[1].GetContent()) + t.Cleanup(func() { + require.NoError(t, history.Clear(ctx)) + }) +}