Skip to content

Commit

Permalink
Support MariaDB 11.7 in the MariaDB vector store integration (#17497)
Browse files Browse the repository at this point in the history
  • Loading branch information
karsov authored Jan 15, 2025
1 parent f8f3621 commit 0de7c36
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 36 deletions.
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
# LlamaIndex Vector_Stores Integration: MariaDB

With the release of MariaDB 11.6 Vector Preview, the MariaDB relational database introduced the long-awaited vector search functionality.
Starting with version `11.7.1`, the MariaDB relational database has vector search functionality integrated.
Thus now it can be used as a fully-functional vector store in LlamaIndex.
Please note, however, that the latest MariaDB version is only an Alpha release, which means that it may crash unexpectedly.

To learn more about the feature, check the [Vector Overview](https://mariadb.com/kb/en/vector-overview/) in the MariaDB docs.
To learn more about the feature in MariaDB, check its [Vector Overview documentation](https://mariadb.com/kb/en/vector-overview/).

Please note that versions before `0.3.0` of this package are not compatible with MariaDB 11.7 and later.
They are compatible only with the one-off `MariaDB 11.6 Vector` preview release which used a slightly different syntax.

## Installation

Expand Down Expand Up @@ -33,7 +35,7 @@ vector_store = MariaDBVectorStore.from_params(
### Running Integration Tests

A suite of integration tests is available to verify the MariaDB vector store integration.
The test suite needs a MariaDB database with vector search support up and running, if not found the tests are skipped.
The test suite needs a MariaDB database with vector search support up and running. If not found, the tests are skipped.
To facilitate that, a sample `docker-compose.yaml` file is provided, so you can simply do:

```shell
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,17 +177,32 @@ def _connect(self) -> Any:
self.connection_string, connect_args=self.connection_args, echo=self.debug
)

def _validate_server_version(self) -> None:
"""Validate that the MariaDB server version is supported."""
with self._engine.connect() as connection:
result = connection.execute(sqlalchemy.text("SELECT VERSION()"))
version = result.fetchone()[0]

if not _meets_min_server_version(version, "11.7.1"):
raise ValueError(
f"MariaDB version 11.7.1 or later is required, found version: {version}."
)

def _create_table_if_not_exists(self) -> None:
with self._engine.connect() as connection:
# Note that we define the vector index with DISTANCE=cosine, because we use VEC_DISTANCE_COSINE.
# This is because searches using a different distance function do not use the vector index.
# Reference: https://mariadb.com/kb/en/create-table-with-vectors/
stmt = f"""
CREATE TABLE IF NOT EXISTS `{self.table_name}` (
id SERIAL PRIMARY KEY,
node_id VARCHAR(255) NOT NULL,
text TEXT,
metadata JSON,
embedding BLOB NOT NULL,
VECTOR INDEX (embedding)
);
embedding VECTOR({self.embed_dim}) NOT NULL,
INDEX `{self.table_name}_node_id_idx` (`node_id`),
VECTOR INDEX (embedding) DISTANCE=cosine
)
"""
connection.execute(sqlalchemy.text(stmt))

Expand All @@ -197,6 +212,7 @@ def _initialize(self) -> None:
if not self._is_initialized:
self._connect()
if self.perform_setup:
self._validate_server_version()
self._create_table_if_not_exists()
self._is_initialized = True

Expand Down Expand Up @@ -251,7 +267,7 @@ def add(
VALUES (
:node_id,
:text,
vec_fromtext(:embedding),
VEC_FromText(:embedding),
:metadata
)
"""
Expand Down Expand Up @@ -367,33 +383,17 @@ def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResul
text,
embedding,
metadata,
vec_distance(embedding, vec_fromtext('{query.query_embedding}')) AS distance
FROM `{self.table_name}`
ORDER BY distance
LIMIT {query.similarity_top_k}
"""
VEC_DISTANCE_COSINE(embedding, VEC_FromText('{query.query_embedding}')) AS distance
FROM `{self.table_name}`"""

if query.filters:
where = self._filters_to_where_clause(query.filters)
stmt += f"""
WHERE {self._filters_to_where_clause(query.filters)}"""

# We cannot use the query above when there is a WHERE clause,
# because of a bug in MariaDB: https://jira.mariadb.org/browse/MDEV-34774.
# The following query works around it.
stmt = f"""
SELECT * FROM (
SELECT
node_id,
text,
embedding,
metadata,
vec_distance(embedding, vec_fromtext('{query.query_embedding}')) AS distance
FROM `{self.table_name}`
WHERE {where}
LIMIT 1000000
) AS unordered
ORDER BY distance
LIMIT {query.similarity_top_k}
"""
stmt += f"""
ORDER BY distance
LIMIT {query.similarity_top_k}
"""

with self._engine.connect() as connection:
result = connection.execute(sqlalchemy.text(stmt))
Expand Down Expand Up @@ -443,3 +443,19 @@ def clear(self) -> None:
connection.execute(sqlalchemy.text(stmt))

connection.commit()


def _meets_min_server_version(version: str, min_version: str) -> bool:
"""Check if a MariaDB server version meets minimum required version.
Args:
version: Version string from MariaDB server (e.g. "11.7.1-MariaDB-ubu2404")
min_version: Minimum required version string (e.g. "11.7.1")
Returns:
bool: True if version >= min_version, False otherwise
"""
version = version.split("-")[0]
version_parts = [int(x) for x in version.split(".")]
min_version_parts = [int(x) for x in min_version.split(".")]
return version_parts >= min_version_parts
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ exclude = ["**/BUILD"]
license = "MIT"
name = "llama-index-vector-stores-mariadb"
readme = "README.md"
version = "0.2.0"
version = "0.3.0"

[tool.poetry.dependencies]
python = ">=3.9,<4.0"
Expand Down Expand Up @@ -69,3 +69,6 @@ include = "llama_index/"
filterwarnings = [
"ignore::DeprecationWarning:",
]
markers = [
"noautousefixtures: marks tests that should not run fixtures with autouse",
]
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
services:
mariadb:
image: "quay.io/mariadb-foundation/mariadb-devel:11.6-vector-preview"
image: mariadb:11.7.1-rc
environment:
MARIADB_DATABASE: test
MARIADB_ROOT_PASSWORD: test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
VectorStoreQuery,
)
from llama_index.vector_stores.mariadb import MariaDBVectorStore
from llama_index.vector_stores.mariadb.base import _meets_min_server_version

TEST_NODES: List[TextNode] = [
TextNode(
Expand Down Expand Up @@ -51,6 +52,7 @@
vector_store = MariaDBVectorStore.from_params(
database="test",
table_name="vector_store_test",
embed_dim=3,
host="127.0.0.1",
user="root",
password="test",
Expand All @@ -75,21 +77,45 @@


@pytest.fixture(autouse=True)
def teardown() -> Generator:
def teardown(request: pytest.FixtureRequest) -> Generator:
"""Clear the store after a test completion."""
yield

if "noautousefixtures" in request.keywords:
return

vector_store.clear()


@pytest.fixture(scope="session", autouse=True)
def close_db_connection() -> Generator:
def close_db_connection(request: pytest.FixtureRequest) -> Generator:
"""Close the DB connections after the last test."""
yield

if "noautousefixtures" in request.keywords:
return

vector_store.close()


@pytest.mark.parametrize(
("version", "supported"),
[
("11.7.2-MariaDB-ubu2504", True),
("11.7.1-MariaDB-ubu2404", True),
("11.8.0", True),
("12.0.0", True),
("11.7.0", False),
("11.6.0-MariaDB-ubu2404", False),
("10.11.7-MariaDB-1:10.11.7+maria~ubu2204", False),
("8.4.3", False),
],
)
@pytest.mark.noautousefixtures()
def test_meets_min_server_version(version: str, supported: bool) -> None:
assert _meets_min_server_version(version, "11.7.1") == supported


@pytest.mark.skipif(
run_integration_tests is False,
reason="MariaDB instance required for integration tests",
Expand Down

0 comments on commit 0de7c36

Please sign in to comment.