Skip to content

Commit

Permalink
Merge pull request #46 from Duke-GCB/38-cleanup-on-terminate
Browse files Browse the repository at this point in the history
Installs a handler to delete active pods on termination
  • Loading branch information
dleehr authored Feb 12, 2019
2 parents 3c3199c + aaf132a commit a2fcc7a
Show file tree
Hide file tree
Showing 5 changed files with 196 additions and 22 deletions.
78 changes: 73 additions & 5 deletions calrissian/k8s.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from kubernetes import client, config, watch
import threading
import logging
import os

Expand All @@ -25,7 +26,7 @@ def load_config_get_namespace():
try:
config.load_incluster_config() # raises if not in cluster
namespace = read_file(K8S_NAMESPACE_FILE)
except config.ConfigException:
except config.config_exception.ConfigException:
config.load_kube_config()
namespace = K8S_FALLBACK_NAMESPACE
return namespace
Expand All @@ -36,6 +37,15 @@ class CalrissianJobException(Exception):


class KubernetesClient(object):
"""
Instances of this class are created by a `calrissian.job.CalrissianCommandLineJob`,
which are often running in background threads (spawned by `calrissian.executor.MultithreadedJobExecutor`
This class uses a PodMonitor to keep track of the pods it submits in a single, shared list.
KubernetesClient is responsible for telling PodMonitor after it has submitted a pod and when it knows that pod
is terminated. Using PodMonitor as a context manager (with PodMonitor() as p) acquires a lock for thread safety.
"""
def __init__(self):
self.pod = None
# load_config must happen before instantiating client
Expand All @@ -44,9 +54,11 @@ def __init__(self):
self.core_api_instance = client.CoreV1Api()

def submit_pod(self, pod_body):
pod = self.core_api_instance.create_namespaced_pod(self.namespace, pod_body)
log.info('Created k8s pod name {} with id {}'.format(pod.metadata.name, pod.metadata.uid))
self._set_pod(pod)
with PodMonitor() as monitor:
pod = self.core_api_instance.create_namespaced_pod(self.namespace, pod_body)
log.info('Created k8s pod name {} with id {}'.format(pod.metadata.name, pod.metadata.uid))
monitor.add(pod)
self._set_pod(pod)

