Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DockerOperator extra_hosts argument support added #10546

Merged
merged 2 commits into from
Aug 27, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions airflow/providers/docker/operators/docker.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ def __init__(
shm_size: Optional[int] = None,
tty: Optional[bool] = False,
cap_add: Optional[Iterable[str]] = None,
extra_hosts: Optional[Dict[str, str]] = None,
**kwargs,
) -> None:

Expand Down Expand Up @@ -200,6 +201,7 @@ def __init__(
self.shm_size = shm_size
self.tty = tty
self.cap_add = cap_add
self.extra_hosts = extra_hosts
if kwargs.get('xcom_push') is not None:
raise AirflowException("'xcom_push' was deprecated, use 'BaseOperator.do_xcom_push' instead")

Expand Down Expand Up @@ -244,6 +246,7 @@ def _run_image(self) -> Optional[str]:
cpu_shares=int(round(self.cpus * 1024)),
mem_limit=self.mem_limit,
cap_add=self.cap_add,
extra_hosts=self.extra_hosts,
),
image=self.image,
user=self.user,
Expand Down
167 changes: 65 additions & 102 deletions tests/providers/docker/operators/test_docker.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,23 +32,29 @@


class TestDockerOperator(unittest.TestCase):
@mock.patch('airflow.providers.docker.operators.docker.TemporaryDirectory')
@mock.patch('airflow.providers.docker.operators.docker.APIClient')
def test_execute(self, client_class_mock, tempdir_mock):
host_config = mock.Mock()
tempdir_mock.return_value.__enter__.return_value = '/mkdtemp'

client_mock = mock.Mock(spec=APIClient)
client_mock.create_container.return_value = {'Id': 'some_id'}
client_mock.create_host_config.return_value = host_config
client_mock.images.return_value = []
client_mock.attach.return_value = ['container log']
client_mock.logs.return_value = ['container log']
client_mock.pull.return_value = {"status": "pull log"}
client_mock.wait.return_value = {"StatusCode": 0}
def setUp(self):
self.tempdir_patcher = mock.patch('airflow.providers.docker.operators.docker.TemporaryDirectory')
self.tempdir_mock = self.tempdir_patcher.start()
self.tempdir_mock.return_value.__enter__.return_value = '/mkdtemp'

self.client_mock = mock.Mock(spec=APIClient)
self.client_mock.create_container.return_value = {'Id': 'some_id'}
self.client_mock.images.return_value = []
self.client_mock.attach.return_value = ['container log']
self.client_mock.logs.return_value = ['container log']
self.client_mock.pull.return_value = {"status": "pull log"}
self.client_mock.wait.return_value = {"StatusCode": 0}
self.client_mock.create_host_config.return_value = mock.Mock()
self.client_class_patcher = mock.patch(
'airflow.providers.docker.operators.docker.APIClient', return_value=self.client_mock,
)
self.client_class_mock = self.client_class_patcher.start()

client_class_mock.return_value = client_mock
def tearDown(self) -> None:
self.tempdir_patcher.stop()
self.client_class_patcher.stop()

def test_execute(self):
operator = DockerOperator(
api_version='1.19',
command='env',
Expand All @@ -67,21 +73,21 @@ def test_execute(self, client_class_mock, tempdir_mock):
)
operator.execute(None)

