Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Invalidate session&tx on YDB errors #16

Merged
merged 1 commit into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@ version = "0.1.5" # AUTOVERSION
description = "YDB Python DBAPI which complies with PEP 249"
authors = ["Yandex LLC <[email protected]>"]
readme = "README.md"

[project.urls]
Homepage = "https://github.com/ydb-platform/ydb-python-dbapi/"
repository = "https://github.com/ydb-platform/ydb-python-dbapi/"

[tool.poetry.dependencies]
python = "^3.8"
Expand Down
39 changes: 39 additions & 0 deletions tests/test_connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,34 @@ def _test_bulk_upsert(self, connection: dbapi.Connection) -> None:

maybe_await(cursor.execute_scheme("DROP TABLE pet"))

def _test_error_with_interactive_tx(
self,
connection: dbapi.Connection,
) -> None:

cur = connection.cursor()
maybe_await(cur.execute_scheme(
"""
DROP TABLE IF EXISTS test;
CREATE TABLE test (
id Int64 NOT NULL,
val Int64,
PRIMARY KEY(id)
)
"""
))

connection.set_isolation_level(dbapi.IsolationLevel.SERIALIZABLE)
maybe_await(connection.begin())

cur = connection.cursor()
maybe_await(cur.execute("INSERT INTO test(id, val) VALUES (1,1)"))
with pytest.raises(dbapi.Error):
maybe_await(cur.execute("INSERT INTO test(id, val) VALUES (1,1)"))

maybe_await(cur.close())
maybe_await(connection.rollback())


class TestConnection(BaseDBApiTestSuit):
@pytest.fixture
Expand Down Expand Up @@ -245,6 +273,11 @@ def test_errors(self, connection: dbapi.Connection) -> None:
def test_bulk_upsert(self, connection: dbapi.Connection) -> None:
self._test_bulk_upsert(connection)

def test_errors_with_interactive_tx(
self, connection: dbapi.Connection
) -> None:
self._test_error_with_interactive_tx(connection)


class TestAsyncConnection(BaseDBApiTestSuit):
@pytest_asyncio.fixture
Expand Down Expand Up @@ -304,3 +337,9 @@ async def test_bulk_upsert(
self, connection: dbapi.AsyncConnection
) -> None:
await greenlet_spawn(self._test_bulk_upsert, connection)

@pytest.mark.asyncio
async def test_errors_with_interactive_tx(
self, connection: dbapi.AsyncConnection
) -> None:
await greenlet_spawn(self._test_error_with_interactive_tx, connection)
34 changes: 34 additions & 0 deletions tests/test_cursors.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@

import pytest
import ydb
import ydb_dbapi
from sqlalchemy.util import await_only
from sqlalchemy.util import greenlet_spawn
from ydb_dbapi import AsyncCursor
from ydb_dbapi import Cursor
from ydb_dbapi.utils import CursorStatus


def maybe_await(obj: callable) -> any:
Expand All @@ -22,6 +24,14 @@ def maybe_await(obj: callable) -> any:
RESULT_SET_COUNT = 3


class FakeSyncConnection:
def _invalidate_session(self) -> None: ...


class FakeAsyncConnection:
async def _invalidate_session(self) -> None: ...


