diff --git a/awx/main/dispatch/pool.py b/awx/main/dispatch/pool.py index 97e2fa630a1a..5f3730d9046b 100644 --- a/awx/main/dispatch/pool.py +++ b/awx/main/dispatch/pool.py @@ -422,6 +422,16 @@ def cleanup(self): running_uuids.extend(list(worker.managed_tasks.keys())) reaper.reap(excluded_uuids=running_uuids) + def cancel_job(self, celery_task_id): + for w in self.workers: + task = w.current_task + if task and task['uuid'] == celery_task_id: + logger.warn(f'Canceling task with id={celery_task_id}, task={task.get("task")}, args={task.get("args")}') + os.kill(w.pid, signal.SIGTERM) + break + else: + logger.warn(f'Could not find running process to cancel {celery_task_id}') + def up(self): if self.full: # if we can't spawn more workers, just toss this message into a @@ -435,9 +445,11 @@ def write(self, preferred_queue, body): if 'guid' in body: GuidMiddleware.set_guid(body['guid']) try: - # when the cluster heartbeat occurs, clean up internally - if isinstance(body, dict) and 'cluster_node_heartbeat' in body['task']: - self.cleanup() + if isinstance(body, dict): + if body['task'] == 'cancel_unified_job': # special local cancel triggered by job canceling + self.cancel_job(body['args'][0]) + if 'cluster_node_heartbeat' in body['task']: # when the cluster heartbeat occurs, clean up internally + self.cleanup() if self.should_grow: self.up() # we don't care about "preferred queue" round robin distribution, just diff --git a/awx/main/models/unified_jobs.py b/awx/main/models/unified_jobs.py index 489cba9799e2..05066c6a6ec7 100644 --- a/awx/main/models/unified_jobs.py +++ b/awx/main/models/unified_jobs.py @@ -1407,6 +1407,10 @@ def cancel(self, job_explanation=None, is_chain=False): cancel_fields.append('job_explanation') self.save(update_fields=cancel_fields) self.websocket_emit_status("canceled") + if self.celery_task_id: + from awx.main.tasks import cancel_unified_job + + cancel_unified_job.apply_async([self.celery_task_id], queue=self.get_queue_name()) return self.cancel_flag @property diff --git a/awx/main/tasks/jobs.py b/awx/main/tasks/jobs.py index c1a5baf36308..4bac1cf04b0a 100644 --- a/awx/main/tasks/jobs.py +++ b/awx/main/tasks/jobs.py @@ -9,6 +9,7 @@ import os from pathlib import Path import shutil +import signal import stat import yaml import tempfile @@ -22,7 +23,6 @@ from django_guid.middleware import GuidMiddleware from django.conf import settings from django.db import transaction, DatabaseError -from django.utils.timezone import now # Runner @@ -76,6 +76,21 @@ logger = logging.getLogger('awx.main.tasks.jobs') +class SigtermWatcher: + SIGNALS = (signal.SIGTERM, signal.SIGINT) + + def __init__(self): + self.sigterm_flag = False + for s in self.SIGNALS: + signal.signal(s, self.set_flag) + + def set_flag(self, *args): + self.sigterm_flag = True + + def cancel_callback(self): + return self.sigterm_flag + + def with_path_cleanup(f): @functools.wraps(f) def _wrapped(self, *args, **kwargs): @@ -100,13 +115,18 @@ class BaseTask(object): event_model = None abstract = True - def __init__(self): + def __init__(self, sigterm_watcher=None): self.cleanup_paths = [] self.parent_workflow_job_id = None self.host_map = {} self.guid = GuidMiddleware.get_guid() self.job_created = None self.recent_event_timings = deque(maxlen=settings.MAX_WEBSOCKET_EVENT_RATE) + # start watching for SIGTERM before loading the model to catch cancel signal at any time + if sigterm_watcher: + self.sigterm_watcher = sigterm_watcher # inherit watcher from parent if local dependent task + else: + self.sigterm_watcher = SigtermWatcher() def update_model(self, pk, _attempt=0, **updates): """Reload the model instance from the database and update the @@ -530,22 +550,6 @@ def event_handler(self, event_data): return False - def cancel_callback(self): - """ - Ansible runner callback to tell the job when/if it is canceled - """ - unified_job_id = self.instance.pk - self.instance = self.update_model(unified_job_id) - if not self.instance: - logger.error('unified job {} was deleted while running, canceling'.format(unified_job_id)) - return True - if self.instance.cancel_flag or self.instance.status == 'canceled': - cancel_wait = (now() - self.instance.modified).seconds if self.instance.modified else 0 - if cancel_wait > 5: - logger.warn('Request to cancel {} took {} seconds to complete.'.format(self.instance.log_format, cancel_wait)) - return True - return False - def finished_callback(self, runner_obj): """ Ansible runner callback triggered on finished run @@ -622,8 +626,10 @@ def run(self, pk, **kwargs): private_data_dir = self.build_private_data_dir(self.instance) self.pre_run_hook(self.instance, private_data_dir) self.instance.log_lifecycle("preparing_playbook") - if self.instance.cancel_flag: + + if self.instance.cancel_flag or self.sigterm_watcher.cancel_callback(): self.instance = self.update_model(self.instance.pk, status='canceled') + if self.instance.status != 'running': # Stop the task chain and prevent starting the job if it has # already been canceled. @@ -710,11 +716,11 @@ def run(self, pk, **kwargs): event_handler=self.event_handler, finished_callback=self.finished_callback, status_handler=self.status_handler, - cancel_callback=self.cancel_callback, + cancel_callback=self.sigterm_watcher.cancel_callback, **params, ) else: - receptor_job = AWXReceptorJob(self, params) + receptor_job = AWXReceptorJob(self, params, sigterm_watcher=self.sigterm_watcher) res = receptor_job.run() self.unit_id = receptor_job.unit_id @@ -1118,7 +1124,7 @@ def pre_run_hook(self, job, private_data_dir): project_update_task = local_project_sync._get_task_class() try: # the job private_data_dir is passed so sync can download roles and collections there - sync_task = project_update_task(job_private_data_dir=private_data_dir) + sync_task = project_update_task(job_private_data_dir=private_data_dir, sigterm_watcher=self.sigterm_watcher) sync_task.run(local_project_sync.id) local_project_sync.refresh_from_db() job = self.update_model(job.pk, scm_revision=local_project_sync.scm_revision) @@ -1411,7 +1417,7 @@ def _update_dependent_inventories(self, project_update, dependent_inventory_sour local_inv_update.log_lifecycle("execution_node_chosen") try: create_partition(local_inv_update.event_class._meta.db_table, start=local_inv_update.created) - inv_update_class().run(local_inv_update.id) + inv_update_class(sigterm_watcher=self.sigterm_watcher).run(local_inv_update.id) except Exception: logger.exception('{} Unhandled exception updating dependent SCM inventory sources.'.format(project_update.log_format)) @@ -1870,7 +1876,7 @@ def pre_run_hook(self, inventory_update, private_data_dir): project_update_task = local_project_sync._get_task_class() try: - sync_task = project_update_task(job_private_data_dir=private_data_dir) + sync_task = project_update_task(job_private_data_dir=private_data_dir, sigterm_watcher=self.sigterm_watcher) sync_task.run(local_project_sync.id) local_project_sync.refresh_from_db() inventory_update.inventory_source.scm_last_revision = local_project_sync.scm_revision @@ -1929,7 +1935,7 @@ def post_run_hook(self, inventory_update, status): handler = SpecialInventoryHandler( self.event_handler, - self.cancel_callback, + self.sigterm_watcher.cancel_callback, verbosity=inventory_update.verbosity, job_timeout=self.get_instance_timeout(self.instance), start_time=inventory_update.started, diff --git a/awx/main/tasks/receptor.py b/awx/main/tasks/receptor.py index 0a68800a4d08..49c7172de2a4 100644 --- a/awx/main/tasks/receptor.py +++ b/awx/main/tasks/receptor.py @@ -256,7 +256,7 @@ def run(self): class AWXReceptorJob: - def __init__(self, task, runner_params=None): + def __init__(self, task, runner_params=None, sigterm_watcher=None): self.task = task self.runner_params = runner_params self.unit_id = None @@ -268,6 +268,8 @@ def __init__(self, task, runner_params=None): if not settings.IS_K8S and self.work_type == 'local' and 'only_transmit_kwargs' not in self.runner_params: self.runner_params['only_transmit_kwargs'] = True + self.sigterm_watcher = sigterm_watcher + def run(self): # We establish a connection to the Receptor socket receptor_ctl = get_receptor_ctl() @@ -450,11 +452,11 @@ def cancel_watcher(self, processor_future): if processor_future.done(): return processor_future.result() - if self.task.cancel_callback(): + if self.sigterm_watcher.cancel_callback(): result = namedtuple('result', ['status', 'rc']) return result('canceled', 1) - time.sleep(1) + time.sleep(0.5) @property def pod_definition(self): diff --git a/awx/main/tasks/system.py b/awx/main/tasks/system.py index e596668f89fe..7a7c1ad7179e 100644 --- a/awx/main/tasks/system.py +++ b/awx/main/tasks/system.py @@ -247,6 +247,12 @@ def handle_setting_changes(setting_keys): reconfigure_rsyslog() +@task(queue=get_local_queuename) +def cancel_unified_job(celery_task_id): + """Triggers special action in awx.main.dispatch.pool, this is a placeholder""" + pass + + @task(queue='tower_broadcast_all') def delete_project_files(project_path): # TODO: possibly implement some retry logic diff --git a/awx/main/utils/handlers.py b/awx/main/utils/handlers.py index ef761159ed24..7e0276c54c6f 100644 --- a/awx/main/utils/handlers.py +++ b/awx/main/utils/handlers.py @@ -75,7 +75,7 @@ def __init__(self, event_handler, cancel_callback, job_timeout, verbosity, start def emit(self, record): # check cancel and timeout status regardless of log level this_time = now() - if (this_time - self.last_check).total_seconds() > 0.5: # cancel callback is expensive + if (this_time - self.last_check).total_seconds() > 0.1: self.last_check = this_time if self.cancel_callback(): raise PostRunError('Inventory update has been canceled', status='canceled')