client_class_mock.assert_called_once_with(
self.client_class_mock.assert_called_once_with(
base_url='unix://var/run/docker.sock', tls=None, version='1.19'
)

client_mock.create_container.assert_called_once_with(
self.client_mock.create_container.assert_called_once_with(
command='env',
name='test_container',
environment={'AIRFLOW_TMP_DIR': '/tmp/airflow', 'UNIT': 'TEST', 'PRIVATE': 'MESSAGE'},
host_config=host_config,
host_config=self.client_mock.create_host_config.return_value,
image='ubuntu:latest',
user=None,
working_dir='/container/path',
tty=True,
)
client_mock.create_host_config.assert_called_once_with(
self.client_mock.create_host_config.assert_called_once_with(
binds=['/host/path:/container/path', '/mkdtemp:/tmp/airflow'],
network_mode='bridge',
shm_size=1000,
Expand All @@ -91,14 +97,17 @@ def test_execute(self, client_class_mock, tempdir_mock):
dns=None,
dns_search=None,
cap_add=None,
extra_hosts=None,
)
tempdir_mock.assert_called_once_with(dir='/host/airflow', prefix='airflowtmp')
client_mock.images.assert_called_once_with(name='ubuntu:latest')
client_mock.attach.assert_called_once_with(container='some_id', stdout=True, stderr=True, stream=True)
client_mock.pull.assert_called_once_with('ubuntu:latest', stream=True, decode=True)
client_mock.wait.assert_called_once_with('some_id')
self.tempdir_mock.assert_called_once_with(dir='/host/airflow', prefix='airflowtmp')
self.client_mock.images.assert_called_once_with(name='ubuntu:latest')
self.client_mock.attach.assert_called_once_with(
container='some_id', stdout=True, stderr=True, stream=True
)
self.client_mock.pull.assert_called_once_with('ubuntu:latest', stream=True, decode=True)
self.client_mock.wait.assert_called_once_with('some_id')
self.assertEqual(
operator.cli.pull('ubuntu:latest', stream=True, decode=True), client_mock.pull.return_value
operator.cli.pull('ubuntu:latest', stream=True, decode=True), self.client_mock.pull.return_value
)

def test_private_environment_is_private(self):
Expand All @@ -112,17 +121,7 @@ def test_private_environment_is_private(self):
)

@mock.patch('airflow.providers.docker.operators.docker.tls.TLSConfig')
@mock.patch('airflow.providers.docker.operators.docker.APIClient')
def test_execute_tls(self, client_class_mock, tls_class_mock):
client_mock = mock.Mock(spec=APIClient)
client_mock.create_container.return_value = {'Id': 'some_id'}
client_mock.create_host_config.return_value = mock.Mock()
client_mock.images.return_value = []
client_mock.attach.return_value = []
client_mock.pull.return_value = []
client_mock.wait.return_value = {"StatusCode": 0}

client_class_mock.return_value = client_mock
def test_execute_tls(self, tls_class_mock):
tls_mock = mock.Mock()
tls_class_mock.return_value = tls_mock

Expand All @@ -145,21 +144,12 @@ def test_execute_tls(self, client_class_mock, tls_class_mock):
verify=True,
)

client_class_mock.assert_called_once_with(
self.client_class_mock.assert_called_once_with(
base_url='https://127.0.0.1:2376', tls=tls_mock, version=None
)

@mock.patch('airflow.providers.docker.operators.docker.APIClient')
def test_execute_unicode_logs(self, client_class_mock):
client_mock = mock.Mock(spec=APIClient)
client_mock.create_container.return_value = {'Id': 'some_id'}
client_mock.create_host_config.return_value = mock.Mock()
client_mock.images.return_value = []
client_mock.attach.return_value = ['unicode container log 😁']
client_mock.pull.return_value = []
client_mock.wait.return_value = {"StatusCode": 0}

client_class_mock.return_value = client_mock
def test_execute_unicode_logs(self):
self.client_mock.attach.return_value = ['unicode container log 😁']

originalRaiseExceptions = logging.raiseExceptions # pylint: disable=invalid-name
logging.raiseExceptions = True
Expand All @@ -171,20 +161,9 @@ def test_execute_unicode_logs(self, client_class_mock):
logging.raiseExceptions = originalRaiseExceptions
print_exception_mock.assert_not_called()

@mock.patch('airflow.providers.docker.operators.docker.APIClient')
def test_execute_container_fails(self, client_class_mock):
client_mock = mock.Mock(spec=APIClient)
client_mock.create_container.return_value = {'Id': 'some_id'}
client_mock.create_host_config.return_value = mock.Mock()
client_mock.images.return_value = []
client_mock.attach.return_value = []
client_mock.pull.return_value = []
client_mock.wait.return_value = {"StatusCode": 1}

client_class_mock.return_value = client_mock

def test_execute_container_fails(self):
self.client_mock.wait.return_value = {"StatusCode": 1}
operator = DockerOperator(image='ubuntu', owner='unittest', task_id='unittest')

with self.assertRaises(AirflowException):
operator.execute(None)

Expand All @@ -200,23 +179,13 @@ def test_on_kill():

client_mock.stop.assert_called_once_with('some_id')

@mock.patch('airflow.providers.docker.operators.docker.APIClient')
def test_execute_no_docker_conn_id_no_hook(self, operator_client_mock):
# Mock out a Docker client, so operations don't raise errors
client_mock = mock.Mock(name='DockerOperator.APIClient mock', spec=APIClient)
client_mock.images.return_value = []
client_mock.create_container.return_value = {'Id': 'some_id'}
client_mock.attach.return_value = []
client_mock.pull.return_value = []
client_mock.wait.return_value = {"StatusCode": 0}
operator_client_mock.return_value = client_mock

def test_execute_no_docker_conn_id_no_hook(self):
# Create the DockerOperator
operator = DockerOperator(image='publicregistry/someimage', owner='unittest', task_id='unittest')

# Mock out the DockerHook
hook_mock = mock.Mock(name='DockerHook mock', spec=DockerHook)
hook_mock.get_conn.return_value = client_mock
hook_mock.get_conn.return_value = self.client_mock
operator.get_hook = mock.Mock(
name='DockerOperator.get_hook mock', spec=DockerOperator.get_hook, return_value=hook_mock
)
Expand All @@ -225,17 +194,7 @@ def test_execute_no_docker_conn_id_no_hook(self, operator_client_mock):
self.assertEqual(operator.get_hook.call_count, 0, 'Hook called though no docker_conn_id configured')

@mock.patch('airflow.providers.docker.operators.docker.DockerHook')
@mock.patch('airflow.providers.docker.operators.docker.APIClient')
def test_execute_with_docker_conn_id_use_hook(self, operator_client_mock, operator_docker_hook):
# Mock out a Docker client, so operations don't raise errors
client_mock = mock.Mock(name='DockerOperator.APIClient mock', spec=APIClient)
client_mock.images.return_value = []
client_mock.create_container.return_value = {'Id': 'some_id'}
client_mock.attach.return_value = []
client_mock.pull.return_value = []
client_mock.wait.return_value = {"StatusCode": 0}
operator_client_mock.return_value = client_mock

def test_execute_with_docker_conn_id_use_hook(self, hook_class_mock):
# Create the DockerOperator
operator = DockerOperator(
image='publicregistry/someimage',
Expand All @@ -246,32 +205,21 @@ def test_execute_with_docker_conn_id_use_hook(self, operator_client_mock, operat

# Mock out the DockerHook
hook_mock = mock.Mock(name='DockerHook mock', spec=DockerHook)
hook_mock.get_conn.return_value = client_mock
operator_docker_hook.return_value = hook_mock
hook_mock.get_conn.return_value = self.client_mock
hook_class_mock.return_value = hook_mock

operator.execute(None)

self.assertEqual(
operator_client_mock.call_count, 0, 'Client was called on the operator instead of the hook'
self.client_class_mock.call_count, 0, 'Client was called on the operator instead of the hook'
)
self.assertEqual(
operator_docker_hook.call_count, 1, 'Hook was not called although docker_conn_id configured'
hook_class_mock.call_count, 1, 'Hook was not called although docker_conn_id configured'
)
self.assertEqual(client_mock.pull.call_count, 1, 'Image was not pulled using operator client')

@mock.patch('airflow.providers.docker.operators.docker.TemporaryDirectory')
@mock.patch('airflow.providers.docker.operators.docker.APIClient')
def test_execute_xcom_behavior(self, client_class_mock, tempdir_mock):
tempdir_mock.return_value.__enter__.return_value = '/mkdtemp'

client_mock = mock.Mock(spec=APIClient)
client_mock.images.return_value = []
client_mock.create_container.return_value = {'Id': 'some_id'}
client_mock.attach.return_value = ['container log']
client_mock.pull.return_value = [b'{"status":"pull log"}']
client_mock.wait.return_value = {"StatusCode": 0}
self.assertEqual(self.client_mock.pull.call_count, 1, 'Image was not pulled using operator client')

client_class_mock.return_value = client_mock
def test_execute_xcom_behavior(self):
self.client_mock.pull.return_value = [b'{"status":"pull log"}']

kwargs = {
'api_version': '1.19',
Expand All @@ -298,3 +246,18 @@ def test_execute_xcom_behavior(self, client_class_mock, tempdir_mock):

self.assertEqual(xcom_push_result, b'container log')
self.assertIs(no_xcom_push_result, None)

def test_extra_hosts(self):
hosts_obj = mock.Mock()
operator = DockerOperator(task_id='test', image='test', extra_hosts=hosts_obj)
operator.execute(None)
self.client_mock.create_container.assert_called_once()
self.assertIn(
'host_config', self.client_mock.create_container.call_args.kwargs,
)
self.assertIn(
'extra_hosts', self.client_mock.create_host_config.call_args.kwargs,
)
self.assertIs(
hosts_obj, self.client_mock.create_host_config.call_args.kwargs['extra_hosts'],
)