diff --git a/preprocessing.py b/preprocessing.py index 595db3890..bd02da120 100644 --- a/preprocessing.py +++ b/preprocessing.py @@ -262,6 +262,40 @@ def get_tpu_bt_input_tensors(games, games_nr, batch_size, num_repeats=1, return dataset +def get_many_tpu_bt_input_tensors(games, games_nr, batch_size, + start_at, num_datasets, + moves=2**21, + window_size=500e3, + window_increment=25000): + dataset = None + for i in range(num_datasets): + # TODO(amj) mixin calibration games with some math. (from start_at that + # is proportionally along compared to last_game_number? comparing + # timestamps?) + ds = games.moves_from_games(start_at + (i * window_increment), + start_at + (i * window_increment) + window_size, + moves=moves, + shuffle=True, + column_family=bigtable_input.TFEXAMPLE, + column='example') + dataset = dataset.concatenate(ds) if dataset else ds + + dataset = dataset.repeat(1) + dataset = dataset.map(lambda row_name, s: s) + dataset = dataset.batch(batch_size,drop_remainder=False) + dataset = dataset.map( + functools.partial(batch_parse_tf_example, batch_size)) + # Unbatch the dataset so we can rotate it + dataset = dataset.apply(tf.contrib.data.unbatch()) + dataset = dataset.apply(tf.contrib.data.map_and_batch( + _random_rotation_pure_tf, + batch_size, + drop_remainder=True)) + + dataset = dataset.prefetch(tf.contrib.data.AUTOTUNE) + return dataset + + def make_dataset_from_selfplay(data_extracts): """ Returns an iterable of tf.Examples. diff --git a/train.py b/train.py index d6b24bcfe..e6300e2f5 100644 --- a/train.py +++ b/train.py @@ -54,6 +54,15 @@ flags.DEFINE_bool('freeze', False, 'Whether to freeze the graph at the end of training.') +flags.DEFINE_bool('train_many', False, + 'Whether to run train repeatedly, automatically incrementing the window') + +flags.DEFINE_integer('window_start_at', 10000000, + 'Used with `train_many`. The game number where the window begins') + +flags.DEFINE_integer('num_datasets', 3, + 'Used with `train_many`. The number of times to increment the window and re-train.') + flags.register_multi_flags_validator( ['use_bt', 'use_tpu'], @@ -139,6 +148,39 @@ def after_run(self, run_context, run_values): self.before_weights = None +def train_many(start_at=1000000, num_datasets=3, moves=2**24): + """ Trains on a set of bt_datasets, skipping eval for now. + (from preprocessing.get_many_tpu_bt_input_tensors) + """ + if not FLAGS.use_tpu and FLAGS.use_bt: + raise ValueError("Only tpu & bt mode supported") + + tf.logging.set_verbosity(tf.logging.INFO) + estimator = dual_net.get_estimator() + effective_batch_size = FLAGS.train_batch_size * FLAGS.num_tpu_cores + + def _input_fn(params): + games = bigtable_input.GameQueue( + FLAGS.cbt_project, FLAGS.cbt_instance, FLAGS.cbt_table) + games_nr = bigtable_input.GameQueue( + FLAGS.cbt_project, FLAGS.cbt_instance, FLAGS.cbt_table + '-nr') + + return preprocessing.get_many_tpu_bt_input_tensors( + games, games_nr, params['batch_size'], + moves=moves, + window_size=FLAGS.window_size, + start_at=start_at, num_datasets=num_datasets) + + hooks = [] + + steps = num_datasets * FLAGS.steps_to_train + logging.info("Training, steps = %s, batch = %s -> %s examples", + steps or '?', effective_batch_size, + (steps * effective_batch_size) if steps else '?') + + estimator.train(_input_fn, steps=steps, hooks=hooks) + + def train(*tf_records: "Records to train on"): """Train on examples.""" tf.logging.set_verbosity(tf.logging.INFO) @@ -209,11 +251,15 @@ def _input_fn(): def main(argv): """Train on examples and export the updated model weights.""" - tf_records = argv[1:] - logging.info("Training on %s records: %s to %s", - len(tf_records), tf_records[0], tf_records[-1]) - with utils.logged_timer("Training"): - train(*tf_records) + if FLAGS.train_many: + with utils.logged_timer("Training"): + train_many(FLAGS.window_start_at, FLAGS.num_datasets) + else: + tf_records = argv[1:] + logging.info("Training on %s records: %s to %s", + len(tf_records), tf_records[0], tf_records[-1]) + with utils.logged_timer("Training"): + train(*tf_records) if FLAGS.export_path: dual_net.export_model(FLAGS.export_path) if FLAGS.freeze: