diff --git a/airflow/providers/docker/operators/docker.py b/airflow/providers/docker/operators/docker.py index 6148f5b367804..10f2153af90c7 100644 --- a/airflow/providers/docker/operators/docker.py +++ b/airflow/providers/docker/operators/docker.py @@ -167,6 +167,7 @@ def __init__( shm_size: Optional[int] = None, tty: Optional[bool] = False, cap_add: Optional[Iterable[str]] = None, + extra_hosts: Optional[Dict[str, str]] = None, **kwargs, ) -> None: @@ -200,6 +201,7 @@ def __init__( self.shm_size = shm_size self.tty = tty self.cap_add = cap_add + self.extra_hosts = extra_hosts if kwargs.get('xcom_push') is not None: raise AirflowException("'xcom_push' was deprecated, use 'BaseOperator.do_xcom_push' instead") @@ -244,6 +246,7 @@ def _run_image(self) -> Optional[str]: cpu_shares=int(round(self.cpus * 1024)), mem_limit=self.mem_limit, cap_add=self.cap_add, + extra_hosts=self.extra_hosts, ), image=self.image, user=self.user, diff --git a/tests/providers/docker/operators/test_docker.py b/tests/providers/docker/operators/test_docker.py index b3258ba0d4227..76d74fc8971d0 100644 --- a/tests/providers/docker/operators/test_docker.py +++ b/tests/providers/docker/operators/test_docker.py @@ -32,23 +32,29 @@ class TestDockerOperator(unittest.TestCase): - @mock.patch('airflow.providers.docker.operators.docker.TemporaryDirectory') - @mock.patch('airflow.providers.docker.operators.docker.APIClient') - def test_execute(self, client_class_mock, tempdir_mock): - host_config = mock.Mock() - tempdir_mock.return_value.__enter__.return_value = '/mkdtemp' - - client_mock = mock.Mock(spec=APIClient) - client_mock.create_container.return_value = {'Id': 'some_id'} - client_mock.create_host_config.return_value = host_config - client_mock.images.return_value = [] - client_mock.attach.return_value = ['container log'] - client_mock.logs.return_value = ['container log'] - client_mock.pull.return_value = {"status": "pull log"} - client_mock.wait.return_value = {"StatusCode": 0} + def setUp(self): + self.tempdir_patcher = mock.patch('airflow.providers.docker.operators.docker.TemporaryDirectory') + self.tempdir_mock = self.tempdir_patcher.start() + self.tempdir_mock.return_value.__enter__.return_value = '/mkdtemp' + + self.client_mock = mock.Mock(spec=APIClient) + self.client_mock.create_container.return_value = {'Id': 'some_id'} + self.client_mock.images.return_value = [] + self.client_mock.attach.return_value = ['container log'] + self.client_mock.logs.return_value = ['container log'] + self.client_mock.pull.return_value = {"status": "pull log"} + self.client_mock.wait.return_value = {"StatusCode": 0} + self.client_mock.create_host_config.return_value = mock.Mock() + self.client_class_patcher = mock.patch( + 'airflow.providers.docker.operators.docker.APIClient', return_value=self.client_mock, + ) + self.client_class_mock = self.client_class_patcher.start() - client_class_mock.return_value = client_mock + def tearDown(self) -> None: + self.tempdir_patcher.stop() + self.client_class_patcher.stop() + def test_execute(self): operator = DockerOperator( api_version='1.19', command='env', @@ -67,21 +73,21 @@ def test_execute(self, client_class_mock, tempdir_mock): ) operator.execute(None) - client_class_mock.assert_called_once_with( + self.client_class_mock.assert_called_once_with( base_url='unix://var/run/docker.sock', tls=None, version='1.19' ) - client_mock.create_container.assert_called_once_with( + self.client_mock.create_container.assert_called_once_with( command='env', name='test_container', environment={'AIRFLOW_TMP_DIR': '/tmp/airflow', 'UNIT': 'TEST', 'PRIVATE': 'MESSAGE'}, - host_config=host_config, + host_config=self.client_mock.create_host_config.return_value, image='ubuntu:latest', user=None, working_dir='/container/path', tty=True, ) - client_mock.create_host_config.assert_called_once_with( + self.client_mock.create_host_config.assert_called_once_with( binds=['/host/path:/container/path', '/mkdtemp:/tmp/airflow'], network_mode='bridge', shm_size=1000, @@ -91,14 +97,17 @@ def test_execute(self, client_class_mock, tempdir_mock): dns=None, dns_search=None, cap_add=None, + extra_hosts=None, ) - tempdir_mock.assert_called_once_with(dir='/host/airflow', prefix='airflowtmp') - client_mock.images.assert_called_once_with(name='ubuntu:latest') - client_mock.attach.assert_called_once_with(container='some_id', stdout=True, stderr=True, stream=True) - client_mock.pull.assert_called_once_with('ubuntu:latest', stream=True, decode=True) - client_mock.wait.assert_called_once_with('some_id') + self.tempdir_mock.assert_called_once_with(dir='/host/airflow', prefix='airflowtmp') + self.client_mock.images.assert_called_once_with(name='ubuntu:latest') + self.client_mock.attach.assert_called_once_with( + container='some_id', stdout=True, stderr=True, stream=True + ) + self.client_mock.pull.assert_called_once_with('ubuntu:latest', stream=True, decode=True) + self.client_mock.wait.assert_called_once_with('some_id') self.assertEqual( - operator.cli.pull('ubuntu:latest', stream=True, decode=True), client_mock.pull.return_value + operator.cli.pull('ubuntu:latest', stream=True, decode=True), self.client_mock.pull.return_value ) def test_private_environment_is_private(self): @@ -112,17 +121,7 @@ def test_private_environment_is_private(self): ) @mock.patch('airflow.providers.docker.operators.docker.tls.TLSConfig') - @mock.patch('airflow.providers.docker.operators.docker.APIClient') - def test_execute_tls(self, client_class_mock, tls_class_mock): - client_mock = mock.Mock(spec=APIClient) - client_mock.create_container.return_value = {'Id': 'some_id'} - client_mock.create_host_config.return_value = mock.Mock() - client_mock.images.return_value = [] - client_mock.attach.return_value = [] - client_mock.pull.return_value = [] - client_mock.wait.return_value = {"StatusCode": 0} - - client_class_mock.return_value = client_mock + def test_execute_tls(self, tls_class_mock): tls_mock = mock.Mock() tls_class_mock.return_value = tls_mock @@ -145,21 +144,12 @@ def test_execute_tls(self, client_class_mock, tls_class_mock): verify=True, ) - client_class_mock.assert_called_once_with( + self.client_class_mock.assert_called_once_with( base_url='https://127.0.0.1:2376', tls=tls_mock, version=None ) - @mock.patch('airflow.providers.docker.operators.docker.APIClient') - def test_execute_unicode_logs(self, client_class_mock): - client_mock = mock.Mock(spec=APIClient) - client_mock.create_container.return_value = {'Id': 'some_id'} - client_mock.create_host_config.return_value = mock.Mock() - client_mock.images.return_value = [] - client_mock.attach.return_value = ['unicode container log 😁'] - client_mock.pull.return_value = [] - client_mock.wait.return_value = {"StatusCode": 0} - - client_class_mock.return_value = client_mock + def test_execute_unicode_logs(self): + self.client_mock.attach.return_value = ['unicode container log 😁'] originalRaiseExceptions = logging.raiseExceptions # pylint: disable=invalid-name logging.raiseExceptions = True @@ -171,20 +161,9 @@ def test_execute_unicode_logs(self, client_class_mock): logging.raiseExceptions = originalRaiseExceptions print_exception_mock.assert_not_called() - @mock.patch('airflow.providers.docker.operators.docker.APIClient') - def test_execute_container_fails(self, client_class_mock): - client_mock = mock.Mock(spec=APIClient) - client_mock.create_container.return_value = {'Id': 'some_id'} - client_mock.create_host_config.return_value = mock.Mock() - client_mock.images.return_value = [] - client_mock.attach.return_value = [] - client_mock.pull.return_value = [] - client_mock.wait.return_value = {"StatusCode": 1} - - client_class_mock.return_value = client_mock - + def test_execute_container_fails(self): + self.client_mock.wait.return_value = {"StatusCode": 1} operator = DockerOperator(image='ubuntu', owner='unittest', task_id='unittest') - with self.assertRaises(AirflowException): operator.execute(None) @@ -200,23 +179,13 @@ def test_on_kill(): client_mock.stop.assert_called_once_with('some_id') - @mock.patch('airflow.providers.docker.operators.docker.APIClient') - def test_execute_no_docker_conn_id_no_hook(self, operator_client_mock): - # Mock out a Docker client, so operations don't raise errors - client_mock = mock.Mock(name='DockerOperator.APIClient mock', spec=APIClient) - client_mock.images.return_value = [] - client_mock.create_container.return_value = {'Id': 'some_id'} - client_mock.attach.return_value = [] - client_mock.pull.return_value = [] - client_mock.wait.return_value = {"StatusCode": 0} - operator_client_mock.return_value = client_mock - + def test_execute_no_docker_conn_id_no_hook(self): # Create the DockerOperator operator = DockerOperator(image='publicregistry/someimage', owner='unittest', task_id='unittest') # Mock out the DockerHook hook_mock = mock.Mock(name='DockerHook mock', spec=DockerHook) - hook_mock.get_conn.return_value = client_mock + hook_mock.get_conn.return_value = self.client_mock operator.get_hook = mock.Mock( name='DockerOperator.get_hook mock', spec=DockerOperator.get_hook, return_value=hook_mock ) @@ -225,17 +194,7 @@ def test_execute_no_docker_conn_id_no_hook(self, operator_client_mock): self.assertEqual(operator.get_hook.call_count, 0, 'Hook called though no docker_conn_id configured') @mock.patch('airflow.providers.docker.operators.docker.DockerHook') - @mock.patch('airflow.providers.docker.operators.docker.APIClient') - def test_execute_with_docker_conn_id_use_hook(self, operator_client_mock, operator_docker_hook): - # Mock out a Docker client, so operations don't raise errors - client_mock = mock.Mock(name='DockerOperator.APIClient mock', spec=APIClient) - client_mock.images.return_value = [] - client_mock.create_container.return_value = {'Id': 'some_id'} - client_mock.attach.return_value = [] - client_mock.pull.return_value = [] - client_mock.wait.return_value = {"StatusCode": 0} - operator_client_mock.return_value = client_mock - + def test_execute_with_docker_conn_id_use_hook(self, hook_class_mock): # Create the DockerOperator operator = DockerOperator( image='publicregistry/someimage', @@ -246,32 +205,21 @@ def test_execute_with_docker_conn_id_use_hook(self, operator_client_mock, operat # Mock out the DockerHook hook_mock = mock.Mock(name='DockerHook mock', spec=DockerHook) - hook_mock.get_conn.return_value = client_mock - operator_docker_hook.return_value = hook_mock + hook_mock.get_conn.return_value = self.client_mock + hook_class_mock.return_value = hook_mock operator.execute(None) self.assertEqual( - operator_client_mock.call_count, 0, 'Client was called on the operator instead of the hook' + self.client_class_mock.call_count, 0, 'Client was called on the operator instead of the hook' ) self.assertEqual( - operator_docker_hook.call_count, 1, 'Hook was not called although docker_conn_id configured' + hook_class_mock.call_count, 1, 'Hook was not called although docker_conn_id configured' ) - self.assertEqual(client_mock.pull.call_count, 1, 'Image was not pulled using operator client') - - @mock.patch('airflow.providers.docker.operators.docker.TemporaryDirectory') - @mock.patch('airflow.providers.docker.operators.docker.APIClient') - def test_execute_xcom_behavior(self, client_class_mock, tempdir_mock): - tempdir_mock.return_value.__enter__.return_value = '/mkdtemp' - - client_mock = mock.Mock(spec=APIClient) - client_mock.images.return_value = [] - client_mock.create_container.return_value = {'Id': 'some_id'} - client_mock.attach.return_value = ['container log'] - client_mock.pull.return_value = [b'{"status":"pull log"}'] - client_mock.wait.return_value = {"StatusCode": 0} + self.assertEqual(self.client_mock.pull.call_count, 1, 'Image was not pulled using operator client') - client_class_mock.return_value = client_mock + def test_execute_xcom_behavior(self): + self.client_mock.pull.return_value = [b'{"status":"pull log"}'] kwargs = { 'api_version': '1.19', @@ -298,3 +246,18 @@ def test_execute_xcom_behavior(self, client_class_mock, tempdir_mock): self.assertEqual(xcom_push_result, b'container log') self.assertIs(no_xcom_push_result, None) + + def test_extra_hosts(self): + hosts_obj = mock.Mock() + operator = DockerOperator(task_id='test', image='test', extra_hosts=hosts_obj) + operator.execute(None) + self.client_mock.create_container.assert_called_once() + self.assertIn( + 'host_config', self.client_mock.create_container.call_args.kwargs, + ) + self.assertIn( + 'extra_hosts', self.client_mock.create_host_config.call_args.kwargs, + ) + self.assertIs( + hosts_obj, self.client_mock.create_host_config.call_args.kwargs['extra_hosts'], + )