Skip to content

Commit

Permalink
Sync polling/fix pipeline/fix cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
Blanca Fuentes Monjas committed Feb 26, 2025
1 parent 3697acf commit 98cbe20
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 35 deletions.
20 changes: 10 additions & 10 deletions reframe/core/schedulers/slurm.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,16 +459,16 @@ async def poll(self, *jobs):
t_start = time.strftime(
'%F', time.localtime(min(job.submit_time for job in jobs))
)
completed = await _run_strict(
f'sacct -S {t_start} -P '
f'-j {",".join(job.jobid for job in jobs)} '
f'-o jobid,state,exitcode,end,nodelist'
)
# completed = _run_strict_s(
# completed = await _run_strict(
# f'sacct -S {t_start} -P '
# f'-j {",".join(job.jobid for job in jobs)} '
# f'-o jobid,state,exitcode,end,nodelist'
# )
completed = _run_strict_s(
f'sacct -S {t_start} -P '
f'-j {",".join(job.jobid for job in jobs)} '
f'-o jobid,state,exitcode,end,nodelist'
)

self._update_state_count += 1

Expand Down Expand Up @@ -531,12 +531,12 @@ async def _cancel_if_blocked(self, job, reasons=None):
return

if not reasons:
completed = await osext.run_command(
'squeue -h -j %s -o %%r' % job.jobid
)
# completed = osext.run_command_s(
# completed = await osext.run_command(
# 'squeue -h -j %s -o %%r' % job.jobid
# )
completed = osext.run_command_s(
'squeue -h -j %s -o %%r' % job.jobid
)
if hasattr(current_task(), 'aborting'):
raise asyncio.CancelledError
reasons = completed.stdout.splitlines()
Expand Down
46 changes: 21 additions & 25 deletions reframe/frontend/executors/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,6 @@ def on_task_success(self, task):
if c in self._task_index:
self._task_index[c].ref_count -= 1

# _cleanup_all(self._retired_tasks, not self.keep_stage_files)
if self.timeout_expired():
raise RunSessionTimeout('maximum session duration exceeded')

Expand Down Expand Up @@ -341,8 +340,9 @@ def execute(self, testcases):
self._exit()

def _exit(self):
pass
# Clean up all remaining tasks
asyncio_run(_cleanup_all(self._retired_tasks, not self.keep_stage_files))
# asyncio_run(_cleanup_all(self._retired_tasks, not self.keep_stage_files))


class AsyncioExecutionPolicy(ExecutionPolicy, TaskEventListener):
Expand Down Expand Up @@ -476,7 +476,7 @@ async def _runcase(self, case, task):
max_jobs = self._max_jobs[partname]
while len(self._partition_tasks[partname])+1 > max_jobs:
getlogger().debug2(f'Hit the max job limit of {partname}: {max_jobs}')
await asyncio.sleep(2)
await asyncio.sleep(0.001)
self._partition_tasks[partname].add(task)
await task.compile()

Expand All @@ -491,9 +491,6 @@ async def _runcase(self, case, task):
self._pollctl.reset_snooze_time(sched.registered_name)
while True:
if not self.dry_run_mode:
# Check if the task was completed
if task.compile_complete():
break
if (getpollcontroller().is_time_to_poll(sched.registered_name)): # and
# getpollcontroller()._poll_event[sched.registered_name].is_set()):
getlogger().debug2("Jobs to poll"
Expand Down Expand Up @@ -530,23 +527,24 @@ async def _runcase(self, case, task):
await self._pollctl.snooze(sched.registered_name)
if task.compile_complete():
break
else:
# yield control to another task to give it the chance to check their status
await asyncio.sleep(0)
# else:
# # yield control to another task to give it the chance to check their status
# await asyncio.sleep(0)
# We need to check the timeout inside the while loop
if self.timeout_expired():
raise RunSessionTimeout(
'maximum session duration exceeded'
)
else:
if self._pipeline_statistics:
getlogger().debug2(f"{task}, compiling to ready_run")
self._update_pipeline_progress('compiling', 'ready_run', 1)
await task.compile_wait()
self._partition_tasks[partname].remove(task)
partname = _get_partition_name(task, phase='run')
max_jobs = self._max_jobs[partname]
while len(self._partition_tasks[partname])+1 > max_jobs:
await asyncio.sleep(2)
await asyncio.sleep(0.001)
self._partition_tasks[partname].add(task)
await task.run()
# If CompileOnly, no polling for run jobs
Expand All @@ -560,8 +558,6 @@ async def _runcase(self, case, task):
self._pollctl.reset_snooze_time(sched.registered_name)
while True:
if not self.dry_run_mode:
if task.run_complete():
break
if (getpollcontroller().is_time_to_poll(sched.registered_name)): # and
# getpollcontroller()._poll_event[sched.registered_name].is_set()):
getlogger().debug2("Jobs to poll"
Expand Down Expand Up @@ -597,14 +593,13 @@ async def _runcase(self, case, task):
await self._pollctl.snooze(sched.registered_name)
if task.run_complete():
break
else:
await asyncio.sleep(0)
if self.timeout_expired():
raise RunSessionTimeout(
'maximum session duration exceeded'
)
else:
if self._pipeline_statistics:
getlogger().debug2(f"{task}, running to completing")
self._update_pipeline_progress('running', 'completing', 1)
await task.run_wait()
self._partition_tasks[partname].remove(task)
Expand All @@ -621,18 +616,14 @@ async def _runcase(self, case, task):
if self._pipeline_statistics:
self._update_pipeline_progress('completing', 'retired', 1)

if self._pipeline_statistics:
num_retired = len(self._retired_tasks)

# await _cleanup_all(self._retired_tasks, not self.keep_stage_files)
if self._pipeline_statistics:
num_retired_actual = num_retired - len(self._retired_tasks)
if task.ref_count == 0:
with contextlib.suppress(TaskExit):
await task.cleanup(not self.keep_stage_files)

# Some tests might not be cleaned up because they are
# waiting for dependencies or because their dependencies
# have failed.
if self._pipeline_statistics:
self._update_pipeline_progress(
'retired', 'completed', num_retired_actual
'retired', 'completed', 1
)

except TaskExit:
Expand Down Expand Up @@ -675,7 +666,7 @@ async def check_deps(self, task):
while not (self.deps_skipped(task) or self.deps_failed(task) or
self.deps_succeeded(task)):
getlogger().debug2(f'{task.info()} waiting for dependencies')
await asyncio.sleep(0)
await asyncio.sleep(0.001)

if self.deps_skipped(task):
return "skipped"
Expand Down Expand Up @@ -771,7 +762,12 @@ def on_task_failure(self, task):
self.printer.status('FAIL', msg, just='right')

if self._pipeline_statistics:
self._update_pipeline_progress(task._failed_state, 'fail', 1)
old_state = task._failed_state
if old_state == 'running':
old_state = 'ready_run'
elif old_state == 'compiling':
old_state = 'ready_compile'
self._update_pipeline_progress(old_state, 'fail', 1)

_print_perf(task)
if task.failed_stage == 'sanity':
Expand Down

0 comments on commit 98cbe20

Please sign in to comment.