Skip to content

Commit

Permalink
Introduce Validation Data Collector
Browse files Browse the repository at this point in the history
- Allows concurrent evaluation of models on a separate dataset during training, with --validation_data_path
- This is done with minimal impact on training time by only utilizing the CPU for the validation dataset when it is mostly idle doing tf.train(), and pinning processes to specific CPUs
- The amount of impact can be adjusted via a gin.config on cpu_affinity.py
- CPU affinities are only optimized for internal AMD-Zen based systems at the moment, but can be extended in the future.
  • Loading branch information
Northbadge committed Aug 10, 2022
1 parent 99671e4 commit 2523ea1
Show file tree
Hide file tree
Showing 11 changed files with 376 additions and 28 deletions.
41 changes: 40 additions & 1 deletion compiler_opt/distributed/local/local_worker_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,17 @@
import functools
import multiprocessing
import threading
import os
import psutil
import signal

from absl import logging
# pylint: disable=unused-import
from compiler_opt.distributed.worker import Worker

from contextlib import AbstractContextManager
from multiprocessing import connection
from typing import Any, Callable, Dict, Optional
from typing import Any, Callable, Dict, Optional, List


@dataclasses.dataclass(frozen=True)
Expand Down Expand Up @@ -131,6 +134,7 @@ def __init__(self):
# when we stop.
self._lock = threading.Lock()
self._map: Dict[int, concurrent.futures.Future] = {}
self.is_paused = False

# thread draining the pipe
self._pump = threading.Thread(target=self._msg_pump)
Expand Down Expand Up @@ -205,10 +209,37 @@ def shutdown(self):
try:
# Killing the process triggers observer exit, which triggers msg_pump
# exit
self.resume()
self._process.kill()
except: # pylint: disable=bare-except
pass

def pause(self):
if self.is_paused:
return
self.is_paused = True
# used to send the STOP signal; does not actually kill the process
os.kill(self._process.pid, signal.SIGSTOP)

def resume(self):
if not self.is_paused:
return
self.is_paused = False
# used to send the CONTINUE signal; does not actually kill the process
os.kill(self._process.pid, signal.SIGCONT)

def set_nice(self, val: int):
"""Sets the nice-ness of the process, this modifies how the OS
schedules it. Only works on Unix, since val is presumed to be an int.
"""
psutil.Process(self._process.pid).nice(val)

def set_affinity(self, val: List[int]):
"""Sets the CPU affinity of the process, this modifies which cores the OS
schedules it on.
"""
psutil.Process(self._process.pid).cpu_affinity(val)

def join(self):
self._observer.join()
self._pump.join()
Expand Down Expand Up @@ -242,3 +273,11 @@ def __exit__(self, *args):
# now wait for the message pumps to indicate they exit.
for s in self._stubs:
s.join()

def __del__(self):
self.__exit__()

@property
def stubs(self):
# Return a shallow copy, to avoid something messing the internal list up
return list(self._stubs)
11 changes: 9 additions & 2 deletions compiler_opt/distributed/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,24 @@
"""Common abstraction for a worker contract."""

import abc
from typing import Generic, Iterable, Optional, TypeVar
from typing import Generic, Iterable, Optional, TypeVar, Protocol, runtime_checkable


class Worker:
class Worker(Protocol):

@classmethod
def is_priority_method(cls, method_name: str) -> bool:
_ = method_name
return False


@runtime_checkable
class ContextAwareWorker(Worker, Protocol):

def set_context(self, local: bool) -> None:
return


T = TypeVar('T')


Expand Down
29 changes: 18 additions & 11 deletions compiler_opt/rl/compilation_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,10 +212,12 @@ class CompilationResult:

def __post_init__(self, sequence_examples: List[tf.train.SequenceExample]):
object.__setattr__(self, 'serialized_sequence_examples',
[x.SerializeToString() for x in sequence_examples])
[(x.SerializeToString() if x is not None else None)
for x in sequence_examples])
lengths = [
len(next(iter(x.feature_lists.feature_list.values())).feature)
for x in sequence_examples
if x is not None
]
object.__setattr__(self, 'length', sum(lengths))

Expand All @@ -229,10 +231,9 @@ class CompilationRunnerStub(metaclass=abc.ABCMeta):
"""The interface of a stub to CompilationRunner, for type checkers."""

@abc.abstractmethod
def collect_data(
self, module_spec: corpus.ModuleSpec, tf_policy_path: str,
reward_stat: Optional[Dict[str, RewardStat]]
) -> WorkerFuture[CompilationResult]:
def collect_data(self, module_spec: corpus.ModuleSpec, tf_policy_path: str,
reward_stat: Optional[Dict[str, RewardStat]],
raw_reward_only: bool) -> WorkerFuture[CompilationResult]:
raise NotImplementedError()

@abc.abstractmethod
Expand Down Expand Up @@ -275,17 +276,18 @@ def enable(self):
def cancel_all_work(self):
self._cancellation_manager.kill_all_processes()

def collect_data(
self, module_spec: corpus.ModuleSpec, tf_policy_path: str,
reward_stat: Optional[Dict[str, RewardStat]]) -> CompilationResult:
def collect_data(self,
module_spec: corpus.ModuleSpec,
tf_policy_path: str,
reward_stat: Optional[Dict[str, RewardStat]],
raw_reward_only=False) -> CompilationResult:
"""Collect data for the given IR file and policy.
Args:
module_spec: a ModuleSpec.
tf_policy_path: path to the tensorflow policy.
reward_stat: reward stat of this module, None if unknown.
cancellation_token: a CancellationToken through which workers may be
signaled early termination
raw_reward_only: whether to return the raw reward value without examples
Returns:
A CompilationResult. In particular:
Expand All @@ -311,7 +313,7 @@ def collect_data(
policy_result = self._compile_fn(
module_spec,
tf_policy_path,
reward_only=False,
reward_only=raw_reward_only,
cancellation_manager=self._cancellation_manager)
else:
policy_result = default_result
Expand All @@ -327,6 +329,11 @@ def collect_data(
(f'Example {k} does not exist under default policy for '
f'module {module_spec.name}'))
default_reward = reward_stat[k].default_reward
if raw_reward_only:
sequence_example_list.append(None)
rewards.append(policy_reward)
keys.append(k)
continue
moving_average_reward = reward_stat[k].moving_average_reward
sequence_example = _overwrite_trajectory_reward(
sequence_example=sequence_example,
Expand Down
52 changes: 52 additions & 0 deletions compiler_opt/rl/cpu_affinity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# coding=utf-8
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utility functions to set cpu affinities when operating main and subprocesses
simultaneously."""
import gin
import psutil
import itertools

N = psutil.cpu_count()

CPU_CONFIG = { # List of CPU numbers in cache-sharing order.
# 'google-epyc' assumes logical core 0 and N/2 are the same physical core.
# Also, L3 cache is assumed to be shared between consecutive core numbers.
'google-epyc': list(itertools.chain(*zip(range(N // 2), range(N // 2, N))))
}


@gin.configurable
def set_and_get(is_main_process: bool,
max_cpus=N,
min_main_cpu: int = 32,
arch: str = 'google-epyc'):
"""
Sets the cpu affinity of the current process to appropriate values, and
returns the list of cpus the process is set to use.
Args:
is_main_process: whether the caller is the main process.
max_cpus: maximal number of cpus to use
min_main_cpu: number of cpus to assign to the main process.
arch: the system type, used to infer the cpu cache architecture.
"""
config = CPU_CONFIG[arch][:max_cpus]
if is_main_process:
cpus = config[:min_main_cpu]
else:
cpus = config[min_main_cpu:]
if len(cpus) == 0:
raise ValueError('Attempting to set cpu affinity of process to nothing.')
psutil.Process().cpu_affinity(cpus)
return list(cpus)
17 changes: 11 additions & 6 deletions compiler_opt/rl/local_data_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import itertools
import random
import time
from typing import Callable, Dict, Iterator, List, Tuple, Optional
from typing import Callable, Dict, Iterator, List, Tuple, Optional, Any

from absl import logging
from tf_agents.trajectories import trajectory
Expand Down Expand Up @@ -55,7 +55,7 @@ def __init__(
# We remove this activity from the critical path by running it concurrently
# with the training phase - i.e. whatever happens between successive data
# collection calls. Subsequent runs will wait for these to finish.
self._reset_workers: concurrent.futures.Future = None
self._reset_workers: Optional[concurrent.futures.Future] = None
self._current_work: List[Tuple[corpus.ModuleSpec, worker.WorkerFuture]] = []
self._pool = concurrent.futures.ThreadPoolExecutor()

Expand All @@ -77,20 +77,25 @@ def _join_pending_jobs(self):
logging.info('Waiting for pending work from last iteration took %f',
time.time() - t1)

def _create_jobs(
self, policy_path: str, sampled_modules: List[corpus.ModuleSpec]
) -> Tuple[List[Tuple[Any, ...]], List[Optional[Dict]]]:
return [(module_spec, policy_path, self._reward_stat_map[module_spec.name])
for module_spec in sampled_modules], [{}] * len(sampled_modules)

def _schedule_jobs(
self, policy_path: str, sampled_modules: List[corpus.ModuleSpec]
) -> List[worker.WorkerFuture[compilation_runner.CompilationResult]]:
# by now, all the pending work, which was signaled to cancel, must've
# finished
self._join_pending_jobs()
jobs = [(module_spec, policy_path, self._reward_stat_map[module_spec.name])
for module_spec in sampled_modules]
args, kwargs = self._create_jobs(policy_path, sampled_modules)

# Naive load balancing.
ret = []
for i in range(len(jobs)):
for i, (arg, kwarg) in enumerate(zip(args, kwargs)):
ret.append(self._worker_pool[i % len(self._worker_pool)].collect_data(
*(jobs[i])))
*arg, **kwarg))
return ret

def collect_data(
Expand Down
6 changes: 5 additions & 1 deletion compiler_opt/rl/local_data_collector_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,11 @@ def mock_collect_data(module_spec, tf_policy_dir, reward_stat):
class Sleeper(compilation_runner.CompilationRunner):
"""Test CompilationRunner that just sleeps."""

def collect_data(self, module_spec, tf_policy_path, reward_stat):
def collect_data(self,
module_spec,
tf_policy_path,
reward_stat,
raw_reward_only=False):
_ = module_spec, tf_policy_path, reward_stat
compilation_runner.start_cancellable_process(['sleep', '3600s'], 3600,
self._cancellation_manager)
Expand Down
Loading

0 comments on commit 2523ea1

Please sign in to comment.