From 3174ce092e0c51b0e90cfdd8a399d3a3e0b22044 Mon Sep 17 00:00:00 2001 From: nate nowack Date: Tue, 17 Dec 2024 12:00:52 -0600 Subject: [PATCH] add async redis client utils to `prefect-redis` (#16417) --- .../prefect-redis/prefect_redis/client.py | 140 ++++++++++++++++++ .../prefect-redis/tests/test_client.py | 81 ++++++++++ src/prefect/settings/base.py | 5 +- 3 files changed, 224 insertions(+), 2 deletions(-) create mode 100644 src/integrations/prefect-redis/prefect_redis/client.py create mode 100644 src/integrations/prefect-redis/tests/test_client.py diff --git a/src/integrations/prefect-redis/prefect_redis/client.py b/src/integrations/prefect-redis/prefect_redis/client.py new file mode 100644 index 000000000000..51f654c15600 --- /dev/null +++ b/src/integrations/prefect-redis/prefect_redis/client.py @@ -0,0 +1,140 @@ +import asyncio +import functools +from typing import Any, Callable, Union + +from pydantic import Field +from redis.asyncio import Redis +from typing_extensions import TypeAlias + +from prefect.settings.base import ( + PrefectBaseSettings, + _build_settings_config, # type: ignore[reportPrivateUsage] +) + + +class RedisMessagingSettings(PrefectBaseSettings): + model_config = _build_settings_config( + ( + "redis", + "messaging", + ), + frozen=True, + ) + + host: str = Field(default="localhost") + port: int = Field(default=6379) + db: int = Field(default=0) + username: str = Field(default="default") + password: str = Field(default="") + health_check_interval: int = Field( + default=20, + description="Health check interval for pinging the server; defaults to 20 seconds.", + ) + ssl: bool = Field( + default=False, + description="Whether to use SSL for the Redis connection", + ) + + +CacheKey: TypeAlias = tuple[ + Callable[..., Any], + tuple[Any, ...], + tuple[tuple[str, Any], ...], + Union[asyncio.AbstractEventLoop, None], +] + +_client_cache: dict[CacheKey, Redis] = {} + + +def _running_loop() -> Union[asyncio.AbstractEventLoop, None]: + try: + return asyncio.get_running_loop() + except RuntimeError as e: + if "no running event loop" in str(e): + return None + raise + + +def cached(fn: Callable[..., Any]) -> Callable[..., Any]: + @functools.wraps(fn) + def cached_fn(*args: Any, **kwargs: Any) -> Redis: + key = (fn, args, tuple(kwargs.items()), _running_loop()) + if key not in _client_cache: + _client_cache[key] = fn(*args, **kwargs) + return _client_cache[key] + + return cached_fn + + +def close_all_cached_connections() -> None: + """Close all cached Redis connections.""" + loop: Union[asyncio.AbstractEventLoop, None] + + for (_, _, _, loop), client in _client_cache.items(): + if loop and loop.is_closed(): + continue + + loop = loop or asyncio.get_event_loop() + loop.run_until_complete(client.connection_pool.disconnect()) # type: ignore + loop.run_until_complete(client.close(close_connection_pool=True)) + + +@cached +def get_async_redis_client( + host: Union[str, None] = None, + port: Union[int, None] = None, + db: Union[int, None] = None, + password: Union[str, None] = None, + username: Union[str, None] = None, + health_check_interval: Union[int, None] = None, + decode_responses: bool = True, + ssl: Union[bool, None] = None, +) -> Redis: + """Retrieves an async Redis client. + + Args: + host: The host location. + port: The port to connect to the host with. + db: The Redis database to interact with. + password: The password for the redis host + username: Username for the redis instance + decode_responses: Whether to decode binary responses from Redis to + unicode strings. + + Returns: + Redis: a Redis client + """ + settings = RedisMessagingSettings() + + return Redis( + host=host or settings.host, + port=port or settings.port, + db=db or settings.db, + password=password or settings.password, + username=username or settings.username, + health_check_interval=health_check_interval or settings.health_check_interval, + ssl=ssl or settings.ssl, + decode_responses=decode_responses, + retry_on_timeout=True, + ) + + +@cached +def async_redis_from_settings( + settings: RedisMessagingSettings, **options: Any +) -> Redis: + options = { + "retry_on_timeout": True, + "decode_responses": True, + **options, + } + return Redis( + host=settings.host, + port=settings.port, + db=settings.db, + password=settings.password, + username=settings.username, + health_check_interval=settings.health_check_interval, + ssl=settings.ssl, + **options, + ) diff --git a/src/integrations/prefect-redis/tests/test_client.py b/src/integrations/prefect-redis/tests/test_client.py new file mode 100644 index 000000000000..11d19c01d90d --- /dev/null +++ b/src/integrations/prefect-redis/tests/test_client.py @@ -0,0 +1,81 @@ +from unittest.mock import MagicMock, patch + +from prefect_redis.client import ( + RedisMessagingSettings, + async_redis_from_settings, + close_all_cached_connections, + get_async_redis_client, +) +from redis.asyncio import Redis + + +def test_redis_settings_defaults(): + """Test that RedisSettings has expected defaults""" + settings = RedisMessagingSettings() + assert settings.host == "localhost" + assert settings.port == 6379 + assert settings.db == 0 + assert settings.username == "default" + assert settings.password == "" + assert settings.health_check_interval == 20 + assert settings.ssl is False + + +async def test_get_async_redis_client_defaults(): + """Test that get_async_redis_client creates client with default settings""" + client = get_async_redis_client() + assert isinstance(client, Redis) + assert client.connection_pool.connection_kwargs["host"] == "localhost" + assert client.connection_pool.connection_kwargs["port"] == 6379 + await client.aclose() + + +async def test_get_async_redis_client_custom_params(): + """Test that get_async_redis_client respects custom parameters""" + client = get_async_redis_client( + host="custom.host", + port=6380, + db=1, + username="custom_user", + password="secret", + ) + conn_kwargs = client.connection_pool.connection_kwargs + assert conn_kwargs["host"] == "custom.host" + assert conn_kwargs["port"] == 6380 + assert conn_kwargs["db"] == 1 + assert conn_kwargs["username"] == "custom_user" + assert conn_kwargs["password"] == "secret" + await client.aclose() + + +async def test_async_redis_from_settings(): + """Test creating Redis client from settings object""" + settings = RedisMessagingSettings( + host="settings.host", + port=6381, + username="settings_user", + ) + client = async_redis_from_settings(settings) + conn_kwargs = client.connection_pool.connection_kwargs + assert conn_kwargs["host"] == "settings.host" + assert conn_kwargs["port"] == 6381 + assert conn_kwargs["username"] == "settings_user" + await client.aclose() + + +@patch("prefect_redis.client._client_cache") +def test_close_all_cached_connections(mock_cache): + """Test that close_all_cached_connections properly closes all clients""" + mock_client = MagicMock() + mock_loop = MagicMock() + mock_loop.is_closed.return_value = False + + # Mock the coroutines that would be awaited + mock_loop.run_until_complete.return_value = None + + mock_cache.items.return_value = [((None, None, None, mock_loop), mock_client)] + + close_all_cached_connections() + + # Verify run_until_complete was called twice (for disconnect and close) + assert mock_loop.run_until_complete.call_count == 2 diff --git a/src/prefect/settings/base.py b/src/prefect/settings/base.py index 3fa42020bcc6..54486cef8ebf 100644 --- a/src/prefect/settings/base.py +++ b/src/prefect/settings/base.py @@ -192,7 +192,7 @@ def _add_environment_variables( def _build_settings_config( - path: Tuple[str, ...] = tuple(), + path: Tuple[str, ...] = tuple(), frozen: bool = False ) -> PrefectSettingsConfigDict: env_prefix = f"PREFECT_{'_'.join(path).upper()}_" if path else "PREFECT_" return PrefectSettingsConfigDict( @@ -202,7 +202,8 @@ def _build_settings_config( toml_file="prefect.toml", prefect_toml_table_header=path, pyproject_toml_table_header=("tool", "prefect", *path), - json_schema_extra=_add_environment_variables, + json_schema_extra=_add_environment_variables, # type: ignore + frozen=frozen, )