Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor work pool client methods #16661

Merged
merged 3 commits into from
Jan 31, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
264 changes: 9 additions & 255 deletions src/prefect/client/orchestration/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import base64
import datetime
import ssl
import warnings
from collections.abc import Iterable
from contextlib import AsyncExitStack
from logging import Logger
Expand Down Expand Up @@ -48,6 +47,12 @@
AutomationClient,
AutomationAsyncClient,
)

from prefect.client.orchestration._work_pools.client import (
WorkPoolClient,
WorkPoolAsyncClient,
)

from prefect._experimental.sla.client import SlaClient, SlaAsyncClient

from prefect.client.orchestration._flows.client import (
Expand Down Expand Up @@ -85,8 +90,6 @@
FlowRunNotificationPolicyUpdate,
TaskRunCreate,
TaskRunUpdate,
WorkPoolCreate,
WorkPoolUpdate,
WorkQueueCreate,
WorkQueueUpdate,
)
Expand All @@ -96,8 +99,6 @@
FlowRunFilter,
FlowRunNotificationPolicyFilter,
TaskRunFilter,
WorkerFilter,
WorkPoolFilter,
WorkQueueFilter,
WorkQueueFilterName,
)
Expand All @@ -107,15 +108,10 @@
Parameter,
TaskRunPolicy,
TaskRunResult,
Worker,
WorkerMetadata,
WorkPool,
WorkQueue,
WorkQueueStatusDetail,
)
from prefect.client.schemas.responses import (
WorkerFlowRunResponse,
)

from prefect.client.schemas.sorting import (
TaskRunSort,
)
Expand All @@ -133,7 +129,6 @@
PREFECT_CLOUD_API_URL,
PREFECT_SERVER_ALLOW_EPHEMERAL_MODE,
PREFECT_TESTING_UNIT_TEST_MODE,
get_current_settings,
)

if TYPE_CHECKING:
Expand Down Expand Up @@ -263,6 +258,7 @@ class PrefectClient(
BlocksDocumentAsyncClient,
BlocksSchemaAsyncClient,
BlocksTypeAsyncClient,
WorkPoolAsyncClient,
):
"""
An asynchronous client for interacting with the [Prefect REST API](/api-ref/rest-api/).
Expand Down Expand Up @@ -1100,214 +1096,6 @@ async def read_flow_run_notification_policies(
response.json()
)

async def send_worker_heartbeat(
self,
work_pool_name: str,
worker_name: str,
heartbeat_interval_seconds: Optional[float] = None,
get_worker_id: bool = False,
worker_metadata: Optional[WorkerMetadata] = None,
) -> Optional[UUID]:
"""
Sends a worker heartbeat for a given work pool.

Args:
work_pool_name: The name of the work pool to heartbeat against.
worker_name: The name of the worker sending the heartbeat.
return_id: Whether to return the worker ID. Note: will return `None` if the connected server does not support returning worker IDs, even if `return_id` is `True`.
worker_metadata: Metadata about the worker to send to the server.
"""
params: dict[str, Any] = {
"name": worker_name,
"heartbeat_interval_seconds": heartbeat_interval_seconds,
}
if worker_metadata:
params["metadata"] = worker_metadata.model_dump(mode="json")
if get_worker_id:
params["return_id"] = get_worker_id

resp = await self._client.post(
f"/work_pools/{work_pool_name}/workers/heartbeat",
json=params,
)

if (
(
self.server_type == ServerType.CLOUD
or get_current_settings().testing.test_mode
)
and get_worker_id
and resp.status_code == 200
):
return UUID(resp.text)
else:
return None

async def read_workers_for_work_pool(
self,
work_pool_name: str,
worker_filter: Optional[WorkerFilter] = None,
offset: Optional[int] = None,
limit: Optional[int] = None,
) -> list[Worker]:
"""
Reads workers for a given work pool.

Args:
work_pool_name: The name of the work pool for which to get
member workers.
worker_filter: Criteria by which to filter workers.
limit: Limit for the worker query.
offset: Limit for the worker query.
"""
response = await self._client.post(
f"/work_pools/{work_pool_name}/workers/filter",
json={
"workers": (
worker_filter.model_dump(mode="json", exclude_unset=True)
if worker_filter
else None
),
"offset": offset,
"limit": limit,
},
)

return pydantic.TypeAdapter(list[Worker]).validate_python(response.json())

async def read_work_pool(self, work_pool_name: str) -> WorkPool:
"""
Reads information for a given work pool

Args:
work_pool_name: The name of the work pool to for which to get
information.

Returns:
Information about the requested work pool.
"""
try:
response = await self._client.get(f"/work_pools/{work_pool_name}")
return WorkPool.model_validate(response.json())
except httpx.HTTPStatusError as e:
if e.response.status_code == status.HTTP_404_NOT_FOUND:
raise prefect.exceptions.ObjectNotFound(http_exc=e) from e
else:
raise

async def read_work_pools(
self,
limit: Optional[int] = None,
offset: int = 0,
work_pool_filter: Optional[WorkPoolFilter] = None,
) -> list[WorkPool]:
"""
Reads work pools.

