Skip to content

Commit

Permalink
support einsum
Browse files Browse the repository at this point in the history
  • Loading branch information
xhuohai committed Jul 6, 2023
1 parent e969d17 commit e2e386a
Show file tree
Hide file tree
Showing 4 changed files with 152 additions and 0 deletions.
93 changes: 93 additions & 0 deletions src/Nncase.Importer/Onnx/Einsum.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.

using System;
using System.Linq;
using Nncase.IR;
using Onnx;
using F = Nncase.IR.F;

namespace Nncase.Importer
{
public partial class OnnxImporter
{
private Expr VisitEinsum(in NodeProto op)
{
// TODO: only support two inputs and '->' can not be ommitted
var equation = GetStringAttribute(op, "equation", string.Empty);
if (string.IsNullOrEmpty(equation) || equation.Count(c => c == ',') != 1)
{
throw new InvalidOperationException("Not Yet Supported Einsum Operation!");
}

var inTerm1 = equation.Split(',')[0];
var remains = equation.Split(',')[1];
var inTerm2 = remains.Substring(0, remains.IndexOf('-', System.StringComparison.Ordinal));
var outTerm = remains.Split('>')[1];
var (lhs, rhs) = GetInputExprs(op, 0, 1);

// i,j->ij
if (inTerm1.Length == 1 && inTerm2.Length == 1 && outTerm.Length == 2 && inTerm1 + inTerm2 == outTerm)
{
return F.Tensors.Unsqueeze(lhs, new[] { 1 }) * F.Tensors.Unsqueeze(rhs, new[] { 0 });
}

// ibh,hnd->ibnd
if (inTerm1.Length == 3 && inTerm2.Length == 3 && outTerm.Length == 4
&& inTerm1.Substring(0, 2) + inTerm2.Substring(1, 2) == outTerm
&& inTerm1.Last() == inTerm2.First())
{
var lhsShape = F.Tensors.ShapeOf(lhs);
var rhsShape = F.Tensors.ShapeOf(rhs);
var mm = F.Math.MatMul(lhs, F.Tensors.Reshape(rhs, F.Tensors.Stack(new IR.Tuple(rhsShape[0], rhsShape[1] * rhsShape[2]), 0)));
return F.Tensors.Reshape(mm, F.Tensors.Stack(new IR.Tuple(lhsShape[0], lhsShape[1], rhsShape[1], rhsShape[2]), 0));
}

// ibnd,jbnd->bnij
if (inTerm1.Length == 4 && inTerm2.Length == 4 && outTerm.Length == 4
&& inTerm1.Substring(1, 2) + inTerm1.First() + inTerm2.First() == outTerm
&& inTerm1.Substring(1, 2) == inTerm2.Substring(1, 2)
&& inTerm1.Last() == inTerm2.Last())
{
var lhsShape = F.Tensors.ShapeOf(lhs);
var rhsShape = F.Tensors.ShapeOf(rhs);
var mm = F.Math.MatMul(
F.Tensors.Transpose(F.Tensors.Reshape(lhs, F.Tensors.Stack(new IR.Tuple(lhsShape[0], lhsShape[1] * lhsShape[2], lhsShape[3]), 0)), new[] { 1, 0, 2 }),
F.Tensors.Transpose(F.Tensors.Reshape(rhs, F.Tensors.Stack(new IR.Tuple(rhsShape[0], rhsShape[1] * rhsShape[2], rhsShape[3]), 0)), new[] { 1, 2, 0 }));

return F.Tensors.Reshape(mm, F.Tensors.Stack(new IR.Tuple(lhsShape[1], lhsShape[2], lhsShape[0], rhsShape[0]), 0));
}

// bnij,jbnd->ibnd
if (inTerm1.Length == 4 && inTerm2.Length == 4 && outTerm.Length == 4
&& inTerm1[2] + inTerm1.Substring(0, 2) + inTerm2.Last() == outTerm
&& inTerm1.Substring(0, 2) == inTerm2.Substring(1, 2)
&& inTerm1.Last() == inTerm2.First())
{
var lhsShape = F.Tensors.ShapeOf(lhs);
var rhsShape = F.Tensors.ShapeOf(rhs);
var mm = F.Math.MatMul(
F.Tensors.Reshape(lhs, F.Tensors.Stack(new IR.Tuple(lhsShape[0] * lhsShape[1], lhsShape[2], lhsShape[3]), 0)),
F.Tensors.Transpose(F.Tensors.Reshape(rhs, F.Tensors.Stack(new IR.Tuple(rhsShape[0], rhsShape[1] * rhsShape[2], rhsShape[3]), 0)), new[] { 1, 0, 2 }));

return F.Tensors.Reshape(F.Tensors.Transpose(mm, new[] { 1, 0, 2 }), F.Tensors.Stack(new IR.Tuple(lhsShape[2], lhsShape[0], lhsShape[1], rhsShape[3]), 0));
}

// ibnd,hnd->ibh
if (inTerm1.Length == 4 && inTerm2.Length == 3 && outTerm.Length == 3
&& inTerm1.Substring(0, 2) + inTerm2.First() == outTerm
&& inTerm1.Substring(2, 2) == inTerm2.Substring(1, 2))
{
var lhsShape = F.Tensors.ShapeOf(lhs);
var rhsShape = F.Tensors.ShapeOf(rhs);
var mm = F.Math.MatMul(
F.Tensors.Reshape(lhs, F.Tensors.Stack(new IR.Tuple(lhsShape[0], lhsShape[1], lhsShape[2] * lhsShape[3]), 0)),
F.Tensors.Transpose(F.Tensors.Reshape(rhs, F.Tensors.Stack(new IR.Tuple(rhsShape[0], rhsShape[1] * rhsShape[2]), 0)), new[] { 1, 0 }));

return F.Tensors.Reshape(mm, F.Tensors.Stack(new IR.Tuple(lhsShape[0], lhsShape[1], rhsShape[0]), 0));
}

throw new InvalidOperationException("Not Yet Supported Einsum Operation!");
}
}
}
1 change: 1 addition & 0 deletions src/Nncase.Importer/Onnx/OnnxImporter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ private void Visit(NodeProto op)
"DequantizeLinear" => VisitDequantizeLinear(op),
"Div" => VisitBinary(op, BinaryOp.Div),
"Dropout" => VisitDropout(op),
"Einsum" => VisitEinsum(op),
"Elu" => VisitElu(op),
"Equal" => VisitCompare(op, CompareOp.Equal),
"Exp" => VisitUnary(op, UnaryOp.Exp),
Expand Down
56 changes: 56 additions & 0 deletions tests/importer/onnx_/basic/test_einsum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright 2019-2021 Canaan Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=invalid-name, unused-argument, import-outside-toplevel

