Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add optional argument for overriding get_tls_context() parameters #8275

Merged
merged 11 commits into from
Jan 7, 2021
25 changes: 21 additions & 4 deletions datadog_checks_base/datadog_checks/base/checks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,20 @@
import unicodedata
from collections import defaultdict, deque
from os.path import basename
from typing import TYPE_CHECKING, Any, Callable, DefaultDict, Deque, Dict, List, Optional, Sequence, Tuple, Union
from typing import (
TYPE_CHECKING,
Any,
AnyStr,
Callable,
DefaultDict,
Deque,
Dict,
List,
Optional,
Sequence,
Tuple,
Union,
)

import yaml
from six import binary_type, iteritems, text_type
Expand Down Expand Up @@ -309,15 +322,19 @@ def http(self):

return self._http

def get_tls_context(self, refresh=False):
# type: (bool) -> ssl.SSLContext
def get_tls_context(self, refresh=False, overrides=None):
# type: (bool, Dict[AnyStr, Any]) -> ssl.SSLContext
"""
Creates and cache an SSLContext instance based on user configuration.
Note that user configuration can be overridden by using `overrides`.
This should only be applied to older integration that manually set config values.

Since: Agent 7.24
"""
if not hasattr(self, '_tls_context_wrapper'):
self._tls_context_wrapper = TlsContextWrapper(self.instance or {}, self.TLS_CONFIG_REMAPPER)
self._tls_context_wrapper = TlsContextWrapper(
self.instance or {}, self.TLS_CONFIG_REMAPPER, overrides=overrides
)

if refresh:
self._tls_context_wrapper.refresh_tls_context()
Expand Down
10 changes: 8 additions & 2 deletions datadog_checks_base/datadog_checks/base/utils/tls.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,16 @@
class TlsContextWrapper(object):
__slots__ = ('logger', 'config', 'tls_context')

def __init__(self, instance, remapper=None):
# type: (InstanceType, Dict[AnyStr, Dict[AnyStr, Any]]) -> None
def __init__(self, instance, remapper=None, overrides=None):
# type: (InstanceType, Dict[AnyStr, Dict[AnyStr, Any]], Dict[AnyStr, Any]) -> None
default_fields = dict(STANDARD_FIELDS)

# Override existing config options if there exists any overrides
if overrides:
for overridden_field, data in iteritems(overrides):
if instance.get(overridden_field):
instance[overridden_field] = data

# Populate with the default values
config = {field: instance.get(field, value) for field, value in iteritems(default_fields)}
for field in STANDARD_FIELDS:
Expand Down
44 changes: 44 additions & 0 deletions datadog_checks_base/tests/test_tls.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,3 +201,47 @@ def test_client_key_expanded(self):
with patch('ssl.SSLContext'), patch('os.path.expanduser') as mock_expand:
check.get_tls_context()
mock_expand.assert_called_with('~/foo')


class TestTLSContextOverrides:
def test_override_context(self):
instance = {'tls_cert': 'foo', 'tls_private_key': 'bar'}
check = AgentCheck('test', {}, [instance])

overrides = {'tls_cert': 'not_foo'}
with patch('ssl.SSLContext'):
context = check.get_tls_context(overrides=overrides) # type: MagicMock
context.load_cert_chain.assert_called_with('not_foo', keyfile='bar', password=None)

def test_override_context_empty(self):
instance = {'tls_cert': 'foo', 'tls_private_key': 'bar'}
check = AgentCheck('test', {}, [instance])

overrides = {}
with patch('ssl.SSLContext'):
context = check.get_tls_context(overrides=overrides) # type: MagicMock
context.load_cert_chain.assert_called_with('foo', keyfile='bar', password=None)

def test_override_context_wrapper_config(self):
instance = {'tls_verify': True}
overrides = {'tls_verify': False}
tls = TlsContextWrapper(instance, overrides=overrides)
assert tls.config['tls_verify'] is False

def test_override_context_wrapper_config_empty(self):
instance = {'tls_verify': True}
overrides = {}
tls = TlsContextWrapper(instance, overrides=overrides)
assert tls.config['tls_verify'] is True

def test_override_instance_config(self):
instance = {'tls_verify': True}
overrides = {'tls_verify': False}
tls = TlsContextWrapper(instance, overrides=overrides)
assert instance['tls_verify'] is False

def test_override_non_exist_instance_config(self):
instance = {'tls_verify': True}
overrides = {'fake_config': 'foo'}
tls = TlsContextWrapper(instance, overrides=overrides)
assert instance.get('fake_config') is None