Skip to content

Commit

Permalink
support new dense vector quantization in 8.16
Browse files Browse the repository at this point in the history
  • Loading branch information
miguelgrinberg committed Dec 10, 2024
1 parent 0dd69f8 commit a5f4e80
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 4 deletions.
14 changes: 12 additions & 2 deletions elasticsearch_dsl/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,13 +389,23 @@ def _deserialize(self, data: Any) -> float:
return float(data)


class DenseVector(Float):
class DenseVector(Field):
name = "dense_vector"
_coerce = True

def __init__(self, **kwargs: Any):
kwargs["multi"] = True
self._element_type = kwargs.get("element_type", "float")
if self._element_type in ["float", "byte"]:
kwargs["multi"] = True
super().__init__(**kwargs)

def _deserialize(self, data: Any) -> Any:
if self._element_type == "float":
return float(data)
elif self._element_type == "byte":
return int(data)
return data


class SparseVector(Field):
name = "sparse_vector"
Expand Down
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,6 @@ filterwarnings =
error
ignore:Legacy index templates are deprecated in favor of composable templates.:elasticsearch.exceptions.ElasticsearchWarning
ignore:datetime.datetime.utcfromtimestamp\(\) is deprecated and scheduled for removal in a future version..*:DeprecationWarning
default:enable_cleanup_closed ignored.*:DeprecationWarning
markers =
sync: mark a test as performing I/O without asyncio.
57 changes: 56 additions & 1 deletion tests/test_integration/_async/test_document.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from datetime import datetime
from ipaddress import ip_address
from typing import TYPE_CHECKING, Any, AsyncIterator, Dict, List, Union
from typing import TYPE_CHECKING, Any, AsyncIterator, Dict, List, Tuple, Union

import pytest
from elasticsearch import AsyncElasticsearch, ConflictError, NotFoundError
Expand All @@ -37,6 +37,7 @@
Binary,
Boolean,
Date,
DenseVector,
Double,
InnerDoc,
Ip,
Expand Down Expand Up @@ -795,3 +796,57 @@ async def gen3() -> AsyncIterator[Union[Doc, Dict[str, Any]]]:
"age": 45,
"languages": ["es"],
}


@pytest.mark.asyncio
async def test_legacy_dense_vector(
async_client: AsyncElasticsearch, es_version: Tuple[int, ...]
) -> None:
if es_version >= (8, 16):
pytest.skip("this test is a legacy version for Elasticsearch 8.15 or older")

class Doc(AsyncDocument):
float_vector: List[float] = mapped_field(DenseVector(dims=3))

class Index:
name = "vectors"

await Doc._index.delete(ignore_unavailable=True)
await Doc.init()

doc = Doc(float_vector=[1.0, 1.2, 2.3])
await doc.save(refresh=True)

docs = await Doc.search().execute()
assert len(docs) == 1
assert docs[0].float_vector == doc.float_vector


@pytest.mark.asyncio
async def test_dense_vector(
async_client: AsyncElasticsearch, es_version: Tuple[int, ...]
) -> None:
if es_version < (8, 16):
pytest.skip("this test requires Elasticsearch 8.16 or newer")

class Doc(AsyncDocument):
float_vector: List[float] = mapped_field(DenseVector())
byte_vector: List[int] = mapped_field(DenseVector(element_type="byte"))
bit_vector: str = mapped_field(DenseVector(element_type="bit"))

class Index:
name = "vectors"

await Doc._index.delete(ignore_unavailable=True)
await Doc.init()

doc = Doc(
float_vector=[1.0, 1.2, 2.3], byte_vector=[12, 23, 34, 45], bit_vector="12abf0"
)
await doc.save(refresh=True)

docs = await Doc.search().execute()
assert len(docs) == 1
assert docs[0].float_vector == doc.float_vector
assert docs[0].byte_vector == doc.byte_vector
assert docs[0].bit_vector == doc.bit_vector
55 changes: 54 additions & 1 deletion tests/test_integration/_sync/test_document.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from datetime import datetime
from ipaddress import ip_address
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Union
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Tuple, Union

import pytest
from elasticsearch import ConflictError, Elasticsearch, NotFoundError
Expand All @@ -35,6 +35,7 @@
Binary,
Boolean,
Date,
DenseVector,
Document,
Double,
InnerDoc,
Expand Down Expand Up @@ -789,3 +790,55 @@ def gen3() -> Iterator[Union[Doc, Dict[str, Any]]]:
"age": 45,
"languages": ["es"],
}


@pytest.mark.sync
def test_legacy_dense_vector(
client: Elasticsearch, es_version: Tuple[int, ...]
) -> None:
if es_version >= (8, 16):
pytest.skip("this test is a legacy version for Elasticsearch 8.15 or older")

class Doc(Document):
float_vector: List[float] = mapped_field(DenseVector(dims=3))

class Index:
name = "vectors"

Doc._index.delete(ignore_unavailable=True)
Doc.init()

doc = Doc(float_vector=[1.0, 1.2, 2.3])
doc.save(refresh=True)

docs = Doc.search().execute()
assert len(docs) == 1
assert docs[0].float_vector == doc.float_vector


@pytest.mark.sync
def test_dense_vector(client: Elasticsearch, es_version: Tuple[int, ...]) -> None:
if es_version < (8, 16):
pytest.skip("this test requires Elasticsearch 8.16 or newer")

class Doc(Document):
float_vector: List[float] = mapped_field(DenseVector())
byte_vector: List[int] = mapped_field(DenseVector(element_type="byte"))
bit_vector: str = mapped_field(DenseVector(element_type="bit"))

class Index:
name = "vectors"

Doc._index.delete(ignore_unavailable=True)
Doc.init()

doc = Doc(
float_vector=[1.0, 1.2, 2.3], byte_vector=[12, 23, 34, 45], bit_vector="12abf0"
)
doc.save(refresh=True)

docs = Doc.search().execute()
assert len(docs) == 1
assert docs[0].float_vector == doc.float_vector
assert docs[0].byte_vector == doc.byte_vector
assert docs[0].bit_vector == doc.bit_vector

0 comments on commit a5f4e80

Please sign in to comment.