Skip to content

Commit

Permalink
FIX: Invariant input_spec for WeightNormalization (#687)
Browse files Browse the repository at this point in the history
* Make the first dimension `None` to support invariant batch size.
* Add test case to check compatibility of WeightNormalization with
  TimeDistributed.
  • Loading branch information
Squadrick authored and seanpmorgan committed Nov 8, 2019
1 parent 895d11d commit 776b751
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 1 deletion.
3 changes: 2 additions & 1 deletion tensorflow_addons/layers/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ def __init__(self, layer, data_init=True, **kwargs):
def build(self, input_shape):
"""Build `Layer`"""
input_shape = tf.TensorShape(input_shape).as_list()
self.input_spec = tf.keras.layers.InputSpec(shape=input_shape)
self.input_spec = tf.keras.layers.InputSpec(
shape=[None] + input_shape[1:])

if not self.layer.built:
self.layer.build(input_shape)
Expand Down
8 changes: 8 additions & 0 deletions tensorflow_addons/layers/wrappers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,14 @@ def test_weightnorm_non_kernel_layer(self):
wn_wrapper = wrappers.WeightNormalization(non_kernel_layer)
wn_wrapper(images)

def test_weightnorm_with_time_dist(self):
batch_shape = (32, 16, 64, 64, 3)
inputs = tf.keras.layers.Input(batch_shape=batch_shape)
a = tf.keras.layers.Conv2D(3, 5)
b = wrappers.WeightNormalization(a)
out = tf.keras.layers.TimeDistributed(b)(inputs)
model = tf.keras.Model(inputs, out)


if __name__ == "__main__":
tf.test.main()

0 comments on commit 776b751

Please sign in to comment.