import pytest
import torch
import numpy as np
from onnx_test_runner import OnnxTestRunner


def _make_module(case):
class EinsumModule(torch.nn.Module):
def __init__(self):
super(EinsumModule, self).__init__()
self.v = torch.from_numpy(np.random.rand(*case[1]).astype(np.float32))

def forward(self, x):
outs = []
outs.append(torch.einsum(case[2], x, self.v))
return outs

return EinsumModule()


cases = [
[[3], [4], "i,j->ij"],
[[2, 4, 6], [6, 4, 3], "ibh,hnd->ibnd"],
[[4, 2, 5, 6], [3, 2, 5, 6], "ibnd,jbnd->bnij"],
[[2, 5, 4, 6], [6, 2, 5, 3], "bnij,jbnd->ibnd"],
[[5, 2, 3, 4], [6, 3, 4], "ibnd,hnd->ibh"]
]


@pytest.mark.parametrize('case', cases)
def test_einsum(case, request):
module = _make_module(case)

runner = OnnxTestRunner(request.node.name)
model_file = runner.from_torch(module, case[0], 12)
runner.run(model_file)


if __name__ == "__main__":
pytest.main(
['-vv', 'test_einsum.py'])
2 changes: 2 additions & 0 deletions tests/importer/onnx_/basic/test_pad.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ def _make_module(in_shape, padding, constant_value, mode, op_version, value_form
nodes = []

out_shape = in_shape.copy()
out_shape[0] += padding[0] + padding[4]
out_shape[1] += padding[1] + padding[5]
out_shape[2] += padding[2] + padding[6]
out_shape[3] += padding[3] + padding[7]

Expand Down

0 comments on commit e2e386a

Please sign in to comment.