Skip to content

Commit

Permalink
Validate that the server version is supported on MariaDB vector store…
Browse files Browse the repository at this point in the history
… init
  • Loading branch information
karsov committed Jan 15, 2025
1 parent 08010f6 commit db6ac68
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,17 @@ def _connect(self) -> Any:
self.connection_string, connect_args=self.connection_args, echo=self.debug
)

def _validate_server_version(self) -> None:
"""Validate that the MariaDB server version is supported."""
with self._engine.connect() as connection:
result = connection.execute(sqlalchemy.text("SELECT VERSION()"))
version = result.fetchone()[0]

if not _meets_min_server_version(version, "11.7.1"):
raise ValueError(
f"MariaDB version 11.7.1 or later is required, found version: {version}."
)

def _create_table_if_not_exists(self) -> None:
with self._engine.connect() as connection:
# Note that we define the vector index with DISTANCE=cosine, because we use VEC_DISTANCE_COSINE.
Expand All @@ -201,6 +212,7 @@ def _initialize(self) -> None:
if not self._is_initialized:
self._connect()
if self.perform_setup:
self._validate_server_version()
self._create_table_if_not_exists()
self._is_initialized = True

Expand Down Expand Up @@ -431,3 +443,19 @@ def clear(self) -> None:
connection.execute(sqlalchemy.text(stmt))

connection.commit()


def _meets_min_server_version(version: str, min_version: str) -> bool:
"""Check if a MariaDB server version meets minimum required version.
Args:
version: Version string from MariaDB server (e.g. "11.7.1-MariaDB-ubu2404")
min_version: Minimum required version string (e.g. "11.7.1")
Returns:
bool: True if version >= min_version, False otherwise
"""
version = version.split("-")[0]
version_parts = [int(x) for x in version.split(".")]
min_version_parts = [int(x) for x in min_version.split(".")]
return version_parts >= min_version_parts
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,6 @@ include = "llama_index/"
filterwarnings = [
"ignore::DeprecationWarning:",
]
markers = [
"noautousefixtures: marks tests that should not run fixtures with autouse",
]
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
VectorStoreQuery,
)
from llama_index.vector_stores.mariadb import MariaDBVectorStore
from llama_index.vector_stores.mariadb.base import _meets_min_server_version

TEST_NODES: List[TextNode] = [
TextNode(
Expand Down Expand Up @@ -76,21 +77,45 @@


@pytest.fixture(autouse=True)
def teardown() -> Generator:
def teardown(request: pytest.FixtureRequest) -> Generator:
"""Clear the store after a test completion."""
yield

if "noautousefixtures" in request.keywords:
return

vector_store.clear()


@pytest.fixture(scope="session", autouse=True)
def close_db_connection() -> Generator:
def close_db_connection(request: pytest.FixtureRequest) -> Generator:
"""Close the DB connections after the last test."""
yield

if "noautousefixtures" in request.keywords:
return

vector_store.close()


@pytest.mark.parametrize(
("version", "supported"),
[
("11.7.2-MariaDB-ubu2504", True),
("11.7.1-MariaDB-ubu2404", True),
("11.8.0", True),
("12.0.0", True),
("11.7.0", False),
("11.6.0-MariaDB-ubu2404", False),
("10.11.7-MariaDB-1:10.11.7+maria~ubu2204", False),
("8.4.3", False),
],
)
@pytest.mark.noautousefixtures()
def test_meets_min_server_version(version: str, supported: bool) -> None:
assert _meets_min_server_version(version, "11.7.1") == supported


@pytest.mark.skipif(
run_integration_tests is False,
reason="MariaDB instance required for integration tests",
Expand Down

0 comments on commit db6ac68

Please sign in to comment.