Skip to content

Commit

Permalink
style fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
vgvoleg committed Oct 23, 2024
1 parent be55108 commit 47cce3a
Show file tree
Hide file tree
Showing 9 changed files with 104 additions and 54 deletions.
7 changes: 5 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "ydb-dbapi"
version = "0.1.0"
version = "0.0.1"
description = ""
authors = ["Oleg Ovcharuk <[email protected]>"]
readme = "README.md"
Expand Down Expand Up @@ -58,6 +58,8 @@ ignore = [
# Ignores below could be deleted
"EM101", # Allow to use string literals in exceptions
"TRY003", # Allow specifying long messages outside the exception class
"SLF001", # Allow access private member,
"PGH003", # Allow not to specify rule codes
]
select = ["ALL"]

Expand All @@ -72,7 +74,8 @@ dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"
force-single-line = true

[tool.ruff.lint.per-file-ignores]
"**/test_*.py" = ["S", "SLF", "ANN201", "ARG", "PLR2004"]
"**/test_*.py" = ["S", "SLF", "ANN201", "ARG", "PLR2004", "PT012"]
"conftest.py" = ["ARG001"]
"__init__.py" = ["F401", "F403"]

[tool.pytest.ini_options]
Expand Down
18 changes: 14 additions & 4 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

from asyncio import AbstractEventLoop
from collections.abc import AsyncGenerator
from collections.abc import Generator
from typing import Any
from typing import Callable
Expand All @@ -9,6 +11,7 @@
from testcontainers.core.generic import DbContainer
from testcontainers.core.generic import wait_container_is_ready
from testcontainers.core.utils import setup_logger
from typing_extensions import Self

logger = setup_logger(__name__)

