Skip to content

Commit

Permalink
Merge pull request #11745 from AlanCoding/cancel_rework_no_close
Browse files Browse the repository at this point in the history
Close database connections while processing job output
  • Loading branch information
AlanCoding authored Sep 6, 2022
2 parents b83b65d + f512971 commit 15964dc
Show file tree
Hide file tree
Showing 14 changed files with 141 additions and 90 deletions.
10 changes: 8 additions & 2 deletions awx/main/dispatch/control.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,18 +37,24 @@ def status(self, *args, **kwargs):
def running(self, *args, **kwargs):
return self.control_with_reply('running', *args, **kwargs)

def cancel(self, task_ids, *args, **kwargs):
return self.control_with_reply('cancel', *args, extra_data={'task_ids': task_ids}, **kwargs)

@classmethod
def generate_reply_queue_name(cls):
return f"reply_to_{str(uuid.uuid4()).replace('-','_')}"

def control_with_reply(self, command, timeout=5):
def control_with_reply(self, command, timeout=5, extra_data=None):
logger.warning('checking {} {} for {}'.format(self.service, command, self.queuename))
reply_queue = Control.generate_reply_queue_name()
self.result = None

with pg_bus_conn(new_connection=True) as conn:
conn.listen(reply_queue)
conn.notify(self.queuename, json.dumps({'control': command, 'reply_to': reply_queue}))
send_data = {'control': command, 'reply_to': reply_queue}
if extra_data:
send_data.update(extra_data)
conn.notify(self.queuename, json.dumps(send_data))

for reply in conn.events(select_timeout=timeout, yield_timeouts=True):
if reply is None:
Expand Down
13 changes: 12 additions & 1 deletion awx/main/dispatch/worker/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def listening_on(self):
def control(self, body):
logger.warning(f'Received control signal:\n{body}')
control = body.get('control')
if control in ('status', 'running'):
if control in ('status', 'running', 'cancel'):
reply_queue = body['reply_to']
if control == 'status':
msg = '\n'.join([self.listening_on, self.pool.debug()])
Expand All @@ -72,6 +72,17 @@ def control(self, body):
for worker in self.pool.workers:
worker.calculate_managed_tasks()
msg.extend(worker.managed_tasks.keys())
elif control == 'cancel':
msg = []
task_ids = set(body['task_ids'])
for worker in self.pool.workers:
task = worker.current_task
if task and task['uuid'] in task_ids:
logger.warn(f'Sending SIGTERM to task id={task["uuid"]}, task={task.get("task")}, args={task.get("args")}')
os.kill(worker.pid, signal.SIGTERM)
msg.append(task['uuid'])
if task_ids and not msg:
logger.info(f'Could not locate running tasks to cancel with ids={task_ids}')

with pg_bus_conn() as conn:
conn.notify(reply_queue, json.dumps(msg))
Expand Down
22 changes: 21 additions & 1 deletion awx/main/management/commands/run_dispatcher.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) 2015 Ansible, Inc.
# All Rights Reserved.
import logging
import yaml

from django.conf import settings
from django.core.cache import cache as django_cache
Expand Down Expand Up @@ -30,7 +31,16 @@ def add_arguments(self, parser):
'--reload',
dest='reload',
action='store_true',
help=('cause the dispatcher to recycle all of its worker processes;' 'running jobs will run to completion first'),
help=('cause the dispatcher to recycle all of its worker processes; running jobs will run to completion first'),
)
parser.add_argument(
'--cancel',
dest='cancel',
help=(
'Cancel a particular task id. Takes either a single id string, or a JSON list of multiple ids. '
'Can take in output from the --running argument as input to cancel all tasks. '
'Only running tasks can be canceled, queued tasks must be started before they can be canceled.'
),
)

def handle(self, *arg, **options):
Expand All @@ -42,6 +52,16 @@ def handle(self, *arg, **options):
return
if options.get('reload'):
return Control('dispatcher').control({'control': 'reload'})
if options.get('cancel'):
cancel_str = options.get('cancel')
try:
cancel_data = yaml.safe_load(cancel_str)
except Exception:
cancel_data = [cancel_str]
if not isinstance(cancel_data, list):
cancel_data = [cancel_str]
print(Control('dispatcher').cancel(cancel_data))
return

# It's important to close these because we're _about_ to fork, and we
# don't want the forked processes to inherit the open sockets
Expand Down
5 changes: 5 additions & 0 deletions awx/main/models/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,11 @@ class TaskManagerJobMixin(TaskManagerUnifiedJobMixin):
class Meta:
abstract = True

def get_jobs_fail_chain(self):
if self.project_update_id:
return [self.project_update]
return []


class TaskManagerUpdateOnLaunchMixin(TaskManagerUnifiedJobMixin):
class Meta:
Expand Down
69 changes: 43 additions & 26 deletions awx/main/models/unified_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1395,23 +1395,6 @@ def signal_start(self, **kwargs):
# Done!
return True