def should_delete_pod(self):
"""
Expand All @@ -60,6 +72,12 @@ def should_delete_pod(self):
else:
return True

def delete_pod_name(self, pod_name):
try:
self.core_api_instance.delete_namespaced_pod(pod_name, self.namespace, client.V1DeleteOptions())
except client.rest.ApiException as e:
raise CalrissianJobException('Error deleting pod named {}'.format(pod_name), e)

def wait_for_completion(self):
w = watch.Watch()
for event in w.stream(self.core_api_instance.list_namespaced_pod, self.namespace, field_selector=self._get_pod_field_selector()):
Expand All @@ -73,7 +91,9 @@ def wait_for_completion(self):
log.info('Handling terminated pod name {} with id {}'.format(pod.metadata.name, pod.metadata.uid))
self._handle_terminated_state(status.state)
if self.should_delete_pod():
self.core_api_instance.delete_namespaced_pod(self.pod.metadata.name, self.namespace, client.V1DeleteOptions())
with PodMonitor() as monitor:
self.delete_pod_name(pod.metadata.name)
monitor.remove(pod)
self._clear_pod()
# stop watching for events, our pod is done. Causes wait loop to exit
w.stop()
Expand Down Expand Up @@ -151,3 +171,51 @@ def get_current_pod(self):
if not pod_name:
raise CalrissianJobException("Missing required environment variable ${}".format(POD_NAME_ENV_VARIABLE))
return self.get_pod_for_name(pod_name)


class PodMonitor(object):
"""
This class is designed to track pods submitted by KubernetesClient across different background threads,
and provide a static cleanup() method to attempt to delete those pods on termination.
Instances of this class are used as context manager, and acquire the shared lock.
The add and remove methods should only be called from inside the context block while the lock is acquired.
The static cleanup() method also acquires the lock and attempts to delete all outstanding pods.
"""
pod_names = []
lock = threading.Lock()

def __enter__(self):
PodMonitor.lock.acquire()
return self

def __exit__(self, exc_type, exc_val, exc_tb):
PodMonitor.lock.release()

# add and remove methods should be called with the lock acquired, e.g. inside PodMonitor():
def add(self, pod):
log.info('PodMonitor adding {}'.format(pod.metadata.name))
PodMonitor.pod_names.append(pod.metadata.name)

def remove(self, pod):
log.info('PodMonitor removing {}'.format(pod.metadata.name))
# This has to look up the pod by something unique
PodMonitor.pod_names.remove(pod.metadata.name)

@staticmethod
def cleanup():
with PodMonitor() as monitor:
k8s_client = KubernetesClient()
for pod_name in PodMonitor.pod_names:
log.info('PodMonitor deleting pod {}'.format(pod_name))
try:
k8s_client.delete_pod_name(pod_name)
except Exception:
log.error('Error deleting pod named {}, ignoring'.format(pod_name))
PodMonitor.pod_names = []


def delete_pods():
PodMonitor.cleanup()
37 changes: 30 additions & 7 deletions calrissian/main.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from calrissian.executor import CalrissianExecutor
from calrissian.context import CalrissianLoadingContext
from calrissian.version import version
from calrissian.k8s import delete_pods
from cwltool.main import main as cwlmain
from cwltool.argparser import arg_parser
from cwltool.context import RuntimeContext
import logging
import sys
import signal


def activate_logging():
Expand Down Expand Up @@ -36,22 +38,43 @@ def parse_arguments(parser):
return args


def handle_sigterm(signum, frame):
print('Received signal {}, deleting pods'.format(signum))
delete_pods()
sys.exit(signum)


def install_signal_handler():
"""
Installs a handler to cleanup submitted pods on termination.
This is installed on the main thread and calls there on termination.
The CalrissianExecutor is multi-threaded and will submit jobs from other threads
"""
signal.signal(signal.SIGTERM, handle_sigterm)


def main():
activate_logging()
parser = arg_parser()
add_arguments(parser)
parsed_args = parse_arguments(parser)
executor = CalrissianExecutor(parsed_args.max_ram, parsed_args.max_cores)
runtimeContext = RuntimeContext(vars(parsed_args))
runtimeContext.select_resources = executor.select_resources
result = cwlmain(args=parsed_args,
executor=executor,
loadingContext=CalrissianLoadingContext(),
runtimeContext=runtimeContext,
versionfunc=version,
)
install_signal_handler()
try:
result = cwlmain(args=parsed_args,
executor=executor,
loadingContext=CalrissianLoadingContext(),
runtimeContext=runtimeContext,
versionfunc=version,
)
finally:
# Always clean up after cwlmain
delete_pods()

return result


if __name__ == '__main__':
activate_logging()
sys.exit(main())
2 changes: 2 additions & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
import logging
logging.disable(logging.CRITICAL)
67 changes: 62 additions & 5 deletions tests/test_k8s.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from unittest import TestCase
from unittest.mock import Mock, patch, call, PropertyMock
from calrissian.k8s import load_config_get_namespace, KubernetesClient, CalrissianJobException

from calrissian.k8s import load_config_get_namespace, KubernetesClient, CalrissianJobException, PodMonitor, delete_pods

@patch('calrissian.k8s.read_file')
@patch('calrissian.k8s.config')
Expand All @@ -17,7 +16,7 @@ def test_load_config_get_namespace_incluster(self, mock_config, mock_read_file):

def test_load_config_get_namespace_external(self, mock_config, mock_read_file):
# When load_incluster_config raises an exception, call load_kube_config and assume 'default'
mock_config.ConfigException = Exception
mock_config.config_exception.ConfigException = Exception
mock_config.load_incluster_config.side_effect = Exception
namespace = load_config_get_namespace()
self.assertEqual(namespace, 'default')
Expand All @@ -37,7 +36,8 @@ def test_init(self, mock_get_namespace, mock_client):
self.assertIsNone(kc.pod)
self.assertIsNone(kc.process_exit_code)

def test_submit_pod(self, mock_get_namespace, mock_client):
@patch('calrissian.k8s.PodMonitor')
def test_submit_pod(self, mock_podmonitor, mock_get_namespace, mock_client):
mock_get_namespace.return_value = 'namespace'
mock_create_namespaced_pod = Mock()
mock_create_namespaced_pod.return_value = Mock(metadata=Mock(uid='123'))
Expand All @@ -47,6 +47,8 @@ def test_submit_pod(self, mock_get_namespace, mock_client):
kc.submit_pod(mock_body)
self.assertEqual(kc.pod.metadata.uid, '123')
self.assertEqual(mock_create_namespaced_pod.call_args, call('namespace', mock_body))
# This is to inspect `with PodMonitor() as monitor`:
self.assertTrue(mock_podmonitor.return_value.__enter__.return_value.add.called)

def setup_mock_watch(self, mock_watch, event_objects=[]):
mock_stream = Mock()
Expand Down Expand Up @@ -98,7 +100,8 @@ def test_wait_skips_pod_when_state_is_running(self, mock_watch, mock_get_namespa
self.assertIsNotNone(kc.pod)

@patch('calrissian.k8s.watch')
def test_wait_finishes_when_pod_state_is_terminated(self, mock_watch, mock_get_namespace, mock_client):
@patch('calrissian.k8s.PodMonitor')
def test_wait_finishes_when_pod_state_is_terminated(self, mock_podmonitor, mock_watch, mock_get_namespace, mock_client):
mock_pod = Mock(status=Mock(container_statuses=[Mock(state=Mock(running=None, terminated=Mock(exit_code=123), waiting=None))]))
self.setup_mock_watch(mock_watch, [mock_pod])
kc = KubernetesClient()
Expand All @@ -108,6 +111,8 @@ def test_wait_finishes_when_pod_state_is_terminated(self, mock_watch, mock_get_n
self.assertTrue(mock_watch.Watch.return_value.stop.called)
self.assertTrue(mock_client.CoreV1Api.return_value.delete_namespaced_pod.called)
self.assertIsNone(kc.pod)
# This is to inspect `with PodMonitor() as monitor`:
self.assertTrue(mock_podmonitor.return_value.__enter__.return_value.remove.called)

@patch('calrissian.k8s.watch')
@patch('calrissian.k8s.KubernetesClient.should_delete_pod')
Expand Down Expand Up @@ -196,6 +201,19 @@ def test_should_delete_pod_reads_env(self, mock_os, mock_get_namespace, mock_cli
self.assertFalse(kc.should_delete_pod())
self.assertEqual(mock_os.getenv.call_args, call('CALRISSIAN_DELETE_PODS', ''))

def test_delete_pod_name_calls_api(self, mock_get_namespace, mock_client):
kc = KubernetesClient()
kc.delete_pod_name('pod-123')
self.assertEqual('pod-123', mock_client.CoreV1Api.return_value.delete_namespaced_pod.call_args[0][0])

def test_delete_pod_name_raises(self, mock_get_namespace, mock_client):
mock_client.rest.ApiException = Exception
mock_client.CoreV1Api.return_value.delete_namespaced_pod.side_effect = mock_client.rest.ApiException
kc = KubernetesClient()
with self.assertRaises(CalrissianJobException) as context:
kc.delete_pod_name('pod-123')
self.assertIn('Error deleting pod named pod-123', str(context.exception))


class KubernetesClientStateTestCase(TestCase):

Expand Down Expand Up @@ -236,3 +254,42 @@ def test_multiple_statuses_raises(self):
with self.assertRaises(CalrissianJobException) as context:
KubernetesClient.get_first_status_or_none(self.multiple_statuses)
self.assertIn('Expected 0 or 1 container statuses, found 2', str(context.exception))


class PodMonitorTestCase(TestCase):

def make_mock_pod(self, name):
mock_metadata = Mock()
# Cannot mock name attribute without a propertymock
name_property = PropertyMock(return_value=name)
type(mock_metadata).name = name_property
return Mock(metadata=mock_metadata)

def setUp(self):
PodMonitor.pod_names = []

def test_add(self):
pod = self.make_mock_pod('pod-123')
self.assertEqual(len(PodMonitor.pod_names), 0)
with PodMonitor() as monitor:
monitor.add(pod)
self.assertEqual(PodMonitor.pod_names, ['pod-123'])

def test_remove(self):
pod2 = self.make_mock_pod('pod2')
PodMonitor.pod_names = ['pod1', 'pod2']
with PodMonitor() as monitor:
monitor.remove(pod2)
self.assertEqual(PodMonitor.pod_names, ['pod1'])

@patch('calrissian.k8s.KubernetesClient')
def test_cleanup(self, mock_client):
mock_delete_pod_name = mock_client.return_value.delete_pod_name
PodMonitor.pod_names = ['cleanup-pod']
PodMonitor.cleanup()
self.assertEqual(mock_delete_pod_name.call_args, call('cleanup-pod'))

@patch('calrissian.k8s.PodMonitor')
def test_delete_pods_calls_podmonitor(self, mock_pod_monitor):
delete_pods()
self.assertTrue(mock_pod_monitor.cleanup.called)
34 changes: 29 additions & 5 deletions tests/test_main.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from unittest import TestCase
from unittest.mock import patch, call, Mock
from calrissian.main import main, add_arguments, parse_arguments

from calrissian.main import main, add_arguments, parse_arguments, handle_sigterm, install_signal_handler


class CalrissianMainTestCase(TestCase):
Expand All @@ -13,22 +14,30 @@ class CalrissianMainTestCase(TestCase):
@patch('calrissian.main.version')
@patch('calrissian.main.parse_arguments')
@patch('calrissian.main.add_arguments')
def test_main_calls_cwlmain_returns_exit_code(self, mock_add_arguments, mock_parse_arguments, mock_version, mock_runtime_context, mock_loading_context, mock_executor, mock_arg_parser, mock_cwlmain):
@patch('calrissian.main.delete_pods')
@patch('calrissian.main.install_signal_handler')
def test_main_calls_cwlmain_returns_exit_code(self, mock_install_signal_handler, mock_delete_pods,
mock_add_arguments, mock_parse_arguments, mock_version,
mock_runtime_context, mock_loading_context, mock_executor,
mock_arg_parser, mock_cwlmain):
mock_exit_code = Mock()
mock_cwlmain.return_value = mock_exit_code
mock_cwlmain.return_value = mock_exit_code # not called before main
result = main()
self.assertTrue(mock_arg_parser.called)
self.assertEqual(mock_add_arguments.call_args, call(mock_arg_parser.return_value))
self.assertEqual(mock_parse_arguments.call_args, call(mock_arg_parser.return_value))
self.assertEqual(mock_executor.call_args, call(mock_parse_arguments.return_value.max_ram, mock_parse_arguments.return_value.max_cores))
self.assertEqual(mock_executor.call_args,
call(mock_parse_arguments.return_value.max_ram, mock_parse_arguments.return_value.max_cores))
self.assertTrue(mock_runtime_context.called)
self.assertEqual(mock_cwlmain.call_args, call(args=mock_parse_arguments.return_value,
executor=mock_executor.return_value,
loadingContext=mock_loading_context.return_value,
runtimeContext=mock_runtime_context.return_value,
versionfunc=mock_version))
self.assertEqual(mock_runtime_context.return_value.select_resources, mock_executor.return_value.select_resources)
self.assertEqual(mock_runtime_context.return_value.select_resources,
mock_executor.return_value.select_resources)
self.assertEqual(result, mock_exit_code)
self.assertTrue(mock_delete_pods.called) # called after main()

def test_add_arguments(self):
mock_parser = Mock()
Expand Down Expand Up @@ -66,3 +75,18 @@ def test_parse_arguments_exits_with_version(self, mock_print_version, mock_sys):
self.assertEqual(parsed, mock_parser.parse_args.return_value)
self.assertTrue(mock_print_version.called)
self.assertEqual(mock_sys.exit.call_args, call(0))

@patch('calrissian.main.sys')
@patch('calrissian.main.delete_pods')
def test_handle_sigterm_exits_with_signal(self, mock_delete_pods, mock_sys):
frame = Mock()
signum = 15
handle_sigterm(signum, frame)
self.assertEqual(mock_sys.exit.call_args, call(signum))
self.assertTrue(mock_delete_pods.called)

@patch('calrissian.main.signal')
@patch('calrissian.main.handle_sigterm')
def test_install_signal_handler(self, mock_handle_sigterm, mock_signal):
install_signal_handler()
self.assertEqual(mock_signal.signal.call_args, call(mock_signal.SIGTERM, mock_handle_sigterm))

0 comments on commit a2fcc7a

Please sign in to comment.