Skip to content

Commit

Permalink
support new dense vector quantization in 8.16
Browse files Browse the repository at this point in the history
  • Loading branch information
miguelgrinberg committed Dec 10, 2024
1 parent 0dd69f8 commit 9b21ff7
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 3 deletions.
14 changes: 12 additions & 2 deletions elasticsearch_dsl/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,13 +389,23 @@ def _deserialize(self, data: Any) -> float:
return float(data)


class DenseVector(Float):
class DenseVector(Field):
name = "dense_vector"
_coerce = True

def __init__(self, **kwargs: Any):
kwargs["multi"] = True
self._element_type = kwargs.get("element_type", "float")
if self._element_type in ["float", "byte"]:
kwargs["multi"] = True
super().__init__(**kwargs)

def _deserialize(self, data: Any) -> Any:
if self._element_type == "float":
return float(data)
elif self._element_type == "byte":
return int(data)
return data


class SparseVector(Field):
name = "sparse_vector"
Expand Down
55 changes: 54 additions & 1 deletion tests/test_integration/_async/test_document.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from datetime import datetime
from ipaddress import ip_address
from typing import TYPE_CHECKING, Any, AsyncIterator, Dict, List, Union
from typing import TYPE_CHECKING, Any, AsyncIterator, Dict, List, Tuple, Union

import pytest
from elasticsearch import AsyncElasticsearch, ConflictError, NotFoundError
Expand All @@ -37,6 +37,7 @@
Binary,
Boolean,
Date,
DenseVector,
Double,
InnerDoc,
Ip,
Expand Down Expand Up @@ -795,3 +796,55 @@ async def gen3() -> AsyncIterator[Union[Doc, Dict[str, Any]]]:
"age": 45,
"languages": ["es"],
}


@pytest.mark.asyncio
async def test_float_dense_vector(async_client: AsyncElasticsearch) -> None:
if es_version >= (8, 16):
pytest.skip("this test is a legacy version for Elasticsearch 8.15 or older")

class Doc(AsyncDocument):
float_vector: List[float] = mapped_field(DenseVector())

class Index:
name = "vectors"

await Doc._index.delete(ignore_unavailable=True)
await Doc.init()

doc = Doc(
float_vector=[1.0, 1.2, 2.3]
)
await doc.save(refresh=True)

docs = await Doc.search().execute()
assert len(docs) == 1
assert docs[0].float_vector == doc.float_vector


@pytest.mark.asyncio
async def test_dense_vector(async_client: AsyncElasticsearch, es_version: Tuple[int, ...]) -> None:
if es_version < (8, 16):
pytest.skip("this test requires Elasticsearch 8.16 or newer")

class Doc(AsyncDocument):
float_vector: List[float] = mapped_field(DenseVector())
byte_vector: List[int] = mapped_field(DenseVector(element_type="byte"))
bit_vector: str = mapped_field(DenseVector(element_type="bit"))

class Index:
name = "vectors"

await Doc._index.delete(ignore_unavailable=True)
await Doc.init()

doc = Doc(
float_vector=[1.0, 1.2, 2.3], byte_vector=[12, 23, 34, 45], bit_vector="12abf0"
)
await doc.save(refresh=True)

docs = await Doc.search().execute()
assert len(docs) == 1
assert docs[0].float_vector == doc.float_vector
assert docs[0].byte_vector == doc.byte_vector
assert docs[0].bit_vector == doc.bit_vector
26 changes: 26 additions & 0 deletions tests/test_integration/_sync/test_document.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
Binary,
Boolean,
Date,
DenseVector,
Document,
Double,
InnerDoc,
Expand Down Expand Up @@ -789,3 +790,28 @@ def gen3() -> Iterator[Union[Doc, Dict[str, Any]]]:
"age": 45,
"languages": ["es"],
}


@pytest.mark.sync
def test_dense_vector_quantization(client: Elasticsearch) -> None:
class Doc(Document):
float_vector: List[float] = mapped_field(DenseVector())
byte_vector: List[int] = mapped_field(DenseVector(element_type="byte"))
bit_vector: str = mapped_field(DenseVector(element_type="bit"))

class Index:
name = "vectors"

Doc._index.delete(ignore_unavailable=True)
Doc.init()

doc = Doc(
float_vector=[1.0, 1.2, 2.3], byte_vector=[12, 23, 34, 45], bit_vector="12abf0"
)
doc.save(refresh=True)

docs = Doc.search().execute()
assert len(docs) == 1
assert docs[0].float_vector == doc.float_vector
assert docs[0].byte_vector == doc.byte_vector
assert docs[0].bit_vector == doc.bit_vector

0 comments on commit 9b21ff7

Please sign in to comment.