Skip to content

Commit

Permalink
Address review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
pankajkoti committed Jul 8, 2024
1 parent 7a675ed commit b98b55b
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 22 deletions.
4 changes: 2 additions & 2 deletions airflow/providers/databricks/operators/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
WorkflowRunMetadata,
)
from airflow.providers.databricks.plugins.databricks_workflow import (
WorkflowJobRepairSingleFailedLink,
WorkflowJobRepairSingleTaskLink,
WorkflowJobRunLink,
)
from airflow.providers.databricks.triggers.databricks import DatabricksExecutionTrigger
Expand Down Expand Up @@ -965,7 +965,7 @@ def __init__(
if self._databricks_workflow_task_group is not None:
self.operator_extra_links = (
WorkflowJobRunLink(),
WorkflowJobRepairSingleFailedLink(),
WorkflowJobRepairSingleTaskLink(),
)
else:
# Databricks does not support repair for non-workflow tasks, hence do not show the repair link.
Expand Down
40 changes: 22 additions & 18 deletions airflow/providers/databricks/plugins/databricks_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from operator import itemgetter
from typing import TYPE_CHECKING, Any, cast

from flask import current_app, flash, redirect, request
from flask import current_app, flash, redirect, request, url_for
from flask_appbuilder.api import expose

from airflow.configuration import conf
Expand Down Expand Up @@ -279,13 +279,16 @@ def get_link(

tasks_str = self.get_tasks_to_run(ti_key, operator, self.log)
self.log.debug("tasks to rerun: %s", tasks_str)
return (
f"/repair_databricks_job?dag_id={ti_key.dag_id}&"
f"databricks_conn_id={metadata.conn_id}&"
f"databricks_run_id={metadata.run_id}&"
f"run_id={ti_key.run_id}&"
f"tasks_to_repair={tasks_str}"
)

query_params = {
"dag_id": ti_key.dag_id,
"databricks_conn_id": metadata.conn_id,
"databricks_run_id": metadata.run_id,
"run_id": ti_key.run_id,
"tasks_to_repair": tasks_str,
}

return url_for("RepairDatabricksTasks.repair", **query_params)

@classmethod
def get_task_group_children(cls, task_group: TaskGroup) -> dict[str, BaseOperator]:
Expand Down Expand Up @@ -344,10 +347,10 @@ def _get_failed_and_skipped_tasks(dr: DagRun) -> list[str]:
]


class WorkflowJobRepairSingleFailedLink(BaseOperatorLink, LoggingMixin):
class WorkflowJobRepairSingleTaskLink(BaseOperatorLink, LoggingMixin):
"""Construct a link to send a repair request for a single databricks task."""

name = "Repair a single failed task"
name = "Repair a single task"

def get_link(
self,
Expand Down Expand Up @@ -378,13 +381,14 @@ def get_link(
ti_key = _get_launch_task_key(ti_key, task_id=launch_task_id)
metadata = get_xcom_result(ti_key, "return_value", ti)

return (
f"/repair_databricks_job?dag_id={ti_key.dag_id}&"
f"databricks_conn_id={metadata.conn_id}&"
f"databricks_run_id={metadata.run_id}&"
f"tasks_to_repair={_get_databricks_task_id(task)}&"
f"run_id={ti_key.run_id}"
)
query_params = {
"dag_id": ti_key.dag_id,
"databricks_conn_id": metadata.conn_id,
"databricks_run_id": metadata.run_id,
"run_id": ti_key.run_id,
"tasks_to_repair": _get_databricks_task_id(task),
}
return url_for("RepairDatabricksTasks.repair", **query_params)


class RepairDatabricksTasks(AirflowBaseView, LoggingMixin):
Expand Down Expand Up @@ -447,7 +451,7 @@ class DatabricksWorkflowPlugin(AirflowPlugin):
name = "databricks_workflow"
operator_extra_links = [
WorkflowJobRepairAllFailedLink(),
WorkflowJobRepairSingleFailedLink(),
WorkflowJobRepairSingleTaskLink(),
WorkflowJobRunLink(),
]
appbuilder_views = [repair_databricks_package]
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from airflow.providers.databricks.plugins.databricks_workflow import (
DatabricksWorkflowPlugin,
RepairDatabricksTasks,
WorkflowJobRepairSingleFailedLink,
WorkflowJobRepairSingleTaskLink,
WorkflowJobRunLink,
_get_dagrun,
_get_databricks_task_id,
Expand Down Expand Up @@ -211,7 +211,7 @@ def test_workflow_job_run_link(app):

def test_workflow_job_repair_single_failed_link(app):
with app.app_context():
link = WorkflowJobRepairSingleFailedLink()
link = WorkflowJobRepairSingleTaskLink()
operator = Mock()
operator.task_group = Mock()
operator.task_group.group_id = "group_id"
Expand Down

0 comments on commit b98b55b

Please sign in to comment.