Skip to content

Commit

Permalink
Refactor canceling to work through messaging and signals, not database
Browse files Browse the repository at this point in the history
If canceled attempted before, still allow attempting another cancel
in this case, attempt to send the sigterm signal again.
Keep clicking, you might help!

Replace other cancel_callbacks with sigterm watcher
  adapt special inventory mechanism for this too

Get rid of the cancel_watcher method with exception in main thread

Handle academic case of sigterm race condition

Process cancelation as control signal

Fully connect cancel method and run_dispatcher to control

Never transition workflows directly to canceled, add logs
  • Loading branch information
AlanCoding committed Aug 29, 2022
1 parent 6d207d2 commit c26139d
Show file tree
Hide file tree
Showing 13 changed files with 136 additions and 90 deletions.
10 changes: 8 additions & 2 deletions awx/main/dispatch/control.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,18 +37,24 @@ def status(self, *args, **kwargs):
def running(self, *args, **kwargs):
return self.control_with_reply('running', *args, **kwargs)

def cancel(self, task_ids, *args, **kwargs):
return self.control_with_reply('cancel', *args, extra_data={'task_ids': task_ids}, **kwargs)

@classmethod
def generate_reply_queue_name(cls):
return f"reply_to_{str(uuid.uuid4()).replace('-','_')}"

def control_with_reply(self, command, timeout=5):
def control_with_reply(self, command, timeout=5, extra_data=None):
logger.warning('checking {} {} for {}'.format(self.service, command, self.queuename))
reply_queue = Control.generate_reply_queue_name()
self.result = None

with pg_bus_conn() as conn:
conn.listen(reply_queue)
conn.notify(self.queuename, json.dumps({'control': command, 'reply_to': reply_queue}))
send_data = {'control': command, 'reply_to': reply_queue}
if extra_data:
send_data.update(extra_data)
conn.notify(self.queuename, json.dumps(send_data))

for reply in conn.events(select_timeout=timeout, yield_timeouts=True):
if reply is None:
Expand Down
13 changes: 12 additions & 1 deletion awx/main/dispatch/worker/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def listening_on(self):
def control(self, body):
logger.warning(f'Received control signal:\n{body}')
control = body.get('control')
if control in ('status', 'running'):
if control in ('status', 'running', 'cancel'):
reply_queue = body['reply_to']
if control == 'status':
msg = '\n'.join([self.listening_on, self.pool.debug()])
Expand All @@ -72,6 +72,17 @@ def control(self, body):
for worker in self.pool.workers:
worker.calculate_managed_tasks()
msg.extend(worker.managed_tasks.keys())
elif control == 'cancel':
msg = []
task_ids = set(body['task_ids'])
for worker in self.pool.workers:
task = worker.current_task
if task and task['uuid'] in task_ids:
logger.warn(f'Sending SIGTERM to task id={task["uuid"]}, task={task.get("task")}, args={task.get("args")}')
os.kill(worker.pid, signal.SIGTERM)
msg.append(task['uuid'])
if task_ids and not msg:
logger.info(f'Could not locate running tasks to cancel with ids={task_ids}')

with pg_bus_conn() as conn:
conn.notify(reply_queue, json.dumps(msg))
Expand Down
22 changes: 21 additions & 1 deletion awx/main/management/commands/run_dispatcher.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) 2015 Ansible, Inc.
# All Rights Reserved.
import logging
import yaml

from django.conf import settings
from django.core.cache import cache as django_cache
Expand Down Expand Up @@ -30,7 +31,16 @@ def add_arguments(self, parser):
'--reload',
dest='reload',
action='store_true',
help=('cause the dispatcher to recycle all of its worker processes;' 'running jobs will run to completion first'),
help=('cause the dispatcher to recycle all of its worker processes; running jobs will run to completion first'),
)
parser.add_argument(
'--cancel',
dest='cancel',
help=(
'Cancel a particular task id. Takes either a single id string, or a JSON list of multiple ids. '
'Can take in output from the --running argument as input to cancel all tasks. '
'Only running tasks can be canceled, queued tasks must be started before they can be canceled.'
),
)