@property
def actually_running(self):
# returns True if the job is running in the appropriate dispatcher process
running = False
if all([self.status == 'running', self.celery_task_id, self.execution_node]):
# If the job is marked as running, but the dispatcher
# doesn't know about it (or the dispatcher doesn't reply),
# then cancel the job
timeout = 5
try:
running = self.celery_task_id in ControlDispatcher('dispatcher', self.controller_node or self.execution_node).running(timeout=timeout)
except socket.timeout:
logger.error('could not reach dispatcher on {} within {}s'.format(self.execution_node, timeout))
except Exception:
logger.exception("error encountered when checking task status")
return running

@property
def can_cancel(self):
return bool(self.status in CAN_CANCEL)
Expand All @@ -1421,27 +1404,61 @@ def _build_job_explanation(self):
return 'Previous Task Canceled: {"job_type": "%s", "job_name": "%s", "job_id": "%s"}' % (self.model_to_str(), self.name, self.id)
return None

def fallback_cancel(self):
if not self.celery_task_id:
self.refresh_from_db(fields=['celery_task_id'])
self.cancel_dispatcher_process()

def cancel_dispatcher_process(self):
"""Returns True if dispatcher running this job acknowledged request and sent SIGTERM"""
if not self.celery_task_id:
return
canceled = []
try:
# Use control and reply mechanism to cancel and obtain confirmation
timeout = 5
canceled = ControlDispatcher('dispatcher', self.controller_node).cancel([self.celery_task_id])
except socket.timeout:
logger.error(f'could not reach dispatcher on {self.controller_node} within {timeout}s')
except Exception:
logger.exception("error encountered when checking task status")
return bool(self.celery_task_id in canceled) # True or False, whether confirmation was obtained

def cancel(self, job_explanation=None, is_chain=False):
if self.can_cancel:
if not is_chain:
for x in self.get_jobs_fail_chain():
x.cancel(job_explanation=self._build_job_explanation(), is_chain=True)

cancel_fields = []
if not self.cancel_flag:
self.cancel_flag = True
self.start_args = '' # blank field to remove encrypted passwords
cancel_fields = ['cancel_flag', 'start_args']
if self.status in ('pending', 'waiting', 'new'):
self.status = 'canceled'
cancel_fields.append('status')
if self.status == 'running' and not self.actually_running:
self.status = 'canceled'
cancel_fields.append('status')
cancel_fields.extend(['cancel_flag', 'start_args'])
connection.on_commit(lambda: self.websocket_emit_status("canceled"))

if job_explanation is not None:
self.job_explanation = job_explanation
cancel_fields.append('job_explanation')
self.save(update_fields=cancel_fields)
self.websocket_emit_status("canceled")

controller_notified = False
if self.celery_task_id:
controller_notified = self.cancel_dispatcher_process()

else:
# Avoid race condition where we have stale model from pending state but job has already started,
# its checking signal but not cancel_flag, so re-send signal after this database commit
connection.on_commit(self.fallback_cancel)

# If a SIGTERM signal was sent to the control process, and acked by the dispatcher
# then we want to let its own cleanup change status, otherwise change status now
if not controller_notified:
if self.status != 'canceled':
self.status = 'canceled'
cancel_fields.append('status')

self.save(update_fields=cancel_fields)

return self.cancel_flag

@property
Expand Down
5 changes: 2 additions & 3 deletions awx/main/models/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,11 +723,10 @@ def get_notification_friendly_name(self):
def preferred_instance_groups(self):
return []

@property
def actually_running(self):
def cancel_dispatcher_process(self):
# WorkflowJobs don't _actually_ run anything in the dispatcher, so
# there's no point in asking the dispatcher if it knows about this task
return self.status == 'running'
return True


class WorkflowApprovalTemplate(UnifiedJobTemplate, RelatedJobsMixin):
Expand Down
27 changes: 3 additions & 24 deletions awx/main/tasks/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,16 @@
import stat

# Django
from django.utils.timezone import now
from django.conf import settings
from django_guid import get_guid
from django.utils.functional import cached_property
from django.db import connections

# AWX
from awx.main.redact import UriCleaner
from awx.main.constants import MINIMAL_EVENTS, ANSIBLE_RUNNER_NEEDS_UPDATE_MESSAGE
from awx.main.utils.update_model import update_model
from awx.main.queue import CallbackQueueDispatcher
from awx.main.tasks.signals import signal_callback

logger = logging.getLogger('awx.main.tasks.callback')

Expand Down Expand Up @@ -175,28 +174,6 @@ def event_handler(self, event_data):

return False

def cancel_callback(self):
"""
Ansible runner callback to tell the job when/if it is canceled
"""
unified_job_id = self.instance.pk
if signal_callback():
return True
try:
self.instance = self.update_model(unified_job_id)
except Exception:
logger.exception(f'Encountered error during cancel check for {unified_job_id}, canceling now')
return True
if not self.instance:
logger.error('unified job {} was deleted while running, canceling'.format(unified_job_id))
return True
if self.instance.cancel_flag or self.instance.status == 'canceled':
cancel_wait = (now() - self.instance.modified).seconds if self.instance.modified else 0
if cancel_wait > 5:
logger.warning('Request to cancel {} took {} seconds to complete.'.format(self.instance.log_format, cancel_wait))
return True
return False

