From 4cda4186b9e72e1015ab46ec0a5a4d24c44ba721 Mon Sep 17 00:00:00 2001 From: Miguel Grinberg Date: Fri, 12 Apr 2024 19:13:23 +0100 Subject: [PATCH 1/4] Vector search example --- examples/async/vectors.py | 170 ++++++++++++++++++++++++++++++++++++++ examples/vectors.py | 169 +++++++++++++++++++++++++++++++++++++ 2 files changed, 339 insertions(+) create mode 100644 examples/async/vectors.py create mode 100644 examples/vectors.py diff --git a/examples/async/vectors.py b/examples/async/vectors.py new file mode 100644 index 00000000..fcdca459 --- /dev/null +++ b/examples/async/vectors.py @@ -0,0 +1,170 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +# Vector database example + +Requirements: + +$ pip install nltk sentence_transformers tqdm elasticsearch_dsl + +To run the example: + +$ python vectors.py "text to search" + +The index will be created automatically if it does not exist. Add `--create` to +regenerate it. + +The example dataset includes a selection of workplace documentation. The +following are good example queries to try out: + +$ python vectors.py "work from home" +$ python vectors.py "vacation time" +$ python vectors.py "bring a bird to work" +""" + +import argparse +import asyncio +import json +import os +from urllib.request import urlopen + +import nltk +from sentence_transformers import SentenceTransformer +from tqdm import tqdm + +from elasticsearch_dsl import ( + AsyncDocument, + Date, + DenseVector, + InnerDoc, + Keyword, + Nested, + Text, + async_connections, +) + +DATASET_URL = "https://raw.githubusercontent.com/elastic/elasticsearch-labs/main/datasets/workplace-documents.json" +MODEL_NAME = "all-MiniLM-L6-v2" + +# initialize sentence tokenizer +nltk.download("punkt", quiet=True) + + +class Passage(InnerDoc): + content = Text() + embedding = DenseVector() + + +class WorkplaceDoc(AsyncDocument): + class Index: + name = "workplace_documents" + + name = Text() + summary = Text() + content = Text() + created = Date() + updated = Date() + url = Keyword() + category = Keyword() + passages = Nested(Passage) + + _model = None + + @classmethod + def get_embedding_model(cls): + if cls._model is None: + cls._model = SentenceTransformer(MODEL_NAME) + return cls._model + + def clean(self): + # split the content into sentences + passages = nltk.sent_tokenize(self.content) + + # generate an embedding for each passage and save it as a nested document + model = self.get_embedding_model() + for passage in passages: + self.passages.append( + Passage(content=passage, embedding=list(model.encode(passage))) + ) + + +async def create(): + + # create the index + await WorkplaceDoc._index.delete(ignore_unavailable=True) + await WorkplaceDoc.init() + + # download the data + dataset = json.loads(urlopen(DATASET_URL).read()) + + # import the dataset + for data in tqdm(dataset, desc="Indexing documents..."): + doc = WorkplaceDoc( + name=data["name"], + summary=data["summary"], + content=data["content"], + created=data.get("created_on"), + updated=data.get("updated_at"), + url=data["url"], + category=data["category"], + ) + await doc.save() + + +async def search(query): + model = WorkplaceDoc.get_embedding_model() + search = WorkplaceDoc.search().knn( + field="passages.embedding", + k=5, + num_candidates=50, + query_vector=list(model.encode(query)), + inner_hits={"size": 3}, + ) + async for hit in search: + print(f"Document: {hit.name} (Category: {hit.category}") + for passage in hit.meta.inner_hits.passages: + print(f" - [Score: {passage.meta.score}] {passage.content!r}") + print("") + + +def parse_args(): + parser = argparse.ArgumentParser(description="Vector database with Elasticsearch") + parser.add_argument( + "--create", action="store_true", help="Create and populate a new index" + ) + parser.add_argument("query", action="store", help="The search query") + return parser.parse_args() + + +async def main(): + args = parse_args() + + # initiate the default connection to elasticsearch + async_connections.create_connection(hosts=[os.environ["ELASTICSEARCH_URL"]]) + + if args.create or not await WorkplaceDoc._index.exists(): + await create() + + await search(args.query) + + # close the connection + await async_connections.get_connection().close() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/vectors.py b/examples/vectors.py new file mode 100644 index 00000000..32258475 --- /dev/null +++ b/examples/vectors.py @@ -0,0 +1,169 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +# Vector database example + +Requirements: + +$ pip install nltk sentence_transformers tqdm elasticsearch_dsl + +To run the example: + +$ python vectors.py "text to search" + +The index will be created automatically if it does not exist. Add `--create` to +regenerate it. + +The example dataset includes a selection of workplace documentation. The +following are good example queries to try out: + +$ python vectors.py "work from home" +$ python vectors.py "vacation time" +$ python vectors.py "bring a bird to work" +""" + +import argparse +import json +import os +from urllib.request import urlopen + +import nltk +from sentence_transformers import SentenceTransformer +from tqdm import tqdm + +from elasticsearch_dsl import ( + Date, + DenseVector, + Document, + InnerDoc, + Keyword, + Nested, + Text, + connections, +) + +DATASET_URL = "https://raw.githubusercontent.com/elastic/elasticsearch-labs/main/datasets/workplace-documents.json" +MODEL_NAME = "all-MiniLM-L6-v2" + +# initialize sentence tokenizer +nltk.download("punkt", quiet=True) + + +class Passage(InnerDoc): + content = Text() + embedding = DenseVector() + + +class WorkplaceDoc(Document): + class Index: + name = "workplace_documents" + + name = Text() + summary = Text() + content = Text() + created = Date() + updated = Date() + url = Keyword() + category = Keyword() + passages = Nested(Passage) + + _model = None + + @classmethod + def get_embedding_model(cls): + if cls._model is None: + cls._model = SentenceTransformer(MODEL_NAME) + return cls._model + + def clean(self): + # split the content into sentences + passages = nltk.sent_tokenize(self.content) + + # generate an embedding for each passage and save it as a nested document + model = self.get_embedding_model() + for passage in passages: + self.passages.append( + Passage(content=passage, embedding=list(model.encode(passage))) + ) + + +def create(): + + # create the index + WorkplaceDoc._index.delete(ignore_unavailable=True) + WorkplaceDoc.init() + + # download the data + dataset = json.loads(urlopen(DATASET_URL).read()) + + # import the dataset + for data in tqdm(dataset, desc="Indexing documents..."): + doc = WorkplaceDoc( + name=data["name"], + summary=data["summary"], + content=data["content"], + created=data.get("created_on"), + updated=data.get("updated_at"), + url=data["url"], + category=data["category"], + ) + doc.save() + + +def search(query): + model = WorkplaceDoc.get_embedding_model() + search = WorkplaceDoc.search().knn( + field="passages.embedding", + k=5, + num_candidates=50, + query_vector=list(model.encode(query)), + inner_hits={"size": 3}, + ) + for hit in search: + print(f"Document: {hit.name} (Category: {hit.category}") + for passage in hit.meta.inner_hits.passages: + print(f" - [Score: {passage.meta.score}] {passage.content!r}") + print("") + + +def parse_args(): + parser = argparse.ArgumentParser(description="Vector database with Elasticsearch") + parser.add_argument( + "--create", action="store_true", help="Create and populate a new index" + ) + parser.add_argument("query", action="store", help="The search query") + return parser.parse_args() + + +def main(): + args = parse_args() + + # initiate the default connection to elasticsearch + connections.create_connection(hosts=[os.environ["ELASTICSEARCH_URL"]]) + + if args.create or not WorkplaceDoc._index.exists(): + create() + + search(args.query) + + # close the connection + connections.get_connection().close() + + +if __name__ == "__main__": + main() From 34432f218520cc95ba876ec88b5d1522e44500be Mon Sep 17 00:00:00 2001 From: Miguel Grinberg Date: Thu, 18 Apr 2024 12:16:51 +0100 Subject: [PATCH 2/4] addressed feedback --- examples/async/vectors.py | 48 ++++++++++++++++++++++++++------------- examples/vectors.py | 48 ++++++++++++++++++++++++++------------- utils/run-unasync.py | 8 +++++++ 3 files changed, 72 insertions(+), 32 deletions(-) diff --git a/examples/async/vectors.py b/examples/async/vectors.py index fcdca459..620ea45f 100644 --- a/examples/async/vectors.py +++ b/examples/async/vectors.py @@ -20,21 +20,27 @@ Requirements: -$ pip install nltk sentence_transformers tqdm elasticsearch_dsl +$ pip install nltk sentence_transformers tqdm elasticsearch-dsl[async] To run the example: $ python vectors.py "text to search" -The index will be created automatically if it does not exist. Add `--create` to -regenerate it. +The index will be created automatically if it does not exist. Add +`--recreate-index` to regenerate it. -The example dataset includes a selection of workplace documentation. The -following are good example queries to try out: +The example dataset includes a selection of workplace documents. The +following are good example queries to try out with this dataset: $ python vectors.py "work from home" $ python vectors.py "vacation time" -$ python vectors.py "bring a bird to work" +$ python vectors.py "can I bring a bird to work?" + +When the index is created, the documents are split into short passages, and for +each passage an embedding is generated using the open source +"all-MiniLM-L6-v2" model. The documents that are returned as search results are +those that have the highest scored passages. Add `--show-inner-hits` to the +command to see individual passage results as well. """ import argparse @@ -128,24 +134,24 @@ async def create(): async def search(query): model = WorkplaceDoc.get_embedding_model() - search = WorkplaceDoc.search().knn( + return WorkplaceDoc.search().knn( field="passages.embedding", k=5, num_candidates=50, query_vector=list(model.encode(query)), - inner_hits={"size": 3}, + inner_hits={"size": 2}, ) - async for hit in search: - print(f"Document: {hit.name} (Category: {hit.category}") - for passage in hit.meta.inner_hits.passages: - print(f" - [Score: {passage.meta.score}] {passage.content!r}") - print("") def parse_args(): parser = argparse.ArgumentParser(description="Vector database with Elasticsearch") parser.add_argument( - "--create", action="store_true", help="Create and populate a new index" + "--recreate-index", action="store_true", help="Recreate and populate the index" + ) + parser.add_argument( + "--show-inner-hits", + action="store_true", + help="Show results for individual passages", ) parser.add_argument("query", action="store", help="The search query") return parser.parse_args() @@ -157,10 +163,20 @@ async def main(): # initiate the default connection to elasticsearch async_connections.create_connection(hosts=[os.environ["ELASTICSEARCH_URL"]]) - if args.create or not await WorkplaceDoc._index.exists(): + if args.recreate_index or not await WorkplaceDoc._index.exists(): await create() - await search(args.query) + results = await search(args.query) + + async for hit in results: + print( + f"Document: {hit.name} [Category: {hit.category}] [Score: {hit.meta.score}]" + ) + print(f"Summary: {hit.summary}") + if args.show_inner_hits: + for passage in hit.meta.inner_hits.passages: + print(f" - [Score: {passage.meta.score}] {passage.content!r}") + print("") # close the connection await async_connections.get_connection().close() diff --git a/examples/vectors.py b/examples/vectors.py index 32258475..c204cb61 100644 --- a/examples/vectors.py +++ b/examples/vectors.py @@ -20,21 +20,27 @@ Requirements: -$ pip install nltk sentence_transformers tqdm elasticsearch_dsl +$ pip install nltk sentence_transformers tqdm elasticsearch-dsl To run the example: $ python vectors.py "text to search" -The index will be created automatically if it does not exist. Add `--create` to -regenerate it. +The index will be created automatically if it does not exist. Add +`--recreate-index` to regenerate it. -The example dataset includes a selection of workplace documentation. The -following are good example queries to try out: +The example dataset includes a selection of workplace documents. The +following are good example queries to try out with this dataset: $ python vectors.py "work from home" $ python vectors.py "vacation time" -$ python vectors.py "bring a bird to work" +$ python vectors.py "can I bring a bird to work?" + +When the index is created, the documents are split into short passages, and for +each passage an embedding is generated using the open source +"all-MiniLM-L6-v2" model. The documents that are returned as search results are +those that have the highest scored passages. Add `--show-inner-hits` to the +command to see individual passage results as well. """ import argparse @@ -127,24 +133,24 @@ def create(): def search(query): model = WorkplaceDoc.get_embedding_model() - search = WorkplaceDoc.search().knn( + return WorkplaceDoc.search().knn( field="passages.embedding", k=5, num_candidates=50, query_vector=list(model.encode(query)), - inner_hits={"size": 3}, + inner_hits={"size": 2}, ) - for hit in search: - print(f"Document: {hit.name} (Category: {hit.category}") - for passage in hit.meta.inner_hits.passages: - print(f" - [Score: {passage.meta.score}] {passage.content!r}") - print("") def parse_args(): parser = argparse.ArgumentParser(description="Vector database with Elasticsearch") parser.add_argument( - "--create", action="store_true", help="Create and populate a new index" + "--recreate-index", action="store_true", help="Recreate and populate the index" + ) + parser.add_argument( + "--show-inner-hits", + action="store_true", + help="Show results for individual passages", ) parser.add_argument("query", action="store", help="The search query") return parser.parse_args() @@ -156,10 +162,20 @@ def main(): # initiate the default connection to elasticsearch connections.create_connection(hosts=[os.environ["ELASTICSEARCH_URL"]]) - if args.create or not WorkplaceDoc._index.exists(): + if args.recreate_index or not WorkplaceDoc._index.exists(): create() - search(args.query) + results = search(args.query) + + for hit in results: + print( + f"Document: {hit.name} [Category: {hit.category}] [Score: {hit.meta.score}]" + ) + print(f"Summary: {hit.summary}") + if args.show_inner_hits: + for passage in hit.meta.inner_hits.passages: + print(f" - [Score: {passage.meta.score}] {passage.content!r}") + print("") # close the connection connections.get_connection().close() diff --git a/utils/run-unasync.py b/utils/run-unasync.py index cd8a3ab5..e212f306 100644 --- a/utils/run-unasync.py +++ b/utils/run-unasync.py @@ -107,6 +107,14 @@ def main(check=False): f"{output_dir}{file}", ] ) + subprocess.check_call( + [ + "sed", + "-i.bak", + "s/elasticsearch-dsl\\[async\\]/elasticsearch-dsl/", + f"{output_dir}{file}", + ] + ) subprocess.check_call(["rm", f"{output_dir}{file}.bak"]) if check: From e7cb85635edb097afeed76ddb906192255eac2f2 Mon Sep 17 00:00:00 2001 From: Miguel Grinberg Date: Thu, 18 Apr 2024 12:42:57 +0100 Subject: [PATCH 3/4] vectors example tests --- setup.py | 4 +++ .../test_examples/_async/test_vectors.py | 25 +++++++++++++++++++ .../test_examples/_sync/test_vectors.py | 25 +++++++++++++++++++ 3 files changed, 54 insertions(+) create mode 100644 tests/test_integration/test_examples/_async/test_vectors.py create mode 100644 tests/test_integration/test_examples/_sync/test_vectors.py diff --git a/setup.py b/setup.py index dfbc0025..15b78a06 100644 --- a/setup.py +++ b/setup.py @@ -45,6 +45,10 @@ "pytest-asyncio", "pytz", "coverage", + # the following three are used by the vectors example and its tests + "nltk", + "sentence_transformers", + "tqdm", # Override Read the Docs default (sphinx<2 and sphinx-rtd-theme<0.5) "sphinx>2", "sphinx-rtd-theme>0.5", diff --git a/tests/test_integration/test_examples/_async/test_vectors.py b/tests/test_integration/test_examples/_async/test_vectors.py new file mode 100644 index 00000000..8d8253a6 --- /dev/null +++ b/tests/test_integration/test_examples/_async/test_vectors.py @@ -0,0 +1,25 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +from ..async_examples.vectors import create, search + + +async def test_vector_search(async_write_client): + await create() + results = await (await search("work from home")).execute() + assert results[0].name == "Work From Home Policy" diff --git a/tests/test_integration/test_examples/_sync/test_vectors.py b/tests/test_integration/test_examples/_sync/test_vectors.py new file mode 100644 index 00000000..a7719020 --- /dev/null +++ b/tests/test_integration/test_examples/_sync/test_vectors.py @@ -0,0 +1,25 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +from ..examples.vectors import create, search + + +def test_vector_search(write_client): + create() + results = (search("work from home")).execute() + assert results[0].name == "Work From Home Policy" From 0bdfab36041db0bd7e3fe30f03826ea23289290b Mon Sep 17 00:00:00 2001 From: Miguel Grinberg Date: Thu, 18 Apr 2024 14:15:06 +0100 Subject: [PATCH 4/4] skip vectors integration test for stacks < 8.11 --- .../test_integration/test_examples/_async/test_vectors.py | 8 +++++++- .../test_integration/test_examples/_sync/test_vectors.py | 8 +++++++- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/tests/test_integration/test_examples/_async/test_vectors.py b/tests/test_integration/test_examples/_async/test_vectors.py index 8d8253a6..b4b4b210 100644 --- a/tests/test_integration/test_examples/_async/test_vectors.py +++ b/tests/test_integration/test_examples/_async/test_vectors.py @@ -15,11 +15,17 @@ # specific language governing permissions and limitations # under the License. +from unittest import SkipTest from ..async_examples.vectors import create, search -async def test_vector_search(async_write_client): +async def test_vector_search(async_write_client, es_version): + # this test only runs on Elasticsearch >= 8.11 because the example uses + # a dense vector without giving them an explicit size + if es_version < (8, 11): + raise SkipTest("This test requires Elasticsearch 8.11 or newer") + await create() results = await (await search("work from home")).execute() assert results[0].name == "Work From Home Policy" diff --git a/tests/test_integration/test_examples/_sync/test_vectors.py b/tests/test_integration/test_examples/_sync/test_vectors.py index a7719020..1b015688 100644 --- a/tests/test_integration/test_examples/_sync/test_vectors.py +++ b/tests/test_integration/test_examples/_sync/test_vectors.py @@ -15,11 +15,17 @@ # specific language governing permissions and limitations # under the License. +from unittest import SkipTest from ..examples.vectors import create, search -def test_vector_search(write_client): +def test_vector_search(write_client, es_version): + # this test only runs on Elasticsearch >= 8.11 because the example uses + # a dense vector without giving them an explicit size + if es_version < (8, 11): + raise SkipTest("This test requires Elasticsearch 8.11 or newer") + create() results = (search("work from home")).execute() assert results[0].name == "Work From Home Policy"