class BaseCursorTestSuit:
def _test_cursor_fetch_one(self, cursor: Cursor | AsyncCursor) -> None:
yql_text = """
Expand Down Expand Up @@ -136,13 +146,24 @@ def _test_cursor_fetch_all_multiple_result_sets(
assert maybe_await(cursor.fetchall()) == []
assert not maybe_await(cursor.nextset())

def _test_cursor_state_after_error(
self, cursor: Cursor | AsyncCursor
) -> None:
query = "INSERT INTO table (id, val) VALUES (0,0)"
with pytest.raises(ydb_dbapi.Error):
maybe_await(cursor.execute(query=query))

assert cursor._state == CursorStatus.finished


class TestCursor(BaseCursorTestSuit):
@pytest.fixture
def sync_cursor(
self, session_pool_sync: ydb.QuerySessionPool
) -> Generator[Cursor]:

cursor = Cursor(
FakeSyncConnection(),
session_pool_sync,
ydb.QuerySerializableReadWrite(),
request_settings=ydb.BaseRequestSettings(),
Expand Down Expand Up @@ -174,6 +195,10 @@ def test_cursor_fetch_all_multiple_result_sets(
) -> None:
self._test_cursor_fetch_all_multiple_result_sets(sync_cursor)

def test_cursor_state_after_error(
self, sync_cursor: Cursor
) -> None:
self._test_cursor_state_after_error(sync_cursor)


class TestAsyncCursor(BaseCursorTestSuit):
Expand All @@ -182,6 +207,7 @@ async def async_cursor(
self, session_pool: ydb.aio.QuerySessionPool
) -> AsyncGenerator[Cursor]:
cursor = AsyncCursor(
FakeAsyncConnection(),
session_pool,
ydb.QuerySerializableReadWrite(),
request_settings=ydb.BaseRequestSettings(),
Expand Down Expand Up @@ -224,3 +250,11 @@ async def test_cursor_fetch_all_multiple_result_sets(
await greenlet_spawn(
self._test_cursor_fetch_all_multiple_result_sets, async_cursor
)

@pytest.mark.asyncio
async def test_cursor_state_after_error(
self, async_cursor: AsyncCursor
) -> None:
await greenlet_spawn(
self._test_cursor_state_after_error, async_cursor
)
16 changes: 16 additions & 0 deletions ydb_dbapi/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ def __init__(

def cursor(self) -> Cursor:
return self._cursor_cls(
connection=self,
session_pool=self._session_pool,
tx_mode=self._tx_mode,
tx_context=self._tx_context,
Expand Down Expand Up @@ -326,6 +327,13 @@ def bulk_upsert(
settings=settings,
)

def _invalidate_session(self) -> None:
if self._tx_context:
self._tx_context = None
if self._session:
self._session_pool.release(self._session)
self._session = None


class AsyncConnection(BaseConnection):
_driver_cls = ydb.aio.Driver
Expand Down Expand Up @@ -357,6 +365,7 @@ def __init__(

def cursor(self) -> AsyncCursor:
return self._cursor_cls(
connection=self,
session_pool=self._session_pool,
tx_mode=self._tx_mode,
tx_context=self._tx_context,
Expand Down Expand Up @@ -492,6 +501,13 @@ async def bulk_upsert(
settings=settings,
)

async def _invalidate_session(self) -> None:
if self._tx_context:
self._tx_context = None
if self._session:
await self._session_pool.release(self._session)
self._session = None


def connect(*args: tuple, **kwargs: dict) -> Connection:
conn = Connection(*args, **kwargs) # type: ignore
Expand Down
46 changes: 46 additions & 0 deletions ydb_dbapi/cursors.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
from __future__ import annotations

import functools
import itertools
from collections.abc import AsyncIterator
from collections.abc import Generator
from collections.abc import Iterator
from collections.abc import Sequence
from inspect import iscoroutinefunction
from typing import TYPE_CHECKING
from typing import Any
from typing import Callable
from typing import Union

import ydb
Expand All @@ -20,6 +23,9 @@
from .utils import maybe_get_current_trace_id

if TYPE_CHECKING:
from .connections import AsyncConnection
from .connections import Connection

ParametersType = dict[
str,
Union[
Expand All @@ -34,6 +40,34 @@ def _get_column_type(type_obj: Any) -> str:
return str(ydb.convert.type_to_native(type_obj))


def invalidate_cursor_on_ydb_error(func: Callable) -> Callable:
if iscoroutinefunction(func):

@functools.wraps(func)
async def awrapper(
self: AsyncCursor, *args: tuple, **kwargs: dict
) -> Any:
try:
return await func(self, *args, **kwargs)
except ydb.Error:
self._state = CursorStatus.finished
await self._connection._invalidate_session()
raise

return awrapper

@functools.wraps(func)
def wrapper(self: Cursor, *args: tuple, **kwargs: dict) -> Any:
try:
return func(self, *args, **kwargs)
except ydb.Error:
self._state = CursorStatus.finished
self._connection._invalidate_session()
raise

return wrapper


class BufferedCursor:
def __init__(self) -> None:
self.arraysize: int = 1
Expand Down Expand Up @@ -154,13 +188,15 @@ def _append_table_path_prefix(self, query: str) -> str:
class Cursor(BufferedCursor):
def __init__(
self,
connection: Connection,
session_pool: ydb.QuerySessionPool,
tx_mode: ydb.BaseQueryTxMode,
request_settings: ydb.BaseRequestSettings,
tx_context: ydb.QueryTxContext | None = None,
table_path_prefix: str = "",
) -> None:
super().__init__()
self._connection = connection
self._session_pool = session_pool
self._tx_mode = tx_mode
self._request_settings = request_settings
Expand Down Expand Up @@ -188,6 +224,7 @@ def _get_request_settings(self) -> ydb.BaseRequestSettings:
return settings

@handle_ydb_errors
@invalidate_cursor_on_ydb_error

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

question: This method doesn't use session, it uses pool. So if the error occur here, this decorator will delete session and transaction, but shouldn't. Example:

with connection.begin():
  connection.execute(DML)
  try:
    connection.execute(DDL) <--- exception occur and session invalidated
  except:
    pass
  connection.execute(DML) <--- fail, because session is lost

def _execute_generic_query(
self, query: str, parameters: ParametersType | None = None
) -> Iterator[ydb.convert.ResultSet]:
Expand All @@ -205,6 +242,7 @@ def callee(
return self._session_pool.retry_operation_sync(callee)

@handle_ydb_errors
@invalidate_cursor_on_ydb_error
def _execute_session_query(
self,
query: str,
Expand All @@ -225,6 +263,7 @@ def callee(
return self._session_pool.retry_operation_sync(callee)

@handle_ydb_errors
@invalidate_cursor_on_ydb_error
def _execute_transactional_query(
self,
tx_context: ydb.QueryTxContext,
Expand Down Expand Up @@ -283,6 +322,7 @@ def executemany(
self.execute(query, parameters)

@handle_ydb_errors
@invalidate_cursor_on_ydb_error
def nextset(self, replace_current: bool = True) -> bool:
if self._stream is None:
return False
Expand Down Expand Up @@ -328,13 +368,15 @@ def __exit__(
class AsyncCursor(BufferedCursor):
def __init__(
self,
connection: AsyncConnection,
session_pool: ydb.aio.QuerySessionPool,
tx_mode: ydb.BaseQueryTxMode,
request_settings: ydb.BaseRequestSettings,
tx_context: ydb.aio.QueryTxContext | None = None,
table_path_prefix: str = "",
) -> None:
super().__init__()
self._connection = connection
self._session_pool = session_pool
self._tx_mode = tx_mode
self._request_settings = request_settings
Expand Down Expand Up @@ -362,6 +404,7 @@ def _get_request_settings(self) -> ydb.BaseRequestSettings:
return settings

@handle_ydb_errors
@invalidate_cursor_on_ydb_error
async def _execute_generic_query(
self, query: str, parameters: ParametersType | None = None
) -> AsyncIterator[ydb.convert.ResultSet]:
Expand All @@ -379,6 +422,7 @@ async def callee(
return await self._session_pool.retry_operation_async(callee)

@handle_ydb_errors
@invalidate_cursor_on_ydb_error
async def _execute_session_query(
self,
query: str,
Expand All @@ -399,6 +443,7 @@ async def callee(
return await self._session_pool.retry_operation_async(callee)

@handle_ydb_errors
@invalidate_cursor_on_ydb_error
async def _execute_transactional_query(
self,
tx_context: ydb.aio.QueryTxContext,
Expand Down Expand Up @@ -457,6 +502,7 @@ async def executemany(
await self.execute(query, parameters)

@handle_ydb_errors
@invalidate_cursor_on_ydb_error
async def nextset(self, replace_current: bool = True) -> bool:
if self._stream is None:
return False
Expand Down
Loading