Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
refactor KubernetesJob and KubernetesJobRun to use existing task func…
Browse files Browse the repository at this point in the history
…tions. (#43)

* refactor KubernetesJob and KubernetesJobRun to use existing task
functions.

    - Use existing tasks to create, read, delete or fetch job status.
    - Use existing pods tasks to list pods and read pod logs.
    - Added read_namespaced_job_status task.

* Update prefect_kubernetes/jobs.py

correct doc string grammer.

Co-authored-by: Alexander Streed <[email protected]>

* update CHANGELOG.md

* move summary to unreleased section

* Update CHANGELOG.md

Co-authored-by: Alexander Streed <[email protected]>

---------

Co-authored-by: Alexander Streed <[email protected]>
  • Loading branch information
tardunge and desertaxle authored Mar 27, 2023
1 parent 5b116f8 commit 6f7c6ec
Show file tree
Hide file tree
Showing 4 changed files with 142 additions and 88 deletions.
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -135,4 +135,7 @@ dmypy.json
.vscode

# Jupyter notebook
*.ipynb
*.ipynb

#direnv .envrc
.envrc
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Custom objects crud tasks for kubernetes custom resource definitions. - [#45](https://github.com/PrefectHQ/prefect-kubernetes/pull/45)

### Changed
- Refactor KubernetesJob and KubernetesJobRun to use existing task functions - [#43](https://github.com/PrefectHQ/prefect-kubernetes/pull/43)

### Deprecated

Expand Down
202 changes: 116 additions & 86 deletions prefect_kubernetes/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from prefect_kubernetes.credentials import KubernetesCredentials
from prefect_kubernetes.exceptions import KubernetesJobTimeoutError
from prefect_kubernetes.pods import list_namespaced_pod, read_namespaced_pod_log
from prefect_kubernetes.utilities import convert_manifest_to_model

KubernetesManifest = Union[Dict, Path, str]
Expand Down Expand Up @@ -56,7 +57,6 @@ def kubernetes_orchestrator():
```
"""
with kubernetes_credentials.get_client("batch") as batch_v1_client:

return await run_sync_in_worker_thread(
batch_v1_client.create_namespaced_job,
namespace=namespace,
Expand Down Expand Up @@ -107,7 +107,6 @@ def kubernetes_orchestrator():
"""

with kubernetes_credentials.get_client("batch") as batch_v1_client:

return await run_sync_in_worker_thread(
batch_v1_client.delete_namespaced_job,
name=job_name,
Expand Down Expand Up @@ -151,7 +150,6 @@ def kubernetes_orchestrator():
```
"""
with kubernetes_credentials.get_client("batch") as batch_v1_client:

return await run_sync_in_worker_thread(
batch_v1_client.list_namespaced_job,
namespace=namespace,
Expand Down Expand Up @@ -204,7 +202,6 @@ def kubernetes_orchestrator():
"""

with kubernetes_credentials.get_client("batch") as batch_v1_client:

return await run_sync_in_worker_thread(
batch_v1_client.patch_namespaced_job,
name=job_name,
Expand Down Expand Up @@ -253,7 +250,6 @@ def kubernetes_orchestrator():
```
"""
with kubernetes_credentials.get_client("batch") as batch_v1_client:

return await run_sync_in_worker_thread(
batch_v1_client.read_namespaced_job,
name=job_name,
Expand Down Expand Up @@ -301,7 +297,6 @@ def kubernetes_orchestrator():
```
"""
with kubernetes_credentials.get_client("batch") as batch_v1_client:

return await run_sync_in_worker_thread(
batch_v1_client.replace_namespaced_job,
name=job_name,
Expand All @@ -311,6 +306,50 @@ def kubernetes_orchestrator():
)


@task
async def read_namespaced_job_status(
kubernetes_credentials: KubernetesCredentials,
job_name: str,
namespace: Optional[str] = "default",
**kube_kwargs: Dict[str, Any],
) -> V1Job:
"""Task for fetching status of a namespaced Kubernetes job.
Args:
kubernetes_credentials: `KubernetesCredentials` block
holding authentication needed to generate the required API client.
job_name: The name of a job to fetch status for.
namespace: The Kubernetes namespace to fetch status of job in.
**kube_kwargs: Optional extra keyword arguments to pass to the
Kubernetes API (e.g. `{"pretty": "...", "dry_run": "..."}`).
Returns:
A Kubernetes `V1JobStatus` object.
Example:
Fetch status of a job in the default namespace:
```python
from prefect import flow
from prefect_kubernetes.credentials import KubernetesCredentials
from prefect_kubernetes.jobs import read_namespaced_job_status
@flow
def kubernetes_orchestrator():
v1_job_status = read_namespaced_job_status(
kubernetes_credentials=KubernetesCredentials.load("k8s-creds"),
job_name="my-job",
)
```
"""
with kubernetes_credentials.get_client("batch") as batch_v1_client:
return await run_sync_in_worker_thread(
batch_v1_client.read_namespaced_job_status,
name=job_name,
namespace=namespace,
**kube_kwargs,
)


class KubernetesJobRun(JobRun[Dict[str, Any]]):
"""A container representing a run of a Kubernetes job."""

Expand All @@ -327,17 +366,20 @@ def __init__(

async def _cleanup(self):
"""Deletes the Kubernetes job resource."""
with self._kubernetes_job.credentials.get_client("batch") as batch_v1_client:
deleted_v1_job = await run_sync_in_worker_thread(
batch_v1_client.delete_namespaced_job,
namespace=self._kubernetes_job.namespace,
name=self._v1_job_model.metadata.name,
**self._kubernetes_job.api_kwargs,
)
self.logger.info(
f"Job {self._v1_job_model.metadata.name} deleted "
f"with {deleted_v1_job.status!r}."
)

delete_options = V1DeleteOptions(propagation_policy="Foreground")

deleted_v1_job = await delete_namespaced_job.fn(
kubernetes_credentials=self._kubernetes_job.credentials,
job_name=self._v1_job_model.metadata.name,
delete_options=delete_options,
namespace=self._kubernetes_job.namespace,
**self._kubernetes_job.api_kwargs,
)
self.logger.info(
f"Job {self._v1_job_model.metadata.name} deleted "
f"with {deleted_v1_job.status!r}."
)

@sync_compatible
async def wait_for_completion(self):
Expand All @@ -354,74 +396,63 @@ async def wait_for_completion(self):
"""
self.pod_logs = {}

with self._kubernetes_job.credentials.get_client(
"batch"
) as batch_v1_client, self._kubernetes_job.credentials.get_client(
"core"
) as core_v1_client:

elapsed_time = 0
elapsed_time = 0

while not self._completed:
job_expired = (
elapsed_time > self._kubernetes_job.timeout_seconds
if self._kubernetes_job.timeout_seconds
else False
while not self._completed:
job_expired = (
elapsed_time > self._kubernetes_job.timeout_seconds
if self._kubernetes_job.timeout_seconds
else False
)
if job_expired:
raise KubernetesJobTimeoutError(
f"Job timed out after {elapsed_time} seconds."
)
if job_expired:
raise KubernetesJobTimeoutError(
f"Job timed out after {elapsed_time} seconds."
)

latest_v1_job = await run_sync_in_worker_thread(
batch_v1_client.read_namespaced_job_status,
name=self._v1_job_model.metadata.name,

v1_job_status = await read_namespaced_job_status.fn(
kubernetes_credentials=self._kubernetes_job.credentials,
job_name=self._v1_job_model.metadata.name,
namespace=self._kubernetes_job.namespace,
**self._kubernetes_job.api_kwargs,
)
pod_selector = (
"controller-uid=" f"{v1_job_status.metadata.labels['controller-uid']}"
)
v1_pod_list = await list_namespaced_pod.fn(
kubernetes_credentials=self._kubernetes_job.credentials,
namespace=self._kubernetes_job.namespace,
label_selector=pod_selector,
**self._kubernetes_job.api_kwargs,
)

for pod in v1_pod_list.items:
pod_name = pod.metadata.name

if pod.status.phase == "Pending" or pod_name in self.pod_logs.keys():
continue

self.logger.info(f"Capturing logs for pod {pod_name!r}.")

self.pod_logs[pod_name] = await read_namespaced_pod_log.fn(
kubernetes_credentials=self._kubernetes_job.credentials,
pod_name=pod_name,
container=v1_job_status.spec.template.spec.containers[0].name,
namespace=self._kubernetes_job.namespace,
**self._kubernetes_job.api_kwargs,
)
pod_selector = (
"controller-uid="
f"{latest_v1_job.metadata.labels['controller-uid']}"
)
v1_pod_list = await run_sync_in_worker_thread(
core_v1_client.list_namespaced_pod,
namespace=self._kubernetes_job.namespace,
label_selector=pod_selector,
**self._kubernetes_job.api_kwargs,

if v1_job_status.status.active:
await sleep(self._kubernetes_job.interval_seconds)
if self._kubernetes_job.timeout_seconds:
elapsed_time += self._kubernetes_job.interval_seconds
elif v1_job_status.status.failed:
raise RuntimeError(
f"Job {v1_job_status.metadata.name!r} failed, check the "
"Kubernetes pod logs for more information."
)
for pod in v1_pod_list.items:
pod_name = pod.metadata.name

if (
pod.status.phase == "Pending"
or pod_name in self.pod_logs.keys()
):
continue

self.logger.info(f"Capturing logs for pod {pod_name!r}.")

self.pod_logs[pod_name] = await run_sync_in_worker_thread(
core_v1_client.read_namespaced_pod_log,
namespace=self._kubernetes_job.namespace,
name=pod_name,
container=latest_v1_job.spec.template.spec.containers[0].name,
**self._kubernetes_job.api_kwargs,
)

if latest_v1_job.status.active:
await sleep(self._kubernetes_job.interval_seconds)
if self._kubernetes_job.timeout_seconds:
elapsed_time += self._kubernetes_job.interval_seconds
elif latest_v1_job.status.failed:
raise RuntimeError(
f"Job {latest_v1_job.metadata.name!r} failed, check the "
"Kubernetes pod logs for more information."
)
elif latest_v1_job.status.succeeded:
self._completed = True
self.logger.info(
f"Job {latest_v1_job.metadata.name!r} has completed."
)
elif v1_job_status.status.succeeded:
self._completed = True
self.logger.info(f"Job {v1_job_status.metadata.name!r} has completed.")

if self._kubernetes_job.delete_after_completion:
await self._cleanup()
Expand Down Expand Up @@ -495,13 +526,12 @@ async def trigger(self):

v1_job_model = convert_manifest_to_model(self.v1_job, "V1Job")

with self.credentials.get_client("batch") as batch_v1_client:
await run_sync_in_worker_thread(
batch_v1_client.create_namespaced_job,
body=v1_job_model,
namespace=self.namespace,
**self.api_kwargs,
)
await create_namespaced_job.fn(
kubernetes_credentials=self.credentials,
new_job=v1_job_model,
namespace=self.namespace,
**self.api_kwargs,
)

return KubernetesJobRun(kubernetes_job=self, v1_job_model=v1_job_model)

Expand Down
22 changes: 21 additions & 1 deletion tests/test_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
list_namespaced_job,
patch_namespaced_job,
read_namespaced_job,
read_namespaced_job_status,
replace_namespaced_job,
)

Expand Down Expand Up @@ -111,6 +112,26 @@ async def test_replace_namespaced_job(kubernetes_credentials, _mock_api_batch_cl
assert _mock_api_batch_client.replace_namespaced_job.call_args[1]["a"] == "test"


async def test_read_namespaced_job_status(
kubernetes_credentials, _mock_api_batch_client
):
await read_namespaced_job_status.fn(
job_name="test-job",
namespace="ns",
a="test",
kubernetes_credentials=kubernetes_credentials,
)
assert (
_mock_api_batch_client.read_namespaced_job_status.call_args[1]["name"]
== "test-job"
)
assert (
_mock_api_batch_client.read_namespaced_job_status.call_args[1]["namespace"]
== "ns"
)
assert _mock_api_batch_client.read_namespaced_job_status.call_args[1]["a"] == "test"


async def test_job_block_from_job_yaml(kubernetes_credentials):
job = KubernetesJob.from_yaml_file(
credentials=kubernetes_credentials,
Expand All @@ -125,7 +146,6 @@ async def test_job_block_wait_never_called_raises(
mock_create_namespaced_job,
mock_delete_namespaced_job,
):

job_run = await valid_kubernetes_job_block.trigger()

with pytest.raises(
Expand Down

0 comments on commit 6f7c6ec

Please sign in to comment.