Skip to content

Commit

Permalink
Move infer_value to _get_list_param
Browse files Browse the repository at this point in the history
  • Loading branch information
Trevor Morris committed May 14, 2021
1 parent 3bf65b7 commit c1321b9
Showing 1 changed file with 31 additions and 52 deletions.
83 changes: 31 additions & 52 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,11 @@ def _get_num_param(params, input_node):
return _get_param(params, input_node).item()


def _get_list_param(params, input_node):
return _get_param(params, input_node).tolist()
def _get_list_param(params, input_node, mod):
try:
return _get_param(params, input_node).tolist()
except (IndexError, KeyError, AttributeError):
return _infer_value(input_node, params, mod).asnumpy().tolist()


def _get_tuple_param(params, input_node):
Expand Down Expand Up @@ -913,10 +916,7 @@ def _crop_and_resize():
def _impl(inputs, attr, params, mod):
# input image is a 4-D tensor of shape [batch, image_height, image_width, depth]
# boxes is a 2-D tensor of shape [num_boxes, 4], 4 is for [y1, x1, y2, x2]
try:
crop_size = _get_list_param(params, inputs[3])
except (IndexError, KeyError):
crop_size = _infer_value(inputs[3], params, mod).asnumpy().tolist()
crop_size = _get_list_param(params, inputs[3], mod)

method = attr["method"].decode()
method = "nearest_neighbor" if method == "nearest" else method
Expand Down Expand Up @@ -1658,7 +1658,7 @@ def _impl(inputs, attr, params, mod):
np_reps = _infer_value(reps_input, params, mod).asnumpy()
reps = [np_reps.flatten()[i] for i in range(np_reps.flatten().shape[0])]
else:
reps = _get_list_param(params, reps_input)
reps = _get_list_param(params, reps_input, mod)
new_input = [inputs.pop(0)]

return AttrCvt(op_name="tile", extras={"reps": tuple(reps)}, ignores=["Tmultiples"])(
Expand All @@ -1671,21 +1671,15 @@ def _impl(inputs, attr, params, mod):
def _slice():
def _impl(inputs, attr, params, mod):
try:
begin = _get_list_param(params, inputs[1])
except (IndexError, KeyError, AttributeError):
begin = _get_list_param(params, inputs[1], mod)
except Exception:
# Handle symbolic begin
try:
begin = _infer_value(inputs[1], params, mod).asnumpy().tolist()
except Exception:
begin = inputs[1]
begin = inputs[1]
try:
size = _get_list_param(params, inputs[2])
except (IndexError, KeyError, AttributeError):
size = _get_list_param(params, inputs[2], mod)
except Exception:
# Handle symbolic size
try:
size = _infer_value(inputs[2], params, mod).asnumpy().tolist()
except Exception:
size = inputs[2]
size = inputs[2]

# Align begin and strides for dynamic shape.
data_dim = len(_infer_shape(inputs[0], mod))
Expand Down Expand Up @@ -1962,7 +1956,7 @@ def _impl(inputs, attr, params, mod):

def _reduce(op):
def _impl(inputs, attr, params, mod):
axis = _get_list_param(params, inputs[1])
axis = _get_list_param(params, inputs[1], mod)
axis = tuple(axis)
if not axis:
axis = None
Expand All @@ -1978,7 +1972,7 @@ def _impl(inputs, attr, params, mod):

def _euclidean_norm():
def _impl(inputs, attr, params, mod):
axis = tuple(_get_list_param(params, inputs[1]))
axis = tuple(_get_list_param(params, inputs[1], mod))
keep_dims = bool(attr.get("keep_dims", False))
return _op.sqrt(
_op.cast(_op.reduce.sum(_op.multiply(inputs[0], inputs[0]), axis, keep_dims), "float32")
Expand Down Expand Up @@ -2039,9 +2033,9 @@ def _impl(inputs, attr, params, mod):
Tensorflow mask validation: https://github.com/tensorflow/tensorflow/blob/master/
tensorflow/core/util/strided_slice_op.cc#L147-L368
"""
begin = _get_list_param(params, inputs[1])
end = _get_list_param(params, inputs[2])
stride = _get_list_param(params, inputs[3])
begin = _get_list_param(params, inputs[1], mod)
end = _get_list_param(params, inputs[2], mod)
stride = _get_list_param(params, inputs[3], mod)

begin_mask = int(attr.get("begin_mask", 0))
end_mask = int(attr.get("end_mask", 0))
Expand Down Expand Up @@ -2243,10 +2237,7 @@ def _transpose():
def _impl(inputs, attr, params, mod):
# If perm is not specified, axes is left empty,
# otherwise its value is get from params
try:
axes = _get_list_param(params, inputs[1])
except (IndexError, KeyError, AttributeError):
axes = _infer_value(inputs[1], params, mod).asnumpy().tolist()
axes = _get_list_param(params, inputs[1], mod)
return _op.transpose(inputs[0], axes=axes)

return _impl
Expand Down Expand Up @@ -2536,19 +2527,13 @@ def _impl(inputs, attr, params, mod):

def _space_to_batch_nd():
def _impl(inputs, attr, params, mod):
try:
block_shape = _get_list_param(params, inputs[1])
except (IndexError, KeyError, AttributeError):
block_shape = _infer_value(inputs[1], params, mod).asnumpy().tolist()
block_shape = _get_list_param(params, inputs[1], mod)

try:
paddings = _get_list_param(params, inputs[2])
except (IndexError, KeyError, AttributeError):
paddings = _infer_value(inputs[2], params, mod).asnumpy()
paddings = np.squeeze(paddings)
if len(paddings.shape) == 1:
paddings = np.expand_dims(paddings, axis=0)
paddings = paddings.tolist()
paddings = _get_list_param(params, inputs[2], mod)
paddings = np.squeeze(paddings)
if len(paddings.shape) == 1:
paddings = np.expand_dims(paddings, axis=0)
paddings = paddings.tolist()

attr["block_shape"] = block_shape
attr["paddings"] = paddings
Expand All @@ -2561,19 +2546,13 @@ def _impl(inputs, attr, params, mod):

def _batch_to_space_nd():
def _impl(inputs, attr, params, mod):
try:
block_shape = _get_list_param(params, inputs[1])
except (IndexError, KeyError, AttributeError):
block_shape = _infer_value(inputs[1], params, mod).asnumpy().tolist()
block_shape = _get_list_param(params, inputs[1], mod)

try:
crops = _get_list_param(params, inputs[2])
except (IndexError, KeyError, AttributeError):
crops = _infer_value(inputs[2], params, mod).asnumpy()
crops = np.squeeze(crops)
if len(crops.shape) == 1:
crops = np.expand_dims(crops, axis=0)
crops = crops.tolist()
crops = _get_list_param(params, inputs[2], mod)
crops = np.squeeze(crops)
if len(crops.shape) == 1:
crops = np.expand_dims(crops, axis=0)
crops = crops.tolist()

attr["block_shape"] = block_shape
attr["crops"] = crops
Expand Down

0 comments on commit c1321b9

Please sign in to comment.