Skip to content

Commit

Permalink
Refactor canceling to work through messaging and signals, not database
Browse files Browse the repository at this point in the history
If canceled attempted before, still allow attempting another cancel
in this case, attempt to send the sigterm signal again.
Keep clicking, you might help!

Use queue name to cancel task call

Replace other cancel_callbacks with sigterm watcher
  adapt special inventory mechanism for this too

Pass watcher to any dependent local tasks
  • Loading branch information
AlanCoding committed Jan 14, 2022
1 parent 6dda5f4 commit aea4c05
Show file tree
Hide file tree
Showing 6 changed files with 62 additions and 32 deletions.
18 changes: 15 additions & 3 deletions awx/main/dispatch/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,16 @@ def cleanup(self):
running_uuids.extend(list(worker.managed_tasks.keys()))
reaper.reap(excluded_uuids=running_uuids)

def cancel_job(self, celery_task_id):
for w in self.workers:
task = w.current_task
if task and task['uuid'] == celery_task_id:
logger.warn(f'Canceling task with id={celery_task_id}, task={task.get("task")}, args={task.get("args")}')
os.kill(w.pid, signal.SIGTERM)
break
else:
logger.warn(f'Could not find running process to cancel {celery_task_id}')

def up(self):
if self.full:
# if we can't spawn more workers, just toss this message into a
Expand All @@ -435,9 +445,11 @@ def write(self, preferred_queue, body):
if 'guid' in body:
GuidMiddleware.set_guid(body['guid'])
try:
# when the cluster heartbeat occurs, clean up internally
if isinstance(body, dict) and 'cluster_node_heartbeat' in body['task']:
self.cleanup()
if isinstance(body, dict):
if body['task'] == 'cancel_unified_job': # special local cancel triggered by job canceling
self.cancel_job(body['args'][0])
if 'cluster_node_heartbeat' in body['task']: # when the cluster heartbeat occurs, clean up internally
self.cleanup()
if self.should_grow:
self.up()
# we don't care about "preferred queue" round robin distribution, just
Expand Down
4 changes: 4 additions & 0 deletions awx/main/models/unified_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1407,6 +1407,10 @@ def cancel(self, job_explanation=None, is_chain=False):
cancel_fields.append('job_explanation')
self.save(update_fields=cancel_fields)
self.websocket_emit_status("canceled")
if self.celery_task_id:
from awx.main.tasks import cancel_unified_job

cancel_unified_job.apply_async([self.celery_task_id], queue=self.get_queue_name())
return self.cancel_flag

@property
Expand Down
56 changes: 31 additions & 25 deletions awx/main/tasks/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import os
from pathlib import Path
import shutil
import signal
import stat
import yaml
import tempfile
Expand All @@ -22,7 +23,6 @@
from django_guid.middleware import GuidMiddleware
from django.conf import settings
from django.db import transaction, DatabaseError
from django.utils.timezone import now


# Runner
Expand Down Expand Up @@ -76,6 +76,21 @@
logger = logging.getLogger('awx.main.tasks.jobs')


class SigtermWatcher:
SIGNALS = (signal.SIGTERM, signal.SIGINT)

def __init__(self):
self.sigterm_flag = False
for s in self.SIGNALS:
signal.signal(s, self.set_flag)

def set_flag(self, *args):
self.sigterm_flag = True

def cancel_callback(self):
return self.sigterm_flag


def with_path_cleanup(f):
@functools.wraps(f)
def _wrapped(self, *args, **kwargs):
Expand All @@ -100,13 +115,18 @@ class BaseTask(object):
event_model = None
abstract = True

def __init__(self):
def __init__(self, sigterm_watcher=None):
self.cleanup_paths = []
self.parent_workflow_job_id = None
self.host_map = {}
self.guid = GuidMiddleware.get_guid()
self.job_created = None
self.recent_event_timings = deque(maxlen=settings.MAX_WEBSOCKET_EVENT_RATE)
# start watching for SIGTERM before loading the model to catch cancel signal at any time
if sigterm_watcher:
self.sigterm_watcher = sigterm_watcher # inherit watcher from parent if local dependent task
else:
self.sigterm_watcher = SigtermWatcher()

def update_model(self, pk, _attempt=0, **updates):
"""Reload the model instance from the database and update the
Expand Down Expand Up @@ -530,22 +550,6 @@ def event_handler(self, event_data):

return False

def cancel_callback(self):
"""
Ansible runner callback to tell the job when/if it is canceled
"""
unified_job_id = self.instance.pk
self.instance = self.update_model(unified_job_id)
if not self.instance:
logger.error('unified job {} was deleted while running, canceling'.format(unified_job_id))
return True
if self.instance.cancel_flag or self.instance.status == 'canceled':
cancel_wait = (now() - self.instance.modified).seconds if self.instance.modified else 0
if cancel_wait > 5:
logger.warn('Request to cancel {} took {} seconds to complete.'.format(self.instance.log_format, cancel_wait))
return True
return False

def finished_callback(self, runner_obj):
"""
Ansible runner callback triggered on finished run
Expand Down Expand Up @@ -622,8 +626,10 @@ def run(self, pk, **kwargs):
private_data_dir = self.build_private_data_dir(self.instance)
self.pre_run_hook(self.instance, private_data_dir)
self.instance.log_lifecycle("preparing_playbook")
if self.instance.cancel_flag:

