From 99f101c55a011724a887a91069b718f0f8d977e5 Mon Sep 17 00:00:00 2001 From: Alex Streed Date: Wed, 15 Jan 2025 18:04:48 -0600 Subject: [PATCH 1/2] Raise type completeness on `prefect.inputs` to 100% (#16740) --- src/prefect/client/utilities.py | 6 +- src/prefect/input/actions.py | 6 +- src/prefect/input/run_input.py | 200 +++++++++++++++++++++++++------- 3 files changed, 165 insertions(+), 47 deletions(-) diff --git a/src/prefect/client/utilities.py b/src/prefect/client/utilities.py index 4622a7d6fe32..3aa7043333d3 100644 --- a/src/prefect/client/utilities.py +++ b/src/prefect/client/utilities.py @@ -5,7 +5,7 @@ # This module must not import from `prefect.client` when it is imported to avoid # circular imports for decorators such as `inject_client` which are widely used. -from collections.abc import Awaitable, Coroutine +from collections.abc import Coroutine from functools import wraps from typing import TYPE_CHECKING, Any, Callable, Optional, Union @@ -61,8 +61,8 @@ def get_or_create_client( def client_injector( - func: Callable[Concatenate["PrefectClient", P], Awaitable[R]], -) -> Callable[P, Awaitable[R]]: + func: Callable[Concatenate["PrefectClient", P], Coroutine[Any, Any, R]], +) -> Callable[P, Coroutine[Any, Any, R]]: @wraps(func) async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: client, _ = get_or_create_client() diff --git a/src/prefect/input/actions.py b/src/prefect/input/actions.py index 9e88e5c59b69..e771ca491c39 100644 --- a/src/prefect/input/actions.py +++ b/src/prefect/input/actions.py @@ -1,3 +1,4 @@ +import inspect from typing import TYPE_CHECKING, Any, Optional, Set from uuid import UUID @@ -44,9 +45,12 @@ async def create_flow_run_input_from_model( else: json_safe = orjson.loads(model_instance.json()) - await create_flow_run_input( + coro = create_flow_run_input( key=key, value=json_safe, flow_run_id=flow_run_id, sender=sender ) + if TYPE_CHECKING: + assert inspect.iscoroutine(coro) + await coro @sync_compatible diff --git a/src/prefect/input/run_input.py b/src/prefect/input/run_input.py index 75434eb4cec4..1567b0d2ff72 100644 --- a/src/prefect/input/run_input.py +++ b/src/prefect/input/run_input.py @@ -60,11 +60,15 @@ async def receiver_flow(): ``` """ +from __future__ import annotations + +import inspect from inspect import isclass from typing import ( TYPE_CHECKING, Any, ClassVar, + Coroutine, Dict, Generic, Literal, @@ -81,6 +85,7 @@ async def receiver_flow(): import anyio import pydantic from pydantic import ConfigDict +from typing_extensions import Self from prefect.input.actions import ( create_flow_run_input, @@ -144,7 +149,7 @@ class RunInputMetadata(pydantic.BaseModel): receiver: UUID -class RunInput(pydantic.BaseModel): +class BaseRunInput(pydantic.BaseModel): model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid") _description: Optional[str] = pydantic.PrivateAttr(default=None) @@ -172,23 +177,29 @@ async def save(cls, keyset: Keyset, flow_run_id: Optional[UUID] = None): if is_v2_model(cls): schema = create_v2_schema(cls.__name__, model_base=cls) else: - schema = cls.schema(by_alias=True) + schema = cls.model_json_schema(by_alias=True) - await create_flow_run_input( + coro = create_flow_run_input( key=keyset["schema"], value=schema, flow_run_id=flow_run_id ) + if TYPE_CHECKING: + assert inspect.iscoroutine(coro) + await coro description = cls._description if isinstance(cls._description, str) else None if description: - await create_flow_run_input( + coro = create_flow_run_input( key=keyset["description"], value=description, flow_run_id=flow_run_id, ) + if TYPE_CHECKING: + assert inspect.iscoroutine(coro) + await coro @classmethod @sync_compatible - async def load(cls, keyset: Keyset, flow_run_id: Optional[UUID] = None): + async def load(cls, keyset: Keyset, flow_run_id: Optional[UUID] = None) -> Self: """ Load the run input response from the given key. @@ -208,7 +219,7 @@ async def load(cls, keyset: Keyset, flow_run_id: Optional[UUID] = None): return instance @classmethod - def load_from_flow_run_input(cls, flow_run_input: "FlowRunInput"): + def load_from_flow_run_input(cls, flow_run_input: "FlowRunInput") -> Self: """ Load the run input from a FlowRunInput object. @@ -284,6 +295,8 @@ async def send_to( key_prefix=key_prefix, ) + +class RunInput(BaseRunInput): @classmethod def receive( cls, @@ -293,7 +306,7 @@ def receive( exclude_keys: Optional[Set[str]] = None, key_prefix: Optional[str] = None, flow_run_id: Optional[UUID] = None, - ): + ) -> GetInputHandler[Self]: if key_prefix is None: key_prefix = f"{cls.__name__.lower()}-auto" @@ -322,12 +335,12 @@ def subclass_from_base_model_type( return type(f"{model_cls.__name__}RunInput", (RunInput, model_cls), {}) # type: ignore -class AutomaticRunInput(RunInput, Generic[T]): +class AutomaticRunInput(BaseRunInput, Generic[T]): value: T @classmethod @sync_compatible - async def load(cls, keyset: Keyset, flow_run_id: Optional[UUID] = None) -> T: + async def load(cls, keyset: Keyset, flow_run_id: Optional[UUID] = None) -> Self: """ Load the run input response from the given key. @@ -335,7 +348,10 @@ async def load(cls, keyset: Keyset, flow_run_id: Optional[UUID] = None) -> T: - keyset (Keyset): the keyset to load the input for - flow_run_id (UUID, optional): the flow run ID to load the input for """ - instance = await super().load(keyset, flow_run_id=flow_run_id) + instance_coro = super().load(keyset, flow_run_id=flow_run_id) + if TYPE_CHECKING: + assert inspect.iscoroutine(instance_coro) + instance = await instance_coro return instance.value @classmethod @@ -370,17 +386,34 @@ def subclass_from_type(cls, _type: Type[T]) -> Type["AutomaticRunInput[T]"]: # Creating a new Pydantic model class dynamically with the name based # on the type prefix. - new_cls: Type["AutomaticRunInput"] = pydantic.create_model( + new_cls: Type["AutomaticRunInput[T]"] = pydantic.create_model( class_name, **fields, __base__=AutomaticRunInput ) return new_cls @classmethod - def receive(cls, *args, **kwargs): - if kwargs.get("key_prefix") is None: - kwargs["key_prefix"] = f"{cls.__name__.lower()}-auto" + def receive( + cls, + timeout: Optional[float] = 3600, + poll_interval: float = 10, + raise_timeout_error: bool = False, + exclude_keys: Optional[Set[str]] = None, + key_prefix: Optional[str] = None, + flow_run_id: Optional[UUID] = None, + with_metadata: bool = False, + ) -> GetAutomaticInputHandler[T]: + key_prefix = key_prefix or f"{cls.__name__.lower()}-auto" - return GetAutomaticInputHandler(run_input_cls=cls, *args, **kwargs) + return GetAutomaticInputHandler( + run_input_cls=cls, + key_prefix=key_prefix, + timeout=timeout, + poll_interval=poll_interval, + raise_timeout_error=raise_timeout_error, + exclude_keys=exclude_keys, + flow_run_id=flow_run_id, + with_metadata=with_metadata, + ) def run_input_subclass_from_type( @@ -409,24 +442,24 @@ def __init__( self, run_input_cls: Type[R], key_prefix: str, - timeout: Optional[float] = 3600, + timeout: float | None = 3600, poll_interval: float = 10, raise_timeout_error: bool = False, exclude_keys: Optional[Set[str]] = None, flow_run_id: Optional[UUID] = None, ): - self.run_input_cls = run_input_cls - self.key_prefix = key_prefix - self.timeout = timeout - self.poll_interval = poll_interval - self.exclude_keys = set() - self.raise_timeout_error = raise_timeout_error - self.flow_run_id = ensure_flow_run_id(flow_run_id) + self.run_input_cls: Type[R] = run_input_cls + self.key_prefix: str = key_prefix + self.timeout: float | None = timeout + self.poll_interval: float = poll_interval + self.exclude_keys: set[str] = set() + self.raise_timeout_error: bool = raise_timeout_error + self.flow_run_id: UUID = ensure_flow_run_id(flow_run_id) if exclude_keys is not None: self.exclude_keys.update(exclude_keys) - def __iter__(self): + def __iter__(self) -> Self: return self def __next__(self) -> R: @@ -437,24 +470,31 @@ def __next__(self) -> R: raise raise StopIteration - def __aiter__(self): + def __aiter__(self) -> Self: return self async def __anext__(self) -> R: try: - return await self.next() + coro = self.next() + if TYPE_CHECKING: + assert inspect.iscoroutine(coro) + return await coro except TimeoutError: if self.raise_timeout_error: raise raise StopAsyncIteration - async def filter_for_inputs(self): - flow_run_inputs = await filter_flow_run_input( + async def filter_for_inputs(self) -> list["FlowRunInput"]: + flow_run_inputs_coro = filter_flow_run_input( key_prefix=self.key_prefix, limit=1, exclude_keys=self.exclude_keys, flow_run_id=self.flow_run_id, ) + if TYPE_CHECKING: + assert inspect.iscoroutine(flow_run_inputs_coro) + + flow_run_inputs = await flow_run_inputs_coro if flow_run_inputs: self.exclude_keys.add(*[i.key for i in flow_run_inputs]) @@ -478,22 +518,91 @@ async def next(self) -> R: return self.to_instance(flow_run_inputs[0]) -class GetAutomaticInputHandler(GetInputHandler, Generic[T]): - def __init__(self, *args, **kwargs): - self.with_metadata = kwargs.pop("with_metadata", False) - super().__init__(*args, **kwargs) +class GetAutomaticInputHandler(Generic[T]): + def __init__( + self, + run_input_cls: Type[AutomaticRunInput[T]], + key_prefix: str, + timeout: float | None = 3600, + poll_interval: float = 10, + raise_timeout_error: bool = False, + exclude_keys: Optional[Set[str]] = None, + flow_run_id: Optional[UUID] = None, + with_metadata: bool = False, + ): + self.run_input_cls: Type[AutomaticRunInput[T]] = run_input_cls + self.key_prefix: str = key_prefix + self.timeout: float | None = timeout + self.poll_interval: float = poll_interval + self.exclude_keys: set[str] = set() + self.raise_timeout_error: bool = raise_timeout_error + self.flow_run_id: UUID = ensure_flow_run_id(flow_run_id) + self.with_metadata = with_metadata - def __next__(self) -> T: - return cast(T, super().__next__()) + if exclude_keys is not None: + self.exclude_keys.update(exclude_keys) - async def __anext__(self) -> T: - return cast(T, await super().__anext__()) + def __iter__(self) -> Self: + return self + + def __next__(self) -> T | AutomaticRunInput[T]: + try: + not_coro = self.next() + if TYPE_CHECKING: + assert not isinstance(not_coro, Coroutine) + return not_coro + except TimeoutError: + if self.raise_timeout_error: + raise + raise StopIteration + + def __aiter__(self) -> Self: + return self + + async def __anext__(self) -> Union[T, AutomaticRunInput[T]]: + try: + coro = self.next() + if TYPE_CHECKING: + assert inspect.iscoroutine(coro) + return cast(Union[T, AutomaticRunInput[T]], await coro) + except TimeoutError: + if self.raise_timeout_error: + raise + raise StopAsyncIteration + + async def filter_for_inputs(self) -> list["FlowRunInput"]: + flow_run_inputs_coro = filter_flow_run_input( + key_prefix=self.key_prefix, + limit=1, + exclude_keys=self.exclude_keys, + flow_run_id=self.flow_run_id, + ) + if TYPE_CHECKING: + assert inspect.iscoroutine(flow_run_inputs_coro) + + flow_run_inputs = await flow_run_inputs_coro + + if flow_run_inputs: + self.exclude_keys.add(*[i.key for i in flow_run_inputs]) + + return flow_run_inputs @sync_compatible - async def next(self) -> T: - return cast(T, await super().next()) + async def next(self) -> Union[T, AutomaticRunInput[T]]: + flow_run_inputs = await self.filter_for_inputs() + if flow_run_inputs: + return self.to_instance(flow_run_inputs[0]) - def to_instance(self, flow_run_input: "FlowRunInput") -> T: + with anyio.fail_after(self.timeout): + while True: + await anyio.sleep(self.poll_interval) + flow_run_inputs = await self.filter_for_inputs() + if flow_run_inputs: + return self.to_instance(flow_run_inputs[0]) + + def to_instance( + self, flow_run_input: "FlowRunInput" + ) -> Union[T, AutomaticRunInput[T]]: run_input = self.run_input_cls.load_from_flow_run_input(flow_run_input) if self.with_metadata: @@ -503,14 +612,15 @@ def to_instance(self, flow_run_input: "FlowRunInput") -> T: async def _send_input( flow_run_id: UUID, - run_input: Any, + run_input: RunInput | pydantic.BaseModel, sender: Optional[str] = None, key_prefix: Optional[str] = None, ): + _run_input: Union[RunInput, AutomaticRunInput[Any]] if isinstance(run_input, RunInput): - _run_input: RunInput = run_input + _run_input = run_input else: - input_cls: Type[AutomaticRunInput] = run_input_subclass_from_type( + input_cls: Type[AutomaticRunInput[Any]] = run_input_subclass_from_type( type(run_input) ) _run_input = input_cls(value=run_input) @@ -520,9 +630,13 @@ async def _send_input( key = f"{key_prefix}-{uuid4()}" - await create_flow_run_input_from_model( + coro = create_flow_run_input_from_model( key=key, flow_run_id=flow_run_id, model_instance=_run_input, sender=sender ) + if TYPE_CHECKING: + assert inspect.iscoroutine(coro) + + await coro @sync_compatible From e18e438fd8ef283c09b7892837411b4678e4d9a3 Mon Sep 17 00:00:00 2001 From: Chris White Date: Wed, 15 Jan 2025 16:05:33 -0800 Subject: [PATCH 2/2] Fixes to Prefect Basic Auth (#16735) --- src/prefect/server/api/server.py | 3 +++ ui/src/pages/Unauthenticated.vue | 23 +++++++++++++++++++++-- 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/src/prefect/server/api/server.py b/src/prefect/server/api/server.py index 908ae5c3afc8..cc2b99b02e85 100644 --- a/src/prefect/server/api/server.py +++ b/src/prefect/server/api/server.py @@ -353,6 +353,9 @@ async def server_version() -> str: # type: ignore[reportUnusedFunction] async def token_validation(request: Request, call_next: Any): # type: ignore[reportUnusedFunction] header_token = request.headers.get("Authorization") + # used for probes in k8s and such + if request.url.path in ["/api/health", "/api/ready"]: + return await call_next(request) try: if header_token is None: return JSONResponse( diff --git a/ui/src/pages/Unauthenticated.vue b/ui/src/pages/Unauthenticated.vue index faa79ee9fa3f..5b3d4bf7e1a7 100644 --- a/ui/src/pages/Unauthenticated.vue +++ b/ui/src/pages/Unauthenticated.vue @@ -8,7 +8,7 @@ => { try { localStorage.setItem('prefect-password', btoa(password.value)) - router.push(props.redirect || '/') + api.admin.authCheck().then(status_code => { + if (status_code == 401) { + localStorage.removeItem('prefect-pasword') + showToast('Authentication failed.', 'error', { timeout: false }) + if (router.currentRoute.value.name !== 'login') { + router.push({ + name: 'login', + query: { redirect: router.currentRoute.value.fullPath } + }) + } + } else { + api.health.isHealthy().then(healthy => { + if (!healthy) { + showToast(`Can't connect to Server API at ${config.baseUrl}. Check that it's accessible from your machine.`, 'error', { timeout: false }) + } + router.push(props.redirect || '/') + }) + } + }) } catch (e) { localStorage.removeItem('prefect-password') error.value = 'Invalid password'