From bb99b1c11170f4e6e4e529b52c81d6f19dd13fb9 Mon Sep 17 00:00:00 2001 From: Miguel Grinberg Date: Thu, 15 Aug 2024 17:02:03 +0100 Subject: [PATCH] Added support for the `semantic_text` field and `semantic` query type --- elasticsearch_dsl/field.py | 4 + elasticsearch_dsl/query.py | 4 + examples/async/semantic_text.py | 150 ++++++++++++++++++++++++++++++++ examples/semantic_text.py | 149 +++++++++++++++++++++++++++++++ 4 files changed, 307 insertions(+) create mode 100644 examples/async/semantic_text.py create mode 100644 examples/semantic_text.py diff --git a/elasticsearch_dsl/field.py b/elasticsearch_dsl/field.py index 7896fe5f..26f2336b 100644 --- a/elasticsearch_dsl/field.py +++ b/elasticsearch_dsl/field.py @@ -560,3 +560,7 @@ class TokenCount(Field): class Murmur3(Field): name = "murmur3" + + +class SemanticText(Field): + name = "semantic_text" diff --git a/elasticsearch_dsl/query.py b/elasticsearch_dsl/query.py index ce445216..993213c6 100644 --- a/elasticsearch_dsl/query.py +++ b/elasticsearch_dsl/query.py @@ -527,6 +527,10 @@ class Shape(Query): name = "shape" +class Semantic(Query): + name = "semantic" + + class SimpleQueryString(Query): name = "simple_query_string" diff --git a/examples/async/semantic_text.py b/examples/async/semantic_text.py new file mode 100644 index 00000000..00965f30 --- /dev/null +++ b/examples/async/semantic_text.py @@ -0,0 +1,150 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +""" +# Semantic Text example + +Requirements: + +$ pip install "elasticsearch-dsl[async]" tqdm + +Before running this example, an ELSER inference endpoint must be created in the +Elasticsearch cluster. This can be done manually from Kibana, or with the +following curl command from a terminal: + +curl -X PUT \ + "$ELASTICSEARCH_URL/_inference/sparse_embedding/my-elser-endpoint" \ + -H "Content-Type: application/json" \ + -d '{"service":"elser","service_settings":{"num_allocations":1,"num_threads":1}}' + +To run the example: + +$ python semantic_text.py "text to search" + +The index will be created automatically if it does not exist. Add +`--recreate-index` to the command to regenerate it. + +The example dataset includes a selection of workplace documents. The +following are good example queries to try out with this dataset: + +$ python semantic_text.py "work from home" +$ python semantic_text.py "vacation time" +$ python semantic_text.py "can I bring a bird to work?" + +When the index is created, the inference service will split the documents into +short passages, and for each passage a sparse embedding will be generated using +Elastic's ELSER v2 model. +""" + +import argparse +import asyncio +import json +import os +from datetime import datetime +from typing import Any, Optional +from urllib.request import urlopen + +from tqdm import tqdm + +import elasticsearch_dsl as dsl + +DATASET_URL = "https://raw.githubusercontent.com/elastic/elasticsearch-labs/main/datasets/workplace-documents.json" + + +class WorkplaceDoc(dsl.AsyncDocument): + class Index: + name = "workplace_documents_semantic" + + name: str + summary: str + content: Any = dsl.mapped_field( + dsl.field.SemanticText(inference_id="my-elser-endpoint") + ) + created: datetime + updated: Optional[datetime] + url: str = dsl.mapped_field(dsl.Keyword()) + category: str = dsl.mapped_field(dsl.Keyword()) + + +async def create() -> None: + + # create the index + await WorkplaceDoc._index.delete(ignore_unavailable=True) + await WorkplaceDoc.init() + + # download the data + dataset = json.loads(urlopen(DATASET_URL).read()) + + # import the dataset + for data in tqdm(dataset, desc="Indexing documents..."): + doc = WorkplaceDoc( + name=data["name"], + summary=data["summary"], + content=data["content"], + created=data.get("created_on"), + updated=data.get("updated_at"), + url=data["url"], + category=data["category"], + ) + await doc.save() + + # refresh the index + await WorkplaceDoc._index.refresh() + + +async def search(query: str) -> dsl.AsyncSearch[WorkplaceDoc]: + return WorkplaceDoc.search()[:5].query( + "semantic", + field=WorkplaceDoc.content, + query=query, + ) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Vector database with Elasticsearch") + parser.add_argument( + "--recreate-index", action="store_true", help="Recreate and populate the index" + ) + parser.add_argument("query", action="store", help="The search query") + return parser.parse_args() + + +async def main() -> None: + args = parse_args() + + # initiate the default connection to elasticsearch + dsl.async_connections.create_connection(hosts=[os.environ["ELASTICSEARCH_URL"]]) + + if args.recreate_index or not await WorkplaceDoc._index.exists(): + await create() + + results = await search(args.query) + + async for hit in results: + print( + f"Document: {hit.name} [Category: {hit.category}] [Score: {hit.meta.score}]" + ) + print(f"Content: {hit.content.text}") + print("--------------------\n") + + # close the connection + await dsl.async_connections.get_connection().close() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/semantic_text.py b/examples/semantic_text.py new file mode 100644 index 00000000..7f935b80 --- /dev/null +++ b/examples/semantic_text.py @@ -0,0 +1,149 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +""" +# Semantic Text example + +Requirements: + +$ pip install "elasticsearch-dsl" tqdm + +Before running this example, an ELSER inference endpoint must be created in the +Elasticsearch cluster. This can be done manually from Kibana, or with the +following curl command from a terminal: + +curl -X PUT \ + "$ELASTICSEARCH_URL/_inference/sparse_embedding/my-elser-endpoint" \ + -H "Content-Type: application/json" \ + -d '{"service":"elser","service_settings":{"num_allocations":1,"num_threads":1}}' + +To run the example: + +$ python semantic_text.py "text to search" + +The index will be created automatically if it does not exist. Add +`--recreate-index` to the command to regenerate it. + +The example dataset includes a selection of workplace documents. The +following are good example queries to try out with this dataset: + +$ python semantic_text.py "work from home" +$ python semantic_text.py "vacation time" +$ python semantic_text.py "can I bring a bird to work?" + +When the index is created, the inference service will split the documents into +short passages, and for each passage a sparse embedding will be generated using +Elastic's ELSER v2 model. +""" + +import argparse +import json +import os +from datetime import datetime +from typing import Any, Optional +from urllib.request import urlopen + +from tqdm import tqdm + +import elasticsearch_dsl as dsl + +DATASET_URL = "https://raw.githubusercontent.com/elastic/elasticsearch-labs/main/datasets/workplace-documents.json" + + +class WorkplaceDoc(dsl.Document): + class Index: + name = "workplace_documents_semantic" + + name: str + summary: str + content: Any = dsl.mapped_field( + dsl.field.SemanticText(inference_id="my-elser-endpoint") + ) + created: datetime + updated: Optional[datetime] + url: str = dsl.mapped_field(dsl.Keyword()) + category: str = dsl.mapped_field(dsl.Keyword()) + + +def create() -> None: + + # create the index + WorkplaceDoc._index.delete(ignore_unavailable=True) + WorkplaceDoc.init() + + # download the data + dataset = json.loads(urlopen(DATASET_URL).read()) + + # import the dataset + for data in tqdm(dataset, desc="Indexing documents..."): + doc = WorkplaceDoc( + name=data["name"], + summary=data["summary"], + content=data["content"], + created=data.get("created_on"), + updated=data.get("updated_at"), + url=data["url"], + category=data["category"], + ) + doc.save() + + # refresh the index + WorkplaceDoc._index.refresh() + + +def search(query: str) -> dsl.Search[WorkplaceDoc]: + return WorkplaceDoc.search()[:5].query( + "semantic", + field=WorkplaceDoc.content, + query=query, + ) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Vector database with Elasticsearch") + parser.add_argument( + "--recreate-index", action="store_true", help="Recreate and populate the index" + ) + parser.add_argument("query", action="store", help="The search query") + return parser.parse_args() + + +def main() -> None: + args = parse_args() + + # initiate the default connection to elasticsearch + dsl.connections.create_connection(hosts=[os.environ["ELASTICSEARCH_URL"]]) + + if args.recreate_index or not WorkplaceDoc._index.exists(): + create() + + results = search(args.query) + + for hit in results: + print( + f"Document: {hit.name} [Category: {hit.category}] [Score: {hit.meta.score}]" + ) + print(f"Content: {hit.content.text}") + print("--------------------\n") + + # close the connection + dsl.connections.get_connection().close() + + +if __name__ == "__main__": + main()