Skip to content

Commit

Permalink
DockerOperator extra_hosts argument support added (apache#10546)
Browse files Browse the repository at this point in the history
  • Loading branch information
bryzgaloff authored Aug 27, 2020
1 parent 91ff31a commit 2e56ee7
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 102 deletions.
3 changes: 3 additions & 0 deletions airflow/providers/docker/operators/docker.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ def __init__(
shm_size: Optional[int] = None,
tty: Optional[bool] = False,
cap_add: Optional[Iterable[str]] = None,
extra_hosts: Optional[Dict[str, str]] = None,
**kwargs,
) -> None:

Expand Down Expand Up @@ -200,6 +201,7 @@ def __init__(
self.shm_size = shm_size
self.tty = tty
self.cap_add = cap_add
self.extra_hosts = extra_hosts
if kwargs.get('xcom_push') is not None:
raise AirflowException("'xcom_push' was deprecated, use 'BaseOperator.do_xcom_push' instead")

Expand Down Expand Up @@ -244,6 +246,7 @@ def _run_image(self) -> Optional[str]:
cpu_shares=int(round(self.cpus * 1024)),
mem_limit=self.mem_limit,
cap_add=self.cap_add,
extra_hosts=self.extra_hosts,
),
image=self.image,
user=self.user,
Expand Down
167 changes: 65 additions & 102 deletions tests/providers/docker/operators/test_docker.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,23 +32,29 @@


class TestDockerOperator(unittest.TestCase):
@mock.patch('airflow.providers.docker.operators.docker.TemporaryDirectory')
@mock.patch('airflow.providers.docker.operators.docker.APIClient')
def test_execute(self, client_class_mock, tempdir_mock):
host_config = mock.Mock()
tempdir_mock.return_value.__enter__.return_value = '/mkdtemp'

client_mock = mock.Mock(spec=APIClient)
client_mock.create_container.return_value = {'Id': 'some_id'}
client_mock.create_host_config.return_value = host_config
client_mock.images.return_value = []
client_mock.attach.return_value = ['container log']
client_mock.logs.return_value = ['container log']
client_mock.pull.return_value = {"status": "pull log"}
client_mock.wait.return_value = {"StatusCode": 0}
def setUp(self):
self.tempdir_patcher = mock.patch('airflow.providers.docker.operators.docker.TemporaryDirectory')
self.tempdir_mock = self.tempdir_patcher.start()
self.tempdir_mock.return_value.__enter__.return_value = '/mkdtemp'

self.client_mock = mock.Mock(spec=APIClient)
self.client_mock.create_container.return_value = {'Id': 'some_id'}
self.client_mock.images.return_value = []
self.client_mock.attach.return_value = ['container log']
self.client_mock.logs.return_value = ['container log']
self.client_mock.pull.return_value = {"status": "pull log"}
self.client_mock.wait.return_value = {"StatusCode": 0}
self.client_mock.create_host_config.return_value = mock.Mock()
self.client_class_patcher = mock.patch(
'airflow.providers.docker.operators.docker.APIClient', return_value=self.client_mock,
)
self.client_class_mock = self.client_class_patcher.start()

client_class_mock.return_value = client_mock
def tearDown(self) -> None:
self.tempdir_patcher.stop()
self.client_class_patcher.stop()

def test_execute(self):
operator = DockerOperator(
api_version='1.19',
command='env',
Expand All @@ -67,21 +73,21 @@ def test_execute(self, client_class_mock, tempdir_mock):
)
operator.execute(None)

client_class_mock.assert_called_once_with(
self.client_class_mock.assert_called_once_with(
base_url='unix://var/run/docker.sock', tls=None, version='1.19'
)

client_mock.create_container.assert_called_once_with(
self.client_mock.create_container.assert_called_once_with(
command='env',
name='test_container',
environment={'AIRFLOW_TMP_DIR': '/tmp/airflow', 'UNIT': 'TEST', 'PRIVATE': 'MESSAGE'},
host_config=host_config,
host_config=self.client_mock.create_host_config.return_value,
image='ubuntu:latest',
user=None,
working_dir='/container/path',
tty=True,
)
client_mock.create_host_config.assert_called_once_with(
self.client_mock.create_host_config.assert_called_once_with(
binds=['/host/path:/container/path', '/mkdtemp:/tmp/airflow'],
network_mode='bridge',
shm_size=1000,
Expand All @@ -91,14 +97,17 @@ def test_execute(self, client_class_mock, tempdir_mock):
dns=None,
dns_search=None,
cap_add=None,
extra_hosts=None,
)
tempdir_mock.assert_called_once_with(dir='/host/airflow', prefix='airflowtmp')
client_mock.images.assert_called_once_with(name='ubuntu:latest')
client_mock.attach.assert_called_once_with(container='some_id', stdout=True, stderr=True, stream=True)
client_mock.pull.assert_called_once_with('ubuntu:latest', stream=True, decode=True)
client_mock.wait.assert_called_once_with('some_id')
self.tempdir_mock.assert_called_once_with(dir='/host/airflow', prefix='airflowtmp')
self.client_mock.images.assert_called_once_with(name='ubuntu:latest')
self.client_mock.attach.assert_called_once_with(
container='some_id', stdout=True, stderr=True, stream=True
)
self.client_mock.pull.assert_called_once_with('ubuntu:latest', stream=True, decode=True)
self.client_mock.wait.assert_called_once_with('some_id')
self.assertEqual(
operator.cli.pull('ubuntu:latest', stream=True, decode=True), client_mock.pull.return_value
operator.cli.pull('ubuntu:latest', stream=True, decode=True), self.client_mock.pull.return_value
)

def test_private_environment_is_private(self):
Expand All @@ -112,17 +121,7 @@ def test_private_environment_is_private(self):
)

@mock.patch('airflow.providers.docker.operators.docker.tls.TLSConfig')
@mock.patch('airflow.providers.docker.operators.docker.APIClient')
def test_execute_tls(self, client_class_mock, tls_class_mock):
client_mock = mock.Mock(spec=APIClient)
client_mock.create_container.return_value = {'Id': 'some_id'}
client_mock.create_host_config.return_value = mock.Mock()
client_mock.images.return_value = []
client_mock.attach.return_value = []
client_mock.pull.return_value = []
client_mock.wait.return_value = {"StatusCode": 0}

client_class_mock.return_value = client_mock
def test_execute_tls(self, tls_class_mock):
tls_mock = mock.Mock()
tls_class_mock.return_value = tls_mock

Expand All @@ -145,21 +144,12 @@ def test_execute_tls(self, client_class_mock, tls_class_mock):
verify=True,
)

client_class_mock.assert_called_once_with(
self.client_class_mock.assert_called_once_with(
base_url='https://127.0.0.1:2376', tls=tls_mock, version=None
)

@mock.patch('airflow.providers.docker.operators.docker.APIClient')
def test_execute_unicode_logs(self, client_class_mock):
client_mock = mock.Mock(spec=APIClient)
client_mock.create_container.return_value = {'Id': 'some_id'}
client_mock.create_host_config.return_value = mock.Mock()
client_mock.images.return_value = []
client_mock.attach.return_value = ['unicode container log 😁']
client_mock.pull.return_value = []
client_mock.wait.return_value = {"StatusCode": 0}

client_class_mock.return_value = client_mock
def test_execute_unicode_logs(self):
self.client_mock.attach.return_value = ['unicode container log 😁']

originalRaiseExceptions = logging.raiseExceptions # pylint: disable=invalid-name
logging.raiseExceptions = True
Expand All @@ -171,20 +161,9 @@ def test_execute_unicode_logs(self, client_class_mock):
logging.raiseExceptions = originalRaiseExceptions
print_exception_mock.assert_not_called()

@mock.patch('airflow.providers.docker.operators.docker.APIClient')
def test_execute_container_fails(self, client_class_mock):
client_mock = mock.Mock(spec=APIClient)
client_mock.create_container.return_value = {'Id': 'some_id'}
client_mock.create_host_config.return_value = mock.Mock()
client_mock.images.return_value = []
client_mock.attach.return_value = []
client_mock.pull.return_value = []
client_mock.wait.return_value = {"StatusCode": 1}

client_class_mock.return_value = client_mock

def test_execute_container_fails(self):
self.client_mock.wait.return_value = {"StatusCode": 1}
operator = DockerOperator(image='ubuntu', owner='unittest', task_id='unittest')

with self.assertRaises(AirflowException):
operator.execute(None)

Expand All @@ -200,23 +179,13 @@ def test_on_kill():

client_mock.stop.assert_called_once_with('some_id')

@mock.patch('airflow.providers.docker.operators.docker.APIClient')
def test_execute_no_docker_conn_id_no_hook(self, operator_client_mock):
# Mock out a Docker client, so operations don't raise errors
client_mock = mock.Mock(name='DockerOperator.APIClient mock', spec=APIClient)
client_mock.images.return_value = []
client_mock.create_container.return_value = {'Id': 'some_id'}
client_mock.attach.return_value = []
client_mock.pull.return_value = []
client_mock.wait.return_value = {"StatusCode": 0}
operator_client_mock.return_value = client_mock

def test_execute_no_docker_conn_id_no_hook(self):
# Create the DockerOperator
operator = DockerOperator(image='publicregistry/someimage', owner='unittest', task_id='unittest')

# Mock out the DockerHook
hook_mock = mock.Mock(name='DockerHook mock', spec=DockerHook)
hook_mock.get_conn.return_value = client_mock
hook_mock.get_conn.return_value = self.client_mock
operator.get_hook = mock.Mock(
name='DockerOperator.get_hook mock', spec=DockerOperator.get_hook, return_value=hook_mock
)
Expand All @@ -225,17 +194,7 @@ def test_execute_no_docker_conn_id_no_hook(self, operator_client_mock):
self.assertEqual(operator.get_hook.call_count, 0, 'Hook called though no docker_conn_id configured')

@mock.patch('airflow.providers.docker.operators.docker.DockerHook')
@mock.patch('airflow.providers.docker.operators.docker.APIClient')
def test_execute_with_docker_conn_id_use_hook(self, operator_client_mock, operator_docker_hook):
# Mock out a Docker client, so operations don't raise errors
client_mock = mock.Mock(name='DockerOperator.APIClient mock', spec=APIClient)
client_mock.images.return_value = []
client_mock.create_container.return_value = {'Id': 'some_id'}
client_mock.attach.return_value = []
client_mock.pull.return_value = []
client_mock.wait.return_value = {"StatusCode": 0}
operator_client_mock.return_value = client_mock

def test_execute_with_docker_conn_id_use_hook(self, hook_class_mock):
# Create the DockerOperator
operator = DockerOperator(
image='publicregistry/someimage',
Expand All @@ -246,32 +205,21 @@ def test_execute_with_docker_conn_id_use_hook(self, operator_client_mock, operat

# Mock out the DockerHook
hook_mock = mock.Mock(name='DockerHook mock', spec=DockerHook)
hook_mock.get_conn.return_value = client_mock
operator_docker_hook.return_value = hook_mock
hook_mock.get_conn.return_value = self.client_mock
hook_class_mock.return_value = hook_mock

operator.execute(None)

self.assertEqual(
operator_client_mock.call_count, 0, 'Client was called on the operator instead of the hook'
self.client_class_mock.call_count, 0, 'Client was called on the operator instead of the hook'
)
self.assertEqual(
operator_docker_hook.call_count, 1, 'Hook was not called although docker_conn_id configured'
hook_class_mock.call_count, 1, 'Hook was not called although docker_conn_id configured'
)
self.assertEqual(client_mock.pull.call_count, 1, 'Image was not pulled using operator client')

@mock.patch('airflow.providers.docker.operators.docker.TemporaryDirectory')
@mock.patch('airflow.providers.docker.operators.docker.APIClient')
def test_execute_xcom_behavior(self, client_class_mock, tempdir_mock):
tempdir_mock.return_value.__enter__.return_value = '/mkdtemp'

client_mock = mock.Mock(spec=APIClient)
client_mock.images.return_value = []
client_mock.create_container.return_value = {'Id': 'some_id'}
client_mock.attach.return_value = ['container log']
client_mock.pull.return_value = [b'{"status":"pull log"}']
client_mock.wait.return_value = {"StatusCode": 0}
self.assertEqual(self.client_mock.pull.call_count, 1, 'Image was not pulled using operator client')

client_class_mock.return_value = client_mock
def test_execute_xcom_behavior(self):
self.client_mock.pull.return_value = [b'{"status":"pull log"}']

kwargs = {
'api_version': '1.19',
Expand All @@ -298,3 +246,18 @@ def test_execute_xcom_behavior(self, client_class_mock, tempdir_mock):

self.assertEqual(xcom_push_result, b'container log')
self.assertIs(no_xcom_push_result, None)

def test_extra_hosts(self):
hosts_obj = mock.Mock()
operator = DockerOperator(task_id='test', image='test', extra_hosts=hosts_obj)
operator.execute(None)
self.client_mock.create_container.assert_called_once()
self.assertIn(
'host_config', self.client_mock.create_container.call_args.kwargs,
)
self.assertIn(
'extra_hosts', self.client_mock.create_host_config.call_args.kwargs,
)
self.assertIs(
hosts_obj, self.client_mock.create_host_config.call_args.kwargs['extra_hosts'],
)

0 comments on commit 2e56ee7

Please sign in to comment.