diff --git a/dual_net.py b/dual_net.py index eb57f585d..55004dea2 100644 --- a/dual_net.py +++ b/dual_net.py @@ -136,6 +136,11 @@ 'SE_ratio', 2, help='Squeeze and Excitation ratio.') +flags.DEFINE_bool( + 'use_swish', False, + help=('Use Swish activation function inplace of ReLu. ' + 'https://arxiv.org/pdf/1710.05941.pdf')) + # TODO(seth): Verify if this is still required. flags.register_multi_flags_validator( @@ -406,6 +411,9 @@ def model_inference_fn(features, training, params): data_format="channels_last") def mg_activation(inputs): + if FLAGS.use_swish: + return tf.nn.swish(inputs) + return tf.nn.relu(inputs)