Skip to content
This repository has been archived by the owner on Mar 11, 2021. It is now read-only.

"sliding window" bigtable training mode #713

Open
wants to merge 17 commits into
base: master
Choose a base branch
from
34 changes: 34 additions & 0 deletions preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,40 @@ def get_tpu_bt_input_tensors(games, games_nr, batch_size, num_repeats=1,
return dataset


def get_many_tpu_bt_input_tensors(games, games_nr, batch_size,
start_at, num_datasets,
moves=2**21,
window_size=500e3,
window_increment=25000):
dataset = None
for i in range(num_datasets):
# TODO(amj) mixin calibration games with some math. (from start_at that
# is proportionally along compared to last_game_number? comparing
# timestamps?)
ds = games.moves_from_games(start_at + (i * window_increment),
start_at + (i * window_increment) + window_size,
moves=moves,
shuffle=True,
column_family=bigtable_input.TFEXAMPLE,
column='example')
dataset = dataset.concatenate(ds) if dataset else ds
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regarding the general approach: if the training loop does multiple scans, I would expect to create a new dataset for each pass, rather than try to create a single enormous dataset, which I imagine would be harder to debug, inspect, etc.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, but multiple calls to tpuestimator.train will create new graphs :( I am not sure what a good solution for lazy evaluating of these Datasets would be. As it is, it takes a real long time to build the datasets before training even starts -- i suspect the concatenate is doing something bad as things get slower and slower.


dataset = dataset.repeat(1)
dataset = dataset.map(lambda row_name, s: s)
dataset = dataset.batch(batch_size,drop_remainder=False)
dataset = dataset.map(
functools.partial(batch_parse_tf_example, batch_size))
# Unbatch the dataset so we can rotate it
dataset = dataset.apply(tf.contrib.data.unbatch())
dataset = dataset.apply(tf.contrib.data.map_and_batch(
_random_rotation_pure_tf,
batch_size,
drop_remainder=True))

dataset = dataset.prefetch(tf.contrib.data.AUTOTUNE)
return dataset


def make_dataset_from_selfplay(data_extracts):
'''
Returns an iterable of tf.Examples.
Expand Down
30 changes: 30 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,36 @@ def after_run(self, run_context, run_values):
self.before_weights = None


def train_many(start_at=1000000, num_datasets=3):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you expose moves here also.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what do you mean? number of steps?

""" Trains on a set of bt_datasets, skipping eval for now.
(from preprocessing.get_many_tpu_bt_input_tensors)
"""
if not FLAGS.use_tpu and FLAGS.use_bt:
raise ValueError("Only tpu & bt mode supported")

tf.logging.set_verbosity(tf.logging.INFO)
estimator = dual_net.get_estimator()
effective_batch_size = FLAGS.train_batch_size * FLAGS.num_tpu_cores

def _input_fn(params):
games = bigtable_input.GameQueue(
FLAGS.cbt_project, FLAGS.cbt_instance, FLAGS.cbt_table)
games_nr = bigtable_input.GameQueue(
FLAGS.cbt_project, FLAGS.cbt_instance, FLAGS.cbt_table + '-nr')

return preprocessing.get_many_tpu_bt_input_tensors(
games, games_nr, params['batch_size'],
start_at=start_at, num_datasets=num_datasets)

hooks = []
steps = num_datasets * FLAGS.steps_to_train
logging.info("Training, steps = %s, batch = %s -> %s examples",
steps or '?', effective_batch_size,
(steps * effective_batch_size) if steps else '?')

estimator.train(_input_fn, steps=steps, hooks=hooks)


def train(*tf_records: "Records to train on"):
"""Train on examples."""
tf.logging.set_verbosity(tf.logging.INFO)
Expand Down