From 8cd20dda6f4031c63a179a5c7430671c3248706e Mon Sep 17 00:00:00 2001 From: YJ Shi Date: Fri, 1 Jul 2022 01:34:05 -0700 Subject: [PATCH 1/3] add trilu --- python/tvm/relay/frontend/pytorch.py | 39 +++++++++++++++++++ tests/python/frontend/pytorch/test_forward.py | 36 +++++++++++++++++ 2 files changed, 75 insertions(+) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 123b0299839e..1b89a5528a66 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -45,6 +45,8 @@ from .common import infer_value as _infer_value from .common import infer_value_simulated as _infer_value_simulated from .common import lstm_cell, try_infer_value, unbind +from .common import infer_type as _infer_type +from .common import fold_constant as _fold_constant from .pytorch_utils import is_version_greater_than, getattr_attr_name __all__ = ["from_pytorch"] @@ -319,6 +321,41 @@ def square(self, inputs, input_types): (dtype,) = input_types return _op.power(inputs[0], _expr.const(2, dtype)) + def trilu(self, upper, inputs, input_types): + data = inputs[0] + if len(inputs) == 2: + k_value = inputs[1] + else: + k_value = 0 + k_tensor = _op.const(np.asarray(k_value), dtype=np.int64) + + input_shape = _op.shape_of(data, input_types[0]) + input_dims = _infer_shape(input_shape)[0] + data_type = _infer_type(data).checked_type.dtype + diag_input = _op.zeros(_fold_constant(input_shape), dtype=data_type) + + if upper == 0: # tril case + k1 = _op.add(k_tensor, _op.const(1, dtype="int64")) + k1 = _op.expand_dims(k1, axis=0) + k2 = _op.take(input_shape, _op.const(input_dims - 1, dtype="int32")) + k2 = _op.expand_dims(k2, axis=0) + elif upper == 1: # triu case + k1 = _op.take(input_shape, _op.const(input_dims - 2, dtype="int32")) + k1 = _op.multiply(k1, _op.const(-1, dtype="int64")) + k1 = _op.subtract(k1, _op.const(1, dtype="int64")) + k1 = _op.expand_dims(k1, axis=0) + k2 = _op.subtract(k_tensor, _op.const(1, dtype="int64")) + k2 = _op.expand_dims(k2, axis=0) + else: + raise ValueError("Upper argument for trilu can only be 0/1.") + return _op.matrix_set_diag(data, diag_input, (k1, k2)) + + def tril(self, inputs, input_types): + return self.trilu(0, inputs, input_types) + + def triu(self, inputs, input_types): + return self.trilu(1, inputs, input_types) + def arange(self, inputs, input_types): def _get_value(val, dtype): # dtype is a tvm dtype @@ -3328,6 +3365,8 @@ def create_convert_map(self): "aten::sqrt": self.make_unary("sqrt"), "aten::rsqrt": self.make_unary("rsqrt"), "aten::square": self.square, + "aten::tril": self.tril, + "aten::triu": self.triu, "aten::ceil": self.make_unary("ceil"), "aten::floor": self.make_unary("floor"), "aten::round": self.make_unary("round"), diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 4f42c183b66a..ddc0d068aeab 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -3423,6 +3423,42 @@ def forward(self, *args): verify_model(Neg1().float().eval(), input_data=input_data) +@tvm.testing.uses_gpu +def test_forward_tril(): + torch.set_grad_enabled(False) + input_shape = [1, 3, 10, 10] + input_data = torch.rand(input_shape).float() + + class Tril1(Module): + def forward(self, *args): + return torch.tril(args[0]) + + class Tril2(Module): + def forward(self, *args): + return torch.tril(args[0], 1) + + verify_model(Tril1().float().eval(), input_data=input_data) + verify_model(Tril2().float().eval(), input_data=input_data) + + +@tvm.testing.uses_gpu +def test_forward_triu(): + torch.set_grad_enabled(False) + input_shape = [1, 3, 10, 10] + input_data = torch.rand(input_shape).float() + + class Triu1(Module): + def forward(self, *args): + return torch.triu(args[0]) + + class Triu2(Module): + def forward(self, *args): + return torch.triu(args[0], 1) + + verify_model(Triu1().float().eval(), input_data=input_data) + verify_model(Triu2().float().eval(), input_data=input_data) + + @tvm.testing.uses_gpu def test_forward_where(): torch.set_grad_enabled(False) From 9f81737c88b3b4431b2388da2117eff704c45ce3 Mon Sep 17 00:00:00 2001 From: YJ Shi Date: Fri, 1 Jul 2022 12:31:36 -0700 Subject: [PATCH 2/3] update triu and tril; fix empty --- python/tvm/relay/frontend/pytorch.py | 44 ++++------ tests/python/frontend/pytorch/test_forward.py | 85 +++++++++++++------ 2 files changed, 75 insertions(+), 54 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 1b89a5528a66..553602a233a9 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -321,40 +321,30 @@ def square(self, inputs, input_types): (dtype,) = input_types return _op.power(inputs[0], _expr.const(2, dtype)) - def trilu(self, upper, inputs, input_types): + def tril(self, inputs, input_types): data = inputs[0] if len(inputs) == 2: k_value = inputs[1] else: k_value = 0 - k_tensor = _op.const(np.asarray(k_value), dtype=np.int64) - - input_shape = _op.shape_of(data, input_types[0]) - input_dims = _infer_shape(input_shape)[0] - data_type = _infer_type(data).checked_type.dtype - diag_input = _op.zeros(_fold_constant(input_shape), dtype=data_type) - - if upper == 0: # tril case - k1 = _op.add(k_tensor, _op.const(1, dtype="int64")) - k1 = _op.expand_dims(k1, axis=0) - k2 = _op.take(input_shape, _op.const(input_dims - 1, dtype="int32")) - k2 = _op.expand_dims(k2, axis=0) - elif upper == 1: # triu case - k1 = _op.take(input_shape, _op.const(input_dims - 2, dtype="int32")) - k1 = _op.multiply(k1, _op.const(-1, dtype="int64")) - k1 = _op.subtract(k1, _op.const(1, dtype="int64")) - k1 = _op.expand_dims(k1, axis=0) - k2 = _op.subtract(k_tensor, _op.const(1, dtype="int64")) - k2 = _op.expand_dims(k2, axis=0) - else: - raise ValueError("Upper argument for trilu can only be 0/1.") - return _op.matrix_set_diag(data, diag_input, (k1, k2)) - - def tril(self, inputs, input_types): - return self.trilu(0, inputs, input_types) + input_shape = self.infer_shape(data) + k1, k2 = input_shape[-2:] + k1 = k_value + 1 + diag_input = _op.zeros(input_shape, dtype=input_types[0]) + return _op.matrix_set_diag(data, diag_input, k=(k1, k2)) def triu(self, inputs, input_types): - return self.trilu(1, inputs, input_types) + data = inputs[0] + if len(inputs) == 2: + k_value = inputs[1] + else: + k_value = 0 + input_shape = self.infer_shape(data) + k1, k2 = input_shape[-2:] + k1 = (k1 * -1) - 1 + k2 = k_value - 1 + diag_input = _op.zeros(input_shape, dtype=input_types[0]) + return _op.matrix_set_diag(data, diag_input, k=(k1, k2)) def arange(self, inputs, input_types): def _get_value(val, dtype): diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index ddc0d068aeab..80a5cd07f7b6 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -199,12 +199,21 @@ def visit(op): torch.cuda.empty_cache() -def verify_model_with_input(test_func, input_data, input_dict={}): +def verify_model_with_input( + test_func, + input_data, + *, + input_dict={}, + custom_convert_map={}, + rtol=1e-5, + atol=1e-5, + assert_shape_only=False, +): baseline_outputs = test_func(*input_data) trace = torch.jit.trace(test_func, [input.clone() for input in input_data]) input_names = ["input{}".format(idx) for idx, inp in enumerate(input_data)] input_shapes = list(zip(input_names, [inp.shape for inp in input_data])) - mod, params = relay.frontend.from_pytorch(trace, input_shapes, {}) + mod, params = relay.frontend.from_pytorch(trace, input_shapes, custom_convert_map) with tvm.transform.PassContext(opt_level=3): for target in ["llvm", "cuda"]: if not tvm.runtime.enabled(target): @@ -218,7 +227,8 @@ def verify_model_with_input(test_func, input_data, input_dict={}): compiled_output = relay_model.get_output(0).numpy() assert_shapes_match(baseline_outputs, compiled_output) - tvm.testing.assert_allclose(baseline_outputs, compiled_output, rtol=1e-5, atol=1e-5) + if assert_shape_only == False: + tvm.testing.assert_allclose(baseline_outputs, compiled_output, rtol=rtol, atol=atol) # Single operator tests @@ -1304,7 +1314,7 @@ def test_func(input_tensor, other_tensor): input_data = [torch.rand([2, 1, 10, 1, 10]), torch.rand([2, 1, 10, 10])] - verify_model_with_input(test_func, input_data, {"input0": input_data[0]}) + verify_model_with_input(test_func, input_data, input_dict={"input0": input_data[0]}) @tvm.testing.uses_gpu @@ -3426,37 +3436,59 @@ def forward(self, *args): @tvm.testing.uses_gpu def test_forward_tril(): torch.set_grad_enabled(False) - input_shape = [1, 3, 10, 10] - input_data = torch.rand(input_shape).float() - class Tril1(Module): - def forward(self, *args): - return torch.tril(args[0]) + def test_func(input_data): + return torch.tril(input_data) - class Tril2(Module): - def forward(self, *args): - return torch.tril(args[0], 1) + input_data = torch.rand([3, 3]).float() + verify_model(test_func, input_data=input_data) + input_data = torch.rand([1, 3, 10, 10]).float() + verify_model(test_func, input_data=input_data) + + def test_func1(input_data): + return torch.tril(input_data, 1) + + input_data = torch.rand([3, 3]).float() + verify_model(test_func1, input_data=input_data) + input_data = torch.rand([1, 3, 10, 10]).float() + verify_model(test_func1, input_data=input_data) + + def test_func2(input_data): + return torch.tril(input_data, -1) - verify_model(Tril1().float().eval(), input_data=input_data) - verify_model(Tril2().float().eval(), input_data=input_data) + input_data = torch.rand([3, 3]).float() + verify_model(test_func2, input_data=input_data) + input_data = torch.rand([1, 3, 10, 10]).float() + verify_model(test_func2, input_data=input_data) @tvm.testing.uses_gpu def test_forward_triu(): torch.set_grad_enabled(False) - input_shape = [1, 3, 10, 10] - input_data = torch.rand(input_shape).float() - class Triu1(Module): - def forward(self, *args): - return torch.triu(args[0]) + def test_func(input_data): + return torch.triu(input_data) - class Triu2(Module): - def forward(self, *args): - return torch.triu(args[0], 1) + input_data = torch.rand([3, 3]).float() + verify_model(test_func, input_data=input_data) + input_data = torch.rand([1, 3, 10, 10]).float() + verify_model(test_func, input_data=input_data) - verify_model(Triu1().float().eval(), input_data=input_data) - verify_model(Triu2().float().eval(), input_data=input_data) + def test_func1(input_data): + return torch.triu(input_data, 1) + + input_data = torch.rand([3, 3]).float() + verify_model(test_func1, input_data=input_data) + input_data = torch.rand([1, 3, 10, 10]).float() + verify_model(test_func1, input_data=input_data) + + def test_func2(input_data): + return torch.triu(input_data, -1) + + input_data = torch.rand([3, 3]).float() + verify_model(test_func2, input_data=input_data) + input_data = torch.rand([1, 3, 10, 10]).float() + verify_model(test_func2, input_data=input_data) @tvm.testing.uses_gpu @@ -3853,15 +3885,14 @@ def test_empty(): def test_func(): return torch.empty([1, 3, 10, 10]) - verify_model_with_input(test_func, []) + verify_model_with_input(test_func, [], assert_shape_only=True) -@pytest.mark.skip(reason="See https://github.com/apache/tvm/issues/11967") def test_empty_like(): def test_func(data): return torch.empty_like(data) - verify_model_with_input(test_func, [torch.rand([1, 3, 10, 10]).float()]) + verify_model_with_input(test_func, [torch.rand([1, 3, 10, 10]).float()], assert_shape_only=True) def test_forward_pretrained_bert_base_uncased(): From d710b506ad4d462413783cc2041f6b04e38c2b3f Mon Sep 17 00:00:00 2001 From: YJ Shi Date: Fri, 1 Jul 2022 12:33:01 -0700 Subject: [PATCH 3/3] fix lint --- python/tvm/relay/frontend/pytorch.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 553602a233a9..cb5392fa16ab 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -45,8 +45,6 @@ from .common import infer_value as _infer_value from .common import infer_value_simulated as _infer_value_simulated from .common import lstm_cell, try_infer_value, unbind -from .common import infer_type as _infer_type -from .common import fold_constant as _fold_constant from .pytorch_utils import is_version_greater_than, getattr_attr_name __all__ = ["from_pytorch"]