From 67bba9032577025419dc0e110fdf4b08c5f66895 Mon Sep 17 00:00:00 2001 From: Alexander Pivovarov Date: Tue, 2 Mar 2021 21:45:51 -0800 Subject: [PATCH] [torch] Add linear operator support (#7569) --- python/tvm/relay/frontend/pytorch.py | 15 ++++++++ tests/python/frontend/pytorch/test_forward.py | 34 +++++++++++++++++++ 2 files changed, 49 insertions(+) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 3c61749fc203..dcf2f08caeef 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -1374,6 +1374,20 @@ def avg_pool3d(self, inputs, input_types): count_include_pad=count_include_pad, ) + def linear(self, inputs, input_types): + # https://pytorch.org/docs/stable/nn.functional.html#linear + # 0 - input + # 1 - weight + bias = inputs[2] + mm_out = self.matmul(inputs[:2], input_types[:2]) + if isinstance(bias, _expr.Expr): + bias_ndims = len(self.infer_shape_with_prelude(bias)) + if bias_ndims == 1: + return _op.nn.bias_add(mm_out, bias) + mm_dtype = self.infer_type_with_prelude(mm_out).dtype + return self.add([mm_out, bias], [mm_dtype, input_types[2]]) + return mm_out + def dropout(self, inputs, input_types): data = inputs[0] rate = float(inputs[1]) @@ -2289,6 +2303,7 @@ def create_convert_map(self): "aten::softplus": self.softplus, "aten::avg_pool2d": self.avg_pool2d, "aten::avg_pool3d": self.avg_pool3d, + "aten::linear": self.linear, "aten::dropout": self.dropout, "aten::dropout_": self.dropout, "aten::feature_dropout": self.dropout, diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 9f035ade7a21..54bf2fd49acb 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -24,6 +24,7 @@ import torch import torchvision from torch.nn import Module +from torch.nn import functional as F import tvm from tvm import relay from tvm.contrib import graph_runtime @@ -1459,6 +1460,39 @@ def forward(self, *args): assert not any([op.name == "multiply" for op in list_ops(mod["main"])]) +@tvm.testing.uses_gpu +def test_forward_linear(): + torch.set_grad_enabled(False) + + class Linear(Module): + def forward(self, input, weight, bias): + return F.linear(input, weight, bias) + + class LinearNoBias(Module): + def forward(self, input, weight): + return F.linear(input, weight) + + input2d = torch.rand([2, 2]).float() + weight1d = torch.rand([2]).float() + weight2d = torch.rand([2, 2]).float() + bias1d = torch.rand([2]).float() + bias2d = torch.rand([2, 2]).float() + # 2D input, 2D weight, 1D bias + verify_model(Linear(), input_data=[input2d, weight2d, bias1d]) + # 2D input, 2D weight, 2D bias + verify_model(Linear(), input_data=[input2d, weight2d, bias2d]) + # 2D input, 2D weight, no bias + verify_model(LinearNoBias(), input_data=[input2d, weight2d]) + # 2D input, 1D weight, 1D bias is not supported by torch.linear() + # 2D input, 1D weight, no bias + verify_model(LinearNoBias(), input_data=[input2d, weight1d]) + # TODO: Add the following cases when matmul(1D, _) is supported by TVM + # 1D input, 2D weight, 1D bias + # 1D input, 2D weight, no bias + # 1D input, 1D weight, scalar bias + # 1D input, 1D weight, no bias + + @tvm.testing.uses_gpu def test_forward_dropout(): torch.set_grad_enabled(False)