Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the async for row in cursor: infinite loop error #112

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -4,6 +4,7 @@

### 0.2.5

- Fix infinite iteration case when a cursor object is put in the `async for` loop. By @stankudrow in #112.
- Fix pool connection management (the discussion #108 by @DFilyushin) by @stankudrow in #109:

- add the asynchronous context manager support to the `Pool` class with the pool "startup()" as `__aenter__` and "shutdown()" as `__aexit__` methods;
5 changes: 3 additions & 2 deletions asynch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .connection import connect # noqa:F401
from .pool import create_pool # noqa:F401
from asynch.connection import Connection, connect # noqa: F401
from asynch.cursors import Cursor, DictCursor # noqa: F401
from asynch.pool import Pool, create_async_pool, create_pool # noqa: F401
48 changes: 33 additions & 15 deletions asynch/connection.py
Original file line number Diff line number Diff line change
@@ -83,7 +83,10 @@ def connected(self) -> Optional[bool]:
"""

warn(
"consider using `connection.opened` attribute",
(
"Please consider using the `connection.opened` property. "
"This property may be removed in the version 0.2.6 or a later release."
),
DeprecationWarning,
)
return self._opened
@@ -130,7 +133,7 @@ def status(self) -> str:
and the `conn.opened` is False.
:raise ConnectionError: unknown connection state
:return: connection status
:return: the connection status
:rtype: str (ConnectionStatuses StrEnum)
"""

@@ -167,6 +170,8 @@ def echo(self) -> bool:
return self._echo

async def close(self) -> None:
"""Close the connection."""

if self._opened:
await self._connection.disconnect()
self._opened = False
@@ -186,14 +191,27 @@ async def connect(self) -> None:
self._closed = False

def cursor(self, cursor: Optional[Cursor] = None, *, echo: bool = False) -> Cursor:
"""Return the cursor object for the connection.
When a parameter is interpreted as True,
it takes precedence over the corresponding default value.
If cursor is None, but echo is True, then an instance
of a default `Cursor` class will be created with echoing
set to True even if the `self.echo` property returns False.
:param cursor None | Cursor: a Cursor factory class
:param echo bool:
:return: the cursor from a connection
:rtype: Cursor
"""

cursor_cls = cursor or self._cursor_cls
return cursor_cls(self, self._echo or echo)
return cursor_cls(self, echo or self._echo)

async def ping(self) -> None:
"""Check the connection liveliness.
:raises ConnectionError: if ping() has failed
:return: None
"""

@@ -219,17 +237,17 @@ async def connect(
1. conn = Connection(...) # init a Connection instance
2. conn.connect() # connect to a ClickHouse instance
:param dsn: DSN/connection string (if None -> constructed from default dsn parts)
:param user: user string ("default" by default)
:param password: password string ("" by default)
:param host: host string ("127.0.0.1" by default)
:param port: port integer (9000 by default)
:param database: database string ("default" by default)
:param cursor_cls: Cursor class (asynch.Cursor by default)
:param echo: connection echo mode (False by default)
:param kwargs: connection settings
:return: the open connection
:param dsn str: DSN/connection string (if None -> constructed from default dsn parts)
:param user str: user string ("default" by default)
:param password str: password string ("" by default)
:param host str: host string ("127.0.0.1" by default)
:param port int: port integer (9000 by default)
:param database str: database string ("default" by default)
:param cursor_cls Cursor: Cursor class (asynch.Cursor by default)
:param echo bool: echo mode flag (False by default)
:param kwargs dict: connection settings
:return: an opened connection
:rtype: Connection
"""

47 changes: 34 additions & 13 deletions asynch/cursors.py
Original file line number Diff line number Diff line change
@@ -246,7 +246,7 @@ def __aiter__(self):
async def __anext__(self):
while True:
one = await self.fetchone()
if one is None:
if not one:
raise StopAsyncIteration
return one

@@ -349,23 +349,44 @@ def set_query_id(self, query_id=""):


class DictCursor(Cursor):
async def fetchone(self):
row = await super(DictCursor, self).fetchone()
async def fetchone(self) -> dict:
"""Fetch exactly one row from the last executed query.
:raises AttributeError: columns mismatch
:return: one row from the query
:rtype: dict
"""

row = await super().fetchone()
if self._columns:
return dict(zip(self._columns, row)) if row else {}
else:
raise AttributeError("Invalid columns.")
raise AttributeError("Invalid columns.")

async def fetchmany(self, size: int):
rows = await super(DictCursor, self).fetchmany(size)
async def fetchmany(self, size: int) -> list[dict]:
"""Fetch no more than `size` rows from the last executed query.
:raises AttributeError: columns mismatch
:return: the list of rows from the query
:rtype: list[dict]
"""

rows = await super().fetchmany(size)
if self._columns:
return [dict(zip(self._columns, item)) for item in rows] if rows else []
else:
raise AttributeError("Invalid columns.")
raise AttributeError("Invalid columns.")

async def fetchall(self):
rows = await super(DictCursor, self).fetchall()
async def fetchall(self) -> list[dict]:
"""Fetch all resulting rows from the last executed query.
:raises AttributeError: columns mismatch
:return: the list of all possible rows from the query
:rtype: list[dict]
"""

rows = await super().fetchall()
if self._columns:
return [dict(zip(self._columns, item)) for item in rows] if rows else []
else:
raise AttributeError("Invalid columns.")
raise AttributeError("Invalid columns.")
4 changes: 2 additions & 2 deletions asynch/pool.py
Original file line number Diff line number Diff line change
@@ -133,7 +133,7 @@ async def __aenter__(self) -> "Pool":
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
await self.shutdown()

def __repr__(self):
def __repr__(self) -> str:
cls_name = self.__class__.__name__
status = self.status
return (
@@ -155,7 +155,7 @@ def status(self) -> str:
and the `pool.opened` is False.
:raise PoolError: unresolved pool state.
:return: pool status
:return: the pool status
:rtype: str (PoolStatuses StrEnum)
"""

43 changes: 35 additions & 8 deletions tests/test_cursors.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,38 @@
from typing import Any

import pytest

from asynch.connection import Connection
from asynch.cursors import DictCursor
from asynch.proto import constants


@pytest.mark.asyncio
async def test_fetchone(conn):
@pytest.mark.parametrize(
("stmt", "answer"),
[
("SELECT 42", [{"42": 42}]),
("SELECT -21 WHERE 1 != 1", []),
],
)
async def test_cursor_async_for(
stmt: str,
answer: list[dict[str, Any]],
conn: Connection,
):
result: list[dict[str, Any]] = []

async with conn:
async with conn.cursor(cursor=DictCursor) as cursor:
cursor.set_stream_results(stream_results=True, max_row_buffer=1000)
await cursor.execute(stmt)
result = [row async for row in cursor]

assert result == answer


@pytest.mark.asyncio
async def test_fetchone(conn: Connection):
async with conn.cursor() as cursor:
await cursor.execute("SELECT 1")
ret = await cursor.fetchone()
@@ -17,23 +44,23 @@ async def test_fetchone(conn):


@pytest.mark.asyncio
async def test_fetchall(conn):
async def test_fetchall(conn: Connection):
async with conn.cursor() as cursor:
await cursor.execute("SELECT 1")
ret = await cursor.fetchall()
assert ret == [(1,)]


@pytest.mark.asyncio
async def test_dict_cursor(conn):
async def test_dict_cursor(conn: Connection):
async with conn.cursor(cursor=DictCursor) as cursor:
await cursor.execute("SELECT 1")
ret = await cursor.fetchall()
assert ret == [{"1": 1}]


@pytest.mark.asyncio
async def test_insert_dict(conn):
async def test_insert_dict(conn: Connection):
async with conn.cursor(cursor=DictCursor) as cursor:
rows = await cursor.execute(
"""INSERT INTO test.asynch(id,decimal,date,datetime,float,uuid,string,ipv4,ipv6,bool) VALUES""",
@@ -56,7 +83,7 @@ async def test_insert_dict(conn):


@pytest.mark.asyncio
async def test_insert_tuple(conn):
async def test_insert_tuple(conn: Connection):
async with conn.cursor(cursor=DictCursor) as cursor:
rows = await cursor.execute(
"""INSERT INTO test.asynch(id,decimal,date,datetime,float,uuid,string,ipv4,ipv6,bool) VALUES""",
@@ -79,7 +106,7 @@ async def test_insert_tuple(conn):


@pytest.mark.asyncio
async def test_executemany(conn):
async def test_executemany(conn: Connection):
async with conn.cursor(cursor=DictCursor) as cursor:
rows = await cursor.executemany(
"""INSERT INTO test.asynch(id,decimal,date,datetime,float,uuid,string,ipv4,ipv6,bool) VALUES""",
@@ -114,7 +141,7 @@ async def test_executemany(conn):


@pytest.mark.asyncio
async def test_table_ddl(conn):
async def test_table_ddl(conn: Connection):
async with conn.cursor() as cursor:
await cursor.execute("drop table if exists test.alter_table")
create_table_sql = """
@@ -137,7 +164,7 @@ async def test_table_ddl(conn):


@pytest.mark.asyncio
async def test_insert_buffer_overflow(conn):
async def test_insert_buffer_overflow(conn: Connection):
old_buffer_size = constants.BUFFER_SIZE
constants.BUFFER_SIZE = 2**6 + 1