Skip to content

Commit

Permalink
Remove metrics placeholder from training loop (blue-oil#959)
Browse files Browse the repository at this point in the history
  • Loading branch information
Taketoshi Fujiwara authored and oatawa1 committed Mar 31, 2020
1 parent e21d0ae commit e5d138a
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 43 deletions.
15 changes: 3 additions & 12 deletions blueoil/cmd/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,7 @@ def evaluate(config, restore_path, output_dir):
model.summary(output, labels_placeholder)

summary_op = tf.compat.v1.summary.merge_all()

metrics_summary_op, metrics_placeholders = executor.prepare_metrics(metrics_ops_dict)
metrics_summary_op = executor.metrics_summary_op(metrics_ops_dict)

init_op = tf.compat.v1.global_variables_initializer()
reset_metrics_op = tf.compat.v1.local_variables_initializer()
Expand Down Expand Up @@ -140,14 +139,7 @@ def evaluate(config, restore_path, output_dir):
else:
sess.run([metrics_update_op], feed_dict=feed_dict)

metrics_values = sess.run(list(metrics_ops_dict.values()))
metrics_feed_dict = {
# TODO: Fix to avoid the implementation depended on the order of dict implicitly
placeholder: value for placeholder, value in zip(metrics_placeholders, metrics_values)
}
metrics_summary, = sess.run(
[metrics_summary_op], feed_dict=metrics_feed_dict,
)
metrics_summary = sess.run(metrics_summary_op)
validation_writer.add_summary(metrics_summary, last_step)

is_tfds = "TFDS_KWARGS" in config.DATASET
Expand All @@ -160,8 +152,7 @@ def evaluate(config, restore_path, output_dir):
'dataset_name': dataset_name,
'dataset_path': dataset_path,
'last_step': int(last_step),
# TODO: Fix to avoid the implementation depended on the order of dict implicitly
'metrics': {k: float(v) for k, v in zip(list(metrics_ops_dict.keys()), metrics_values)},
'metrics': {k: float(sess.run(op)) for k, op in metrics_ops_dict.items()},
}
save_json(output_dir, json.dumps(metrics_dict, indent=4,), metrics_dict["last_step"])
validation_dataset.close()
Expand Down
18 changes: 3 additions & 15 deletions blueoil/cmd/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,7 @@ def start_training(config):
model.summary(output, labels_placeholder)

summary_op = tf.compat.v1.summary.merge_all()

metrics_summary_op, metrics_placeholders = executor.prepare_metrics(metrics_ops_dict)
metrics_summary_op = executor.metrics_summary_op(metrics_ops_dict)

init_op = tf.compat.v1.global_variables_initializer()
reset_metrics_op = tf.compat.v1.local_variables_initializer()
Expand Down Expand Up @@ -223,12 +222,7 @@ def start_training(config):
# train_writer.add_run_metadata(run_metadata, "step: {}".format(step + 1))
train_writer.add_summary(summary, step + 1)

metrics_values = sess.run(list(metrics_ops_dict.values()))
metrics_feed_dict = {placeholder: value for placeholder, value in zip(metrics_placeholders, metrics_values)}

metrics_summary, = sess.run(
[metrics_summary_op], feed_dict=metrics_feed_dict,
)
metrics_summary = sess.run(metrics_summary_op)
train_writer.add_summary(metrics_summary, step + 1)
train_writer.flush()
else:
Expand Down Expand Up @@ -261,13 +255,7 @@ def start_training(config):
else:
sess.run([metrics_update_op], feed_dict=feed_dict)

metrics_values = sess.run(list(metrics_ops_dict.values()))
metrics_feed_dict = {
placeholder: value for placeholder, value in zip(metrics_placeholders, metrics_values)
}
metrics_summary, = sess.run(
[metrics_summary_op], feed_dict=metrics_feed_dict,
)
metrics_summary = sess.run(metrics_summary_op)
if rank == 0:
val_writer.add_summary(metrics_summary, step + 1)
val_writer.flush()
Expand Down
24 changes: 8 additions & 16 deletions blueoil/utils/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,28 +103,20 @@ def save_pb_file(sess, output_dir, output_node_names=["output"], pb_name="minima
return pb_name


def prepare_metrics(metrics_ops_dict):
"""Create summary_op and placeholders for training metrics.
def metrics_summary_op(metrics_ops_dict):
"""Create summary_op for training metrics.
Args:
metrics_ops_dict (dict): dict of name and metrics_op.
Returns:
metrics_summary_op: summary op of metrics.
metrics_placeholders: list of metrics placeholder.
"""
with tf.compat.v1.name_scope("metrics"):
metrics_placeholders = []
metrics_summaries = []
for (metrics_key, metrics_op) in metrics_ops_dict.items():
metrics_placeholder = tf.compat.v1.placeholder(
tf.float32, name="{}_placeholder".format(metrics_key)
)
summary = tf.compat.v1.summary.scalar(metrics_key, metrics_placeholder)
metrics_placeholders.append(metrics_placeholder)
metrics_summaries.append(summary)

metrics_summary_op = tf.compat.v1.summary.merge(metrics_summaries)

return metrics_summary_op, metrics_placeholders
metrics_summaries = [
tf.compat.v1.summary.scalar(metrics_key, metrics_op)
for (metrics_key, metrics_op) in metrics_ops_dict.items()
]

return tf.compat.v1.summary.merge(metrics_summaries)

0 comments on commit e5d138a

Please sign in to comment.