From 21199deca54e68c00d325208b4f316a54f7d9935 Mon Sep 17 00:00:00 2001 From: Xingyu Zhou Date: Sun, 25 Jul 2021 00:54:22 -0700 Subject: [PATCH] [Frontend, Tensorflow2] Added support for TensorList ops (#8454) --- python/tvm/relay/frontend/tensorflow2.py | 206 +++++++++++++++++- python/tvm/relay/frontend/tensorflow2_ops.py | 179 +++++++++++++++ python/tvm/relay/frontend/tensorflow_ops.py | 12 + .../tensorflow2/test_functional_models.py | 136 ++++++++++++ .../tensorflow2/test_sequential_models.py | 55 +++++ 5 files changed, 583 insertions(+), 5 deletions(-) create mode 100644 python/tvm/relay/frontend/tensorflow2_ops.py diff --git a/python/tvm/relay/frontend/tensorflow2.py b/python/tvm/relay/frontend/tensorflow2.py index e5339b33c4e9..db900428d06d 100644 --- a/python/tvm/relay/frontend/tensorflow2.py +++ b/python/tvm/relay/frontend/tensorflow2.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=invalid-name, unused-argument, too-many-lines, len-as-condition, broad-except +# pylint: disable=invalid-name, unused-argument, too-many-lines, len-as-condition, broad-except, too-many-nested-blocks """Tensorflow2.x graph to relay converter. If model is constructed using tf2.x API, then use this converter: @@ -38,12 +38,20 @@ from .common import infer_type as _infer_type from .tensorflow_ops import _convert_map as _convert_map_common -from .tensorflow_ops import _need_prelude_for_shape_inference +from .tensorflow_ops import _get_more_static_shape_rank +from .tensorflow2_ops import _convert_map as _convert_map_tf2 +from .tensorflow2_ops import _need_prelude_for_shape_inference from ..ty import Any __all__ = ["from_tensorflow"] +# A map to record tensor list write ops and input tl/tensor indices +# Value is (index of tensor list, index of written node) +_tensor_list_write_ops = { + "TensorListSetItem": (0, 2), +} + def _infer_type_with_prelude(val, prelude): body = _infer_type(val, prelude.mod) @@ -66,6 +74,11 @@ def set_span(sym, node_name): return sym +def is_tensor_list_constuctor(tf_node): + """Check whether is tensor list constructor node.""" + return tf_node.op == "TensorListReserve" + + def convert_const_node(node, shape): """convert tf const node into relay const or var""" @@ -196,6 +209,10 @@ def __init__(self, module): self._output_shapes = {} self._tf_node_map = {} self._gdef_lib = {} + self._tensor_list_shapes = {} + self._tensor_list_shape_nodes = {} + self._sub_map = {} + self._sub_input_idx_map = {} def from_tensorflow( self, graph, layout="NHWC", shape=None, outputs=None, input_types=None, gdef_lib=None @@ -215,10 +232,134 @@ def from_tensorflow( ) return func, self._params + def _analysis_tensor_list_op( + self, + graph, + node, + tl_write_nodes, + tl_stack_nodes, + tl_construct_nodes, + sub_func_name="", + root_node="", + ): + if sub_func_name and sub_func_name not in self._sub_input_idx_map: + self._sub_input_idx_map[sub_func_name] = {} + + if node.op == "Placeholder": + # record placeholder node in sub functions + self._sub_map[sub_func_name] = node + self._sub_input_idx_map[sub_func_name][node.name] = len( + self._sub_input_idx_map[sub_func_name] + ) + + if node.op.startswith("TensorList"): + if is_tensor_list_constuctor(node): + tl_construct_nodes.append(node) + else: + for tl_write_name, idx in _tensor_list_write_ops.items(): + if node.op.startswith(tl_write_name): + tl_write_nodes.append((node, idx, sub_func_name, root_node)) + if node.op.startswith("TensorListStack"): + tl_stack_nodes.append(node) + elif node.op.startswith("StatelessWhile"): + root_node = node.name + cond_fn_name, body_fn_name = [ + parse_attr(node.attr).get(x).name for x in ["cond", "body"] + ] + for fn_name in [cond_fn_name, body_fn_name]: + subfunction = self._gdef_lib[fn_name] + sub_func_name = fn_name + for sub_node in subfunction.node: + # bypass const node + if sub_node.op == "Const": + continue + self._tf_node_map[sub_node.name] = sub_node + self._analysis_tensor_list_op( + subfunction, + sub_node, + tl_write_nodes, + tl_stack_nodes, + tl_construct_nodes, + sub_func_name=sub_func_name, + root_node=root_node, + ) + + def _infer_static_shape_stack_node(self, tl_stack_nodes): + for stack_node in tl_stack_nodes: + if len(stack_node.input) < 2: + # Stack node does not have shape + continue + input_shape_name = stack_node.input[1].split(":")[0] + input_shape_node = self._tf_node_map[input_shape_name] + stack = [self._tf_node_map[stack_node.input[0].split(":")[0]]] + in_idx = -1 + while stack: + cnode = stack.pop(0) + if not cnode.op.startswith("TensorList"): + if in_idx and cnode.op.startswith("StatelessWhile"): + stack.append(self._tf_node_map[cnode.input[in_idx].split(":")[0]]) + else: + for iname in cnode.input: + if self._tf_node_map[iname.split(":")[0]].op.startswith( + "StatelessWhile" + ): + # identify input index based on output index + if iname.split(":")[1]: + in_idx = int(iname.split(":")[1]) + stack.append(self._tf_node_map[iname.split(":")[0]]) + # identify the corresponding constructor node and add shape to _tensor_list_shapes + elif cnode.name != stack_node.name: + if is_tensor_list_constuctor(cnode): + shape_attr = parse_attr(input_shape_node.attr) + if "value" not in shape_attr: + continue + raw_elem_shape = tensor_util.MakeNdarray(shape_attr["value"]) + elem_shape = [] + for dim in raw_elem_shape: + if dim < 0: + elem_shape.append(Any()) + else: + elem_shape.append(int(dim)) + self._tensor_list_shapes[cnode.name] = elem_shape + break + + def _infer_static_shape_write_node(self, tl_write_nodes): + for item in tl_write_nodes: + wnode = item[0] + ta_idx, inode_idx = item[1] + sub_func_name = item[2] + root_name = item[3] + stack = [self._tf_node_map[wnode.input[ta_idx].split(":")[0]]] + while stack: + cnode = stack.pop(0) + + if not cnode.op.startswith("TensorList"): + if cnode.op == "Placeholder" and sub_func_name: + # need to map subfunction + input_idx = self._sub_input_idx_map[sub_func_name][cnode.name] + stack.append( + self._tf_node_map[ + self._tf_node_map[root_name].input[input_idx].split(":")[0] + ] + ) + else: + for iname in cnode.input: + stack.append(self._tf_node_map[iname.split(":")[0]]) + # identify the corresponding constructor node and add it to _tensor_list_shape_nodes + elif cnode.name != wnode.name: + if is_tensor_list_constuctor(cnode): + inode = self._tf_node_map[wnode.input[inode_idx].split(":")[0]] + tn = wnode.input[inode_idx].split(":") + output_index = int(tn[1]) if len(tn) > 1 else 0 + self._tensor_list_shape_nodes[cnode.name] = (inode, wnode.op, output_index) + break + def _get_relay_func(self, graph, layout="NHWC", shape=None, outputs=None, input_types=None): if input_types is None: input_types = {} - + tl_write_nodes = [] + tl_stack_nodes = [] + tl_construct_nodes = [] self._layout = layout for node in graph.node: name = node.name @@ -235,6 +376,18 @@ def _get_relay_func(self, graph, layout="NHWC", shape=None, outputs=None, input_ self._nodes[node.name] = sym if param: self._params[node.name] = param + # recursivly iterate tensorlist op if seen while loop + else: + self._analysis_tensor_list_op( + graph, node, tl_write_nodes, tl_stack_nodes, tl_construct_nodes + ) + + # Use tensor list stack to infer static tensor list shape + self._infer_static_shape_stack_node(tl_stack_nodes) + + # Fetch node contains static tensor list shape + self._infer_static_shape_write_node(tl_write_nodes) + for node in graph.node: self._backtrack_construct(graph, node.name) @@ -321,16 +474,36 @@ def _convert_operator(self, graph, op_name, node_name, inputs, attrs): gdef_lib=self._gdef_lib, ) elif op_name in _convert_map_common: + # assert op are exclusive + assert not set(_convert_map_common.keys()) & set(_convert_map_tf2.keys()) if _need_prelude_for_shape_inference(op_name): sym = _convert_map_common[op_name](inputs, attrs, self._params, self._prelude) else: sym = _convert_map_common[op_name](inputs, attrs, self._params, self._module.mod) + elif op_name in _convert_map_tf2: + if _need_prelude_for_shape_inference(op_name): + sym = _convert_map_tf2[op_name](inputs, attrs, self._params, self._prelude) + else: + sym = _convert_map_tf2[op_name](inputs, attrs, self._params, self._module.mod) else: raise NotImplementedError("Operator {} not implemented.".format(op_name)) sym = set_span(sym, node_name) return sym + def _parse_element_shape(self, elem_shape, shape_attr): + if "value" in shape_attr: + raw_elem_shape = tensor_util.MakeNdarray(shape_attr["value"]) + + if raw_elem_shape.size == 1 and raw_elem_shape == -1: + elem_shape.append(Any()) + else: + for dim in raw_elem_shape: + if dim < 0: + elem_shape.append(Any()) + else: + elem_shape.append(dim) + def _backtrack_construct(self, graph, node_name): """Convert a specific tensorflow node to relay expression. @@ -370,8 +543,8 @@ def _backtrack_construct(self, graph, node_name): CallNode(Op(add), [Var(x, ty=TensorType([], float32)), Constant(1.0)], (nullptr), []) """ - input_op_name = node_name.split(":")[0].split("^")[-1] + if input_op_name not in self._nodes: node = self._tf_node_map[input_op_name] attr = parse_attr(node.attr) @@ -386,8 +559,31 @@ def _backtrack_construct(self, graph, node_name): attr["_node_name"] = node.name attr["_target_layout"] = self._layout inputs = [self._backtrack_construct(graph, iname) for iname in node.input] - op = self._convert_operator(graph, node.op, node.name, inputs, attr) + # infer shape for TensorList op + if is_tensor_list_constuctor(node): + input_shape_name = ( + node.input[1] if "TensorListFromTensor" in node.op else node.input[0] + ) + input_shape_name = input_shape_name.split(":")[0] + input_shape_node = self._tf_node_map[input_shape_name] + shape_attr = parse_attr(input_shape_node.attr) + elem_shape = [] + + self._parse_element_shape(elem_shape, shape_attr) + + if elem_shape: + attr["shape"] = elem_shape + if ( + "identical_element_shapes" in attr and attr["identical_element_shapes"] + ) or elem_shape: + shape = elem_shape + if node.name in self._tensor_list_shapes: + preset_shape = self._tensor_list_shapes[node.name] + shape = _get_more_static_shape_rank(shape, preset_shape) + attr["shape"] = shape + + op = self._convert_operator(graph, node.op, node.name, inputs, attr) if isinstance(op, np.ndarray): self._params[node.name] = tvm.nd.array(op) op = [ diff --git a/python/tvm/relay/frontend/tensorflow2_ops.py b/python/tvm/relay/frontend/tensorflow2_ops.py new file mode 100644 index 000000000000..5024c97238ea --- /dev/null +++ b/python/tvm/relay/frontend/tensorflow2_ops.py @@ -0,0 +1,179 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, unused-argument, too-many-lines, len-as-condition, broad-except +"""Tensorflow2.x to relay converter ops and helper""" +import tvm +from tvm.relay.prelude import StaticTensorArrayOps, get_tensor_array_shape + +from .. import op as _op +from ..ty import Any +from .common import infer_value as _infer_value +from .common import infer_type as _infer_type +from .tensorflow_ops import _get_more_static_shape_rank + + +def _infer_type_with_prelude(val, prelude): + body = _infer_type(val, prelude.mod) + return body.checked_type + + +def _need_prelude_for_shape_inference(op): + return "TensorList" in op or "TensorArray" in op + + +def _tensorlist_reserve(): + def _impl(inputs, attr, params, prelude): + dtype_str = attr.get("element_dtype").name + elem_shape = _infer_value(inputs[0], params, prelude.mod) + elem_shape = tuple(elem_shape.asnumpy().astype("int32").flatten()) + + if elem_shape or "shape" in attr: + shape = attr["shape"] if "shape" in attr else elem_shape + static_tensor_array_ops = StaticTensorArrayOps(prelude, dtype_str, shape) + static_tensor_array_ops.register() + tensor_array_constructor = static_tensor_array_ops.get_global_var("tensor_array") + tensor_array = tensor_array_constructor(inputs[1]) + else: + tensor_array_constructor = prelude.get_global_var("tensor_array", dtype_str) + tensor_array = tensor_array_constructor(inputs[1]) + return tensor_array + + return _impl + + +def _tensorlist_set_item(): + def _impl(inputs, attr, params, prelude): + dtype_str = attr.get("element_dtype").name + input_ta = inputs[0] + input_ta_shape = get_tensor_array_shape(input_ta, dtype_str, prelude) + input_t_shape = _infer_type_with_prelude(inputs[2], prelude).shape + input_rank = len(input_t_shape) + + if input_ta_shape is None: + tensor_name = "tensor{}".format(input_rank) + tensor_func = prelude.get_tensor_ctor(tensor_name, dtype_str) + v = tensor_func(inputs[2]) + write_func = prelude.get_global_var("tensor_array_write", dtype_str) + out = write_func(input_ta, inputs[1], v) + else: + static_tensor_array_ops = StaticTensorArrayOps(prelude, dtype_str, input_ta_shape) + static_tensor_array_ops.register() + tensor_func = static_tensor_array_ops.get_ctor("tensor_constructor") + v = tensor_func(inputs[2]) + # Write tensor with more static shape + # convert shape with -1 to any() + input_ta_shape_a = [] + for dim in input_ta_shape: + if isinstance(dim, (int, tvm.tir.expr.IntImm)): + if dim < 0: + input_ta_shape_a.append(Any()) + else: + input_ta_shape_a.append(dim) + else: + input_ta_shape_a.append(dim) + actual_shape = _get_more_static_shape_rank(input_t_shape, input_ta_shape_a) + if actual_shape != input_ta_shape_a: + new_shape = [] + num_any_dim = 0 + for dim in actual_shape: + if not isinstance(dim, int): + num_any_dim += 1 + new_shape.append(dim if isinstance(dim, int) else -1) + if num_any_dim <= 1: + v = tensor_func(_op.reshape(inputs[2], new_shape)) + write_func = prelude.get_global_var_static( + "tensor_array_write", dtype_str, input_ta_shape_a + ) + out = write_func(input_ta, inputs[1], v) + return out + + return _impl + + +def _tensorlist_get_item(): + def _impl(inputs, attr, params, prelude): + dtype_str = attr["element_dtype"].name + input_shape = get_tensor_array_shape(inputs[0], dtype_str, prelude) + + if input_shape is None: + read_func = prelude.get_global_var("tensor_array_read", dtype_str) + out = read_func(inputs[0], _op.take(inputs[1], tvm.relay.const(0))) + else: + static_tensor_array_ops = StaticTensorArrayOps(prelude, dtype_str, input_shape) + static_tensor_array_ops.register() + read_func = static_tensor_array_ops.get_global_var("tensor_array_read") + out_tensor = read_func(inputs[0], _op.take(inputs[1], tvm.relay.const(0))) + get_data_func = static_tensor_array_ops.get_global_var("tensor_get_data") + out = get_data_func(out_tensor) + return out + + return _impl + + +def _tensorlist_stack(): + def _impl(inputs, attr, params, prelude): + dtype_str = attr["element_dtype"].name + input_ta_shape = get_tensor_array_shape(inputs[0], dtype_str, prelude) + + if input_ta_shape is None: + stack_func = prelude.get_global_var("tensor_array_stack", dtype_str) + out = stack_func(inputs[0]) + else: + static_tensor_array_ops = StaticTensorArrayOps(prelude, dtype_str, input_ta_shape) + static_tensor_array_ops.register() + stack_func = prelude.get_global_var_static( + "tensor_array_stack", dtype_str, input_ta_shape + ) + out_tensor = stack_func(inputs[0]) + out_shape = (Any(),) + input_ta_shape + static_tensor_array_ops = StaticTensorArrayOps(prelude, dtype_str, out_shape) + static_tensor_array_ops.register() + get_data_func = prelude.get_global_var_static("tensor_get_data", dtype_str, out_shape) + out = get_data_func(out_tensor) + + return out + + return _impl + + +def _tensorlist_from_tensor(): + def _impl(inputs, attr, params, prelude): + dtype_str = attr["element_dtype"].name + input_ta_shape = _infer_type_with_prelude(inputs[0], prelude).shape + + if input_ta_shape is None: + unstack_func = prelude.get_global_var("tensor_array_unstack", dtype_str) + out = unstack_func(inputs[0]) + else: + static_tensor_array_ops = StaticTensorArrayOps(prelude, dtype_str, input_ta_shape) + static_tensor_array_ops.register() + unstack_func = prelude.get_global_var_static( + "tensor_array_unstack", dtype_str, input_ta_shape + ) + out = unstack_func(inputs[0]) + return out + + return _impl + + +_convert_map = { + "TensorListFromTensor": _tensorlist_from_tensor(), + "TensorListGetItem": _tensorlist_get_item(), + "TensorListReserve": _tensorlist_reserve(), + "TensorListSetItem": _tensorlist_set_item(), + "TensorListStack": _tensorlist_stack(), +} diff --git a/python/tvm/relay/frontend/tensorflow_ops.py b/python/tvm/relay/frontend/tensorflow_ops.py index 8753f73ebd85..607769d261e1 100644 --- a/python/tvm/relay/frontend/tensorflow_ops.py +++ b/python/tvm/relay/frontend/tensorflow_ops.py @@ -138,6 +138,18 @@ def _get_more_static_shape(shape0, shape1): return shape1 +def _get_more_static_shape_rank(shape0, shape1): + """Compare two shapes with different rank, + and return the one with fewer symbolic dimension. + """ + num_sym_dim0 = sum([not isinstance(dim, (int, tvm.tir.expr.IntImm)) for dim in list(shape0)]) + num_sym_dim1 = sum([not isinstance(dim, (int, tvm.tir.expr.IntImm)) for dim in list(shape1)]) + + if num_sym_dim0 < num_sym_dim1: + return shape0 + return shape1 + + def _rsqrt(): def _impl(inputs, attr, params, mod): inputs.append(tvm.relay.const(-0.5, attr["T"].name)) diff --git a/tests/python/frontend/tensorflow2/test_functional_models.py b/tests/python/frontend/tensorflow2/test_functional_models.py index a39ecb411f15..53353f5ccffb 100644 --- a/tests/python/frontend/tensorflow2/test_functional_models.py +++ b/tests/python/frontend/tensorflow2/test_functional_models.py @@ -448,5 +448,141 @@ def func(self, x): run_model_graph(StatelessWhile2Var, outputs=["Identity:output:0"]) +def test_tensorlist(): + def run_test(elem_shape): + class TensorList(tf.Module): + def get_input(self): + in_tens = np.ones((2, 3), dtype="float32") + in_tens[1, :] = np.zeros((3,), dtype="float32") + return in_tens + + @tf.function(input_signature=[tf.TensorSpec(shape=(2, 3), dtype=tf.float32)]) + def func(self, x): + dtype = tf.float32 + tl = tf.raw_ops.TensorListReserve( + element_shape=elem_shape, num_elements=2, element_dtype=dtype + ) + tl = tf.raw_ops.TensorListSetItem(input_handle=tl, index=0, item=x[0, :]) + tl = tf.raw_ops.TensorListSetItem(input_handle=tl, index=1, item=x[1, :]) + output = tf.raw_ops.TensorListGetItem( + input_handle=tl, index=0, element_shape=elem_shape, element_dtype=dtype + ) + return output + + run_model_graph(TensorList) + run_func_graph(TensorList, runtime="vm") + + run_test((3,)) + run_test((-1,)) + + +def test_tensorlist_stack(): + def run_test(elem_shape): + class TensorListStack(tf.Module): + def get_input(self): + in_tens = np.ones((2, 3), dtype="float32") + in_tens[1] = np.zeros((3,), dtype="float32") + return in_tens + + """2D array as input""" + + @tf.function(input_signature=[tf.TensorSpec(shape=(2, 3), dtype=tf.float32)]) + def func(self, x): + dtype = tf.float32 + tl = tf.raw_ops.TensorListReserve( + element_shape=elem_shape, num_elements=2, element_dtype=dtype + ) + tl = tf.raw_ops.TensorListFromTensor(tensor=x, element_shape=elem_shape) + output = tf.raw_ops.TensorListStack( + input_handle=tl, element_shape=elem_shape, element_dtype=dtype + ) + return output + + run_model_graph(TensorListStack) + run_func_graph(TensorListStack, runtime="vm") + + run_test((3,)) + run_test((-1,)) + + +def test_tensorlist_2d(): + def run_test(elem_shape): + class TensorList2D(tf.Module): + def get_input(self): + in_tens = np.ones((2, 3, 4), dtype="float32") + in_tens[1, :, :] = np.zeros((3, 4), dtype="float32") + return in_tens + + """2D array as input""" + + @tf.function(input_signature=[tf.TensorSpec(shape=(2, 3, 4), dtype=tf.float32)]) + def func(self, x): + dtype = tf.float32 + tl = tf.raw_ops.TensorListReserve( + element_shape=elem_shape, num_elements=2, element_dtype=dtype + ) + tl = tf.raw_ops.TensorListSetItem(input_handle=tl, index=0, item=x[0, :, :]) + tl = tf.raw_ops.TensorListSetItem(input_handle=tl, index=1, item=x[1, :, :]) + output = tf.raw_ops.TensorListGetItem( + input_handle=tl, index=0, element_shape=elem_shape, element_dtype=dtype + ) + return output + + run_model_graph(TensorList2D) + run_func_graph(TensorList2D, runtime="vm") + + run_test( + ( + 3, + 4, + ) + ) + run_test( + ( + -1, + -1, + ) + ) + + +def test_tensorlist_stack_2d(): + def run_test(elem_shape): + class TensorListStack2D(tf.Module): + def get_input(self): + in_tens = np.ones((2, 3, 4), dtype="float32") + in_tens[1, :, :] = np.zeros((3, 4), dtype="float32") + return in_tens + + """2D array as input""" + + @tf.function(input_signature=[tf.TensorSpec(shape=(2, 3, 4), dtype=tf.float32)]) + def func(self, x): + dtype = tf.float32 + tl = tf.raw_ops.TensorListReserve( + element_shape=elem_shape, num_elements=2, element_dtype=dtype + ) + tl = tf.raw_ops.TensorListFromTensor(tensor=x, element_shape=elem_shape) + output = tf.raw_ops.TensorListStack( + input_handle=tl, element_shape=elem_shape, element_dtype=dtype + ) + return output + + run_model_graph(TensorListStack2D) + run_func_graph(TensorListStack2D, runtime="vm") + + run_test( + ( + 3, + 4, + ) + ) + run_test( + ( + -1, + -1, + ) + ) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/python/frontend/tensorflow2/test_sequential_models.py b/tests/python/frontend/tensorflow2/test_sequential_models.py index 394a49d0f2e9..1b5a6342f07d 100644 --- a/tests/python/frontend/tensorflow2/test_sequential_models.py +++ b/tests/python/frontend/tensorflow2/test_sequential_models.py @@ -109,5 +109,60 @@ def maxpool_batchnorm_model(input_shape, pool_size=(2, 2)): run_sequential_model(maxpool_batchnorm_model, input_shape=(1, 32, 32, 3)) +def test_tensorlist_stack_model(): + def tensorlist_stack_model(input_shape): + class TensorArrayStackLayer(tf.keras.layers.Layer): + def __init__(self): + super().__init__() + + def call(self, inputs): + inputs = tf.squeeze(inputs) + outputs = tf.TensorArray( + tf.float32, + size=inputs.shape[0], + infer_shape=False, + element_shape=inputs.shape[1:], + ) + outputs = outputs.unstack(inputs) + + return outputs.stack() + + input_shape = (3, 32) + model = tf.keras.Sequential( + [tf.keras.layers.Input(shape=input_shape, batch_size=1), TensorArrayStackLayer()] + ) + return model + + run_sequential_model(tensorlist_stack_model, input_shape=(3, 32)) + + +def test_tensorlist_read_model(): + def tensorlist_read_model(input_shape): + class TensorArrayReadLayer(tf.keras.layers.Layer): + def __init__(self): + super().__init__() + + def call(self, inputs): + inputs = tf.squeeze(inputs) + outputs = tf.TensorArray( + tf.float32, + size=inputs.shape[0], + infer_shape=False, + element_shape=inputs.shape[1:], + ) + for i in range(inputs.shape[0]): + outputs = outputs.write(i, inputs[i, :]) + + return outputs.read(0) + + input_shape = (3, 32) + model = tf.keras.Sequential( + [tf.keras.layers.Input(shape=input_shape, batch_size=1), TensorArrayReadLayer()] + ) + return model + + run_sequential_model(tensorlist_read_model, input_shape=(3, 32)) + + if __name__ == "__main__": pytest.main([__file__])