Skip to content

Commit

Permalink
Merge branch 'master' into release/2.0
Browse files Browse the repository at this point in the history
  • Loading branch information
sunnycase committed Jul 3, 2023
2 parents 0c1a646 + c2a438d commit 4a87051
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/Nncase.Compiler/Compiler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ public void TargetIndependentPass(IPassManager passManager)
p.Add<Passes.Rules.Neutral.Relu6ToClamp>();
p.Add<Passes.Rules.Neutral.FoldNopSlice>();
p.Add<Passes.Rules.Neutral.FoldTwoSlices>();
p.Add<Passes.Rules.Neutral.SpaceToBatchToPad>();
});

// passManager.AddWithName<EGraphPass>("NeutralOptimizeClamp").Configure(p =>
Expand Down
56 changes: 56 additions & 0 deletions src/Nncase.Passes/Rules/Neutral/SpaceToBatchTransform.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.

using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;
using Nncase.IR;
using Nncase.IR.Math;
using Nncase.IR.NN;
using Nncase.IR.Tensors;
using Nncase.Passes;
using Nncase.PatternMatch;
using static Nncase.IR.F.Math;
using static Nncase.IR.F.NN;
using static Nncase.IR.F.Tensors;
using static Nncase.IR.TypePatternUtility;
using static Nncase.PatternMatch.F.Math;
using static Nncase.PatternMatch.F.NN;
using static Nncase.PatternMatch.F.Tensors;
using static Nncase.PatternMatch.Utility;

namespace Nncase.Passes.Rules.Neutral;

/// <summary>
/// squeeze to reshape.
/// </summary>
[RuleGenerator]
public sealed partial class SpaceToBatchToPad : IRewriteRule
{
/// <inheritdoc/>
public IPattern Pattern { get; } = IsSpaceToBatch(
IsWildcard("input") with { TypePattern = HasFixedShape() },
IsTensorConst("blockShape"),
IsTensorConst("paddings"));

private Expr? GetReplace(Expr input, Tensor<int> blockShape, Tensor<int> paddings)
{
var blockShapeArray = blockShape.ToArray();
var paddingsArray = paddings.ToArray();
if (input.CheckedShape.Rank == 4 && blockShapeArray.Length == 2 && blockShapeArray[0] == 1 && blockShape[1] == 1)
{
var newPaddingsArray = new int[8];
for (var i = 0; i < paddingsArray.Length; i++)
{
newPaddingsArray[i + 2] = paddingsArray[i];
}

var newPaddings = Tensor.From(newPaddingsArray, new[] { 4, 2 });

return Pad(input, newPaddings, PadMode.Constant, 0f);
}

return null;
}
}
54 changes: 54 additions & 0 deletions src/Nncase.Tests/Rules/Neutral/UnitTestSpaceToBatchTransform.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.

using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Text;
using System.Threading.Tasks;
using Nncase.Passes;
using Nncase.Passes.Rules.Neutral;
using Xunit;
using Math = Nncase.IR.F.Math;
using NN = Nncase.IR.F.NN;
using Random = Nncase.IR.F.Random;

namespace Nncase.Tests.Rules.NeutralTest;

public class UnitTestSpaceToBatchToPad : TransformTestBase
{
public static IEnumerable<object[]> TestSpaceToBatchToPadPositiveData =>
new[]
{
new object[] { new[] { 1, 128, 128, 3 }, new[] { 1, 1 }, new[,] { { 1, 1 }, { 1, 1 } } },
new object[] { new[] { 3, 64, 64, 16 }, new[] { 1, 1 }, new[,] { { 2, 2 }, { 0, 3 } } },
new object[] { new[] { 3, 32, 32, 16 }, new[] { 1, 1 }, new[,] { { 3, 8 }, { 7, 4 } } },
};

public static IEnumerable<object[]> TestSpaceToBatchToPadNegativeData =>
new[]
{
new object[] { new[] { 1, 128, 128, new IR.Dimension(1) }, new[] { 2, 2 }, new[,] { { 0, 0 }, { 0, 0 } } },
new object[] { new[] { 1, 128, 128, IR.Dimension.Unknown }, new[] { 1, 1 }, new[,] { { 1, 1 }, { 1, 1 } } },
};

[Theory]
[MemberData(nameof(TestSpaceToBatchToPadPositiveData))]
public void TestSpaceToBatchToPadPositive(int[] shape, int[] blockShape, int[,] paddings)
{
var a = Random.Normal(DataTypes.Float32, 0, 1, 0, shape);
var rootPre = NN.SpaceToBatch(a, blockShape, paddings);
TestMatched<SpaceToBatchToPad>(rootPre);
}

[Theory]
[MemberData(nameof(TestSpaceToBatchToPadNegativeData))]
public void TestFlattenToReshapeNegative(IR.Dimension[] shape, int[] blockShape, int[,] paddings)
{
var a = new IR.Var(new IR.TensorType(DataTypes.Float32, shape));
var rootPre = NN.SpaceToBatch(a, blockShape, paddings);
TestNotMatch<SpaceToBatchToPad>(rootPre);
}
}

0 comments on commit 4a87051

Please sign in to comment.