if self.instance.cancel_flag or self.sigterm_watcher.cancel_callback():
self.instance = self.update_model(self.instance.pk, status='canceled')

if self.instance.status != 'running':
# Stop the task chain and prevent starting the job if it has
# already been canceled.
Expand Down Expand Up @@ -710,11 +716,11 @@ def run(self, pk, **kwargs):
event_handler=self.event_handler,
finished_callback=self.finished_callback,
status_handler=self.status_handler,
cancel_callback=self.cancel_callback,
cancel_callback=self.sigterm_watcher.cancel_callback,
**params,
)
else:
receptor_job = AWXReceptorJob(self, params)
receptor_job = AWXReceptorJob(self, params, sigterm_watcher=self.sigterm_watcher)
res = receptor_job.run()
self.unit_id = receptor_job.unit_id

Expand Down Expand Up @@ -1118,7 +1124,7 @@ def pre_run_hook(self, job, private_data_dir):
project_update_task = local_project_sync._get_task_class()
try:
# the job private_data_dir is passed so sync can download roles and collections there
sync_task = project_update_task(job_private_data_dir=private_data_dir)
sync_task = project_update_task(job_private_data_dir=private_data_dir, sigterm_watcher=self.sigterm_watcher)
sync_task.run(local_project_sync.id)
local_project_sync.refresh_from_db()
job = self.update_model(job.pk, scm_revision=local_project_sync.scm_revision)
Expand Down Expand Up @@ -1411,7 +1417,7 @@ def _update_dependent_inventories(self, project_update, dependent_inventory_sour
local_inv_update.log_lifecycle("execution_node_chosen")
try:
create_partition(local_inv_update.event_class._meta.db_table, start=local_inv_update.created)
inv_update_class().run(local_inv_update.id)
inv_update_class(sigterm_watcher=self.sigterm_watcher).run(local_inv_update.id)
except Exception:
logger.exception('{} Unhandled exception updating dependent SCM inventory sources.'.format(project_update.log_format))

Expand Down Expand Up @@ -1870,7 +1876,7 @@ def pre_run_hook(self, inventory_update, private_data_dir):

project_update_task = local_project_sync._get_task_class()
try:
sync_task = project_update_task(job_private_data_dir=private_data_dir)
sync_task = project_update_task(job_private_data_dir=private_data_dir, sigterm_watcher=self.sigterm_watcher)
sync_task.run(local_project_sync.id)
local_project_sync.refresh_from_db()
inventory_update.inventory_source.scm_last_revision = local_project_sync.scm_revision
Expand Down Expand Up @@ -1929,7 +1935,7 @@ def post_run_hook(self, inventory_update, status):

handler = SpecialInventoryHandler(
self.event_handler,
self.cancel_callback,
self.sigterm_watcher.cancel_callback,
verbosity=inventory_update.verbosity,
job_timeout=self.get_instance_timeout(self.instance),
start_time=inventory_update.started,
Expand Down
8 changes: 5 additions & 3 deletions awx/main/tasks/receptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ def run(self):


class AWXReceptorJob:
def __init__(self, task, runner_params=None):
def __init__(self, task, runner_params=None, sigterm_watcher=None):
self.task = task
self.runner_params = runner_params
self.unit_id = None
Expand All @@ -268,6 +268,8 @@ def __init__(self, task, runner_params=None):
if not settings.IS_K8S and self.work_type == 'local' and 'only_transmit_kwargs' not in self.runner_params:
self.runner_params['only_transmit_kwargs'] = True

self.sigterm_watcher = sigterm_watcher

def run(self):
# We establish a connection to the Receptor socket
receptor_ctl = get_receptor_ctl()
Expand Down Expand Up @@ -450,11 +452,11 @@ def cancel_watcher(self, processor_future):
if processor_future.done():
return processor_future.result()

if self.task.cancel_callback():
if self.sigterm_watcher.cancel_callback():
result = namedtuple('result', ['status', 'rc'])
return result('canceled', 1)

time.sleep(1)
time.sleep(0.5)

@property
def pod_definition(self):
Expand Down
6 changes: 6 additions & 0 deletions awx/main/tasks/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,12 @@ def handle_setting_changes(setting_keys):
reconfigure_rsyslog()


@task(queue=get_local_queuename)
def cancel_unified_job(celery_task_id):
"""Triggers special action in awx.main.dispatch.pool, this is a placeholder"""
pass


@task(queue='tower_broadcast_all')
def delete_project_files(project_path):
# TODO: possibly implement some retry logic
Expand Down
2 changes: 1 addition & 1 deletion awx/main/utils/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def __init__(self, event_handler, cancel_callback, job_timeout, verbosity, start
def emit(self, record):
# check cancel and timeout status regardless of log level
this_time = now()
if (this_time - self.last_check).total_seconds() > 0.5: # cancel callback is expensive
if (this_time - self.last_check).total_seconds() > 0.1:
self.last_check = this_time
if self.cancel_callback():
raise PostRunError('Inventory update has been canceled', status='canceled')
Expand Down

0 comments on commit aea4c05

Please sign in to comment.