Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make DefaultAggregateStore picklable #383

Merged
merged 12 commits into from
Dec 28, 2020
13 changes: 13 additions & 0 deletions flow/aggregates.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,6 +538,19 @@ def _register_aggregates(self, project):
"""
self._project = project

def __setstate__(self, data):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we still need to define these methods? I thought that using cloudpickle might make this unnecessary.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I ran some tests and wanted to update my question. I don't think this block of code is needed if we're using cloudpickle, since the only problem I'm aware of is an issue with pickling locally defined functions (which cloudpickle solves). If that's correct, we shouldn't need to manually define __getstate__ and __setstate__ for any classes.

@b-butler @kidrahahjo Is this correct? If not, could you help provide a minimal case where these are required and commit it to the tests?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The project attribute gets lost if we don't get them manually. Not sure why is this happening though.
The error can be reproduced by executing this pytest tests/test_project -k parallel -x when these __getstate__ and __setstate__ methods are removed.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I spent quite a bit of time and determined the core issue. The problem has to do with pickling/unpickling the project configuration. We monkey-patch the signac project configuration in signac-flow and when that gets unpickled, it points to the configobj from signac. Since the flow key isn't defined, it can't re-construct the project configuration during unpickling.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can adopt this workaround with __getstate__ / __setstate__ for now, and add a comment describing the core issue. I'll add that to this PR and approve it.

Copy link
Member

@bdice bdice Dec 22, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This gets even better. I got deeper into the problem, and found a cleaner solution. Here's my best explanation:

Recall that the FlowProject owns aggregate stores, which may include a _DefaultAggregateStore that holds a reference to the original FlowProject. The circular reference is not intrinsically a problem, but there is a complication. The _DefaultAggregateStore is used as a key in a dict, which means that it must be hashable. When the __hash__ method is called during unpickling, the _DefaultAggregateStore's hash function returns hash(repr(self._project)). This means that in order to reconstruct the pickled project, the FlowProject (because of the _DefaultAggregateStore) must compute its own __repr__. However, that relies on FlowProject._config already being set, since the repr returns something like Project.get_project('/path/to/project') via the Project._rd() method, which requires the root directory from the config. The _DefaultAggregateStore instance cannot be used as a dict key in the FlowProject (perhaps in _aggregator_per_group or _stored_aggregates?) because it's not hashable until the FlowProject is fully initialized.

If it were possible to control the order in which data was unpickled, it might be possible to break this circular chain and avoid the error. However, rather than defining __getstate__ and __setstate__ for the FlowProject or _DefaultAggregateStore, I think it would be easiest just to store self._hash_repr = repr(self._project) when the _DefaultAggregateStore is created, and define it __hash__ method to return hash(self._hash_repr). It ends up being only 2 lines (instead of adding entirely new methods) and I think it might be safer and more future-proof.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kidrahahjo @b-butler Tagging you both so you can see the above comment. I think this PR is good to go, now that I've fixed the issue described above. Is my explanation clear? Did I leave a good enough comment in the new code?

self._project = data["project"]
self._project.__dict__ = data["project_attributes"]

def __getstate__(self):
# We also need to store project attributes as they get lost
# during the process of pickling.
project_attributes = self._project.__dict__.copy()
return {
"project": self._project,
"project_attributes": project_attributes,
}


def get_aggregate_id(jobs):
"""Generate hashed id for an aggregate of jobs.
Expand Down
154 changes: 40 additions & 114 deletions flow/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from multiprocessing import Event, Pool, TimeoutError, cpu_count
from multiprocessing.pool import ThreadPool

