Skip to content

Commit

Permalink
Refactor canceling to work through messaging and signals, not database
Browse files Browse the repository at this point in the history
If canceled attempted before, still allow attempting another cancel
in this case, attempt to send the sigterm signal again.
Keep clicking, you might help!

Use queue name to cancel task call

Replace other cancel_callbacks with sigterm watcher
  adapt special inventory mechanism for this too

Pass watcher to any dependent local tasks

Move task to on_commit for race conditions

Remove existing shutdown reaping in favor of new stuff
  • Loading branch information
AlanCoding committed Mar 12, 2022
1 parent 22ad724 commit 537f3c3
Show file tree
Hide file tree
Showing 9 changed files with 119 additions and 43 deletions.
18 changes: 15 additions & 3 deletions awx/main/dispatch/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,16 @@ def cleanup(self):
running_uuids.extend(list(worker.managed_tasks.keys()))
reaper.reap(excluded_uuids=running_uuids)

def cancel_worker_process(self, celery_task_id):
for w in self.workers:
task = w.current_task
if task and task['uuid'] == celery_task_id:
logger.warn(f'Canceling task with id={celery_task_id}, task={task.get("task")}, args={task.get("args")}')
os.kill(w.pid, signal.SIGTERM)
break
else:
logger.warn(f'Could not find running process to cancel {celery_task_id}')

def up(self):
if self.full:
# if we can't spawn more workers, just toss this message into a
Expand All @@ -438,9 +448,11 @@ def write(self, preferred_queue, body):
if 'guid' in body:
GuidMiddleware.set_guid(body['guid'])
try:
# when the cluster heartbeat occurs, clean up internally
if isinstance(body, dict) and 'cluster_node_heartbeat' in body['task']:
self.cleanup()
if isinstance(body, dict) and 'task' in body:
if body['task'].endswith('.cancel_control_process'): # special local cancel triggered by job canceling
self.cancel_worker_process(body['args'][0])
if 'cluster_node_heartbeat' in body['task']: # when the cluster heartbeat occurs, clean up internally
self.cleanup()
if self.should_grow:
self.up()
# we don't care about "preferred queue" round robin distribution, just
Expand Down
15 changes: 15 additions & 0 deletions awx/main/models/unified_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1407,6 +1407,21 @@ def cancel(self, job_explanation=None, is_chain=False):
cancel_fields.append('job_explanation')
self.save(update_fields=cancel_fields)
self.websocket_emit_status("canceled")

def actually_cancel():
if self.celery_task_id:
from awx.main.tasks.system import cancel_control_process

# This task runs logic in the main dispatcher process
# so the sigterm will be issued without waiting in the multiprocessing queue
# this is important so users can cancel jobs in an overloaded system
cancel_control_process.apply_async([self.celery_task_id], queue=self.get_queue_name())
else:
from awx.main.tasks.system import cancel_unified_job

cancel_unified_job.apply_async([self.id], queue=self.get_queue_name())

connection.on_commit(actually_cancel)
return self.cancel_flag

@property
Expand Down
20 changes: 3 additions & 17 deletions awx/main/tasks/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
import stat

# Django
from django.utils.timezone import now
from django.conf import settings
from django_guid.middleware import GuidMiddleware
from django.db import connections

# AWX
from awx.main.redact import UriCleaner
Expand Down Expand Up @@ -142,22 +142,6 @@ def event_handler(self, event_data):

return False

def cancel_callback(self):
"""
Ansible runner callback to tell the job when/if it is canceled
"""
unified_job_id = self.instance.pk
self.instance.refresh_from_db()
if not self.instance:
logger.error('unified job {} was deleted while running, canceling'.format(unified_job_id))
return True
if self.instance.cancel_flag or self.instance.status == 'canceled':
cancel_wait = (now() - self.instance.modified).seconds if self.instance.modified else 0
if cancel_wait > 5:
logger.warn('Request to cancel {} took {} seconds to complete.'.format(self.instance.log_format, cancel_wait))
return True
return False

