From bd8f87a41c8b94441f66478d90cb03130d927d96 Mon Sep 17 00:00:00 2001 From: Jin Xin Ng Date: Fri, 19 Aug 2022 15:23:36 -0700 Subject: [PATCH] Introduce Validation Data Collector - Allows concurrent evaluation of models on a separate dataset during training, with --validation_data_path - This is done with minimal impact on training time by only utilizing the CPU for the validation dataset when it is mostly idle doing tf.train(), and pinning processes to specific CPUs - The amount of impact can be adjusted via a gin.config on cpu_affinity.py - CPU affinities are only optimized for internal AMD-Zen based systems at the moment, but can be extended in the future. --- .../distributed/local/cpu_affinity.py | 55 ++++ .../distributed/local/cpu_affinity_test.py | 30 +++ compiler_opt/distributed/worker.py | 9 +- compiler_opt/rl/compilation_runner.py | 5 +- compiler_opt/rl/corpus.py | 4 + compiler_opt/rl/inlining/inlining_runner.py | 2 + .../rl/local_validation_data_collector.py | 243 ++++++++++++++++++ compiler_opt/rl/train_locally.py | 48 +++- compiler_opt/rl/trainer.py | 6 + 9 files changed, 395 insertions(+), 7 deletions(-) create mode 100644 compiler_opt/distributed/local/cpu_affinity.py create mode 100644 compiler_opt/distributed/local/cpu_affinity_test.py create mode 100644 compiler_opt/rl/local_validation_data_collector.py diff --git a/compiler_opt/distributed/local/cpu_affinity.py b/compiler_opt/distributed/local/cpu_affinity.py new file mode 100644 index 00000000..466bb248 --- /dev/null +++ b/compiler_opt/distributed/local/cpu_affinity.py @@ -0,0 +1,55 @@ +# coding=utf-8 +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utility functions to set cpu affinities when operating main and subprocesses +simultaneously.""" +import gin +import psutil +import itertools + +_NR_CPUS = psutil.cpu_count() + +_CPU_CONFIG = { # List of CPU numbers in cache-sharing order. + # 'google-epyc' assumes logical core 0 and N/2 are the same physical core. + # Also, L3 cache is assumed to be shared between consecutive core numbers. + 'google-epyc': + list( + itertools.chain( + *zip(range(_NR_CPUS // 2), range(_NR_CPUS // 2, _NR_CPUS)))) +} + + +@gin.configurable +def set_and_get(is_main_process: bool, + max_cpus=_NR_CPUS, + min_main_cpu: int = 32, + arch: str = 'google-epyc'): + """ + Sets the cpu affinity of the current process to appropriate values, and + returns the list of cpus the process is set to use. + Args: + is_main_process: whether the caller is the main process. + max_cpus: maximal number of cpus to use + min_main_cpu: number of cpus to assign to the main process. + arch: the system type, used to infer the cpu cache architecture. + """ + config = _CPU_CONFIG[arch][:max_cpus] + if is_main_process: + cpus = config[:min_main_cpu] + else: + cpus = config[min_main_cpu:] + if len(cpus) == 0: + raise ValueError('Attempting to set cpu affinity of process to nothing.') + psutil.Process().cpu_affinity(cpus) + return list(cpus) diff --git a/compiler_opt/distributed/local/cpu_affinity_test.py b/compiler_opt/distributed/local/cpu_affinity_test.py new file mode 100644 index 00000000..c6781d29 --- /dev/null +++ b/compiler_opt/distributed/local/cpu_affinity_test.py @@ -0,0 +1,30 @@ +# coding=utf-8 +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Test for cpu_affinity.""" + +from absl.testing import absltest +from compiler_opt.distributed.local import cpu_affinity +# pylint: disable=protected-access + + +class CpuAffinityTest(absltest.TestCase): + + def test_tally(self): + for v in cpu_affinity._CPU_CONFIG: + self.assertLen(set(v), cpu_affinity._NR_CPUS) + + +if __name__ == '__main__': + absltest.main() diff --git a/compiler_opt/distributed/worker.py b/compiler_opt/distributed/worker.py index ff040d8e..07a28126 100644 --- a/compiler_opt/distributed/worker.py +++ b/compiler_opt/distributed/worker.py @@ -14,7 +14,7 @@ # limitations under the License. """Common abstraction for a worker contract.""" -from typing import Iterable, Optional, Protocol, TypeVar +from typing import Iterable, Optional, TypeVar, Protocol, runtime_checkable class Worker(Protocol): @@ -25,6 +25,13 @@ def is_priority_method(cls, method_name: str) -> bool: return False +@runtime_checkable +class ContextAwareWorker(Worker, Protocol): + + def set_context(self, local: bool) -> None: + return + + T = TypeVar('T') diff --git a/compiler_opt/rl/compilation_runner.py b/compiler_opt/rl/compilation_runner.py index 15a6d32e..1dccdd54 100644 --- a/compiler_opt/rl/compilation_runner.py +++ b/compiler_opt/rl/compilation_runner.py @@ -283,7 +283,8 @@ def is_priority_method(cls, method_name: str) -> bool: def __init__(self, clang_path: Optional[str] = None, launcher_path: Optional[str] = None, - moving_average_decay_rate: float = 1): + moving_average_decay_rate: float = 1, + compilation_timeout=_COMPILATION_TIMEOUT.value): """Initialization of CompilationRunner class. Args: @@ -294,7 +295,7 @@ def __init__(self, self._clang_path = clang_path self._launcher_path = launcher_path self._moving_average_decay_rate = moving_average_decay_rate - self._compilation_timeout = _COMPILATION_TIMEOUT.value + self._compilation_timeout = compilation_timeout self._cancellation_manager = WorkerCancellationManager() # re-allow the cancellation manager accept work. diff --git a/compiler_opt/rl/corpus.py b/compiler_opt/rl/corpus.py index e29abf31..81fdffe8 100644 --- a/compiler_opt/rl/corpus.py +++ b/compiler_opt/rl/corpus.py @@ -122,6 +122,10 @@ def filter(self, p: re.Pattern): """Filters module specs, keeping those which match the provided pattern.""" self._module_specs = [ms for ms in self._module_specs if p.match(ms.name)] + @property + def modules(self): + return list(self._module_specs) + def __len__(self): return len(self._module_specs) diff --git a/compiler_opt/rl/inlining/inlining_runner.py b/compiler_opt/rl/inlining/inlining_runner.py index 69730f0a..a0483849 100644 --- a/compiler_opt/rl/inlining/inlining_runner.py +++ b/compiler_opt/rl/inlining/inlining_runner.py @@ -71,6 +71,8 @@ def compile_fn( cancelled work. RuntimeError: if llvm-size produces unexpected output. """ + if cancellation_manager is None: + cancellation_manager = self._cancellation_manager working_dir = tempfile.mkdtemp() log_path = os.path.join(working_dir, 'log') diff --git a/compiler_opt/rl/local_validation_data_collector.py b/compiler_opt/rl/local_validation_data_collector.py new file mode 100644 index 00000000..f3e41b56 --- /dev/null +++ b/compiler_opt/rl/local_validation_data_collector.py @@ -0,0 +1,243 @@ +# coding=utf-8 +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Validation data collection module.""" +import concurrent.futures +import threading +import time +from typing import Dict, Optional, List, Tuple + +from absl import logging + +from compiler_opt.distributed import worker +from compiler_opt.distributed.local import buffered_scheduler +from compiler_opt.distributed.local import cpu_affinity +from compiler_opt.distributed.local.local_worker_manager import LocalWorkerPool + +from compiler_opt.rl import corpus + + +class LocalValidationDataCollector(worker.ContextAwareWorker): + """Local implementation of a validation data collector + Args: + module_specs: List of module specs to use + worker_pool_args: Pool of workers to use + """ + + def __init__(self, cps: corpus.Corpus, worker_pool_args, reward_stat_map, + max_cpus): + self._num_modules = len(cps) if cps is not None else 0 + self._corpus: corpus.Corpus = cps + self._default_rewards = {} + + self._running_policy = None + self._default_futures: List[worker.WorkerFuture] = [] + self._current_work: List[Tuple[corpus.ModuleSpec, worker.WorkerFuture]] = [] + self._last_time = None + self._elapsed_time = 0 + + self._context_local = True + + # Check a bit later so some expected vars have been set first. + if not cps: + return + + affinities = cpu_affinity.set_and_get( + is_main_process=False, max_cpus=max_cpus) + + # Add some runner specific flags. + logging.info('Validation data collector using %d workers.', len(affinities)) + worker_pool_args['count'] = len(affinities) + worker_pool_args['moving_average_decay_rate'] = 1 + worker_pool_args['compilation_timeout'] = 1200 + + # Borrow from the external reward_stat_map in case it got loaded from disk + # and already has some values. On a fresh run this will be recalculated + # from scratch in the main data collector and here. It would be ideal if + # both shared the same dict, but that would be too complex to implement. + for name, data in reward_stat_map.items(): + if name not in self._default_rewards: + self._default_rewards[name] = {} + for identifier, reward_stat in data.items(): + self._default_rewards[name][identifier] = reward_stat.default_reward + + self._pool = LocalWorkerPool(**worker_pool_args) + self._worker_pool = self._pool.stubs + + for i, p in zip(affinities, self._worker_pool): + p.set_nice(19) + p.set_affinity([i]) + + # BEGIN: ContextAwareWorker methods + @classmethod + def is_priority_method(cls, _: str) -> bool: + # Everything is a priority: this is essentially a synchronous RPC endpoint. + return True + + def set_context(self, local: bool): + self._context_local = local + + # END: ContextAwareWorker methods + + def _schedule_jobs(self, policy_path, module_specs): + default_jobs = [] + for module_spec in module_specs: + if module_spec.name not in self._default_rewards: + # The bool is reward_only, None is cancellation_manager + default_jobs.append((module_spec, '', True, None)) + + default_rewards_lock = threading.Lock() + + def create_update_rewards(spec_name): + + def updater(f: concurrent.futures.Future): + if f.exception() is not None: + reward_stat = f.result() + for identifier, (_, default_reward) in reward_stat: + with default_rewards_lock: + self._default_rewards[spec_name][identifier] = default_reward + + return updater + + # The bool is reward_only, None is cancellation_manager + policy_jobs = [ + (module_spec, policy_path, True, None) for module_spec in module_specs + ] + + def work_factory(job): + + def work(w): + return w.compile_fn(*job) + + return work + + work = [work_factory(job) for job in default_jobs] + work += [work_factory(job) for job in policy_jobs] + + futures = buffered_scheduler.schedule(work, self._worker_pool, buffer=10) + + self._default_futures = futures[:len(default_jobs)] + policy_futures = futures[len(default_jobs):] + + for job, future in zip(default_jobs, self._default_futures): + future.add_done_callback(create_update_rewards(job[0])) + + return policy_futures + + def collect_data_async( + self, + policy_path: str, + step: int = 0) -> Optional[Dict[tuple, Dict[str, float]]]: + """Collect data for a given policy. + + Args: + policy_path: the path to the policy directory to collect data with. + step: the step number associated with the policy_path + + Returns: + Either returns data in the form of a dictionary, or returns None if the + data is not ready yet. + """ + if self._num_modules == 0: + return None + + # Resume immediately, so that if new jobs are scheduled, + # they run while processing last batch's results + self.resume_children() + finished_work = [ + (spec, res) for spec, res in self._current_work if res.done() + ] + + # Check if there are default rewards being collected. + if len(self._default_futures) > 0: + finished_default_work = sum(res.done() for res in self._default_futures) + if finished_default_work != len(self._default_futures): + logging.info('%d out of %d default-rewards modules are finished.', + finished_default_work, len(self._default_futures)) + return None + + if len(finished_work) != len(self._current_work): # on 1st iter both are 0 + logging.info('%d out of %d modules are finished.', len(finished_work), + len(self._current_work)) + return None + module_specs = self._corpus.modules + results = self._schedule_jobs(policy_path, module_specs) + self._current_work = list(zip(module_specs, results)) + prev_policy = self._running_policy + self._running_policy = step + + if len(finished_work) == 0: # 1st iteration this is 0 + return None + + # Since all work is done: reset clock. Essential if processes never paused. + if self._last_time is not None: + cur_time = time.time() + self._elapsed_time += cur_time - self._last_time + self._last_time = cur_time + + successful_work = [(spec, res.result()) + for spec, res in finished_work + if not worker.get_exception(res)] + failures = len(finished_work) - len(successful_work) + + logging.info('%d of %d modules finished in %d seconds (%d failures).', + len(finished_work), self._num_modules, self._elapsed_time, + failures) + + sum_policy = 0 + sum_default = 0 + for spec, res in successful_work: + # res format: {_DEFAULT_IDENTIFIER: (None, native_size)} + for identifier, (_, policy_reward) in res: + sum_policy += policy_reward + sum_default += self._default_rewards[spec.name][identifier] + + if sum_default <= 0: + raise ValueError('Sum of default rewards is 0.') + reward = 1 - sum_policy / sum_default + + monitor_dict = { + prev_policy: { + 'success_modules': len(successful_work), + 'compile_wall_time': self._elapsed_time, + 'sum_reward': reward + } + } + self._elapsed_time = 0 # Only on completion this is reset + return monitor_dict + + def pause_children(self): + if not self._context_local or self._running_policy is None: + return + + for p in self._worker_pool: + p.pause_all_work() + + if self._last_time is not None: + self._elapsed_time += time.time() - self._last_time + self._last_time = None + + def resume_children(self): + last_time_was_none = False + if self._last_time is None: + last_time_was_none = True + self._last_time = time.time() + + if not self._context_local or self._running_policy is None: + return + + # Only pause changes last_time to None. + if last_time_was_none: + for p in self._worker_pool: + p.resume_all_work() diff --git a/compiler_opt/rl/train_locally.py b/compiler_opt/rl/train_locally.py index 4a0a47e8..b02a23c4 100644 --- a/compiler_opt/rl/train_locally.py +++ b/compiler_opt/rl/train_locally.py @@ -29,14 +29,15 @@ from tf_agents.system import system_multiprocessing as multiprocessing from typing import List +from compiler_opt.distributed.local import cpu_affinity from compiler_opt.distributed.local.local_worker_manager import LocalWorkerPool from compiler_opt.rl import agent_creators from compiler_opt.rl import compilation_runner from compiler_opt.rl import constant from compiler_opt.rl import corpus from compiler_opt.rl import data_reader -from compiler_opt.rl import gin_external_configurables # pylint: disable=unused-import from compiler_opt.rl import local_data_collector +from compiler_opt.rl import local_validation_data_collector from compiler_opt.rl import policy_saver from compiler_opt.rl import random_net_distillation from compiler_opt.rl import registry @@ -46,6 +47,8 @@ 'Root directory for writing logs/summaries/checkpoints.') flags.DEFINE_string('data_path', None, 'Path to directory containing the corpus.') +flags.DEFINE_string('validation_data_path', None, + 'Path to directory containing the validation corpus.') flags.DEFINE_integer( 'num_workers', None, 'Number of parallel data collection workers. `None` for max available') @@ -103,6 +106,17 @@ def train_eval(agent_name=constant.AgentName.PPO, problem_config.flags_to_delete()) logging.info('Done loading module specs from corpus.') + val_cps = None + if FLAGS.validation_data_path is not None: + logging.info('Loading module specs from validation corpus at %s.', + FLAGS.validation_data_path) + val_cps = corpus.Corpus(FLAGS.validation_data_path, + problem_config.flags_to_add(), + problem_config.flags_to_delete()) + logging.info('Done loading module specs from validation corpus.') + else: + logging.info('Validation corpus data path not specified.') + dataset_fn = data_reader.create_sequence_example_dataset_fn( agent_name=agent_name, time_step_spec=time_step_spec, @@ -129,11 +143,19 @@ def sequence_example_iterator_fn(seq_ex: List[str]): } logging.info('Loaded Reward Stat Map from disk, containing %d modules', len(reward_stat_map)) - + val_pool_args = {'worker_class': problem_config.get_runner_type()} with LocalWorkerPool( worker_class=problem_config.get_runner_type(), count=FLAGS.num_workers, - moving_average_decay_rate=moving_average_decay_rate) as worker_pool: + moving_average_decay_rate=moving_average_decay_rate + ) as worker_pool, LocalWorkerPool( + worker_class=local_validation_data_collector.LocalValidationDataCollector, + count=1, + cps=val_cps, + worker_pool_args=val_pool_args, + reward_stat_map=reward_stat_map, + max_cpus=FLAGS.num_workers) as validation_collector_pool: + data_collector = local_data_collector.LocalDataCollector( cps=cps, num_modules=num_modules, @@ -141,6 +163,10 @@ def sequence_example_iterator_fn(seq_ex: List[str]): parser=sequence_example_iterator_fn, reward_stat_map=reward_stat_map) + validation_collector = validation_collector_pool[0] + + if val_cps is not None: + cpu_affinity.set_and_get(is_main_process=True, max_cpus=FLAGS.num_workers) # Repeat for num_policy_iterations iterations. t1 = time.time() while (llvm_trainer.global_step_numpy() < @@ -154,10 +180,24 @@ def sequence_example_iterator_fn(seq_ex: List[str]): policy_path = os.path.join(root_dir, 'policy', str(llvm_trainer.global_step_numpy())) + # Pausing is done before saving to give the validation collector's + # children time to receive the stop signal, minimizing the risk of it + # executing simultaneously with the main data_collector's children. + if val_cps is not None: + validation_collector.pause_children() + time.sleep(15) saver.save(policy_path) + policy_fullpath = os.path.join(policy_path, deploy_policy_name) dataset_iter, monitor_dict = data_collector.collect_data( - policy_path=os.path.join(policy_path, deploy_policy_name)) + policy_path=policy_fullpath) + if val_cps is not None: + validation_dict_maybe = validation_collector.collect_data_async( + policy_path=policy_fullpath, + step=llvm_trainer.global_step_numpy()).result() + if validation_dict_maybe is not None: + llvm_trainer.write_validation_data(validation_dict_maybe) + llvm_trainer.train(dataset_iter, monitor_dict, num_iterations) data_collector.on_dataset_consumed(dataset_iter) diff --git a/compiler_opt/rl/trainer.py b/compiler_opt/rl/trainer.py index 912eae43..74814fd5 100644 --- a/compiler_opt/rl/trainer.py +++ b/compiler_opt/rl/trainer.py @@ -176,6 +176,12 @@ def _save_checkpoint(self): if tf.math.equal(self._global_step % self._checkpoint_interval, 0): self._checkpointer.save(global_step=self._global_step) + def write_validation_data(self, monitor_dict): + with tf.name_scope('validation/'): + for step, d in monitor_dict.items(): + for key, value in d.items(): + tf.summary.scalar(name=key, data=value, step=int(step)) + def global_step_numpy(self): return self._global_step.numpy()