Skip to content

Commit

Permalink
Fix prelu bug in pytorch frontend (apache#8192)
Browse files Browse the repository at this point in the history
* Fix prelu bug in pytorch frontend

* Fix lint error

* fix lint error

* Fix lint error

* Try to fix lint error

* Fix lint error

Co-authored-by: huangyuheng <[email protected]>
  • Loading branch information
2 people authored and trevor-m committed Jun 17, 2021
1 parent 36c249e commit 3f2f0a0
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 2 deletions.
8 changes: 6 additions & 2 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -754,9 +754,13 @@ def relu(self, inputs, input_types):
return _op.nn.relu(data)

def prelu(self, inputs, input_types):
# Reference: https://pytorch.org/docs/stable/generated/torch.nn.PReLU.html#torch.nn.PReLU
data = inputs[0]
alpha = inputs[1]
return _op.nn.prelu(data, alpha)
dim = self.get_dims(data)
ndims = len(dim)
axis = 0 if ndims == 1 else 1
alpha = _op.broadcast_to(inputs[1], (dim[axis]))
return _op.nn.prelu(data, alpha, axis)

def leaky_relu(self, inputs, input_types):
data = inputs[0]
Expand Down
4 changes: 4 additions & 0 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,6 +643,10 @@ def test_forward_prelu():
input_shape = [1, 3, 10, 10]
input_data = torch.rand(input_shape).float()
verify_model(torch.nn.PReLU(num_parameters=3).eval(), input_data=input_data)
# Test when input channel > 1 and num parameters = 1
verify_model(torch.nn.PReLU(num_parameters=1).eval(), input_data=input_data)
# Test when input dims < 2
verify_model(torch.nn.PReLU(num_parameters=1).eval(), input_data=torch.randn(2))


@tvm.testing.uses_gpu
Expand Down

0 comments on commit 3f2f0a0

Please sign in to comment.