Args:
limit: Limit for the work pool query.
offset: Offset for the work pool query.
work_pool_filter: Criteria by which to filter work pools.

Returns:
A list of work pools.
"""

body: dict[str, Any] = {
"limit": limit,
"offset": offset,
"work_pools": (
work_pool_filter.model_dump(mode="json") if work_pool_filter else None
),
}
response = await self._client.post("/work_pools/filter", json=body)
return pydantic.TypeAdapter(list[WorkPool]).validate_python(response.json())

async def create_work_pool(
self,
work_pool: WorkPoolCreate,
overwrite: bool = False,
) -> WorkPool:
"""
Creates a work pool with the provided configuration.

Args:
work_pool: Desired configuration for the new work pool.

Returns:
Information about the newly created work pool.
"""
try:
response = await self._client.post(
"/work_pools/",
json=work_pool.model_dump(mode="json", exclude_unset=True),
)
except httpx.HTTPStatusError as e:
if e.response.status_code == status.HTTP_409_CONFLICT:
if overwrite:
existing_work_pool = await self.read_work_pool(
work_pool_name=work_pool.name
)
if existing_work_pool.type != work_pool.type:
warnings.warn(
"Overwriting work pool type is not supported. Ignoring provided type.",
category=UserWarning,
)
await self.update_work_pool(
work_pool_name=work_pool.name,
work_pool=WorkPoolUpdate.model_validate(
work_pool.model_dump(exclude={"name", "type"})
),
)
response = await self._client.get(f"/work_pools/{work_pool.name}")
else:
raise prefect.exceptions.ObjectAlreadyExists(http_exc=e) from e
else:
raise

return WorkPool.model_validate(response.json())

async def update_work_pool(
self,
work_pool_name: str,
work_pool: WorkPoolUpdate,
) -> None:
"""
Updates a work pool.

Args:
work_pool_name: Name of the work pool to update.
work_pool: Fields to update in the work pool.
"""
try:
await self._client.patch(
f"/work_pools/{work_pool_name}",
json=work_pool.model_dump(mode="json", exclude_unset=True),
)
except httpx.HTTPStatusError as e:
if e.response.status_code == status.HTTP_404_NOT_FOUND:
raise prefect.exceptions.ObjectNotFound(http_exc=e) from e
else:
raise

async def delete_work_pool(
self,
work_pool_name: str,
) -> None:
"""
Deletes a work pool.

Args:
work_pool_name: Name of the work pool to delete.
"""
try:
await self._client.delete(f"/work_pools/{work_pool_name}")
except httpx.HTTPStatusError as e:
if e.response.status_code == status.HTTP_404_NOT_FOUND:
raise prefect.exceptions.ObjectNotFound(http_exc=e) from e
else:
raise

async def read_work_queues(
self,
work_pool_name: Optional[str] = None,
Expand Down Expand Up @@ -1353,41 +1141,6 @@ async def read_work_queues(

return pydantic.TypeAdapter(list[WorkQueue]).validate_python(response.json())

async def get_scheduled_flow_runs_for_work_pool(
self,
work_pool_name: str,
work_queue_names: Optional[list[str]] = None,
scheduled_before: Optional[datetime.datetime] = None,
) -> list[WorkerFlowRunResponse]:
"""
Retrieves scheduled flow runs for the provided set of work pool queues.

Args:
work_pool_name: The name of the work pool that the work pool
queues are associated with.
work_queue_names: The names of the work pool queues from which
to get scheduled flow runs.
scheduled_before: Datetime used to filter returned flow runs. Flow runs
scheduled for after the given datetime string will not be returned.

Returns:
A list of worker flow run responses containing information about the
retrieved flow runs.
"""
body: dict[str, Any] = {}
if work_queue_names is not None:
body["work_queue_names"] = list(work_queue_names)
if scheduled_before:
body["scheduled_before"] = str(scheduled_before)

response = await self._client.post(
f"/work_pools/{work_pool_name}/get_scheduled_flow_runs",
json=body,
)
return pydantic.TypeAdapter(list[WorkerFlowRunResponse]).validate_python(
response.json()
)

async def read_worker_metadata(self) -> dict[str, Any]:
"""Reads worker metadata stored in Prefect collection registry."""
response = await self._client.get("collections/views/aggregate-worker-metadata")
Expand Down Expand Up @@ -1504,6 +1257,7 @@ class SyncPrefectClient(
BlocksDocumentClient,
BlocksSchemaClient,
BlocksTypeClient,
WorkPoolClient,
):
"""
A synchronous client for interacting with the [Prefect REST API](/api-ref/rest-api/).
Expand Down
Empty file.
Loading