From b88097b24b09249ac52095809aca6d4dd3851c30 Mon Sep 17 00:00:00 2001 From: Abhishek Varghese Date: Mon, 31 Jul 2023 01:57:22 +0000 Subject: [PATCH 1/3] Added initial multiprocessing code --- src/rl.py | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/src/rl.py b/src/rl.py index f1b7cb657..a1c83d1ea 100644 --- a/src/rl.py +++ b/src/rl.py @@ -11,6 +11,8 @@ import pickle import sys import re +from mpire import WorkerPool +from multiprocessing import Process from pathlib import Path from operator import itemgetter from typing import (List, Optional, Dict, Tuple, Union, Any, Set, @@ -94,6 +96,7 @@ def main(): dest="blacklisted_tactics") parser.add_argument("--resume", choices=["no", "yes", "ask"], default="ask") parser.add_argument("--save-every", type=int, default=20) + parser.add_argument("--num-eval-workers", type=int, default=5) evalGroup = parser.add_mutually_exclusive_group() evalGroup.add_argument("--evaluate", action="store_true") evalGroup.add_argument("--evaluate-baseline", action="store_true") @@ -392,7 +395,8 @@ def reinforce_jobs(args: argparse.Namespace) -> None: initial_replay_buffer = replay_buffer) evaluate_results(args, evaluation_worker, test_tasks) else: - evaluate_results(args, worker, tasks) + dispatch_evaluation_workers(args, predictor, worker.v_network, worker.target_v_network, switch_dict, worker.replay_buffer, tasks) + # evaluate_results(args, worker, tasks) if args.verifyvval: verify_vvals(args, worker, tasks) @@ -794,6 +798,20 @@ def evaluate_proof(args: argparse.Namespace, return proof_succeeded +def dispatch_evaluation_workers(args, predictor, v_network, target_network, switch_dict, replay_buffer, tasks) : + + workers = [ ReinforcementWorker(args, predictor, v_network, target_network, switch_dict, + initial_replay_buffer = replay_buffer) for _ in range(args.num_eval_workers) ] + Processes = [] + for i in range(args.num_eval_workers) : + Processes.append(Process(target = evaluate_results, args = (args, workers[i], + tasks[len(tasks)*i//args.num_eval_workers : len(tasks)*(i+1)//args.num_eval_workers]))) + Processes[-1].start() + + for i in range(args.num_eval_workers) : + Processes[i].join() + + def evaluate_results(args: argparse.Namespace, worker: ReinforcementWorker, tasks: List[RLTask]) -> None: From 3f3d0c1e6fc023fd650a3d40f1dd08b6934b4d94 Mon Sep 17 00:00:00 2001 From: Abhishek Varghese Date: Tue, 1 Aug 2023 15:53:04 +0000 Subject: [PATCH 2/3] Added working code --- src/rl.py | 46 +++++++++++++++++++++++++++++++++++----------- 1 file changed, 35 insertions(+), 11 deletions(-) diff --git a/src/rl.py b/src/rl.py index a1c83d1ea..6d51496c4 100644 --- a/src/rl.py +++ b/src/rl.py @@ -12,7 +12,7 @@ import sys import re from mpire import WorkerPool -from multiprocessing import Process +from multiprocessing import Process, Manager, Queue from pathlib import Path from operator import itemgetter from typing import (List, Optional, Dict, Tuple, Union, Any, Set, @@ -395,6 +395,9 @@ def reinforce_jobs(args: argparse.Namespace) -> None: initial_replay_buffer = replay_buffer) evaluate_results(args, evaluation_worker, test_tasks) else: + # with WorkerPool(n_jobs = args.num_eval_workers, shared_variables = worker) as pool : + # results = pool.map(evaluate_results, () ) + dispatch_evaluation_workers(args, predictor, worker.v_network, worker.target_v_network, switch_dict, worker.replay_buffer, tasks) # evaluate_results(args, worker, tasks) if args.verifyvval: @@ -799,17 +802,38 @@ def evaluate_proof(args: argparse.Namespace, def dispatch_evaluation_workers(args, predictor, v_network, target_network, switch_dict, replay_buffer, tasks) : - + + return_queue = Queue() workers = [ ReinforcementWorker(args, predictor, v_network, target_network, switch_dict, initial_replay_buffer = replay_buffer) for _ in range(args.num_eval_workers) ] - Processes = [] + processes = [] for i in range(args.num_eval_workers) : - Processes.append(Process(target = evaluate_results, args = (args, workers[i], - tasks[len(tasks)*i//args.num_eval_workers : len(tasks)*(i+1)//args.num_eval_workers]))) - Processes[-1].start() + processes.append(Process(target = evaluation_worker, args = (args, workers[i], + tasks[len(tasks)*i//args.num_eval_workers : len(tasks)*(i+1)//args.num_eval_workers], return_queue))) + processes[-1].start() + print("Process", i, "started") + + for process in processes : + process.join() - for i in range(args.num_eval_workers) : - Processes[i].join() + results = [] + while not return_queue.empty() : + results.append(return_queue.get()) + + assert len(results) == args.num_eval_workers, "Here's the queue : " + str(results) + + proofs_completed = sum(results) + print(f"{proofs_completed} out of {len(tasks)} " + f"tasks successfully proven " + f"({stringified_percent(proofs_completed, len(tasks))}%)") + + +def evaluation_worker(args: argparse.Namespace, + worker: ReinforcementWorker, + tasks: List[RLTask], return_queue) : + + return_queue.put(evaluate_results(args, worker,tasks)) + return def evaluate_results(args: argparse.Namespace, @@ -819,9 +843,9 @@ def evaluate_results(args: argparse.Namespace, for task in tasks: if worker.evaluate_job(task.to_job(), task.tactic_prefix): proofs_completed += 1 - print(f"{proofs_completed} out of {len(tasks)} " - f"tasks successfully proven " - f"({stringified_percent(proofs_completed, len(tasks))}%)") + + return proofs_completed + def verify_vvals(args: argparse.Namespace, From 818c85f1bad814e47963e3f77c7912492553dec9 Mon Sep 17 00:00:00 2001 From: Abhishek Varghese Date: Fri, 4 Aug 2023 17:08:03 +0000 Subject: [PATCH 3/3] Added working multiprocessing --- src/rl.py | 52 ++++++++++++++++++++++++++++++++++------------------ 1 file changed, 34 insertions(+), 18 deletions(-) diff --git a/src/rl.py b/src/rl.py index 6d51496c4..f51d5d1d2 100644 --- a/src/rl.py +++ b/src/rl.py @@ -11,8 +11,7 @@ import pickle import sys import re -from mpire import WorkerPool -from multiprocessing import Process, Manager, Queue +from multiprocessing import Process, Queue, set_start_method from pathlib import Path from operator import itemgetter from typing import (List, Optional, Dict, Tuple, Union, Any, Set, @@ -390,16 +389,11 @@ def reinforce_jobs(args: argparse.Namespace) -> None: save_state(args, worker, step) if args.evaluate or args.evaluate_baseline: if args.test_file: - test_tasks = read_tasks_file(args, args.test_file, False) - evaluation_worker = ReinforcementWorker(args, predictor, v_network, target_network, switch_dict, - initial_replay_buffer = replay_buffer) - evaluate_results(args, evaluation_worker, test_tasks) - else: - # with WorkerPool(n_jobs = args.num_eval_workers, shared_variables = worker) as pool : - # results = pool.map(evaluate_results, () ) + evaluation_tasks = read_tasks_file(args, args.test_file, False) + else : + evaluation_tasks = tasks + dispatch_evaluation_workers(args, switch_dict, evaluation_tasks) - dispatch_evaluation_workers(args, predictor, worker.v_network, worker.target_v_network, switch_dict, worker.replay_buffer, tasks) - # evaluate_results(args, worker, tasks) if args.verifyvval: verify_vvals(args, worker, tasks) @@ -419,10 +413,13 @@ def __init__(self, term_encoder: 'coq2vec.CoqTermRNNVectorizer', super().__init__(term_encoder, max_num_hypotheses) self.obl_cache = OrderedDict() self.max_size = max_size #TODO: Add in arguments if desired + + def obligations_to_vectors_cached(self, obls: List[Obligation]) \ -> torch.FloatTensor: encoded_obl_size = self.term_encoder.hidden_size * (self.max_num_hypotheses + 1) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") cached_results = [] for obl in obls: r = self.obl_cache.get(obl, None) @@ -431,7 +428,7 @@ def obligations_to_vectors_cached(self, obls: List[Obligation]) \ cached_results.append(r) encoded = run_network_with_cache( - lambda x: self.obligations_to_vectors(x).view(len(x), encoded_obl_size), + lambda x: self.obligations_to_vectors(x).to(device).view(len(x), encoded_obl_size), [coq2vec.Obligation(list(obl.hypotheses), obl.goal) for obl in obls], cached_results) @@ -467,6 +464,8 @@ def _load_encoder_state(self, encoder_state: Any) -> None: num_hyps = 5 assert self.obligation_encoder is None, "Can't load weights twice!" self.obligation_encoder = CachedObligationEncoder(term_encoder, num_hyps) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") insize = term_encoder.hidden_size * (num_hyps + 1) self.network = nn.Sequential( nn.Linear(insize, 120), @@ -475,6 +474,7 @@ def _load_encoder_state(self, encoder_state: Any) -> None: nn.ReLU(), nn.Linear(84, 1), ) + self.network.to(device) self.optimizer = optim.RMSprop(self.network.parameters(), lr=self.learning_rate) self.adjuster = scheduler.StepLR(self.optimizer, self.batch_step, self.lr_step) @@ -801,14 +801,13 @@ def evaluate_proof(args: argparse.Namespace, return proof_succeeded -def dispatch_evaluation_workers(args, predictor, v_network, target_network, switch_dict, replay_buffer, tasks) : +def dispatch_evaluation_workers(args, switch_dict, tasks) : return_queue = Queue() - workers = [ ReinforcementWorker(args, predictor, v_network, target_network, switch_dict, - initial_replay_buffer = replay_buffer) for _ in range(args.num_eval_workers) ] + processes = [] for i in range(args.num_eval_workers) : - processes.append(Process(target = evaluation_worker, args = (args, workers[i], + processes.append(Process(target = evaluation_worker, args = (args, switch_dict, tasks[len(tasks)*i//args.num_eval_workers : len(tasks)*(i+1)//args.num_eval_workers], return_queue))) processes[-1].start() print("Process", i, "started") @@ -819,7 +818,7 @@ def dispatch_evaluation_workers(args, predictor, v_network, target_network, swit results = [] while not return_queue.empty() : results.append(return_queue.get()) - + assert len(results) == args.num_eval_workers, "Here's the queue : " + str(results) proofs_completed = sum(results) @@ -829,9 +828,24 @@ def dispatch_evaluation_workers(args, predictor, v_network, target_network, swit def evaluation_worker(args: argparse.Namespace, - worker: ReinforcementWorker, + switch_dict, tasks: List[RLTask], return_queue) : + predictor = MemoizingPredictor(get_predictor(args)) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + replay_buffer, steps_already_done, network_state, tnetwork_state, random_state = \ + torch.load(str(args.output_file), map_location=device) + random.setstate(random_state) + print(f"Resuming from existing weights of {steps_already_done} steps") + v_network = VNetwork(None, args.learning_rate, + args.batch_step, args.lr_step) + target_network = VNetwork(None, args.learning_rate, + args.batch_step, args.lr_step) + target_network.obligation_encoder = v_network.obligation_encoder + v_network.load_state(network_state) + target_network.load_state(tnetwork_state) + worker = ReinforcementWorker(args, predictor, v_network, target_network, switch_dict, + initial_replay_buffer = replay_buffer) return_queue.put(evaluate_results(args, worker,tasks)) return @@ -970,6 +984,7 @@ def run_network_with_cache(f: Callable[[List[T]], torch.FloatTensor], if output_list[i] is None: uncached_values.append(value) uncached_value_indices.append(i) + if len(uncached_values) > 0: new_results = f(uncached_values) for idx, result in zip(uncached_value_indices, new_results): @@ -984,4 +999,5 @@ def tactic_prefix_is_usable(tactic_prefix: List[str]): if __name__ == "__main__": + set_start_method('spawn') main()