Skip to content

Commit

Permalink
Add update mask to patch dag endpoint (apache#10535)
Browse files Browse the repository at this point in the history
  • Loading branch information
ephraimbuddy authored Aug 25, 2020
1 parent d760265 commit d6ce8c8
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 3 deletions.
14 changes: 11 additions & 3 deletions airflow/api_connexion/endpoints/dag_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def get_dags(session, limit, offset=0):

@security.requires_authentication
@provide_session
def patch_dag(session, dag_id):
def patch_dag(session, dag_id, update_mask=None):
"""
Update the specific DAG
"""
Expand All @@ -83,7 +83,15 @@ def patch_dag(session, dag_id):
patch_body = dag_schema.load(request.json, session=session)
except ValidationError as err:
raise BadRequest("Invalid Dag schema", detail=str(err.messages))
for key, value in patch_body.items():
setattr(dag, key, value)
if update_mask:
patch_body_ = {}
if len(update_mask) > 1:
raise BadRequest(detail="Only `is_paused` field can be updated through the REST API")
update_mask = update_mask[0]
if update_mask != 'is_paused':
raise BadRequest(detail="Only `is_paused` field can be updated through the REST API")
patch_body_[update_mask] = patch_body[update_mask]
patch_body = patch_body_
setattr(dag, 'is_paused', patch_body['is_paused'])
session.commit()
return dag_schema.dump(dag)
2 changes: 2 additions & 0 deletions airflow/api_connexion/openapi/v1.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,8 @@ paths:
summary: Update a DAG
x-openapi-router-controller: airflow.api_connexion.endpoints.dag_endpoint
operationId: patch_dag
parameters:
- $ref: '#/components/parameters/UpdateMask'
tags: [DAG]
requestBody:
required: true
Expand Down
56 changes: 56 additions & 0 deletions tests/api_connexion/endpoints/test_dag_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,3 +391,59 @@ def test_should_raises_401_unauthenticated(self):
)

assert_401(response)

def test_should_response_200_with_update_mask(self):
dag_model = self._create_dag_model()
payload = {
"is_paused": False,
}
response = self.client.patch(
f"/api/v1/dags/{dag_model.dag_id}?update_mask=is_paused",
json=payload,
environ_overrides={'REMOTE_USER': "test"}
)
self.assertEqual(response.status_code, 200)
expected_response = {
"dag_id": "TEST_DAG_1",
"description": None,
"fileloc": "/tmp/dag_1.py",
"is_paused": False,
"is_subdag": False,
"owners": [],
"root_dag_id": None,
"schedule_interval": {
"__type": "CronExpression",
"value": "2 2 * * *",
},
"tags": [],
}
self.assertEqual(response.json, expected_response)

@parameterized.expand([
(
{
"is_paused": True,
},
"update_mask=description",
"Only `is_paused` field can be updated through the REST API"
),
(
{
"is_paused": True,
},
"update_mask=schedule_interval, description",
"Only `is_paused` field can be updated through the REST API"
)
])
def test_should_response_400_for_invalid_fields_in_update_mask(
self, payload, update_mask, error_message
):
dag_model = self._create_dag_model()

response = self.client.patch(
f"/api/v1/dags/{dag_model.dag_id}?{update_mask}",
json=payload,
environ_overrides={'REMOTE_USER': "test"}
)
self.assertEqual(response.status_code, 400)
self.assertEqual(response.json['detail'], error_message)

0 comments on commit d6ce8c8

Please sign in to comment.