Skip to content

Commit

Permalink
[Caffe Frontend] Add support for Permute layer (apache#9157)
Browse files Browse the repository at this point in the history
* Add support for Permute layer

* Add test for Permute layer

* Fix alignment
  • Loading branch information
mshr-h authored and crazydemo committed Jan 27, 2022
1 parent 8e45881 commit 1cdd308
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 0 deletions.
12 changes: 12 additions & 0 deletions python/tvm/relay/frontend/caffe.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def __init__(self, init_layer_dict, predict_layer, exp_tab):
"InnerProduct": self.convert_innerproduct,
"Input": None,
"LRN": self.convert_lrn,
"Permute": self.convert_permute,
"Pooling": self.convert_pooling,
"PReLU": self.convert_prelu,
"ReLU": self.convert_relu,
Expand Down Expand Up @@ -597,6 +598,17 @@ def convert_crop(self, op):
out = _op.slice_like(in_expr_a_stride, in_expr_b, axes=to_crop_axis)
return out

def convert_permute(self, op):
"""Convert Permute layer"""
inputs = op.bottom
in_expr = self.exp_tab.get_expr(inputs[0])

# parse permute params
permute_param = op.permute_param
axes = list(getattr(permute_param, "order", 0))
out = _op.transpose(in_expr, axes)
return out

def convert_embed(self, op):
"""Convert Embed layer"""
inputs = op.bottom
Expand Down
21 changes: 21 additions & 0 deletions tests/python/frontend/caffe/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,6 +585,27 @@ def test_forward_LRN():
_test_lrn(data, local_size=3, alpha=2.0, beta=0.5, k=2.0)


#######################################################################
# Permute
# -------


def _test_permute(data, **kwargs):
"""One iteration of Permute."""
_test_op(data, L.Permute, "Permute", **kwargs)


def test_forward_Permute():
"""Permute"""
data = np.random.rand(2, 3, 4).astype(np.float32)
_test_permute(data, permute_param={"order": [0, 1, 2]})
_test_permute(data, permute_param={"order": [0, 2, 1]})
_test_permute(data, permute_param={"order": [1, 0, 2]})
_test_permute(data, permute_param={"order": [1, 2, 0]})
_test_permute(data, permute_param={"order": [2, 0, 1]})
_test_permute(data, permute_param={"order": [2, 1, 0]})


#######################################################################
# Pooling
# -----------
Expand Down

0 comments on commit 1cdd308

Please sign in to comment.