Skip to content

Commit

Permalink
Replace unittests in providers tests by pure pytest [Wave-3] (apa…
Browse files Browse the repository at this point in the history
  • Loading branch information
Taragolis authored Nov 26, 2022
1 parent b6013c0 commit 518fd84
Show file tree
Hide file tree
Showing 118 changed files with 615 additions and 922 deletions.
35 changes: 16 additions & 19 deletions tests/providers/airbyte/hooks/test_airbyte.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,17 @@
# under the License.
from __future__ import annotations

import unittest
from unittest import mock

import pytest
import requests_mock

from airflow.exceptions import AirflowException
from airflow.models import Connection
from airflow.providers.airbyte.hooks.airbyte import AirbyteHook
from airflow.utils import db


class TestAirbyteHook(unittest.TestCase):
class TestAirbyteHook:
"""
Test all functions from Airbyte Hook
"""
Expand All @@ -46,7 +44,7 @@ class TestAirbyteHook(unittest.TestCase):
_mock_job_status_success_response_body = {"job": {"status": "succeeded"}}
_mock_job_cancel_status = "cancelled"

def setUp(self):
def setup_method(self):
db.merge_conn(
Connection(
conn_id="airbyte_conn_id_test", conn_type="airbyte", host="http://test-airbyte", port=8001
Expand All @@ -59,25 +57,26 @@ def return_value_get_job(self, status):
response.json.return_value = {"job": {"status": status}}
return response

@requests_mock.mock()
def test_submit_sync_connection(self, m):
m.post(
def test_submit_sync_connection(self, requests_mock):
requests_mock.post(
self.sync_connection_endpoint, status_code=200, json=self._mock_sync_conn_success_response_body
)
resp = self.hook.submit_sync_connection(connection_id=self.connection_id)
assert resp.status_code == 200
assert resp.json() == self._mock_sync_conn_success_response_body

@requests_mock.mock()
def test_get_job_status(self, m):
m.post(self.get_job_endpoint, status_code=200, json=self._mock_job_status_success_response_body)
def test_get_job_status(self, requests_mock):
requests_mock.post(
self.get_job_endpoint, status_code=200, json=self._mock_job_status_success_response_body
)
resp = self.hook.get_job(job_id=self.job_id)
assert resp.status_code == 200
assert resp.json() == self._mock_job_status_success_response_body

@requests_mock.mock()
def test_cancel_job(self, m):
m.post(self.cancel_job_endpoint, status_code=200, json=self._mock_job_status_success_response_body)
def test_cancel_job(self, requests_mock):
requests_mock.post(
self.cancel_job_endpoint, status_code=200, json=self._mock_job_status_success_response_body
)
resp = self.hook.cancel_job(job_id=self.job_id)
assert resp.status_code == 200

Expand Down Expand Up @@ -147,9 +146,8 @@ def test_wait_for_job_cancelled(self, mock_get_job):
calls = [mock.call(job_id=self.job_id), mock.call(job_id=self.job_id)]
mock_get_job.assert_has_calls(calls)

@requests_mock.mock()
def test_connection_success(self, m):
m.get(
def test_connection_success(self, requests_mock):
requests_mock.get(
self.health_endpoint,
status_code=200,
)
Expand All @@ -158,9 +156,8 @@ def test_connection_success(self, m):
assert status is True
assert msg == "Connection successfully tested"

@requests_mock.mock()
def test_connection_failure(self, m):
m.get(self.health_endpoint, status_code=500, json={"message": "internal server error"})
def test_connection_failure(self, requests_mock):
requests_mock.get(self.health_endpoint, status_code=500, json={"message": "internal server error"})

status, msg = self.hook.test_connection()
assert status is False
Expand Down
3 changes: 1 addition & 2 deletions tests/providers/airbyte/operators/test_airbyte.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,12 @@
# under the License.
from __future__ import annotations

import unittest
from unittest import mock

from airflow.providers.airbyte.operators.airbyte import AirbyteTriggerSyncOperator


class TestAirbyteTriggerSyncOp(unittest.TestCase):
class TestAirbyteTriggerSyncOp:
"""
Test execute function from Airbyte Operator
"""
Expand Down
3 changes: 1 addition & 2 deletions tests/providers/airbyte/sensors/test_airbyte.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
# under the License.
from __future__ import annotations

import unittest
from unittest import mock

import pytest
Expand All @@ -25,7 +24,7 @@
from airflow.providers.airbyte.sensors.airbyte import AirbyteJobSensor


class TestAirbyteJobSensor(unittest.TestCase):
class TestAirbyteJobSensor:

task_id = "task-id"
airbyte_conn_id = "airbyte-conn-test"
Expand Down
5 changes: 2 additions & 3 deletions tests/providers/alibaba/cloud/hooks/test_oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
# under the License.
from __future__ import annotations

import unittest
from unittest import mock

from airflow.providers.alibaba.cloud.hooks.oss import OSSHook
Expand All @@ -32,8 +31,8 @@
MOCK_FILE_PATH = "mock_file_path"


class TestOSSHook(unittest.TestCase):
def setUp(self):
class TestOSSHook:
def setup_method(self):
with mock.patch(
OSS_STRING.format("OSSHook.__init__"),
new=mock_oss_hook_default_project_id,
Expand Down
5 changes: 2 additions & 3 deletions tests/providers/alibaba/cloud/log/test_oss_task_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
# under the License.
from __future__ import annotations

import unittest
from unittest import mock
from unittest.mock import PropertyMock

Expand All @@ -32,8 +31,8 @@
MOCK_FILE_PATH = "mock_file_path"


class TestOSSTaskHandler(unittest.TestCase):
def setUp(self):
class TestOSSTaskHandler:
def setup_method(self):
self.base_log_folder = "local/airflow/logs/1.log"
self.oss_log_folder = f"oss://{MOCK_BUCKET_NAME}/airflow/logs"
self.oss_task_handler = OSSTaskHandler(self.base_log_folder, self.oss_log_folder)
Expand Down
13 changes: 6 additions & 7 deletions tests/providers/alibaba/cloud/operators/test_oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
# under the License.
from __future__ import annotations

import unittest
from unittest import mock

from airflow.providers.alibaba.cloud.operators.oss import (
Expand All @@ -38,7 +37,7 @@
MOCK_CONTENT = "mock_content"


class TestOSSCreateBucketOperator(unittest.TestCase):
class TestOSSCreateBucketOperator:
@mock.patch("airflow.providers.alibaba.cloud.operators.oss.OSSHook")
def test_execute(self, mock_hook):
operator = OSSCreateBucketOperator(
Expand All @@ -49,7 +48,7 @@ def test_execute(self, mock_hook):
mock_hook.return_value.create_bucket.assert_called_once_with(bucket_name=MOCK_BUCKET)


class TestOSSDeleteBucketOperator(unittest.TestCase):
class TestOSSDeleteBucketOperator:
@mock.patch("airflow.providers.alibaba.cloud.operators.oss.OSSHook")
def test_execute(self, mock_hook):
operator = OSSDeleteBucketOperator(
Expand All @@ -60,7 +59,7 @@ def test_execute(self, mock_hook):
mock_hook.return_value.delete_bucket.assert_called_once_with(bucket_name=MOCK_BUCKET)


class TestOSSUploadObjectOperator(unittest.TestCase):
class TestOSSUploadObjectOperator:
@mock.patch("airflow.providers.alibaba.cloud.operators.oss.OSSHook")
def test_execute(self, mock_hook):
operator = OSSUploadObjectOperator(
Expand All @@ -78,7 +77,7 @@ def test_execute(self, mock_hook):
)


class TestOSSDownloadObjectOperator(unittest.TestCase):
class TestOSSDownloadObjectOperator:
@mock.patch("airflow.providers.alibaba.cloud.operators.oss.OSSHook")
def test_execute(self, mock_hook):
operator = OSSDownloadObjectOperator(
Expand All @@ -96,7 +95,7 @@ def test_execute(self, mock_hook):
)


class TestOSSDeleteBatchObjectOperator(unittest.TestCase):
class TestOSSDeleteBatchObjectOperator:
@mock.patch("airflow.providers.alibaba.cloud.operators.oss.OSSHook")
def test_execute(self, mock_hook):
operator = OSSDeleteBatchObjectOperator(
Expand All @@ -111,7 +110,7 @@ def test_execute(self, mock_hook):
mock_hook.return_value.delete_objects.assert_called_once_with(bucket_name=MOCK_BUCKET, key=MOCK_KEYS)


class TestOSSDeleteObjectOperator(unittest.TestCase):
class TestOSSDeleteObjectOperator:
@mock.patch("airflow.providers.alibaba.cloud.operators.oss.OSSHook")
def test_execute(self, mock_hook):
operator = OSSDeleteObjectOperator(
Expand Down
5 changes: 2 additions & 3 deletions tests/providers/alibaba/cloud/sensors/test_oss_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
# under the License.
from __future__ import annotations

import unittest
from unittest import mock
from unittest.mock import PropertyMock

Expand All @@ -33,8 +32,8 @@
MOCK_CONTENT = "mock_content"


class TestOSSKeySensor(unittest.TestCase):
def setUp(self):
class TestOSSKeySensor:
def setup_method(self):
self.sensor = OSSKeySensor(
bucket_key=MOCK_KEY,
oss_conn_id=MOCK_OSS_CONN_ID,
Expand Down
6 changes: 2 additions & 4 deletions tests/providers/arangodb/hooks/test_arangodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
# under the License.
from __future__ import annotations

import unittest
from unittest.mock import Mock, patch

from airflow.models import Connection
Expand All @@ -26,9 +25,8 @@
arangodb_client_mock = Mock(name="arangodb_client_for_test")


class TestArangoDBHook(unittest.TestCase):
def setUp(self):
super().setUp()
class TestArangoDBHook:
def setup_method(self):
db.merge_conn(
Connection(
conn_id="arangodb_default",
Expand Down
3 changes: 1 addition & 2 deletions tests/providers/arangodb/operators/test_arangodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,12 @@
# under the License.
from __future__ import annotations

import unittest
from unittest import mock

from airflow.providers.arangodb.operators.arangodb import AQLOperator


class TestAQLOperator(unittest.TestCase):
class TestAQLOperator:
@mock.patch("airflow.providers.arangodb.operators.arangodb.ArangoDBHook")
def test_arangodb_operator_test(self, mock_hook):

Expand Down
5 changes: 2 additions & 3 deletions tests/providers/arangodb/sensors/test_arangodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
# under the License.
from __future__ import annotations

import unittest
from unittest.mock import Mock, patch

from airflow.models import Connection
Expand All @@ -29,8 +28,8 @@
arangodb_hook_mock = Mock(name="arangodb_hook_for_test", **{"query.return_value.count.return_value": 1})


class TestAQLSensor(unittest.TestCase):
def setUp(self):
class TestAQLSensor:
def setup_method(self):
args = {"owner": "airflow", "start_date": DEFAULT_DATE}
dag = DAG("test_dag_id", default_args=args)
self.dag = dag
Expand Down
5 changes: 2 additions & 3 deletions tests/providers/asana/operators/test_asana_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
# under the License.
from __future__ import annotations

import unittest
from unittest.mock import Mock, patch

from airflow.models import Connection
Expand All @@ -34,12 +33,12 @@
asana_client_mock = Mock(name="asana_client_for_test")


class TestAsanaTaskOperators(unittest.TestCase):
class TestAsanaTaskOperators:
"""
Test that the AsanaTaskOperators are using the python-asana methods as expected.
"""

def setUp(self):
def setup_method(self):
args = {"owner": "airflow", "start_date": DEFAULT_DATE}
dag = DAG(TEST_DAG_ID, default_args=args)
self.dag = dag
Expand Down
5 changes: 2 additions & 3 deletions tests/providers/atlassian/jira/hooks/test_jira.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
# under the License.
from __future__ import annotations

import unittest
from unittest.mock import Mock, patch

from airflow.models import Connection
Expand All @@ -27,8 +26,8 @@
jira_client_mock = Mock(name="jira_client")


class TestJiraHook(unittest.TestCase):
def setUp(self):
class TestJiraHook:
def setup_method(self):
db.merge_conn(
Connection(
conn_id="jira_default",
Expand Down
5 changes: 2 additions & 3 deletions tests/providers/atlassian/jira/operators/test_jira.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
# under the License.
from __future__ import annotations

import unittest
from unittest.mock import Mock, patch

from airflow.models import Connection
Expand All @@ -39,8 +38,8 @@
}


class TestJiraOperator(unittest.TestCase):
def setUp(self):
class TestJiraOperator:
def setup_method(self):
args = {"owner": "airflow", "start_date": DEFAULT_DATE}
dag = DAG("test_dag_id", default_args=args)
self.dag = dag
Expand Down
5 changes: 2 additions & 3 deletions tests/providers/atlassian/jira/sensors/test_jira.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
# under the License.
from __future__ import annotations

import unittest
from unittest.mock import Mock, patch

from airflow.models import Connection
Expand Down Expand Up @@ -46,8 +45,8 @@ class _TicketFields:
)


class TestJiraSensor(unittest.TestCase):
def setUp(self):
class TestJiraSensor:
def setup_method(self):
args = {"owner": "airflow", "start_date": DEFAULT_DATE}
dag = DAG("test_dag_id", default_args=args)
self.dag = dag
Expand Down
5 changes: 2 additions & 3 deletions tests/providers/celery/sensors/test_celery_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,13 @@
# under the License.
from __future__ import annotations

import unittest
from unittest.mock import patch

from airflow.providers.celery.sensors.celery_queue import CeleryQueueSensor


class TestCeleryQueueSensor(unittest.TestCase):
def setUp(self):
class TestCeleryQueueSensor:
def setup_method(self):
class TestCeleryqueueSensor(CeleryQueueSensor):
def _check_task_id(self, context):
return True
Expand Down
Loading

0 comments on commit 518fd84

Please sign in to comment.