diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index f0ba99291727..acc33d73e826 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -754,9 +754,13 @@ def relu(self, inputs, input_types): return _op.nn.relu(data) def prelu(self, inputs, input_types): + # Reference: https://pytorch.org/docs/stable/generated/torch.nn.PReLU.html#torch.nn.PReLU data = inputs[0] - alpha = inputs[1] - return _op.nn.prelu(data, alpha) + dim = self.get_dims(data) + ndims = len(dim) + axis = 0 if ndims == 1 else 1 + alpha = _op.broadcast_to(inputs[1], (dim[axis])) + return _op.nn.prelu(data, alpha, axis) def leaky_relu(self, inputs, input_types): data = inputs[0] diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 07f0d8e75c4d..be4d74ed205a 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -643,6 +643,10 @@ def test_forward_prelu(): input_shape = [1, 3, 10, 10] input_data = torch.rand(input_shape).float() verify_model(torch.nn.PReLU(num_parameters=3).eval(), input_data=input_data) + # Test when input channel > 1 and num parameters = 1 + verify_model(torch.nn.PReLU(num_parameters=1).eval(), input_data=input_data) + # Test when input dims < 2 + verify_model(torch.nn.PReLU(num_parameters=1).eval(), input_data=torch.randn(2)) @tvm.testing.uses_gpu