def finished_callback(self, runner_obj):
"""
Ansible runner callback triggered on finished run
Expand Down Expand Up @@ -186,6 +170,8 @@ def status_handler(self, status_data, runner_config):

with disable_activity_stream():
self.instance = self.update_model(self.instance.pk, job_args=json.dumps(runner_config.command), job_cwd=runner_config.cwd, job_env=job_env)
# We opened a connection just for that save, close it here now
connections.close_all()
elif status_data['status'] == 'failed':
# For encrypted ssh_key_data, ansible-runner worker will open and write the
# ssh_key_data to a named pipe. Then, once the podman container starts, ssh-agent will
Expand Down
39 changes: 31 additions & 8 deletions awx/main/tasks/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import os
from pathlib import Path
import shutil
import signal
import stat
import yaml
import tempfile
Expand Down Expand Up @@ -86,6 +87,21 @@
logger = logging.getLogger('awx.main.tasks.jobs')


class SigtermWatcher:
SIGNALS = (signal.SIGTERM, signal.SIGINT)

def __init__(self):
self.sigterm_flag = False
for s in self.SIGNALS:
signal.signal(s, self.set_flag)

def set_flag(self, *args):
self.sigterm_flag = True

def cancel_callback(self):
return self.sigterm_flag


def with_path_cleanup(f):
@functools.wraps(f)
def _wrapped(self, *args, **kwargs):
Expand All @@ -111,9 +127,14 @@ class BaseTask(object):
abstract = True
callback_class = RunnerCallback

def __init__(self):
def __init__(self, sigterm_watcher=None):
self.cleanup_paths = []
self.runner_callback = self.callback_class(model=self.model)
# start watching for SIGTERM before loading the model to catch cancel signal at any time
if sigterm_watcher:
self.sigterm_watcher = sigterm_watcher # inherit watcher from parent if local dependent task
else:
self.sigterm_watcher = SigtermWatcher()

def update_model(self, pk, _attempt=0, **updates):
return update_model(self.model, pk, _attempt=0, **updates)
Expand Down Expand Up @@ -455,8 +476,10 @@ def run(self, pk, **kwargs):
private_data_dir = self.build_private_data_dir(self.instance)
self.pre_run_hook(self.instance, private_data_dir)
self.instance.log_lifecycle("preparing_playbook")
if self.instance.cancel_flag:

if self.instance.cancel_flag or self.sigterm_watcher.cancel_callback():
self.instance = self.update_model(self.instance.pk, status='canceled')

if self.instance.status != 'running':
# Stop the task chain and prevent starting the job if it has
# already been canceled.
Expand Down Expand Up @@ -551,11 +574,11 @@ def run(self, pk, **kwargs):
event_handler=self.runner_callback.event_handler,
finished_callback=self.runner_callback.finished_callback,
status_handler=self.runner_callback.status_handler,
cancel_callback=self.runner_callback.cancel_callback,
cancel_callback=self.sigterm_watcher.cancel_callback,
**params,
)
else:
receptor_job = AWXReceptorJob(self, params)
receptor_job = AWXReceptorJob(self, params, sigterm_watcher=self.sigterm_watcher)
res = receptor_job.run()
self.unit_id = receptor_job.unit_id

Expand Down Expand Up @@ -941,7 +964,7 @@ def pre_run_hook(self, job, private_data_dir):
project_update_task = local_project_sync._get_task_class()
try:
# the job private_data_dir is passed so sync can download roles and collections there
sync_task = project_update_task(job_private_data_dir=private_data_dir)
sync_task = project_update_task(job_private_data_dir=private_data_dir, sigterm_watcher=self.sigterm_watcher)
sync_task.run(local_project_sync.id)
local_project_sync.refresh_from_db()
job = self.update_model(job.pk, scm_revision=local_project_sync.scm_revision)
Expand Down Expand Up @@ -1225,7 +1248,7 @@ def _update_dependent_inventories(self, project_update, dependent_inventory_sour
local_inv_update.log_lifecycle("execution_node_chosen")
try:
create_partition(local_inv_update.event_class._meta.db_table, start=local_inv_update.created)
inv_update_class().run(local_inv_update.id)
inv_update_class(sigterm_watcher=self.sigterm_watcher).run(local_inv_update.id)
except Exception:
logger.exception('{} Unhandled exception updating dependent SCM inventory sources.'.format(project_update.log_format))

Expand Down Expand Up @@ -1684,7 +1707,7 @@ def pre_run_hook(self, inventory_update, private_data_dir):

project_update_task = local_project_sync._get_task_class()
try:
sync_task = project_update_task(job_private_data_dir=private_data_dir)
sync_task = project_update_task(job_private_data_dir=private_data_dir, sigterm_watcher=self.sigterm_watcher)
sync_task.run(local_project_sync.id)
local_project_sync.refresh_from_db()
inventory_update.inventory_source.scm_last_revision = local_project_sync.scm_revision
Expand Down Expand Up @@ -1744,7 +1767,7 @@ def post_run_hook(self, inventory_update, status):

handler = SpecialInventoryHandler(
self.runner_callback.event_handler,
self.runner_callback.cancel_callback,
self.sigterm_watcher.cancel_callback,
verbosity=inventory_update.verbosity,
job_timeout=self.get_instance_timeout(self.instance),
start_time=inventory_update.started,
Expand Down
28 changes: 21 additions & 7 deletions awx/main/tasks/receptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

# Django
from django.conf import settings
from django.db import connections
from django.utils.translation import ugettext_lazy as _

# Runner
import ansible_runner
Expand Down Expand Up @@ -247,8 +249,11 @@ def worker_cleanup(node_name, vargs, timeout=300.0):
return stdout


AnsibleRunnerResult = namedtuple('result', ['status', 'rc'])


class AWXReceptorJob:
def __init__(self, task, runner_params=None):
def __init__(self, task, runner_params=None, sigterm_watcher=None):
self.task = task
self.runner_params = runner_params
self.unit_id = None
Expand All @@ -260,6 +265,8 @@ def __init__(self, task, runner_params=None):
if not settings.IS_K8S and self.work_type == 'local' and 'only_transmit_kwargs' not in self.runner_params:
self.runner_params['only_transmit_kwargs'] = True

self.sigterm_watcher = sigterm_watcher

def run(self):
# We establish a connection to the Receptor socket
receptor_ctl = get_receptor_ctl()
Expand All @@ -275,6 +282,7 @@ def run(self):
receptor_ctl.simple_command(f"work release {self.unit_id}")
except Exception:
logger.exception(f"Error releasing work unit {self.unit_id}.")
receptor_ctl.close()

@property
def sign_work(self):
Expand Down Expand Up @@ -330,6 +338,9 @@ def _run_internal(self, receptor_ctl):
shutil.rmtree(artifact_dir)

resultsock, resultfile = receptor_ctl.get_work_results(self.unit_id, return_socket=True, return_sockfile=True)

connections.close_all()

# Both "processor" and "cancel_watcher" are spawned in separate threads.
# We wait for the first one to return. If cancel_watcher returns first,
# we yank the socket out from underneath the processor, which will cause it
Expand All @@ -345,8 +356,12 @@ def _run_internal(self, receptor_ctl):
res = list(first_future.done)[0].result()
if res.status == 'canceled':
receptor_ctl.simple_command(f"work cancel {self.unit_id}")
resultsock.shutdown(socket.SHUT_RDWR)
resultfile.close()
# TODO: abort without status transition, recover later by restarting the processing step
self.task.instance.refresh_from_db(fields=['cancel_flag'])
if not self.task.instance.cancel_flag:
self.task.instance.job_explanation = _('Control process received shutdown signal and aborted job')
self.task.instance.save(update_fields=['job_explanation'])
return AnsibleRunnerResult('error', 1)
elif res.status == 'error':
try:
unit_status = receptor_ctl.simple_command(f'work status {self.unit_id}')
Expand Down Expand Up @@ -449,11 +464,10 @@ def cancel_watcher(self, processor_future):
if processor_future.done():
return processor_future.result()

if self.task.runner_callback.cancel_callback():
result = namedtuple('result', ['status', 'rc'])
return result('canceled', 1)
if self.sigterm_watcher.cancel_callback():
return AnsibleRunnerResult('canceled', 1)

time.sleep(1)
time.sleep(0.5)

@property
def pod_definition(self):
Expand Down
35 changes: 31 additions & 4 deletions awx/main/tasks/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,6 @@ def inform_cluster_of_shutdown():
try:
this_inst = Instance.objects.get(hostname=settings.CLUSTER_HOST_ID)
this_inst.mark_offline(update_last_seen=True, errors=_('Instance received normal shutdown signal'))
try:
reaper.reap(this_inst)
except Exception:
logger.exception('failed to reap jobs for {}'.format(this_inst.hostname))
logger.warning('Normal shutdown signal for instance {}, ' 'removed self from capacity pool.'.format(this_inst.hostname))
except Exception:
logger.exception('Encountered problem with normal shutdown signal.')
Expand Down Expand Up @@ -247,6 +243,37 @@ def handle_setting_changes(setting_keys):
reconfigure_rsyslog()


@task(queue=get_local_queuename)
def cancel_control_process(celery_task_id):
"""Triggers special action in awx.main.dispatch.pool, this is a placeholder"""
pass


@task(queue=get_local_queuename)
def cancel_unified_job(unified_job_id):
"""
This method exits to assure cancelation of jobs which had not yet been assigned
a celery_task_id, which should be pending and waiting jobs
"""
try:
unified_job = UnifiedJob.objects.get(pk=unified_job_id)
except UnifiedJob.DoesNotExist:
logger.info(f'Job id {unified_job_id} has been deleted, aborting cancel')
return
while unified_job.status in ACTIVE_STATES:
if unified_job.celery_task_id:
cancel_control_process.delay(unified_job.celery_task_id)
logger.warning(f'sigterm issued to {unified_job.log_format} after it obtained a task id')
return
try:
unified_job.refresh_from_db(fields=['status'])
except unified_job.DoesNotExist:
logger.info(f'Job id {unified_job_id} has been deleted, cancel aborted')
return
time.sleep(1)
logger.info(f'{unified_job.log_format} stopped before obtaining a task id, sigterm not needed')


@task(queue='tower_broadcast_all')
def delete_project_files(project_path):
# TODO: possibly implement some retry logic
Expand Down
2 changes: 1 addition & 1 deletion awx/main/utils/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def __init__(self, event_handler, cancel_callback, job_timeout, verbosity, start
def emit(self, record):
# check cancel and timeout status regardless of log level
this_time = now()
if (this_time - self.last_check).total_seconds() > 0.5: # cancel callback is expensive
if (this_time - self.last_check).total_seconds() > 0.1:
self.last_check = this_time
if self.cancel_callback():
raise PostRunError('Inventory update has been canceled', status='canceled')
Expand Down
2 changes: 1 addition & 1 deletion docs/ansible_runner_integration.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ In AWX, a task of a certain job type is kicked off (_i.e._, RunJob, RunProjectUp

The callbacks and handlers are:
* `event_handler`: Called each time a new event is created in `ansible-runner`. AWX will dispatch the event to `redis` to be processed on the other end by the callback receiver.
* `cancel_callback`: Called periodically by `ansible-runner`; this is so that AWX can inform `ansible-runner` if the job should be canceled or not.
* `cancel_callback`: Called periodically by `ansible-runner`; this is so that AWX can inform `ansible-runner` if the job should be canceled or not. Only applies for system jobs now, and other jobs are canceled via receptor.
* `finished_callback`: Called once by `ansible-runner` to denote that the process that was asked to run is finished. AWX will construct the special control event, `EOF`, with the associated total number of events that it observed.
* `status_handler`: Called by `ansible-runner` as the process transitions state internally. AWX uses the `starting` status to know that `ansible-runner` has made all of its decisions around the process that it will launch. AWX gathers and associates these decisions with the Job for historical observation.

Expand Down
3 changes: 1 addition & 2 deletions tools/docker-compose/supervisor.conf
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@ nodaemon=true
command = make dispatcher
autostart = true
autorestart = true
stopwaitsecs = 1
stopsignal=KILL
stopwaitsecs = 5
stopasgroup=true
killasgroup=true
redirect_stderr=true
Expand Down

0 comments on commit 537f3c3

Please sign in to comment.