diff --git a/flow/__init__.py b/flow/__init__.py index 3b8c8a945..cc6920cf2 100644 --- a/flow/__init__.py +++ b/flow/__init__.py @@ -12,6 +12,7 @@ from . import scheduling from . import errors from . import testing +from .aggregates import get_aggregate_id from .project import IgnoreConditions from .project import FlowProject from .project import JobOperation @@ -38,6 +39,7 @@ 'scheduling', 'errors', 'testing', + 'get_aggregate_id', 'IgnoreConditions', 'FlowProject', 'JobOperation', diff --git a/flow/aggregates.py b/flow/aggregates.py index f5a00a32c..7d448d65f 100644 --- a/flow/aggregates.py +++ b/flow/aggregates.py @@ -2,6 +2,9 @@ # All rights reserved. # This software is licensed under the BSD 3-Clause License. import itertools + +from collections import Mapping +from collections import OrderedDict from collections.abc import Iterable from hashlib import md5 @@ -44,7 +47,7 @@ def foo(*jobs): def __init__(self, aggregator_function=None, sort_by=None, sort_ascending=True, select=None): if aggregator_function is None: def aggregator_function(jobs): - return [jobs] + return (jobs,) if jobs else () if not callable(aggregator_function): raise TypeError("Expected callable for aggregator_function, got " @@ -258,7 +261,7 @@ def __call__(self, func=None): f'got {type(func)}.') -class _AggregatesStore: +class _AggregatesStore(Mapping): """This class holds the information of all the aggregates associated with a :class:`aggregator`. @@ -278,35 +281,35 @@ def __init__(self, aggregator, project): self._aggregator = aggregator # We need to register the aggregates for this instance using the - # project provided. - self._aggregates = [] - self._aggregate_ids = {} + # project provided. After registering, we store the aggregates + # mapped with the ids using the `get_aggregate_id` method. + self._aggregate_per_id = OrderedDict() self._register_aggregates(project) def __iter__(self): - yield from self._aggregates + yield from self._aggregate_per_id def __getitem__(self, id): - "Return an aggregate, if exists, using the id provided" + """Get the aggregate corresponding to the provided id.""" try: - return self._aggregate_ids[id] + return self._aggregate_per_id[id] except KeyError: raise LookupError(f'Unable to find the aggregate having id {id} in ' 'the FlowProject') - def __contains__(self, aggregate): - """Return whether an aggregate is stored in the this - instance of :py:class:`_AggregateStore` + def __contains__(self, id): + """Return whether an aggregate is stored in this instance of + :py:class:`_AggregateStore`. - :param aggregate: - An aggregate of jobs. - :type aggregate: - tuple of :py:class:`signac.contrib.job.Job` + :param id: + The id of an aggregate of jobs. + :type id: + str """ - return get_aggregate_id(aggregate) in self._aggregate_ids + return id in self._aggregate_per_id def __len__(self): - return len(self._aggregates) + return len(self._aggregate_per_id) def __eq__(self, other): return type(self) == type(other) and self._aggregator == other._aggregator @@ -314,9 +317,18 @@ def __eq__(self, other): def __hash__(self): return hash(self._aggregator) + def keys(self): + return self._aggregate_per_id.keys() + + def values(self): + return self._aggregate_per_id.values() + + def items(self): + return self._aggregate_per_id.items() + def _register_aggregates(self, project): """If the instance of this class is called then we will - generate aggregates and store them in ``self._aggregates``. + generate aggregates and store them in ``self._aggregate_per_id``. """ aggregated_jobs = self._generate_aggregates(project) self._create_nested_aggregate_list(aggregated_jobs, project) @@ -355,13 +367,11 @@ def _validate_and_filter_job(job): filter_aggregate = tuple(filter(_validate_and_filter_job, aggregate)) except TypeError: # aggregate is not iterable ValueError("Invalid aggregator_function provided by the user.") - # Store aggregate in this instance - self._aggregates.append(filter_aggregate) # Store aggregate by their ids in order to search through id - self._aggregate_ids[get_aggregate_id(filter_aggregate)] = filter_aggregate + self._aggregate_per_id[get_aggregate_id(filter_aggregate)] = filter_aggregate -class _DefaultAggregateStore: +class _DefaultAggregateStore(Mapping): """This class holds the information of the project associated with an operation function using the default aggregator, i.e. ``aggregator.groupsof(1)``. @@ -378,7 +388,7 @@ def __init__(self, project): def __iter__(self): for job in self._project: - yield (job,) + yield job.get_id() def __getitem__(self, id): "Return a tuple of a single job via job id." @@ -387,13 +397,23 @@ def __getitem__(self, id): except KeyError: raise LookupError(f"Did not find aggregate with id {id}.") - def __contains__(self, aggregate): + def __contains__(self, id): """Return whether the job is present in the project associated with this instance of :py:class:`_DefaultAggregateStore`. + + :param id: + The job id. + :type id: + str """ - # signac-flow internally assumes every aggregate to be a tuple. - # Hence this method will also get a tuple as an input. - return len(aggregate) == 1 and aggregate[0] in self._project + try: + self._project.open_job(id=id) + except KeyError: + return False + except LookupError: + raise + else: + return True def __len__(self): return len(self._project) @@ -404,6 +424,18 @@ def __eq__(self, other): def __hash__(self): return hash(repr(self._project)) + def keys(self): + for job in self._project: + yield job.get_id() + + def values(self): + for job in self._project: + yield (job,) + + def items(self): + for job in self._project: + yield (job.get_id(), (job,)) + def _register_aggregates(self, project): """We have to store self._project when this method is invoked This is because we will then iterate over that project in diff --git a/flow/project.py b/flow/project.py index 0bac3aad4..7ac2a7085 100644 --- a/flow/project.py +++ b/flow/project.py @@ -28,6 +28,7 @@ from itertools import islice from itertools import count from itertools import groupby +from itertools import chain from hashlib import sha1 import multiprocessing import threading @@ -47,6 +48,9 @@ from enum import IntFlag +from .aggregates import get_aggregate_id +from .aggregates import aggregator +from .aggregates import _DefaultAggregateStore from .environment import get_environment from .scheduling.base import ClusterJob from .scheduling.base import JobStatus @@ -302,7 +306,7 @@ def __str__(self): max_len = 3 min_len_unique_id = self._jobs[0]._project.min_len_unique_id() if len(self._jobs) > max_len: - shown = self._jobs[:max_len-2] + ['...'] + self._jobs[-1:] + shown = self._jobs[:max_len - 2] + ('...',) + self._jobs[-1:] else: shown = self._jobs return f"{self.name}[#{len(self._jobs)}]" \ @@ -686,8 +690,8 @@ def __str__(self): return "{type}(op_func='{op_func}')" \ "".format(type=type(self).__name__, op_func=self._op_func) - def __call__(self, job): - return self._op_func(job) + def __call__(self, *jobs): + return self._op_func(*jobs) class FlowGroupEntry(object): @@ -709,11 +713,15 @@ class FlowGroupEntry(object): commands to execute. :type options: str + :param aggregator: + aggregator object associated with the :py:class:`FlowGroup` + :type aggregator: + :py:class:`aggregator` """ - - def __init__(self, name, options=""): + def __init__(self, name, options="", aggregator=aggregator.groupsof(1)): self.name = name self.options = options + self.aggregator = aggregator def __call__(self, func): """Decorator that adds the function into the group's operations. @@ -866,8 +874,8 @@ def _resolve_directives(self, name, defaults, env): def _submit_cmd(self, entrypoint, ignore_conditions, jobs=None): entrypoint = self._determine_entrypoint(entrypoint, dict(), jobs) - cmd = "{} run -o {}".format(entrypoint, self.name) - cmd = cmd if jobs is None else cmd + ' -j {}'.format(' '.join(map(str, jobs))) + cmd = f"{entrypoint} run -o {self.name}" + cmd = cmd if jobs is None else cmd + f' -j {get_aggregate_id(jobs)}' cmd = cmd if self.options is None else cmd + ' ' + self.options if ignore_conditions != IgnoreConditions.NONE: return cmd.strip() + ' --ignore-conditions=' + str(ignore_conditions) @@ -879,7 +887,7 @@ def _run_cmd(self, entrypoint, operation_name, operation, directives, jobs): return operation(*jobs).lstrip() else: entrypoint = self._determine_entrypoint(entrypoint, directives, jobs) - return f"{entrypoint} exec {operation_name} {' '.join(map(str, jobs))}".lstrip() + return f"{entrypoint} exec {operation_name} {get_aggregate_id(jobs)}".lstrip() def __iter__(self): yield from self.operations.values() @@ -1025,7 +1033,7 @@ def _generate_id(self, jobs, operation_name=None, index=0): raise ValueError("Value for MAX_LEN_ID is too small ({}).".format(self.MAX_LEN_ID)) if len(jobs) > 1: - concat_jobs_str = str(jobs[0])[0:8]+'-'+str(jobs[-1])[0:8] + concat_jobs_str = str(jobs[0])[0:8] + '-' + str(jobs[-1])[0:8] else: concat_jobs_str = str(jobs[0])[0:8] @@ -1385,8 +1393,13 @@ def __init__(self, config=None, environment=None, entrypoint=None): # Register all groups with this project instance. self._groups = dict() + self._aggregator_per_group = dict() self._register_groups() + # Register all aggregates which are created for this project + self._stored_aggregates = dict() + self._register_aggregates() + def _setup_template_environment(self): """Setup the jinja2 template environment. @@ -1712,28 +1725,31 @@ def scheduler_jobs(self, scheduler): for sjob in self._expand_bundled_jobs(scheduler.jobs()): yield sjob - def _get_operations_status(self, job, cached_status): - "Return a dict with information about job-operations for this job." + def _get_operations_status(self, jobs, cached_status): + "Return a dict with information about job-operations for this aggregate." starting_dict = functools.partial(dict, scheduler_status=JobStatus.unknown) status_dict = defaultdict(starting_dict) - for group in self._groups.values(): - completed = group._complete((job,)) - eligible = False if completed else group._eligible((job,)) - scheduler_status = cached_status.get(group._generate_id((job,)), - JobStatus.unknown) - for operation in group.operations: - if scheduler_status >= status_dict[operation]['scheduler_status']: - status_dict[operation] = { - 'scheduler_status': scheduler_status, - 'eligible': eligible, - 'completed': completed - } - - for key in sorted(status_dict): - yield key, status_dict[key] + operation_names = list(self.operations.keys()) + groups = [self._groups[name] for name in operation_names] + for group in groups: + if get_aggregate_id(jobs) in self._get_aggregate_store(group.name): + completed = group._complete(jobs) + eligible = not completed and group._eligible(jobs) + scheduler_status = cached_status.get(group._generate_id(jobs), + JobStatus.unknown) + for operation in group.operations: + if scheduler_status >= status_dict[operation]['scheduler_status']: + status_dict[operation] = { + 'scheduler_status': scheduler_status, + 'eligible': eligible, + 'completed': completed + } + + yield from sorted(status_dict.items()) def get_job_status(self, job, ignore_errors=False, cached_status=None): - "Return a dict with detailed information about the status of a job." + "Return a dict with detailed information about the status of a job or an aggregate of jobs." + # TODO: Add support for aggregates for this method. result = dict() result['job_id'] = str(job) try: @@ -1742,7 +1758,7 @@ def get_job_status(self, job, ignore_errors=False, cached_status=None): cached_status = self.document['_status']._as_dict() except KeyError: cached_status = dict() - result['operations'] = OrderedDict(self._get_operations_status(job, cached_status)) + result['operations'] = OrderedDict(self._get_operations_status((job,), cached_status)) result['_operations_error'] = None except Exception as error: msg = "Error while getting operations status for job '{}': '{}'.".format(job, error) @@ -1756,7 +1772,7 @@ def get_job_status(self, job, ignore_errors=False, cached_status=None): result['labels'] = sorted(set(self.labels(job))) result['_labels_error'] = None except Exception as error: - logger.debug("Error while determining labels for job '{}': '{}'.".format(job, error)) + logger.debug(f"Error while determining labels for job '{job}': '{error}'.") if ignore_errors: result['labels'] = list() result['_labels_error'] = str(error) @@ -1768,8 +1784,6 @@ def _fetch_scheduler_status(self, jobs=None, file=None, ignore_errors=False): "Update the status docs." if file is None: file = sys.stderr - if jobs is None: - jobs = list(self) try: scheduler = self._environment.get_scheduler() @@ -1777,12 +1791,16 @@ def _fetch_scheduler_status(self, jobs=None, file=None, ignore_errors=False): scheduler_info = {sjob.name(): sjob.status() for sjob in self.scheduler_jobs(scheduler)} status = dict() print("Query scheduler...", file=file) - for job in tqdm(jobs, - desc="Fetching operation status", - total=len(jobs), file=file): - for group in self._groups.values(): - _id = group._generate_id((job,)) - status[_id] = int(scheduler_info.get(_id, JobStatus.unknown)) + for group in tqdm(self._groups.values(), + desc="Fetching operation status", + total=len(self._groups), file=file): + aggregate_store = self._get_aggregate_store(group.name) + for aggregate in tqdm(aggregate_store.values(), total=len(aggregate_store), + desc="Fetching aggregate info for aggregate", + leave=False, file=file): + if self._is_selected_aggregate(aggregate, jobs): + submit_id = group._generate_id(aggregate) + status[submit_id] = int(scheduler_info.get(submit_id, JobStatus.unknown)) self.document._status.update(status) except NoSchedulerError: logger.debug("No scheduler available.") @@ -1793,7 +1811,88 @@ def _fetch_scheduler_status(self, jobs=None, file=None, ignore_errors=False): else: logger.info("Updated job status cache.") - def _fetch_status(self, jobs, err, ignore_errors, status_parallelization='thread'): + def _get_group_status(self, group_name, ignore_errors=False, cached_status=None): + "Return a dict with detailed information about the status of jobs per group." + group = self._groups[group_name] + status_dict = dict() + errors = dict() + aggregate_store = self._get_aggregate_store(group.name) + for aggregate_id, aggregate in tqdm(aggregate_store.items(), + desc="Collecting aggregate status info " + f"for operation {group.name}", + leave=False): + errors.setdefault(aggregate_id, '') + try: + job_op_id = group._generate_id(aggregate) + scheduler_status = cached_status.get(job_op_id, JobStatus.unknown) + completed = group._complete(aggregate) + eligible = False if completed else group._eligible(aggregate) + except Exception as error: + msg = "Error while getting operations status for aggregate " \ + "'{}': '{}'.".format(aggregate_id, error) + logger.debug(msg) + if ignore_errors: + errors[aggregate_id] += str(error) + '\n' + scheduler_status = JobStatus.unknown + completed = False + eligible = False + else: + raise + finally: + status_dict[aggregate_id] = { + 'scheduler_status': scheduler_status, + 'eligible': eligible, + 'completed': completed, + } + + return { + 'operation_name': group_name, + 'job_status_details': status_dict, + '_operation_error_per_job': errors, + } + + def _get_job_labels(self, job, ignore_errors=False): + "Return a dict with information about the labels of a job." + result = dict() + result['job_id'] = str(job) + try: + result['labels'] = sorted(set(self.labels(job))) + except Exception as error: + logger.debug("Error while determining labels for job '{}': '{}'.".format(job, error)) + if ignore_errors: + result['labels'] = list() + result['_labels_error'] = str(error) + else: + raise + else: + result['_labels_error'] = None + return result + + def _fetch_status(self, aggregates, distinct_jobs, err, ignore_errors, + status_parallelization='thread'): + """Fetch status associated for either all the jobs associated in a project + or jobs specified by a user + + :param aggregates: + The aggregates for which a user requested to fetch status. + :type aggregates: + list + :param distinct_jobs: + Distinct jobs fetched from the ids provided in the ``jobs`` argument. + This is used for fetching labels for a job because a label is not associated + with an aggregate. + :type distinct_jobs: + list of :py:class:`signac.contrib.job.Job` + :param ignore_errors: + Fetch status even if querying the scheduler fails. + :type ignore_errors: + bool + :param status_parallelization: + Nature of parallelization for fetching the status. + Default value parallelizes using ``multiprocessing.ThreadPool()`` + :type status_parallelization: + str + """ # The argument status_parallelization is used so that _fetch_status method # gets to know whether the deprecated argument no_parallelization passed # while calling print_status is True or False. This can also be done by @@ -1802,27 +1901,20 @@ def _fetch_status(self, jobs, err, ignore_errors, status_parallelization='thread # to do proper deprecation, it is not required for now. # Update the project's status cache - self._fetch_scheduler_status(jobs, err, ignore_errors) - # Get status dict for all selected jobs - - def _print_progress(x): - print("Updating status: ", end='', file=err) - err.flush() - n = max(1, int(len(jobs) / 10)) - for i, _ in enumerate(x): - if (i % n) == 0: - print('.', end='', file=err) - err.flush() - yield _ - + self._fetch_scheduler_status(aggregates, err, ignore_errors) + # Get status dict for all selected aggregates try: cached_status = self.document['_status']._as_dict() except KeyError: cached_status = dict() - _get_job_status = functools.partial(self.get_job_status, - ignore_errors=ignore_errors, - cached_status=cached_status) + get_job_labels = functools.partial(self._get_job_labels, + ignore_errors=ignore_errors) + get_group_status = functools.partial(self._get_group_status, + ignore_errors=ignore_errors, + cached_status=cached_status) + + operation_names = list(self.operations.keys()) with self._potentially_buffered(): try: @@ -1830,15 +1922,21 @@ def _print_progress(x): with contextlib.closing(ThreadPool()) as pool: # First attempt at parallelized status determination. # This may fail on systems that don't allow threads. - return list(tqdm( - iterable=pool.imap(_get_job_status, jobs), - desc="Collecting job status info", total=len(jobs), file=err)) + label_results = list(tqdm( + iterable=pool.imap(get_job_labels, distinct_jobs), + desc="Collecting job label info", total=len(distinct_jobs), + file=err)) + op_results = list(tqdm( + iterable=pool.imap(get_group_status, operation_names), + desc="Collecting operation status", total=len(operation_names), + file=err)) elif status_parallelization == 'process': with contextlib.closing(Pool()) as pool: try: import pickle - results = self._fetch_status_in_parallel( - pool, pickle, jobs, ignore_errors, cached_status) + l_results, g_results = self._fetch_status_in_parallel( + pool, pickle, distinct_jobs, operation_names, ignore_errors, + cached_status) except Exception as error: if not isinstance(error, (pickle.PickleError, self._PickleError)) and\ 'pickle' not in str(error).lower(): @@ -1854,52 +1952,104 @@ def _print_progress(x): raise error else: try: - results = self._fetch_status_in_parallel( - pool, cloudpickle, jobs, ignore_errors, cached_status) + l_results, g_results = self._fetch_status_in_parallel( + pool, cloudpickle, distinct_jobs, operation_names, + ignore_errors, cached_status) except self._PickleError as error: raise RuntimeError( "Unable to parallelize execution due to a pickling " "error: {}.".format(error)) - return list(tqdm( - iterable=results, - desc="Collecting job status info", total=len(jobs), file=err)) + label_results = list(tqdm( + iterable=l_results, desc="Collecting job label info", + total=len(distinct_jobs), file=err)) + op_results = list(tqdm( + iterable=g_results, desc="Collecting operation status", + total=len(operation_names), file=err)) elif status_parallelization == 'none': - return list(tqdm( - iterable=map(_get_job_status, jobs), - desc="Collecting job status info", total=len(jobs), file=err)) + label_results = list(tqdm( + iterable=map(get_job_labels, distinct_jobs), + desc="Collecting job label info", total=len(distinct_jobs), + file=err)) + op_results = list(tqdm( + iterable=map(get_group_status, operation_names), + desc="Collecting operation status", total=len(operation_names), + file=err)) else: raise RuntimeError("Configuration value status_parallelization is invalid. " - "You can set it to 'thread', 'parallel', or 'none'" - ) + "You can set it to 'thread', 'parallel', or 'none'") except RuntimeError as error: if "can't start new thread" not in error.args: raise # unrelated error - t = time.time() - num_jobs = len(jobs) - statuses = [] - for i, job in enumerate(jobs): - statuses.append(_get_job_status(job)) - if time.time() - t > 0.2: # status interval - print( - 'Collecting job status info: {}/{}'.format(i+1, num_jobs), - end='\r', file=err) - t = time.time() - # Always print the completed progressbar. - print('Collecting job status info: {}/{}'.format(i+1, num_jobs), file=err) - return statuses - - def _fetch_status_in_parallel(self, pool, pickle, jobs, ignore_errors, cached_status): + def print_status(iterable, fetch_status, description): + t = time.time() + num_itr = len(iterable) + results = [] + for i, itr in enumerate(iterable): + results.append(fetch_status(itr)) + # The status interval 0.2 seconds is used since we expect the + # status for an aggregate to be fetched within that interval + if time.time() - t > 0.2: + tqdm.update(f'{description}: {i+1}/{num_itr}', end='\r', file=err) + t = time.time() + # Always print the completed progressbar. + print(f'{description}: {i+1}/{num_itr}', file=err) + return results + + label_results = print_status(distinct_jobs, get_job_labels, + "Collecting job label info") + op_results = print_status(operation_names, get_group_status, + "Collecting operation status") + + results = [] + index = {} + for i, job in enumerate(distinct_jobs): + results_entry = dict() + results_entry['job_id'] = str(job) + results_entry['operations'] = dict() + results_entry['_operations_error'] = None + results_entry['labels'] = list() + results_entry['_labels_error'] = None + results.append(results_entry) + index[job.get_id()] = i + + for op_result in op_results: + for id, aggregates_status in op_result['job_status_details'].items(): + aggregate = self._get_aggregate_from_id(id) + if not self._is_selected_aggregate(aggregate, aggregates): + continue + error = op_result['_operation_error_per_job'].get(id, None) + for job in aggregate: + results[index[job.get_id()]]['operations'][op_result['operation_name']] = \ + aggregates_status + results[index[job.get_id()]]['_operations_error'] = error + + for label_result in label_results: + results[index[label_result['job_id']]]['labels'] = label_result['labels'] + results[index[label_result['job_id']]]['_labels_error'] = label_result['_labels_error'] + + return results + + def _fetch_status_in_parallel(self, pool, pickle, jobs, groups, ignore_errors, cached_status): try: - s_project = pickle.dumps(self) - s_tasks = [(pickle.loads, s_project, job.get_id(), ignore_errors, cached_status) - for job in jobs] + # Since pickling the project results in loss of necessary information. We + # explicitly pickle all the necessary information and then mock them in the + # serialized methods. + s_root = pickle.dumps(self.root_directory()) + s_label_funcs = pickle.dumps(self._label_functions) + s_groups = pickle.dumps(self._groups) + s_groups_aggregate = pickle.dumps(self._stored_aggregates) + s_tasks_labels = [(pickle.loads, s_root, job.get_id(), ignore_errors, + s_label_funcs, 'fetch_labels') for job in jobs] + s_tasks_groups = [(pickle.loads, s_root, group, ignore_errors, cached_status, + s_groups, s_groups_aggregate, 'fetch_status') for group in groups] except Exception as error: # Masking all errors since they must be pickling related. raise self._PickleError(error) - results = pool.imap(_serialized_get_job_status, s_tasks) + label_results = pool.starmap(_serializer, s_tasks_labels) + group_results = pool.starmap(_serializer, s_tasks_groups) - return results + return label_results, group_results PRINT_STATUS_ALL_VARYING_PARAMETERS = True """This constant can be used to signal that the print_status() method is supposed @@ -1918,7 +2068,7 @@ def print_status(self, jobs=None, overview=True, overview_max_lines=None, :param jobs: Only execute operations for the given jobs, or all if the argument is omitted. :type jobs: - Sequence of instances :class:`.Job`. + Sequence of instances of :class:`.Job`. :param overview: Aggregate an overview of the project' status. :type overview: @@ -2005,8 +2155,16 @@ def print_status(self, jobs=None, overview=True, overview_max_lines=None, file = sys.stdout if err is None: err = sys.stderr - if jobs is None: - jobs = self # all jobs + + aggregates = self._convert_aggregates_from_jobs(jobs) + if aggregates is not None: + # Fetch all the distinct jobs from all the jobs or aggregate passed by the user. + distinct_jobs = set() + for aggregate in aggregates: + for job in aggregate: + distinct_jobs.add(job) + else: + distinct_jobs = self if eligible_jobs_max_lines is None: eligible_jobs_max_lines = flow_config.get_config_value('eligible_jobs_max_lines') @@ -2049,7 +2207,8 @@ def print_status(self, jobs=None, overview=True, overview_max_lines=None, ] with prof(single=False): - tmp = self._fetch_status(jobs, err, ignore_errors, status_parallelization) + tmp = self._fetch_status(aggregates, distinct_jobs, err, ignore_errors, + status_parallelization) prof._mergeFileTiming() @@ -2113,7 +2272,8 @@ def print_status(self, jobs=None, overview=True, overview_max_lines=None, "results may be highly inaccurate.") else: - tmp = self._fetch_status(jobs, err, ignore_errors, status_parallelization) + tmp = self._fetch_status(aggregates, distinct_jobs, err, ignore_errors, + status_parallelization) profiling_results = None operations_errors = {s['_operations_error'] for s in tmp} @@ -2157,8 +2317,8 @@ def _incomplete(s): # Optionally expand parameters argument to all varying parameters. if parameters is self.PRINT_STATUS_ALL_VARYING_PARAMETERS: parameters = list( - sorted({key for job in jobs for key in job.sp.keys() if - len(set([to_hashable(job.sp().get(key)) for job in jobs])) > 1})) + sorted({key for job in distinct_jobs for key in job.sp.keys() if + len(set([to_hashable(job.sp().get(key)) for job in distinct_jobs])) > 1})) if parameters: # get parameters info @@ -2256,6 +2416,7 @@ def _add_dummy_operation(job): context['op_counter'].append(('[{} more operations omitted]'.format(n), '')) status_renderer = StatusRenderer() + # We have to make a deep copy of the template environment if we're # using a process Pool for parallelism. Somewhere in the process of # manually pickling and dispatching tasks to individual processes @@ -2307,7 +2468,7 @@ def _run_operations(self, operations=None, pretend=False, np=None, if timeout is not None and timeout < 0: timeout = None if operations is None: - operations = list(self._get_pending_operations(self)) + operations = list(self._get_pending_operations()) else: operations = list(operations) # ensure list @@ -2400,13 +2561,14 @@ def _run_operations_in_parallel(self, pool, pickle, operations, progress, timeou """ try: - s_project = pickle.dumps(self) - s_tasks = [(pickle.loads, s_project, self._dumps_op(op)) + s_root = pickle.dumps(self.root_directory()) + s_ops = pickle.dumps(self._operations) + s_tasks = [(pickle.loads, s_root, self._dumps_op(op), s_ops, 'run_operations') for op in tqdm(operations, desc='Serialize tasks', file=sys.stderr)] except Exception as error: # Masking all errors since they must be pickling related. raise self._PickleError(error) - results = [pool.apply_async(_execute_serialized_operation, task) for task in s_tasks] + results = [pool.apply_async(_serializer, task) for task in s_tasks] for result in tqdm(results) if progress else results: result.get(timeout=timeout) @@ -2444,8 +2606,9 @@ def _execute_operation(self, operation, timeout=None, pretend=False): except Exception as e: assert len(operation._jobs) == 1 raise UserOperationError( - 'An exception was raised during operation {operation.name} ' - 'for job {operation._jobs[0]}.'.format(operation=operation)) from e + f'An exception was raised during operation {operation.name} ' + f'for job or aggregate with id {get_aggregate_id(operation._jobs)}.' + ) from e def _get_default_directives(self): return {name: self.groups[name].operation_directives.get(name, dict()) @@ -2466,9 +2629,10 @@ def run(self, jobs=None, names=None, pretend=False, np=None, timeout=None, num=N See also: :meth:`~.run_operations` :param jobs: - Only execute operations for the given jobs, or all if the argument is omitted. + Only execute operations for the given jobs or aggregates of jobs, + or all if the argument is omitted. :type jobs: - Sequence of instances :class:`.Job`. + Sequence of instances of :class:`.Job` or aggregate of instances of :class:`.Job`. :param names: Only execute operations that are in the provided set of names, or all, if the argument is omitted. @@ -2506,12 +2670,13 @@ def run(self, jobs=None, names=None, pretend=False, np=None, timeout=None, num=N Specify the order of operations, possible values are: * 'none' or None (no specific order) * 'by-job' (operations are grouped by job) + * 'by-op' (operations are grouped by operation) * 'cyclic' (order operations cyclic by job) * 'random' (shuffle the execution order randomly) * callable (a callable returning a comparison key for an operation used to sort operations) - The default value is `none`, which is equivalent to `by-job` in the current + The default value is `none`, which is equivalent to `by-op` in the current implementation. .. note:: @@ -2527,9 +2692,7 @@ def run(self, jobs=None, names=None, pretend=False, np=None, timeout=None, num=N :type ignore_conditions: :py:class:`~.IgnoreConditions` """ - # If no jobs argument is provided, we run operations for all jobs. - if jobs is None: - jobs = self + aggregates = self._convert_aggregates_from_jobs(jobs) # Get all matching FlowGroups if isinstance(names, str): @@ -2563,7 +2726,8 @@ def log(msg, lvl=logging.INFO): reached_execution_limit = Event() def select(operation): - self._verify_aggregate_project(operation._jobs) + if not self._is_selected_aggregate(operation._jobs, aggregates): + return False if num is not None and select.total_execution_count >= num: reached_execution_limit.set() @@ -2607,12 +2771,11 @@ def select(operation): with self._potentially_buffered(): operations = [] for flow_group in flow_groups: - for job in jobs: + for aggregate in self._get_aggregate_store(flow_group.name).values(): operations.extend( flow_group._create_run_job_operations( self._entrypoint, default_directives, - (job,), ignore_conditions)) - + aggregate, ignore_conditions)) operations = list(filter(select, operations)) finally: if messages: @@ -2622,21 +2785,30 @@ def select(operation): if not operations: break # No more pending operations or execution limits reached. + def key_func_by_job(op): + # In order to group the aggregates in a by-job manner, we need + # to first sort the aggregates using their aggregate id. + return get_aggregate_id(op._jobs) + # Optionally re-order operations for execution if order argument is provided: if callable(order): operations = list(sorted(operations, key=order)) - elif order == 'cyclic': - groups = [list(group) - for _, group in groupby(operations, key=lambda op: op._jobs)] - operations = list(roundrobin(*groups)) elif order == 'random': random.shuffle(operations) - elif order is None or order in ('none', 'by-job'): - pass # by-job is the default order + elif order in ('by-job', 'cyclic'): + groups = [list(group) + for _, group in groupby(sorted(operations, key=key_func_by_job), + key=key_func_by_job)] + if order == 'cyclic': + operations = list(roundrobin(*groups)) + else: + operations = list(chain(*groups)) + elif order is None or order in ('none', 'by-op'): + pass # by-op is the default order else: raise ValueError( "Invalid value for the 'order' argument, valid arguments are " - "'none', 'by-job', 'cyclic', 'random', None, or a callable.") + "'none', 'by-op', 'by-job', 'cyclic', 'random', None, or a callable.") logger.info( "Executing {} operation(s) (Pass # {:02d})...".format(len(operations), i_pass)) @@ -2666,23 +2838,24 @@ def _gather_flow_groups(self, names=None): " -o/--operation option.") return operations - def _get_submission_operations(self, jobs, default_directives, names=None, + def _get_submission_operations(self, aggregates, default_directives, names=None, ignore_conditions=IgnoreConditions.NONE, ignore_conditions_on_execution=IgnoreConditions.NONE): """Grabs _JobOperations that are eligible to run from FlowGroups.""" for group in self._gather_flow_groups(names): - for job in jobs: + for aggregate in self._get_aggregate_store(group.name).values(): if ( - group._eligible((job,), ignore_conditions) and - self._eligible_for_submission(group, (job,)) + group._eligible(aggregate, ignore_conditions) and + self._eligible_for_submission(group, aggregate) and + self._is_selected_aggregate(aggregate, aggregates) ): yield group._create_submission_job_operation( entrypoint=self._entrypoint, default_directives=default_directives, - jobs=(job,), index=0, + jobs=aggregate, index=0, ignore_conditions_on_execution=ignore_conditions_on_execution) - def _get_pending_operations(self, jobs, operation_names=None, + def _get_pending_operations(self, jobs=None, operation_names=None, ignore_conditions=IgnoreConditions.NONE): "Get all pending operations for the given selection." assert not isinstance(operation_names, str) @@ -2694,12 +2867,57 @@ def _verify_group_compatibility(self, groups): """Verifies that all selected groups can be submitted together.""" return all(a.isdisjoint(b) for a in groups for b in groups if a != b) - def _verify_aggregate_project(self, aggregate): - """Verifies that all aggregates belongs to the same project.""" - for job in aggregate: - if job not in self: - raise ValueError("Job {} is not present " - "in the project".format(job)) + def _aggregate_is_in_project(self, aggregate): + """Verifies that the aggregate belongs to this project.""" + return any(get_aggregate_id(aggregate) in aggregates + for aggregates in self._stored_aggregates) + + @staticmethod + def _is_selected_aggregate(aggregate, jobs): + """Verifies whether the aggregate is present in the provided jobs. + + Providing ``jobs=None`` indicates that no specific job is provided by + the user and hence ``aggregate`` is eligible for further evaluation. + + Always returns True if jobs is None. + """ + return (jobs is None) or (aggregate in jobs) + + def _get_aggregate_from_id(self, id): + # Iterate over all the instances of stored aggregates and search for the + # aggregate in those instances. + for aggregate_store in self._stored_aggregates: + if id in aggregate_store: + return aggregate_store[id] + # Raise error as didn't find the id in any of the stored objects + raise LookupError(f"Did not find aggregate with id {id} in the project") + + def _convert_aggregates_from_jobs(self, jobs): + # The jobs parameter in public methods like ``run``, ``submit``, ``status`` may + # accept either a signac job or an aggregate. We convert that job / aggregate + # (which may be of any type (e.g. list)) to an aggregate of type ``tuple``. + if jobs is not None: + # aggregates must be a set to prevent duplicate entries + aggregates = set() + for aggregate in jobs: + # User can still pass signac jobs. + if isinstance(aggregate, signac.contrib.job.Job): + if aggregate not in self: + raise LookupError(f"Did not find job {aggregate} in the project") + aggregates.add((aggregate,)) + else: + try: + aggregate = tuple(aggregate) + except TypeError: + raise TypeError('Invalid argument provided by a user. Please provide ' + 'a valid signac job or an aggregate of jobs instead.') + else: + if not self._aggregate_is_in_project(aggregate): + raise LookupError(f"Did not find aggregate {aggregate} in the project") + aggregates.add(aggregate) # An aggregate provided by the user + return list(aggregates) + else: + return None @contextlib.contextmanager def _potentially_buffered(self): @@ -2980,10 +3198,9 @@ def submit(self, bundle_size=1, jobs=None, names=None, num=None, parallel=False, :type ignore_conditions: :py:class:`~.IgnoreConditions` """ - # Regular argument checks and expansion - if jobs is None: - jobs = self # select all jobs + aggregates = self._convert_aggregates_from_jobs(jobs) + # Regular argument checks and expansion if isinstance(names, str): raise ValueError( "The 'names' argument must be a sequence of strings, however you " @@ -3011,7 +3228,7 @@ def submit(self, bundle_size=1, jobs=None, names=None, num=None, parallel=False, default_directives = self._get_default_directives() # The generator must be used *inside* the buffering context manager # for performance reasons. - operation_generator = self._get_submission_operations(jobs, + operation_generator = self._get_submission_operations(aggregates, default_directives, names, ignore_conditions, @@ -3382,20 +3599,12 @@ def completed_operations(self, job): if op._complete((job,)): yield name - def _job_operations(self, job, ignore_conditions=IgnoreConditions.NONE): - "Yield instances of _JobOperation constructed for specific jobs." - for name in self.operations: - group = self._groups[name] - yield from group._create_run_job_operations(entrypoint=self._entrypoint, jobs=(job,), - default_directives=dict(), - ignore_conditions=ignore_conditions, - index=0) - - def _next_operations(self, jobs, ignore_conditions=IgnoreConditions.NONE): - """Determine the next eligible operations for jobs. + def _next_operations(self, jobs=None, ignore_conditions=IgnoreConditions.NONE): + """Determine the next eligible operations for aggregates. :param jobs: - The signac job handles. + The signac job handles. By default all the aggregates are evaluated to get + the next operation associated. :type jobs: tuple of :class:`~signac.contrib.job.Job` :param ignore_conditions: @@ -3406,9 +3615,15 @@ def _next_operations(self, jobs, ignore_conditions=IgnoreConditions.NONE): :yield: All instances of :class:`~._JobOperation` jobs are eligible for. """ - for job in jobs: - for op in self._job_operations(job, ignore_conditions): - yield op + for name in self.operations: + group = self._groups[name] + for aggregate in self._get_aggregate_store(group.name).values(): + if not self._is_selected_aggregate(aggregate, jobs): + continue + yield from group._create_run_job_operations(entrypoint=self._entrypoint, + default_directives=dict(), + jobs=aggregate, + ignore_conditions=ignore_conditions) @deprecated(deprecated_in="0.11", removed_in="0.13", current_version=__version__) def next_operations(self, *jobs, ignore_conditions=IgnoreConditions.NONE): @@ -3416,8 +3631,8 @@ def next_operations(self, *jobs, ignore_conditions=IgnoreConditions.NONE): :param jobs: The signac job handles. - :type job: - :class:`~signac.contrib.job.Job` + :type jobs: + Sequence of instances of :class:`.Job`. :param ignore_conditions: Specify if pre and/or post conditions check is to be ignored for eligibility check. The default is :py:class:`IgnoreConditions.NONE`. @@ -3426,12 +3641,26 @@ def next_operations(self, *jobs, ignore_conditions=IgnoreConditions.NONE): :yield: All instances of :class:`~.JobOperation` jobs are eligible for. """ - for job in jobs: - for op in self._job_operations(job, ignore_conditions): - # JobOperation is just meand to deal with a single job and not a tuple of jobs. - # Hence we have to make sure that a JobOperation instance hold a single job. - assert len(op._jobs) == 1 - yield JobOperation(op.id, op.name, op._jobs[0], op._cmd, op.directives) + for name in self.operations: + group = self._groups[name] + aggregate_store = self._get_aggregate_store(group.name) + + # Only yield JobOperation instances from the default aggregates + if not isinstance(aggregate_store, _DefaultAggregateStore): + continue + + for aggregate in aggregate_store.values(): + # JobOperation handles a single job and not an aggregate of + # jobs. Hence the single job in that aggregate should be + # present in the jobs passed by a user. + if aggregate[0] not in jobs: + continue + + for op in group._create_run_job_operations( + entrypoint=self._entrypoint, jobs=aggregate, default_directives={}, + ignore_conditions=ignore_conditions, index=0 + ): + yield JobOperation(op.id, op.name, op._jobs[0], op._cmd, op.directives) @classmethod def operation(cls, func, name=None): @@ -3458,17 +3687,18 @@ def hello(job): "An operation with name '{}' is already registered.".format(name)) if name in cls._GROUP_NAMES: raise ValueError("A group with name '{}' is already registered.".format(name)) - signature = inspect.signature(func) for i, (k, v) in enumerate(signature.parameters.items()): if i and v.default is inspect.Parameter.empty: raise ValueError( "Only the first argument in an operation argument may not have " "a default value! ({})".format(name)) + if not getattr(func, '_flow_aggregate', False): + func._flow_aggregate = aggregator.groupsof(1) # Append the name and function to the class registry cls._OPERATION_FUNCTIONS.append((name, func)) - cls._GROUPS.append(FlowGroupEntry(name=name, options="")) + cls._GROUPS.append(FlowGroupEntry(name=name, options="", aggregator=func._flow_aggregate)) if hasattr(func, '_flow_groups'): func._flow_groups.append(name) else: @@ -3524,6 +3754,13 @@ def _register_operations(self): else: self._operations[name] = FlowOperation(op_func=func, **params) + def _register_aggregates(self): + """Generate aggregates for every operation or group in a FlowProject""" + stored_aggregates = {} + for _aggregator, groups in self._aggregator_per_group.items(): + stored_aggregates[_aggregator._create_AggregatesStore(self)] = groups + self._stored_aggregates = stored_aggregates + @classmethod def make_group(cls, name, options=""): """Make a FlowGroup named ``name`` and return a decorator to make groups. @@ -3548,13 +3785,16 @@ def foo(job): A string to append to submissions can be any valid :meth:`FlowOperation.run` option. :type options: str + :param aggregator_obj: + aggregator object associated with the :py:class:`FlowGroup` + :type aggregator_obj: + :py:class:`aggregator` """ if name in cls._GROUP_NAMES: raise ValueError("Repeat definition of group with name '{}'.".format(name)) else: cls._GROUP_NAMES.add(name) - - group_entry = FlowGroupEntry(name, options) + group_entry = FlowGroupEntry(name, options, aggregator.groupsof(1)) cls._GROUPS.append(group_entry) return group_entry @@ -3565,9 +3805,14 @@ def _register_groups(self): for cls in type(self).__mro__: group_entries.extend(getattr(cls, '_GROUPS', [])) - # Initialize all groups without operations + aggregators = defaultdict(list) + # Initialize all groups without operations. + # Also store the aggregates we need to store all the groups associated + # with each aggregator. for entry in group_entries: self._groups[entry.name] = FlowGroup(entry.name, options=entry.options) + aggregators[entry.aggregator].append(entry.name) + self._aggregator_per_group = dict(aggregators) # Add operations and directives to group for (op_name, op) in self._operations.items(): @@ -3596,6 +3841,23 @@ def operations(self): def groups(self): return self._groups + def _get_aggregate_store(self, group): + """Return aggregate store associated with the FlowGroup. + + :param group: + The name of the FlowGroup whose aggregate store will be returned. + :type group: + str + :returns: + Aggregate store containing aggregates associated with the provided FlowGroup. + :rtype: + :py:class:`_DefaultAggregateStore` + """ + for aggregate_store, groups in self._stored_aggregates.items(): + if group in groups: + return aggregate_store + return {} + def _eligible_for_submission(self, flow_group, jobs): """Determine if a flow_group is eligible for submission with a given job-aggregate. @@ -3616,7 +3878,7 @@ def _eligible_for_submission(self, flow_group, jobs): def _main_status(self, args): "Print status overview." - jobs = self._select_jobs_from_args(args) + aggregates = self._select_jobs_from_args(args) if args.compact and not args.unroll: logger.warn("The -1/--one-line argument is incompatible with " "'--stack' and will be ignored.") @@ -3629,9 +3891,7 @@ def _main_status(self, args): start = time.time() try: - self.print_status(jobs=jobs, **args) - except NoSchedulerError: - self.print_status(jobs=jobs, **args) + self.print_status(jobs=aggregates, **args) except Exception as error: if show_traceback: logger.error( @@ -3645,8 +3905,13 @@ def _main_status(self, args): error = error.__cause__ # Always show the user traceback cause. traceback.print_exception(type(error), error, error.__traceback__) else: + if aggregates is None: + length_jobs = sum(len(aggregate_store) + for aggregate_store in self._stored_aggregates) + else: + length_jobs = len(aggregates) # Use small offset to account for overhead with few jobs - delta_t = (time.time() - start - 0.5) / max(len(jobs), 1) + delta_t = (time.time() - start - 0.5) / max(length_jobs, 1) config_key = 'status_performance_warn_threshold' warn_threshold = flow_config.get_config_value(config_key) if not args['profile'] and delta_t > warn_threshold >= 0: @@ -3665,20 +3930,20 @@ def _main_status(self, args): def _main_next(self, args): "Determine the jobs that are eligible for a specific operation." - for op in self._next_operations(self): + for op in self._next_operations(): if args.name in op.name: - print(' '.join(map(str, op._jobs))) + print(get_aggregate_id(op._jobs)) def _main_run(self, args): "Run all (or select) job operations." # Select jobs: - jobs = self._select_jobs_from_args(args) + aggregates = self._select_jobs_from_args(args) # Setup partial run function, because we need to call this either # inside some context managers or not based on whether we need # to switch to the project root directory or not. run = functools.partial(self.run, - jobs=jobs, names=args.operation_name, pretend=args.pretend, + jobs=aggregates, names=args.operation_name, pretend=args.pretend, np=args.parallel, timeout=args.timeout, num=args.num, num_passes=args.num_passes, progress=args.progress, order=args.order, @@ -3694,13 +3959,13 @@ def _main_run(self, args): def _main_script(self, args): "Generate a script for the execution of operations." # Select jobs: - jobs = self._select_jobs_from_args(args) + aggregates = self._select_jobs_from_args(args) # Gather all pending operations or generate them based on a direct command... with self._potentially_buffered(): names = args.operation_name if args.operation_name else None default_directives = self._get_default_directives() - operations = self._get_submission_operations(jobs, default_directives, names, + operations = self._get_submission_operations(aggregates, default_directives, names, args.ignore_conditions, args.ignore_conditions_on_execution) operations = list(islice(operations, args.num)) @@ -3717,20 +3982,17 @@ def _main_submit(self, args): kwargs = vars(args) # Select jobs: - jobs = self._select_jobs_from_args(args) + aggregates = self._select_jobs_from_args(args) # Fetch the scheduler status. if not args.test: - self._fetch_scheduler_status(jobs) + self._fetch_scheduler_status(aggregates) names = args.operation_name if args.operation_name else None - self.submit(jobs=jobs, names=names, **kwargs) + self.submit(jobs=aggregates, names=names, **kwargs) def _main_exec(self, args): - if len(args.job_id): - jobs = [self.open_job(id=jid) for jid in args.job_id] - else: - jobs = self + aggregates = self._select_jobs_from_args(args) try: operation = self._operations[args.operation] @@ -3744,24 +4006,36 @@ def operation_function(job): except KeyError: raise KeyError("Unknown operation '{}'.".format(args.operation)) - for job in jobs: - operation_function(job) + for aggregate in self._get_aggregate_store(args.operation).values(): + if self._is_selected_aggregate(aggregate, aggregates): + operation_function(*aggregate) def _select_jobs_from_args(self, args): - "Select jobs with the given command line arguments ('-j/-f/--doc-filter')." - if args.job_id and (args.filter or args.doc_filter): + "Select jobs with the given command line arguments ('-j/-f/--doc-filter/--jobid')." + if ( + not args.func == self._main_exec and + args.job_id and (args.filter or args.doc_filter) + ): raise ValueError( "Cannot provide both -j/--job-id and -f/--filter or --doc-filter in combination.") if args.job_id: - try: - return [self.open_job(id=job_id) for job_id in args.job_id] - except KeyError as error: - raise LookupError("Did not find job with id {}.".format(error)) - else: + # aggregates must be a set to prevent duplicate entries + aggregates = set() + for id in args.job_id: + # TODO: We need to add support for aggregation id parameter + # for the -j flag ('agg-...') + try: + aggregates.add((self.open_job(id=id),)) + except KeyError as error: + raise LookupError("Did not find job with id {}.".format(error)) + return list(aggregates) + elif 'filter' in args or 'doc_filter' in args: filter_ = parse_filter_arg(args.filter) doc_filter = parse_filter_arg(args.doc_filter) return JobsCursor(self, filter_, doc_filter) + else: + return None def main(self, parser=None): """Call this function to use the main command line interface. @@ -3891,7 +4165,7 @@ class MyProject(FlowProject): execution_group.add_argument( '--order', type=str, - choices=['none', 'by-job', 'cyclic', 'random'], + choices=['none', 'by-op', 'by-job', 'cyclic', 'random'], default=None, help="Specify the execution order of operations for each execution pass.") execution_group.add_argument( @@ -4039,20 +4313,27 @@ def _show_traceback_and_exit(error): _show_traceback_and_exit(error) -def _execute_serialized_operation(loads, project, operation): - """Invoke the _execute_operation() method on a serialized project instance.""" - project = loads(project) - project._execute_operation(project._loads_op(operation)) - - -def _serialized_get_job_status(s_task): - """Invoke the _get_job_status() method on a serialized project instance.""" - loads = s_task[0] - project = loads(s_task[1]) - job = project.open_job(id=s_task[2]) - ignore_errors = s_task[3] - cached_status = s_task[4] - return project.get_job_status(job, ignore_errors=ignore_errors, cached_status=cached_status) +def _serializer(loads, root, *args): + root = loads(root) + project = FlowProject.get_project(root) + if args[-1] == 'run_operations': + operation = args[0] + project._operations = loads(args[1]) + project._execute_operation(project._loads_op(operation)) + elif args[-1] == 'fetch_labels': + job = project.open_job(id=args[0]) + ignore_errors = args[1] + project._label_functions = loads(args[2]) + return project._get_job_labels(job, ignore_errors=ignore_errors) + elif args[-1] == 'fetch_status': + group = args[0] + ignore_errors = args[1] + cached_status = args[2] + groups = loads(args[3]) + project._groups = groups + groups_aggregate = loads(args[4]) + project._stored_aggregates = groups_aggregate + return project._get_group_status(group, ignore_errors, cached_status) # Status-related helper functions diff --git a/setup.cfg b/setup.cfg index a497f3c84..1d872c180 100644 --- a/setup.cfg +++ b/setup.cfg @@ -32,6 +32,6 @@ ignore = E123,E126,E226,E241,E704,W503,W504 [bumpversion:file:.zenodo.json] [tool:pytest] -filterwarnings = +filterwarnings = ignore:.*get_id is deprecated.*:DeprecationWarning ignore:.*The env argument is deprecated*:DeprecationWarning diff --git a/tests/test_aggregates.py b/tests/test_aggregates.py index 655e50caa..4a5c2c71a 100644 --- a/tests/test_aggregates.py +++ b/tests/test_aggregates.py @@ -63,7 +63,7 @@ def test_default_init(self): aggregate_instance = aggregator() test_list = (1, 2, 3, 4, 5) assert aggregate_instance._sort_by is None - assert aggregate_instance._aggregator_function(test_list) == [test_list] + assert aggregate_instance._aggregator_function(test_list) == (test_list,) assert aggregate_instance._select is None def test_invalid_aggregator_function(self, setUp, project): @@ -147,26 +147,24 @@ def helper_aggregator_function(jobs): aggregate_instance = aggregator(helper_aggregator_function) aggregate_instance = aggregate_instance._create_AggregatesStore(project) aggregate_job_manual = helper_aggregator_function(project) - assert [aggregate for aggregate in aggregate_job_manual] == \ - list(aggregate_instance) + assert tuple(aggregate_job_manual) == tuple(aggregate_instance.values()) # Testing aggregator function returning aggregates of all the jobs aggregate_instance = aggregator(lambda jobs: [jobs]) aggregate_instance = aggregate_instance._create_AggregatesStore(project) - assert [tuple(project)] == list(aggregate_instance) + assert (tuple(project),) == tuple(aggregate_instance.values()) def test_valid_sort_by(self, setUp, project): helper_sort = partial(sorted, key=lambda job: job.sp.i) aggregate_instance = aggregator(sort_by='i') aggregate_instance = aggregate_instance._create_AggregatesStore(project) - - assert [tuple(helper_sort(project))] == list(aggregate_instance) + assert (tuple(helper_sort(project)),) == tuple(aggregate_instance.values()) def test_valid_descending_sort(self, setUp, project): helper_sort = partial(sorted, key=lambda job: job.sp.i, reverse=True) aggregate_instance = aggregator(sort_by='i', sort_ascending=False) aggregate_instance = aggregate_instance._create_AggregatesStore(project) - assert [tuple(helper_sort(project))] == list(aggregate_instance) + assert (tuple(helper_sort(project)),) == tuple(aggregate_instance.values()) def test_groups_of_valid_num(self, setUp, project): valid_values = [1, 2, 3, 6, 10] @@ -186,7 +184,7 @@ def test_groups_of_valid_num(self, setUp, project): # We also check the length of every aggregate in order to ensure # proper aggregation. - for j, aggregate in enumerate(aggregate_instance): + for j, aggregate in enumerate(aggregate_instance.values()): if j == expected_len - 1: # Checking for the last aggregate assert len(aggregate) == expected_length_per_aggregate[i][1] else: @@ -195,7 +193,7 @@ def test_groups_of_valid_num(self, setUp, project): def test_groupby_with_valid_string_key(self, setUp, project): aggregate_instance = aggregator.groupby('even') aggregate_instance = aggregate_instance._create_AggregatesStore(project) - for aggregate in aggregate_instance: + for aggregate in aggregate_instance.values(): even = aggregate[0].sp.even assert all(even == job.sp.even for job in aggregate) assert len(aggregate_instance) == 2 @@ -210,7 +208,7 @@ def test_groupby_with_invalid_string_key(self, setUp, project): def test_groupby_with_default_key_for_string(self, setUp, project): aggregate_instance = aggregator.groupby('half', default=-1) aggregate_instance = aggregate_instance._create_AggregatesStore(project) - for aggregate in aggregate_instance: + for aggregate in aggregate_instance.values(): half = aggregate[0].sp.get('half', -1) assert all(half == job.sp.get('half', -1) for job in aggregate) assert len(aggregate_instance) == 6 @@ -232,7 +230,7 @@ def test_groupby_with_invalid_Iterable_key(self, setUp, project): def test_groupby_with_valid_default_key_for_Iterable(self, setUp, project): aggregate_instance = aggregator.groupby(['half', 'even'], default=[-1, -1]) aggregate_instance = aggregate_instance._create_AggregatesStore(project) - for aggregate in aggregate_instance: + for aggregate in aggregate_instance.values(): half = aggregate[0].sp.get('half', -1) even = aggregate[0].sp.get('even', -1) assert all( @@ -247,7 +245,7 @@ def keyfunction(job): aggregate_instance = aggregator.groupby(keyfunction) aggregate_instance = aggregate_instance._create_AggregatesStore(project) - for aggregate in aggregate_instance: + for aggregate in aggregate_instance.values(): even = aggregate[0].sp.even assert all(even == job.sp.even for job in aggregate) assert len(aggregate_instance) == 2 @@ -271,7 +269,7 @@ def _select(job): for job in project: if _select(job): selected_jobs.append((job,)) - assert list(aggregate_instance) == selected_jobs + assert list(aggregate_instance.values()) == selected_jobs def test_storing_hashing(self, setUp, project, list_of_aggregates): # Since we need to store groups on a per aggregate basis in the project, @@ -296,7 +294,7 @@ def _create_storing(aggregator): # objects which stores aggregates. list_of_storing = set(map(_create_storing, list_of_aggregates)) for stored_aggregate in list_of_storing: - for aggregate in stored_aggregate: + for aggregate in stored_aggregate.values(): assert aggregate == stored_aggregate[get_aggregate_id(aggregate)] def test_get_invalid_id(self, setUp, project): @@ -315,8 +313,8 @@ def test_contains(self, setUp, project): aggregator_instance = aggregator()._create_AggregatesStore(project) default_aggregator = aggregator.groupsof(1)._create_AggregatesStore(project) # Test for an aggregate of all jobs - assert jobs in aggregator_instance - assert jobs not in default_aggregator + assert get_aggregate_id(jobs) in aggregator_instance + assert get_aggregate_id(jobs) not in default_aggregator # Test for an aggregate of single job - assert not (jobs[0],) in aggregator_instance - assert (jobs[0],) in default_aggregator + assert not jobs[0].get_id() in aggregator_instance + assert jobs[0].get_id() in default_aggregator diff --git a/tests/test_project.py b/tests/test_project.py index 0e11d78ad..b9ac6a6f5 100644 --- a/tests/test_project.py +++ b/tests/test_project.py @@ -183,7 +183,7 @@ def test_status_performance(self): MockScheduler.reset() time = timeit.timeit( - lambda: project._fetch_status(project, StringIO(), + lambda: project._fetch_status(None, project, StringIO(), ignore_errors=False), number=10) assert time < 10 @@ -434,7 +434,7 @@ def test_context(job): project = self.mock_project(A) for job in project: job.doc.np = 3 - for next_op in project._next_operations((job,)): + for next_op in project._next_operations([(job,)]): assert 'mpirun -np 3 python' in next_op.cmd break @@ -453,20 +453,20 @@ def a(job): # test setting neither nranks nor omp_num_threads for job in project: - for next_op in project._next_operations((job,)): + for next_op in project._next_operations([(job,)]): assert next_op.directives['np'] == 1 # test only setting nranks for i, job in enumerate(project): job.doc.nranks = i+1 - for next_op in project._next_operations((job,)): + for next_op in project._next_operations([(job,)]): assert next_op.directives['np'] == next_op.directives['nranks'] del job.doc['nranks'] # test only setting omp_num_threads for i, job in enumerate(project): job.doc.omp_num_threads = i+1 - for next_op in project._next_operations((job,)): + for next_op in project._next_operations([(job,)]): assert next_op.directives['np'] == next_op.directives['omp_num_threads'] del job.doc['omp_num_threads'] @@ -475,7 +475,7 @@ def a(job): job.doc.omp_num_threads = i+1 job.doc.nranks = i % 3 + 1 expected_np = (i + 1) * (i % 3 + 1) - for next_op in project._next_operations((job,)): + for next_op in project._next_operations([(job,)]): assert next_op.directives['np'] == expected_np def test_copy_conditions(self): @@ -599,7 +599,7 @@ def test_next_operations(self): project = self.mock_project() even_jobs = [job for job in project if job.sp.b % 2 == 0] for job in project: - for i, op in enumerate(project._next_operations((job,))): + for i, op in enumerate(project._next_operations([(job,)])): assert op._jobs == (job,) if job in even_jobs: assert op.name == ['op1', 'op2', 'op3'][i] @@ -613,7 +613,7 @@ def test_get_job_status(self): status = project.get_job_status(job) assert status['job_id'] == job.get_id() assert len(status['operations']) == len(project.operations) - for op in project._next_operations((job,)): + for op in project._next_operations([(job,)]): assert op.name in status['operations'] op_status = status['operations'][op.name] assert op_status['eligible'] == project.operations[op.name]._eligible((job,)) @@ -663,7 +663,7 @@ def test_project_status_heterogeneous_schema(self): def test_script(self): project = self.mock_project() for job in project: - script = project._script(project._next_operations((job,))) + script = project._script(project._next_operations([(job,)])) if job.sp.b % 2 == 0: assert str(job) in script assert 'echo "hello"' in script @@ -683,7 +683,7 @@ def test_script_with_custom_script(self): file.write("THIS IS A CUSTOM SCRIPT!\n") file.write("{% endblock %}\n") for job in project: - script = project._script(project._next_operations((job,))) + script = project._script(project._next_operations([(job,)])) assert "THIS IS A CUSTOM SCRIPT" in script if job.sp.b % 2 == 0: assert str(job) in script @@ -738,15 +738,13 @@ class TestExecutionProject(TestProjectBase): def test_pending_operations_order(self): # The execution order of local runs is internally assumed to be - # 'by-job' by default. A failure of this unit tests means that - # a 'by-job' order must be implemented explicitly within the - # FlowProject.run() function. + # 'by-op' by default. project = self.mock_project() - ops = list(project._get_pending_operations(self.project.find_jobs())) - # The length of the list of operations grouped by job is equal - # to the length of its set if and only if the operations are grouped - # by job already: - jobs_order_none = [job._id for job, _ in groupby(ops, key=lambda op: op._jobs[0])] + ops = list(project._get_pending_operations()) + # The length of the list of job-operations grouped by operation is equal + # to the length of its set if and only if the job-operations are grouped + # by operations already: + jobs_order_none = [name for name, _ in groupby(ops, key=lambda op: op.name)] assert len(jobs_order_none) == len(set(jobs_order_none)) def test_run(self, subtests): @@ -758,7 +756,7 @@ def test_run(self, subtests): def sort_key(op): return op.name, op._jobs[0].get_id() - for order in (None, 'none', 'cyclic', 'by-job', 'random', sort_key): + for order in (None, 'none', 'cyclic', 'by-job', 'by-op', 'random', sort_key): for job in self.project.find_jobs(): # clear job.remove() with subtests.test(order=order): @@ -913,7 +911,7 @@ def test_submit_operations(self): project = self.mock_project() operations = [] for job in project: - operations.extend(project._next_operations((job,))) + operations.extend(project._next_operations([(job,)])) assert len(list(MockScheduler.jobs())) == 0 cluster_job_id = project._store_bundled(operations) with redirect_stderr(StringIO()): @@ -994,7 +992,7 @@ def test_submit_status(self): if job not in even_jobs: continue list(project.labels(job)) - next_op = list(project._next_operations((job,)))[0] + next_op = list(project._next_operations([(job,)]))[0] assert next_op.name == 'op1' assert next_op._jobs == (job,) with redirect_stderr(StringIO()): @@ -1002,7 +1000,7 @@ def test_submit_status(self): assert len(list(MockScheduler.jobs())) == num_jobs_submitted for job in project: - next_op = list(project._next_operations((job,)))[0] + next_op = list(project._next_operations([(job,)]))[0] assert next_op.get_status() == JobStatus.submitted MockScheduler.step() @@ -1010,7 +1008,7 @@ def test_submit_status(self): project._fetch_scheduler_status(file=StringIO()) for job in project: - next_op = list(project._next_operations((job,)))[0] + next_op = list(project._next_operations([(job,)]))[0] assert next_op.get_status() == JobStatus.queued MockScheduler.step() @@ -1026,7 +1024,7 @@ def test_submit_operations_bad_directive(self): project = self.mock_project() operations = [] for job in project: - operations.extend(project._next_operations((job,))) + operations.extend(project._next_operations([(job,)])) assert len(list(MockScheduler.jobs())) == 0 cluster_job_id = project._store_bundled(operations) stderr = StringIO() @@ -1182,7 +1180,7 @@ def test_main_status(self): op_lines.append(next(lines)) except StopIteration: continue - for op in project._next_operations((job,)): + for op in project._next_operations([(job,)]): assert any(op.name in op_line for op_line in op_lines) def test_main_script(self): @@ -1234,7 +1232,7 @@ def test_script(self): project = self.mock_project() # For run mode single operation groups for job in project: - job_ops = project._get_submission_operations((job,), dict()) + job_ops = project._get_submission_operations([(job,)], dict()) script = project._script(job_ops) if job.sp.b % 2 == 0: assert str(job) in script