diff --git a/flow/__init__.py b/flow/__init__.py index 3b8c8a945..2ee02babb 100644 --- a/flow/__init__.py +++ b/flow/__init__.py @@ -12,12 +12,15 @@ from . import scheduling from . import errors from . import testing +from .aggregate import Aggregate +from .aggregate import get_aggregate_id from .project import IgnoreConditions from .project import FlowProject from .project import JobOperation from .project import label from .project import classlabel from .project import staticlabel +from .project import make_aggregates from .operations import cmd from .operations import directives from .operations import run @@ -34,6 +37,9 @@ __all__ = [ + 'Aggregate', + 'make_aggregates', + 'get_aggregate_id', 'environment', 'scheduling', 'errors', diff --git a/flow/aggregate.py b/flow/aggregate.py new file mode 100644 index 000000000..aaacb02d9 --- /dev/null +++ b/flow/aggregate.py @@ -0,0 +1,215 @@ +# Copyright (c) 2020 The Regents of the University of Michigan +# All rights reserved. +# This software is licensed under the BSD 3-Clause License. +from collections.abc import Iterable +from hashlib import md5 +from itertools import groupby +from itertools import zip_longest +from tqdm import tqdm + + +class Aggregate: + """Decorator for operation functions that are to be aggregated. + + By default, if the aggregator parameter is not passed, + an aggregate of all jobs will be created. + + .. code-block:: python + + example_aggregate = Aggregate() + @example_aggregate + @FlowProject.operation + def foo(*jobs): + print(len(jobs)) + + :param aggregator: + Information describing how to aggregate jobs. Is a callable that + takes in a list of jobs and can return or yield subsets of jobs as + an iterable. The default behavior is creating a single aggregate + of all jobs + :type aggregator: + callable + :param sort: + Before aggregating, sort the jobs by a given statepoint parameter. + The default behavior is no sorting. + :type sort: + str or NoneType + :param reverse: + States if the jobs are to be sorted in reverse order. + The default value is False. + :type reverse: + bool + :param select: + Condition for filtering individual jobs. This is passed as the + callable argument to `filter`. + The default behavior is no filtering. + :type select: + callable or NoneType + """ + + def __init__(self, aggregator=None, sort=None, reverse=False, select=None): + if aggregator is None: + def aggregator(jobs): + return [jobs] + + if not callable(aggregator): + raise TypeError("Expected callable for aggregator, got {}" + "".format(type(aggregator))) + + if sort is not None and not isinstance(sort, str): + raise TypeError("Expected string sort parameter, got {}" + "".format(type(sort))) + + if select is not None and not callable(select): + raise TypeError("Expected callable for select, got {}" + "".format(type(select))) + + if getattr(aggregator, '_num', False): + self._is_aggregate = False if aggregator._num == 1 else True + else: + self._is_aggregate = True + + self._aggregator = aggregator + self._sort = sort + self._reverse = reverse + self._select = select + + @classmethod + def groupsof(cls, num=1, sort=None, reverse=False, select=None): + # copied from: https://docs.python.org/3/library/itertools.html#itertools.zip_longest + try: + num = int(num) + if num <= 0: + raise ValueError('The num parameter should have a value greater than 0') + except TypeError: + raise TypeError('The num parameter should be an integer') + + def aggregator(jobs): + args = [iter(jobs)] * num + return zip_longest(*args) + setattr(aggregator, '_num', num) + return cls(aggregator, sort, reverse, select) + + @classmethod + def groupby(cls, key, default=None, sort=None, reverse=False, select=None): + if isinstance(key, str): + if default is None: + def keyfunction(job): + return job.sp[key] + else: + def keyfunction(job): + return job.sp.get(key, default) + + elif isinstance(key, Iterable): + keys = list(key) + + if default is None: + def keyfunction(job): + return [job.sp[key] for key in keys] + else: + if isinstance(default, Iterable): + if len(default) != len(keys): + raise ValueError("Expected length of default argument is {}, " + "got {}.".format(len(keys), len(default))) + else: + raise TypeError("Invalid default argument. Expected Iterable, " + "got {}".format(type(default))) + + def keyfunction(job): + return [job.sp.get(key, default[i]) for i, key in enumerate(keys)] + + elif callable(key): + keyfunction = key + + else: + raise TypeError("Invalid key argument. Expected either str, Iterable " + "or a callable, got {}".format(type(key))) + + def aggregator(jobs): + for key, group in groupby(sorted(jobs, key=keyfunction), key=keyfunction): + yield group + + return cls(aggregator, sort, reverse, select) + + def _create_MakeAggregate(self): + return MakeAggregate(self._aggregator, self._sort, self._reverse, self._select) + + def __call__(self, func=None): + if callable(func): + setattr(func, '_flow_aggregate', self._create_MakeAggregate()) + return func + else: + raise TypeError('Invalid argument passed while calling ' + 'the aggregate instance. Expected a callable, ' + 'got {}.'.format(type(func))) + + +class MakeAggregate(Aggregate): + r"""This class handles the creation of aggregates. + + .. note:: + This class should not be instantiated by users directly. + :param \*args: + Passed to the constructor of :py:class:`Aggregate`. + """ + def __init__(self, *args): + super(MakeAggregate, self).__init__(*args) + + def __call__(self, obj, group_name='unknown-operation', project=None): + "Return aggregated jobs." + aggregated_jobs = list(obj) + if self._select is not None: + aggregated_jobs = list(filter(self._select, aggregated_jobs)) + if self._sort is not None: + aggregated_jobs = list(sorted(aggregated_jobs, + key=lambda job: job.sp[self._sort], + reverse=bool(self._reverse))) + + aggregated_jobs = self._aggregator([job for job in aggregated_jobs]) + aggregated_jobs = self._create_nested_aggregate_list(aggregated_jobs, group_name, project) + if not len(aggregated_jobs): + return [] + return aggregated_jobs + + def _create_nested_aggregate_list(self, aggregated_jobs, group_name, project): + # This method converts the returned subset of jobs as an Iterable + # from an aggregator function to a subset of jobs as list. + aggregated_jobs = list(aggregated_jobs) + nested_aggregates = [] + + desc = f"Collecting aggregates for {group_name}" + for aggregate in tqdm(aggregated_jobs, total=len(aggregated_jobs), + desc=desc, leave=False): + try: + filter_aggregate = [] + for job in aggregate: + if job is None: + continue + if project is not None: + if job not in project: + raise ValueError(f'The signac job {str(job)} not found in {project}') + filter_aggregate.append(job) + filter_aggregate = tuple(filter_aggregate) + if project is not None: + project._aggregates_ids[get_aggregate_id(filter_aggregate)] = \ + filter_aggregate + nested_aggregates.append(filter_aggregate) + except Exception: + raise ValueError("Invalid aggregator function provided by " + "the user.") + return nested_aggregates + + +def get_aggregate_id(jobs): + """Generate hashed id for an aggregate of jobs. + + :param jobs: + The signac job handles + :type jobs: + tuple + """ + if len(jobs) == 1: + return str(jobs[0]) # Return job id as it's already unique + + blob = ''.join((job.get_id() for job in jobs)) + return f'agg-{md5(blob.encode()).hexdigest()}' diff --git a/flow/project.py b/flow/project.py index 280e13a8c..de1ff9e68 100644 --- a/flow/project.py +++ b/flow/project.py @@ -47,6 +47,8 @@ from enum import IntFlag +from .aggregate import Aggregate +from .aggregate import get_aggregate_id from .environment import get_environment from .scheduling.base import ClusterJob from .scheduling.base import JobStatus @@ -169,7 +171,7 @@ def __init__(self, condition, tag=None): @classmethod def isfile(cls, filename): - "True if the specified file exists for this job." + "True if the specified file exists for this job-aggregate." def _isfile(*jobs): return all(job.isfile(filename) for job in jobs) @@ -177,8 +179,8 @@ def _isfile(*jobs): @classmethod def true(cls, key): - """True if the specified key is present in the job document and - evaluates to True.""" + """True if the specified key is present in the job document + of every job in the job-aggregate and evaluates to True.""" def _document(*jobs): return all(job.document.get(key, False) for job in jobs) @@ -186,8 +188,8 @@ def _document(*jobs): @classmethod def false(cls, key): - """True if the specified key is present in the job document and - evaluates to False.""" + """True if the specified key is present in the job document + of every job in the job-aggregate and evaluates to True.""" def _no_document(*jobs): return all(not job.document.get(key, False) for job in jobs) return cls(_no_document, 'false_' + key) @@ -200,7 +202,8 @@ def never(cls, func): @classmethod def not_(cls, condition): - """Returns ``not condition(job)`` for the provided condition function.""" + """Evaluates to True if the provided condition function returns + False for all the jobs in the job-aggregate passed at once.""" def _not(*jobs): return not condition(*jobs) return cls(_not, @@ -236,8 +239,40 @@ def _make_bundles(operations, size=None): break +def make_aggregates(jobs, aggregator=None): + """Utility function for the generation of aggregates. + + This function returns the generated aggregates using the + aggregator parameter passed by the user. + + By default an aggregate of all jobs passed will be returned + + :param jobs: + The signac job handles. + :type jobs: + Iterable of type :class:`signac.contrib.job.Job` + :param aggregator: + The aggregator object to aggregate jobs. + :type aggregator: + :class:`Aggregate` + :return: + Aggregated jobs. + :rtype: + list + """ + if aggregator is None: + aggregator = Aggregate() + + if not isinstance(aggregator, Aggregate): + raise ValueError("Please provide a valid aggregator object. " + "Expected parameter of type Aggregate, " + f"got {type(aggregator)}") + + return aggregator._create_MakeAggregate()(jobs) + + class _JobOperation(object): - """This class represents the information needed to execute one group for one job. + """This class represents the information needed to execute one group for a job-aggregate. The execution or submission of a :py:class:`FlowGroup` uses a passed-in command which can either be a string or function with no arguments that returns a shell @@ -301,11 +336,10 @@ def __init__(self, id, name, jobs, cmd, directives=None): self.directives._keys_set_by_user = keys_set_by_user def __str__(self): - assert len(self._jobs) > 0 max_len = 3 min_len_unique_id = self._jobs[0]._project.min_len_unique_id() if len(self._jobs) > max_len: - shown = self._jobs[:max_len-2] + ['...'] + self._jobs[-1:] + shown = list(self._jobs[:max_len-2]) + ['...'] + list(self._jobs[-1:]) else: shown = self._jobs return f"{self.name}[#{len(self._jobs)}]" \ @@ -344,12 +378,6 @@ def cmd(self): def set_status(self, value): "Store the operation's status." - # Since #324 doesn't include actual aggregation, it is guaranteed that the length - # of self._jobs is equal to 1, hence we won't be facing the problem for lost - # aggregates. #335 introduces the concept of storing aggregates which will - # help retrieve the information of lost aggregates. The storage of aggregates - # will be similar to bundles hence no change will be made to this method. - # This comment should be removed after #335 gets merged. self._jobs[0]._project.document.setdefault('_status', dict())[self.id] = int(value) def get_status(self): @@ -440,7 +468,7 @@ def __repr__(self): class _SubmissionJobOperation(_JobOperation): - R"""This class represents the information needed to submit one group for one job. + R"""This class represents the information needed to submit one group for a job-aggregate. This class extends :py:class:`_JobOperation` to include a set of groups that will be executed via the "run" command. These groups are known at @@ -450,7 +478,7 @@ class _SubmissionJobOperation(_JobOperation): Passed to the constructor of :py:class:`_JobOperation`. :param eligible_operations: A list of :py:class:`_JobOperation` that will be executed when this - submitted job is executed. + submitted job-aggregate is executed. :type eligible_operations: list :param operations_with_unmet_preconditions: @@ -497,14 +525,14 @@ def __init__( class _FlowCondition(object): - """A _FlowCondition represents a condition as a function of a signac job. + """A _FlowCondition represents a condition as a function of a job-aggregate. The __call__() function of a _FlowCondition object may return either True or False, representing whether the condition is met or not. This can be used to build a graph of conditions and operations. :param callback: - A callable with one positional argument (the job). + A callable with one or multiple positional arguments or a variable argument. :type callback: callable """ @@ -516,11 +544,10 @@ def __call__(self, jobs): try: return self._callback(*jobs) except Exception as e: - assert len(jobs) == 1 raise UserConditionError( 'An exception was raised while evaluating the condition {name} ' - 'for job {jobs}.'.format(name=self._callback.__name__, - jobs=', '.join(map(str, jobs)))) from e + 'for job(s) {jobs}.'.format(name=self._callback.__name__, + jobs=', '.join(map(str, jobs)))) from e def __hash__(self): return hash(self._callback) @@ -530,14 +557,15 @@ def __eq__(self, other): class BaseFlowOperation(object): - """A BaseFlowOperation represents a data space operation, operating on any job. + """A BaseFlowOperation represents a data space operation, operating on any job-aggregate. Every BaseFlowOperation is associated with a specific command. Pre-requirements (pre) and post-conditions (post) can be used to trigger an operation only when certain conditions are met. Conditions are unary - callables, which expect an instance of job as their first and only positional - argument and return either True or False. + callables, which expects multiple positional arguments or a variable arguement + which expects instance(s) of :class:`~signac.contrib.job.Job`and return either + True or False. An operation is considered "eligible" for execution when all pre-requirements are met and when at least one of the post-conditions is not met. @@ -548,7 +576,7 @@ class BaseFlowOperation(object): This class should not be instantiated directly. :param cmd: - The command to execute operation; should be a function of job. + The command to execute operation; should be a function of one or multiple jobs. :type cmd: str or callable :param pre: @@ -634,13 +662,14 @@ class FlowCmdOperation(BaseFlowOperation): """A BaseFlowOperation that holds a shell executable command. When an operation has the ``@cmd`` directive specified, it is instantiated - as a FlowCmdOperation. The operation should be a function of - :py:class:`~signac.contrib.job.Job`. The command (cmd) may - either be a unary callable that expects an instance of - :class:`~signac.contrib.job.Job` as its only positional argument and returns - a string containing valid shell commands, or the string of commands itself. - In either case, the resulting string may contain any attributes of the job placed - in curly braces, which will then be substituted by Python string formatting. + as a FlowCmdOperation. The operation may be a function of multiple + :py:class:`~signac.contrib.job.Job` depending on the type of aggregator user uses. + The command (cmd) is a callable which expects multiple positional arguments + or a variable arguement which expects instance(s) of :class:`~signac.contrib.job.Job` + and returns a string containing valid shell commands, or the string of commands itself. + In either case, the resulting string may contain any attributes of the job or + aggregate of jobs place in curly braces, which will then be substituted by + Python string formatting. .. note:: This class should not be instantiated directly. @@ -665,9 +694,15 @@ def __call__(self, *jobs, **kwargs): job = jobs[0] if len(jobs) == 1 else None if callable(self._cmd): - return self._cmd(job).format(job=job) + cmd_val = self._cmd(*jobs) + if len(jobs) == 1: + try: + return cmd_val.format(job=job) + except KeyError: + pass + return cmd_val.format(*jobs) else: - return self._cmd.format(job=job) + return self._cmd.format(*jobs) class FlowOperation(BaseFlowOperation): @@ -688,8 +723,8 @@ def __str__(self): return "{type}(op_func='{op_func}')" \ "".format(type=type(self).__name__, op_func=self._op_func) - def __call__(self, job): - return self._op_func(job) + def __call__(self, jobs): + return self._op_func(*jobs) class FlowGroupEntry(object): @@ -711,10 +746,15 @@ class FlowGroupEntry(object): commands to execute. :type options: str + :param aggregate: + Object that aggregates jobs for a :py:class:`FlowGroup`. + :type aggregate: + :py:class:`MakeAggregate` """ - def __init__(self, name, options=""): + def __init__(self, name, options="", aggregate=None): self.name = name self.options = options + self.aggregate = aggregate def __call__(self, func): """Decorator that adds the function into the group's operations. @@ -770,7 +810,7 @@ class FlowGroup(object): """A FlowGroup represents a subset of a workflow for a project. Any :py:class:`FlowGroup` is associated with one or more instances of - :py:class:`BaseFlowOperation`. + :py:class:`BaseFlowOperation` having same aggregations. In the example below, the directives will be {'nranks': 4} for op1 and {'nranks': 2, 'executable': 'python3'} for op2 @@ -815,12 +855,16 @@ def op2(job): This lets options like ``--num_passes`` to be given to a group. :type options: str + :param aggregate: + Object that aggregates jobs for this group. + :type aggregate: + :py:class:`MakeAggregate` """ MAX_LEN_ID = 100 def __init__(self, name, operations=None, operation_directives=None, - options=""): + options="", aggregate=None): self.name = name self.options = options # An OrderedDict is not necessary here, but is used to ensure @@ -830,6 +874,7 @@ def __init__(self, name, operations=None, operation_directives=None, self.operation_directives = dict() else: self.operation_directives = operation_directives + self.aggregate = aggregate def _set_entrypoint_item(self, entrypoint, directives, key, default, jobs): """Set a value (executable, path) for entrypoint in command. @@ -892,7 +937,7 @@ def evaluate(value): def _submit_cmd(self, entrypoint, ignore_conditions, jobs=None): entrypoint = self._determine_entrypoint(entrypoint, dict(), jobs) cmd = "{} run -o {}".format(entrypoint, self.name) - cmd = cmd if jobs is None else cmd + ' -j {}'.format(' '.join(map(str, jobs))) + cmd = cmd if jobs is None else cmd + f' -j {get_aggregate_id(jobs)}' cmd = cmd if self.options is None else cmd + ' ' + self.options if ignore_conditions != IgnoreConditions.NONE: return cmd.strip() + ' --ignore-conditions=' + str(ignore_conditions) @@ -904,7 +949,7 @@ def _run_cmd(self, entrypoint, operation_name, operation, directives, jobs): return operation(*jobs).lstrip() else: entrypoint = self._determine_entrypoint(entrypoint, directives, jobs) - return f"{entrypoint} exec {operation_name} {' '.join(map(str, jobs))}".lstrip() + return f"{entrypoint} exec {operation_name} {get_aggregate_id(jobs)}".lstrip() def __iter__(self): yield from self.operations.values() @@ -988,7 +1033,7 @@ def complete(self, job): """ return self._complete((job,)) - def add_operation(self, name, operation, directives=None): + def add_operation(self, name, operation, aggregate, directives=None): """Add an operation to the FlowGroup. :param name: @@ -999,12 +1044,18 @@ def add_operation(self, name, operation, directives=None): The workflow operation to add to the FlowGroup. :type operation: :py:class:`BaseFlowOperation` + :param aggregate: + Object that aggregates jobs for this group. + :type aggregate: + :py:class:`MakeAggregate` :param directives: The operation specific directives. :type directives: dict """ self.operations[name] = operation + if self.aggregate is None: + self.aggregate = aggregate if directives is not None: self.operation_directives[name] = directives @@ -1406,6 +1457,14 @@ def __init__(self, config=None, environment=None, entrypoint=None): self._groups = dict() self._register_groups() + # Register all aggregates which are created for this project + self._aggregates = dict() + self._aggregates_ids = dict() + # We'd wan't to keep track of orphan aggregates for status view inorder to fetch + # the actual aggregate from the aggregate id of those aggregates. + self._orphan = dict() + self.register_aggregates() + def _setup_template_environment(self): """Setup the jinja2 template environment. @@ -1699,6 +1758,71 @@ def _store_bundled(self, operations): file.write(operation.id + '\n') return bid + def _fn_aggregate(self, id): + "Return the canonical name to store aggregate information." + return os.path.join(self.root_directory(), '.aggregate', id) + + def _store_aggregates(self, operations): + """Store aggregate information on a per operation basis. + + This enables status check of aggregates which were + formed previously but are not present currently for any + operation. + + :param operations: + The operations to be submitted. + :type operations: + A sequence of instances of :py:class:`.JobOperation` + """ + for operation in operations: + op_fn = '{}.txt'.format(operation.name) + aggregate_wid = '{} {}'.format(operation.id, ' '.join(map(str, operation._jobs))) + fn_aggregate = self._fn_aggregate(op_fn) + os.makedirs(os.path.dirname(fn_aggregate), exist_ok=True) + if os.path.exists(fn_aggregate): + with open(fn_aggregate, 'r+') as file: + agg_file_contents = file.read() + if aggregate_wid in agg_file_contents.splitlines(): + continue + else: + file.write(aggregate_wid + ' \n') + else: + with open(fn_aggregate, 'w') as file: + file.write(aggregate_wid + ' \n') + + def _fetch_aggregates(self, group): + """Fetch submitted aggregates for a group. + + This enables the fetching status of lost aggregates. + + :param group: + FlowGroup associated with the operation. + :type group: + list :py:class:`flow.FlowGroup` + :yields: + All aggregates which were stored during submission. + """ + dir = '.aggregate/{}.txt'.format(group.name) + if os.path.exists(dir): + with open(dir, 'r') as file: + for obj in file: + aggregate = [] + _ids = obj.split(' ') + submission_id = _ids[0] + job_ids = _ids[1:-1] + try: + for job_id in job_ids: + aggregate.append(self.open_job(id=job_id)) + + aggregate = tuple(aggregate) + # Checking whether the aggregate and the submission id match. + # If not, then a user must have changed the submission id. + # Hence skip this aggregate. + if group._generate_id(aggregate) == submission_id: + yield aggregate + except KeyError: # Not able to open the job via job id. + pass + def _expand_bundled_jobs(self, scheduler_jobs): "Expand jobs which were submitted as part of a bundle." for job in scheduler_jobs: @@ -1728,64 +1852,81 @@ def scheduler_jobs(self, scheduler): for sjob in self._expand_bundled_jobs(scheduler.jobs()): yield sjob - def _get_operations_status(self, job, cached_status): - "Return a dict with information about job-operations for this job." + def _get_operations_status(self, jobs, cached_status): + "Return a dict with information about job-operations for this job-aggregate." starting_dict = functools.partial(dict, scheduler_status=JobStatus.unknown) status_dict = defaultdict(starting_dict) for group in self._groups.values(): - completed = group._complete((job,)) - eligible = False if completed else group._eligible((job,)) - scheduler_status = cached_status.get(group._generate_id((job,)), - JobStatus.unknown) - for operation in group.operations: - if scheduler_status >= status_dict[operation]['scheduler_status']: - status_dict[operation] = { - 'scheduler_status': scheduler_status, - 'eligible': eligible, - 'completed': completed - } + for aggregate in self.aggregates[group.name]: + if aggregate != jobs: + continue + completed = group._complete(aggregate) + eligible = False if completed else group._eligible(aggregate) + scheduler_status = cached_status.get(group._generate_id(aggregate), + JobStatus.unknown) + for operation in group.operations: + if scheduler_status >= status_dict[operation]['scheduler_status']: + status_dict[operation] = { + 'scheduler_status': scheduler_status, + 'eligible': eligible, + 'completed': completed + } for key in sorted(status_dict): yield key, status_dict[key] def get_job_status(self, job, ignore_errors=False, cached_status=None): - "Return a dict with detailed information about the status of a job." + "Return a dict with detailed information about the status of a job or an aggregate of jobs." result = dict() - result['job_id'] = str(job) + if isinstance(job, signac.contrib.job.Job): + distinct_jobs = [job] + jobs = (job,) + result['job_id'] = get_aggregate_id(jobs) + else: + jobs = tuple(job) # Avoid confusion for developers as an aggregate can also be passed + self._verify_aggregate_project([jobs]) + distinct_jobs = [job for job in jobs] + result['aggregate_id'] = get_aggregate_id(jobs) try: if cached_status is None: try: cached_status = self.document['_status']._as_dict() except KeyError: cached_status = dict() - result['operations'] = OrderedDict(self._get_operations_status(job, cached_status)) + result['operations'] = OrderedDict(self._get_operations_status(jobs, cached_status)) result['_operations_error'] = None except Exception as error: - msg = "Error while getting operations status for job '{}': '{}'.".format(job, error) + msg = "Error while getting operations status for job(s) '{}': '{}'.".format(job, error) logger.debug(msg) if ignore_errors: result['operations'] = dict() result['_operations_error'] = str(error) else: raise - try: - result['labels'] = sorted(set(self.labels(job))) - result['_labels_error'] = None - except Exception as error: - logger.debug("Error while determining labels for job '{}': '{}'.".format(job, error)) - if ignore_errors: - result['labels'] = list() - result['_labels_error'] = str(error) - else: - raise + result['labels'] = dict() + result['_labels_error'] = dict() + for job in distinct_jobs: + try: + result['labels'][str(job)] = sorted(set(self.labels(job))) + result['_labels_error'][str(job)] = None + except Exception as error: + logger.debug(f"Error while determining labels for job '{job}': '{error}'.") + if ignore_errors: + result['labels'][str(job)] = list() + result['_labels_error'][str(job)] = str(error) + else: + raise + + # if len(distinct_jobs) == 1: # Doesn't break the current API + # result['labels'] = result['labels'][str(distinct_jobs[0])] + # result['_labels_error'] = result['_labels_error'][str(job)] + return result def _fetch_scheduler_status(self, jobs=None, file=None, ignore_errors=False): "Update the status docs." if file is None: file = sys.stderr - if jobs is None: - jobs = list(self) try: scheduler = self._environment.get_scheduler() @@ -1793,12 +1934,16 @@ def _fetch_scheduler_status(self, jobs=None, file=None, ignore_errors=False): scheduler_info = {sjob.name(): sjob.status() for sjob in self.scheduler_jobs(scheduler)} status = dict() print("Query scheduler...", file=file) - for job in tqdm(jobs, - desc="Fetching operation status", - total=len(jobs), file=file): - for group in self._groups.values(): - _id = group._generate_id((job,)) - status[_id] = int(scheduler_info.get(_id, JobStatus.unknown)) + for group in tqdm(self._groups.values(), + desc="Fetching operation status", + total=len(self._groups), file=file): + aggregated_jobs = self.aggregates[group.name] + for aggregate in tqdm(aggregated_jobs, total=len(aggregated_jobs), + desc="Fetching aggregate info for aggregate", + leave=False, file=file): + if self._verify_aggregate_in_jobs(aggregate, jobs): + _id = group._generate_id(aggregate) + status[_id] = int(scheduler_info.get(_id, JobStatus.unknown)) self.document._status.update(status) except NoSchedulerError: logger.debug("No scheduler available.") @@ -1809,7 +1954,129 @@ def _fetch_scheduler_status(self, jobs=None, file=None, ignore_errors=False): else: logger.info("Updated job status cache.") - def _fetch_status(self, jobs, err, ignore_errors, status_parallelization='thread'): + def _get_group_status(self, group_name, ignore_errors=False, cached_status=None, + orphan=False): + "Return a dict with detailed information about the status of jobs per group." + result = dict() + group = self._groups[group_name] + # The group associated with the group_name contains only a + # single operation having its name equals group_name. + # Hence set the operation_name to group_name + result['operation_name'] = group_name + status_dict = dict() + errors = dict() + + aggregates = self.aggregates[group.name] + + if orphan: + fetched_aggregates = self._fetch_aggregates(group) + for aggregate in fetched_aggregates: + if aggregate not in aggregates: + self._orphan[get_aggregate_id(aggregate)] = aggregate + aggregates.append(aggregate) + + def _get_job_ids(aggregate): + return ' '.join(map(str, aggregate)) + + def _get_agg_details(aggregate, is_aggregate): + if is_aggregate: + return 'aggregate - {}'.format(len(aggregate)) + else: + return 'non-aggregate' + + for aggregate in tqdm(aggregates, + desc=f"Collecting aggregate status info for operation {group.name}", + leave=False, total=len(aggregates)): + for job in aggregate: + if errors.get(str(job), None) is None: + errors[str(job)] = None + try: + _id = group._generate_id(aggregate) + completed = group._complete(aggregate) + eligible = False if completed else group._eligible(aggregate) + scheduler_status = cached_status.get(_id, JobStatus.unknown) + status_dict[get_aggregate_id(aggregate)] = { + 'scheduler_status': scheduler_status, + 'eligible': eligible, + 'completed': completed, + 'is_aggregate': group.aggregate._is_aggregate, + 'job_ids': _get_job_ids(aggregate), + 'aggregate_detail': _get_agg_details( + aggregate, + group.aggregate._is_aggregate) + } + except Exception as error: + msg = "Error while getting operations status for aggregate " \ + "'{}': '{}'.".format(_get_job_ids(aggregate), error) + logger.debug(msg) + status_dict[get_aggregate_id(aggregate)] = { + 'scheduler_status': JobStatus.unknown, + 'eligible': False, + 'completed': False, + 'is_aggregate': group.aggregate._is_aggregate, + 'job_ids': _get_job_ids(aggregate), + 'aggregate_detail': _get_agg_details( + aggregate, + group.aggregate._is_aggregate) + } + if ignore_errors: + for job in aggregate: + if errors[str(job)] is None: + errors[str(job)] = str(error) + else: + errors[str(job)] += '\n' + str(error) + else: + raise + + result['aggregate_details'] = status_dict + result['_operation_error_per_job'] = errors + return result + + def _get_job_labels(self, job, ignore_errors=False): + "Return a dict with information about the labels of a job." + result = dict() + result['job_id'] = str(job) + try: + result['labels'] = sorted(set(self.labels(job))) + result['_labels_error'] = None + except Exception as error: + logger.debug("Error while determining labels for job '{}': '{}'.".format(job, error)) + if ignore_errors: + result['labels'] = list() + result['_labels_error'] = str(error) + else: + raise + return result + + def _fetch_status(self, jobs, distinct_jobs, err, ignore_errors, + status_parallelization='thread', orphan=False): + """Fetch status associated for either all the jobs associated in a project + or jobs specified by a user + + :param jobs: + The aggregates which a user requested to fetch status for. + :type jobs: + list of aggregates + :param distinct_jobs: + Distinct jobs fetched from the ids provided in the ``jobs`` argument. + This is used for fetching labels for a job because a label is not associated + with an aggregate. + :type distinct_jobs: + list of :py:class:`signac.contrib.job.Job` + :param ignore_errors: + Fetch status even if querying the scheduler fails. + :type ignore_errors: + bool + :param status_parallelization: + Nature of parallelization for fetching the status. + Default value parallelizes using ``multiprocessing.ThreadPool()`` + :type status_parallelization: + str + :param orphan: + Print status for lost aggregates. + :type orphan: + bool + """ # The argument status_parallelization is used so that _fetch_status method # gets to know whether the deprecated argument no_parallelization passed # while calling print_status is True or False. This can also be done by @@ -1821,24 +2088,49 @@ def _fetch_status(self, jobs, err, ignore_errors, status_parallelization='thread self._fetch_scheduler_status(jobs, err, ignore_errors) # Get status dict for all selected jobs - def _print_progress(x): - print("Updating status: ", end='', file=err) - err.flush() - n = max(1, int(len(jobs) / 10)) - for i, _ in enumerate(x): - if (i % n) == 0: - print('.', end='', file=err) - err.flush() - yield _ + def _print_status(iterable, fetch_status, description): + t = time.time() + num_itr = len(iterable) + results = [] + for i, itr in enumerate(iterable): + results.append(fetch_status(itr)) + # The status interval 0.2 seconds is used since we expect the + # status for an aggregate to be fetched within that interval + if time.time() - t > 0.2: + print(f'{description}: {i+1}/{num_itr}', end='\r', file=err) + t = time.time() + # Always print the completed progressbar. + print(f'{description}: {i+1}/{num_itr}', file=err) + return results try: cached_status = self.document['_status']._as_dict() except KeyError: cached_status = dict() - _get_job_status = functools.partial(self.get_job_status, - ignore_errors=ignore_errors, - cached_status=cached_status) + _get_job_labels = functools.partial(self._get_job_labels, + ignore_errors=ignore_errors) + + _get_group_status = functools.partial(self._get_group_status, + ignore_errors=ignore_errors, + cached_status=cached_status, + orphan=orphan) + + singleton_groups = [op for op in self.operations] + + def _generate_results_with_tqdm(iterable, map=map, desc=None, len_itr=None): + if iterable == 'groups': + return list(tqdm( + iterable=map(_get_group_status, singleton_groups), + desc="Collecting operation status", + total=len(singleton_groups), file=err)) + elif iterable == 'job-labels': + return list(tqdm( + iterable=map(_get_job_labels, distinct_jobs), + desc="Collecting job label info", total=len(distinct_jobs), file=err)) + else: + return list(tqdm( + iterable=iterable, desc=desc, total=len_itr, file=err)) with self._potentially_buffered(): try: @@ -1846,15 +2138,15 @@ def _print_progress(x): with contextlib.closing(ThreadPool()) as pool: # First attempt at parallelized status determination. # This may fail on systems that don't allow threads. - return list(tqdm( - iterable=pool.imap(_get_job_status, jobs), - desc="Collecting job status info", total=len(jobs), file=err)) + label_results = _generate_results_with_tqdm('job-labels', pool.imap) + op_results = _generate_results_with_tqdm('groups', pool.imap) elif status_parallelization == 'process': with contextlib.closing(Pool()) as pool: try: import pickle - results = self._fetch_status_in_parallel( - pool, pickle, jobs, ignore_errors, cached_status) + l_results, g_results = self._fetch_status_in_parallel( + pool, pickle, distinct_jobs, singleton_groups, ignore_errors, + cached_status, orphan) except Exception as error: if not isinstance(error, (pickle.PickleError, self._PickleError)) and\ 'pickle' not in str(error).lower(): @@ -1870,52 +2162,90 @@ def _print_progress(x): raise error else: try: - results = self._fetch_status_in_parallel( - pool, cloudpickle, jobs, ignore_errors, cached_status) + l_results, g_results = self._fetch_status_in_parallel( + pool, cloudpickle, distinct_jobs, singleton_groups, + ignore_errors, cached_status, orphan) except self._PickleError as error: raise RuntimeError( "Unable to parallelize execution due to a pickling " "error: {}.".format(error)) - return list(tqdm( - iterable=results, - desc="Collecting job status info", total=len(jobs), file=err)) + label_results = _generate_results_with_tqdm( + l_results, desc="Collecting job label info", len_itr=len(distinct_jobs)) + op_results = _generate_results_with_tqdm( + g_results, desc="Collecting operation status", + len_itr=len(singleton_groups)) elif status_parallelization == 'none': - return list(tqdm( - iterable=map(_get_job_status, jobs), - desc="Collecting job status info", total=len(jobs), file=err)) + label_results = _generate_results_with_tqdm('job-labels') + op_results = _generate_results_with_tqdm('groups') else: raise RuntimeError("Configuration value status_parallelization is invalid. " - "You can set it to 'thread', 'parallel', or 'none'" - ) + "You can set it to 'thread', 'parallel', or 'none'") except RuntimeError as error: if "can't start new thread" not in error.args: raise # unrelated error + label_results = _print_status(distinct_jobs, _get_job_labels, + "Collecting job label info") + op_results = _print_status(singleton_groups, _get_group_status, + "Collecting operation status") + + results, index, i = list(), dict(), 0 + + for job in distinct_jobs: + result = dict() + result['job_id'] = str(job) + result['operations'] = dict() + result['labels'] = list() + result['_operations_error'] = list() + for op_result in op_results: + result['operations'][op_result['operation_name']] = [] + result['_labels_error'] = list() + results.append(result) + index[str(job)] = i + i += 1 + + for op_result in op_results: + for id, aggregate_status in op_result['aggregate_details'].items(): + if self._aggregates_ids.get(id, False): + aggregate = self._aggregates_ids[id] + elif status_parallelization != 'process': # Got orphan aggregates + # Unable to set self._lost_aggregate while parallelizing + # using 'process'. + aggregate = self._orphan[id] + else: + warnings.warn("Can't fetch orphan aggregates while parallelizing " + "using `process`. Please use `thread` or `none` instead.") + continue + + if not self._verify_aggregate_in_jobs(aggregate, jobs): + continue + for job in aggregate: + results[index[str(job)]]['operations'][op_result['operation_name']].append( + aggregate_status + ) + results[index[str(job)]]['_operations_error'] = \ + op_result['_operation_error_per_job'][str(job)] + + for label_result in label_results: + results[index[label_result['job_id']]]['labels'] = label_result['labels'] + results[index[label_result['job_id']]]['_labels_error'] = label_result['_labels_error'] - t = time.time() - num_jobs = len(jobs) - statuses = [] - for i, job in enumerate(jobs): - statuses.append(_get_job_status(job)) - if time.time() - t > 0.2: # status interval - print( - 'Collecting job status info: {}/{}'.format(i+1, num_jobs), - end='\r', file=err) - t = time.time() - # Always print the completed progressbar. - print('Collecting job status info: {}/{}'.format(i+1, num_jobs), file=err) - return statuses - - def _fetch_status_in_parallel(self, pool, pickle, jobs, ignore_errors, cached_status): + return results + + def _fetch_status_in_parallel(self, pool, pickle, jobs, groups, ignore_errors, + cached_status, orphan): try: s_project = pickle.dumps(self) - s_tasks = [(pickle.loads, s_project, job.get_id(), ignore_errors, cached_status) - for job in jobs] + s_tasks_labels = [(pickle.loads, s_project, job.get_id(), ignore_errors) + for job in jobs] + s_tasks_groups = [(pickle.loads, s_project, group, ignore_errors, cached_status, + orphan) for group in groups] except Exception as error: # Masking all errors since they must be pickling related. raise self._PickleError(error) - results = pool.imap(_serialized_get_job_status, s_tasks) + label_results = pool.imap(_serialized_get_job_labels, s_tasks_labels) + group_results = pool.imap(_serialized_get_group_status, s_tasks_groups) - return results + return label_results, group_results PRINT_STATUS_ALL_VARYING_PARAMETERS = True """This constant can be used to signal that the print_status() method is supposed @@ -1928,13 +2258,15 @@ def print_status(self, jobs=None, overview=True, overview_max_lines=None, unroll=True, compact=False, pretty=False, file=None, err=None, ignore_errors=False, no_parallelize=False, template=None, profile=False, - eligible_jobs_max_lines=None, output_format='terminal'): + eligible_jobs_max_lines=None, output_format='terminal', + orphan=False): """Print the status of the project. :param jobs: - Only execute operations for the given jobs, or all if the argument is omitted. + Only execute operations for the given jobs or aggregates, + or all if the argument is omitted. :type jobs: - Sequence of instances :class:`.Job`. + Sequence of instances :class:`.Job` or list of :class:`.Job`. :param overview: Aggregate an overview of the project' status. :type overview: @@ -2012,6 +2344,10 @@ def print_status(self, jobs=None, overview=True, overview_max_lines=None, 'terminal' (default), 'markdown' or 'html'. :type output_format: str + :param orphan: + Print status for lost aggregates. + :type orphan: + bool :return: A Renderer class object that contains the rendered string. :rtype: @@ -2021,8 +2357,18 @@ def print_status(self, jobs=None, overview=True, overview_max_lines=None, file = sys.stdout if err is None: err = sys.stderr - if jobs is None: - jobs = self # all jobs + + # Convert all the signac jobs into an aggregate of 1 + aggregates = self._convert_aggregates_from_jobs(jobs) + + if aggregates is not None: + # Fetch all the distinct jobs from all the jobs or aggregate passed by the user + distinct_jobs = set() + for aggregate in aggregates: + for job in aggregate: + distinct_jobs.add(job) + else: + distinct_jobs = self if eligible_jobs_max_lines is None: eligible_jobs_max_lines = flow_config.get_config_value('eligible_jobs_max_lines') @@ -2065,7 +2411,8 @@ def print_status(self, jobs=None, overview=True, overview_max_lines=None, ] with prof(single=False): - tmp = self._fetch_status(jobs, err, ignore_errors, status_parallelization) + tmp = self._fetch_status(aggregates, distinct_jobs, err, ignore_errors, + status_parallelization, orphan) prof._mergeFileTiming() @@ -2129,7 +2476,8 @@ def print_status(self, jobs=None, overview=True, overview_max_lines=None, "results may be highly inaccurate.") else: - tmp = self._fetch_status(jobs, err, ignore_errors, status_parallelization) + tmp = self._fetch_status(aggregates, distinct_jobs, err, ignore_errors, + status_parallelization, orphan) profiling_results = None operations_errors = {s['_operations_error'] for s in tmp} @@ -2148,7 +2496,7 @@ def print_status(self, jobs=None, overview=True, overview_max_lines=None, # eligible operation. def _incomplete(s): - return any(op['eligible'] for op in s['operations'].values()) + return any([op['eligible'] for op in ops] for ops in s['operations'].values()) tmp = list(filter(_incomplete, tmp)) @@ -2173,8 +2521,8 @@ def _incomplete(s): # Optionally expand parameters argument to all varying parameters. if parameters is self.PRINT_STATUS_ALL_VARYING_PARAMETERS: parameters = list( - sorted({key for job in jobs for key in job.sp.keys() if - len(set([to_hashable(job.sp().get(key)) for job in jobs])) > 1})) + sorted({key for job in distinct_jobs for key in job.sp.keys() if + len(set([to_hashable(job.sp().get(key)) for job in distinct_jobs])) > 1})) if parameters: # get parameters info @@ -2206,9 +2554,6 @@ def get(k, m): # get detailed view info status_legend = ' '.join('[{}]:{}'.format(v, k) for k, v in self.ALIASES.items()) - if compact: - num_operations = len(self._operations) - if pretty: OPERATION_STATUS_SYMBOLS = OrderedDict([ ('ineligible', '\u25cb'), # open circle @@ -2244,32 +2589,56 @@ def get(k, m): context['alias_bool'] = {True: 'Y', False: 'N'} context['scheduler_status_code'] = _FMT_SCHEDULER_STATUS context['status_legend'] = status_legend - if compact: - context['extra_num_operations'] = max(num_operations-1, 0) if not unroll: context['operation_status_legend'] = operation_status_legend context['operation_status_symbols'] = OPERATION_STATUS_SYMBOLS def _add_dummy_operation(job): - job['operations'][''] = { + job['operations'][''] = [{ 'completed': False, 'eligible': False, - 'scheduler_status': JobStatus.dummy} + 'scheduler_status': JobStatus.dummy}] for job in context['jobs']: - has_eligible_ops = any([v['eligible'] for v in job['operations'].values()]) + # If a job has no operation eligible then we add a dummy operation + # Where we show an empty list of operations + for a in job['operations'].values(): + has_eligible_ops = any([b['eligible'] for b in a]) + if has_eligible_ops: + break if not has_eligible_ops and not context['all_ops']: _add_dummy_operation(job) - op_counter = Counter() - for job in context['jobs']: - for k, v in job['operations'].items(): - if k != '' and v['eligible']: - op_counter[k] += 1 - context['op_counter'] = op_counter.most_common(eligible_jobs_max_lines) - n = len(op_counter) - len(context['op_counter']) + aggregate_dict = defaultdict(lambda: []) + aggregate_counter = Counter() + aggregates_per_op = defaultdict(lambda: []) + + for i, job in enumerate(context['jobs']): + total_ops = 0 + for op_name, ops in job['operations'].items(): + if op_name != '': + for op in ops: + if op['eligible']: + total_ops += 1 + agg_id = op['job_ids'].split(' ') + if agg_id not in aggregate_dict[op_name]: + aggregate_dict[op_name].append(agg_id) + if op['is_aggregate']: + aggregates_per_op[op_name].append( + (agg_id, op['scheduler_status']) + ) + if detailed and compact: + context['jobs'][i]['extra_num_operations'] = total_ops - 1 + + for op, ags in aggregate_dict.items(): + aggregate_counter[op] = len(ags) + + context['aggregate_counter'] = aggregate_counter.most_common(eligible_jobs_max_lines) + n = len(aggregate_counter) - len(context['aggregate_counter']) if n > 0: - context['op_counter'].append(('[{} more operations omitted]'.format(n), '')) + context['aggregate_counter'].append(('[{} more operations omitted]'.format(n), '')) + + context['detailed_ags'] = aggregates_per_op status_renderer = StatusRenderer() # We have to make a deep copy of the template environment if we're @@ -2323,7 +2692,7 @@ def _run_operations(self, operations=None, pretend=False, np=None, if timeout is not None and timeout < 0: timeout = None if operations is None: - operations = list(self._get_pending_operations(self)) + operations = list(self._get_pending_operations()) else: operations = list(operations) # ensure list @@ -2454,12 +2823,14 @@ def _execute_operation(self, operation, timeout=None, pretend=False): "Executing operation '{}' with current interpreter " "process ({}).".format(operation, os.getpid())) try: - self._operations[operation.name](*operation._jobs) + self._operations[operation.name](operation._jobs) except Exception as e: - assert len(self._jobs) == 1 raise UserOperationError( 'An exception was raised during operation {operation.name} ' - 'for job {operation._jobs[0]}.'.format(operation=operation)) from e + 'for aggregate {aggregate}.' + ''.format(operation=operation, + aggregate=', '.join(map(str, operation._jobs)))) \ + from e def _get_default_directives(self): return {name: self.groups[name].operation_directives.get(name, dict()) @@ -2520,12 +2891,14 @@ def run(self, jobs=None, names=None, pretend=False, np=None, timeout=None, num=N Specify the order of operations, possible values are: * 'none' or None (no specific order) * 'by-job' (operations are grouped by job) + * 'by-op' (operations are grouped by operation) + * 'by-job' (operations are grouped by job) * 'cyclic' (order operations cyclic by job) * 'random' (shuffle the execution order randomly) * callable (a callable returning a comparison key for an operation used to sort operations) - The default value is `none`, which is equivalent to `by-job` in the current + The default value is `none`, which is equivalent to `by-op` in the current implementation. .. note:: @@ -2541,9 +2914,8 @@ def run(self, jobs=None, names=None, pretend=False, np=None, timeout=None, num=N :type ignore_conditions: :py:class:`~.IgnoreConditions` """ - # If no jobs argument is provided, we run operations for all jobs. - if jobs is None: - jobs = self + # Convert all the signac jobs into an aggregate of 1 + aggregates = self._convert_aggregates_from_jobs(jobs) # Get all matching FlowGroups if isinstance(names, str): @@ -2577,7 +2949,8 @@ def log(msg, lvl=logging.INFO): reached_execution_limit = Event() def select(operation): - self._verify_aggregate_project(operation._jobs) + if not self._verify_aggregate_in_jobs(operation._jobs, aggregates): + return False if num is not None and select.total_execution_count >= num: reached_execution_limit.set() @@ -2621,12 +2994,11 @@ def select(operation): with self._potentially_buffered(): operations = [] for flow_group in flow_groups: - for job in jobs: + for aggregate in self.aggregates[flow_group.name]: operations.extend( flow_group._create_run_job_operations( self._entrypoint, default_directives, - (job,), ignore_conditions)) - + aggregate, ignore_conditions)) operations = list(filter(select, operations)) finally: if messages: @@ -2636,34 +3008,48 @@ def select(operation): if not operations: break # No more pending operations or execution limits reached. + def key_func_by_job(op): + # In order to group the aggregates in a by-job manner, we need + # to first sort the aggregates using their aggregate id. + return get_aggregate_id(op._jobs) + # Optionally re-order operations for execution if order argument is provided: if callable(order): operations = list(sorted(operations, key=order)) elif order == 'cyclic': groups = [list(group) - for _, group in groupby(operations, key=lambda op: op._jobs)] + for _, group in groupby(sorted(operations, key=key_func_by_job), + key=key_func_by_job)] operations = list(roundrobin(*groups)) elif order == 'random': random.shuffle(operations) - elif order is None or order in ('none', 'by-job'): - pass # by-job is the default order + elif order == 'by-job': + groups = [list(group) + for _, group in groupby(sorted(operations, key=key_func_by_job), + key=key_func_by_job)] + operations = list(roundrobin(*groups)) + elif order is None or order in ('none', 'by-op'): + pass # by-op is the default order else: raise ValueError( "Invalid value for the 'order' argument, valid arguments are " - "'none', 'by-job', 'cyclic', 'random', None, or a callable.") + "'none', 'by-op', 'by-job', 'cyclic', 'random', None, or a callable.") logger.info( "Executing {} operation(s) (Pass # {:02d})...".format(len(operations), i_pass)) self._run_operations(operations, pretend=pretend, np=np, timeout=timeout, progress=progress) - def _generate_operations(self, cmd, jobs, requires=None): + def _generate_operations(self, cmd, aggregates, requires=None): "Generate job-operations for a given 'direct' command." - for job in jobs: - if requires and set(requires).difference(self.labels(job)): + for aggregate in aggregates: + if( + len(aggregate) > 1 or + requires and set(requires).difference(self.labels(*aggregate)) + ): continue - cmd_ = cmd.format(job=job) - yield _JobOperation(name=cmd_.replace(' ', '-'), cmd=cmd_, jobs=(job,)) + cmd_ = cmd.format(job=aggregate[0]) + yield _JobOperation(name=cmd_.replace(' ', '-'), cmd=cmd_, jobs=aggregate) def _gather_flow_groups(self, names=None): """Grabs FlowGroups that match any of a set of names.""" @@ -2693,18 +3079,19 @@ def _get_submission_operations(self, jobs, default_directives, names=None, ignore_conditions_on_execution=IgnoreConditions.NONE): """Grabs _JobOperations that are eligible to run from FlowGroups.""" for group in self._gather_flow_groups(names): - for job in jobs: + for aggregate in self.aggregates[group.name]: if ( - group._eligible((job,), ignore_conditions) and - self._eligible_for_submission(group, (job,)) + group._eligible(aggregate, ignore_conditions) and + self._eligible_for_submission(group, aggregate) and + self._verify_aggregate_in_jobs(aggregate, jobs) ): yield group._create_submission_job_operation( entrypoint=self._entrypoint, default_directives=default_directives, - jobs=(job,), index=0, + jobs=aggregate, index=0, ignore_conditions_on_execution=ignore_conditions_on_execution) - def _get_pending_operations(self, jobs, operation_names=None, + def _get_pending_operations(self, jobs=None, operation_names=None, ignore_conditions=IgnoreConditions.NONE): "Get all pending operations for the given selection." assert not isinstance(operation_names, str) @@ -2716,12 +3103,61 @@ def _verify_group_compatibility(self, groups): """Verifies that all selected groups can be submitted together.""" return all(a.isdisjoint(b) for a in groups for b in groups if a != b) - def _verify_aggregate_project(self, aggregate): + def _verify_aggregate_project(self, aggregates): """Verifies that all aggregates belongs to the same project.""" - for job in aggregate: - if job not in self: - raise ValueError("Job {} is not present " - "in the project".format(job)) + if aggregates is None: + return + + for aggregate in aggregates: + aggregate_id = get_aggregate_id(aggregate) + if not self._aggregates_ids.get(aggregate_id, False): + raise LookupError(f"Did not find the aggregate {aggregate} having " + f"id {aggregate_id} in the project") + + def _get_aggregate_from_id(self, id): + try: + return self._aggregates_ids[id] + except KeyError: + raise LookupError(f"Did not find aggregate having id {id} in the project") + + def _convert_aggregates_from_jobs(self, jobs): + # The jobs parameter in the public methods like ``run``, ``submit``, ``status`` + # may accept either a signac job or an aggregate. + # We convert that job / aggregate (which may be of any type (eg. list)) to an + # aggregate of type ``tuple`` + + if jobs is not None: + aggregates = set() # Set in order to prevent duplicate entries + for aggregate in jobs: + # User can still pass signac jobs + if isinstance(aggregate, signac.contrib.job.Job): + if aggregate not in self: + raise LookupError(f"Did not find the job {aggregate} in the project") + aggregates.add((aggregate,)) + else: + try: + aggregate = tuple(aggregate) + id = get_aggregate_id(aggregate) + if not self._aggregates_ids.get(id, False): + raise LookupError(f"Did not find aggregate {aggregate} in the project") + aggregates.add(aggregate) # An aggregate provided by the user + except TypeError: + raise TypeError('Invalid argument provided by a user. Please provide ' + 'a valid signac job or an aggregate of jobs instead.') + else: + aggregates = None + + return aggregates + + def _verify_aggregate_in_jobs(self, aggregate, jobs): + """Verifies whether the aggregate is present in the jobs provided by the users + or not. + """ + if jobs is None: + return True + elif aggregate not in jobs: + return False + return True @contextlib.contextmanager def _potentially_buffered(self): @@ -3002,10 +3438,10 @@ def submit(self, bundle_size=1, jobs=None, names=None, num=None, parallel=False, :type ignore_conditions: :py:class:`~.IgnoreConditions` """ - # Regular argument checks and expansion - if jobs is None: - jobs = self # select all jobs + # Convert all the signac jobs passed by a user into a tuple of 1 job. + aggregates = self._convert_aggregates_from_jobs(jobs) + # Regular argument checks and expansion if isinstance(names, str): raise ValueError( "The 'names' argument must be a sequence of strings, however you " @@ -3031,11 +3467,15 @@ def submit(self, bundle_size=1, jobs=None, names=None, num=None, parallel=False, # Gather all pending operations. with self._potentially_buffered(): default_directives = self._get_default_directives() - operations = self._get_submission_operations(jobs, default_directives, names, - ignore_conditions, + operations = self._get_submission_operations(aggregates, default_directives, + names, ignore_conditions, ignore_conditions_on_execution) if num is not None: operations = list(islice(operations, num)) + else: + operations = list(operations) + + self._store_aggregates(operations) # Bundle them up and submit. for bundle in _make_bundles(operations, bundle_size): @@ -3297,6 +3737,10 @@ def _add_print_status_args(cls, parser): type=str, default='terminal', help="Set status output format: terminal, markdown, or html.") + parser.add_argument( + '--orphan', + action='store_true', + help="Fetch status for lost aggregates.") def labels(self, job): """Yields all labels for the given ``job``. @@ -3412,20 +3856,21 @@ def completed_operations(self, job): if op._complete((job,)): yield name - def _job_operations(self, job, ignore_conditions=IgnoreConditions.NONE): - "Yield instances of _JobOperation constructed for specific jobs." - for name in self.operations: - group = self._groups[name] - yield from group._create_run_job_operations(entrypoint=self._entrypoint, jobs=(job,), - default_directives=dict(), - ignore_conditions=ignore_conditions, - index=0) + def _job_operations(self, jobs, group, ignore_conditions=IgnoreConditions.NONE): + """Yield instances of _JobOperation constructed for the specific job-aggregate associated + with the group.""" + yield from group._create_run_job_operations(entrypoint=self._entrypoint, jobs=jobs, + default_directives=dict(), + ignore_conditions=ignore_conditions, + index=0) - def _next_operations(self, jobs, ignore_conditions=IgnoreConditions.NONE): - """Determine the next eligible operations for jobs. + def _next_operations(self, jobs=None, ignore_conditions=IgnoreConditions.NONE): + """Determine the next eligible operations for aggregates created + while initialization. :param jobs: The signac job handles. + By default all the aggregates are evaluated to get the next opeation associated. :type jobs: tuple of :class:`~signac.contrib.job.Job` :param ignore_conditions: @@ -3436,9 +3881,16 @@ def _next_operations(self, jobs, ignore_conditions=IgnoreConditions.NONE): :yield: All instances of :class:`~._JobOperation` jobs are eligible for. """ - for job in jobs: - for op in self._job_operations(job, ignore_conditions): - yield op + for name in self.operations: + group = self._groups[name] + for aggregate in self.aggregates[group.name]: + if not self._verify_aggregate_in_jobs(aggregate, jobs): + continue + yield from group._create_run_job_operations(entrypoint=self._entrypoint, + jobs=aggregate, + default_directives=dict(), + ignore_conditions=ignore_conditions, + index=0) @deprecated(deprecated_in="0.11", removed_in="0.13", current_version=__version__) def next_operations(self, *jobs, ignore_conditions=IgnoreConditions.NONE): @@ -3456,12 +3908,16 @@ def next_operations(self, *jobs, ignore_conditions=IgnoreConditions.NONE): :yield: All instances of :class:`~.JobOperation` jobs are eligible for. """ - for job in jobs: - for op in self._job_operations(job, ignore_conditions): - # JobOperation is just meand to deal with a single job and not a tuple of jobs. - # Hence we have to make sure that a JobOperation instance hold a single job. - assert len(op._jobs) == 1 - yield JobOperation(op.id, op.name, op._jobs[0], op._cmd, op.directives) + for name in self.operations: + group = self._groups[name] + for aggregate in self.aggregates[group.name]: + # JobOperation is just meant to deal with a single job and not a aggregate + # of jobs. Hence we select aggregates with length == 1 and the job in that + # aggregate should be present in the jobs passed by a user. + if len(aggregate) > 1 and aggregate[0] not in jobs: + continue + for op in self._job_operations(aggregate, group, ignore_conditions): + yield JobOperation(op.id, op.name, op._jobs[0], op._cmd, op.directives) @classmethod def operation(cls, func, name=None): @@ -3489,13 +3945,6 @@ def hello(job): if name in cls._GROUP_NAMES: raise ValueError("A group with name '{}' is already registered.".format(name)) - signature = inspect.signature(func) - for i, (k, v) in enumerate(signature.parameters.items()): - if i and v.default is inspect.Parameter.empty: - raise ValueError( - "Only the first argument in an operation argument may not have " - "a default value! ({})".format(name)) - # Append the name and function to the class registry cls._OPERATION_FUNCTIONS.append((name, func)) cls._GROUPS.append(FlowGroupEntry(name=name, options="")) @@ -3548,14 +3997,30 @@ def _register_operations(self): 'pre': pre_conditions.get(func, None), 'post': post_conditions.get(func, None)} + if not getattr(func, '_flow_aggregate', False): + aggregate = Aggregate.groupsof(1) + func._flow_aggregate = aggregate._create_MakeAggregate() + # Construct FlowOperation: if getattr(func, '_flow_cmd', False): self._operations[name] = FlowCmdOperation(cmd=func, **params) else: self._operations[name] = FlowOperation(op_func=func, **params) + def register_aggregates(self): + """Generate aggregates for every operation or group in a FlowProject""" + for name in self._groups: + self._aggregates[name] = self._groups[name].aggregate(self, group_name=name, + project=self) + + def _reregister_default_aggregates(self): + for name in self._groups: + if not self._groups[name].aggregate._is_aggregate: + # If we use group.aggregate then we'll face pickling issues + self._aggregates[name] = [(job,) for job in self] + @classmethod - def make_group(cls, name, options=""): + def make_group(cls, name, options="", aggregate=None): """Make a FlowGroup named ``name`` and return a decorator to make groups. .. code-block:: python @@ -3578,13 +4043,21 @@ def foo(job): A string to append to submissions can be any valid :meth:`FlowOperation.run` option. :type options: str + :param aggregator: + Object that aggregates jobs for this group. + By default the value is None which results in forming the aggregates of 1. + :type aggregator: + :class:`~.Aggregate`. """ if name in cls._GROUP_NAMES: raise ValueError("Repeat definition of group with name '{}'.".format(name)) else: cls._GROUP_NAMES.add(name) - group_entry = FlowGroupEntry(name, options) + if aggregate is None: + aggregate = Aggregate.groupsof(1) + + group_entry = FlowGroupEntry(name, options, aggregate._create_MakeAggregate()) cls._GROUPS.append(group_entry) return group_entry @@ -3597,7 +4070,8 @@ def _register_groups(self): # Initialize all groups without operations for entry in group_entries: - self._groups[entry.name] = FlowGroup(entry.name, options=entry.options) + self._groups[entry.name] = FlowGroup(entry.name, options=entry.options, + aggregate=entry.aggregate) # Add operations and directives to group for (op_name, op) in self._operations.items(): @@ -3605,12 +4079,12 @@ def _register_groups(self): func = op._cmd else: func = op._op_func - + aggregate = func._flow_aggregate if hasattr(func, '_flow_groups'): operation_directives = getattr(func, '_flow_group_operation_directives', dict()) for group_name in func._flow_groups: self._groups[group_name].add_operation( - op_name, op, operation_directives.get(group_name, None)) + op_name, op, aggregate, operation_directives.get(group_name, None)) # For singleton groups add directives self._groups[op_name].operation_directives[op_name] = getattr(func, @@ -3626,6 +4100,14 @@ def operations(self): def groups(self): return self._groups + @property + def aggregates(self): + "A dictionary of aggregated jobs per group or an operation" + # Re-register default aggregates (aggregates of one) before + # returning in order to avoid returning invalid aggregates + self._reregister_default_aggregates() + return self._aggregates + def _eligible_for_submission(self, flow_group, jobs): """Determine if a flow_group is eligible for submission with a given job-aggregate. @@ -3646,7 +4128,7 @@ def _eligible_for_submission(self, flow_group, jobs): def _main_status(self, args): "Print status overview." - jobs = self._select_jobs_from_args(args) + aggregates = self._select_jobs_from_args(args) if args.compact and not args.unroll: logger.warn("The -1/--one-line argument is incompatible with " "'--stack' and will be ignored.") @@ -3659,9 +4141,9 @@ def _main_status(self, args): start = time.time() try: - self.print_status(jobs=jobs, **args) + self.print_status(jobs=aggregates, **args) except NoSchedulerError: - self.print_status(jobs=jobs, **args) + self.print_status(jobs=aggregates, **args) except Exception as error: if show_traceback: logger.error( @@ -3675,8 +4157,9 @@ def _main_status(self, args): error = error.__cause__ # Always show the user traceback cause. traceback.print_exception(type(error), error, error.__traceback__) else: + length_jobs = len(self) if aggregates is None else len(aggregates) # Use small offset to account for overhead with few jobs - delta_t = (time.time() - start - 0.5) / max(len(jobs), 1) + delta_t = (time.time() - start - 0.5) / max(length_jobs, 1) config_key = 'status_performance_warn_threshold' warn_threshold = flow_config.get_config_value(config_key) if not args['profile'] and delta_t > warn_threshold >= 0: @@ -3695,20 +4178,20 @@ def _main_status(self, args): def _main_next(self, args): "Determine the jobs that are eligible for a specific operation." - for op in self._next_operations(self): + for op in self._next_operations(): if args.name in op.name: - print(' '.join(map(str, op._jobs))) + print(get_aggregate_id(op._jobs)) def _main_run(self, args): "Run all (or select) job operations." # Select jobs: - jobs = self._select_jobs_from_args(args) + aggregate = self._select_jobs_from_args(args) # Setup partial run function, because we need to call this either # inside some context managers or not based on whether we need # to switch to the project root directory or not. run = functools.partial(self.run, - jobs=jobs, names=args.operation_name, pretend=args.pretend, + jobs=aggregate, names=args.operation_name, pretend=args.pretend, np=args.parallel, timeout=args.timeout, num=args.num, num_passes=args.num_passes, progress=args.progress, order=args.order, @@ -3730,7 +4213,7 @@ def _main_script(self, args): raise ValueError( "Cannot use the -o/--operation-name and the --cmd options in combination!") # Select jobs: - jobs = self._select_jobs_from_args(args) + aggregates = self._select_jobs_from_args(args) # Gather all pending operations or generate them based on a direct command... with self._potentially_buffered(): @@ -3738,11 +4221,11 @@ def _main_script(self, args): warnings.warn("The --cmd option for script is deprecated as of " "0.10 and will be removed in 0.12.", DeprecationWarning) - operations = self._generate_operations(args.cmd, jobs, args.requires) + operations = self._generate_operations(args.cmd, aggregates, args.requires) else: names = args.operation_name if args.operation_name else None default_directives = self._get_default_directives() - operations = self._get_submission_operations(jobs, default_directives, names, + operations = self._get_submission_operations(aggregates, default_directives, names, args.ignore_conditions, args.ignore_conditions_on_execution) operations = list(islice(operations, args.num)) @@ -3759,26 +4242,39 @@ def _main_submit(self, args): kwargs = vars(args) # Select jobs: - jobs = self._select_jobs_from_args(args) + aggregates = self._select_jobs_from_args(args) # Fetch the scheduler status. if not args.test: - self._fetch_scheduler_status(jobs) + self._fetch_scheduler_status(aggregates) names = args.operation_name if args.operation_name else None - self.submit(jobs=jobs, names=names, **kwargs) + self.submit(jobs=aggregates, names=names, **kwargs) def _main_exec(self, args): - if len(args.jobid): - jobs = [self.open_job(id=jid) for jid in args.jobid] + if args.jobid: + aggregates = set() + for _id in args.jobid: + if _id.startswith('agg-'): + aggregates.add(self._get_aggregate_from_id(_id)) + continue + try: + aggregates.add((self.open_job(id=_id),)) + except KeyError: + raise LookupError(f'Did not find job with id {_id}.') + + self._verify_aggregate_project(aggregates) else: - jobs = self + # No specific jobs selected by the user so accept all the jobs. + # See ``_verify_aggregate_in_jobs`` for more details. + aggregates = None + try: operation = self._operations[args.operation] if isinstance(operation, FlowCmdOperation): - def operation_function(job): - cmd = operation(job).format(job=job) + def operation_function(jobs): + cmd = operation(*jobs) subprocess.run(cmd, shell=True, check=True) else: operation_function = operation @@ -3786,8 +4282,9 @@ def operation_function(job): except KeyError: raise KeyError("Unknown operation '{}'.".format(args.operation)) - for job in jobs: - operation_function(job) + aggregates = self.aggregates[args.operation] if aggregates is None else aggregates + for aggregate in aggregates: + operation_function(aggregate) def _select_jobs_from_args(self, args): "Select jobs with the given command line arguments ('-j/-f/--doc-filter')." @@ -3796,14 +4293,22 @@ def _select_jobs_from_args(self, args): "Cannot provide both -j/--job-id and -f/--filter or --doc-filter in combination.") if args.job_id: - try: - return [self.open_job(id=job_id) for job_id in args.job_id] - except KeyError as error: - raise LookupError("Did not find job with id {}.".format(error)) - else: + aggregate = set() # set in order to avoid duplicate ids + for id in args.job_id: + if id.startswith('agg'): + aggregate.add(self._get_aggregate_from_id(id)) + continue + try: + aggregate.add((self.open_job(id=id),)) + except KeyError as error: + raise LookupError("Did not find job with id {}.".format(error)) + return list(aggregate) + elif args.filter or args.doc_filter: filter_ = parse_filter_arg(args.filter) doc_filter = parse_filter_arg(args.doc_filter) return JobsCursor(self, filter_, doc_filter) + else: + return None def main(self, parser=None): """Call this function to use the main command line interface. @@ -3933,7 +4438,7 @@ class MyProject(FlowProject): execution_group.add_argument( '--order', type=str, - choices=['none', 'by-job', 'cyclic', 'random'], + choices=['none', 'by-op', 'by-job', 'cyclic', 'random'], default=None, help="Specify the execution order of operations for each execution pass.") execution_group.add_argument( @@ -4087,14 +4592,24 @@ def _execute_serialized_operation(loads, project, operation): project._execute_operation(project._loads_op(operation)) -def _serialized_get_job_status(s_task): - """Invoke the _get_job_status() method on a serialized project instance.""" +def _serialized_get_job_labels(s_task): + """Invoke the _get_job_labels() method on a serialized project instance.""" loads = s_task[0] project = loads(s_task[1]) job = project.open_job(id=s_task[2]) ignore_errors = s_task[3] + return project._get_job_labels(job, ignore_errors=ignore_errors) + + +def _serialized_get_group_status(s_task): + """Invoke the _get_group_status() method on a serialized project instance.""" + loads = s_task[0] + project = loads(s_task[1]) + group = s_task[2] + ignore_errors = s_task[3] cached_status = s_task[4] - return project.get_job_status(job, ignore_errors=ignore_errors, cached_status=cached_status) + aggregates = s_task[5] + return project._get_group_status(group, ignore_errors, cached_status, aggregates) # Status-related helper functions diff --git a/flow/templates/base_status.jinja b/flow/templates/base_status.jinja index fb5617125..855c20023 100644 --- a/flow/templates/base_status.jinja +++ b/flow/templates/base_status.jinja @@ -27,10 +27,10 @@ {% endblock%} {% block operation_summary %} -| operation | number of eligible jobs | -| --------- | ----------------------- | -{% for op, n_jobs in op_counter %} -| {{ op }} | {{ n_jobs }} | +| operation | number of eligible aggregates | +| --------- | ----------------------------- | +{% for ops, n_aggs in aggregate_counter %} +| {{ ops }} | {{ n_aggs }} | {% endfor %} {% endblock %} {% endif %} @@ -39,20 +39,43 @@ {% block detailed %} {% if detailed %} {{ 'Detailed View: \n' }} -| job_id | operation | {{ para_head }}labels | -| ------ | --------- | {{ ns.dash }}-------- | +| job_id | operation | aggregation_status | {{ para_head }}labels | +| ------ | --------- | ------------------ | {{ ns.dash }}-------- | {% for job in jobs %} {% if parameters %} {% set para_output = ns.field_parameters | format(*job['parameters'].values()) %} {% endif %} -{% for key, value in job['operations'].items() if value | job_filter(scheduler_status_code, all_ops) %} +{% for key, ops in job['operations'].items() %} {% if loop.first %} -| {{job['job_id']}} | {{ field_operation | highlight(value['eligible'], pretty) | format(key, '['+scheduler_status_code[value['scheduler_status']]+']') }} | {{ para_output }}{{ job['labels'] | join(', ') }} | +{% for value in ops if value | job_filter(scheduler_status_code, all_ops) %} +| {{job['job_id']}} | {{ field_operation | highlight(value['eligible'], pretty) | format(key, '['+scheduler_status_code[value['scheduler_status']]+']') }} | {{ value['aggregate_detail'] | highlight(value['aggregate_detail'], pretty) }} | {{ para_output }}{{ job['labels'] | join(', ') }} | +{% endfor %} {% else %} -| | {{ field_operation | highlight(value['eligible'], pretty) | format(key, '['+scheduler_status_code[value['scheduler_status']]+']') }} | {{ para_output }}{{ job['labels'] | join(', ') }} | +{% for value in ops if value | job_filter(scheduler_status_code, all_ops) %} +| | {{ field_operation | highlight(value['eligible'], pretty) | format(key, '['+scheduler_status_code[value['scheduler_status']]+']') }} | {{ value['aggregate_detail'] | highlight(value['aggregate_detail'], pretty) }} | {{ para_output }}{{ job['labels'] | join(', ') }} | +{% endfor %} {% endif %} {% endfor %} {% endfor %} + +{{ 'Detailed Aggregate View: \n' }} +| operation | jobs_in_aggregate | length_of_aggregate | status | +| --------- | ----------------- | ------------------- | ------ | +{% for op, agg_stati in detailed_ags.items() %} +{% for aggregate, stati in agg_stati %} +{% for job in aggregate %} +{% if loop.first %} +| {{op}} | {{job | highlight(job, pretty) }} | {{aggregate|length}} | {{'['+scheduler_status_code[stati]+']'}} | +{% else %} +| {{op}} | {{job | highlight(job, pretty) }} | {{aggregate|length}} | {{ '['+scheduler_status_code[stati]+']' }} | +{% if loop.last %} +| | | | +{% endif %} +{% endif %} +{% endfor %} +{% endfor %} +{% endfor %} + {{ status_legend }} {% endif %} {% endblock %} diff --git a/flow/templates/base_status_compact.jinja b/flow/templates/base_status_compact.jinja index 9550f6396..af828d871 100644 --- a/flow/templates/base_status_compact.jinja +++ b/flow/templates/base_status_compact.jinja @@ -9,18 +9,21 @@ {% set para_output = ns.field_parameters | format(*job['parameters'].values()) %} {% endif %} {% if all_ops %} -{% set key, value = job['operations'].items() | first() %} -| {{ job['job_id'] }} | {{ field_operation | highlight(value['eligible'], pretty) | format(key, '[' + scheduler_status_code[value['scheduler_status']] + ']', '+(' + extra_num_operations | string() + ')') }} | {{ para_output }}{{ job['labels'] | join(', ') }} | +{% set key, ops = job['operations'].items() | first() %} +{% set value = ops | first() %} +| {{ job['job_id'] }} | {{ field_operation | highlight(value['eligible'], pretty) | format(key, '[' + scheduler_status_code[value['scheduler_status']] + ']', '+(' + job['extra_num_operations'] | string() + ')') }} | {{ para_output }}{{ job['labels'] | join(', ') }} | {% else %} {% set ns.extra_num_operation = -1 %} {% set ns.if_first_eligible_operation = True %} -{% for key, value in job['operations'].items() if value | job_filter(scheduler_status_code, all_ops) %} +{% for key, ops in job['operations'].items() %} +{% for value in ops if value | job_filter(scheduler_status_code, all_ops) %} {% set ns.extra_num_operation = ns.extra_num_operation + 1 %} {% if loop.first %} {% set ns.first_operation_key = key %} {% set ns.first_operation_value = value %} {% endif %} {% endfor %} +{% endfor %} | {{ job['job_id'] }} | {{ field_operation | highlight(ns.first_operation_value['eligible'], pretty) | format(ns.first_operation_key, '['+scheduler_status_code[ns.first_operation_value['scheduler_status']]+']', '+('+ns.extra_num_operation | string()+')') }} | {{ para_output }}{{ job['labels'] | join(', ') }} | {% endif %} {% endfor %} diff --git a/flow/templates/base_status_expand.jinja b/flow/templates/base_status_expand.jinja index 3741948a6..b58daa455 100644 --- a/flow/templates/base_status_expand.jinja +++ b/flow/templates/base_status_expand.jinja @@ -8,14 +8,18 @@ {% block detailed %} {{ super () }} {{ 'Operations: \n' }} -| job_id | operation | eligible | cluster_status | -| ------ | --------- | -------- | -------------- | +| job_id | operation | aggregation_status | eligible | cluster_status | +| ------ | --------- | ------------------ | -------- | -------------- | {% for job in jobs %} -{% for key, value in job['operations'].items() if value | job_filter(scheduler_status_code, all_ops) %} +{% for key, ops in job['operations'].items() %} {% if loop.first %} -| {{ job['job_id'] }} | {{ '%s' | highlight(value['eligible'], pretty) | format(key) }} | {{ alias_bool[value['eligible']] }} | {{ scheduler_status_code[value['scheduler_status']] }} | +{% for value in ops if value | job_filter(scheduler_status_code, all_ops) %} +| {{ job['job_id'] }} | {{ '%s' | highlight(value['eligible'], pretty) | format(key) }} | {{ value['aggregate_detail'] | highlight(value['aggregate_detail'], pretty) }} | {{ alias_bool[value['eligible']] }} | {{ scheduler_status_code[value['scheduler_status']] }} | +{% endfor %} {% else %} -| | {{ '%s' | highlight(value['eligible'], pretty) | format(key) }} | {{ alias_bool[value['eligible']] }} | {{ scheduler_status_code[value['scheduler_status']] }} | +{% for value in ops if value | job_filter(scheduler_status_code, all_ops) %} +| | {{ '%s' | highlight(value['eligible'], pretty) | format(key) }} | {{ value['aggregate_detail'] | highlight(value['aggregate_detail'], pretty) }} | {{ alias_bool[value['eligible']] }} | {{ scheduler_status_code[value['scheduler_status']] }} | +{% endfor %} {% endif %} {% endfor %} {% endfor %} diff --git a/flow/templates/base_status_stack.jinja b/flow/templates/base_status_stack.jinja index 2bb9e7ba8..d90479475 100644 --- a/flow/templates/base_status_stack.jinja +++ b/flow/templates/base_status_stack.jinja @@ -1,16 +1,18 @@ {% extends "base_status.jinja" %} {% block detailed %} {{ 'Detailed View: \n' }} -| job_id | {{ para_head }}labels | -| ------ | {{ ns.dash }}-------- | +| job_id | aggregation_status | {{ para_head }}labels | +| ------ | ------------------ | {{ ns.dash }}-------- | {% set field_operation = '%s %s %s' %} {% for job in jobs %} {% if parameters %} {% set para_output = ns.field_parameters | format(*job['parameters'].values()) %} {% endif %} -| {{job['job_id']}} | {{ para_output }}{{ job['labels'] | join(', ') }} | -{% for key, value in job['operations'].items() if value | job_filter(scheduler_status_code, all_ops) %} -| {{ field_operation | highlight(value['eligible'], pretty) | format(value | get_operation_status(operation_status_symbols), key, '['+scheduler_status_code[value['scheduler_status']]+']') }} | +| {{job['job_id']}} | | {{ para_output }}{{ job['labels'] | join(', ') }} | +{% for key, ops in job['operations'].items() %} +{% for value in ops if value | job_filter(scheduler_status_code, all_ops) %} +| {{ field_operation | highlight(value['eligible'], pretty) | format(value | get_operation_status(operation_status_symbols), key, '['+scheduler_status_code[value['scheduler_status']]+']') }} | {{ value['aggregate_detail'] | highlight(value['aggregate_detail'], pretty) }}| +{% endfor %} {% endfor %} {% endfor %} {{ operation_status_legend }} diff --git a/tests/define_aggregate_test_project.py b/tests/define_aggregate_test_project.py new file mode 100644 index 000000000..17677bd76 --- /dev/null +++ b/tests/define_aggregate_test_project.py @@ -0,0 +1,51 @@ +from flow import FlowProject, Aggregate + + +class _AggregateTestProject(FlowProject): + pass + + +group1 = _AggregateTestProject.make_group(name="group_agg", aggregate=Aggregate()) + + +@_AggregateTestProject.label +def aggregate_doc_condition(job): + try: + return job._project.document['average'] + except KeyError: + return False + + +@_AggregateTestProject.operation +@Aggregate() +@_AggregateTestProject.post.true('average') +def agg_op1(*jobs): + sum = 0 + for job in jobs: + sum += job.sp.i + for job in jobs: + job.document.sum = sum + + +@_AggregateTestProject.operation +@group1 +@_AggregateTestProject.post.true('average') +def agg_op2(*jobs): + average = 0 + for job in jobs: + average += job.sp.i + average /= len(jobs) + for job in jobs: + job.document.average = average + + +@_AggregateTestProject.operation +@group1 +@_AggregateTestProject.post.true('test3') +def agg_op3(*jobs): + for job in jobs: + job.document.test3 = True + + +if __name__ == '__main__': + _AggregateTestProject().main() diff --git a/tests/define_template_test_project.py b/tests/define_template_test_project.py index 6738cf8df..24c707fd4 100644 --- a/tests/define_template_test_project.py +++ b/tests/define_template_test_project.py @@ -9,6 +9,7 @@ class TestProject(flow.FlowProject): group1 = TestProject.make_group(name="group1") +group2 = TestProject.make_group(name="group2", aggregate=flow.Aggregate.groupsof(2)) @TestProject.operation @@ -52,3 +53,18 @@ def gpu_op(job): @flow.directives(ngpu=TestProject.ngpu, nranks=TestProject.nranks) def mpi_gpu_op(job): pass + + +@TestProject.operation +@flow.Aggregate.groupsof(2) +@group2 +def serial_agg_op(*jobs): + pass + + +@TestProject.operation +@flow.directives(np=TestProject.np) +@flow.Aggregate.groupsof(2) +@group2 +def parallel_agg_op(*jobs): + pass diff --git a/tests/generate_template_reference_data.py b/tests/generate_template_reference_data.py index f247512d1..34bf45916 100755 --- a/tests/generate_template_reference_data.py +++ b/tests/generate_template_reference_data.py @@ -138,6 +138,16 @@ def _store_bundled(self, operations): flow.FlowProject._store_bundled = _store_bundled +# We don't need to store the information of lost aggregates for +# testing templates. Hence we need to mock this method in order to +# avoid needing to make a file. +def _store_aggregates(self, operations): + pass + + +flow.FlowProject._store_aggregates = _store_aggregates + + def get_masked_flowproject(p): """Mock environment-dependent attributes and functions. Need to mock sys.executable before the FlowProject is instantiated, and then modify the @@ -161,9 +171,22 @@ def main(args): "Use `-f/--force` to overwrite.".format(ARCHIVE_DIR)) return + def mock_submit(fp, env, jobs, names, bundle_size, **kwargs): + tmp_out = io.TextIOWrapper(io.BytesIO(), sys.stdout.encoding) + with redirect_stdout(tmp_out): + try: + fp.submit( + env=env, bundle_size=bundle_size, jobs=jobs, names=names, + force=True, pretend=True, **kwargs) + except jinja2.TemplateError as e: + print('ERROR:', e) # Shows template error in output script + tmp_out.seek(0) + return tmp_out + with signac.TemporaryProject(name=PROJECT_NAME) as p: init(p) fp = get_masked_flowproject(p) + agg_ops = ['serial_agg_op', 'parallel_agg_op', 'group2'] for job in fp: with job: @@ -173,17 +196,7 @@ def main(args): if 'bundle' in parameters: bundle = parameters.pop('bundle') fn = 'script_{}.sh'.format('_'.join(bundle)) - tmp_out = io.TextIOWrapper(io.BytesIO(), sys.stdout.encoding) - with redirect_stdout(tmp_out): - try: - fp.submit( - env=env, jobs=[job], names=bundle, pretend=True, - force=True, bundle_size=len(bundle), **parameters) - except jinja2.TemplateError as e: - print('ERROR:', e) # Shows template error in output script - - # Filter out non-header lines - tmp_out.seek(0) + tmp_out = mock_submit(fp, env, [job], bundle, len(bundle), **parameters) with open(fn, 'w') as f: with redirect_stdout(f): print(tmp_out.read(), end='') @@ -198,17 +211,12 @@ def main(args): 'gpu' in op.lower()): continue fn = 'script_{}.sh'.format(op) - tmp_out = io.TextIOWrapper(io.BytesIO(), sys.stdout.encoding) - with redirect_stdout(tmp_out): - try: - fp.submit( - env=env, jobs=[job], names=[op], - pretend=True, force=True, **parameters) - except jinja2.TemplateError as e: - print('ERROR:', e) # Shows template error in output script - - # Filter out non-header lines and the job-name line - tmp_out.seek(0) + if op in agg_ops: + continue + # tmp_out = mock_submit(fp, env, fp, [op], 1, **parameters) + else: + tmp_out = mock_submit(fp, env, [job], [op], 1, **parameters) + with open(fn, 'w') as f: with redirect_stdout(f): print(tmp_out.read(), end='') diff --git a/tests/template_reference_data.tar.gz b/tests/template_reference_data.tar.gz index e74c0951e..eeb3c9e54 100644 Binary files a/tests/template_reference_data.tar.gz and b/tests/template_reference_data.tar.gz differ diff --git a/tests/test_aggregate.py b/tests/test_aggregate.py new file mode 100644 index 000000000..7f302a6a4 --- /dev/null +++ b/tests/test_aggregate.py @@ -0,0 +1,253 @@ +import pytest + +from functools import partial +from tempfile import TemporaryDirectory + +import signac +from flow.aggregate import Aggregate + + +class AggregateProjectSetup: + project_class = signac.Project + entrypoint = dict(path='') + + @pytest.fixture + def setUp(self, request): + self._tmp_dir = TemporaryDirectory(prefix='flow-aggregate_') + request.addfinalizer(self._tmp_dir.cleanup) + self.project = self.project_class.init_project( + name='AggregateTestProject', + root=self._tmp_dir.name) + + def mock_project(self): + project = self.project_class.get_project(root=self._tmp_dir.name) + for i in range(10): + even = i % 2 == 0 + if even: + project.open_job(dict(i=i, half=i / 2, even=even)).init() + else: + project.open_job(dict(i=i, even=even)).init() + return project + + @pytest.fixture + def project(self): + return self.mock_project() + + +class TestAggregate(AggregateProjectSetup): + + def test_default_init(self): + aggregate_instance = Aggregate() + test_list = (1, 2, 3, 4, 5) + assert aggregate_instance._sort is None + assert aggregate_instance._aggregator(test_list) == [test_list] + assert aggregate_instance._select is None + + def test_invalid_aggregator(self, setUp, project): + aggregators = ['str', 1, {}] + for aggregator in aggregators: + with pytest.raises(TypeError): + Aggregate(aggregator) + + def test_invalid_sort(self): + sort_list = [1, {}, lambda x: x] + for sort in sort_list: + with pytest.raises(TypeError): + Aggregate(sort=sort) + + def test_invalid_select(self): + selectors = ['str', 1, []] + for _select in selectors: + with pytest.raises(TypeError): + Aggregate(select=_select) + + def test_invalid_call(self): + call_params = ['str', 1, None] + for param in call_params: + with pytest.raises(TypeError): + Aggregate()(param) + + def test_call_without_function(self): + aggregate_instance = Aggregate() + with pytest.raises(TypeError): + aggregate_instance() + + def test_call_with_function(self): + aggregate_instance = Aggregate() + + def test_function(x): + return x + + assert not getattr(test_function, '_flow_aggregate', False) + test_function = aggregate_instance(test_function) + assert getattr(test_function, '_flow_aggregate', False) + + def test_with_decorator_with_pre_initialization(self): + aggregate_instance = Aggregate() + + @aggregate_instance + def test_function(x): + return x + + assert getattr(test_function, '_flow_aggregate', False) + + def test_with_decorator_without_pre_initialization(self): + @Aggregate() + def test_function(x): + return x + + assert getattr(test_function, '_flow_aggregate', False) + + def test_groups_of_invalid_num(self): + invalid_values = [{}, 'str', -1, -1.5] + for invalid_value in invalid_values: + with pytest.raises((TypeError, ValueError)): + Aggregate.groupsof(invalid_value) + + def test_group_by_invalid_key(self): + with pytest.raises(TypeError): + Aggregate.groupby(1) + + def test_groupby_with_valid_type_default_for_Iterable(self, setUp, project): + Aggregate.groupby(['half', 'even'], default=[-1, -1]) + + def test_groupby_with_invalid_type_default_key_for_Iterable(self, setUp, project): + with pytest.raises(TypeError): + Aggregate.groupby(['half', 'even'], default=-1) + + def test_groupby_with_invalid_length_default_key_for_Iterable(self, setUp, project): + with pytest.raises(ValueError): + Aggregate.groupby(['half', 'even'], default=[-1, -1, -1]) + + +class TestMakeAggregate(AggregateProjectSetup): + + def test_valid_aggregator_non_partial(self, setUp, project): + # Return groups of 1 + def helper_aggregator(jobs): + for job in jobs: + yield (job, ) + + aggregate_instance = Aggregate(helper_aggregator) + + aggregate_instance = aggregate_instance._create_MakeAggregate() + aggregate_job_manual = helper_aggregator(project) + aggregate_job_via_aggregator = aggregate_instance(project) + + assert [aggregate for aggregate in aggregate_job_manual] == \ + aggregate_job_via_aggregator + + def test_valid_aggregator_partial(self, setUp, project): + aggregate_instance = Aggregate(lambda jobs: [jobs]) + aggregate_instance = aggregate_instance._create_MakeAggregate() + aggregate_job_via_aggregator = aggregate_instance(project) + + assert [tuple([job for job in project])] == \ + aggregate_job_via_aggregator + + def test_valid_sort(self, setUp, project): + helper_sort = partial(sorted, key=lambda job: job.sp.i) + aggregate_instance = Aggregate(sort='i') + aggregate_instance = aggregate_instance._create_MakeAggregate() + + assert([tuple(helper_sort(project))] == aggregate_instance(project)) + + def test_valid_reversed_sort(self, setUp, project): + helper_sort = partial(sorted, key=lambda job: job.sp.i, reverse=True) + aggregate_instance = Aggregate(sort='i', reverse=True) + aggregate_instance = aggregate_instance._create_MakeAggregate() + + assert([tuple(helper_sort(project))] == aggregate_instance(project)) + + def test_groups_of_valid_num(self, setUp, project): + valid_values = [1, 2, 3, 6] + + total_jobs = len(project) + + for valid_value in valid_values: + aggregate_instance = Aggregate.groupsof(valid_value) + aggregate_instance = aggregate_instance._create_MakeAggregate() + aggregate_job_via_aggregator = aggregate_instance(project) + if total_jobs % valid_value == 0: + length_of_aggregate = total_jobs/valid_value + else: + length_of_aggregate = int(total_jobs/valid_value) + 1 + assert len(aggregate_job_via_aggregator) == length_of_aggregate + + def test_groupby_with_valid_string_key(self, setUp, project): + aggregate_instance = Aggregate.groupby('even') + aggregate_instance = aggregate_instance._create_MakeAggregate() + aggregates = 0 + for agg in aggregate_instance(project): + aggregates += 1 + assert aggregates == 2 + + def test_groupby_with_invalid_string_key(self, setUp, project): + aggregate_instance = Aggregate.groupby('invalid_key') + aggregate_instance = aggregate_instance._create_MakeAggregate() + with pytest.raises(KeyError): + for agg in aggregate_instance(project): + pass + + def test_groupby_with_default_key_for_string(self, setUp, project): + aggregate_instance = Aggregate.groupby('half', default=-1) + aggregate_instance = aggregate_instance._create_MakeAggregate() + aggregates = 0 + for agg in aggregate_instance(project): + aggregates += 1 + assert aggregates == 6 + + def test_groupby_with_Iterable_key(self, setUp, project): + aggregate_instance = Aggregate.groupby(['i', 'even']) + aggregate_instance = aggregate_instance._create_MakeAggregate() + aggregates = 0 + for agg in aggregate_instance(project): + aggregates += 1 + assert aggregates == 10 + + def test_groupby_with_invalid_Iterable_key(self, setUp, project): + aggregate_instance = Aggregate.groupby(['half', 'even']) + aggregate_instance = aggregate_instance._create_MakeAggregate() + with pytest.raises(KeyError): + for agg in aggregate_instance(project): + pass + + def test_groupby_with_valid_default_key_for_Iterable(self, setUp, project): + aggregate_instance = Aggregate.groupby(['half', 'even'], default=[-1, -1]) + aggregate_instance = aggregate_instance._create_MakeAggregate() + aggregates = 0 + for agg in aggregate_instance(project): + aggregates += 1 + assert aggregates == 6 + + def test_groupby_with_callable_key(self, setUp, project): + def keyfunction(job): + return job.sp['even'] + + aggregate_instance = Aggregate.groupby(keyfunction) + aggregate_instance = aggregate_instance._create_MakeAggregate() + aggregates = 0 + for agg in aggregate_instance(project): + aggregates += 1 + assert aggregates == 2 + + def test_groupby_with_invalid_callable_key(self, setUp, project): + def keyfunction(job): + return job.sp['half'] + aggregate_instance = Aggregate.groupby(keyfunction) + aggregate_instance = aggregate_instance._create_MakeAggregate() + with pytest.raises(KeyError): + for agg in aggregate_instance(project): + pass + + def test_valid_select(self, setUp, project): + def _select(job): + return job.sp.i > 5 + + aggregate_instance = Aggregate.groupsof(1, select=_select) + aggregate_instance = aggregate_instance._create_MakeAggregate() + selected_jobs = [] + for job in project: + if _select(job): + selected_jobs.append((job,)) + assert aggregate_instance(project) == selected_jobs diff --git a/tests/test_project.py b/tests/test_project.py index d4397499c..557ef4a2d 100644 --- a/tests/test_project.py +++ b/tests/test_project.py @@ -20,6 +20,7 @@ import signac import flow from flow import FlowProject, cmd, with_job, directives +from flow import get_aggregate_id from flow.scheduling.base import Scheduler from flow.scheduling.base import ClusterJob from flow.scheduling.base import JobStatus @@ -33,6 +34,7 @@ from define_test_project import _TestProject from define_test_project import _DynamicTestProject from define_dag_test_project import DagTestProject +from define_aggregate_test_project import _AggregateTestProject @contextmanager @@ -143,6 +145,7 @@ def recursive_update(d, u): project.open_job(dict(a=a, b=b)).init() project.open_job(dict(a=dict(a=a), b=b)).init() project._entrypoint = self.entrypoint + project.register_aggregates() return project @@ -163,6 +166,7 @@ def mock_project(self): project = self.project_class.get_project(root=self._tmp_dir.name) for i in range(1000): project.open_job(dict(i=i)).init() + project.register_aggregates() return project @pytest.mark.skipif(signac.__version__ < '1.3.0', @@ -179,7 +183,7 @@ def test_status_performance(self): MockScheduler.reset() time = timeit.timeit( - lambda: project._fetch_status(project, StringIO(), + lambda: project._fetch_status(None, project, StringIO(), ignore_errors=False), number=10) assert time < 10 @@ -430,7 +434,7 @@ def test_context(job): project = self.mock_project(A) for job in project: job.doc.np = 3 - for next_op in project._next_operations((job,)): + for next_op in project._next_operations([(job,)]): assert 'mpirun -np 3 python' in next_op.cmd break @@ -449,20 +453,20 @@ def a(job): # test setting neither nranks nor omp_num_threads for job in project: - for next_op in project._next_operations((job,)): + for next_op in project._next_operations([(job,)]): assert next_op.directives['np'] == 1 # test only setting nranks for i, job in enumerate(project): job.doc.nranks = i+1 - for next_op in project._next_operations((job,)): + for next_op in project._next_operations([(job,)]): assert next_op.directives['np'] == next_op.directives['nranks'] del job.doc['nranks'] # test only setting omp_num_threads for i, job in enumerate(project): job.doc.omp_num_threads = i+1 - for next_op in project._next_operations((job,)): + for next_op in project._next_operations([(job,)]): assert next_op.directives['np'] == next_op.directives['omp_num_threads'] del job.doc['omp_num_threads'] @@ -471,7 +475,7 @@ def a(job): job.doc.omp_num_threads = i+1 job.doc.nranks = i % 3 + 1 expected_np = (i + 1) * (i % 3 + 1) - for next_op in project._next_operations((job,)): + for next_op in project._next_operations([(job,)]): assert next_op.directives['np'] == expected_np def test_copy_conditions(self): @@ -595,7 +599,7 @@ def test_next_operations(self): project = self.mock_project() even_jobs = [job for job in project if job.sp.b % 2 == 0] for job in project: - for i, op in enumerate(project._next_operations((job,))): + for i, op in enumerate(project._next_operations([(job,)])): assert op._jobs == (job,) if job in even_jobs: assert op.name == ['op1', 'op2', 'op3'][i] @@ -609,7 +613,7 @@ def test_get_job_status(self): status = project.get_job_status(job) assert status['job_id'] == job.get_id() assert len(status['operations']) == len(project.operations) - for op in project._next_operations((job,)): + for op in project._next_operations([(job,)]): assert op.name in status['operations'] op_status = status['operations'][op.name] assert op_status['eligible'] == project.operations[op.name]._eligible((job,)) @@ -659,7 +663,7 @@ def test_project_status_heterogeneous_schema(self): def test_script(self): project = self.mock_project() for job in project: - script = project._script(project._next_operations((job,))) + script = project._script(project._next_operations([(job,)])) if job.sp.b % 2 == 0: assert str(job) in script assert 'echo "hello"' in script @@ -679,7 +683,7 @@ def test_script_with_custom_script(self): file.write("THIS IS A CUSTOM SCRIPT!\n") file.write("{% endblock %}\n") for job in project: - script = project._script(project._next_operations((job,))) + script = project._script(project._next_operations([(job,)])) assert "THIS IS A CUSTOM SCRIPT" in script if job.sp.b % 2 == 0: assert str(job) in script @@ -734,15 +738,15 @@ class TestExecutionProject(TestProjectBase): def test_pending_operations_order(self): # The execution order of local runs is internally assumed to be - # 'by-job' by default. A failure of this unit tests means that - # a 'by-job' order must be implemented explicitly within the + # 'by-op' by default. A failure of this unit tests means that + # a 'by-op' order must be implemented explicitly within the # FlowProject.run() function. project = self.mock_project() - ops = list(project._get_pending_operations(self.project.find_jobs())) + ops = list(project._get_pending_operations()) # The length of the list of operations grouped by job is equal # to the length of its set if and only if the operations are grouped # by job already: - jobs_order_none = [job._id for job, _ in groupby(ops, key=lambda op: op._jobs[0])] + jobs_order_none = [name for name, _ in groupby(ops, key=lambda op: op.name)] assert len(jobs_order_none) == len(set(jobs_order_none)) def test_run(self, subtests): @@ -754,7 +758,7 @@ def test_run(self, subtests): def sort_key(op): return op.name, op._jobs[0].get_id() - for order in (None, 'none', 'cyclic', 'by-job', 'random', sort_key): + for order in (None, 'none', 'cyclic', 'by-job', 'by-op', 'random', sort_key): for job in self.project.find_jobs(): # clear job.remove() with subtests.test(order=order): @@ -909,7 +913,7 @@ def test_submit_operations(self): project = self.mock_project() operations = [] for job in project: - operations.extend(project._next_operations((job,))) + operations.extend(project._next_operations([(job,)])) assert len(list(MockScheduler.jobs())) == 0 cluster_job_id = project._store_bundled(operations) with redirect_stderr(StringIO()): @@ -990,7 +994,7 @@ def test_submit_status(self): if job not in even_jobs: continue list(project.labels(job)) - next_op = list(project._next_operations((job,)))[0] + next_op = list(project._next_operations([(job,)]))[0] assert next_op.name == 'op1' assert next_op._jobs == (job,) with redirect_stderr(StringIO()): @@ -998,7 +1002,7 @@ def test_submit_status(self): assert len(list(MockScheduler.jobs())) == num_jobs_submitted for job in project: - next_op = list(project._next_operations((job,)))[0] + next_op = list(project._next_operations([(job,)]))[0] assert next_op.get_status() == JobStatus.submitted MockScheduler.step() @@ -1006,7 +1010,7 @@ def test_submit_status(self): project._fetch_scheduler_status(file=StringIO()) for job in project: - next_op = list(project._next_operations((job,)))[0] + next_op = list(project._next_operations([(job,)]))[0] assert next_op.get_status() == JobStatus.queued MockScheduler.step() @@ -1022,7 +1026,7 @@ def test_submit_operations_bad_directive(self): project = self.mock_project() operations = [] for job in project: - operations.extend(project._next_operations((job,))) + operations.extend(project._next_operations([(job,)])) assert len(list(MockScheduler.jobs())) == 0 cluster_job_id = project._store_bundled(operations) stderr = StringIO() @@ -1178,7 +1182,7 @@ def test_main_status(self): op_lines.append(next(lines)) except StopIteration: continue - for op in project._next_operations((job,)): + for op in project._next_operations([(job,)]): assert any(op.name in op_line for op_line in op_lines) def test_main_script(self): @@ -1230,7 +1234,7 @@ def test_script(self): project = self.mock_project() # For run mode single operation groups for job in project: - job_ops = project._get_submission_operations((job,), dict()) + job_ops = project._get_submission_operations([(job,)], dict()) script = project._script(job_ops) if job.sp.b % 2 == 0: assert str(job) in script @@ -1525,3 +1529,176 @@ def test_main_submit(self): class TestGroupDynamicProjectMainInterface(TestProjectMainInterface): project_class = _DynamicTestProject + + +class TestAggregationProjectMainInterface(TestProjectBase): + project_class = _AggregateTestProject + entrypoint = dict( + path=os.path.realpath(os.path.join(os.path.dirname(__file__), + 'define_aggregate_test_project.py')) + ) + + def mock_project(self): + project = self.project_class.get_project(root=self._tmp_dir.name) + for i in range(30): + project.open_job(dict(i=i)).init() + project._entrypoint = self.entrypoint + project.register_aggregates() + return project + + def switch_to_cwd(self): + os.chdir(self.cwd) + + @pytest.fixture(autouse=True) + def setup_main_interface(self, request): + self.project = self.mock_project() + self.cwd = os.getcwd() + os.chdir(self._tmp_dir.name) + request.addfinalizer(self.switch_to_cwd) + + def call_subcmd(self, subcmd): + # Determine path to project module and construct command. + fn_script = inspect.getsourcefile(type(self.project)) + _cmd = 'python {} {}'.format(fn_script, subcmd) + try: + with add_path_to_environment_pythonpath(os.path.abspath(self.cwd)): + with switch_to_directory(self.project.root_directory()): + return subprocess.check_output(_cmd.split(), stderr=subprocess.DEVNULL) + except subprocess.CalledProcessError as error: + print(error, file=sys.stderr) + print(error.output, file=sys.stderr) + raise + + def generate_str_jobop(self, jobs): + # This method is supposed to replicate the job representation of an + # _JobOperation instance + jobs = list(jobs) + max_len = 3 + min_len_unique_id = jobs[0]._project.min_len_unique_id() + if len(jobs) > max_len: + shown = list(jobs[:max_len-2]) + ['...'] + list(jobs[-1:]) + else: + shown = jobs + return f"[#{len(jobs)}]" \ + f"({', '.join([str(element)[:min_len_unique_id] for element in shown])})" + + def test_main_run(self): + project = self.mock_project() + assert len(project) + for job in project: + assert not job.doc.get('sum', False) + self.call_subcmd( + f'run -o agg_op1 -j {get_aggregate_id(project)}' + ) + sum = 0 + for job in project: + sum += job.sp.i + + for job in project: + assert job.doc['sum'] == sum + + def test_main_script(self): + project = self.mock_project() + assert len(project) + hashed_aggregate_id = get_aggregate_id(project) + script_output = self.call_subcmd( + f'script -o agg_op1 -j {hashed_aggregate_id}' + ).decode().splitlines() + assert self.generate_str_jobop(project) in '\n'.join(script_output) + assert f'-o agg_op1 -j {hashed_aggregate_id}' in '\n'.join(script_output) + + def test_main_submit(self): + project = self.mock_project() + assert len(project) + # Assert that correct output for group submission is given + hashed_aggregate_id = get_aggregate_id(project) + submit_output = self.call_subcmd( + f'submit -o agg_op1 -j {hashed_aggregate_id} --pretend' + ).decode().splitlines() + output_string = '\n'.join(submit_output) + assert self.generate_str_jobop(project) in output_string + assert f'run -o agg_op1 -j {hashed_aggregate_id}' in output_string + + +class TestGroupAggregationProjectMainInterface(TestProjectBase): + project_class = _AggregateTestProject + entrypoint = dict( + path=os.path.realpath(os.path.join(os.path.dirname(__file__), + 'define_aggregate_test_project.py')) + ) + + def mock_project(self): + project = self.project_class.get_project(root=self._tmp_dir.name) + for i in range(30): + project.open_job(dict(i=i)).init() + project._entrypoint = self.entrypoint + project.register_aggregates() + return project + + def switch_to_cwd(self): + os.chdir(self.cwd) + + @pytest.fixture(autouse=True) + def setup_main_interface(self, request): + self.project = self.mock_project() + self.cwd = os.getcwd() + os.chdir(self._tmp_dir.name) + request.addfinalizer(self.switch_to_cwd) + + def call_subcmd(self, subcmd): + # Determine path to project module and construct command. + fn_script = inspect.getsourcefile(type(self.project)) + _cmd = 'python {} {}'.format(fn_script, subcmd) + try: + with add_path_to_environment_pythonpath(os.path.abspath(self.cwd)): + with switch_to_directory(self.project.root_directory()): + return subprocess.check_output(_cmd.split(), stderr=subprocess.DEVNULL) + except subprocess.CalledProcessError as error: + print(error, file=sys.stderr) + print(error.output, file=sys.stderr) + raise + + def generate_str_jobop(self, jobs): + # This method is supposed to replicate the job representation of an + # _JobOperation instance + jobs = list(jobs) + max_len = 3 + min_len_unique_id = jobs[0]._project.min_len_unique_id() + if len(jobs) > max_len: + shown = list(jobs[:max_len-2]) + ['...'] + list(jobs[-1:]) + else: + shown = jobs + return f"[#{len(jobs)}]" \ + f"({', '.join([str(element)[:min_len_unique_id] for element in shown])})" + + def test_main_run(self): + project = self.mock_project() + assert len(project) + for job in project: + assert not job.doc.get('average', False) + assert not job.doc.get('test3', False) + self.call_subcmd(f'run -o group_agg -j {get_aggregate_id(project)}') + for job in project: + assert job.doc.get('average', False) + assert job.doc.get('test3', False) + + def test_main_script(self): + project = self.mock_project() + assert len(project) + hashed_aggregate_id = get_aggregate_id(project) + script_output = self.call_subcmd( + f'script -o group_agg -j {hashed_aggregate_id}' + ).decode().splitlines() + assert self.generate_str_jobop(project) in '\n'.join(script_output) + assert f'-o group_agg -j {hashed_aggregate_id}' in '\n'.join(script_output) + + def test_main_submit(self): + project = self.mock_project() + assert len(project) + hashed_aggregate_id = get_aggregate_id(project) + submit_output = self.call_subcmd( + f'submit -o group_agg -j {hashed_aggregate_id} --pretend' + ).decode().splitlines() + output_string = '\n'.join(submit_output) + assert self.generate_str_jobop(project) in output_string + assert f'run -o group_agg -j {hashed_aggregate_id}' in output_string diff --git a/tests/test_templates.py b/tests/test_templates.py index 0c273124b..919c731e9 100644 --- a/tests/test_templates.py +++ b/tests/test_templates.py @@ -27,6 +27,19 @@ def find_envs(): yield env +def mock_submit(fp, env, jobs, names, bundle_size, **kwargs): + tmp_out = io.TextIOWrapper( + io.BytesIO(), sys.stdout.encoding) + with open(os.devnull, 'w') as devnull: + with redirect_stderr(devnull): + with redirect_stdout(tmp_out): + fp.submit( + env=env, bundle_size=bundle_size, jobs=jobs, names=names, + force=True, pretend=True, **kwargs) + tmp_out.seek(0) + return tmp_out + + @pytest.mark.parametrize('env', find_envs()) def test_env(env): # Force asserts to show the full file when failures occur. @@ -37,26 +50,20 @@ def test_env(env): fp = gen.get_masked_flowproject(p) fp.import_from(origin=gen.ARCHIVE_DIR) jobs = fp.find_jobs(dict(environment=_env_name(env))) + fp.register_aggregates() if not len(jobs): raise RuntimeError( "No reference data for environment {}!".format(_env_name(env)) ) reference = [] generated = [] + agg_ops = ['serial_agg_op', 'parallel_agg_op', 'group2'] for job in jobs: parameters = job.sp.parameters() if 'bundle' in parameters: bundle = parameters.pop('bundle') - tmp_out = io.TextIOWrapper( - io.BytesIO(), sys.stdout.encoding) - with open(os.devnull, 'w') as devnull: - with redirect_stderr(devnull): - with redirect_stdout(tmp_out): - fp.submit( - env=env, jobs=[job], names=bundle, pretend=True, - force=True, bundle_size=len(bundle), **parameters) - tmp_out.seek(0) - msg = "---------- Bundled submission of job {}".format(job) + tmp_out = mock_submit(fp, env, [job], bundle, len(bundle), **parameters) + msg = f"---------- Bundled submission of job {job}" generated.extend([msg] + tmp_out.read().splitlines()) with open(job.fn('script_{}.sh'.format('_'.join(bundle)))) as file: @@ -72,16 +79,14 @@ def test_env(env): 'gpu' in parameters['partition'].lower(), 'gpu' in op.lower()): continue - tmp_out = io.TextIOWrapper( - io.BytesIO(), sys.stdout.encoding) - with open(os.devnull, 'w') as devnull: - with redirect_stderr(devnull): - with redirect_stdout(tmp_out): - fp.submit( - env=env, jobs=[job], - names=[op], pretend=True, force=True, **parameters) - tmp_out.seek(0) - msg = "---------- Submission of operation {} for job {}.".format(op, job) + if op in agg_ops: + continue + # tmp_out = mock_submit(fp, env, None, [op], 1, **parameters) + # msg = f"---------- Submission of operation {op} for jobs " \ + # f"{' '.join(map(str, jobs))}." + else: + tmp_out = mock_submit(fp, env, [job], [op], 1, **parameters) + msg = f"---------- Submission of operation {op} for job {job}." generated.extend([msg] + tmp_out.read().splitlines()) with open(job.fn('script_{}.sh'.format(op))) as file: