From db5569ce8d202f77154f288c21d3f2fa228f9aa3 Mon Sep 17 00:00:00 2001 From: TFiFiE Date: Fri, 3 May 2019 13:10:20 +0200 Subject: [PATCH] Take network size in TFProcess constructor. * Take network size in TFProcess constructor. * Make number of blocks and filters required in call to parse.py. Pull request #2361. --- README.md | 10 +++++----- training/tf/net_to_model.py | 8 +------- training/tf/parse.py | 20 +++++++++++++++++--- training/tf/tfprocess.py | 30 +++++++++++++++--------------- 4 files changed, 38 insertions(+), 30 deletions(-) diff --git a/README.md b/README.md index 059ba1314..8d2577686 100644 --- a/README.md +++ b/README.md @@ -318,13 +318,13 @@ This requires a working installation of TensorFlow 1.4 or later: src/leelaz -w weights.txt dump_supervised bigsgf.sgf train.out exit - training/tf/parse.py train.out + training/tf/parse.py 6 128 train.out -This will run and regularly dump Leela Zero weight files to disk, as -well as snapshots of the learning state numbered by the batch number. -If interrupted, training can be resumed with: +This will run and regularly dump Leela Zero weight files (of networks with 6 +blocks and 128 filters) to disk, as well as snapshots of the learning state +numbered by the batch number. If interrupted, training can be resumed with: - training/tf/parse.py train.out leelaz-model-batchnumber + training/tf/parse.py 6 128 train.out leelaz-model-batchnumber # Todo diff --git a/training/tf/net_to_model.py b/training/tf/net_to_model.py index 80bd08a27..5a2e58b2c 100755 --- a/training/tf/net_to_model.py +++ b/training/tf/net_to_model.py @@ -23,14 +23,8 @@ blocks //= 8 print("Blocks", blocks) -tfprocess = TFProcess() +tfprocess = TFProcess(blocks, channels) tfprocess.init(batch_size=1, gpus_num=1) -if tfprocess.RESIDUAL_BLOCKS != blocks: - raise ValueError("Number of blocks in tensorflow model doesn't match "\ - "number of blocks in input network") -if tfprocess.RESIDUAL_FILTERS != channels: - raise ValueError("Number of filters in tensorflow model doesn't match "\ - "number of filters in input network") tfprocess.replace_weights(weights) path = os.path.join(os.getcwd(), "leelaz-model") save_path = tfprocess.saver.save(tfprocess.session, path, global_step=0) diff --git a/training/tf/parse.py b/training/tf/parse.py index 04ca24d7f..e67774da8 100755 --- a/training/tf/parse.py +++ b/training/tf/parse.py @@ -107,24 +107,38 @@ def split_chunks(chunks, test_ratio): def main(): parser = argparse.ArgumentParser( description='Train network from game data.') + parser.add_argument("blockspref", + help="Number of blocks", nargs='?', type=int) + parser.add_argument("filterspref", + help="Number of filters", nargs='?', type=int) parser.add_argument("trainpref", help='Training file prefix', nargs='?', type=str) parser.add_argument("restorepref", help='Training snapshot prefix', nargs='?', type=str) + parser.add_argument("--blocks", '-b', + help="Number of blocks", type=int) + parser.add_argument("--filters", '-f', + help="Number of filters", type=int) parser.add_argument("--train", '-t', help="Training file prefix", type=str) parser.add_argument("--test", help="Test file prefix", type=str) parser.add_argument("--restore", type=str, help="Prefix of tensorflow snapshot to restore from") parser.add_argument("--logbase", default='leelalogs', type=str, - help="Log file prefix (for tensorboard)") + help="Log file prefix (for tensorboard) (default: %(default)s)") parser.add_argument("--sample", default=DOWN_SAMPLE, type=int, - help="Rate of data down-sampling to use") + help="Rate of data down-sampling to use (default: %(default)d)") args = parser.parse_args() + blocks = args.blocks or args.blockspref + filters = args.filters or args.filterspref train_data_prefix = args.train or args.trainpref restore_prefix = args.restore or args.restorepref + if not blocks or not filters: + print("Must supply number of blocks and filters") + return + training = get_chunks(train_data_prefix) if not args.test: # Generate test by taking 10% of the training chunks. @@ -150,7 +164,7 @@ def main(): sample=args.sample, batch_size=RAM_BATCH_SIZE).parse() - tfprocess = TFProcess() + tfprocess = TFProcess(blocks, filters) tfprocess.init(RAM_BATCH_SIZE, logbase=args.logbase, macrobatch=BATCH_SIZE // RAM_BATCH_SIZE) diff --git a/training/tf/tfprocess.py b/training/tf/tfprocess.py index 1f0efebcf..be66df9e0 100644 --- a/training/tf/tfprocess.py +++ b/training/tf/tfprocess.py @@ -111,10 +111,10 @@ def elapsed(self): return e class TFProcess: - def __init__(self): + def __init__(self, residual_blocks, residual_filters): # Network structure - self.RESIDUAL_FILTERS = 128 - self.RESIDUAL_BLOCKS = 6 + self.residual_blocks = residual_blocks + self.residual_filters = residual_filters # model type: full precision (fp32) or mixed precision (fp16) self.model_dtype = tf.float32 @@ -602,17 +602,17 @@ def construct_net(self, planes): # Input convolution flow = self.conv_block(x_planes, filter_size=3, input_channels=18, - output_channels=self.RESIDUAL_FILTERS, + output_channels=self.residual_filters, name="first_conv") # Residual tower - for i in range(0, self.RESIDUAL_BLOCKS): + for i in range(0, self.residual_blocks): block_name = "res_" + str(i) - flow = self.residual_block(flow, self.RESIDUAL_FILTERS, + flow = self.residual_block(flow, self.residual_filters, name=block_name) # Policy head conv_pol = self.conv_block(flow, filter_size=1, - input_channels=self.RESIDUAL_FILTERS, + input_channels=self.residual_filters, output_channels=2, name="policy_head") h_conv_pol_flat = tf.reshape(conv_pol, [-1, 2 * 19 * 19]) @@ -624,7 +624,7 @@ def construct_net(self, planes): # Value head conv_val = self.conv_block(flow, filter_size=1, - input_channels=self.RESIDUAL_FILTERS, + input_channels=self.residual_filters, output_channels=1, name="value_head") h_conv_val_flat = tf.reshape(conv_val, [-1, 19 * 19]) @@ -707,21 +707,21 @@ def gen_block(size, f_in, f_out): class TFProcessTest(unittest.TestCase): def test_can_replace_weights(self): - tfprocess = TFProcess() + tfprocess = TFProcess(6, 128) tfprocess.init(batch_size=1) # use known data to test replace_weights() works. - data = gen_block(3, 18, tfprocess.RESIDUAL_FILTERS) # input conv - for _ in range(tfprocess.RESIDUAL_BLOCKS): + data = gen_block(3, 18, tfprocess.residual_filters) # input conv + for _ in range(tfprocess.residual_blocks): data.extend(gen_block(3, - tfprocess.RESIDUAL_FILTERS, tfprocess.RESIDUAL_FILTERS)) + tfprocess.residual_filters, tfprocess.residual_filters)) data.extend(gen_block(3, - tfprocess.RESIDUAL_FILTERS, tfprocess.RESIDUAL_FILTERS)) + tfprocess.residual_filters, tfprocess.residual_filters)) # policy - data.extend(gen_block(1, tfprocess.RESIDUAL_FILTERS, 2)) + data.extend(gen_block(1, tfprocess.residual_filters, 2)) data.append([0.4] * 2*19*19 * (19*19+1)) data.append([0.5] * (19*19+1)) # value - data.extend(gen_block(1, tfprocess.RESIDUAL_FILTERS, 1)) + data.extend(gen_block(1, tfprocess.residual_filters, 1)) data.append([0.6] * 19*19 * 256) data.append([0.7] * 256) data.append([0.8] * 256)