Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add reward_only to runners #100

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 57 additions & 16 deletions compiler_opt/rl/compilation_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ def __post_init__(self, sequence_examples: List[tf.train.SequenceExample]):
]
object.__setattr__(self, 'length', sum(lengths))

# TODO: is it necessary to return keys AND reward_stats(which has the keys)?
assert (len(self.serialized_sequence_examples) == len(self.rewards) ==
(len(self.keys)))
assert set(self.keys) == set(self.reward_stats.keys())
Expand All @@ -228,6 +229,15 @@ def __post_init__(self, sequence_examples: List[tf.train.SequenceExample]):
class CompilationRunnerStub(metaclass=abc.ABCMeta):
"""The interface of a stub to CompilationRunner, for type checkers."""

@abc.abstractmethod
def collect_results(
self,
module_spec: corpus.ModuleSpec,
tf_policy_path: str,
collect_default_result: bool,
reward_only: bool = False) -> Tuple[Optional[Dict], Optional[Dict]]:
raise NotImplementedError()

@abc.abstractmethod
def collect_data(
self, module_spec: corpus.ModuleSpec, tf_policy_path: str,
Expand Down Expand Up @@ -275,6 +285,46 @@ def enable(self):
def cancel_all_work(self):
self._cancellation_manager.kill_all_processes()

@staticmethod
def get_rewards(result: Dict) -> List[float]:
return [v[1] for v in result.values()]

def collect_results(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think it is over-engineering to refactor the collect_data function with the collect_results function, the options in the collect_results here are too complicated and the return type is also confusing, the name of collect_results and collect_data is also confusing.

I recommend to just write a function for your evaluation purpose, in this way the function args and return types will be much clearer

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's sounds reasonable. I'll just make _compile_fn public scoped, that's essentially all I need anyway

self,
module_spec: corpus.ModuleSpec,
tf_policy_path: str,
collect_default_result: bool,
reward_only: bool = False) -> Tuple[Optional[Dict], Optional[Dict]]:
"""Collect data for the given IR file and policy.

Args:
module_spec: a ModuleSpec.
tf_policy_path: path to the tensorflow policy.
collect_default_result: whether to get the default result as well.
reward_only: whether to only collect the rewards in the results.

Returns:
A tuple of the default result and policy result.
"""
default_result = None
policy_result = None
if collect_default_result:
default_result = self._compile_fn(
module_spec,
tf_policy_path='',
reward_only=bool(tf_policy_path) or reward_only,
cancellation_manager=self._cancellation_manager)
policy_result = default_result

if tf_policy_path:
policy_result = self._compile_fn(
module_spec,
tf_policy_path,
reward_only=reward_only,
cancellation_manager=self._cancellation_manager)

return default_result, policy_result

def collect_data(
self, module_spec: corpus.ModuleSpec, tf_policy_path: str,
reward_stat: Optional[Dict[str, RewardStat]]) -> CompilationResult:
Expand All @@ -284,8 +334,6 @@ def collect_data(
module_spec: a ModuleSpec.
tf_policy_path: path to the tensorflow policy.
reward_stat: reward stat of this module, None if unknown.
cancellation_token: a CancellationToken through which workers may be
signaled early termination

Returns:
A CompilationResult. In particular:
Expand All @@ -297,25 +345,18 @@ def collect_data(
compilation_runner.ProcessKilledException is passed through.
ValueError if example under default policy and ml policy does not match.
"""
default_result, policy_result = self.collect_results(
module_spec,
tf_policy_path,
collect_default_result=reward_stat is None,
reward_only=False)
if reward_stat is None:
default_result = self._compile_fn(
module_spec,
tf_policy_path='',
reward_only=bool(tf_policy_path),
cancellation_manager=self._cancellation_manager)
# TODO: Add structure to default_result and policy_result.
# get_rewards above should be updated/removed when this is resolved.
reward_stat = {
k: RewardStat(v[1], v[1]) for (k, v) in default_result.items()
}

if tf_policy_path:
policy_result = self._compile_fn(
module_spec,
tf_policy_path,
reward_only=False,
cancellation_manager=self._cancellation_manager)
else:
policy_result = default_result

sequence_example_list = []
rewards = []
keys = []
Expand Down
23 changes: 23 additions & 0 deletions compiler_opt/rl/compilation_runner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,29 @@ def test_default(self, mock_compile_fn):
}, data.reward_stats)
self.assertAllClose([0], data.rewards)

@mock.patch(constant.BASE_MODULE_DIR +
'.compilation_runner.CompilationRunner._compile_fn')
def test_reward_only(self, mock_compile_fn):
mock_compile_fn.side_effect = _mock_compile_fn
runner = compilation_runner.CompilationRunner(
moving_average_decay_rate=_MOVING_AVERAGE_DECAY_RATE)
default_result, policy_result = runner.collect_results(
module_spec=corpus.ModuleSpec(name='dummy'),
tf_policy_path='policy_path',
collect_default_result=True,
reward_only=True)
self.assertEqual(2, mock_compile_fn.call_count)

self.assertIsNotNone(default_result)
self.assertIsNotNone(policy_result)

self.assertEqual(
[_DEFAULT_REWARD],
compilation_runner.CompilationRunner.get_rewards(default_result))
self.assertEqual(
[_POLICY_REWARD],
compilation_runner.CompilationRunner.get_rewards(policy_result))

@mock.patch(constant.BASE_MODULE_DIR +
'.compilation_runner.CompilationRunner._compile_fn')
def test_given_default_size(self, mock_compile_fn):
Expand Down