diff --git a/src/prefect/client/orchestration/__init__.py b/src/prefect/client/orchestration/__init__.py index 5735f01e06d9..45908ba8831d 100644 --- a/src/prefect/client/orchestration/__init__.py +++ b/src/prefect/client/orchestration/__init__.py @@ -2,7 +2,6 @@ import base64 import datetime import ssl -import warnings from collections.abc import Iterable from contextlib import AsyncExitStack from logging import Logger @@ -48,6 +47,12 @@ AutomationClient, AutomationAsyncClient, ) + +from prefect.client.orchestration._work_pools.client import ( + WorkPoolClient, + WorkPoolAsyncClient, +) + from prefect._experimental.sla.client import SlaClient, SlaAsyncClient from prefect.client.orchestration._flows.client import ( @@ -85,8 +90,6 @@ FlowRunNotificationPolicyUpdate, TaskRunCreate, TaskRunUpdate, - WorkPoolCreate, - WorkPoolUpdate, WorkQueueCreate, WorkQueueUpdate, ) @@ -96,8 +99,6 @@ FlowRunFilter, FlowRunNotificationPolicyFilter, TaskRunFilter, - WorkerFilter, - WorkPoolFilter, WorkQueueFilter, WorkQueueFilterName, ) @@ -107,15 +108,10 @@ Parameter, TaskRunPolicy, TaskRunResult, - Worker, - WorkerMetadata, - WorkPool, WorkQueue, WorkQueueStatusDetail, ) -from prefect.client.schemas.responses import ( - WorkerFlowRunResponse, -) + from prefect.client.schemas.sorting import ( TaskRunSort, ) @@ -133,7 +129,6 @@ PREFECT_CLOUD_API_URL, PREFECT_SERVER_ALLOW_EPHEMERAL_MODE, PREFECT_TESTING_UNIT_TEST_MODE, - get_current_settings, ) if TYPE_CHECKING: @@ -263,6 +258,7 @@ class PrefectClient( BlocksDocumentAsyncClient, BlocksSchemaAsyncClient, BlocksTypeAsyncClient, + WorkPoolAsyncClient, ): """ An asynchronous client for interacting with the [Prefect REST API](/api-ref/rest-api/). @@ -1100,214 +1096,6 @@ async def read_flow_run_notification_policies( response.json() ) - async def send_worker_heartbeat( - self, - work_pool_name: str, - worker_name: str, - heartbeat_interval_seconds: Optional[float] = None, - get_worker_id: bool = False, - worker_metadata: Optional[WorkerMetadata] = None, - ) -> Optional[UUID]: - """ - Sends a worker heartbeat for a given work pool. - - Args: - work_pool_name: The name of the work pool to heartbeat against. - worker_name: The name of the worker sending the heartbeat. - return_id: Whether to return the worker ID. Note: will return `None` if the connected server does not support returning worker IDs, even if `return_id` is `True`. - worker_metadata: Metadata about the worker to send to the server. - """ - params: dict[str, Any] = { - "name": worker_name, - "heartbeat_interval_seconds": heartbeat_interval_seconds, - } - if worker_metadata: - params["metadata"] = worker_metadata.model_dump(mode="json") - if get_worker_id: - params["return_id"] = get_worker_id - - resp = await self._client.post( - f"/work_pools/{work_pool_name}/workers/heartbeat", - json=params, - ) - - if ( - ( - self.server_type == ServerType.CLOUD - or get_current_settings().testing.test_mode - ) - and get_worker_id - and resp.status_code == 200 - ): - return UUID(resp.text) - else: - return None - - async def read_workers_for_work_pool( - self, - work_pool_name: str, - worker_filter: Optional[WorkerFilter] = None, - offset: Optional[int] = None, - limit: Optional[int] = None, - ) -> list[Worker]: - """ - Reads workers for a given work pool. - - Args: - work_pool_name: The name of the work pool for which to get - member workers. - worker_filter: Criteria by which to filter workers. - limit: Limit for the worker query. - offset: Limit for the worker query. - """ - response = await self._client.post( - f"/work_pools/{work_pool_name}/workers/filter", - json={ - "workers": ( - worker_filter.model_dump(mode="json", exclude_unset=True) - if worker_filter - else None - ), - "offset": offset, - "limit": limit, - }, - ) - - return pydantic.TypeAdapter(list[Worker]).validate_python(response.json()) - - async def read_work_pool(self, work_pool_name: str) -> WorkPool: - """ - Reads information for a given work pool - - Args: - work_pool_name: The name of the work pool to for which to get - information. - - Returns: - Information about the requested work pool. - """ - try: - response = await self._client.get(f"/work_pools/{work_pool_name}") - return WorkPool.model_validate(response.json()) - except httpx.HTTPStatusError as e: - if e.response.status_code == status.HTTP_404_NOT_FOUND: - raise prefect.exceptions.ObjectNotFound(http_exc=e) from e - else: - raise - - async def read_work_pools( - self, - limit: Optional[int] = None, - offset: int = 0, - work_pool_filter: Optional[WorkPoolFilter] = None, - ) -> list[WorkPool]: - """ - Reads work pools. - - Args: - limit: Limit for the work pool query. - offset: Offset for the work pool query. - work_pool_filter: Criteria by which to filter work pools. - - Returns: - A list of work pools. - """ - - body: dict[str, Any] = { - "limit": limit, - "offset": offset, - "work_pools": ( - work_pool_filter.model_dump(mode="json") if work_pool_filter else None - ), - } - response = await self._client.post("/work_pools/filter", json=body) - return pydantic.TypeAdapter(list[WorkPool]).validate_python(response.json()) - - async def create_work_pool( - self, - work_pool: WorkPoolCreate, - overwrite: bool = False, - ) -> WorkPool: - """ - Creates a work pool with the provided configuration. - - Args: - work_pool: Desired configuration for the new work pool. - - Returns: - Information about the newly created work pool. - """ - try: - response = await self._client.post( - "/work_pools/", - json=work_pool.model_dump(mode="json", exclude_unset=True), - ) - except httpx.HTTPStatusError as e: - if e.response.status_code == status.HTTP_409_CONFLICT: - if overwrite: - existing_work_pool = await self.read_work_pool( - work_pool_name=work_pool.name - ) - if existing_work_pool.type != work_pool.type: - warnings.warn( - "Overwriting work pool type is not supported. Ignoring provided type.", - category=UserWarning, - ) - await self.update_work_pool( - work_pool_name=work_pool.name, - work_pool=WorkPoolUpdate.model_validate( - work_pool.model_dump(exclude={"name", "type"}) - ), - ) - response = await self._client.get(f"/work_pools/{work_pool.name}") - else: - raise prefect.exceptions.ObjectAlreadyExists(http_exc=e) from e - else: - raise - - return WorkPool.model_validate(response.json()) - - async def update_work_pool( - self, - work_pool_name: str, - work_pool: WorkPoolUpdate, - ) -> None: - """ - Updates a work pool. - - Args: - work_pool_name: Name of the work pool to update. - work_pool: Fields to update in the work pool. - """ - try: - await self._client.patch( - f"/work_pools/{work_pool_name}", - json=work_pool.model_dump(mode="json", exclude_unset=True), - ) - except httpx.HTTPStatusError as e: - if e.response.status_code == status.HTTP_404_NOT_FOUND: - raise prefect.exceptions.ObjectNotFound(http_exc=e) from e - else: - raise - - async def delete_work_pool( - self, - work_pool_name: str, - ) -> None: - """ - Deletes a work pool. - - Args: - work_pool_name: Name of the work pool to delete. - """ - try: - await self._client.delete(f"/work_pools/{work_pool_name}") - except httpx.HTTPStatusError as e: - if e.response.status_code == status.HTTP_404_NOT_FOUND: - raise prefect.exceptions.ObjectNotFound(http_exc=e) from e - else: - raise - async def read_work_queues( self, work_pool_name: Optional[str] = None, @@ -1353,41 +1141,6 @@ async def read_work_queues( return pydantic.TypeAdapter(list[WorkQueue]).validate_python(response.json()) - async def get_scheduled_flow_runs_for_work_pool( - self, - work_pool_name: str, - work_queue_names: Optional[list[str]] = None, - scheduled_before: Optional[datetime.datetime] = None, - ) -> list[WorkerFlowRunResponse]: - """ - Retrieves scheduled flow runs for the provided set of work pool queues. - - Args: - work_pool_name: The name of the work pool that the work pool - queues are associated with. - work_queue_names: The names of the work pool queues from which - to get scheduled flow runs. - scheduled_before: Datetime used to filter returned flow runs. Flow runs - scheduled for after the given datetime string will not be returned. - - Returns: - A list of worker flow run responses containing information about the - retrieved flow runs. - """ - body: dict[str, Any] = {} - if work_queue_names is not None: - body["work_queue_names"] = list(work_queue_names) - if scheduled_before: - body["scheduled_before"] = str(scheduled_before) - - response = await self._client.post( - f"/work_pools/{work_pool_name}/get_scheduled_flow_runs", - json=body, - ) - return pydantic.TypeAdapter(list[WorkerFlowRunResponse]).validate_python( - response.json() - ) - async def read_worker_metadata(self) -> dict[str, Any]: """Reads worker metadata stored in Prefect collection registry.""" response = await self._client.get("collections/views/aggregate-worker-metadata") @@ -1504,6 +1257,7 @@ class SyncPrefectClient( BlocksDocumentClient, BlocksSchemaClient, BlocksTypeClient, + WorkPoolClient, ): """ A synchronous client for interacting with the [Prefect REST API](/api-ref/rest-api/). diff --git a/src/prefect/client/orchestration/_work_pools/__init__.py b/src/prefect/client/orchestration/_work_pools/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/src/prefect/client/orchestration/_work_pools/client.py b/src/prefect/client/orchestration/_work_pools/client.py new file mode 100644 index 000000000000..2c8d842319ce --- /dev/null +++ b/src/prefect/client/orchestration/_work_pools/client.py @@ -0,0 +1,598 @@ +from __future__ import annotations + +import warnings +from datetime import datetime +from typing import TYPE_CHECKING, Any + +from httpx import HTTPStatusError + +from prefect.client.base import ServerType +from prefect.client.orchestration.base import BaseAsyncClient, BaseClient + +if TYPE_CHECKING: + from uuid import UUID + + from prefect.client.schemas.actions import ( + WorkPoolCreate, + WorkPoolUpdate, + ) + from prefect.client.schemas.filters import ( + WorkerFilter, + WorkPoolFilter, + ) + from prefect.client.schemas.objects import ( + Worker, + WorkerMetadata, + WorkPool, + ) + from prefect.client.schemas.responses import WorkerFlowRunResponse + +from prefect.exceptions import ObjectAlreadyExists, ObjectNotFound + + +class WorkPoolClient(BaseClient): + def send_worker_heartbeat( + self, + work_pool_name: str, + worker_name: str, + heartbeat_interval_seconds: float | None = None, + get_worker_id: bool = False, + worker_metadata: "WorkerMetadata | None" = None, + ) -> "UUID | None": + """ + Sends a worker heartbeat for a given work pool. + + Args: + work_pool_name: The name of the work pool to heartbeat against. + worker_name: The name of the worker sending the heartbeat. + return_id: Whether to return the worker ID. Note: will return `None` if the connected server does not support returning worker IDs, even if `return_id` is `True`. + worker_metadata: Metadata about the worker to send to the server. + """ + from uuid import UUID + + params: dict[str, Any] = { + "name": worker_name, + "heartbeat_interval_seconds": heartbeat_interval_seconds, + } + if worker_metadata: + params["metadata"] = worker_metadata.model_dump(mode="json") + if get_worker_id: + params["return_id"] = get_worker_id + + resp = self.request( + "POST", + "/work_pools/{work_pool_name}/workers/heartbeat", + path_params={"work_pool_name": work_pool_name}, + json=params, + ) + from prefect.settings import get_current_settings + + if ( + ( + self.server_type == ServerType.CLOUD + or get_current_settings().testing.test_mode + ) + and get_worker_id + and resp.status_code == 200 + ): + return UUID(resp.text) + else: + return None + + def read_workers_for_work_pool( + self, + work_pool_name: str, + worker_filter: "WorkerFilter | None" = None, + offset: int | None = None, + limit: int | None = None, + ) -> list["Worker"]: + """ + Reads workers for a given work pool. + + Args: + work_pool_name: The name of the work pool for which to get + member workers. + worker_filter: Criteria by which to filter workers. + limit: Limit for the worker query. + offset: Limit for the worker query. + """ + from prefect.client.schemas.objects import Worker + + response = self.request( + "POST", + "/work_pools/{work_pool_name}/workers/filter", + path_params={"work_pool_name": work_pool_name}, + json={ + "workers": ( + worker_filter.model_dump(mode="json", exclude_unset=True) + if worker_filter + else None + ), + "offset": offset, + "limit": limit, + }, + ) + + return Worker.model_validate_list(response.json()) + + def read_work_pool(self, work_pool_name: str) -> "WorkPool": + """ + Reads information for a given work pool + + Args: + work_pool_name: The name of the work pool to for which to get + information. + + Returns: + Information about the requested work pool. + """ + from prefect.client.schemas.objects import WorkPool + + try: + response = self.request( + "GET", + "/work_pools/{name}", + path_params={"name": work_pool_name}, + ) + return WorkPool.model_validate(response.json()) + except HTTPStatusError as e: + if e.response.status_code == 404: + raise ObjectNotFound(http_exc=e) from e + else: + raise + + def read_work_pools( + self, + limit: int | None = None, + offset: int = 0, + work_pool_filter: "WorkPoolFilter | None" = None, + ) -> list["WorkPool"]: + """ + Reads work pools. + + Args: + limit: Limit for the work pool query. + offset: Offset for the work pool query. + work_pool_filter: Criteria by which to filter work pools. + + Returns: + A list of work pools. + """ + from prefect.client.schemas.objects import WorkPool + + body: dict[str, Any] = { + "limit": limit, + "offset": offset, + "work_pools": ( + work_pool_filter.model_dump(mode="json") if work_pool_filter else None + ), + } + response = self.request("POST", "/work_pools/filter", json=body) + return WorkPool.model_validate_list(response.json()) + + def create_work_pool( + self, + work_pool: "WorkPoolCreate", + overwrite: bool = False, + ) -> "WorkPool": + """ + Creates a work pool with the provided configuration. + + Args: + work_pool: Desired configuration for the new work pool. + + Returns: + Information about the newly created work pool. + """ + from prefect.client.schemas.actions import WorkPoolUpdate + from prefect.client.schemas.objects import WorkPool + + try: + response = self.request( + "POST", + "/work_pools/", + json=work_pool.model_dump(mode="json", exclude_unset=True), + ) + except HTTPStatusError as e: + if e.response.status_code == 409: + if overwrite: + existing_work_pool = self.read_work_pool( + work_pool_name=work_pool.name + ) + if existing_work_pool.type != work_pool.type: + warnings.warn( + "Overwriting work pool type is not supported. Ignoring provided type.", + category=UserWarning, + ) + self.update_work_pool( + work_pool_name=work_pool.name, + work_pool=WorkPoolUpdate.model_validate( + work_pool.model_dump(exclude={"name", "type"}) + ), + ) + response = self.request( + "GET", + "/work_pools/{name}", + path_params={"name": work_pool.name}, + ) + else: + raise ObjectAlreadyExists(http_exc=e) from e + else: + raise + + return WorkPool.model_validate(response.json()) + + def update_work_pool( + self, + work_pool_name: str, + work_pool: "WorkPoolUpdate", + ) -> None: + """ + Updates a work pool. + + Args: + work_pool_name: Name of the work pool to update. + work_pool: Fields to update in the work pool. + """ + try: + self.request( + "PATCH", + "/work_pools/{name}", + path_params={"name": work_pool_name}, + json=work_pool.model_dump(mode="json", exclude_unset=True), + ) + except HTTPStatusError as e: + if e.response.status_code == 404: + raise ObjectNotFound(http_exc=e) from e + else: + raise + + def delete_work_pool( + self, + work_pool_name: str, + ) -> None: + """ + Deletes a work pool. + + Args: + work_pool_name: Name of the work pool to delete. + """ + try: + self.request( + "DELETE", + "/work_pools/{name}", + path_params={"name": work_pool_name}, + ) + except HTTPStatusError as e: + if e.response.status_code == 404: + raise ObjectNotFound(http_exc=e) from e + else: + raise + + def get_scheduled_flow_runs_for_work_pool( + self, + work_pool_name: str, + work_queue_names: list[str] | None = None, + scheduled_before: datetime | None = None, + ) -> list["WorkerFlowRunResponse"]: + """ + Retrieves scheduled flow runs for the provided set of work pool queues. + + Args: + work_pool_name: The name of the work pool that the work pool + queues are associated with. + work_queue_names: The names of the work pool queues from which + to get scheduled flow runs. + scheduled_before: Datetime used to filter returned flow runs. Flow runs + scheduled for after the given datetime string will not be returned. + + Returns: + A list of worker flow run responses containing information about the + retrieved flow runs. + """ + from prefect.client.schemas.responses import WorkerFlowRunResponse + + body: dict[str, Any] = {} + if work_queue_names is not None: + body["work_queue_names"] = list(work_queue_names) + if scheduled_before: + body["scheduled_before"] = str(scheduled_before) + + try: + response = self.request( + "POST", + "/work_pools/{name}/get_scheduled_flow_runs", + path_params={"name": work_pool_name}, + json=body, + ) + except HTTPStatusError as e: + if e.response.status_code == 404: + raise ObjectNotFound(http_exc=e) from e + else: + raise + + return WorkerFlowRunResponse.model_validate_list(response.json()) + + +class WorkPoolAsyncClient(BaseAsyncClient): + async def send_worker_heartbeat( + self, + work_pool_name: str, + worker_name: str, + heartbeat_interval_seconds: float | None = None, + get_worker_id: bool = False, + worker_metadata: "WorkerMetadata | None" = None, + ) -> "UUID | None": + """ + Sends a worker heartbeat for a given work pool. + + Args: + work_pool_name: The name of the work pool to heartbeat against. + worker_name: The name of the worker sending the heartbeat. + return_id: Whether to return the worker ID. Note: will return `None` if the connected server does not support returning worker IDs, even if `return_id` is `True`. + worker_metadata: Metadata about the worker to send to the server. + """ + from uuid import UUID + + params: dict[str, Any] = { + "name": worker_name, + "heartbeat_interval_seconds": heartbeat_interval_seconds, + } + if worker_metadata: + params["metadata"] = worker_metadata.model_dump(mode="json") + if get_worker_id: + params["return_id"] = get_worker_id + + resp = await self.request( + "POST", + "/work_pools/{work_pool_name}/workers/heartbeat", + path_params={"work_pool_name": work_pool_name}, + json=params, + ) + from prefect.settings import get_current_settings + + if ( + ( + self.server_type == ServerType.CLOUD + or get_current_settings().testing.test_mode + ) + and get_worker_id + and resp.status_code == 200 + ): + return UUID(resp.text) + else: + return None + + async def read_workers_for_work_pool( + self, + work_pool_name: str, + worker_filter: "WorkerFilter | None" = None, + offset: int | None = None, + limit: int | None = None, + ) -> list["Worker"]: + """ + Reads workers for a given work pool. + + Args: + work_pool_name: The name of the work pool for which to get + member workers. + worker_filter: Criteria by which to filter workers. + limit: Limit for the worker query. + offset: Limit for the worker query. + """ + from prefect.client.schemas.objects import Worker + + response = await self.request( + "POST", + "/work_pools/{work_pool_name}/workers/filter", + path_params={"work_pool_name": work_pool_name}, + json={ + "workers": ( + worker_filter.model_dump(mode="json", exclude_unset=True) + if worker_filter + else None + ), + "offset": offset, + "limit": limit, + }, + ) + + return Worker.model_validate_list(response.json()) + + async def read_work_pool(self, work_pool_name: str) -> "WorkPool": + """ + Reads information for a given work pool + + Args: + work_pool_name: The name of the work pool to for which to get + information. + + Returns: + Information about the requested work pool. + """ + from prefect.client.schemas.objects import WorkPool + + try: + response = await self.request( + "GET", + "/work_pools/{name}", + path_params={"name": work_pool_name}, + ) + return WorkPool.model_validate(response.json()) + except HTTPStatusError as e: + if e.response.status_code == 404: + raise ObjectNotFound(http_exc=e) from e + else: + raise + + async def read_work_pools( + self, + limit: int | None = None, + offset: int = 0, + work_pool_filter: "WorkPoolFilter | None" = None, + ) -> list["WorkPool"]: + """ + Reads work pools. + + Args: + limit: Limit for the work pool query. + offset: Offset for the work pool query. + work_pool_filter: Criteria by which to filter work pools. + + Returns: + A list of work pools. + """ + from prefect.client.schemas.objects import WorkPool + + body: dict[str, Any] = { + "limit": limit, + "offset": offset, + "work_pools": ( + work_pool_filter.model_dump(mode="json") if work_pool_filter else None + ), + } + response = await self.request("POST", "/work_pools/filter", json=body) + return WorkPool.model_validate_list(response.json()) + + async def create_work_pool( + self, + work_pool: "WorkPoolCreate", + overwrite: bool = False, + ) -> "WorkPool": + """ + Creates a work pool with the provided configuration. + + Args: + work_pool: Desired configuration for the new work pool. + + Returns: + Information about the newly created work pool. + """ + from prefect.client.schemas.actions import WorkPoolUpdate + from prefect.client.schemas.objects import WorkPool + + try: + response = await self.request( + "POST", + "/work_pools/", + json=work_pool.model_dump(mode="json", exclude_unset=True), + ) + except HTTPStatusError as e: + if e.response.status_code == 409: + if overwrite: + existing_work_pool = await self.read_work_pool( + work_pool_name=work_pool.name + ) + if existing_work_pool.type != work_pool.type: + warnings.warn( + "Overwriting work pool type is not supported. Ignoring provided type.", + category=UserWarning, + ) + await self.update_work_pool( + work_pool_name=work_pool.name, + work_pool=WorkPoolUpdate.model_validate( + work_pool.model_dump(exclude={"name", "type"}) + ), + ) + response = await self.request( + "GET", + "/work_pools/{name}", + path_params={"name": work_pool.name}, + ) + else: + raise ObjectAlreadyExists(http_exc=e) from e + else: + raise + + return WorkPool.model_validate(response.json()) + + async def update_work_pool( + self, + work_pool_name: str, + work_pool: "WorkPoolUpdate", + ) -> None: + """ + Updates a work pool. + + Args: + work_pool_name: Name of the work pool to update. + work_pool: Fields to update in the work pool. + """ + try: + await self.request( + "PATCH", + "/work_pools/{name}", + path_params={"name": work_pool_name}, + json=work_pool.model_dump(mode="json", exclude_unset=True), + ) + except HTTPStatusError as e: + if e.response.status_code == 404: + raise ObjectNotFound(http_exc=e) from e + else: + raise + + async def delete_work_pool( + self, + work_pool_name: str, + ) -> None: + """ + Deletes a work pool. + + Args: + work_pool_name: Name of the work pool to delete. + """ + try: + await self.request( + "DELETE", + "/work_pools/{name}", + path_params={"name": work_pool_name}, + ) + except HTTPStatusError as e: + if e.response.status_code == 404: + raise ObjectNotFound(http_exc=e) from e + else: + raise + + async def get_scheduled_flow_runs_for_work_pool( + self, + work_pool_name: str, + work_queue_names: list[str] | None = None, + scheduled_before: datetime | None = None, + ) -> list["WorkerFlowRunResponse"]: + """ + Retrieves scheduled flow runs for the provided set of work pool queues. + + Args: + work_pool_name: The name of the work pool that the work pool + queues are associated with. + work_queue_names: The names of the work pool queues from which + to get scheduled flow runs. + scheduled_before: Datetime used to filter returned flow runs. Flow runs + scheduled for after the given datetime string will not be returned. + + Returns: + A list of worker flow run responses containing information about the + retrieved flow runs. + """ + from prefect.client.schemas.responses import WorkerFlowRunResponse + + body: dict[str, Any] = {} + if work_queue_names is not None: + body["work_queue_names"] = list(work_queue_names) + if scheduled_before: + body["scheduled_before"] = str(scheduled_before) + + try: + response = await self.request( + "POST", + "/work_pools/{name}/get_scheduled_flow_runs", + path_params={"name": work_pool_name}, + json=body, + ) + except HTTPStatusError as e: + if e.response.status_code == 404: + raise ObjectNotFound(http_exc=e) from e + else: + raise + + return WorkerFlowRunResponse.model_validate_list(response.json()) diff --git a/src/prefect/client/orchestration/base.py b/src/prefect/client/orchestration/base.py index 7cda11f7076b..8bbfbdf7fc87 100644 --- a/src/prefect/client/orchestration/base.py +++ b/src/prefect/client/orchestration/base.py @@ -7,12 +7,15 @@ if TYPE_CHECKING: from httpx import AsyncClient, Client, Response + from prefect.client.base import ServerType from prefect.client.orchestration.routes import ServerRoutes HTTP_METHODS: TypeAlias = Literal["GET", "POST", "PUT", "DELETE", "PATCH"] class BaseClient: + server_type: "ServerType" + def __init__(self, client: "Client"): self._client = client @@ -30,6 +33,8 @@ def request( class BaseAsyncClient: + server_type: "ServerType" + def __init__(self, client: "AsyncClient"): self._client = client diff --git a/tests/client/test_prefect_client.py b/tests/client/test_prefect_client.py index 2be38114d9ad..a49dca613141 100644 --- a/tests/client/test_prefect_client.py +++ b/tests/client/test_prefect_client.py @@ -2872,7 +2872,7 @@ async def test_worker_heartbeat_sends_metadata_if_passed( self, prefect_client: PrefectClient ): with mock.patch( - "prefect.client.orchestration.PrefectHttpxAsyncClient.post", + "prefect.client.orchestration.base.BaseAsyncClient.request", return_value=httpx.Response(status_code=204), ) as mock_post: await prefect_client.send_worker_heartbeat( @@ -2895,7 +2895,7 @@ async def test_worker_heartbeat_does_not_send_metadata_if_not_passed( self, prefect_client: PrefectClient ): with mock.patch( - "prefect.client.orchestration.PrefectHttpxAsyncClient.post", + "prefect.client.orchestration.base.BaseAsyncClient.request", return_value=httpx.Response(status_code=204), ) as mock_post: await prefect_client.send_worker_heartbeat( diff --git a/tests/workers/test_base_worker.py b/tests/workers/test_base_worker.py index 08cfbb004f04..781a84b884d9 100644 --- a/tests/workers/test_base_worker.py +++ b/tests/workers/test_base_worker.py @@ -16,7 +16,7 @@ from prefect.client.base import ServerType from prefect.client.orchestration import PrefectClient, get_client from prefect.client.schemas import FlowRun -from prefect.client.schemas.objects import StateType, WorkerMetadata +from prefect.client.schemas.objects import Integration, StateType, WorkerMetadata from prefect.exceptions import ( CrashedRun, ObjectNotFound, @@ -1990,7 +1990,7 @@ async def test_worker_heartbeat_sends_integrations( "prefect.workers.base.load_prefect_collections" ) as mock_load_prefect_collections, mock.patch( - "prefect.client.orchestration.PrefectHttpxAsyncClient.post" + "prefect.client.orchestration._work_pools.client.WorkPoolAsyncClient.send_worker_heartbeat", ) as mock_send_worker_heartbeat_post, mock.patch("prefect.workers.base.distributions") as mock_distributions, ): @@ -2010,17 +2010,13 @@ async def test_worker_heartbeat_sends_integrations( await worker.sync_with_backend() mock_send_worker_heartbeat_post.assert_called_once_with( - f"/work_pools/{work_pool.name}/workers/heartbeat", - json={ - "name": worker.name, - "heartbeat_interval_seconds": worker.heartbeat_interval_seconds, - "metadata": { - "integrations": [ - {"name": "prefect-aws", "version": "1.0.0"} - ] - }, - "return_id": True, - }, + work_pool_name=work_pool.name, + worker_name=worker.name, + heartbeat_interval_seconds=30.0, + get_worker_id=True, + worker_metadata=WorkerMetadata( + integrations=[Integration(name="prefect-aws", version="1.0.0")] + ), ) assert worker._worker_metadata_sent @@ -2050,7 +2046,7 @@ async def _worker_metadata(self) -> WorkerMetadata: "prefect.workers.base.load_prefect_collections" ) as mock_load_prefect_collections, mock.patch( - "prefect.client.orchestration.PrefectHttpxAsyncClient.post" + "prefect.client.orchestration._work_pools.client.WorkPoolAsyncClient.send_worker_heartbeat", ) as mock_send_worker_heartbeat_post, mock.patch("prefect.workers.base.distributions") as mock_distributions, ): @@ -2070,18 +2066,14 @@ async def _worker_metadata(self) -> WorkerMetadata: await worker.sync_with_backend() mock_send_worker_heartbeat_post.assert_called_once_with( - f"/work_pools/{work_pool.name}/workers/heartbeat", - json={ - "name": worker.name, - "heartbeat_interval_seconds": worker.heartbeat_interval_seconds, - "metadata": { - "integrations": [ - {"name": "prefect-aws", "version": "1.0.0"} - ], - "custom_field": "heya", - }, - "return_id": True, - }, + work_pool_name=work_pool.name, + worker_name=worker.name, + heartbeat_interval_seconds=30.0, + get_worker_id=True, + worker_metadata=WorkerMetadata( + integrations=[Integration(name="prefect-aws", version="1.0.0")], + custom_field="heya", + ), ) assert worker._worker_metadata_sent