Skip to content

Commit

Permalink
add async redis client utils to prefect-redis (#16417)
Browse files Browse the repository at this point in the history
  • Loading branch information
zzstoatzz authored Dec 17, 2024
1 parent 8e37680 commit 3174ce0
Show file tree
Hide file tree
Showing 3 changed files with 224 additions and 2 deletions.
140 changes: 140 additions & 0 deletions src/integrations/prefect-redis/prefect_redis/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
import asyncio
import functools
from typing import Any, Callable, Union

from pydantic import Field
from redis.asyncio import Redis
from typing_extensions import TypeAlias

from prefect.settings.base import (
PrefectBaseSettings,
_build_settings_config, # type: ignore[reportPrivateUsage]
)


class RedisMessagingSettings(PrefectBaseSettings):
model_config = _build_settings_config(
(
"redis",
"messaging",
),
frozen=True,
)

host: str = Field(default="localhost")
port: int = Field(default=6379)
db: int = Field(default=0)
username: str = Field(default="default")
password: str = Field(default="")
health_check_interval: int = Field(
default=20,
description="Health check interval for pinging the server; defaults to 20 seconds.",
)
ssl: bool = Field(
default=False,
description="Whether to use SSL for the Redis connection",
)


CacheKey: TypeAlias = tuple[
Callable[..., Any],
tuple[Any, ...],
tuple[tuple[str, Any], ...],
Union[asyncio.AbstractEventLoop, None],
]

_client_cache: dict[CacheKey, Redis] = {}


def _running_loop() -> Union[asyncio.AbstractEventLoop, None]:
try:
return asyncio.get_running_loop()
except RuntimeError as e:
if "no running event loop" in str(e):
return None
raise


def cached(fn: Callable[..., Any]) -> Callable[..., Any]:
@functools.wraps(fn)
def cached_fn(*args: Any, **kwargs: Any) -> Redis:
key = (fn, args, tuple(kwargs.items()), _running_loop())
if key not in _client_cache:
_client_cache[key] = fn(*args, **kwargs)
return _client_cache[key]

return cached_fn


def close_all_cached_connections() -> None:
"""Close all cached Redis connections."""
loop: Union[asyncio.AbstractEventLoop, None]

for (_, _, _, loop), client in _client_cache.items():
if loop and loop.is_closed():
continue

loop = loop or asyncio.get_event_loop()
loop.run_until_complete(client.connection_pool.disconnect()) # type: ignore
loop.run_until_complete(client.close(close_connection_pool=True))


@cached
def get_async_redis_client(
host: Union[str, None] = None,
port: Union[int, None] = None,
db: Union[int, None] = None,
password: Union[str, None] = None,
username: Union[str, None] = None,
health_check_interval: Union[int, None] = None,
decode_responses: bool = True,
ssl: Union[bool, None] = None,
) -> Redis:
"""Retrieves an async Redis client.
Args:
host: The host location.
port: The port to connect to the host with.
db: The Redis database to interact with.
password: The password for the redis host
username: Username for the redis instance
decode_responses: Whether to decode binary responses from Redis to
unicode strings.
Returns:
Redis: a Redis client
"""
settings = RedisMessagingSettings()

return Redis(
host=host or settings.host,
port=port or settings.port,
db=db or settings.db,
password=password or settings.password,
username=username or settings.username,
health_check_interval=health_check_interval or settings.health_check_interval,
ssl=ssl or settings.ssl,
decode_responses=decode_responses,
retry_on_timeout=True,
)


@cached
def async_redis_from_settings(
settings: RedisMessagingSettings, **options: Any
) -> Redis:
options = {
"retry_on_timeout": True,
"decode_responses": True,
**options,
}
return Redis(
host=settings.host,
port=settings.port,
db=settings.db,
password=settings.password,
username=settings.username,
health_check_interval=settings.health_check_interval,
ssl=settings.ssl,
**options,
)
81 changes: 81 additions & 0 deletions src/integrations/prefect-redis/tests/test_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from unittest.mock import MagicMock, patch

from prefect_redis.client import (
RedisMessagingSettings,
async_redis_from_settings,
close_all_cached_connections,
get_async_redis_client,
)
from redis.asyncio import Redis


def test_redis_settings_defaults():
"""Test that RedisSettings has expected defaults"""
settings = RedisMessagingSettings()
assert settings.host == "localhost"
assert settings.port == 6379
assert settings.db == 0
assert settings.username == "default"
assert settings.password == ""
assert settings.health_check_interval == 20
assert settings.ssl is False


async def test_get_async_redis_client_defaults():
"""Test that get_async_redis_client creates client with default settings"""
client = get_async_redis_client()
assert isinstance(client, Redis)
assert client.connection_pool.connection_kwargs["host"] == "localhost"
assert client.connection_pool.connection_kwargs["port"] == 6379
await client.aclose()


async def test_get_async_redis_client_custom_params():
"""Test that get_async_redis_client respects custom parameters"""
client = get_async_redis_client(
host="custom.host",
port=6380,
db=1,
username="custom_user",
password="secret",
)
conn_kwargs = client.connection_pool.connection_kwargs
assert conn_kwargs["host"] == "custom.host"
assert conn_kwargs["port"] == 6380
assert conn_kwargs["db"] == 1
assert conn_kwargs["username"] == "custom_user"
assert conn_kwargs["password"] == "secret"
await client.aclose()


async def test_async_redis_from_settings():
"""Test creating Redis client from settings object"""
settings = RedisMessagingSettings(
host="settings.host",
port=6381,
username="settings_user",
)
client = async_redis_from_settings(settings)
conn_kwargs = client.connection_pool.connection_kwargs
assert conn_kwargs["host"] == "settings.host"
assert conn_kwargs["port"] == 6381
assert conn_kwargs["username"] == "settings_user"
await client.aclose()


@patch("prefect_redis.client._client_cache")
def test_close_all_cached_connections(mock_cache):
"""Test that close_all_cached_connections properly closes all clients"""
mock_client = MagicMock()
mock_loop = MagicMock()
mock_loop.is_closed.return_value = False

# Mock the coroutines that would be awaited
mock_loop.run_until_complete.return_value = None

mock_cache.items.return_value = [((None, None, None, mock_loop), mock_client)]

close_all_cached_connections()

# Verify run_until_complete was called twice (for disconnect and close)
assert mock_loop.run_until_complete.call_count == 2
5 changes: 3 additions & 2 deletions src/prefect/settings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def _add_environment_variables(


def _build_settings_config(
path: Tuple[str, ...] = tuple(),
path: Tuple[str, ...] = tuple(), frozen: bool = False
) -> PrefectSettingsConfigDict:
env_prefix = f"PREFECT_{'_'.join(path).upper()}_" if path else "PREFECT_"
return PrefectSettingsConfigDict(
Expand All @@ -202,7 +202,8 @@ def _build_settings_config(
toml_file="prefect.toml",
prefect_toml_table_header=path,
pyproject_toml_table_header=("tool", "prefect", *path),
json_schema_extra=_add_environment_variables,
json_schema_extra=_add_environment_variables, # type: ignore
frozen=frozen,
)


Expand Down

0 comments on commit 3174ce0

Please sign in to comment.