def finished_callback(self, runner_obj):
"""
Ansible runner callback triggered on finished run
Expand Down Expand Up @@ -227,6 +204,8 @@ def status_handler(self, status_data, runner_config):

with disable_activity_stream():
self.instance = self.update_model(self.instance.pk, job_args=json.dumps(runner_config.command), job_cwd=runner_config.cwd, job_env=job_env)
# We opened a connection just for that save, close it here now
connections.close_all()
elif status_data['status'] == 'failed':
# For encrypted ssh_key_data, ansible-runner worker will open and write the
# ssh_key_data to a named pipe. Then, once the podman container starts, ssh-agent will
Expand Down
5 changes: 3 additions & 2 deletions awx/main/tasks/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,7 @@ def run(self, pk, **kwargs):
self.instance.log_lifecycle("preparing_playbook")
if self.instance.cancel_flag or signal_callback():
self.instance = self.update_model(self.instance.pk, status='canceled')

if self.instance.status != 'running':
# Stop the task chain and prevent starting the job if it has
# already been canceled.
Expand Down Expand Up @@ -589,7 +590,7 @@ def run(self, pk, **kwargs):
event_handler=self.runner_callback.event_handler,
finished_callback=self.runner_callback.finished_callback,
status_handler=self.runner_callback.status_handler,
cancel_callback=self.runner_callback.cancel_callback,
cancel_callback=signal_callback,
**params,
)
else:
Expand Down Expand Up @@ -1626,7 +1627,7 @@ def post_run_hook(self, inventory_update, status):

handler = SpecialInventoryHandler(
self.runner_callback.event_handler,
self.runner_callback.cancel_callback,
signal_callback,
verbosity=inventory_update.verbosity,
job_timeout=self.get_instance_timeout(self.instance),
start_time=inventory_update.started,
Expand Down
48 changes: 23 additions & 25 deletions awx/main/tasks/receptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

# Django
from django.conf import settings
from django.db import connections

# Runner
import ansible_runner
Expand All @@ -25,6 +26,7 @@
cleanup_new_process,
)
from awx.main.constants import MAX_ISOLATED_PATH_COLON_DELIMITER
from awx.main.tasks.signals import signal_state, signal_callback, SignalExit

# Receptorctl
from receptorctl.socket_interface import ReceptorControl
Expand Down Expand Up @@ -335,24 +337,32 @@ def _run_internal(self, receptor_ctl):
shutil.rmtree(artifact_dir)

resultsock, resultfile = receptor_ctl.get_work_results(self.unit_id, return_socket=True, return_sockfile=True)
# Both "processor" and "cancel_watcher" are spawned in separate threads.
# We wait for the first one to return. If cancel_watcher returns first,
# we yank the socket out from underneath the processor, which will cause it
# to exit. A reference to the processor_future is passed into the cancel_watcher_future,
# Which exits if the job has finished normally. The context manager ensures we do not
# leave any threads laying around.
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:

connections.close_all()

# "processor" and the main thread will be separate threads.
# If a cancel happens, the main thread will encounter an exception, in which case
# we yank the socket out from underneath the processor, which will cause it to exit.
# The ThreadPoolExecutor context manager ensures we do not leave any threads laying around.
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
processor_future = executor.submit(self.processor, resultfile)
cancel_watcher_future = executor.submit(self.cancel_watcher, processor_future)
futures = [processor_future, cancel_watcher_future]
first_future = concurrent.futures.wait(futures, return_when=concurrent.futures.FIRST_COMPLETED)

res = list(first_future.done)[0].result()
if res.status == 'canceled':
try:
signal_state.raise_exception = True
# address race condition where SIGTERM was issued after this dispatcher task started
if signal_callback():
raise SignalExit()
res = processor_future.result()
except SignalExit:
receptor_ctl.simple_command(f"work cancel {self.unit_id}")
resultsock.shutdown(socket.SHUT_RDWR)
resultfile.close()
elif res.status == 'error':
result = namedtuple('result', ['status', 'rc'])
res = result('canceled', 1)
finally:
signal_state.raise_exception = False

if res.status == 'error':
# If ansible-runner ran, but an error occured at runtime, the traceback information
# is saved via the status_handler passed in to the processor.
if 'result_traceback' in self.task.runner_callback.extra_update_fields:
Expand Down Expand Up @@ -446,18 +456,6 @@ def work_type(self):
return 'local'
return 'ansible-runner'

@cleanup_new_process
def cancel_watcher(self, processor_future):
while True:
if processor_future.done():
return processor_future.result()

if self.task.runner_callback.cancel_callback():
result = namedtuple('result', ['status', 'rc'])
return result('canceled', 1)

time.sleep(1)

@property
def pod_definition(self):
ee = self.task.instance.execution_environment
Expand Down
Loading

0 comments on commit 15964dc

Please sign in to comment.