Expand All @@ -33,7 +36,7 @@ def __init__(
self._name = name
self._database_name = "local"

def start(self):
def start(self) -> Self:
self._maybe_stop_old_container()
super().start()
return self
Expand Down Expand Up @@ -115,7 +118,9 @@ def connection_kwargs(ydb_container: YDBContainer) -> dict:


@pytest.fixture
async def driver(ydb_container, event_loop):
async def driver(
ydb_container: YDBContainer, event_loop: AbstractEventLoop
) -> AsyncGenerator[ydb.aio.Driver]:
driver = ydb.aio.Driver(
connection_string=ydb_container.get_connection_string()
)
Expand All @@ -128,7 +133,9 @@ async def driver(ydb_container, event_loop):


@pytest.fixture
async def session_pool(driver: ydb.aio.Driver):
async def session_pool(
driver: ydb.aio.Driver,
) -> AsyncGenerator[ydb.aio.QuerySessionPool]:
session_pool = ydb.aio.QuerySessionPool(driver)
async with session_pool:
await session_pool.execute_with_retries(
Expand All @@ -146,8 +153,11 @@ async def session_pool(driver: ydb.aio.Driver):

yield session_pool


@pytest.fixture
async def session(session_pool: ydb.aio.QuerySessionPool):
async def session(
session_pool: ydb.aio.QuerySessionPool,
) -> AsyncGenerator[ydb.aio.QuerySession]:
session = await session_pool.acquire()

yield session
Expand Down
30 changes: 21 additions & 9 deletions tests/test_connection.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from __future__ import annotations

from collections.abc import AsyncGenerator
from contextlib import suppress

import pytest
Expand Down Expand Up @@ -37,7 +40,7 @@ async def _test_isolation_level_read_only(
await connection.rollback()

async with connection.cursor() as cursor:
cursor.execute("DROP TABLE foo")
await cursor.execute("DROP TABLE foo")

async def _test_connection(self, connection: dbapi.Connection) -> None:
await connection.commit()
Expand Down Expand Up @@ -66,7 +69,9 @@ async def _test_connection(self, connection: dbapi.Connection) -> None:
await cur.execute("DROP TABLE foo")
await cur.close()

async def _test_cursor_raw_query(self, connection: dbapi.Connection) -> None:
async def _test_cursor_raw_query(
self, connection: dbapi.Connection
) -> None:
cur = connection.cursor()
assert cur

Expand Down Expand Up @@ -107,7 +112,10 @@ async def _test_cursor_raw_query(self, connection: dbapi.Connection) -> None:

async def _test_errors(self, connection: dbapi.Connection) -> None:
with pytest.raises(dbapi.InterfaceError):
await dbapi.connect("localhost:2136", database="/local666")
await dbapi.connect(
"localhost:2136", # type: ignore
database="/local666", # type: ignore
)

cur = connection.cursor()

Expand Down Expand Up @@ -142,8 +150,10 @@ async def _test_errors(self, connection: dbapi.Connection) -> None:

class TestAsyncConnection(BaseDBApiTestSuit):
@pytest_asyncio.fixture
async def connection(self, connection_kwargs):
conn = await dbapi.connect(**connection_kwargs)
async def connection(
self, connection_kwargs: dict
) -> AsyncGenerator[dbapi.Connection]:
conn = await dbapi.connect(**connection_kwargs) # ignore: typing
try:
yield conn
finally:
Expand All @@ -166,19 +176,21 @@ async def test_isolation_level_read_only(
isolation_level: str,
read_only: bool,
connection: dbapi.Connection,
):
) -> None:
await self._test_isolation_level_read_only(
connection, isolation_level, read_only
)

@pytest.mark.asyncio
async def test_connection(self, connection: dbapi.Connection):
async def test_connection(self, connection: dbapi.Connection) -> None:
await self._test_connection(connection)

@pytest.mark.asyncio
async def test_cursor_raw_query(self, connection: dbapi.Connection):
async def test_cursor_raw_query(
self, connection: dbapi.Connection
) -> None:
await self._test_cursor_raw_query(connection)

@pytest.mark.asyncio
async def test_errors(self, connection: dbapi.Connection):
async def test_errors(self, connection: dbapi.Connection) -> None:
await self._test_errors(connection)
22 changes: 16 additions & 6 deletions tests/test_cursor.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import pytest
import ydb_dbapi
from ydb.aio import QuerySession


@pytest.mark.asyncio
async def test_cursor_ddl(session):
async def test_cursor_ddl(session: QuerySession) -> None:
cursor = ydb_dbapi.Cursor(session=session)

yql = """
Expand All @@ -27,7 +28,7 @@ async def test_cursor_ddl(session):


@pytest.mark.asyncio
async def test_cursor_dml(session):
async def test_cursor_dml(session: QuerySession) -> None:
cursor = ydb_dbapi.Cursor(session=session)
yql_text = """
INSERT INTO table (id, val) VALUES
Expand All @@ -48,12 +49,13 @@ async def test_cursor_dml(session):
await cursor.execute(query=yql_text)

res = await cursor.fetchone()
assert res is not None
assert len(res) == 1
assert res[0] == 3


@pytest.mark.asyncio
async def test_cursor_fetch_one(session):
async def test_cursor_fetch_one(session: QuerySession) -> None:
cursor = ydb_dbapi.Cursor(session=session)
yql_text = """
INSERT INTO table (id, val) VALUES
Expand All @@ -73,16 +75,18 @@ async def test_cursor_fetch_one(session):
await cursor.execute(query=yql_text)

res = await cursor.fetchone()
assert res is not None
assert res[0] == 1

res = await cursor.fetchone()
assert res is not None
assert res[0] == 2

assert await cursor.fetchone() is None


@pytest.mark.asyncio
async def test_cursor_fetch_many(session):
async def test_cursor_fetch_many(session: QuerySession) -> None:
cursor = ydb_dbapi.Cursor(session=session)
yql_text = """
INSERT INTO table (id, val) VALUES
Expand All @@ -104,23 +108,26 @@ async def test_cursor_fetch_many(session):
await cursor.execute(query=yql_text)

res = await cursor.fetchmany()
assert res is not None
assert len(res) == 1
assert res[0][0] == 1

res = await cursor.fetchmany(size=2)
assert res is not None
assert len(res) == 2
assert res[0][0] == 2
assert res[1][0] == 3

res = await cursor.fetchmany(size=2)
assert res is not None
assert len(res) == 1
assert res[0][0] == 4

assert await cursor.fetchmany(size=2) is None


@pytest.mark.asyncio
async def test_cursor_fetch_all(session):
async def test_cursor_fetch_all(session: QuerySession) -> None:
cursor = ydb_dbapi.Cursor(session=session)
yql_text = """
INSERT INTO table (id, val) VALUES
Expand All @@ -143,6 +150,7 @@ async def test_cursor_fetch_all(session):
assert cursor.rowcount == 3

res = await cursor.fetchall()
assert res is not None
assert len(res) == 3
assert res[0][0] == 1
assert res[1][0] == 2
Expand All @@ -152,20 +160,22 @@ async def test_cursor_fetch_all(session):


@pytest.mark.asyncio
async def test_cursor_next_set(session):
async def test_cursor_next_set(session: QuerySession) -> None:
cursor = ydb_dbapi.Cursor(session=session)
yql_text = """SELECT 1 as val; SELECT 2 as val;"""

await cursor.execute(query=yql_text)

res = await cursor.fetchall()
assert res is not None
assert len(res) == 1
assert res[0][0] == 1

nextset = await cursor.nextset()
assert nextset

res = await cursor.fetchall()
assert res is not None
assert len(res) == 1
assert res[0][0] == 2

Expand Down
6 changes: 3 additions & 3 deletions ydb_dbapi/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .connection import Connection
from .connection import IsolationLevel
from .connection import connect
from .connections import Connection
from .connections import IsolationLevel
from .connections import connect
from .cursors import Cursor
from .errors import *
42 changes: 26 additions & 16 deletions ydb_dbapi/connection.py → ydb_dbapi/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,7 @@ def __init__(
"ydb_session_pool" in self.conn_kwargs
): # Use session pool managed manually
self._shared_session_pool = True
self._session_pool = self.conn_kwargs.pop(
"ydb_session_pool"
)
self._session_pool = self.conn_kwargs.pop("ydb_session_pool")
self._driver = self._session_pool._driver
else:
self._shared_session_pool = False
Expand Down Expand Up @@ -127,13 +125,24 @@ class Connection(BaseYDBConnection):
_ydb_driver_class = ydb.aio.Driver
_ydb_session_pool_class = ydb.aio.QuerySessionPool

def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
def __init__(
self,
host: str = "",
port: str = "",
database: str = "",
**conn_kwargs: Any,
) -> None:
super().__init__(
host,
port,
database,
**conn_kwargs,
)

self._session: ydb.aio.QuerySession | None = None
self._tx_context: ydb.QueryTxContext | None = None
self._tx_context: ydb.aio.QueryTxContext | None = None

async def _wait(self, timeout: int = 5) -> None:
async def wait_ready(self, timeout: int = 5) -> None:
try:
await self._driver.wait(timeout, fail_fast=True)
except ydb.Error as e:
Expand All @@ -144,13 +153,13 @@ async def _wait(self, timeout: int = 5) -> None:
"Failed to connect to YDB, details "
f"{self._driver.discovery_debug_details()}"
)
raise InterfaceError(
msg
) from e
raise InterfaceError(msg) from e

self._session = await self._session_pool.acquire()

def cursor(self):
def cursor(self) -> Cursor:
if self._session is None:
raise RuntimeError("Connection is not ready, use wait_ready.")
if self._current_cursor and not self._current_cursor._closed:
raise RuntimeError(
"Unable to create new Cursor before closing existing one."
Expand Down Expand Up @@ -218,12 +227,13 @@ async def callee() -> None:
await self._driver.scheme_client.describe_path(table_path)

await retry_operation_async(callee)
return True
except ydb.SchemeError:
return False
else:
return True

async def _get_table_names(self, abs_dir_path: str) -> list[str]:
async def callee():
async def callee() -> ydb.Directory:
return await self._driver.scheme_client.list_directory(
abs_dir_path
)
Expand All @@ -239,7 +249,7 @@ async def callee():
return result


async def connect(*args, **kwargs) -> Connection:
conn = Connection(*args, **kwargs)
await conn._wait()
async def connect(*args: tuple, **kwargs: dict) -> Connection:
conn = Connection(*args, **kwargs) # type: ignore
await conn.wait_ready()
return conn
Loading

0 comments on commit 47cce3a

Please sign in to comment.