diff --git a/compiler_opt/distributed/local/cpu_affinity.py b/compiler_opt/distributed/local/cpu_affinity.py new file mode 100644 index 00000000..466bb248 --- /dev/null +++ b/compiler_opt/distributed/local/cpu_affinity.py @@ -0,0 +1,55 @@ +# coding=utf-8 +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utility functions to set cpu affinities when operating main and subprocesses +simultaneously.""" +import gin +import psutil +import itertools + +_NR_CPUS = psutil.cpu_count() + +_CPU_CONFIG = { # List of CPU numbers in cache-sharing order. + # 'google-epyc' assumes logical core 0 and N/2 are the same physical core. + # Also, L3 cache is assumed to be shared between consecutive core numbers. + 'google-epyc': + list( + itertools.chain( + *zip(range(_NR_CPUS // 2), range(_NR_CPUS // 2, _NR_CPUS)))) +} + + +@gin.configurable +def set_and_get(is_main_process: bool, + max_cpus=_NR_CPUS, + min_main_cpu: int = 32, + arch: str = 'google-epyc'): + """ + Sets the cpu affinity of the current process to appropriate values, and + returns the list of cpus the process is set to use. + Args: + is_main_process: whether the caller is the main process. + max_cpus: maximal number of cpus to use + min_main_cpu: number of cpus to assign to the main process. + arch: the system type, used to infer the cpu cache architecture. + """ + config = _CPU_CONFIG[arch][:max_cpus] + if is_main_process: + cpus = config[:min_main_cpu] + else: + cpus = config[min_main_cpu:] + if len(cpus) == 0: + raise ValueError('Attempting to set cpu affinity of process to nothing.') + psutil.Process().cpu_affinity(cpus) + return list(cpus) diff --git a/compiler_opt/distributed/local/cpu_affinity_test.py b/compiler_opt/distributed/local/cpu_affinity_test.py new file mode 100644 index 00000000..c6781d29 --- /dev/null +++ b/compiler_opt/distributed/local/cpu_affinity_test.py @@ -0,0 +1,30 @@ +# coding=utf-8 +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Test for cpu_affinity.""" + +from absl.testing import absltest +from compiler_opt.distributed.local import cpu_affinity +# pylint: disable=protected-access + + +class CpuAffinityTest(absltest.TestCase): + + def test_tally(self): + for v in cpu_affinity._CPU_CONFIG: + self.assertLen(set(v), cpu_affinity._NR_CPUS) + + +if __name__ == '__main__': + absltest.main() diff --git a/compiler_opt/distributed/worker.py b/compiler_opt/distributed/worker.py index ff040d8e..07a28126 100644 --- a/compiler_opt/distributed/worker.py +++ b/compiler_opt/distributed/worker.py @@ -14,7 +14,7 @@ # limitations under the License. """Common abstraction for a worker contract.""" -from typing import Iterable, Optional, Protocol, TypeVar +from typing import Iterable, Optional, TypeVar, Protocol, runtime_checkable class Worker(Protocol): @@ -25,6 +25,13 @@ def is_priority_method(cls, method_name: str) -> bool: return False +@runtime_checkable +class ContextAwareWorker(Worker, Protocol): + + def set_context(self, local: bool) -> None: + return + + T = TypeVar('T') diff --git a/compiler_opt/rl/compilation_runner.py b/compiler_opt/rl/compilation_runner.py index 15a6d32e..1dccdd54 100644 --- a/compiler_opt/rl/compilation_runner.py +++ b/compiler_opt/rl/compilation_runner.py @@ -283,7 +283,8 @@ def is_priority_method(cls, method_name: str) -> bool: def __init__(self, clang_path: Optional[str] = None, launcher_path: Optional[str] = None, - moving_average_decay_rate: float = 1): + moving_average_decay_rate: float = 1, + compilation_timeout=_COMPILATION_TIMEOUT.value): """Initialization of CompilationRunner class. Args: @@ -294,7 +295,7 @@ def __init__(self, self._clang_path = clang_path self._launcher_path = launcher_path self._moving_average_decay_rate = moving_average_decay_rate - self._compilation_timeout = _COMPILATION_TIMEOUT.value + self._compilation_timeout = compilation_timeout self._cancellation_manager = WorkerCancellationManager() # re-allow the cancellation manager accept work. diff --git a/compiler_opt/rl/corpus.py b/compiler_opt/rl/corpus.py index e29abf31..81fdffe8 100644 --- a/compiler_opt/rl/corpus.py +++ b/compiler_opt/rl/corpus.py @@ -122,6 +122,10 @@ def filter(self, p: re.Pattern): """Filters module specs, keeping those which match the provided pattern.""" self._module_specs = [ms for ms in self._module_specs if p.match(ms.name)] + @property + def modules(self): + return list(self._module_specs) + def __len__(self): return len(self._module_specs) diff --git a/compiler_opt/rl/inlining/inlining_runner.py b/compiler_opt/rl/inlining/inlining_runner.py index 69730f0a..a0483849 100644 --- a/compiler_opt/rl/inlining/inlining_runner.py +++ b/compiler_opt/rl/inlining/inlining_runner.py @@ -71,6 +71,8 @@ def compile_fn( cancelled work. RuntimeError: if llvm-size produces unexpected output. """ + if cancellation_manager is None: + cancellation_manager = self._cancellation_manager working_dir = tempfile.mkdtemp() log_path = os.path.join(working_dir, 'log') diff --git a/compiler_opt/rl/local_validation_data_collector.py b/compiler_opt/rl/local_validation_data_collector.py new file mode 100644 index 00000000..f3e41b56 --- /dev/null +++ b/compiler_opt/rl/local_validation_data_collector.py @@ -0,0 +1,243 @@ +# coding=utf-8 +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Validation data collection module.""" +import concurrent.futures +import threading +import time +from typing import Dict, Optional, List, Tuple + +from absl import logging + +from compiler_opt.distributed import worker +from compiler_opt.distributed.local import buffered_scheduler +from compiler_opt.distributed.local import cpu_affinity +from compiler_opt.distributed.local.local_worker_manager import LocalWorkerPool + +from compiler_opt.rl import corpus + + +class LocalValidationDataCollector(worker.ContextAwareWorker): + """Local implementation of a validation data collector + Args: + module_specs: List of module specs to use + worker_pool_args: Pool of workers to use + """ + + def __init__(self, cps: corpus.Corpus, worker_pool_args, reward_stat_map, + max_cpus): + self._num_modules = len(cps) if cps is not None else 0 + self._corpus: corpus.Corpus = cps + self._default_rewards = {} + + self._running_policy = None + self._default_futures: List[worker.WorkerFuture] = [] + self._current_work: List[Tuple[corpus.ModuleSpec, worker.WorkerFuture]] = [] + self._last_time = None + self._elapsed_time = 0 + + self._context_local = True + + # Check a bit later so some expected vars have been set first. + if not cps: + return + + affinities = cpu_affinity.set_and_get( + is_main_process=False, max_cpus=max_cpus) + + # Add some runner specific flags. + logging.info('Validation data collector using %d workers.', len(affinities)) + worker_pool_args['count'] = len(affinities) + worker_pool_args['moving_average_decay_rate'] = 1 + worker_pool_args['compilation_timeout'] = 1200 + + # Borrow from the external reward_stat_map in case it got loaded from disk + # and already has some values. On a fresh run this will be recalculated + # from scratch in the main data collector and here. It would be ideal if + # both shared the same dict, but that would be too complex to implement. + for name, data in reward_stat_map.items(): + if name not in self._default_rewards: + self._default_rewards[name] = {} + for identifier, reward_stat in data.items(): + self._default_rewards[name][identifier] = reward_stat.default_reward + + self._pool = LocalWorkerPool(**worker_pool_args) + self._worker_pool = self._pool.stubs + + for i, p in zip(affinities, self._worker_pool): + p.set_nice(19) + p.set_affinity([i]) + + # BEGIN: ContextAwareWorker methods + @classmethod + def is_priority_method(cls, _: str) -> bool: + # Everything is a priority: this is essentially a synchronous RPC endpoint. + return True + + def set_context(self, local: bool): + self._context_local = local + + # END: ContextAwareWorker methods + + def _schedule_jobs(self, policy_path, module_specs): + default_jobs = [] + for module_spec in module_specs: + if module_spec.name not in self._default_rewards: + # The bool is reward_only, None is cancellation_manager + default_jobs.append((module_spec, '', True, None)) + + default_rewards_lock = threading.Lock() + + def create_update_rewards(spec_name): + + def updater(f: concurrent.futures.Future): + if f.exception() is not None: + reward_stat = f.result() + for identifier, (_, default_reward) in reward_stat: + with default_rewards_lock: + self._default_rewards[spec_name][identifier] = default_reward + + return updater + + # The bool is reward_only, None is cancellation_manager + policy_jobs = [ + (module_spec, policy_path, True, None) for module_spec in module_specs + ] + + def work_factory(job): + + def work(w): + return w.compile_fn(*job) + + return work + + work = [work_factory(job) for job in default_jobs] + work += [work_factory(job) for job in policy_jobs] + + futures = buffered_scheduler.schedule(work, self._worker_pool, buffer=10) + + self._default_futures = futures[:len(default_jobs)] + policy_futures = futures[len(default_jobs):] + + for job, future in zip(default_jobs, self._default_futures): + future.add_done_callback(create_update_rewards(job[0])) + + return policy_futures + + def collect_data_async( + self, + policy_path: str, + step: int = 0) -> Optional[Dict[tuple, Dict[str, float]]]: + """Collect data for a given policy. + + Args: + policy_path: the path to the policy directory to collect data with. + step: the step number associated with the policy_path + + Returns: + Either returns data in the form of a dictionary, or returns None if the + data is not ready yet. + """ + if self._num_modules == 0: + return None + + # Resume immediately, so that if new jobs are scheduled, + # they run while processing last batch's results + self.resume_children() + finished_work = [ + (spec, res) for spec, res in self._current_work if res.done() + ] + + # Check if there are default rewards being collected. + if len(self._default_futures) > 0: + finished_default_work = sum(res.done() for res in self._default_futures) + if finished_default_work != len(self._default_futures): + logging.info('%d out of %d default-rewards modules are finished.', + finished_default_work, len(self._default_futures)) + return None + + if len(finished_work) != len(self._current_work): # on 1st iter both are 0 + logging.info('%d out of %d modules are finished.', len(finished_work), + len(self._current_work)) + return None + module_specs = self._corpus.modules + results = self._schedule_jobs(policy_path, module_specs) + self._current_work = list(zip(module_specs, results)) + prev_policy = self._running_policy + self._running_policy = step + + if len(finished_work) == 0: # 1st iteration this is 0 + return None + + # Since all work is done: reset clock. Essential if processes never paused. + if self._last_time is not None: + cur_time = time.time() + self._elapsed_time += cur_time - self._last_time + self._last_time = cur_time + + successful_work = [(spec, res.result()) + for spec, res in finished_work + if not worker.get_exception(res)] + failures = len(finished_work) - len(successful_work) + + logging.info('%d of %d modules finished in %d seconds (%d failures).', + len(finished_work), self._num_modules, self._elapsed_time, + failures) + + sum_policy = 0 + sum_default = 0 + for spec, res in successful_work: + # res format: {_DEFAULT_IDENTIFIER: (None, native_size)} + for identifier, (_, policy_reward) in res: + sum_policy += policy_reward + sum_default += self._default_rewards[spec.name][identifier] + + if sum_default <= 0: + raise ValueError('Sum of default rewards is 0.') + reward = 1 - sum_policy / sum_default + + monitor_dict = { + prev_policy: { + 'success_modules': len(successful_work), + 'compile_wall_time': self._elapsed_time, + 'sum_reward': reward + } + } + self._elapsed_time = 0 # Only on completion this is reset + return monitor_dict + + def pause_children(self): + if not self._context_local or self._running_policy is None: + return + + for p in self._worker_pool: + p.pause_all_work() + + if self._last_time is not None: + self._elapsed_time += time.time() - self._last_time + self._last_time = None + + def resume_children(self): + last_time_was_none = False + if self._last_time is None: + last_time_was_none = True + self._last_time = time.time() + + if not self._context_local or self._running_policy is None: + return + + # Only pause changes last_time to None. + if last_time_was_none: + for p in self._worker_pool: + p.resume_all_work() diff --git a/compiler_opt/rl/train_locally.py b/compiler_opt/rl/train_locally.py index 4a0a47e8..b02a23c4 100644 --- a/compiler_opt/rl/train_locally.py +++ b/compiler_opt/rl/train_locally.py @@ -29,14 +29,15 @@ from tf_agents.system import system_multiprocessing as multiprocessing from typing import List +from compiler_opt.distributed.local import cpu_affinity from compiler_opt.distributed.local.local_worker_manager import LocalWorkerPool from compiler_opt.rl import agent_creators from compiler_opt.rl import compilation_runner from compiler_opt.rl import constant from compiler_opt.rl import corpus from compiler_opt.rl import data_reader -from compiler_opt.rl import gin_external_configurables # pylint: disable=unused-import from compiler_opt.rl import local_data_collector +from compiler_opt.rl import local_validation_data_collector from compiler_opt.rl import policy_saver from compiler_opt.rl import random_net_distillation from compiler_opt.rl import registry @@ -46,6 +47,8 @@ 'Root directory for writing logs/summaries/checkpoints.') flags.DEFINE_string('data_path', None, 'Path to directory containing the corpus.') +flags.DEFINE_string('validation_data_path', None, + 'Path to directory containing the validation corpus.') flags.DEFINE_integer( 'num_workers', None, 'Number of parallel data collection workers. `None` for max available') @@ -103,6 +106,17 @@ def train_eval(agent_name=constant.AgentName.PPO, problem_config.flags_to_delete()) logging.info('Done loading module specs from corpus.') + val_cps = None + if FLAGS.validation_data_path is not None: + logging.info('Loading module specs from validation corpus at %s.', + FLAGS.validation_data_path) + val_cps = corpus.Corpus(FLAGS.validation_data_path, + problem_config.flags_to_add(), + problem_config.flags_to_delete()) + logging.info('Done loading module specs from validation corpus.') + else: + logging.info('Validation corpus data path not specified.') + dataset_fn = data_reader.create_sequence_example_dataset_fn( agent_name=agent_name, time_step_spec=time_step_spec, @@ -129,11 +143,19 @@ def sequence_example_iterator_fn(seq_ex: List[str]): } logging.info('Loaded Reward Stat Map from disk, containing %d modules', len(reward_stat_map)) - + val_pool_args = {'worker_class': problem_config.get_runner_type()} with LocalWorkerPool( worker_class=problem_config.get_runner_type(), count=FLAGS.num_workers, - moving_average_decay_rate=moving_average_decay_rate) as worker_pool: + moving_average_decay_rate=moving_average_decay_rate + ) as worker_pool, LocalWorkerPool( + worker_class=local_validation_data_collector.LocalValidationDataCollector, + count=1, + cps=val_cps, + worker_pool_args=val_pool_args, + reward_stat_map=reward_stat_map, + max_cpus=FLAGS.num_workers) as validation_collector_pool: + data_collector = local_data_collector.LocalDataCollector( cps=cps, num_modules=num_modules, @@ -141,6 +163,10 @@ def sequence_example_iterator_fn(seq_ex: List[str]): parser=sequence_example_iterator_fn, reward_stat_map=reward_stat_map) + validation_collector = validation_collector_pool[0] + + if val_cps is not None: + cpu_affinity.set_and_get(is_main_process=True, max_cpus=FLAGS.num_workers) # Repeat for num_policy_iterations iterations. t1 = time.time() while (llvm_trainer.global_step_numpy() < @@ -154,10 +180,24 @@ def sequence_example_iterator_fn(seq_ex: List[str]): policy_path = os.path.join(root_dir, 'policy', str(llvm_trainer.global_step_numpy())) + # Pausing is done before saving to give the validation collector's + # children time to receive the stop signal, minimizing the risk of it + # executing simultaneously with the main data_collector's children. + if val_cps is not None: + validation_collector.pause_children() + time.sleep(15) saver.save(policy_path) + policy_fullpath = os.path.join(policy_path, deploy_policy_name) dataset_iter, monitor_dict = data_collector.collect_data( - policy_path=os.path.join(policy_path, deploy_policy_name)) + policy_path=policy_fullpath) + if val_cps is not None: + validation_dict_maybe = validation_collector.collect_data_async( + policy_path=policy_fullpath, + step=llvm_trainer.global_step_numpy()).result() + if validation_dict_maybe is not None: + llvm_trainer.write_validation_data(validation_dict_maybe) + llvm_trainer.train(dataset_iter, monitor_dict, num_iterations) data_collector.on_dataset_consumed(dataset_iter) diff --git a/compiler_opt/rl/trainer.py b/compiler_opt/rl/trainer.py index 912eae43..74814fd5 100644 --- a/compiler_opt/rl/trainer.py +++ b/compiler_opt/rl/trainer.py @@ -176,6 +176,12 @@ def _save_checkpoint(self): if tf.math.equal(self._global_step % self._checkpoint_interval, 0): self._checkpointer.save(global_step=self._global_step) + def write_validation_data(self, monitor_dict): + with tf.name_scope('validation/'): + for step, d in monitor_dict.items(): + for key, value in d.items(): + tf.summary.scalar(name=key, data=value, step=int(step)) + def global_step_numpy(self): return self._global_step.numpy()