-
-
Notifications
You must be signed in to change notification settings - Fork 372
/
Copy pathembeddings.py
162 lines (140 loc) · 4.75 KB
/
embeddings.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
from .models import EmbeddingModel
from .embeddings_migrations import embeddings_migrations
import json
from sqlite_utils import Database
from typing import Any, Dict, List, Tuple, Optional, Union
class Collection:
def __init__(
self,
db: Database,
name: str,
*,
model: Optional[EmbeddingModel] = None,
model_id: Optional[str] = None,
) -> None:
from llm import get_embedding_model
self.db = db
self.name = name
if model and model_id and model.model_id != model_id:
raise ValueError("model_id does not match model.model_id")
if model_id and not model:
model = get_embedding_model(model_id)
self.model = model
self._id = None
def id(self) -> int:
"""
Get the ID of the collection, creating it in the DB if necessary.
Returns:
int: ID of the collection
"""
if self._id is not None:
return self._id
if not self.db["collections"].exists():
embeddings_migrations.apply(self.db)
rows = self.db["collections"].rows_where("name = ?", [self.name])
try:
row = next(rows)
self._id = row["id"]
except StopIteration:
# Create it
self._id = (
self.db["collections"]
.insert(
{
"name": self.name,
"model": self.model.model_id,
}
)
.last_pk
)
return self._id
def exists(self) -> bool:
"""
Check if the collection exists in the DB.
Returns:
bool: True if exists, False otherwise
"""
matches = list(
self.db.query("select 1 from collections where name = ?", (self.name,))
)
return bool(matches)
def count(self) -> int:
"""
Count the number of items in the collection.
Returns:
int: Number of items in the collection
"""
return next(
self.db.query(
"""
select count(*) as c from embeddings where collection_id = (
select id from collections where name = ?
)
""",
(self.name,),
)
)["c"]
def embed(
self,
id: str,
text: str,
metadata: Optional[Dict[str, Any]] = None,
store: bool = False,
) -> None:
"""
Embed a text and store it in the collection with a given ID.
Args:
id (str): ID for the text
text (str): Text to be embedded
metadata (dict, optional): Metadata to be stored
store (bool, optional): Whether to store the text in the content column
"""
from llm import encode
embedding = self.model.embed(text)
self.db["embeddings"].insert(
{
"collection_id": self.id(),
"id": id,
"embedding": encode(embedding),
"content": text if store else None,
"metadata": json.dumps(metadata) if metadata else None,
}
)
def embed_multi(self, id_text_map: Dict[str, str], store: bool = False) -> None:
"""
Embed multiple texts and store them in the collection with given IDs.
Args:
id_text_map (dict): Dictionary mapping IDs to texts
store (bool, optional): Whether to store the text in the content column
"""
raise NotImplementedError
def embed_multi_with_metadata(
self,
id_text_metadata_map: Dict[str, Tuple[str, Dict[str, Union[str, int, float]]]],
) -> None:
"""
Embed multiple texts along with metadata and store them in the collection with given IDs.
Args:
id_text_metadata_map (dict): Dictionary mapping IDs to (text, metadata) tuples
"""
raise NotImplementedError
def similar_by_id(self, id: str, number: int = 5) -> List[Tuple[str, float]]:
"""
Find similar items in the collection by a given ID.
Args:
id (str): ID to search by
number (int, optional): Number of similar items to return
Returns:
list: List of (id, score) tuples
"""
raise NotImplementedError
def similar(self, text: str, number: int = 5) -> List[Tuple[str, float]]:
"""
Find similar items in the collection by a given text.
Args:
text (str): Text to search by
number (int, optional): Number of similar items to return
Returns:
list: List of (id, score) tuples
"""
raise NotImplementedError