diff --git a/dpgen/auto_test/common_equi.py b/dpgen/auto_test/common_equi.py index 437178309..ec4010c70 100644 --- a/dpgen/auto_test/common_equi.py +++ b/dpgen/auto_test/common_equi.py @@ -2,6 +2,7 @@ import os import warnings from monty.serialization import dumpfn +from multiprocessing import Pool import dpgen.auto_test.lib.crys as crys import dpgen.auto_test.lib.util as util @@ -111,6 +112,47 @@ def make_equi(confs, inter.make_input_file(ii, 'relaxation', relax_param) +def worker(work_path, + run_task, + forward_common_files, + forward_files, + backward_files, + mdata, + inter_type): + machine, resources, command, group_size = util.get_machine_info(mdata, inter_type) + disp = make_dispatcher(machine, resources, work_path, [run_task], group_size) + print("%s --> Runing... " % (work_path)) + + api_version = mdata.get('api_version', '0.9') + if LooseVersion(api_version) < LooseVersion('1.0'): + warnings.warn(f"the dpdispatcher will be updated to new version." + f"And the interface may be changed. Please check the documents for more details") + disp.run_jobs(resources, + command, + work_path, + [run_task], + group_size, + forward_common_files, + forward_files, + backward_files, + outlog='outlog', + errlog='errlog') + elif LooseVersion(api_version) >= LooseVersion('1.0'): + submission = make_submission( + mdata_machine=machine, + mdata_resource=resources, + commands=[command], + work_path=work_path, + run_tasks=run_task, + group_size=group_size, + forward_common_files=forward_common_files, + forward_files=forward_files, + backward_files=backward_files, + outlog='outlog', + errlog='errlog' + ) + submission.run_submission() + def run_equi(confs, inter_param, mdata): @@ -120,6 +162,11 @@ def run_equi(confs, for conf in confs: conf_dirs.extend(glob.glob(conf)) conf_dirs.sort() + + processes = len(conf_dirs) + pool = Pool(processes=processes) + print("Submit job via %d processes" % processes) + # generate a list of task names like mp-xxx/relaxation/relax_task # ... work_path_list = [] @@ -150,45 +197,28 @@ def run_equi(confs, if len(run_tasks) == 0: return else: - # if LooseVersion() run_tasks = [os.path.basename(ii) for ii in all_task] machine, resources, command, group_size = util.get_machine_info(mdata, inter_type) print('%d tasks will be submited '%len(run_tasks)) + multiple_ret = [] for ii in range(len(work_path_list)): work_path = work_path_list[ii] - disp = make_dispatcher(machine, resources, work_path, [run_tasks[ii]], group_size) - print("%s --> Runing... "%(work_path)) - - api_version = mdata.get('api_version', '0.9') - if LooseVersion(api_version) < LooseVersion('1.0'): - warnings.warn(f"the dpdispatcher will be updated to new version." - f"And the interface may be changed. Please check the documents for more details") - disp.run_jobs(resources, - command, - work_path, - [run_tasks[ii]], - group_size, - forward_common_files, - forward_files, - backward_files, - outlog='outlog', - errlog='errlog') - elif LooseVersion(api_version) >= LooseVersion('1.0'): - submission = make_submission( - mdata_machine=machine, - mdata_resource=resources, - commands=[command], - work_path=work_path, - run_tasks=run_tasks, - group_size=group_size, - forward_common_files=forward_common_files, - forward_files=forward_files, - backward_files=backward_files, - outlog = 'outlog', - errlog = 'errlog' - ) - submission.run_submission() + ret = pool.apply_async(worker, (work_path, + run_tasks[ii], + forward_common_files, + forward_files, + backward_files, + mdata, + inter_type, + )) + multiple_ret.append(ret) + pool.close() + pool.join() + for ii in range(len(multiple_ret)): + if not multiple_ret[ii].successful(): + raise RuntimeError("Task %d is not successful! work_path: %s " % (ii, work_path_list[ii])) + print('finished') def post_equi(confs, inter_param): # find all POSCARs and their name like mp-xxx