-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add async redis client utils to
prefect-redis
(#16417)
- Loading branch information
Showing
3 changed files
with
224 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
import asyncio | ||
import functools | ||
from typing import Any, Callable, Union | ||
|
||
from pydantic import Field | ||
from redis.asyncio import Redis | ||
from typing_extensions import TypeAlias | ||
|
||
from prefect.settings.base import ( | ||
PrefectBaseSettings, | ||
_build_settings_config, # type: ignore[reportPrivateUsage] | ||
) | ||
|
||
|
||
class RedisMessagingSettings(PrefectBaseSettings): | ||
model_config = _build_settings_config( | ||
( | ||
"redis", | ||
"messaging", | ||
), | ||
frozen=True, | ||
) | ||
|
||
host: str = Field(default="localhost") | ||
port: int = Field(default=6379) | ||
db: int = Field(default=0) | ||
username: str = Field(default="default") | ||
password: str = Field(default="") | ||
health_check_interval: int = Field( | ||
default=20, | ||
description="Health check interval for pinging the server; defaults to 20 seconds.", | ||
) | ||
ssl: bool = Field( | ||
default=False, | ||
description="Whether to use SSL for the Redis connection", | ||
) | ||
|
||
|
||
CacheKey: TypeAlias = tuple[ | ||
Callable[..., Any], | ||
tuple[Any, ...], | ||
tuple[tuple[str, Any], ...], | ||
Union[asyncio.AbstractEventLoop, None], | ||
] | ||
|
||
_client_cache: dict[CacheKey, Redis] = {} | ||
|
||
|
||
def _running_loop() -> Union[asyncio.AbstractEventLoop, None]: | ||
try: | ||
return asyncio.get_running_loop() | ||
except RuntimeError as e: | ||
if "no running event loop" in str(e): | ||
return None | ||
raise | ||
|
||
|
||
def cached(fn: Callable[..., Any]) -> Callable[..., Any]: | ||
@functools.wraps(fn) | ||
def cached_fn(*args: Any, **kwargs: Any) -> Redis: | ||
key = (fn, args, tuple(kwargs.items()), _running_loop()) | ||
if key not in _client_cache: | ||
_client_cache[key] = fn(*args, **kwargs) | ||
return _client_cache[key] | ||
|
||
return cached_fn | ||
|
||
|
||
def close_all_cached_connections() -> None: | ||
"""Close all cached Redis connections.""" | ||
loop: Union[asyncio.AbstractEventLoop, None] | ||
|
||
for (_, _, _, loop), client in _client_cache.items(): | ||
if loop and loop.is_closed(): | ||
continue | ||
|
||
loop = loop or asyncio.get_event_loop() | ||
loop.run_until_complete(client.connection_pool.disconnect()) # type: ignore | ||
loop.run_until_complete(client.close(close_connection_pool=True)) | ||
|
||
|
||
@cached | ||
def get_async_redis_client( | ||
host: Union[str, None] = None, | ||
port: Union[int, None] = None, | ||
db: Union[int, None] = None, | ||
password: Union[str, None] = None, | ||
username: Union[str, None] = None, | ||
health_check_interval: Union[int, None] = None, | ||
decode_responses: bool = True, | ||
ssl: Union[bool, None] = None, | ||
) -> Redis: | ||
"""Retrieves an async Redis client. | ||
Args: | ||
host: The host location. | ||
port: The port to connect to the host with. | ||
db: The Redis database to interact with. | ||
password: The password for the redis host | ||
username: Username for the redis instance | ||
decode_responses: Whether to decode binary responses from Redis to | ||
unicode strings. | ||
Returns: | ||
Redis: a Redis client | ||
""" | ||
settings = RedisMessagingSettings() | ||
|
||
return Redis( | ||
host=host or settings.host, | ||
port=port or settings.port, | ||
db=db or settings.db, | ||
password=password or settings.password, | ||
username=username or settings.username, | ||
health_check_interval=health_check_interval or settings.health_check_interval, | ||
ssl=ssl or settings.ssl, | ||
decode_responses=decode_responses, | ||
retry_on_timeout=True, | ||
) | ||
|
||
|
||
@cached | ||
def async_redis_from_settings( | ||
settings: RedisMessagingSettings, **options: Any | ||
) -> Redis: | ||
options = { | ||
"retry_on_timeout": True, | ||
"decode_responses": True, | ||
**options, | ||
} | ||
return Redis( | ||
host=settings.host, | ||
port=settings.port, | ||
db=settings.db, | ||
password=settings.password, | ||
username=settings.username, | ||
health_check_interval=settings.health_check_interval, | ||
ssl=settings.ssl, | ||
**options, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
from unittest.mock import MagicMock, patch | ||
|
||
from prefect_redis.client import ( | ||
RedisMessagingSettings, | ||
async_redis_from_settings, | ||
close_all_cached_connections, | ||
get_async_redis_client, | ||
) | ||
from redis.asyncio import Redis | ||
|
||
|
||
def test_redis_settings_defaults(): | ||
"""Test that RedisSettings has expected defaults""" | ||
settings = RedisMessagingSettings() | ||
assert settings.host == "localhost" | ||
assert settings.port == 6379 | ||
assert settings.db == 0 | ||
assert settings.username == "default" | ||
assert settings.password == "" | ||
assert settings.health_check_interval == 20 | ||
assert settings.ssl is False | ||
|
||
|
||
async def test_get_async_redis_client_defaults(): | ||
"""Test that get_async_redis_client creates client with default settings""" | ||
client = get_async_redis_client() | ||
assert isinstance(client, Redis) | ||
assert client.connection_pool.connection_kwargs["host"] == "localhost" | ||
assert client.connection_pool.connection_kwargs["port"] == 6379 | ||
await client.aclose() | ||
|
||
|
||
async def test_get_async_redis_client_custom_params(): | ||
"""Test that get_async_redis_client respects custom parameters""" | ||
client = get_async_redis_client( | ||
host="custom.host", | ||
port=6380, | ||
db=1, | ||
username="custom_user", | ||
password="secret", | ||
) | ||
conn_kwargs = client.connection_pool.connection_kwargs | ||
assert conn_kwargs["host"] == "custom.host" | ||
assert conn_kwargs["port"] == 6380 | ||
assert conn_kwargs["db"] == 1 | ||
assert conn_kwargs["username"] == "custom_user" | ||
assert conn_kwargs["password"] == "secret" | ||
await client.aclose() | ||
|
||
|
||
async def test_async_redis_from_settings(): | ||
"""Test creating Redis client from settings object""" | ||
settings = RedisMessagingSettings( | ||
host="settings.host", | ||
port=6381, | ||
username="settings_user", | ||
) | ||
client = async_redis_from_settings(settings) | ||
conn_kwargs = client.connection_pool.connection_kwargs | ||
assert conn_kwargs["host"] == "settings.host" | ||
assert conn_kwargs["port"] == 6381 | ||
assert conn_kwargs["username"] == "settings_user" | ||
await client.aclose() | ||
|
||
|
||
@patch("prefect_redis.client._client_cache") | ||
def test_close_all_cached_connections(mock_cache): | ||
"""Test that close_all_cached_connections properly closes all clients""" | ||
mock_client = MagicMock() | ||
mock_loop = MagicMock() | ||
mock_loop.is_closed.return_value = False | ||
|
||
# Mock the coroutines that would be awaited | ||
mock_loop.run_until_complete.return_value = None | ||
|
||
mock_cache.items.return_value = [((None, None, None, mock_loop), mock_client)] | ||
|
||
close_all_cached_connections() | ||
|
||
# Verify run_until_complete was called twice (for disconnect and close) | ||
assert mock_loop.run_until_complete.call_count == 2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters