Skip to content

Commit

Permalink
Add raw_reward_only to runners
Browse files Browse the repository at this point in the history
- Will be useful for validation runner and generate default trace
- Allows None values in sequence_examples
  • Loading branch information
Northbadge committed Aug 11, 2022
1 parent 99671e4 commit 468adc3
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 16 deletions.
73 changes: 57 additions & 16 deletions compiler_opt/rl/compilation_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ def __post_init__(self, sequence_examples: List[tf.train.SequenceExample]):
]
object.__setattr__(self, 'length', sum(lengths))

# TODO: is it necessary to return keys AND reward_stats(which has the keys)?
assert (len(self.serialized_sequence_examples) == len(self.rewards) ==
(len(self.keys)))
assert set(self.keys) == set(self.reward_stats.keys())
Expand All @@ -228,6 +229,14 @@ def __post_init__(self, sequence_examples: List[tf.train.SequenceExample]):
class CompilationRunnerStub(metaclass=abc.ABCMeta):
"""The interface of a stub to CompilationRunner, for type checkers."""

@abc.abstractmethod
def collect_results(self,
module_spec: corpus.ModuleSpec,
tf_policy_path: str,
collect_default_result: bool,
reward_only: bool = False) -> Tuple[Dict, Dict]:
raise NotImplementedError()

@abc.abstractmethod
def collect_data(
self, module_spec: corpus.ModuleSpec, tf_policy_path: str,
Expand Down Expand Up @@ -275,6 +284,47 @@ def enable(self):
def cancel_all_work(self):
self._cancellation_manager.kill_all_processes()

@staticmethod
def get_rewards(result: Dict) -> List[float]:
if len(result) == 0:
return []
return [v[1] for v in result.values()]

def collect_results(self,
module_spec: corpus.ModuleSpec,
tf_policy_path: str,
collect_default_result: bool,
reward_only: bool = False) -> Tuple[Dict, Dict]:
"""Collect data for the given IR file and policy.
Args:
module_spec: a ModuleSpec.
tf_policy_path: path to the tensorflow policy.
collect_default_result: whether to get the default result as well.
reward_only: whether to only collect the rewards in the results.
Returns:
A tuple of the default result and policy result.
"""
default_result = None
policy_result = None
if collect_default_result:
default_result = self._compile_fn(
module_spec,
tf_policy_path='',
reward_only=bool(tf_policy_path) or reward_only,
cancellation_manager=self._cancellation_manager)
policy_result = default_result

if tf_policy_path:
policy_result = self._compile_fn(
module_spec,
tf_policy_path,
reward_only=reward_only,
cancellation_manager=self._cancellation_manager)

return default_result, policy_result

def collect_data(
self, module_spec: corpus.ModuleSpec, tf_policy_path: str,
reward_stat: Optional[Dict[str, RewardStat]]) -> CompilationResult:
Expand All @@ -284,8 +334,6 @@ def collect_data(
module_spec: a ModuleSpec.
tf_policy_path: path to the tensorflow policy.
reward_stat: reward stat of this module, None if unknown.
cancellation_token: a CancellationToken through which workers may be
signaled early termination
Returns:
A CompilationResult. In particular:
Expand All @@ -297,25 +345,18 @@ def collect_data(
compilation_runner.ProcessKilledException is passed through.
ValueError if example under default policy and ml policy does not match.
"""
default_result, policy_result = self.collect_results(
module_spec,
tf_policy_path,
collect_default_result=reward_stat is None,
reward_only=False)
if reward_stat is None:
default_result = self._compile_fn(
module_spec,
tf_policy_path='',
reward_only=bool(tf_policy_path),
cancellation_manager=self._cancellation_manager)
# TODO: Add structure to default_result and policy_result.
# get_rewards above should be updated/removed when this is resolved.
reward_stat = {
k: RewardStat(v[1], v[1]) for (k, v) in default_result.items()
}

if tf_policy_path:
policy_result = self._compile_fn(
module_spec,
tf_policy_path,
reward_only=False,
cancellation_manager=self._cancellation_manager)
else:
policy_result = default_result

sequence_example_list = []
rewards = []
keys = []
Expand Down
23 changes: 23 additions & 0 deletions compiler_opt/rl/compilation_runner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,29 @@ def test_default(self, mock_compile_fn):
}, data.reward_stats)
self.assertAllClose([0], data.rewards)

@mock.patch(constant.BASE_MODULE_DIR +
'.compilation_runner.CompilationRunner._compile_fn')
def test_reward_only(self, mock_compile_fn):
mock_compile_fn.side_effect = _mock_compile_fn
runner = compilation_runner.CompilationRunner(
moving_average_decay_rate=_MOVING_AVERAGE_DECAY_RATE)
default_result, policy_result = runner.collect_results(
module_spec=corpus.ModuleSpec(name='dummy'),
tf_policy_path='policy_path',
collect_default_result=True,
reward_only=True)
self.assertEqual(2, mock_compile_fn.call_count)

self.assertIsNotNone(default_result)
self.assertIsNotNone(policy_result)

self.assertEqual(
[_DEFAULT_REWARD],
compilation_runner.CompilationRunner.get_rewards(default_result))
self.assertEqual(
[_POLICY_REWARD],
compilation_runner.CompilationRunner.get_rewards(policy_result))

@mock.patch(constant.BASE_MODULE_DIR +
'.compilation_runner.CompilationRunner._compile_fn')
def test_given_default_size(self, mock_compile_fn):
Expand Down

0 comments on commit 468adc3

Please sign in to comment.