Skip to content
This repository has been archived by the owner on Mar 11, 2021. It is now read-only.

Add an average_winrate_observed metric to tensorboard #896

Merged
merged 2 commits into from
Sep 24, 2019
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions dual_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,8 +279,9 @@ def model_fn(features, labels, mode, params):
train_op = optimizer.minimize(combined_cost, global_step=global_step)

# Computations to be executed on CPU, outside of the main TPU queues.
def eval_metrics_host_call_fn(policy_output, value_output, pi_tensor, policy_cost,
value_cost, l2_cost, combined_cost, step,
def eval_metrics_host_call_fn(policy_output, value_output, pi_tensor,
value_tensor, policy_cost, value_cost,
l2_cost, combined_cost, step,
est_mode=tf.estimator.ModeKeys.TRAIN):
policy_entropy = -tf.reduce_mean(tf.reduce_sum(
policy_output * tf.log(policy_output), axis=1))
Expand All @@ -299,6 +300,7 @@ def eval_metrics_host_call_fn(policy_output, value_output, pi_tensor, policy_cos
tf.one_hot(policy_target_top_1, tf.shape(policy_output)[1]))

value_cost_normalized = value_cost / params['value_cost_weight']
avg_value_observed = tf.reduce_mean(value_tensor)

with tf.variable_scope('metrics'):
metric_ops = {
Expand All @@ -308,7 +310,7 @@ def eval_metrics_host_call_fn(policy_output, value_output, pi_tensor, policy_cos
'l2_cost': tf.metrics.mean(l2_cost),
'policy_entropy': tf.metrics.mean(policy_entropy),
'combined_cost': tf.metrics.mean(combined_cost),

'avg_value_observed': tf.metrics.mean(avg_value_observed),
'policy_accuracy_top_1': tf.metrics.mean(policy_output_in_top1),
'policy_accuracy_top_3': tf.metrics.mean(policy_output_in_top3),
'policy_top_1_confidence': tf.metrics.mean(policy_top_1_confidence),
Expand Down Expand Up @@ -345,6 +347,7 @@ def eval_metrics_host_call_fn(policy_output, value_output, pi_tensor, policy_cos
policy_output,
value_output,
labels['pi_tensor'],
labels['value_tensor'],
tf.reshape(policy_cost, [1]),
tf.reshape(value_cost, [1]),
tf.reshape(l2_cost, [1]),
Expand Down