def handle(self, *arg, **options):
Expand All @@ -42,6 +52,16 @@ def handle(self, *arg, **options):
return
if options.get('reload'):
return Control('dispatcher').control({'control': 'reload'})
if options.get('cancel'):
cancel_str = options.get('cancel')
try:
cancel_data = yaml.safe_load(cancel_str)
except Exception:
cancel_data = [cancel_str]
if not isinstance(cancel_data, list):
cancel_data = [cancel_str]
print(Control('dispatcher').cancel(cancel_data))
return

# It's important to close these because we're _about_ to fork, and we
# don't want the forked processes to inherit the open sockets
Expand Down
69 changes: 43 additions & 26 deletions awx/main/models/unified_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1395,23 +1395,6 @@ def signal_start(self, **kwargs):
# Done!
return True

@property
def actually_running(self):
# returns True if the job is running in the appropriate dispatcher process
running = False
if all([self.status == 'running', self.celery_task_id, self.execution_node]):
# If the job is marked as running, but the dispatcher
# doesn't know about it (or the dispatcher doesn't reply),
# then cancel the job
timeout = 5
try:
running = self.celery_task_id in ControlDispatcher('dispatcher', self.controller_node or self.execution_node).running(timeout=timeout)
except socket.timeout:
logger.error('could not reach dispatcher on {} within {}s'.format(self.execution_node, timeout))
except Exception:
logger.exception("error encountered when checking task status")
return running

@property
def can_cancel(self):
return bool(self.status in CAN_CANCEL)
Expand All @@ -1421,27 +1404,61 @@ def _build_job_explanation(self):
return 'Previous Task Canceled: {"job_type": "%s", "job_name": "%s", "job_id": "%s"}' % (self.model_to_str(), self.name, self.id)
return None

def fallback_cancel(self):
if not self.celery_task_id:
self.refresh_from_db(fields=['celery_task_id'])
self.cancel_dispatcher_process()

def cancel_dispatcher_process(self):
"""Returns True if dispatcher running this job acknowledged request and sent SIGTERM"""
if not self.celery_task_id:
return
canceled = []
try:
# Use control and reply mechanism to cancel and obtain confirmation
timeout = 5
canceled = ControlDispatcher('dispatcher', self.controller_node).cancel([self.celery_task_id])
except socket.timeout:
logger.error(f'could not reach dispatcher on {self.controller_node} within {timeout}s')
except Exception:
logger.exception("error encountered when checking task status")
return bool(self.celery_task_id in canceled) # True or False, whether confirmation was obtained

def cancel(self, job_explanation=None, is_chain=False):
if self.can_cancel:
if not is_chain:
for x in self.get_jobs_fail_chain():
x.cancel(job_explanation=self._build_job_explanation(), is_chain=True)

cancel_fields = []
if not self.cancel_flag:
self.cancel_flag = True
self.start_args = '' # blank field to remove encrypted passwords
cancel_fields = ['cancel_flag', 'start_args']
if self.status in ('pending', 'waiting', 'new'):
self.status = 'canceled'
cancel_fields.append('status')
if self.status == 'running' and not self.actually_running:
self.status = 'canceled'
cancel_fields.append('status')
cancel_fields.extend(['cancel_flag', 'start_args'])
connection.on_commit(lambda: self.websocket_emit_status("canceled"))

if job_explanation is not None:
self.job_explanation = job_explanation
cancel_fields.append('job_explanation')
self.save(update_fields=cancel_fields)
self.websocket_emit_status("canceled")

controller_notified = False
if self.celery_task_id:
controller_notified = self.cancel_dispatcher_process()

else:
# Avoid race condition where we have stale model from pending state but job has already started,
# its checking signal but not cancel_flag, so re-send signal after this database commit
connection.on_commit(self.fallback_cancel)

