From 01580ae463e2694e1c8a5f5bb7179f453d334df2 Mon Sep 17 00:00:00 2001 From: Nick Hynes Date: Mon, 18 Mar 2019 17:40:18 +0000 Subject: [PATCH 1/4] Allow converting keras.layers.Sequential --- python/tvm/relay/frontend/keras.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/python/tvm/relay/frontend/keras.py b/python/tvm/relay/frontend/keras.py index f6f2d99e2ea5..939f8761f9e5 100644 --- a/python/tvm/relay/frontend/keras.py +++ b/python/tvm/relay/frontend/keras.py @@ -661,12 +661,15 @@ def from_keras(model, shape=None): raise ValueError("Keras frontend currently supports data_format = channels_last only.") _check_unsupported_layers(model) + def _convert_input_layer(keras_layer): + input_name = keras_layer.name + input_shape = shape[input_name] if shape is not None and input_name in shape else None + etab.set_expr(input_name, _expr.var(input_name, shape=input_shape)) + etab = ExprTable() for keras_layer in model.layers: if isinstance(keras_layer, keras.engine.InputLayer): - input_name = keras_layer.name - input_shape = shape[input_name] if shape is not None and input_name in shape else None - etab.set_expr(input_name, _expr.var(input_name, shape=input_shape)) + _convert_input_layer(keras_layer) else: inbound_nodes = keras_layer.inbound_nodes if hasattr(keras_layer, 'inbound_nodes') \ else keras_layer._inbound_nodes if hasattr(keras_layer, '_inbound_nodes') \ @@ -690,6 +693,7 @@ def from_keras(model, shape=None): for n_idx, t_idx, inbound_layer in zip_node: if isinstance(inbound_layer, keras.engine.InputLayer): expr_name = inbound_layer.name + _convert_input_layer(inbound_layer) else: expr_name = inbound_layer.name + ':' + str(n_idx) + ':' + str(t_idx) expr = etab.get_expr(expr_name) From c51518e24d32b064df80fea7703bd86c815370fb Mon Sep 17 00:00:00 2001 From: Nick Hynes Date: Mon, 18 Mar 2019 21:33:02 +0000 Subject: [PATCH 2/4] Use existing new_var function --- python/tvm/relay/frontend/keras.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/frontend/keras.py b/python/tvm/relay/frontend/keras.py index 939f8761f9e5..a865f08243eb 100644 --- a/python/tvm/relay/frontend/keras.py +++ b/python/tvm/relay/frontend/keras.py @@ -7,7 +7,7 @@ from .. import expr as _expr from .. import op as _op from ... import nd as _nd -from .common import ExprTable +from .common import ExprTable, new_var __all__ = ['from_keras'] @@ -664,7 +664,7 @@ def from_keras(model, shape=None): def _convert_input_layer(keras_layer): input_name = keras_layer.name input_shape = shape[input_name] if shape is not None and input_name in shape else None - etab.set_expr(input_name, _expr.var(input_name, shape=input_shape)) + etab.set_expr(input_name, new_var(input_name, shape=input_shape)) etab = ExprTable() for keras_layer in model.layers: From a744cc3159272c93a6e7f01145a9d0324cfb6aa9 Mon Sep 17 00:00:00 2001 From: Nick Hynes Date: Mon, 25 Mar 2019 07:30:05 +0000 Subject: [PATCH 3/4] Only update expr when missing --- python/tvm/relay/frontend/common.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/common.py b/python/tvm/relay/frontend/common.py index ef9f63f3cd95..2871b7f73163 100644 --- a/python/tvm/relay/frontend/common.py +++ b/python/tvm/relay/frontend/common.py @@ -255,7 +255,8 @@ def get_expr(self, name): def set_expr(self, name, expr): assert isinstance(expr, _expr.Expr) - self.exprs[name] = expr + if name not in self.exprs: + self.exprs[name] = expr def set_padding(self, paddings): self.paddings = paddings From fea2f1ac2effd3fbfc2aa8a27a06afa2d87abdc9 Mon Sep 17 00:00:00 2001 From: Nick Hynes Date: Mon, 25 Mar 2019 18:35:56 +0000 Subject: [PATCH 4/4] Add test --- tests/python/frontend/keras/test_forward.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tests/python/frontend/keras/test_forward.py b/tests/python/frontend/keras/test_forward.py index baa2e4fc203f..90c07ac09042 100644 --- a/tests/python/frontend/keras/test_forward.py +++ b/tests/python/frontend/keras/test_forward.py @@ -106,6 +106,17 @@ def test_forward_dense(): verify_keras_frontend(keras_model) +def test_forward_sequential(): + keras_model = keras.models.Sequential([ + keras.layers.Dense(16, input_dim=32, activation='relu'), + keras.layers.Dropout(0.5), + keras.layers.Dense(8, activation='relu'), + keras.layers.Dropout(0.5), + keras.layers.Dense(1, activation='sigmoid') + ]) + verify_keras_frontend(keras_model) + + def test_forward_pool(): data = keras.layers.Input(shape=(32,32,1)) # maxpool @@ -244,6 +255,7 @@ def test_forward_mobilenet(): test_forward_merge() test_forward_activations() test_forward_dense() + test_forward_sequential() test_forward_pool() test_forward_conv() test_forward_upsample(interpolation='nearest')