Skip to content

Commit

Permalink
task engine typing (#16731)
Browse files Browse the repository at this point in the history
Co-authored-by: Alex Streed <[email protected]>
  • Loading branch information
zzstoatzz and desertaxle authored Jan 15, 2025
1 parent a19131c commit d966c36
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 47 deletions.
40 changes: 20 additions & 20 deletions src/prefect/cache_policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from copy import deepcopy
from dataclasses import dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, Literal, Optional, Union

from typing_extensions import Self

Expand Down Expand Up @@ -73,8 +73,8 @@ def configure(
def compute_key(
self,
task_ctx: TaskRunContext,
inputs: Dict[str, Any],
flow_parameters: Dict[str, Any],
inputs: dict[str, Any],
flow_parameters: dict[str, Any],
**kwargs: Any,
) -> Optional[str]:
raise NotImplementedError
Expand Down Expand Up @@ -132,14 +132,14 @@ class CacheKeyFnPolicy(CachePolicy):

# making it optional for tests
cache_key_fn: Optional[
Callable[["TaskRunContext", Dict[str, Any]], Optional[str]]
Callable[["TaskRunContext", dict[str, Any]], Optional[str]]
] = None

def compute_key(
self,
task_ctx: TaskRunContext,
inputs: Dict[str, Any],
flow_parameters: Dict[str, Any],
inputs: dict[str, Any],
flow_parameters: dict[str, Any],
**kwargs: Any,
) -> Optional[str]:
if self.cache_key_fn:
Expand All @@ -155,13 +155,13 @@ class CompoundCachePolicy(CachePolicy):
Any keys that return `None` will be ignored.
"""

policies: List[CachePolicy] = field(default_factory=list)
policies: list[CachePolicy] = field(default_factory=list)

def compute_key(
self,
task_ctx: TaskRunContext,
inputs: Dict[str, Any],
flow_parameters: Dict[str, Any],
inputs: dict[str, Any],
flow_parameters: dict[str, Any],
**kwargs: Any,
) -> Optional[str]:
keys: list[str] = []
Expand Down Expand Up @@ -189,8 +189,8 @@ class _None(CachePolicy):
def compute_key(
self,
task_ctx: TaskRunContext,
inputs: Dict[str, Any],
flow_parameters: Dict[str, Any],
inputs: dict[str, Any],
flow_parameters: dict[str, Any],
**kwargs: Any,
) -> Optional[str]:
return None
Expand All @@ -209,8 +209,8 @@ class TaskSource(CachePolicy):
def compute_key(
self,
task_ctx: TaskRunContext,
inputs: Optional[Dict[str, Any]],
flow_parameters: Optional[Dict[str, Any]],
inputs: Optional[dict[str, Any]],
flow_parameters: Optional[dict[str, Any]],
**kwargs: Any,
) -> Optional[str]:
if not task_ctx:
Expand All @@ -236,8 +236,8 @@ class FlowParameters(CachePolicy):
def compute_key(
self,
task_ctx: TaskRunContext,
inputs: Dict[str, Any],
flow_parameters: Dict[str, Any],
inputs: dict[str, Any],
flow_parameters: dict[str, Any],
**kwargs: Any,
) -> Optional[str]:
if not flow_parameters:
Expand All @@ -255,8 +255,8 @@ class RunId(CachePolicy):
def compute_key(
self,
task_ctx: TaskRunContext,
inputs: Dict[str, Any],
flow_parameters: Dict[str, Any],
inputs: dict[str, Any],
flow_parameters: dict[str, Any],
**kwargs: Any,
) -> Optional[str]:
if not task_ctx:
Expand All @@ -273,13 +273,13 @@ class Inputs(CachePolicy):
Policy that computes a cache key based on a hash of the runtime inputs provided to the task..
"""

exclude: List[str] = field(default_factory=list)
exclude: list[str] = field(default_factory=list)

def compute_key(
self,
task_ctx: TaskRunContext,
inputs: Dict[str, Any],
flow_parameters: Dict[str, Any],
inputs: dict[str, Any],
flow_parameters: dict[str, Any],
**kwargs: Any,
) -> Optional[str]:
hashed_inputs = {}
Expand Down
19 changes: 11 additions & 8 deletions src/prefect/logging/loggers.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from __future__ import annotations

import io
import logging
import sys
from builtins import print
from contextlib import contextmanager
from functools import lru_cache
from logging import LogRecord
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from typing import TYPE_CHECKING, Any, List, Mapping, Optional, Union

from typing_extensions import Self

Expand Down Expand Up @@ -37,27 +39,28 @@ class PrefectLogAdapter(LoggingAdapter):
not a bug in the LoggingAdapter and subclassing is the intended workaround.
"""

def process(self, msg, kwargs):
extra: Mapping[str, object] | None

def process(self, msg: str, kwargs: dict[str, Any]) -> tuple[str, dict[str, Any]]: # type: ignore[incompatibleMethodOverride]
kwargs["extra"] = {**(self.extra or {}), **(kwargs.get("extra") or {})}
return (msg, kwargs)

def getChild(
self, suffix: str, extra: Optional[Dict[str, str]] = None
self, suffix: str, extra: dict[str, Any] | None = None
) -> "PrefectLogAdapter":
if extra is None:
extra = {}
_extra: Mapping[str, object] = extra or {}

return PrefectLogAdapter(
self.logger.getChild(suffix),
extra={
**self.extra,
**extra,
**(self.extra or {}),
**_extra,
},
)


@lru_cache()
def get_logger(name: Optional[str] = None) -> logging.Logger:
def get_logger(name: str | None = None) -> logging.Logger:
"""
Get a `prefect` logger. These loggers are intended for internal use within the
`prefect` package.
Expand Down
2 changes: 1 addition & 1 deletion src/prefect/states.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ async def exception_to_failed_state(
exc: Optional[BaseException] = None,
result_store: Optional["ResultStore"] = None,
write_result: bool = False,
**kwargs,
**kwargs: Any,
) -> State:
"""
Convenience function for creating `Failed` states from exceptions
Expand Down
38 changes: 20 additions & 18 deletions src/prefect/task_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from opentelemetry import trace
from typing_extensions import ParamSpec

from prefect.cache_policies import CachePolicy
from prefect.client.orchestration import PrefectClient, SyncPrefectClient, get_client
from prefect.client.schemas import TaskRun
from prefect.client.schemas.objects import State, TaskRunInput
Expand All @@ -55,7 +56,7 @@
from prefect.logging.loggers import get_logger, patch_print, task_run_logger
from prefect.results import (
ResultRecord,
_format_user_supplied_storage_key,
_format_user_supplied_storage_key, # type: ignore[reportPrivateUsage]
get_result_store,
should_persist_result,
)
Expand Down Expand Up @@ -115,7 +116,7 @@ class BaseTaskRunEngine(Generic[P, R]):
# holds the return value from the user code
_return_value: Union[R, Type[NotSet]] = NotSet
# holds the exception raised by the user code, if any
_raised: Union[Exception, Type[NotSet]] = NotSet
_raised: Union[Exception, BaseException, Type[NotSet]] = NotSet
_initial_run_context: Optional[TaskRunContext] = None
_is_started: bool = False
_task_name_set: bool = False
Expand All @@ -128,7 +129,7 @@ def __post_init__(self):

@property
def state(self) -> State:
if not self.task_run:
if not self.task_run or not self.task_run.state:
raise ValueError("Task run is not set")
return self.task_run.state

Expand All @@ -142,8 +143,8 @@ def is_cancelled(self) -> bool:
return False

def compute_transaction_key(self) -> Optional[str]:
key = None
if self.task.cache_policy:
key: Optional[str] = None
if self.task.cache_policy and isinstance(self.task.cache_policy, CachePolicy):
flow_run_context = FlowRunContext.get()
task_run_context = TaskRunContext.get()

Expand All @@ -153,10 +154,12 @@ def compute_transaction_key(self) -> Optional[str]:
parameters = None

try:
if not task_run_context:
raise ValueError("Task run context is not set")
key = self.task.cache_policy.compute_key(
task_ctx=task_run_context,
inputs=self.parameters,
flow_parameters=parameters,
inputs=self.parameters or {},
flow_parameters=parameters or {},
)
except Exception:
self.logger.exception(
Expand All @@ -169,7 +172,7 @@ def compute_transaction_key(self) -> Optional[str]:

def _resolve_parameters(self):
if not self.parameters:
return {}
return None

resolved_parameters = {}
for parameter, value in self.parameters.items():
Expand Down Expand Up @@ -227,10 +230,8 @@ def record_terminal_state_timing(self, state: State) -> None:
if self.task_run and self.task_run.start_time and not self.task_run.end_time:
self.task_run.end_time = state.timestamp

if self.task_run.state.is_running():
self.task_run.total_run_time += (
state.timestamp - self.task_run.state.timestamp
)
if self.state.is_running():
self.task_run.total_run_time += state.timestamp - self.state.timestamp

def is_running(self) -> bool:
"""Whether or not the engine is currently running a task."""
Expand Down Expand Up @@ -390,6 +391,7 @@ def begin_run(self):

new_state = Running()

assert self.task_run is not None, "Task run is not set"
self.task_run.start_time = new_state.timestamp

flow_run_context = FlowRunContext.get()
Expand All @@ -406,7 +408,7 @@ def begin_run(self):
# result reference that no longer exists
if state.is_completed():
try:
state.result(retry_result_failure=False, _sync=True)
state.result(retry_result_failure=False, _sync=True) # type: ignore[reportCallIssue]
except Exception:
state = self.set_state(new_state, force=True)

Expand All @@ -422,7 +424,7 @@ def begin_run(self):
time.sleep(interval)
state = self.set_state(new_state)

def set_state(self, state: State, force: bool = False) -> State:
def set_state(self, state: State[R], force: bool = False) -> State[R]:
last_state = self.state
if not self.task_run:
raise ValueError("Task run is not set")
Expand Down Expand Up @@ -537,7 +539,7 @@ def handle_retry(self, exc: Exception) -> bool:
new_state = Retrying()

self.logger.info(
"Task run failed with exception: %r - " "Retry %s/%s will start %s",
"Task run failed with exception: %r - Retry %s/%s will start %s",
exc,
self.retries + 1,
self.task.retries,
Expand Down Expand Up @@ -1067,7 +1069,7 @@ async def handle_retry(self, exc: Exception) -> bool:
new_state = Retrying()

self.logger.info(
"Task run failed with exception: %r - " "Retry %s/%s will start %s",
"Task run failed with exception: %r - Retry %s/%s will start %s",
exc,
self.retries + 1,
self.task.retries,
Expand Down Expand Up @@ -1341,7 +1343,7 @@ async def call_task_fn(
if transaction.is_committed():
result = transaction.read()
else:
if self.task_run.tags:
if self.task_run and self.task_run.tags:
# Acquire a concurrency slot for each tag, but only if a limit
# matching the tag already exists.
async with aconcurrency(list(self.task_run.tags), self.task_run.id):
Expand Down Expand Up @@ -1546,7 +1548,7 @@ def run_task(
Returns:
The result of the task run
"""
kwargs = dict(
kwargs: dict[str, Any] = dict(
task=task,
task_run_id=task_run_id,
task_run=task_run,
Expand Down

0 comments on commit d966c36

Please sign in to comment.