import cloudpickle
import jinja2
import signac
from deprecation import deprecated
Expand Down Expand Up @@ -2313,7 +2314,7 @@ def _fetch_status(
file=err,
)
)
op_results = list(
group_results = list(
tqdm(
iterable=pool.imap(get_group_status, operation_names),
desc="Collecting operation status",
Expand All @@ -2324,64 +2325,32 @@ def _fetch_status(
elif status_parallelization == "process":
with contextlib.closing(Pool()) as pool:
try:
import pickle

l_results, g_results = self._fetch_status_in_parallel(
(
label_results,
group_results,
) = self._fetch_status_in_parallel(
pool,
pickle,
distinct_jobs,
operation_names,
ignore_errors,
cached_status,
)
except Exception as error:
if (
not isinstance(
error, (pickle.PickleError, self._PickleError)
)
and "pickle" not in str(error).lower()
):
raise # most likely not a pickle related error...

try:
import cloudpickle
except ImportError: # The cloudpickle package is not available.
logger.error(
"Unable to parallelize execution due to a "
"pickling error. "
"\n\n - Try to install the 'cloudpickle' package, "
"e.g., with 'pip install cloudpickle'!\n"
)
raise error
else:
try:
(
l_results,
g_results,
) = self._fetch_status_in_parallel(
pool,
cloudpickle,
distinct_jobs,
operation_names,
ignore_errors,
cached_status,
)
except self._PickleError as error:
raise RuntimeError(
"Unable to parallelize execution due to a pickling "
f"error: {error}."
)
except self._PickleError as error:
raise RuntimeError(
"Unable to parallelize execution due to a pickling "
f"error: {error}."
)
label_results = list(
tqdm(
iterable=l_results,
iterable=label_results,
desc="Collecting job label info",
total=len(distinct_jobs),
file=err,
)
)
op_results = list(
group_results = list(
tqdm(
iterable=g_results,
iterable=group_results,
desc="Collecting operation status",
total=len(operation_names),
file=err,
Expand All @@ -2396,7 +2365,7 @@ def _fetch_status(
file=err,
)
)
op_results = list(
group_results = list(
tqdm(
iterable=map(get_group_status, operation_names),
desc="Collecting operation status",
Expand Down Expand Up @@ -2434,7 +2403,7 @@ def print_status(iterable, fetch_status, description):
label_results = print_status(
distinct_jobs, get_job_labels, "Collecting job label info"
)
op_results = print_status(
group_results = print_status(
operation_names, get_group_status, "Collecting operation status"
)

Expand All @@ -2450,7 +2419,7 @@ def print_status(iterable, fetch_status, description):
results.append(results_entry)
index[job.get_id()] = i

for op_result in op_results:
for op_result in group_results:
for aggregate_id, aggregate_status in op_result[
"job_status_details"
].items():
Expand All @@ -2473,45 +2442,36 @@ def print_status(iterable, fetch_status, description):
return results

def _fetch_status_in_parallel(
self, pool, pickle, jobs, groups, ignore_errors, cached_status
self, pool, jobs, groups, ignore_errors, cached_status
):
try:
# Since pickling the project results in loss of necessary information. We
# explicitly pickle all the necessary information and then mock them in the
# serialized methods.
s_root = pickle.dumps(self.root_directory())
s_label_funcs = pickle.dumps(self._label_functions)
s_groups = pickle.dumps(self._groups)
s_groups_aggregate = pickle.dumps(self._stored_aggregates)
s_tasks_labels = [
serialized_project = cloudpickle.dumps(self)
serialized_tasks_labels = [
(
pickle.loads,
s_root,
cloudpickle.loads,
serialized_project,
job.get_id(),
ignore_errors,
s_label_funcs,
"fetch_labels",
)
for job in jobs
]
s_tasks_groups = [
serialized_tasks_groups = [
(
pickle.loads,
s_root,
cloudpickle.loads,
serialized_project,
group,
ignore_errors,
cached_status,
s_groups,
s_groups_aggregate,
"fetch_status",
)
for group in groups
]
except Exception as error: # Masking all errors since they must be pickling related.
raise self._PickleError(error)

label_results = pool.starmap(_serializer, s_tasks_labels)
group_results = pool.starmap(_serializer, s_tasks_groups)
label_results = pool.starmap(_serializer, serialized_tasks_labels)
group_results = pool.starmap(_serializer, serialized_tasks_groups)

return label_results, group_results

Expand Down Expand Up @@ -2990,38 +2950,14 @@ def _run_operations(
"Parallelized execution of %i operation(s).", len(operations)
)
try:
import pickle

self._run_operations_in_parallel(
pool, pickle, operations, progress, timeout
pool, operations, progress, timeout
)
except self._PickleError as error:
raise RuntimeError(
"Unable to parallelize execution due to a pickling "
"error: {}.".format(error)
)
logger.debug("Used cPickle module for serialization.")
except Exception as error:
if (
not isinstance(error, (pickle.PickleError, self._PickleError))
and "pickle" not in str(error).lower()
):
raise # most likely not a pickle related error...

try:
import cloudpickle
except ImportError: # The cloudpickle package is not available.
logger.error(
"Unable to parallelize execution due to a pickling error. "
"\n\n - Try to install the 'cloudpickle' package, e.g., with "
"'pip install cloudpickle'!\n"
)
raise error
else:
try:
self._run_operations_in_parallel(
pool, cloudpickle, operations, progress, timeout
)
except self._PickleError as error:
raise RuntimeError(
"Unable to parallelize execution due to a pickling "
f"error: {error}."
)

@deprecated(deprecated_in="0.11", removed_in="0.13", current_version=__version__)
def run_operations(
Expand Down Expand Up @@ -3070,26 +3006,23 @@ def _job_operation_from_tuple(self, data):
all_directives.update(directives)
return _JobOperation(id, name, jobs, cmd, all_directives)

def _run_operations_in_parallel(self, pool, pickle, operations, progress, timeout):
def _run_operations_in_parallel(self, pool, operations, progress, timeout):
"""Execute operations in parallel.

This function executes the given list of operations with the provided
process pool.

Since pickling of the project instance is likely to fail, we manually
pickle the project instance and the operations before submitting them
to the process pool to enable us to try different pool and pickle
module combinations.
to the process pool.
"""
try:
serialized_root = pickle.dumps(self.root_directory())
serialized_operations = pickle.dumps(self._operations)
serialized_project = cloudpickle.dumps(self)
serialized_tasks = [
(
pickle.loads,
serialized_root,
cloudpickle.loads,
serialized_project,
self._job_operation_to_tuple(operation),
serialized_operations,
"run_operations",
)
for operation in tqdm(
Expand Down Expand Up @@ -5139,26 +5072,19 @@ def _show_traceback_and_exit(error):
_show_traceback_and_exit(error)


def _serializer(loads, root, *args):
root = loads(root)
project = FlowProject.get_project(root)
def _serializer(loads, project, *args):
project = loads(project)
if args[-1] == "run_operations":
operation_data = args[0]
project._operations = loads(args[1])
project._execute_operation(project._job_operation_from_tuple(operation_data))
elif args[-1] == "fetch_labels":
job = project.open_job(id=args[0])
ignore_errors = args[1]
project._label_functions = loads(args[2])
return project._get_job_labels(job, ignore_errors=ignore_errors)
elif args[-1] == "fetch_status":
group = args[0]
ignore_errors = args[1]
cached_status = args[2]
groups = loads(args[3])
project._groups = groups
groups_aggregate = loads(args[4])
project._stored_aggregates = groups_aggregate
return project._get_group_status(group, ignore_errors, cached_status)
return None

Expand Down