Skip to content

Commit

Permalink
Adds configuration to kill_infrastructure calls
Browse files Browse the repository at this point in the history
  • Loading branch information
desertaxle committed Apr 18, 2023
1 parent fe9963c commit 49aaef9
Show file tree
Hide file tree
Showing 4 changed files with 163 additions and 47 deletions.
14 changes: 10 additions & 4 deletions src/prefect/workers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,10 @@ async def run(
)

async def kill_infrastructure(
self, infrastructure_pid: str, grace_seconds: int = 30
self,
infrastructure_pid: str,
configuration: BaseJobConfiguration,
grace_seconds: int = 30,
):
"""
Method for killing infrastructure created by a worker. Should be implemented by
Expand Down Expand Up @@ -505,8 +508,13 @@ async def cancel_run(self, flow_run: "FlowRun"):
)
return

configuration = await self._get_configuration(flow_run)

try:
await self.kill_infrastructure(flow_run.infrastructure_pid)
await self.kill_infrastructure(
infrastructure_pid=flow_run.infrastructure_pid,
configuration=configuration,
)
except NotImplementedError:
self._logger.error(
f"Worker type {self.type!r} does not support killing created "
Expand Down Expand Up @@ -711,8 +719,6 @@ async def _submit_run_and_capture_errors(
self, flow_run: "FlowRun", task_status: anyio.abc.TaskStatus = None
) -> Union[BaseWorkerResult, Exception]:
try:
# TODO: Add functionality to handle base job configuration and
# job configuration variables when kicking off a flow run
configuration = await self._get_configuration(flow_run)
submitted_event = self._emit_flow_run_submitted_event(configuration)
result = await self.run(
Expand Down
9 changes: 6 additions & 3 deletions src/prefect/workers/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ def _parse_infrastructure_pid(infrastructure_pid: str) -> Tuple[str, int]:


class ProcessJobConfiguration(BaseJobConfiguration):
stream_output: bool
working_dir: Optional[Path]
stream_output: bool = Field(default=True)
working_dir: Optional[Path] = Field(default=None)

@validator("working_dir")
def validate_command(cls, v):
Expand Down Expand Up @@ -195,7 +195,10 @@ async def run(
)

async def kill_infrastructure(
self, infrastructure_pid: str, grace_seconds: int = 30
self,
infrastructure_pid: str,
configuration: ProcessJobConfiguration,
grace_seconds: int = 30,
):
hostname, pid = _parse_infrastructure_pid(infrastructure_pid)

Expand Down
Loading

0 comments on commit 49aaef9

Please sign in to comment.