From b5800d48139c99ca5de3540f74f7863c8d51adf9 Mon Sep 17 00:00:00 2001 From: Miguel Grinberg Date: Thu, 18 Apr 2024 14:30:27 +0100 Subject: [PATCH] Vector search example (#1778) * Vector search example * addressed feedback * vectors example tests * skip vectors integration test for stacks < 8.11 --------- Co-authored-by: Quentin Pradet --- examples/async/vectors.py | 186 ++++++++++++++++++ examples/vectors.py | 185 +++++++++++++++++ setup.py | 4 + .../test_examples/_async/test_vectors.py | 31 +++ .../test_examples/_sync/test_vectors.py | 31 +++ utils/run-unasync.py | 8 + 6 files changed, 445 insertions(+) create mode 100644 examples/async/vectors.py create mode 100644 examples/vectors.py create mode 100644 tests/test_integration/test_examples/_async/test_vectors.py create mode 100644 tests/test_integration/test_examples/_sync/test_vectors.py diff --git a/examples/async/vectors.py b/examples/async/vectors.py new file mode 100644 index 00000000..620ea45f --- /dev/null +++ b/examples/async/vectors.py @@ -0,0 +1,186 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +# Vector database example + +Requirements: + +$ pip install nltk sentence_transformers tqdm elasticsearch-dsl[async] + +To run the example: + +$ python vectors.py "text to search" + +The index will be created automatically if it does not exist. Add +`--recreate-index` to regenerate it. + +The example dataset includes a selection of workplace documents. The +following are good example queries to try out with this dataset: + +$ python vectors.py "work from home" +$ python vectors.py "vacation time" +$ python vectors.py "can I bring a bird to work?" + +When the index is created, the documents are split into short passages, and for +each passage an embedding is generated using the open source +"all-MiniLM-L6-v2" model. The documents that are returned as search results are +those that have the highest scored passages. Add `--show-inner-hits` to the +command to see individual passage results as well. +""" + +import argparse +import asyncio +import json +import os +from urllib.request import urlopen + +import nltk +from sentence_transformers import SentenceTransformer +from tqdm import tqdm + +from elasticsearch_dsl import ( + AsyncDocument, + Date, + DenseVector, + InnerDoc, + Keyword, + Nested, + Text, + async_connections, +) + +DATASET_URL = "https://raw.githubusercontent.com/elastic/elasticsearch-labs/main/datasets/workplace-documents.json" +MODEL_NAME = "all-MiniLM-L6-v2" + +# initialize sentence tokenizer +nltk.download("punkt", quiet=True) + + +class Passage(InnerDoc): + content = Text() + embedding = DenseVector() + + +class WorkplaceDoc(AsyncDocument): + class Index: + name = "workplace_documents" + + name = Text() + summary = Text() + content = Text() + created = Date() + updated = Date() + url = Keyword() + category = Keyword() + passages = Nested(Passage) + + _model = None + + @classmethod + def get_embedding_model(cls): + if cls._model is None: + cls._model = SentenceTransformer(MODEL_NAME) + return cls._model + + def clean(self): + # split the content into sentences + passages = nltk.sent_tokenize(self.content) + + # generate an embedding for each passage and save it as a nested document + model = self.get_embedding_model() + for passage in passages: + self.passages.append( + Passage(content=passage, embedding=list(model.encode(passage))) + ) + + +async def create(): + + # create the index + await WorkplaceDoc._index.delete(ignore_unavailable=True) + await WorkplaceDoc.init() + + # download the data + dataset = json.loads(urlopen(DATASET_URL).read()) + + # import the dataset + for data in tqdm(dataset, desc="Indexing documents..."): + doc = WorkplaceDoc( + name=data["name"], + summary=data["summary"], + content=data["content"], + created=data.get("created_on"), + updated=data.get("updated_at"), + url=data["url"], + category=data["category"], + ) + await doc.save() + + +async def search(query): + model = WorkplaceDoc.get_embedding_model() + return WorkplaceDoc.search().knn( + field="passages.embedding", + k=5, + num_candidates=50, + query_vector=list(model.encode(query)), + inner_hits={"size": 2}, + ) + + +def parse_args(): + parser = argparse.ArgumentParser(description="Vector database with Elasticsearch") + parser.add_argument( + "--recreate-index", action="store_true", help="Recreate and populate the index" + ) + parser.add_argument( + "--show-inner-hits", + action="store_true", + help="Show results for individual passages", + ) + parser.add_argument("query", action="store", help="The search query") + return parser.parse_args() + + +async def main(): + args = parse_args() + + # initiate the default connection to elasticsearch + async_connections.create_connection(hosts=[os.environ["ELASTICSEARCH_URL"]]) + + if args.recreate_index or not await WorkplaceDoc._index.exists(): + await create() + + results = await search(args.query) + + async for hit in results: + print( + f"Document: {hit.name} [Category: {hit.category}] [Score: {hit.meta.score}]" + ) + print(f"Summary: {hit.summary}") + if args.show_inner_hits: + for passage in hit.meta.inner_hits.passages: + print(f" - [Score: {passage.meta.score}] {passage.content!r}") + print("") + + # close the connection + await async_connections.get_connection().close() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/vectors.py b/examples/vectors.py new file mode 100644 index 00000000..c204cb61 --- /dev/null +++ b/examples/vectors.py @@ -0,0 +1,185 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +# Vector database example + +Requirements: + +$ pip install nltk sentence_transformers tqdm elasticsearch-dsl + +To run the example: + +$ python vectors.py "text to search" + +The index will be created automatically if it does not exist. Add +`--recreate-index` to regenerate it. + +The example dataset includes a selection of workplace documents. The +following are good example queries to try out with this dataset: + +$ python vectors.py "work from home" +$ python vectors.py "vacation time" +$ python vectors.py "can I bring a bird to work?" + +When the index is created, the documents are split into short passages, and for +each passage an embedding is generated using the open source +"all-MiniLM-L6-v2" model. The documents that are returned as search results are +those that have the highest scored passages. Add `--show-inner-hits` to the +command to see individual passage results as well. +""" + +import argparse +import json +import os +from urllib.request import urlopen + +import nltk +from sentence_transformers import SentenceTransformer +from tqdm import tqdm + +from elasticsearch_dsl import ( + Date, + DenseVector, + Document, + InnerDoc, + Keyword, + Nested, + Text, + connections, +) + +DATASET_URL = "https://raw.githubusercontent.com/elastic/elasticsearch-labs/main/datasets/workplace-documents.json" +MODEL_NAME = "all-MiniLM-L6-v2" + +# initialize sentence tokenizer +nltk.download("punkt", quiet=True) + + +class Passage(InnerDoc): + content = Text() + embedding = DenseVector() + + +class WorkplaceDoc(Document): + class Index: + name = "workplace_documents" + + name = Text() + summary = Text() + content = Text() + created = Date() + updated = Date() + url = Keyword() + category = Keyword() + passages = Nested(Passage) + + _model = None + + @classmethod + def get_embedding_model(cls): + if cls._model is None: + cls._model = SentenceTransformer(MODEL_NAME) + return cls._model + + def clean(self): + # split the content into sentences + passages = nltk.sent_tokenize(self.content) + + # generate an embedding for each passage and save it as a nested document + model = self.get_embedding_model() + for passage in passages: + self.passages.append( + Passage(content=passage, embedding=list(model.encode(passage))) + ) + + +def create(): + + # create the index + WorkplaceDoc._index.delete(ignore_unavailable=True) + WorkplaceDoc.init() + + # download the data + dataset = json.loads(urlopen(DATASET_URL).read()) + + # import the dataset + for data in tqdm(dataset, desc="Indexing documents..."): + doc = WorkplaceDoc( + name=data["name"], + summary=data["summary"], + content=data["content"], + created=data.get("created_on"), + updated=data.get("updated_at"), + url=data["url"], + category=data["category"], + ) + doc.save() + + +def search(query): + model = WorkplaceDoc.get_embedding_model() + return WorkplaceDoc.search().knn( + field="passages.embedding", + k=5, + num_candidates=50, + query_vector=list(model.encode(query)), + inner_hits={"size": 2}, + ) + + +def parse_args(): + parser = argparse.ArgumentParser(description="Vector database with Elasticsearch") + parser.add_argument( + "--recreate-index", action="store_true", help="Recreate and populate the index" + ) + parser.add_argument( + "--show-inner-hits", + action="store_true", + help="Show results for individual passages", + ) + parser.add_argument("query", action="store", help="The search query") + return parser.parse_args() + + +def main(): + args = parse_args() + + # initiate the default connection to elasticsearch + connections.create_connection(hosts=[os.environ["ELASTICSEARCH_URL"]]) + + if args.recreate_index or not WorkplaceDoc._index.exists(): + create() + + results = search(args.query) + + for hit in results: + print( + f"Document: {hit.name} [Category: {hit.category}] [Score: {hit.meta.score}]" + ) + print(f"Summary: {hit.summary}") + if args.show_inner_hits: + for passage in hit.meta.inner_hits.passages: + print(f" - [Score: {passage.meta.score}] {passage.content!r}") + print("") + + # close the connection + connections.get_connection().close() + + +if __name__ == "__main__": + main() diff --git a/setup.py b/setup.py index dfbc0025..15b78a06 100644 --- a/setup.py +++ b/setup.py @@ -45,6 +45,10 @@ "pytest-asyncio", "pytz", "coverage", + # the following three are used by the vectors example and its tests + "nltk", + "sentence_transformers", + "tqdm", # Override Read the Docs default (sphinx<2 and sphinx-rtd-theme<0.5) "sphinx>2", "sphinx-rtd-theme>0.5", diff --git a/tests/test_integration/test_examples/_async/test_vectors.py b/tests/test_integration/test_examples/_async/test_vectors.py new file mode 100644 index 00000000..b4b4b210 --- /dev/null +++ b/tests/test_integration/test_examples/_async/test_vectors.py @@ -0,0 +1,31 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from unittest import SkipTest + +from ..async_examples.vectors import create, search + + +async def test_vector_search(async_write_client, es_version): + # this test only runs on Elasticsearch >= 8.11 because the example uses + # a dense vector without giving them an explicit size + if es_version < (8, 11): + raise SkipTest("This test requires Elasticsearch 8.11 or newer") + + await create() + results = await (await search("work from home")).execute() + assert results[0].name == "Work From Home Policy" diff --git a/tests/test_integration/test_examples/_sync/test_vectors.py b/tests/test_integration/test_examples/_sync/test_vectors.py new file mode 100644 index 00000000..1b015688 --- /dev/null +++ b/tests/test_integration/test_examples/_sync/test_vectors.py @@ -0,0 +1,31 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from unittest import SkipTest + +from ..examples.vectors import create, search + + +def test_vector_search(write_client, es_version): + # this test only runs on Elasticsearch >= 8.11 because the example uses + # a dense vector without giving them an explicit size + if es_version < (8, 11): + raise SkipTest("This test requires Elasticsearch 8.11 or newer") + + create() + results = (search("work from home")).execute() + assert results[0].name == "Work From Home Policy" diff --git a/utils/run-unasync.py b/utils/run-unasync.py index cd8a3ab5..e212f306 100644 --- a/utils/run-unasync.py +++ b/utils/run-unasync.py @@ -107,6 +107,14 @@ def main(check=False): f"{output_dir}{file}", ] ) + subprocess.check_call( + [ + "sed", + "-i.bak", + "s/elasticsearch-dsl\\[async\\]/elasticsearch-dsl/", + f"{output_dir}{file}", + ] + ) subprocess.check_call(["rm", f"{output_dir}{file}.bak"]) if check: