diff --git a/python/tvm/relay/frontend/tensorflow2_ops.py b/python/tvm/relay/frontend/tensorflow2_ops.py index 5024c97238ea..945554816984 100644 --- a/python/tvm/relay/frontend/tensorflow2_ops.py +++ b/python/tvm/relay/frontend/tensorflow2_ops.py @@ -133,13 +133,21 @@ def _impl(inputs, attr, params, prelude): stack_func = prelude.get_global_var("tensor_array_stack", dtype_str) out = stack_func(inputs[0]) else: - static_tensor_array_ops = StaticTensorArrayOps(prelude, dtype_str, input_ta_shape) + if "num_elements" in attr: + num_elements = attr["num_elements"] + static_tensor_array_ops = StaticTensorArrayOps( + prelude, dtype_str, input_ta_shape, num_elements + ) static_tensor_array_ops.register() stack_func = prelude.get_global_var_static( - "tensor_array_stack", dtype_str, input_ta_shape + "tensor_array_stack", dtype_str, input_ta_shape, num_elements ) out_tensor = stack_func(inputs[0]) - out_shape = (Any(),) + input_ta_shape + out_shape = ( + (num_elements,) + input_ta_shape + if num_elements and num_elements == 1 + else (Any(),) + input_ta_shape + ) static_tensor_array_ops = StaticTensorArrayOps(prelude, dtype_str, out_shape) static_tensor_array_ops.register() get_data_func = prelude.get_global_var_static("tensor_get_data", dtype_str, out_shape) diff --git a/python/tvm/relay/prelude.py b/python/tvm/relay/prelude.py index 376bf4a4804d..542980561e78 100644 --- a/python/tvm/relay/prelude.py +++ b/python/tvm/relay/prelude.py @@ -73,9 +73,33 @@ def get_tensor_array_shape(expr, dtype, prelude): return None -def _get_name_static(canonical, dtype, shape): - """Get name for static shape tensor array op corresponding - to the canonical name""" +def _get_name_static(canonical, dtype, shape, batch_dim=None): + """Get name for static shape tensor array op + + By design, static ADT tensor in TVM has type name in the format + of static_tensor_dim0_dim1_..._dimN_t + or static_tensor_batch1_dim0_dim1_..._dimN_t if tensorlist stack only have one item. + + Parameters + ---------- + canonical : String + Tensor array op name + + dtype : str + Data type. + + shape : tuple of (int, Any) or None + Tensor array shape + + batch_dim: None or int + 1 if tensorlist stack only have one item. + None by default + + Returns + ------- + name : String + The tensor array op name + """ dim_names = [] for dim in shape: if isinstance(dim, Any): @@ -89,26 +113,31 @@ def _get_name_static(canonical, dtype, shape): shape_str = "scalar" if canonical == "tensor_t": return "static_tensor_{}_{}_t".format(dtype, shape_str) - return "{}_{}_{}".format(canonical, dtype, shape_str) + if batch_dim is None or canonical in ["tensor_constructor", "tensor_nil"]: + return "{}_{}_{}".format(canonical, dtype, shape_str) + if batch_dim != 1: + return "{}_{}_{}".format(canonical, dtype, shape_str) + return "{}_{}_batch{}_{}".format(canonical, dtype, str(batch_dim), shape_str) class StaticTensorArrayOps(object): """Contains tensor array related ops for fixed rank tensor array""" - def __init__(self, prelude, dtype, shape): + def __init__(self, prelude, dtype, shape, batch_dim=None): """Create tensor array ops registry""" self.prelude = prelude self.dtype = dtype self.shape = shape + self.batch_dim = batch_dim self.list, self.cons, self.nil = self.prelude.mod.get_type("List") def get_name(self, canonical): """Get name corresponding to the canonical name""" - return _get_name_static(canonical, self.dtype, self.shape) + return _get_name_static(canonical, self.dtype, self.shape, self.batch_dim) def get_global_var(self, canonical): """Get global corresponding to the canonical name""" - return self.prelude.get_global_var_static(canonical, self.dtype, self.shape) + return self.prelude.get_global_var_static(canonical, self.dtype, self.shape, self.batch_dim) def get_type(self, canonical): """Get type corresponding to the canonical name""" @@ -262,9 +291,10 @@ def define_tensor_expand_dims(self): # Note: we set the added axis to be Any() instead of 1 due to # in stack op, we need to recursively concatenate. + new_axis = Any() if self.batch_dim is None or self.batch_dim != 1 else self.batch_dim tensor_type_var, tensor_constructor, _ = self._get_adt_by_shape( [ - Any(), + new_axis, ] + list(self.shape) ) @@ -573,20 +603,27 @@ def define_tensor_array_stack(self): expand_dims_var = self.get_global_var("tensor_expand_dims") # Register tensor_concatenate for output_shape + new_axis = Any() if not self.batch_dim or self.batch_dim != 1 else self.batch_dim output_shape = [ - Any(), + new_axis, ] + list(self.shape) - _, _, output_ops = self._get_adt_by_shape(output_shape) output_ops.define_tensor_concatenate() concat_var = output_ops.get_global_var("tensor_concatenate") tensor_array_expand_dims = self.prelude.map(expand_dims_var, tensor_array) - tensors = self.prelude.foldl( - concat_var, - self.prelude.hd(tensor_array_expand_dims), - self.prelude.tl(tensor_array_expand_dims), - ) + if self.batch_dim is not None and self.batch_dim == 1: + # only one element + tensors = self.prelude.id( + self.prelude.hd(tensor_array_expand_dims), + ) + else: + tensors = self.prelude.foldl( + concat_var, + self.prelude.hd(tensor_array_expand_dims), + self.prelude.tl(tensor_array_expand_dims), + ) + output_tensor_type_var, _, _ = self._get_adt_by_shape(output_shape) self.prelude.mod[stack_var] = Function( [tensor_array], tensors, output_tensor_type_var(), [] @@ -599,8 +636,9 @@ def define_tensor_array_gather(self): helper_name = self.get_name("tensor_array_gather_helper") helper_var = self._create_global_var(helper_name) + new_axis = Any() if self.batch_dim is None or self.batch_dim != 1 else self.batch_dim output_shape = [ - Any(), + new_axis, ] + list(self.shape) output_tensor_type_var, _, _ = self._get_adt_by_shape(output_shape) stack_var = self.get_global_var("tensor_array_stack") @@ -668,7 +706,7 @@ def register(self): def _get_adt_by_shape(self, shape): """Get ADT type and constructor with given shape.""" - adt_ops = StaticTensorArrayOps(self.prelude, self.dtype, shape) + adt_ops = StaticTensorArrayOps(self.prelude, self.dtype, shape, self.batch_dim) adt_ops.define_tensor_adt() tensor_type_var = adt_ops.get_type("tensor_t") tensor_constructor = adt_ops.get_ctor("tensor_constructor") @@ -1482,13 +1520,13 @@ def get_tensor_ctor(self, canonical, dtype): ty = self.get_type("tensor_t", dtype) return self.get_ctor(ty.name_hint, canonical, dtype) - def get_name_static(self, canonical, dtype, shape): + def get_name_static(self, canonical, dtype, shape, batch_dim=None): """Get name corresponding to the canonical name""" - return _get_name_static(canonical, dtype, shape) + return _get_name_static(canonical, dtype, shape, batch_dim) - def get_global_var_static(self, canonical, dtype, shape): + def get_global_var_static(self, canonical, dtype, shape, batch_dim=None): """Get var corresponding to the canonical name""" - name = self.get_name_static(canonical, dtype, shape) + name = self.get_name_static(canonical, dtype, shape, batch_dim) return self.mod.get_global_var(name) def get_type_static(self, canonical, dtype, shape): diff --git a/tests/python/frontend/tensorflow2/test_functional_models.py b/tests/python/frontend/tensorflow2/test_functional_models.py index 53353f5ccffb..001ba6de1967 100644 --- a/tests/python/frontend/tensorflow2/test_functional_models.py +++ b/tests/python/frontend/tensorflow2/test_functional_models.py @@ -484,8 +484,6 @@ def get_input(self): in_tens[1] = np.zeros((3,), dtype="float32") return in_tens - """2D array as input""" - @tf.function(input_signature=[tf.TensorSpec(shape=(2, 3), dtype=tf.float32)]) def func(self, x): dtype = tf.float32 @@ -513,8 +511,6 @@ def get_input(self): in_tens[1, :, :] = np.zeros((3, 4), dtype="float32") return in_tens - """2D array as input""" - @tf.function(input_signature=[tf.TensorSpec(shape=(2, 3, 4), dtype=tf.float32)]) def func(self, x): dtype = tf.float32 @@ -531,18 +527,8 @@ def func(self, x): run_model_graph(TensorList2D) run_func_graph(TensorList2D, runtime="vm") - run_test( - ( - 3, - 4, - ) - ) - run_test( - ( - -1, - -1, - ) - ) + run_test((3, 4)) + run_test((-1, -1)) def test_tensorlist_stack_2d(): @@ -553,8 +539,6 @@ def get_input(self): in_tens[1, :, :] = np.zeros((3, 4), dtype="float32") return in_tens - """2D array as input""" - @tf.function(input_signature=[tf.TensorSpec(shape=(2, 3, 4), dtype=tf.float32)]) def func(self, x): dtype = tf.float32 @@ -570,18 +554,35 @@ def func(self, x): run_model_graph(TensorListStack2D) run_func_graph(TensorListStack2D, runtime="vm") - run_test( - ( - 3, - 4, - ) - ) - run_test( - ( - -1, - -1, - ) - ) + run_test((3, 4)) + run_test((-1, -1)) + + +def test_tensorlist_stack_unpack(): + def run_test(elem_shape): + class TensorListStack2D(tf.Module): + def get_input(self): + in_tens = np.ones((1, 3, 4), dtype="float32") + return in_tens + + @tf.function(input_signature=[tf.TensorSpec(shape=(1, 3, 4), dtype=tf.float32)]) + def func(self, x): + dtype = tf.float32 + tl = tf.raw_ops.TensorListReserve( + element_shape=elem_shape, num_elements=1, element_dtype=dtype + ) + tl = tf.raw_ops.TensorListSetItem(input_handle=tl, index=0, item=x[0, :, :]) + output = tf.raw_ops.TensorListStack( + input_handle=tl, element_shape=elem_shape, element_dtype=dtype, num_elements=1 + ) + output = tf.raw_ops.Unpack(value=output, num=1, axis=0) + return output + + run_model_graph(TensorListStack2D) + run_func_graph(TensorListStack2D, runtime="vm") + + run_test((3, 4)) + run_test((-1, -1)) if __name__ == "__main__":