# If a SIGTERM signal was sent to the control process, and acked by the dispatcher
# then we want to let its own cleanup change status, otherwise change status now
if not controller_notified:
if self.status != 'canceled':
self.status = 'canceled'
cancel_fields.append('status')

self.save(update_fields=cancel_fields)

return self.cancel_flag

@property
Expand Down
5 changes: 2 additions & 3 deletions awx/main/models/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,11 +723,10 @@ def get_notification_friendly_name(self):
def preferred_instance_groups(self):
return []

@property
def actually_running(self):
def cancel_dispatcher_process(self):
# WorkflowJobs don't _actually_ run anything in the dispatcher, so
# there's no point in asking the dispatcher if it knows about this task
return self.status == 'running'
return True


class WorkflowApprovalTemplate(UnifiedJobTemplate, RelatedJobsMixin):
Expand Down
27 changes: 3 additions & 24 deletions awx/main/tasks/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,16 @@
import stat

# Django
from django.utils.timezone import now
from django.conf import settings
from django_guid import get_guid
from django.utils.functional import cached_property
from django.db import connections

# AWX
from awx.main.redact import UriCleaner
from awx.main.constants import MINIMAL_EVENTS, ANSIBLE_RUNNER_NEEDS_UPDATE_MESSAGE
from awx.main.utils.update_model import update_model
from awx.main.queue import CallbackQueueDispatcher
from awx.main.tasks.signals import signal_callback

logger = logging.getLogger('awx.main.tasks.callback')

Expand Down Expand Up @@ -175,28 +174,6 @@ def event_handler(self, event_data):

return False

def cancel_callback(self):
"""
Ansible runner callback to tell the job when/if it is canceled
"""
unified_job_id = self.instance.pk
if signal_callback():
return True
try:
self.instance = self.update_model(unified_job_id)
except Exception:
logger.exception(f'Encountered error during cancel check for {unified_job_id}, canceling now')
return True
if not self.instance:
logger.error('unified job {} was deleted while running, canceling'.format(unified_job_id))
return True
if self.instance.cancel_flag or self.instance.status == 'canceled':
cancel_wait = (now() - self.instance.modified).seconds if self.instance.modified else 0
if cancel_wait > 5:
logger.warning('Request to cancel {} took {} seconds to complete.'.format(self.instance.log_format, cancel_wait))
return True
return False

