Skip to content

Commit

Permalink
Vector search example
Browse files Browse the repository at this point in the history
  • Loading branch information
miguelgrinberg committed Apr 12, 2024
1 parent c650b62 commit 4cda418
Show file tree
Hide file tree
Showing 2 changed files with 339 additions and 0 deletions.
170 changes: 170 additions & 0 deletions examples/async/vectors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
# Licensed to Elasticsearch B.V. under one or more contributor
# license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright
# ownership. Elasticsearch B.V. licenses this file to you under
# the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""
# Vector database example
Requirements:
$ pip install nltk sentence_transformers tqdm elasticsearch_dsl
To run the example:
$ python vectors.py "text to search"
The index will be created automatically if it does not exist. Add `--create` to
regenerate it.
The example dataset includes a selection of workplace documentation. The
following are good example queries to try out:
$ python vectors.py "work from home"
$ python vectors.py "vacation time"
$ python vectors.py "bring a bird to work"
"""

import argparse
import asyncio
import json
import os
from urllib.request import urlopen

import nltk
from sentence_transformers import SentenceTransformer
from tqdm import tqdm

from elasticsearch_dsl import (
AsyncDocument,
Date,
DenseVector,
InnerDoc,
Keyword,
Nested,
Text,
async_connections,
)

DATASET_URL = "https://raw.githubusercontent.com/elastic/elasticsearch-labs/main/datasets/workplace-documents.json"
MODEL_NAME = "all-MiniLM-L6-v2"

# initialize sentence tokenizer
nltk.download("punkt", quiet=True)


class Passage(InnerDoc):
content = Text()
embedding = DenseVector()


class WorkplaceDoc(AsyncDocument):
class Index:
name = "workplace_documents"

name = Text()
summary = Text()
content = Text()
created = Date()
updated = Date()
url = Keyword()
category = Keyword()
passages = Nested(Passage)

_model = None

@classmethod
def get_embedding_model(cls):
if cls._model is None:
cls._model = SentenceTransformer(MODEL_NAME)
return cls._model

def clean(self):
# split the content into sentences
passages = nltk.sent_tokenize(self.content)

# generate an embedding for each passage and save it as a nested document
model = self.get_embedding_model()
for passage in passages:
self.passages.append(
Passage(content=passage, embedding=list(model.encode(passage)))
)


async def create():

# create the index
await WorkplaceDoc._index.delete(ignore_unavailable=True)
await WorkplaceDoc.init()

# download the data
dataset = json.loads(urlopen(DATASET_URL).read())

# import the dataset
for data in tqdm(dataset, desc="Indexing documents..."):
doc = WorkplaceDoc(
name=data["name"],
summary=data["summary"],
content=data["content"],
created=data.get("created_on"),
updated=data.get("updated_at"),
url=data["url"],
category=data["category"],
)
await doc.save()


async def search(query):
model = WorkplaceDoc.get_embedding_model()
search = WorkplaceDoc.search().knn(
field="passages.embedding",
k=5,
num_candidates=50,
query_vector=list(model.encode(query)),
inner_hits={"size": 3},
)
async for hit in search:
print(f"Document: {hit.name} (Category: {hit.category}")
for passage in hit.meta.inner_hits.passages:
print(f" - [Score: {passage.meta.score}] {passage.content!r}")
print("")


def parse_args():
parser = argparse.ArgumentParser(description="Vector database with Elasticsearch")
parser.add_argument(
"--create", action="store_true", help="Create and populate a new index"
)
parser.add_argument("query", action="store", help="The search query")
return parser.parse_args()


async def main():
args = parse_args()

# initiate the default connection to elasticsearch
async_connections.create_connection(hosts=[os.environ["ELASTICSEARCH_URL"]])

if args.create or not await WorkplaceDoc._index.exists():
await create()

await search(args.query)

# close the connection
await async_connections.get_connection().close()


if __name__ == "__main__":
asyncio.run(main())
169 changes: 169 additions & 0 deletions examples/vectors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
# Licensed to Elasticsearch B.V. under one or more contributor
# license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright
# ownership. Elasticsearch B.V. licenses this file to you under
# the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""
# Vector database example
Requirements:
$ pip install nltk sentence_transformers tqdm elasticsearch_dsl
To run the example:
$ python vectors.py "text to search"
The index will be created automatically if it does not exist. Add `--create` to
regenerate it.
The example dataset includes a selection of workplace documentation. The
following are good example queries to try out:
$ python vectors.py "work from home"
$ python vectors.py "vacation time"
$ python vectors.py "bring a bird to work"
"""

import argparse
import json
import os
from urllib.request import urlopen

import nltk
from sentence_transformers import SentenceTransformer
from tqdm import tqdm

from elasticsearch_dsl import (
Date,
DenseVector,
Document,
InnerDoc,
Keyword,
Nested,
Text,
connections,
)

DATASET_URL = "https://raw.githubusercontent.com/elastic/elasticsearch-labs/main/datasets/workplace-documents.json"
MODEL_NAME = "all-MiniLM-L6-v2"

# initialize sentence tokenizer
nltk.download("punkt", quiet=True)


class Passage(InnerDoc):
content = Text()
embedding = DenseVector()


class WorkplaceDoc(Document):
class Index:
name = "workplace_documents"

name = Text()
summary = Text()
content = Text()
created = Date()
updated = Date()
url = Keyword()
category = Keyword()
passages = Nested(Passage)

_model = None

@classmethod
def get_embedding_model(cls):
if cls._model is None:
cls._model = SentenceTransformer(MODEL_NAME)
return cls._model

def clean(self):
# split the content into sentences
passages = nltk.sent_tokenize(self.content)

# generate an embedding for each passage and save it as a nested document
model = self.get_embedding_model()
for passage in passages:
self.passages.append(
Passage(content=passage, embedding=list(model.encode(passage)))
)


def create():

# create the index
WorkplaceDoc._index.delete(ignore_unavailable=True)
WorkplaceDoc.init()

# download the data
dataset = json.loads(urlopen(DATASET_URL).read())

# import the dataset
for data in tqdm(dataset, desc="Indexing documents..."):
doc = WorkplaceDoc(
name=data["name"],
summary=data["summary"],
content=data["content"],
created=data.get("created_on"),
updated=data.get("updated_at"),
url=data["url"],
category=data["category"],
)
doc.save()


def search(query):
model = WorkplaceDoc.get_embedding_model()
search = WorkplaceDoc.search().knn(
field="passages.embedding",
k=5,
num_candidates=50,
query_vector=list(model.encode(query)),
inner_hits={"size": 3},
)
for hit in search:
print(f"Document: {hit.name} (Category: {hit.category}")
for passage in hit.meta.inner_hits.passages:
print(f" - [Score: {passage.meta.score}] {passage.content!r}")
print("")


def parse_args():
parser = argparse.ArgumentParser(description="Vector database with Elasticsearch")
parser.add_argument(
"--create", action="store_true", help="Create and populate a new index"
)
parser.add_argument("query", action="store", help="The search query")
return parser.parse_args()


def main():
args = parse_args()

# initiate the default connection to elasticsearch
connections.create_connection(hosts=[os.environ["ELASTICSEARCH_URL"]])

if args.create or not WorkplaceDoc._index.exists():
create()

search(args.query)

# close the connection
connections.get_connection().close()


if __name__ == "__main__":
main()

0 comments on commit 4cda418

Please sign in to comment.