Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend][TensorFlow] Move infer_value to _get_list_param #8051

Merged
merged 1 commit into from
May 18, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 31 additions & 52 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,11 @@ def _get_num_param(params, input_node):
return _get_param(params, input_node).item()


def _get_list_param(params, input_node):
return _get_param(params, input_node).tolist()
def _get_list_param(params, input_node, mod):
try:
return _get_param(params, input_node).tolist()
except (IndexError, KeyError, AttributeError):
return _infer_value(input_node, params, mod).asnumpy().tolist()


def _get_tuple_param(params, input_node):
Expand Down Expand Up @@ -913,10 +916,7 @@ def _crop_and_resize():
def _impl(inputs, attr, params, mod):
# input image is a 4-D tensor of shape [batch, image_height, image_width, depth]
# boxes is a 2-D tensor of shape [num_boxes, 4], 4 is for [y1, x1, y2, x2]
try:
crop_size = _get_list_param(params, inputs[3])
except (IndexError, KeyError):
crop_size = _infer_value(inputs[3], params, mod).asnumpy().tolist()
crop_size = _get_list_param(params, inputs[3], mod)

method = attr["method"].decode()
method = "nearest_neighbor" if method == "nearest" else method
Expand Down Expand Up @@ -1658,7 +1658,7 @@ def _impl(inputs, attr, params, mod):
np_reps = _infer_value(reps_input, params, mod).asnumpy()
reps = [np_reps.flatten()[i] for i in range(np_reps.flatten().shape[0])]
else:
reps = _get_list_param(params, reps_input)
reps = _get_list_param(params, reps_input, mod)
new_input = [inputs.pop(0)]

return AttrCvt(op_name="tile", extras={"reps": tuple(reps)}, ignores=["Tmultiples"])(
Expand All @@ -1671,21 +1671,15 @@ def _impl(inputs, attr, params, mod):
def _slice():
def _impl(inputs, attr, params, mod):
try:
begin = _get_list_param(params, inputs[1])
except (IndexError, KeyError, AttributeError):
begin = _get_list_param(params, inputs[1], mod)
except Exception:
# Handle symbolic begin
try:
begin = _infer_value(inputs[1], params, mod).asnumpy().tolist()
except Exception:
begin = inputs[1]
begin = inputs[1]
try:
size = _get_list_param(params, inputs[2])
except (IndexError, KeyError, AttributeError):
size = _get_list_param(params, inputs[2], mod)
except Exception:
# Handle symbolic size
try:
size = _infer_value(inputs[2], params, mod).asnumpy().tolist()
except Exception:
size = inputs[2]
size = inputs[2]

# Align begin and strides for dynamic shape.
data_dim = len(_infer_shape(inputs[0], mod))
Expand Down Expand Up @@ -1962,7 +1956,7 @@ def _impl(inputs, attr, params, mod):

def _reduce(op):
def _impl(inputs, attr, params, mod):
axis = _get_list_param(params, inputs[1])
axis = _get_list_param(params, inputs[1], mod)
axis = tuple(axis)
if not axis:
axis = None
Expand All @@ -1978,7 +1972,7 @@ def _impl(inputs, attr, params, mod):

def _euclidean_norm():
def _impl(inputs, attr, params, mod):
axis = tuple(_get_list_param(params, inputs[1]))
axis = tuple(_get_list_param(params, inputs[1], mod))
keep_dims = bool(attr.get("keep_dims", False))
return _op.sqrt(
_op.cast(_op.reduce.sum(_op.multiply(inputs[0], inputs[0]), axis, keep_dims), "float32")
Expand Down Expand Up @@ -2039,9 +2033,9 @@ def _impl(inputs, attr, params, mod):
Tensorflow mask validation: https://github.com/tensorflow/tensorflow/blob/master/
tensorflow/core/util/strided_slice_op.cc#L147-L368
"""
begin = _get_list_param(params, inputs[1])
end = _get_list_param(params, inputs[2])
stride = _get_list_param(params, inputs[3])
begin = _get_list_param(params, inputs[1], mod)
end = _get_list_param(params, inputs[2], mod)
stride = _get_list_param(params, inputs[3], mod)

begin_mask = int(attr.get("begin_mask", 0))
end_mask = int(attr.get("end_mask", 0))
Expand Down Expand Up @@ -2243,10 +2237,7 @@ def _transpose():
def _impl(inputs, attr, params, mod):
# If perm is not specified, axes is left empty,
# otherwise its value is get from params
try:
axes = _get_list_param(params, inputs[1])
except (IndexError, KeyError, AttributeError):
axes = _infer_value(inputs[1], params, mod).asnumpy().tolist()
axes = _get_list_param(params, inputs[1], mod)
return _op.transpose(inputs[0], axes=axes)

return _impl
Expand Down Expand Up @@ -2536,19 +2527,13 @@ def _impl(inputs, attr, params, mod):

def _space_to_batch_nd():
def _impl(inputs, attr, params, mod):
try:
block_shape = _get_list_param(params, inputs[1])
except (IndexError, KeyError, AttributeError):
block_shape = _infer_value(inputs[1], params, mod).asnumpy().tolist()
block_shape = _get_list_param(params, inputs[1], mod)

try:
paddings = _get_list_param(params, inputs[2])
except (IndexError, KeyError, AttributeError):
paddings = _infer_value(inputs[2], params, mod).asnumpy()
paddings = np.squeeze(paddings)
if len(paddings.shape) == 1:
paddings = np.expand_dims(paddings, axis=0)
paddings = paddings.tolist()
paddings = _get_list_param(params, inputs[2], mod)
paddings = np.squeeze(paddings)
if len(paddings.shape) == 1:
paddings = np.expand_dims(paddings, axis=0)
paddings = paddings.tolist()

attr["block_shape"] = block_shape
attr["paddings"] = paddings
Expand All @@ -2561,19 +2546,13 @@ def _impl(inputs, attr, params, mod):

def _batch_to_space_nd():
def _impl(inputs, attr, params, mod):
try:
block_shape = _get_list_param(params, inputs[1])
except (IndexError, KeyError, AttributeError):
block_shape = _infer_value(inputs[1], params, mod).asnumpy().tolist()
block_shape = _get_list_param(params, inputs[1], mod)

try:
crops = _get_list_param(params, inputs[2])
except (IndexError, KeyError, AttributeError):
crops = _infer_value(inputs[2], params, mod).asnumpy()
crops = np.squeeze(crops)
if len(crops.shape) == 1:
crops = np.expand_dims(crops, axis=0)
crops = crops.tolist()
crops = _get_list_param(params, inputs[2], mod)
crops = np.squeeze(crops)
if len(crops.shape) == 1:
crops = np.expand_dims(crops, axis=0)
crops = crops.tolist()

attr["block_shape"] = block_shape
attr["crops"] = crops
Expand Down