def finished_callback(self, runner_obj):
"""
Ansible runner callback triggered on finished run
Expand Down Expand Up @@ -227,6 +204,8 @@ def status_handler(self, status_data, runner_config):

with disable_activity_stream():
self.instance = self.update_model(self.instance.pk, job_args=json.dumps(runner_config.command), job_cwd=runner_config.cwd, job_env=job_env)
# We opened a connection just for that save, close it here now
connections.close_all()
elif status_data['status'] == 'failed':
# For encrypted ssh_key_data, ansible-runner worker will open and write the
# ssh_key_data to a named pipe. Then, once the podman container starts, ssh-agent will
Expand Down
5 changes: 3 additions & 2 deletions awx/main/tasks/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,7 @@ def run(self, pk, **kwargs):
self.instance.log_lifecycle("preparing_playbook")
if self.instance.cancel_flag or signal_callback():
self.instance = self.update_model(self.instance.pk, status='canceled')

if self.instance.status != 'running':
# Stop the task chain and prevent starting the job if it has
# already been canceled.
Expand Down Expand Up @@ -589,7 +590,7 @@ def run(self, pk, **kwargs):
event_handler=self.runner_callback.event_handler,
finished_callback=self.runner_callback.finished_callback,
status_handler=self.runner_callback.status_handler,
cancel_callback=self.runner_callback.cancel_callback,
cancel_callback=signal_callback,
**params,
)
else:
Expand Down Expand Up @@ -1622,7 +1623,7 @@ def post_run_hook(self, inventory_update, status):

handler = SpecialInventoryHandler(
self.runner_callback.event_handler,
self.runner_callback.cancel_callback,
signal_callback,
verbosity=inventory_update.verbosity,
job_timeout=self.get_instance_timeout(self.instance),
start_time=inventory_update.started,
Expand Down
48 changes: 23 additions & 25 deletions awx/main/tasks/receptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

# Django
from django.conf import settings
from django.db import connections

# Runner
import ansible_runner
Expand All @@ -25,6 +26,7 @@
cleanup_new_process,
)
from awx.main.constants import MAX_ISOLATED_PATH_COLON_DELIMITER
from awx.main.tasks.signals import signal_state, signal_callback, SignalExit

# Receptorctl
from receptorctl.socket_interface import ReceptorControl
Expand Down Expand Up @@ -335,24 +337,32 @@ def _run_internal(self, receptor_ctl):
shutil.rmtree(artifact_dir)

resultsock, resultfile = receptor_ctl.get_work_results(self.unit_id, return_socket=True, return_sockfile=True)
# Both "processor" and "cancel_watcher" are spawned in separate threads.
# We wait for the first one to return. If cancel_watcher returns first,
# we yank the socket out from underneath the processor, which will cause it
# to exit. A reference to the processor_future is passed into the cancel_watcher_future,
# Which exits if the job has finished normally. The context manager ensures we do not
# leave any threads laying around.
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:

connections.close_all()

# "processor" and the main thread will be separate threads.
# If a cancel happens, the main thread will encounter an exception, in which case
# we yank the socket out from underneath the processor, which will cause it to exit.
# The ThreadPoolExecutor context manager ensures we do not leave any threads laying around.
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
processor_future = executor.submit(self.processor, resultfile)
cancel_watcher_future = executor.submit(self.cancel_watcher, processor_future)
futures = [processor_future, cancel_watcher_future]
first_future = concurrent.futures.wait(futures, return_when=concurrent.futures.FIRST_COMPLETED)

res = list(first_future.done)[0].result()
if res.status == 'canceled':
try:
signal_state.raise_exception = True
# address race condition where SIGTERM was issued after this dispatcher task started
if signal_callback():
raise SignalExit()
res = processor_future.result()
except SignalExit:
receptor_ctl.simple_command(f"work cancel {self.unit_id}")
resultsock.shutdown(socket.SHUT_RDWR)
resultfile.close()
elif res.status == 'error':
result = namedtuple('result', ['status', 'rc'])
res = result('canceled', 1)
finally:
signal_state.raise_exception = False

if res.status == 'error':
# If ansible-runner ran, but an error occured at runtime, the traceback information
# is saved via the status_handler passed in to the processor.
if 'result_traceback' in self.task.runner_callback.extra_update_fields:
Expand Down Expand Up @@ -446,18 +456,6 @@ def work_type(self):
return 'local'
return 'ansible-runner'

@cleanup_new_process
def cancel_watcher(self, processor_future):
while True:
if processor_future.done():
return processor_future.result()

if self.task.runner_callback.cancel_callback():
result = namedtuple('result', ['status', 'rc'])
return result('canceled', 1)

time.sleep(1)

@property
def pod_definition(self):
ee = self.task.instance.execution_environment
Expand Down
8 changes: 8 additions & 0 deletions awx/main/tasks/signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,27 @@
__all__ = ['with_signal_handling', 'signal_callback']


class SignalExit(Exception):
pass


class SignalState:
def reset(self):
self.sigterm_flag = False
self.is_active = False
self.original_sigterm = None
self.original_sigint = None
self.raise_exception = False

def __init__(self):
self.reset()

def set_flag(self, *args):
"""Method to pass into the python signal.signal method to receive signals"""
self.sigterm_flag = True
if self.raise_exception:
self.raise_exception = False # so it is not raised a second time in error handling
raise SignalExit()

def connect_signals(self):
self.original_sigterm = signal.getsignal(signal.SIGTERM)
Expand Down
Loading

0 comments on commit c26139d

Please sign in to comment.