From 3f2f0a0177074c04afb9444b0c438ddf8e76af11 Mon Sep 17 00:00:00 2001 From: yuheng huang <32429436+YuhengHuang42@users.noreply.github.com> Date: Sat, 5 Jun 2021 04:10:19 +0800 Subject: [PATCH] Fix prelu bug in pytorch frontend (#8192) * Fix prelu bug in pytorch frontend * Fix lint error * fix lint error * Fix lint error * Try to fix lint error * Fix lint error Co-authored-by: huangyuheng <32429436+hyhzxhy@users.noreply.github.com> --- python/tvm/relay/frontend/pytorch.py | 8 ++++++-- tests/python/frontend/pytorch/test_forward.py | 4 ++++ 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index f0ba99291727..acc33d73e826 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -754,9 +754,13 @@ def relu(self, inputs, input_types): return _op.nn.relu(data) def prelu(self, inputs, input_types): + # Reference: https://pytorch.org/docs/stable/generated/torch.nn.PReLU.html#torch.nn.PReLU data = inputs[0] - alpha = inputs[1] - return _op.nn.prelu(data, alpha) + dim = self.get_dims(data) + ndims = len(dim) + axis = 0 if ndims == 1 else 1 + alpha = _op.broadcast_to(inputs[1], (dim[axis])) + return _op.nn.prelu(data, alpha, axis) def leaky_relu(self, inputs, input_types): data = inputs[0] diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 07f0d8e75c4d..be4d74ed205a 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -643,6 +643,10 @@ def test_forward_prelu(): input_shape = [1, 3, 10, 10] input_data = torch.rand(input_shape).float() verify_model(torch.nn.PReLU(num_parameters=3).eval(), input_data=input_data) + # Test when input channel > 1 and num parameters = 1 + verify_model(torch.nn.PReLU(num_parameters=1).eval(), input_data=input_data) + # Test when input dims < 2 + verify_model(torch.nn.PReLU(num_parameters=1).eval(), input_data=torch.randn(2)) @tvm.testing.uses_gpu