diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index dc5938931ed0..3887b40141c7 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -1698,6 +1698,8 @@ def matmul(self, inputs, input_types): return output elif len(a_shape) > 2: inputs_0 = _op.reshape(inputs_0, [-1, a_shape[-1]]) + elif len(a_shape) == 1: + return _op.squeeze(_op.nn.matmul(_op.expand_dims(inputs_0, axis=0), inputs_1), axis=[0]) if len(b_shape) > 2: trans_axes = list(range(len(b_shape))) diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 1abd59dce811..642beb015fec 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -3511,6 +3511,11 @@ def forward(self, *args): tensor2 = torch.randn(4) verify_model(MatMul1().float().eval(), input_data=[tensor1, tensor2]) + # vector x matrix + tensor1 = torch.randn(4) + tensor2 = torch.randn(4, 3) + verify_model(MatMul1().float().eval(), input_data=[tensor1, tensor2]) + # matrix x matrix tensor1 = torch.randn(10, 4) tensor2 = torch.randn(4, 10)