Skip to content

Commit

Permalink
Add raw_reward_only to runners
Browse files Browse the repository at this point in the history
- Will be useful for validation runner and generate default trace
- Allows None values in sequence_examples
  • Loading branch information
Northbadge committed Aug 11, 2022
1 parent 99671e4 commit 096d1fb
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 15 deletions.
38 changes: 25 additions & 13 deletions compiler_opt/rl/compilation_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,16 +211,20 @@ class CompilationResult:
keys: List[str]

def __post_init__(self, sequence_examples: List[tf.train.SequenceExample]):
object.__setattr__(self, 'serialized_sequence_examples',
[x.SerializeToString() for x in sequence_examples])
object.__setattr__(
self, 'serialized_sequence_examples',
[x.SerializeToString() for x in sequence_examples if x is not None])
lengths = [
len(next(iter(x.feature_lists.feature_list.values())).feature)
for x in sequence_examples
if x is not None
]
object.__setattr__(self, 'length', sum(lengths))

assert (len(self.serialized_sequence_examples) == len(self.rewards) ==
(len(self.keys)))
# TODO: is it necessary to return keys AND reward_stats(which has the keys)?
# sequence_examples' length could also just not be checked, this allows
# raw_reward_only to do less work
assert (len(sequence_examples) == len(self.rewards) == (len(self.keys)))
assert set(self.keys) == set(self.reward_stats.keys())
assert not hasattr(self, 'sequence_examples')

Expand All @@ -230,9 +234,11 @@ class CompilationRunnerStub(metaclass=abc.ABCMeta):

@abc.abstractmethod
def collect_data(
self, module_spec: corpus.ModuleSpec, tf_policy_path: str,
reward_stat: Optional[Dict[str, RewardStat]]
) -> WorkerFuture[CompilationResult]:
self,
module_spec: corpus.ModuleSpec,
tf_policy_path: str,
reward_stat: Optional[Dict[str, RewardStat]],
raw_reward_only: bool = False) -> WorkerFuture[CompilationResult]:
raise NotImplementedError()

@abc.abstractmethod
Expand Down Expand Up @@ -275,17 +281,18 @@ def enable(self):
def cancel_all_work(self):
self._cancellation_manager.kill_all_processes()

def collect_data(
self, module_spec: corpus.ModuleSpec, tf_policy_path: str,
reward_stat: Optional[Dict[str, RewardStat]]) -> CompilationResult:
def collect_data(self,
module_spec: corpus.ModuleSpec,
tf_policy_path: str,
reward_stat: Optional[Dict[str, RewardStat]],
raw_reward_only=False) -> CompilationResult:
"""Collect data for the given IR file and policy.
Args:
module_spec: a ModuleSpec.
tf_policy_path: path to the tensorflow policy.
reward_stat: reward stat of this module, None if unknown.
cancellation_token: a CancellationToken through which workers may be
signaled early termination
raw_reward_only: whether to return the raw reward value without examples.
Returns:
A CompilationResult. In particular:
Expand All @@ -311,7 +318,7 @@ def collect_data(
policy_result = self._compile_fn(
module_spec,
tf_policy_path,
reward_only=False,
reward_only=raw_reward_only,
cancellation_manager=self._cancellation_manager)
else:
policy_result = default_result
Expand All @@ -326,6 +333,11 @@ def collect_data(
raise ValueError(
(f'Example {k} does not exist under default policy for '
f'module {module_spec.name}'))
if raw_reward_only:
sequence_example_list.append(None)
rewards.append(policy_reward)
keys.append(k)
continue
default_reward = reward_stat[k].default_reward
moving_average_reward = reward_stat[k].moving_average_reward
sequence_example = _overwrite_trajectory_reward(
Expand Down
25 changes: 25 additions & 0 deletions compiler_opt/rl/compilation_runner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,31 @@ def test_default(self, mock_compile_fn):
}, data.reward_stats)
self.assertAllClose([0], data.rewards)

@mock.patch(constant.BASE_MODULE_DIR +
'.compilation_runner.CompilationRunner._compile_fn')
def test_raw_reward_only(self, mock_compile_fn):
mock_compile_fn.side_effect = _mock_compile_fn
runner = compilation_runner.CompilationRunner(
moving_average_decay_rate=_MOVING_AVERAGE_DECAY_RATE)
data = runner.collect_data(
module_spec=corpus.ModuleSpec(name='dummy'),
tf_policy_path='policy_path',
reward_stat=None,
raw_reward_only=True)
self.assertEqual(2, mock_compile_fn.call_count)

self.assertLen(data.serialized_sequence_examples, 0)

self.assertEqual(0, data.length)
self.assertCountEqual(
{
'default':
compilation_runner.RewardStat(
default_reward=_DEFAULT_REWARD,
moving_average_reward=_DEFAULT_REWARD)
}, data.reward_stats)
self.assertAllClose([_POLICY_REWARD], data.rewards)

@mock.patch(constant.BASE_MODULE_DIR +
'.compilation_runner.CompilationRunner._compile_fn')
def test_given_default_size(self, mock_compile_fn):
Expand Down
6 changes: 5 additions & 1 deletion compiler_opt/rl/local_data_collector_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,11 @@ def mock_collect_data(module_spec, tf_policy_dir, reward_stat):
class Sleeper(compilation_runner.CompilationRunner):
"""Test CompilationRunner that just sleeps."""

def collect_data(self, module_spec, tf_policy_path, reward_stat):
def collect_data(self,
module_spec,
tf_policy_path,
reward_stat,
raw_reward_only=False):
_ = module_spec, tf_policy_path, reward_stat
compilation_runner.start_cancellable_process(['sleep', '3600s'], 3600,
self._cancellation_manager)
Expand Down
6 changes: 5 additions & 1 deletion compiler_opt/tools/generate_default_trace_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,11 @@
class MockCompilationRunner(compilation_runner.CompilationRunner):
"""A compilation runner just for test."""

def collect_data(self, module_spec, tf_policy_path, reward_stat):
def collect_data(self,
module_spec,
tf_policy_path,
reward_stat,
raw_reward_only=False):
sequence_example_text = """
feature_lists {
feature_list {
Expand Down

0 comments on commit 096d1fb

Please sign in to comment.