diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 4bd332fa01599..f0fd0b5dfb203 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -113,8 +113,11 @@ def _get_num_param(params, input_node): return _get_param(params, input_node).item() -def _get_list_param(params, input_node): - return _get_param(params, input_node).tolist() +def _get_list_param(params, input_node, mod): + try: + return _get_param(params, input_node).tolist() + except (IndexError, KeyError, AttributeError): + return _infer_value(input_node, params, mod).asnumpy().tolist() def _get_tuple_param(params, input_node): @@ -913,10 +916,7 @@ def _crop_and_resize(): def _impl(inputs, attr, params, mod): # input image is a 4-D tensor of shape [batch, image_height, image_width, depth] # boxes is a 2-D tensor of shape [num_boxes, 4], 4 is for [y1, x1, y2, x2] - try: - crop_size = _get_list_param(params, inputs[3]) - except (IndexError, KeyError): - crop_size = _infer_value(inputs[3], params, mod).asnumpy().tolist() + crop_size = _get_list_param(params, inputs[3], mod) method = attr["method"].decode() method = "nearest_neighbor" if method == "nearest" else method @@ -1658,7 +1658,7 @@ def _impl(inputs, attr, params, mod): np_reps = _infer_value(reps_input, params, mod).asnumpy() reps = [np_reps.flatten()[i] for i in range(np_reps.flatten().shape[0])] else: - reps = _get_list_param(params, reps_input) + reps = _get_list_param(params, reps_input, mod) new_input = [inputs.pop(0)] return AttrCvt(op_name="tile", extras={"reps": tuple(reps)}, ignores=["Tmultiples"])( @@ -1671,21 +1671,14 @@ def _impl(inputs, attr, params, mod): def _slice(): def _impl(inputs, attr, params, mod): try: - begin = _get_list_param(params, inputs[1]) - except (IndexError, KeyError, AttributeError): + begin = _get_list_param(params, inputs[1], mod) + except Exception: # Handle symbolic begin - try: - begin = _infer_value(inputs[1], params, mod).asnumpy().tolist() - except Exception: - begin = inputs[1] + begin = inputs[1] try: - size = _get_list_param(params, inputs[2]) - except (IndexError, KeyError, AttributeError): - # Handle symbolic size - try: - size = _infer_value(inputs[2], params, mod).asnumpy().tolist() - except Exception: - size = inputs[2] + size = _get_list_param(params, inputs[2], mod) + except Exception: + size = inputs[2] # Align begin and strides for dynamic shape. data_dim = len(_infer_shape(inputs[0], mod)) @@ -1962,7 +1955,7 @@ def _impl(inputs, attr, params, mod): def _reduce(op): def _impl(inputs, attr, params, mod): - axis = _get_list_param(params, inputs[1]) + axis = _get_list_param(params, inputs[1], mod) axis = tuple(axis) if not axis: axis = None @@ -1978,7 +1971,7 @@ def _impl(inputs, attr, params, mod): def _euclidean_norm(): def _impl(inputs, attr, params, mod): - axis = tuple(_get_list_param(params, inputs[1])) + axis = tuple(_get_list_param(params, inputs[1]), mod) keep_dims = bool(attr.get("keep_dims", False)) return _op.sqrt( _op.cast(_op.reduce.sum(_op.multiply(inputs[0], inputs[0]), axis, keep_dims), "float32") @@ -2039,9 +2032,9 @@ def _impl(inputs, attr, params, mod): Tensorflow mask validation: https://github.com/tensorflow/tensorflow/blob/master/ tensorflow/core/util/strided_slice_op.cc#L147-L368 """ - begin = _get_list_param(params, inputs[1]) - end = _get_list_param(params, inputs[2]) - stride = _get_list_param(params, inputs[3]) + begin = _get_list_param(params, inputs[1], mod) + end = _get_list_param(params, inputs[2], mod) + stride = _get_list_param(params, inputs[3], mod) begin_mask = int(attr.get("begin_mask", 0)) end_mask = int(attr.get("end_mask", 0)) @@ -2243,10 +2236,7 @@ def _transpose(): def _impl(inputs, attr, params, mod): # If perm is not specified, axes is left empty, # otherwise its value is get from params - try: - axes = _get_list_param(params, inputs[1]) - except (IndexError, KeyError, AttributeError): - axes = _infer_value(inputs[1], params, mod).asnumpy().tolist() + axes = _get_list_param(params, inputs[1], mod) return _op.transpose(inputs[0], axes=axes) return _impl @@ -2536,19 +2526,13 @@ def _impl(inputs, attr, params, mod): def _space_to_batch_nd(): def _impl(inputs, attr, params, mod): - try: - block_shape = _get_list_param(params, inputs[1]) - except (IndexError, KeyError, AttributeError): - block_shape = _infer_value(inputs[1], params, mod).asnumpy().tolist() + block_shape = _get_list_param(params, inputs[1], mod) - try: - paddings = _get_list_param(params, inputs[2]) - except (IndexError, KeyError, AttributeError): - paddings = _infer_value(inputs[2], params, mod).asnumpy() - paddings = np.squeeze(paddings) - if len(paddings.shape) == 1: - paddings = np.expand_dims(paddings, axis=0) - paddings = paddings.tolist() + paddings = _get_list_param(params, inputs[2], mod) + paddings = np.squeeze(paddings) + if len(paddings.shape) == 1: + paddings = np.expand_dims(paddings, axis=0) + paddings = paddings.tolist() attr["block_shape"] = block_shape attr["paddings"] = paddings @@ -2561,19 +2545,13 @@ def _impl(inputs, attr, params, mod): def _batch_to_space_nd(): def _impl(inputs, attr, params, mod): - try: - block_shape = _get_list_param(params, inputs[1]) - except (IndexError, KeyError, AttributeError): - block_shape = _infer_value(inputs[1], params, mod).asnumpy().tolist() + block_shape = _get_list_param(params, inputs[1], mod) - try: - crops = _get_list_param(params, inputs[2]) - except (IndexError, KeyError, AttributeError): - crops = _infer_value(inputs[2], params, mod).asnumpy() - crops = np.squeeze(crops) - if len(crops.shape) == 1: - crops = np.expand_dims(crops, axis=0) - crops = crops.tolist() + crops = _get_list_param(params, inputs[2], mod) + crops = np.squeeze(crops) + if len(crops.shape) == 1: + crops = np.expand_dims(crops, axis=0) + crops = crops.tolist() attr["block_shape"] = block_shape attr["crops"] = crops