From d6ce8c8561284834c5417290c49e461141111f1d Mon Sep 17 00:00:00 2001 From: Ephraim Anierobi Date: Tue, 25 Aug 2020 11:56:19 +0100 Subject: [PATCH] Add update mask to patch dag endpoint (#10535) --- .../api_connexion/endpoints/dag_endpoint.py | 14 ++++- airflow/api_connexion/openapi/v1.yaml | 2 + .../endpoints/test_dag_endpoint.py | 56 +++++++++++++++++++ 3 files changed, 69 insertions(+), 3 deletions(-) diff --git a/airflow/api_connexion/endpoints/dag_endpoint.py b/airflow/api_connexion/endpoints/dag_endpoint.py index 27634885f8d11..0192f96e12b2e 100644 --- a/airflow/api_connexion/endpoints/dag_endpoint.py +++ b/airflow/api_connexion/endpoints/dag_endpoint.py @@ -72,7 +72,7 @@ def get_dags(session, limit, offset=0): @security.requires_authentication @provide_session -def patch_dag(session, dag_id): +def patch_dag(session, dag_id, update_mask=None): """ Update the specific DAG """ @@ -83,7 +83,15 @@ def patch_dag(session, dag_id): patch_body = dag_schema.load(request.json, session=session) except ValidationError as err: raise BadRequest("Invalid Dag schema", detail=str(err.messages)) - for key, value in patch_body.items(): - setattr(dag, key, value) + if update_mask: + patch_body_ = {} + if len(update_mask) > 1: + raise BadRequest(detail="Only `is_paused` field can be updated through the REST API") + update_mask = update_mask[0] + if update_mask != 'is_paused': + raise BadRequest(detail="Only `is_paused` field can be updated through the REST API") + patch_body_[update_mask] = patch_body[update_mask] + patch_body = patch_body_ + setattr(dag, 'is_paused', patch_body['is_paused']) session.commit() return dag_schema.dump(dag) diff --git a/airflow/api_connexion/openapi/v1.yaml b/airflow/api_connexion/openapi/v1.yaml index 798b275d60a45..7970538cbc6e1 100644 --- a/airflow/api_connexion/openapi/v1.yaml +++ b/airflow/api_connexion/openapi/v1.yaml @@ -203,6 +203,8 @@ paths: summary: Update a DAG x-openapi-router-controller: airflow.api_connexion.endpoints.dag_endpoint operationId: patch_dag + parameters: + - $ref: '#/components/parameters/UpdateMask' tags: [DAG] requestBody: required: true diff --git a/tests/api_connexion/endpoints/test_dag_endpoint.py b/tests/api_connexion/endpoints/test_dag_endpoint.py index 66c7cf7ec8ea8..bc5fb2e03be77 100644 --- a/tests/api_connexion/endpoints/test_dag_endpoint.py +++ b/tests/api_connexion/endpoints/test_dag_endpoint.py @@ -391,3 +391,59 @@ def test_should_raises_401_unauthenticated(self): ) assert_401(response) + + def test_should_response_200_with_update_mask(self): + dag_model = self._create_dag_model() + payload = { + "is_paused": False, + } + response = self.client.patch( + f"/api/v1/dags/{dag_model.dag_id}?update_mask=is_paused", + json=payload, + environ_overrides={'REMOTE_USER': "test"} + ) + self.assertEqual(response.status_code, 200) + expected_response = { + "dag_id": "TEST_DAG_1", + "description": None, + "fileloc": "/tmp/dag_1.py", + "is_paused": False, + "is_subdag": False, + "owners": [], + "root_dag_id": None, + "schedule_interval": { + "__type": "CronExpression", + "value": "2 2 * * *", + }, + "tags": [], + } + self.assertEqual(response.json, expected_response) + + @parameterized.expand([ + ( + { + "is_paused": True, + }, + "update_mask=description", + "Only `is_paused` field can be updated through the REST API" + ), + ( + { + "is_paused": True, + }, + "update_mask=schedule_interval, description", + "Only `is_paused` field can be updated through the REST API" + ) + ]) + def test_should_response_400_for_invalid_fields_in_update_mask( + self, payload, update_mask, error_message + ): + dag_model = self._create_dag_model() + + response = self.client.patch( + f"/api/v1/dags/{dag_model.dag_id}?{update_mask}", + json=payload, + environ_overrides={'REMOTE_USER': "test"} + ) + self.assertEqual(response.status_code, 400) + self.assertEqual(response.json['detail'], error_message)