Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: incorrect concurrent usage of connection and transaction #546

Merged
merged 20 commits into from
Jul 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
90c33da
fix: incorrect concurrent usage of connection and transaction
zevisert Apr 7, 2023
bea6629
refactor: rename contextvar class attributes, add some explaination c…
zevisert Apr 10, 2023
c9e3464
fix: contextvar.get takes no keyword arguments
zevisert Apr 10, 2023
f3078aa
test: add concurrent task tests
zevisert Apr 11, 2023
75969d3
feat: use ContextVar[dict] to track connections and transactions per …
zevisert Apr 11, 2023
4cd7451
test: check multiple databases in the same task use independant conne…
zevisert Apr 11, 2023
e4c95a7
chore: changes for linting and typechecking
zevisert Apr 11, 2023
a38e135
chore: use typing.Tuple for lower python version compatibility
zevisert Apr 11, 2023
460f72e
docs: update comment on _connection_contextmap
zevisert Apr 11, 2023
2d4554d
Update `Connection` and `Transaction` to be robust to concurrent use
zanieb Apr 16, 2023
16403c3
Merge remote-tracking branch 'madkinsz/example/instance-safe' into fi…
zevisert Apr 17, 2023
8370299
chore: remove optional annotation on asyncio.Task
zevisert Apr 18, 2023
1d4896f
test: add new tests for upcoming contextvar inheritance/isolation and…
zevisert May 24, 2023
02a9acb
feat: reimplement concurrency system with contextvar and weakmap
zevisert May 24, 2023
0f93807
chore: apply corrections from linters
zevisert May 24, 2023
f091482
fix: quote WeakKeyDictionary typing for python<=3.7
zevisert May 24, 2023
6fb55a5
docs: add examples for async transaction context and nested transactions
zevisert May 25, 2023
6de4f60
Merge remote-tracking branch 'upstream/master' into fix-transaction-c…
zevisert May 25, 2023
b94f097
fix: remove connection inheritance, add more tests, update docs
zevisert May 26, 2023
0a9e9e5
Merge branch 'master' into fix-transaction-contextvar
zanieb Jul 25, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 78 additions & 14 deletions databases/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import functools
import logging
import typing
import weakref
from contextvars import ContextVar
from types import TracebackType
from urllib.parse import SplitResult, parse_qsl, unquote, urlsplit
Expand All @@ -11,7 +12,7 @@
from sqlalchemy.sql import ClauseElement

from databases.importer import import_from_string
from databases.interfaces import DatabaseBackend, Record
from databases.interfaces import DatabaseBackend, Record, TransactionBackend

try: # pragma: no cover
import click
Expand All @@ -35,6 +36,11 @@
logger = logging.getLogger("databases")


_ACTIVE_TRANSACTIONS: ContextVar[
typing.Optional["weakref.WeakKeyDictionary['Transaction', 'TransactionBackend']"]
] = ContextVar("databases:active_transactions", default=None)


class Database:
SUPPORTED_BACKENDS = {
"postgresql": "databases.backends.postgres:PostgresBackend",
Expand All @@ -45,6 +51,8 @@ class Database:
"sqlite": "databases.backends.sqlite:SQLiteBackend",
}

_connection_map: "weakref.WeakKeyDictionary[asyncio.Task, 'Connection']"

