diff --git a/llama-index-integrations/vector_stores/llama-index-vector-stores-mariadb/llama_index/vector_stores/mariadb/base.py b/llama-index-integrations/vector_stores/llama-index-vector-stores-mariadb/llama_index/vector_stores/mariadb/base.py index 34b5f644a7823..5394ad1c9e1df 100644 --- a/llama-index-integrations/vector_stores/llama-index-vector-stores-mariadb/llama_index/vector_stores/mariadb/base.py +++ b/llama-index-integrations/vector_stores/llama-index-vector-stores-mariadb/llama_index/vector_stores/mariadb/base.py @@ -177,6 +177,17 @@ def _connect(self) -> Any: self.connection_string, connect_args=self.connection_args, echo=self.debug ) + def _validate_server_version(self) -> None: + """Validate that the MariaDB server version is supported.""" + with self._engine.connect() as connection: + result = connection.execute(sqlalchemy.text("SELECT VERSION()")) + version = result.fetchone()[0] + + if not _meets_min_server_version(version, "11.7.1"): + raise ValueError( + f"MariaDB version 11.7.1 or later is required, found version: {version}." + ) + def _create_table_if_not_exists(self) -> None: with self._engine.connect() as connection: # Note that we define the vector index with DISTANCE=cosine, because we use VEC_DISTANCE_COSINE. @@ -201,6 +212,7 @@ def _initialize(self) -> None: if not self._is_initialized: self._connect() if self.perform_setup: + self._validate_server_version() self._create_table_if_not_exists() self._is_initialized = True @@ -431,3 +443,19 @@ def clear(self) -> None: connection.execute(sqlalchemy.text(stmt)) connection.commit() + + +def _meets_min_server_version(version: str, min_version: str) -> bool: + """Check if a MariaDB server version meets minimum required version. + + Args: + version: Version string from MariaDB server (e.g. "11.7.1-MariaDB-ubu2404") + min_version: Minimum required version string (e.g. "11.7.1") + + Returns: + bool: True if version >= min_version, False otherwise + """ + version = version.split("-")[0] + version_parts = [int(x) for x in version.split(".")] + min_version_parts = [int(x) for x in min_version.split(".")] + return version_parts >= min_version_parts diff --git a/llama-index-integrations/vector_stores/llama-index-vector-stores-mariadb/pyproject.toml b/llama-index-integrations/vector_stores/llama-index-vector-stores-mariadb/pyproject.toml index 3f0c06259fa1e..94443f379bba9 100644 --- a/llama-index-integrations/vector_stores/llama-index-vector-stores-mariadb/pyproject.toml +++ b/llama-index-integrations/vector_stores/llama-index-vector-stores-mariadb/pyproject.toml @@ -69,3 +69,6 @@ include = "llama_index/" filterwarnings = [ "ignore::DeprecationWarning:", ] +markers = [ + "noautousefixtures: marks tests that should not run fixtures with autouse", +] diff --git a/llama-index-integrations/vector_stores/llama-index-vector-stores-mariadb/tests/test_mariadb.py b/llama-index-integrations/vector_stores/llama-index-vector-stores-mariadb/tests/test_mariadb.py index cbcd24ee309d7..4dc9f1e2b1a8e 100644 --- a/llama-index-integrations/vector_stores/llama-index-vector-stores-mariadb/tests/test_mariadb.py +++ b/llama-index-integrations/vector_stores/llama-index-vector-stores-mariadb/tests/test_mariadb.py @@ -14,6 +14,7 @@ VectorStoreQuery, ) from llama_index.vector_stores.mariadb import MariaDBVectorStore +from llama_index.vector_stores.mariadb.base import _meets_min_server_version TEST_NODES: List[TextNode] = [ TextNode( @@ -76,21 +77,45 @@ @pytest.fixture(autouse=True) -def teardown() -> Generator: +def teardown(request: pytest.FixtureRequest) -> Generator: """Clear the store after a test completion.""" yield + if "noautousefixtures" in request.keywords: + return + vector_store.clear() @pytest.fixture(scope="session", autouse=True) -def close_db_connection() -> Generator: +def close_db_connection(request: pytest.FixtureRequest) -> Generator: """Close the DB connections after the last test.""" yield + if "noautousefixtures" in request.keywords: + return + vector_store.close() +@pytest.mark.parametrize( + ("version", "supported"), + [ + ("11.7.2-MariaDB-ubu2504", True), + ("11.7.1-MariaDB-ubu2404", True), + ("11.8.0", True), + ("12.0.0", True), + ("11.7.0", False), + ("11.6.0-MariaDB-ubu2404", False), + ("10.11.7-MariaDB-1:10.11.7+maria~ubu2204", False), + ("8.4.3", False), + ], +) +@pytest.mark.noautousefixtures() +def test_meets_min_server_version(version: str, supported: bool) -> None: + assert _meets_min_server_version(version, "11.7.1") == supported + + @pytest.mark.skipif( run_integration_tests is False, reason="MariaDB instance required for integration tests",