Skip to content

Commit

Permalink
Add support for aten::dot (apache#9893)
Browse files Browse the repository at this point in the history
* Add support for aten::dot

This implements dot product as a composite of of multiply + sum

* address comments

Co-authored-by: driazati <[email protected]>
  • Loading branch information
2 people authored and ylc committed Feb 16, 2022
1 parent ba6cd7a commit d7e2dbb
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 1 deletion.
2 changes: 1 addition & 1 deletion python/tvm/relay/frontend/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -958,7 +958,7 @@ def try_resolve_var_to_const(x, graph_params):


def set_span(sym, node_name):
"""Set up the sapn of relay expression(s) while converting OP"""
"""Set up the span of relay expression(s) while converting OP"""

class SpanFiller(ExprMutator):
"""SpanFiller"""
Expand Down
5 changes: 5 additions & 0 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2852,6 +2852,10 @@ def einsum(self, inputs, input_types):
equation, data = inputs
return _op.einsum(data, equation)

def dot(self, inputs, _):
lhs, rhs = inputs
return _op.sum(_op.multiply(lhs, rhs))

# Operator mappings
def create_convert_map(self):
self.convert_map = {
Expand Down Expand Up @@ -3076,6 +3080,7 @@ def create_convert_map(self):
"aten::bucketize": self.bucketize,
"aten::roll": self.roll,
"aten::einsum": self.einsum,
"aten::dot": self.dot,
}

def update_convert_map(self, custom_map):
Expand Down
9 changes: 9 additions & 0 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -4087,5 +4087,14 @@ def test_fn(equation):
verify_model(test_fn("ij,jk,km->im"), [x, y, z])


@tvm.testing.uses_gpu
def test_dot():
def test_fn(x):
return x.dot(x)

x = torch.randn([4])
verify_model(test_fn, [x])


if __name__ == "__main__":
pytest.main([__file__])

0 comments on commit d7e2dbb

Please sign in to comment.