def __init__(
self,
url: typing.Union[str, "DatabaseURL"],
Expand All @@ -55,6 +63,7 @@ def __init__(
self.url = DatabaseURL(url)
self.options = options
self.is_connected = False
self._connection_map = weakref.WeakKeyDictionary()

self._force_rollback = force_rollback

Expand All @@ -63,14 +72,35 @@ def __init__(
assert issubclass(backend_cls, DatabaseBackend)
self._backend = backend_cls(self.url, **self.options)

# Connections are stored as task-local state.
self._connection_context: ContextVar = ContextVar("connection_context")

# When `force_rollback=True` is used, we use a single global
# connection, within a transaction that always rolls back.
self._global_connection: typing.Optional[Connection] = None
self._global_transaction: typing.Optional[Transaction] = None

@property
def _current_task(self) -> asyncio.Task:
task = asyncio.current_task()
if not task:
raise RuntimeError("No currently active asyncio.Task found")
return task

@property
def _connection(self) -> typing.Optional["Connection"]:
return self._connection_map.get(self._current_task)

@_connection.setter
def _connection(
self, connection: typing.Optional["Connection"]
) -> typing.Optional["Connection"]:
task = self._current_task

if connection is None:
self._connection_map.pop(task, None)
else:
self._connection_map[task] = connection

return self._connection

async def connect(self) -> None:
"""
Establish the connection pool.
Expand All @@ -89,7 +119,7 @@ async def connect(self) -> None:
assert self._global_connection is None
assert self._global_transaction is None

self._global_connection = Connection(self._backend)
self._global_connection = Connection(self, self._backend)
self._global_transaction = self._global_connection.transaction(
force_rollback=True
)
Expand All @@ -113,7 +143,7 @@ async def disconnect(self) -> None:
self._global_transaction = None
self._global_connection = None
else:
self._connection_context = ContextVar("connection_context")
self._connection = None

await self._backend.disconnect()
logger.info(
Expand Down Expand Up @@ -187,12 +217,10 @@ def connection(self) -> "Connection":
if self._global_connection is not None:
return self._global_connection

try:
return self._connection_context.get()
except LookupError:
connection = Connection(self._backend)
self._connection_context.set(connection)
return connection
if not self._connection:
self._connection = Connection(self, self._backend)

return self._connection

def transaction(
self, *, force_rollback: bool = False, **kwargs: typing.Any
Expand All @@ -215,7 +243,8 @@ def _get_backend(self) -> str:


class Connection:
def __init__(self, backend: DatabaseBackend) -> None:
def __init__(self, database: Database, backend: DatabaseBackend) -> None:
self._database = database
self._backend = backend

self._connection_lock = asyncio.Lock()
Expand Down Expand Up @@ -249,6 +278,7 @@ async def __aexit__(
self._connection_counter -= 1
if self._connection_counter == 0:
await self._connection.release()
self._database._connection = None

async def fetch_all(
self,
Expand Down Expand Up @@ -345,6 +375,37 @@ def __init__(
self._force_rollback = force_rollback
self._extra_options = kwargs

@property
def _connection(self) -> "Connection":
# Returns the same connection if called multiple times
return self._connection_callable()

@property
def _transaction(self) -> typing.Optional["TransactionBackend"]:
transactions = _ACTIVE_TRANSACTIONS.get()
if transactions is None:
return None

return transactions.get(self, None)

@_transaction.setter
def _transaction(
self, transaction: typing.Optional["TransactionBackend"]
) -> typing.Optional["TransactionBackend"]:
transactions = _ACTIVE_TRANSACTIONS.get()
if transactions is None:
transactions = weakref.WeakKeyDictionary()
else:
transactions = transactions.copy()

if transaction is None:
transactions.pop(self, None)
else:
transactions[self] = transaction

_ACTIVE_TRANSACTIONS.set(transactions)
return transactions.get(self, None)

async def __aenter__(self) -> "Transaction":
"""
Called when entering `async with database.transaction()`
Expand Down Expand Up @@ -385,7 +446,6 @@ async def wrapper(*args: typing.Any, **kwargs: typing.Any) -> typing.Any:
return wrapper # type: ignore

async def start(self) -> "Transaction":
self._connection = self._connection_callable()
self._transaction = self._connection._connection.transaction()

async with self._connection._transaction_lock:
Expand All @@ -401,15 +461,19 @@ async def commit(self) -> None:
async with self._connection._transaction_lock:
assert self._connection._transaction_stack[-1] is self
self._connection._transaction_stack.pop()
assert self._transaction is not None
await self._transaction.commit()
await self._connection.__aexit__()
self._transaction = None

async def rollback(self) -> None:
async with self._connection._transaction_lock:
assert self._connection._transaction_stack[-1] is self
self._connection._transaction_stack.pop()
assert self._transaction is not None
await self._transaction.rollback()
await self._connection.__aexit__()
self._transaction = None


class _EmptyNetloc(str):
Expand Down
54 changes: 50 additions & 4 deletions docs/connections_and_transactions.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@ that transparently handles the use of either transactions or savepoints.

## Connecting and disconnecting

You can control the database connect/disconnect, by using it as a async context manager.
You can control the database connection pool with an async context manager:

```python
async with Database(DATABASE_URL) as database:
...
```

Or by using explicit connection and disconnection:
Or by using the explicit `.connect()` and `.disconnect()` methods:

```python
database = Database(DATABASE_URL)
Expand All @@ -23,6 +23,8 @@ await database.connect()
await database.disconnect()
```

Connections within this connection pool are acquired for each new `asyncio.Task`.

If you're integrating against a web framework, then you'll probably want
to hook into framework startup or shutdown events. For example, with
[Starlette][starlette] you would use the following:
Expand Down Expand Up @@ -67,6 +69,7 @@ A transaction can be acquired from the database connection pool:
async with database.transaction():
...
```

It can also be acquired from a specific database connection:

```python
Expand Down Expand Up @@ -95,8 +98,51 @@ async def create_users(request):
...
```

Transaction blocks are managed as task-local state. Nested transactions
are fully supported, and are implemented using database savepoints.
Transaction state is tied to the connection used in the currently executing asynchronous task.
If you would like to influence an active transaction from another task, the connection must be
shared. This state is _inherited_ by tasks that are share the same connection:

```python
async def add_excitement(connnection: databases.core.Connection, id: int):
await connection.execute(
"UPDATE notes SET text = CONCAT(text, '!!!') WHERE id = :id",
{"id": id}
)


async with Database(database_url) as database:
async with database.transaction():
# This note won't exist until the transaction closes...
await database.execute(
"INSERT INTO notes(id, text) values (1, 'databases is cool')"
)
# ...but child tasks can use this connection now!
await asyncio.create_task(add_excitement(database.connection(), id=1))

await database.fetch_val("SELECT text FROM notes WHERE id=1")
# ^ returns: "databases is cool!!!"
```

Nested transactions are fully supported, and are implemented using database savepoints:

```python
async with databases.Database(database_url) as db:
async with db.transaction() as outer:
# Do something in the outer transaction
...

# Suppress to prevent influence on the outer transaction
with contextlib.suppress(ValueError):
async with db.transaction():
# Do something in the inner transaction
...

raise ValueError('Abort the inner transaction')

# Observe the results of the outer transaction,
# without effects from the inner transaction.
await db.fetch_all('SELECT * FROM ...')
```

Transaction isolation-level can be specified if the driver backend supports that:

Expand Down
Loading