-
Notifications
You must be signed in to change notification settings - Fork 16.5k
/
Copy pathelasticsearch.py
210 lines (182 loc) Β· 7.01 KB
/
elasticsearch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
import json
import logging
from time import time
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from langchain_core._api import deprecated
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.messages import (
BaseMessage,
message_to_dict,
messages_from_dict,
)
if TYPE_CHECKING:
from elasticsearch import Elasticsearch
logger = logging.getLogger(__name__)
@deprecated("0.0.27", alternative="Use langchain-elasticsearch package", pending=True)
class ElasticsearchChatMessageHistory(BaseChatMessageHistory):
"""Chat message history that stores history in Elasticsearch.
Args:
es_url: URL of the Elasticsearch instance to connect to.
es_cloud_id: Cloud ID of the Elasticsearch instance to connect to.
es_user: Username to use when connecting to Elasticsearch.
es_password: Password to use when connecting to Elasticsearch.
es_api_key: API key to use when connecting to Elasticsearch.
es_connection: Optional pre-existing Elasticsearch connection.
esnsure_ascii: Used to escape ASCII symbols in json.dumps. Defaults to True.
index: Name of the index to use.
session_id: Arbitrary key that is used to store the messages
of a single chat session.
"""
def __init__(
self,
index: str,
session_id: str,
*,
es_connection: Optional["Elasticsearch"] = None,
es_url: Optional[str] = None,
es_cloud_id: Optional[str] = None,
es_user: Optional[str] = None,
es_api_key: Optional[str] = None,
es_password: Optional[str] = None,
esnsure_ascii: Optional[bool] = True,
):
self.index: str = index
self.session_id: str = session_id
self.ensure_ascii = esnsure_ascii
# Initialize Elasticsearch client from passed client arg or connection info
if es_connection is not None:
self.client = es_connection.options(
headers={"user-agent": self.get_user_agent()}
)
elif es_url is not None or es_cloud_id is not None:
self.client = ElasticsearchChatMessageHistory.connect_to_elasticsearch(
es_url=es_url,
username=es_user,
password=es_password,
cloud_id=es_cloud_id,
api_key=es_api_key,
)
else:
raise ValueError(
"""Either provide a pre-existing Elasticsearch connection, \
or valid credentials for creating a new connection."""
)
if self.client.indices.exists(index=index):
logger.debug(
f"Chat history index {index} already exists, skipping creation."
)
else:
logger.debug(f"Creating index {index} for storing chat history.")
self.client.indices.create(
index=index,
mappings={
"properties": {
"session_id": {"type": "keyword"},
"created_at": {"type": "date"},
"history": {"type": "text"},
}
},
)
@staticmethod
def get_user_agent() -> str:
from langchain_community import __version__
return f"langchain-py-ms/{__version__}"
@staticmethod
def connect_to_elasticsearch(
*,
es_url: Optional[str] = None,
cloud_id: Optional[str] = None,
api_key: Optional[str] = None,
username: Optional[str] = None,
password: Optional[str] = None,
) -> "Elasticsearch":
try:
import elasticsearch
except ImportError:
raise ImportError(
"Could not import elasticsearch python package. "
"Please install it with `pip install elasticsearch`."
)
if es_url and cloud_id:
raise ValueError(
"Both es_url and cloud_id are defined. Please provide only one."
)
connection_params: Dict[str, Any] = {}
if es_url:
connection_params["hosts"] = [es_url]
elif cloud_id:
connection_params["cloud_id"] = cloud_id
else:
raise ValueError("Please provide either elasticsearch_url or cloud_id.")
if api_key:
connection_params["api_key"] = api_key
elif username and password:
connection_params["basic_auth"] = (username, password)
es_client = elasticsearch.Elasticsearch(
**connection_params,
headers={"user-agent": ElasticsearchChatMessageHistory.get_user_agent()},
)
try:
es_client.info()
except Exception as err:
logger.error(f"Error connecting to Elasticsearch: {err}")
raise err
return es_client
@property
def messages(self) -> List[BaseMessage]:
"""Retrieve the messages from Elasticsearch"""
try:
from elasticsearch import ApiError
result = self.client.search(
index=self.index,
query={"term": {"session_id": self.session_id}},
sort="created_at:asc",
)
except ApiError as err:
logger.error(f"Could not retrieve messages from Elasticsearch: {err}")
raise err
if result and len(result["hits"]["hits"]) > 0:
items = [
json.loads(document["_source"]["history"])
for document in result["hits"]["hits"]
]
else:
items = []
return messages_from_dict(items)
@messages.setter
def messages(self, messages: List[BaseMessage]) -> None:
raise NotImplementedError(
"Direct assignment to 'messages' is not allowed."
" Use the 'add_messages' instead."
)
def add_message(self, message: BaseMessage) -> None:
"""Add a message to the chat session in Elasticsearch"""
try:
from elasticsearch import ApiError
self.client.index(
index=self.index,
document={
"session_id": self.session_id,
"created_at": round(time() * 1000),
"history": json.dumps(
message_to_dict(message),
ensure_ascii=bool(self.ensure_ascii),
),
},
refresh=True,
)
except ApiError as err:
logger.error(f"Could not add message to Elasticsearch: {err}")
raise err
def clear(self) -> None:
"""Clear session memory in Elasticsearch"""
try:
from elasticsearch import ApiError
self.client.delete_by_query(
index=self.index,
query={"term": {"session_id": self.session_id}},
refresh=True,
)
except ApiError as err:
logger.error(f"Could not clear session memory in Elasticsearch: {err}")
raise err