Skip to content

Commit

Permalink
replace DepthwiseConv1D by stream.Stream(cell=tf.keras.layers.Depthwi…
Browse files Browse the repository at this point in the history
…seConv2D()) it simplifies the design and inference engine calls DepthwiseConv2D anyway

PiperOrigin-RevId: 373460893
  • Loading branch information
rybakov authored and copybara-github committed May 12, 2021
1 parent 5a0601a commit affafc6
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 50 deletions.
39 changes: 26 additions & 13 deletions kws_streaming/layers/svdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
# limitations under the License.

"""SVDF layer."""
from kws_streaming.layers import depthwise_conv1d
from kws_streaming.layers import modes
from kws_streaming.layers import non_scaling_dropout
from kws_streaming.layers import stream
from kws_streaming.layers.compat import tf


Expand Down Expand Up @@ -84,13 +84,17 @@ def build(self, input_shape):
self.dropout1 = tf.keras.layers.Lambda(lambda x, training: x)
self.dense1 = tf.keras.layers.Dense(
units=self.units1, use_bias=self.use_bias1)
self.depth_cnn1 = depthwise_conv1d.DepthwiseConv1D(
memory_size=self.memory_size,
self.depth_cnn1 = stream.Stream(
cell=tf.keras.layers.DepthwiseConv2D(
kernel_size=(self.memory_size, 1),
strides=(1, 1),
padding='valid',
dilation_rate=(1, 1),
use_bias=self.use_bias),
inference_batch_size=self.inference_batch_size,
use_bias=self.use_bias,
mode=self.mode,
pad=self.pad,
state_name_tag=self.state_name_tag)
use_one_step=False,
pad_time_dim=self.pad)
if self.units2 > 0:
self.dense2 = tf.keras.layers.Dense(units=self.units2, use_bias=True)
else:
Expand All @@ -114,13 +118,22 @@ def compute_output_shape(self, input_shape):
return output_shape

def call(self, inputs, training=None):
output = self.dropout1(inputs, training=training)
output = self.dense1(output)
output = self.depth_cnn1(output)
output = self.batch_norm(output, training=training)
output = self.activation(output)
output = self.dense2(output)
return output
net = inputs

# add fake dim [batch, time, 1, feature]
net = tf.keras.backend.expand_dims(net, axis=2)

net = self.dropout1(net, training=training)
net = self.dense1(net)
net = self.depth_cnn1(net)
net = self.batch_norm(net, training=training)
net = self.activation(net)
net = self.dense2(net)

# [batch, time, feature]
net = tf.squeeze(net, [2])

return net

def get_config(self):
config = {
Expand Down
88 changes: 51 additions & 37 deletions kws_streaming/layers/svdf_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ def _run_non_stream_model(self):
None,
self.input_data.shape[2],
))

svdf_layer = svdf.Svdf(
units1=self.weights[0].shape[1],
memory_size=self.memory_size,
Expand All @@ -47,7 +46,10 @@ def _run_non_stream_model(self):
mode=mode)
output_tf = svdf_layer(inputs=input_tf)
svdf_layer.dense1.set_weights([self.weights[0]])
svdf_layer.depth_cnn1.set_weights([self.weights[1], self.weights[2]])
depth_cnn_weight = self.weights[1]
depth_cnn_weight = np.expand_dims(depth_cnn_weight, 1)
depth_cnn_weight = np.expand_dims(depth_cnn_weight, 3)
svdf_layer.depth_cnn1.cell.set_weights([depth_cnn_weight, self.weights[2]])
svdf_layer.dense2.set_weights([self.weights[3], self.weights[4]])

model_tf = tf.keras.models.Model(input_tf, output_tf)
Expand All @@ -63,7 +65,7 @@ def test_streaming_inference_internal_state(self):
input_tf = tf.keras.layers.Input(shape=(
1,
self.input_data.shape[2],
))
), batch_size=None)

svdf_layer = svdf.Svdf(
units1=self.weights[0].shape[1],
Expand All @@ -74,12 +76,15 @@ def test_streaming_inference_internal_state(self):
mode=mode)
output_tf = svdf_layer(inputs=input_tf)

input_states_np = np.zeros(
[self.batch_size, self.memory_size, self.weights[1].shape[-1]])

svdf_layer.dense1.set_weights([self.weights[0]])
depth_cnn_weight = self.weights[1]
depth_cnn_weight = np.expand_dims(depth_cnn_weight, 1)
depth_cnn_weight = np.expand_dims(depth_cnn_weight, 3)

input_states_np = np.zeros(svdf_layer.depth_cnn1.get_weights()[2].shape)

svdf_layer.depth_cnn1.set_weights(
[self.weights[1], self.weights[2], input_states_np])
[depth_cnn_weight, self.weights[2], input_states_np])
svdf_layer.dense2.set_weights([self.weights[3], self.weights[4]])
model = tf.keras.models.Model(input_tf, output_tf)

Expand All @@ -92,36 +97,45 @@ def test_streaming_inference_internal_state(self):

def test_streaming_inference_external_state(self):

output_non_stream_np, model_tf = self._run_non_stream_model()

# input data for streaming stateless model
input_tensors = [
tf.keras.layers.Input(
shape=(
1,
self.input_data.shape[2],
),
batch_size=self.batch_size,
dtype=tf.float32)
]

# convert non streaming trainable model to streaming one with external state
mode = modes.Modes.STREAM_EXTERNAL_STATE_INFERENCE
model_stream = utils.convert_to_inference_model(model_tf, input_tensors,
mode)

input_states_np = np.zeros(
[self.batch_size, self.memory_size, self.weights[1].shape[-1]])

# streaming emulation: loop over every element in time
for i in range(self.input_data.shape[1]):
input_batch_np = self.input_data[:, i, :]
input_batch_np = np.expand_dims(input_batch_np, 1)
output_np, output_states_np = model_stream.predict(
[input_batch_np, input_states_np])
input_states_np = output_states_np
for b in range(self.input_data.shape[0]): # loop over batch
self.assertAllClose(output_np[b][0], output_non_stream_np[b][i])
with tf1.Session() as sess:
output_non_stream_np, model_tf = self._run_non_stream_model()

# input data for streaming stateless model
input_tensors = [
tf.keras.layers.Input(
shape=(
1,
self.input_data.shape[2],
),
batch_size=self.batch_size,
dtype=tf.float32)
]

# convert non streaming model to streaming one with external state
mode = modes.Modes.STREAM_EXTERNAL_STATE_INFERENCE
model_stream = utils.convert_to_inference_model(model_tf, input_tensors,
mode)

# validate that model is convertable to tflite
converter = tf1.lite.TFLiteConverter.from_session(
sess, model_stream.inputs, model_stream.outputs)
self.assertTrue(converter.convert())

inputs = []
for s in range(len(model_stream.inputs)):
inputs.append(np.zeros(model_stream.inputs[s].shape, dtype=np.float32))

# streaming emulation: loop over every element in time
for i in range(self.input_data.shape[1]):
input_batch_np = self.input_data[:, i, :]
input_batch_np = np.expand_dims(input_batch_np, 1)
inputs[0] = input_batch_np
outputs = model_stream.predict(inputs)
# input_states_np = output_states_np
for s in range(1, len(model_stream.inputs)):
inputs[s] = outputs[s]
for b in range(self.input_data.shape[0]): # loop over batch
self.assertAllClose(outputs[0][b][0], output_non_stream_np[b][i])

def test_training(self):
# Test stateful svdf layer in training mode.
Expand Down

0 comments on commit affafc6

Please sign in to comment.