Skip to content

Commit

Permalink
Add optional argument for overriding get_tls_context() parameters (#8275
Browse files Browse the repository at this point in the history
)

* Add parameter for overriding get_tls_context()

* Add comments to update

* Add sample for override

* Undo sample

* Add check for overrides = None

* Fix style

* Add tests for tls overrides

* Add conditional to not add non-existent config options

* Create new overridden_instance instead of modifying instance in place

* Use deepcopy()

* Update to use instance instead of overridden_instance
  • Loading branch information
yzhan289 authored Jan 7, 2021
1 parent 24adfa2 commit 3f3b108
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 6 deletions.
25 changes: 21 additions & 4 deletions datadog_checks_base/datadog_checks/base/checks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,20 @@
import unicodedata
from collections import defaultdict, deque
from os.path import basename
from typing import TYPE_CHECKING, Any, Callable, DefaultDict, Deque, Dict, List, Optional, Sequence, Tuple, Union
from typing import (
TYPE_CHECKING,
Any,
AnyStr,
Callable,
DefaultDict,
Deque,
Dict,
List,
Optional,
Sequence,
Tuple,
Union,
)

import yaml
from six import binary_type, iteritems, text_type
Expand Down Expand Up @@ -309,15 +322,19 @@ def http(self):

return self._http

def get_tls_context(self, refresh=False):
# type: (bool) -> ssl.SSLContext
def get_tls_context(self, refresh=False, overrides=None):
# type: (bool, Dict[AnyStr, Any]) -> ssl.SSLContext
"""
Creates and cache an SSLContext instance based on user configuration.
Note that user configuration can be overridden by using `overrides`.
This should only be applied to older integration that manually set config values.
Since: Agent 7.24
"""
if not hasattr(self, '_tls_context_wrapper'):
self._tls_context_wrapper = TlsContextWrapper(self.instance or {}, self.TLS_CONFIG_REMAPPER)
self._tls_context_wrapper = TlsContextWrapper(
self.instance or {}, self.TLS_CONFIG_REMAPPER, overrides=overrides
)

if refresh:
self._tls_context_wrapper.refresh_tls_context()
Expand Down
13 changes: 11 additions & 2 deletions datadog_checks_base/datadog_checks/base/utils/tls.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import logging
import os
import ssl
from copy import deepcopy
from typing import TYPE_CHECKING, Any, AnyStr, Dict

from six import iteritems
Expand Down Expand Up @@ -35,10 +36,18 @@
class TlsContextWrapper(object):
__slots__ = ('logger', 'config', 'tls_context')

def __init__(self, instance, remapper=None):
# type: (InstanceType, Dict[AnyStr, Dict[AnyStr, Any]]) -> None
def __init__(self, instance, remapper=None, overrides=None):
# type: (InstanceType, Dict[AnyStr, Dict[AnyStr, Any]], Dict[AnyStr, Any]) -> None
default_fields = dict(STANDARD_FIELDS)

# Override existing config options if there exists any overrides
instance = deepcopy(instance)

if overrides:
for overridden_field, data in iteritems(overrides):
if instance.get(overridden_field):
instance[overridden_field] = data

# Populate with the default values
config = {field: instance.get(field, value) for field, value in iteritems(default_fields)}
for field in STANDARD_FIELDS:
Expand Down
40 changes: 40 additions & 0 deletions datadog_checks_base/tests/test_tls.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,3 +201,43 @@ def test_client_key_expanded(self):
with patch('ssl.SSLContext'), patch('os.path.expanduser') as mock_expand:
check.get_tls_context()
mock_expand.assert_called_with('~/foo')


class TestTLSContextOverrides:
def test_override_context(self):
instance = {'tls_cert': 'foo', 'tls_private_key': 'bar'}
check = AgentCheck('test', {}, [instance])

overrides = {'tls_cert': 'not_foo'}
with patch('ssl.SSLContext'):
context = check.get_tls_context(overrides=overrides) # type: MagicMock
context.load_cert_chain.assert_called_with('not_foo', keyfile='bar', password=None)

def test_override_context_empty(self):
instance = {'tls_cert': 'foo', 'tls_private_key': 'bar'}
check = AgentCheck('test', {}, [instance])

overrides = {}
with patch('ssl.SSLContext'):
context = check.get_tls_context(overrides=overrides) # type: MagicMock
context.load_cert_chain.assert_called_with('foo', keyfile='bar', password=None)

def test_override_context_wrapper_config(self):
instance = {'tls_verify': True}
overrides = {'tls_verify': False}
tls = TlsContextWrapper(instance, overrides=overrides)
assert tls.config['tls_verify'] is False
assert instance['tls_verify'] is True # Overrides should not affect the original instance

def test_override_context_wrapper_config_empty(self):
instance = {'tls_verify': True}
overrides = {}
tls = TlsContextWrapper(instance, overrides=overrides)
assert tls.config['tls_verify'] is True

def test_override_non_exist_instance_config(self):
instance = {'tls_verify': True}
overrides = {'fake_config': 'foo'}
tls = TlsContextWrapper(instance, overrides=overrides)
assert instance.get('fake_config') is None
assert tls.config['tls_verify'] is True

0 comments on commit 3f3b108

Please sign in to comment.