Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimization: use cached status everywhere. #410

Merged
merged 9 commits into from
Dec 31, 2020
1 change: 1 addition & 0 deletions changelog.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ Changed
- ``CPUEnvironment`` and ``GPUEnvironment`` classes are deprecated (#381).
- Docstrings are now written in `numpydoc style <https://numpydoc.readthedocs.io/en/latest/format.html>`__ (#392).
- Default environment for the University of Minnesota Mangi cluster changed from SLURM to Torque (#393).
- Improved internal caching of scheduler status (#410).

Fixed
+++++
Expand Down
129 changes: 99 additions & 30 deletions flow/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -1192,15 +1192,6 @@ def _generate_id(self, jobs, operation_name=None, index=0):
# By appending the unique job_op_id, we ensure that each id is truly unique.
return readable_name + job_op_id

def _get_status(self, jobs):
"""For a given job-aggregate, check the group's submission status."""
try:
return JobStatus(
jobs[0]._project.document["_status"][self._generate_id(jobs)]
)
except KeyError:
return JobStatus.unknown

def _create_submission_job_operation(
self,
entrypoint,
Expand Down Expand Up @@ -2042,6 +2033,25 @@ def scheduler_jobs(self, scheduler):
"""
yield from self._expand_bundled_jobs(scheduler.jobs())

def _get_cached_status(self):
"""Fetch all status information.

The project document key ``_status`` is returned as a plain dict, or an
empty dict if no status information is present.

Returns
-------
dict
Dictionary of cached status information. The keys are uniquely
generated ids for each group and job. The values are instances of
:class:`~.JobStatus`.

"""
try:
return self.document["_status"]()
except KeyError:
return {}

def _get_operations_status(self, jobs, cached_status):
"""Return a dict with information about job-operations for this aggregate.

Expand Down Expand Up @@ -2104,14 +2114,11 @@ def get_job_status(self, job, ignore_errors=False, cached_status=None):

"""
# TODO: Add support for aggregates for this method.
if cached_status is None:
cached_status = self._get_cached_status()
result = {}
result["job_id"] = str(job)
try:
if cached_status is None:
try:
cached_status = self.document["_status"]()
except KeyError:
cached_status = {}
result["operations"] = dict(
self._get_operations_status((job,), cached_status)
)
Expand Down Expand Up @@ -2187,7 +2194,7 @@ def _fetch_scheduler_status(self, jobs=None, file=None, ignore_errors=False):
status[submit_id] = int(
scheduler_info.get(submit_id, JobStatus.unknown)
)
self.document._status.update(status)
self.document["_status"].update(status)
except NoSchedulerError:
logger.debug("No scheduler available.")
except RuntimeError as error:
Expand All @@ -2197,7 +2204,7 @@ def _fetch_scheduler_status(self, jobs=None, file=None, ignore_errors=False):
else:
logger.info("Updated job status cache.")

def _get_group_status(self, group_name, ignore_errors=False, cached_status=None):
def _get_group_status(self, group_name, cached_status, ignore_errors=False):
"""Return status information about a group.

Status information is fetched for all jobs/aggregates associated with
Expand All @@ -2207,12 +2214,12 @@ def _get_group_status(self, group_name, ignore_errors=False, cached_status=None)
----------
group_name : str
Group name.
ignore_errors : bool
Whether to ignore exceptions raised during status check. (Default value = False)
cached_status : dict
Dictionary of cached status information. The keys are uniquely
generated ids for each group and job. The values are instances of
:class:`~.JobStatus`. (Default value = None)
:class:`~.JobStatus`.
ignore_errors : bool
Whether to ignore exceptions raised during status check. (Default value = False)

Returns
-------
Expand Down Expand Up @@ -2337,11 +2344,8 @@ def _fetch_status(

# Update the project's status cache
self._fetch_scheduler_status(aggregates, err, ignore_errors)
# Get status dict for all selected aggregates
try:
cached_status = self.document["_status"]()
except KeyError:
cached_status = {}
# Get project status cache
cached_status = self._get_cached_status()

get_job_labels = functools.partial(
self._get_job_labels,
Expand Down Expand Up @@ -3412,16 +3416,53 @@ def _get_submission_operations(
self,
aggregates,
default_directives,
cached_status,
names=None,
ignore_conditions=IgnoreConditions.NONE,
ignore_conditions_on_execution=IgnoreConditions.NONE,
):
r"""Grabs eligible :class:`~._JobOperation`\ s from :class:`~.FlowGroup`s."""
r"""Grabs eligible :class:`~._JobOperation`\ s from :class:`~.FlowGroup`\ s.

Parameters
----------
aggregates : sequence of aggregates
The aggregates to consider for submission.
default_directives : dict
The default directives to use for the operations. This is to allow
for user specified groups to 'inherit' directives from
``default_directives``. If no defaults are desired, the argument
can be set to an empty dictionary. This must be done explicitly,
however.
cached_status : dict
Dictionary of cached status information. The keys are uniquely
generated ids for each group and job. The values are instances of
:class:`~.JobStatus`.
names : iterable of :class:`str`
Only select operations that match the provided set of names
(interpreted as regular expressions), or all if the argument is
None. (Default value = None)
ignore_conditions : :class:`~.IgnoreConditions`
Specify if preconditions and/or postconditions are to be ignored
when determining eligibility. The default is
:class:`IgnoreConditions.NONE`.
ignore_conditions_on_execution : :class:`~.IgnoreConditions`
Specify if preconditions and/or postconditions are to be ignored
when determining eligibility after submitting. The default is
:class:`IgnoreConditions.NONE`.

Yields
------
:class:`~._SubmissionJobOperation`
Returns a :class:`~._SubmissionJobOperation` for submitting the
group. The :class:`~._JobOperation` will have directives that have
been collected appropriately from its contained operations.

"""
for group in self._gather_flow_groups(names):
for aggregate in self._get_aggregate_store(group.name).values():
if (
group._eligible(aggregate, ignore_conditions)
and self._eligible_for_submission(group, aggregate)
and self._eligible_for_submission(group, aggregate, cached_status)
and self._is_selected_aggregate(aggregate, aggregates)
):
yield group._create_submission_job_operation(
Expand Down Expand Up @@ -3875,11 +3916,13 @@ def submit(
# Gather all pending operations.
with self._potentially_buffered():
default_directives = self._get_default_directives()
cached_status = self._get_cached_status()
# The generator must be used *inside* the buffering context manager
# for performance reasons.
operation_generator = self._get_submission_operations(
aggregates,
default_directives,
cached_status,
names,
ignore_conditions,
ignore_conditions_on_execution,
Expand Down Expand Up @@ -4612,20 +4655,44 @@ def _get_aggregate_store(self, group):
return aggregate_store
return {}

def _eligible_for_submission(self, flow_group, jobs):
"""Check flow_group eligibility for submission with a job-aggregate.
def _eligible_for_submission(self, flow_group, jobs, cached_status):
"""Check group eligibility for submission with a job or aggregate.

By default, an operation is eligible for submission when it
is not considered active, that means already queued or running.

Parameters
----------
flow_group : :class:`~.FlowGroup`
The FlowGroup used to determine eligibility.
jobs : :class:`~signac.contrib.job.Job` or aggregate of jobs
The signac job or aggregate.
cached_status : dict
Dictionary of cached status information. The keys are uniquely
generated ids for each group and job. The values are instances of
:class:`~.JobStatus`.

Returns
-------
bool
Whether the group is eligible for submission with the provided job/aggregate.

"""
if flow_group is None or jobs is None:
return False
if flow_group._get_status(jobs) >= JobStatus.submitted:

def _group_is_submitted(flow_group):
"""Check if group has been submitted for the provided jobs."""
group_id = flow_group._generate_id(jobs)
job_status = JobStatus(cached_status.get(group_id, JobStatus.unknown))
return job_status >= JobStatus.submitted

if _group_is_submitted(flow_group):
return False
group_ops = set(flow_group)
for other_group in self._groups.values():
if group_ops & set(other_group):
if other_group._get_status(jobs) >= JobStatus.submitted:
if _group_is_submitted(other_group):
return False
return True

Expand Down Expand Up @@ -4750,9 +4817,11 @@ def _main_script(self, args):
with self._potentially_buffered():
names = args.operation_name if args.operation_name else None
default_directives = self._get_default_directives()
cached_status = self._get_cached_status()
operations = self._get_submission_operations(
aggregates,
default_directives,
cached_status,
names,
args.ignore_conditions,
args.ignore_conditions_on_execution,
Expand Down
41 changes: 32 additions & 9 deletions tests/test_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -1262,7 +1262,11 @@ def test_script(self):
project = self.mock_project()
# For run mode single operation groups
for job in project:
job_ops = project._get_submission_operations([(job,)], {})
job_ops = project._get_submission_operations(
aggregates=[(job,)],
default_directives={},
cached_status={},
)
script = project._script(job_ops)
if job.sp.b % 2 == 0:
assert str(job) in script
Expand All @@ -1288,30 +1292,49 @@ def test_script(self):

def test_directives_hierarchy(self):
project = self.mock_project()
cached_status = project._get_cached_status()
for job in project:
# Test submit JobOperations
job_ops = project._get_submission_operations(
(job,), project._get_default_directives(), names=["group2"]
job_ops = list(
project._get_submission_operations(
aggregates=[(job,)],
Copy link
Member Author

@bdice bdice Dec 30, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change was from a bug in the tests. When _get_submission_operations is called, it requires a sequence of aggregates, not a single aggregate. We validate this in the public method that calls _get_submission_operations, but this test calls an internal private method. As a result, this test was being fed invalid data and was passing for the wrong reason. Somewhere (perhaps _is_selected_aggregate), a single job was being compared to an aggregate containing only that job, and that made the test pass because no job operations were generated. I only found this problem because of a mistake in the Job.__eq__ method that was introduced in PR 442 and is being fixed in PR 455. I have expanded the test to ensure that non-zero (submission/run) job operations are created and fixed the problem.

default_directives=project._get_default_directives(),
cached_status=cached_status,
names=["group2"],
)
)
assert len(job_ops) == 1
assert all(
[job_op.directives.get("omp_num_threads", 0) == 4 for job_op in job_ops]
)
job_ops = project._get_submission_operations(
(job,), project._get_default_directives(), names=["op3"]
job_ops = list(
project._get_submission_operations(
aggregates=[(job,)],
default_directives=project._get_default_directives(),
cached_status=cached_status,
names=["op3"],
)
)
assert len(job_ops) == 1
assert all(
[job_op.directives.get("omp_num_threads", 0) == 1 for job_op in job_ops]
)
# Test run JobOperations
job_ops = project.groups["group2"]._create_run_job_operations(
project._entrypoint, project._get_default_directives(), (job,)
job_ops = list(
project.groups["group2"]._create_run_job_operations(
project._entrypoint, project._get_default_directives(), (job,)
)
)
assert len(job_ops) == 1
assert all(
[job_op.directives.get("omp_num_threads", 0) == 4 for job_op in job_ops]
)
job_ops = project.groups["op3"]._create_run_job_operations(
project._entrypoint, project._get_default_directives(), (job,)
job_ops = list(
project.groups["op3"]._create_run_job_operations(
project._entrypoint, project._get_default_directives(), (job,)
)
)
assert len(job_ops) == 1
assert all(
[job_op.directives.get("omp_num_threads", 0) == 1 for job_op in job_ops]
)
Expand Down