diff --git a/tests/api_connexion/endpoints/test_dag_endpoint.py b/tests/api_connexion/endpoints/test_dag_endpoint.py index c875a14131c73..d89fd85fa484e 100644 --- a/tests/api_connexion/endpoints/test_dag_endpoint.py +++ b/tests/api_connexion/endpoints/test_dag_endpoint.py @@ -359,7 +359,7 @@ def test_should_response_200_with_update_mask(self): response = self.client.patch( f"/api/v1/dags/{dag_model.dag_id}?update_mask=is_paused", json=payload, - environ_overrides={'REMOTE_USER': "test"} + environ_overrides={'REMOTE_USER': "test"}, ) self.assertEqual(response.status_code, 200) expected_response = { @@ -370,39 +370,32 @@ def test_should_response_200_with_update_mask(self): "is_subdag": False, "owners": [], "root_dag_id": None, - "schedule_interval": { - "__type": "CronExpression", - "value": "2 2 * * *", - }, + "schedule_interval": {"__type": "CronExpression", "value": "2 2 * * *",}, "tags": [], } self.assertEqual(response.json, expected_response) - @parameterized.expand([ - ( - { - "is_paused": True, - }, - "update_mask=description", - "Only `is_paused` field can be updated through the REST API" - ), - ( - { - "is_paused": True, - }, - "update_mask=schedule_interval, description", - "Only `is_paused` field can be updated through the REST API" - ) - ]) - def test_should_response_400_for_invalid_fields_in_update_mask( - self, payload, update_mask, error_message - ): + @parameterized.expand( + [ + ( + {"is_paused": True,}, + "update_mask=description", + "Only `is_paused` field can be updated through the REST API", + ), + ( + {"is_paused": True,}, + "update_mask=schedule_interval, description", + "Only `is_paused` field can be updated through the REST API", + ), + ] + ) + def test_should_response_400_for_invalid_fields_in_update_mask(self, payload, update_mask, error_message): dag_model = self._create_dag_model() response = self.client.patch( f"/api/v1/dags/{dag_model.dag_id}?{update_mask}", json=payload, - environ_overrides={'REMOTE_USER': "test"} + environ_overrides={'REMOTE_USER': "test"}, ) self.assertEqual(response.status_code, 400) self.assertEqual(response.json['detail'], error_message)