Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Bucket #1159

Merged
merged 17 commits into from
Jan 24, 2024
18 changes: 15 additions & 3 deletions src/Nncase.Compiler/Compiler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,12 @@
using Nncase.Passes.Rules.Neutral;
using Nncase.Passes.Rules.ShapeBucket;
using Nncase.Passes.Rules.ShapeExpr;
using Nncase.Passes.Rules.WithMarker;
using Nncase.Passes.Transforms;
using Nncase.Quantization;
using static Nncase.Passes.Rules.ShapeBucket.ShapeBucketRegister;
using CombinePadTranspose = Nncase.Passes.Rules.WithMarker.CombinePadTranspose;
using CombineReshapePad = Nncase.Passes.Rules.Neutral.CombineReshapePad;
using FoldConstCall = Nncase.Passes.Rules.Neutral.FoldConstCall;

namespace Nncase.Compiler;
Expand Down Expand Up @@ -97,6 +100,7 @@
p.Add<Passes.Rules.Neutral.NormAxisReshape>();
p.Add<Passes.Rules.Neutral.NormAxisReduceArg>();
p.Add<Passes.Rules.Neutral.NormAxisSlice>();
p.Add<Passes.Rules.Neutral.SwapBinaryArgs>();
p.Add<Passes.Rules.Neutral.SqueezeTransposeShape>();
p.Add<Passes.Rules.Neutral.Squeeze5DTranspose>();
p.Add<Passes.Rules.Neutral.SqueezeBinaryShape>();
Expand All @@ -117,8 +121,6 @@
p.Add<Passes.Rules.Neutral.FoldTwoSlices>();
p.Add<Passes.Rules.Neutral.FocusFull>();
p.Add<Passes.Rules.Neutral.ReshapeMatMul>();
p.Add<Passes.Rules.Neutral.SplitSpaceToBatch>();
p.Add<Passes.Rules.Neutral.SplitBatchToSpace>();
p.Add<Passes.Rules.Neutral.FoldConstCall>();
p.Add<Passes.Rules.Neutral.FoldShapeOf>();
p.Add<Passes.Rules.Neutral.FoldTwoReshapes>();
Expand All @@ -131,20 +133,28 @@
p.Add<Passes.Rules.Neutral.FoldUnsqueezeSqueeze>();
p.Add<Passes.Rules.Neutral.FoldTwoTransposes>();
p.Add<Passes.Rules.Neutral.FoldNopClamp>();
p.Add<Passes.Rules.ShapeBucket.FoldRepeatMarker>();
p.Add<Passes.Rules.Neutral.SqueezeToReshape>();
p.Add<Passes.Rules.Neutral.UnSqueezeToReshape>();
p.Add<Passes.Rules.ShapeExpr.GatherToGetItem>();
p.Add<Passes.Rules.ShapeExpr.FoldGetItemShapeOf>();
p.Add<Passes.Rules.Neutral.FoldNopReduce>();
p.Add<Passes.Rules.Neutral.SliceToGetItem>();
p.Add<Passes.Rules.Neutral.FoldTwoPads>();
p.Add<Passes.Rules.Neutral.FoldDilatedConv2D>();
});

passManager.AddWithName<EGraphRulesPass>("NeutralOptimizeTranspose").Configure(p =>
{
p.Add<Passes.Rules.Neutral.FoldConstCall>();
p.Add<Passes.Rules.Neutral.FoldNopTranspose>();
p.Add<Passes.Rules.Neutral.FoldTwoTransposes>();
p.Add<FoldRepeatMarker>();
p.Add<Passes.Rules.WithMarker.FoldTransposeActTranspose>();
p.Add<Passes.Rules.WithMarker.FoldTransposeBinaryActTranspose>();
p.Add<Passes.Rules.WithMarker.CombineReshapePad>();
p.Add<Passes.Rules.WithMarker.CombineTransposePad>();
p.Add<Passes.Rules.WithMarker.CombinePadTranspose>();
p.Add<Passes.Rules.Neutral.CombineTransposeUnary>();
p.Add<Passes.Rules.Neutral.CombineTransposePad>();
p.Add<Passes.Rules.Neutral.CombinePadTranspose>();
Expand Down Expand Up @@ -179,12 +189,14 @@
p.Add<Passes.Rules.Neutral.FoldNopSlice>();
p.Add<Passes.Rules.Neutral.FoldTwoSlices>();
p.Add<Passes.Rules.Neutral.SpaceToBatchToPad>();
p.Add<Passes.Rules.Neutral.FoldConv2DAddMul>();
});

_compileSession.Target.RegisterTargetInDependentPass(passManager, _compileSession.CompileOptions);

passManager.AddWithName<DataflowPass>("BroadcastMarker").Configure(p =>
{
p.Add<FoldTransposeActTranspose>();
p.Add<BroadcastInputMarker>();
p.Add<BroadcastOutputMarker>();
});
Expand Down Expand Up @@ -220,8 +232,8 @@
MergeOp(p, true);
ClearMarker(p);
MergeFusion(p, singleVar, true);
Bucket(p);
Rebuild(p, singleVar);
Bucket(p);

Check warning on line 236 in src/Nncase.Compiler/Compiler.cs

View check run for this annotation

Codecov / codecov/patch

src/Nncase.Compiler/Compiler.cs#L236

Added line #L236 was not covered by tests
Simplify(p);
}
else
Expand Down
4 changes: 3 additions & 1 deletion src/Nncase.Evaluator/NN/Activations.cs
Original file line number Diff line number Diff line change
Expand Up @@ -521,7 +521,7 @@
/// <summary>
/// Evaluator for <see cref="Sigmoid"/>.
/// </summary>
public class SwishEvaluator : IEvaluator<Swish>, ITypeInferencer<Swish>, ICostEvaluator<Swish>, IMetricEvaluator<Swish>
public class SwishEvaluator : IEvaluator<Swish>, ITypeInferencer<Swish>, ICostEvaluator<Swish>, IMetricEvaluator<Swish>, IShapeEvaluator<Swish>
{
/// <inheritdoc/>
public IValue Visit(IEvaluateContext context, Swish swish)
Expand Down Expand Up @@ -560,6 +560,8 @@
};
}

public Expr Visit(IShapeEvaluateContext context, Swish target) => context.GetArgumentShape(target, Swish.Input);

Check warning on line 563 in src/Nncase.Evaluator/NN/Activations.cs

View check run for this annotation

Codecov / codecov/patch

src/Nncase.Evaluator/NN/Activations.cs#L563

Added line #L563 was not covered by tests

private IRType Visit(IRType input)
{
if (input is DistributedType d && d.NdSBP.Any(s => s is SBPPartialSum))
Expand Down
4 changes: 3 additions & 1 deletion src/Nncase.Importer/Onnx/Transpose.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

using System;
using System.Linq;
using DryIoc.ImTools;
using LanguageExt.UnsafeValueAccess;
using Nncase.IR;
using Nncase.IR.Tensors;
Expand All @@ -16,7 +17,8 @@ public partial class OnnxImporter
private Expr VisitTranspose(NodeProto op)
{
var input = GetSingleInputExpr(op);
var perm = Tensor.From<long>(GetIntsAttribute(op, "perm"));
var defaultPerm = Enumerable.Range(0, input.CheckedShape.Rank).Reverse().ToArray();
var perm = Tensor.From(GetIntsAttribute(op, "perm", defaultPerm));
return F.Tensors.Transpose(input, perm).With(metadata: new IRMetadata() { OutputNames = op.Output, });
}
}
Expand Down
12 changes: 8 additions & 4 deletions src/Nncase.Importer/TFLite/SpaceToBatchND.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.

using System;
using Nncase.IR;
using Nncase.IR.Tensors;
using static Nncase.IR.F.NN;
Expand All @@ -14,15 +16,16 @@
{
var (input, blockShape) = GetInputExprs(op, 0, 1);
var paddings = GetInputExprs(op, 2);
if (input.CheckedShape.Rank == 3)
bool needUnsqueeze = input.CheckedShape.Rank == 3;

Check warning on line 19 in src/Nncase.Importer/TFLite/SpaceToBatchND.cs

View check run for this annotation

Codecov / codecov/patch

src/Nncase.Importer/TFLite/SpaceToBatchND.cs#L19

Added line #L19 was not covered by tests
if (needUnsqueeze)
{
blockShape = Concat(new IR.Tuple(new[] { new[] { 1 }, blockShape }), 0);
paddings = Concat(new IR.Tuple(new[] { new[,] { { 0, 0 } }, paddings }), 0);
input = Unsqueeze(input, new[] { -3 });
}

var stb = NCHWToNHWC(SpaceToBatch(NHWCToNCHW(input), blockShape, paddings));
if (input.CheckedShape.Rank == 3)
if (needUnsqueeze)
{
return Squeeze(stb, new[] { 1 });
}
Expand All @@ -34,15 +37,16 @@
{
var (input, blockShape) = GetInputExprs(op, 0, 1);
var crops = GetInputExprs(op, 2);
if (input.CheckedShape.Rank == 3)
bool needUnsqueeze = input.CheckedShape.Rank == 3;
if (needUnsqueeze)
{
blockShape = Concat(new IR.Tuple(new[] { new[] { 1 }, blockShape }), 0);
crops = Concat(new IR.Tuple(new[] { new[,] { { 0, 0 } }, crops }), 0);
input = Unsqueeze(input, new[] { -3 });
}

var bts = NCHWToNHWC(BatchToSpace(NHWCToNCHW(input), blockShape, crops));
if (input.CheckedShape.Rank == 3)
if (needUnsqueeze)
{
return Squeeze(bts, new[] { 1 });
}
Expand Down
3 changes: 2 additions & 1 deletion src/Nncase.Passes/Rules/Neutral/FoldDilatedConv2D.cs
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,8 @@
(Conv2D.Stride.Index, (Expr)new[] { strideH, strideW }),
(Conv2D.Dilation.Index, (Expr)new[] { dilationH, dilationW }),
};
return ReplaceUtility.ReplaceCallParams(conv, conv.Arguments.ToArray(), pairs).InheritMetaData(btsCall);
var res = ReplaceUtility.ReplaceCallParams(conv.Target, conv.Arguments.ToArray(), pairs).InheritMetaData(btsCall);
return res;

Check warning on line 101 in src/Nncase.Passes/Rules/Neutral/FoldDilatedConv2D.cs

View check run for this annotation

Codecov / codecov/patch

src/Nncase.Passes/Rules/Neutral/FoldDilatedConv2D.cs#L100-L101

Added lines #L100 - L101 were not covered by tests
}

private (int[] Begin, int[] End) GetBeginEnd(int[] btsBlockShape, int[,] crop, int[] btsInputShape)
Expand Down
1 change: 1 addition & 0 deletions src/Nncase.Passes/Rules/ShapeBucket/RecordFusionShape.cs
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ protected override Task<BaseFunction> RunCoreAsync(BaseFunction main, RunPassCon
var memo = EvaluatorUtil.GetMemo(body, ConcatDictionary(input, varValues));
var f = new FusionShapeUpdater(ConcatDictionary(memo, exprValues));
f.Visit(main);
GC.Collect();
return f.FusionShape;
}).SelectMany(x => x)
.ToLookup(x => x.Key, x => x.Value)
Expand Down
76 changes: 44 additions & 32 deletions src/Nncase.Passes/Rules/ShapeBucket/ShapeBucket.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1020,28 +1020,49 @@
x.InputShapes.Select(iShape => iShape.AsTensor().ToArray<int>().ToArray()).ToArray()).ToArray();
if (!SingleDimVar(options))
{
for (int i = 0; i < shapeInfos.Length; i++)
for (int j = 0; j < allFixedShapes.Length; j++)
{
for (int j = 0; j < allFixedShapes.Length; j++)
{
context.FixedShapeCache[j] = allFixedShapes[j];
}
context.FixedShapeCache[j] = allFixedShapes[j];
}
}
else
{
allFixedShapes = new[] { allFixedShapes[0] }.Concat(allFixedShapes.Reverse()).ToArray();
var tmpAllFixedShapes = new[] { allFixedShapes[0] }.Concat(allFixedShapes.Reverse()).ToArray();
var segments = context.DimVarValues.First().Value.Reverse().ToArray();

for (int i = 0; i < segments.Length; i++)
{
context.FixedShapeCache[segments.Length - 1 - i] = allFixedShapes[segments[i]];
context.FixedShapeCache[segments.Length - 1 - i] = tmpAllFixedShapes[segments[i]];
}
}

return allFixedShapes;
}

public static bool ShouldRestore(Call outerCall, BucketFusion fusion)
{
if (CallValidator.IsSimple(fusion))
{
return true;

Check warning on line 1046 in src/Nncase.Passes/Rules/ShapeBucket/ShapeBucket.cs

View check run for this annotation

Codecov / codecov/patch

src/Nncase.Passes/Rules/ShapeBucket/ShapeBucket.cs#L1046

Added line #L1046 was not covered by tests
}

if (outerCall.CheckedType is TupleType tt)
{
if (tt.Fields.All(f => f is TensorType t && t.Shape.Rank < 2))
{
return true;

Check warning on line 1053 in src/Nncase.Passes/Rules/ShapeBucket/ShapeBucket.cs

View check run for this annotation

Codecov / codecov/patch

src/Nncase.Passes/Rules/ShapeBucket/ShapeBucket.cs#L1053

Added line #L1053 was not covered by tests
}
}

if (outerCall.Arguments.ToArray().Any(arg =>
arg.CheckedType is TupleType))
{
return true;

Check warning on line 1060 in src/Nncase.Passes/Rules/ShapeBucket/ShapeBucket.cs

View check run for this annotation

Codecov / codecov/patch

src/Nncase.Passes/Rules/ShapeBucket/ShapeBucket.cs#L1060

Added line #L1060 was not covered by tests
}

return false;
}

public Expr? GetReplace(Call outerCall, BucketFusion fusion, Expr fusionBody)
{
if (ShouldRestore(outerCall, fusion))
Expand Down Expand Up @@ -1073,7 +1094,7 @@
int[][][] allFixedShapes = UpdateShapeCache(shapeInfos, options, context);

var minFixedShapeList = allFixedShapes[^1];
var maxFixedShapeList = allFixedShapes[1];
var maxFixedShapeList = allFixedShapes[0];

// PrintMinMaxShape(minFixedShapeList, maxFixedShapeList, _relPath);

Expand Down Expand Up @@ -1183,30 +1204,6 @@
totalCount == 0 || (minFixedShapeList[0].SequenceEqual(maxFixedShapeList[0]) &&
minFixedShapeList[1].SequenceEqual(maxFixedShapeList[1]));

private static bool ShouldRestore(Call outerCall, BucketFusion fusion)
{
if (CallValidator.IsSimple(fusion))
{
return true;
}

if (outerCall.CheckedType is TupleType tt)
{
if (tt.Fields.All(f => f is TensorType t && t.Shape.Rank < 2))
{
return true;
}
}

if (outerCall.Arguments.ToArray().Any(arg =>
arg.CheckedType is TupleType))
{
return true;
}

return false;
}

private static void PrintMinMaxShape(int[][] minFixedShapeList, int[][] maxFixedShapeList, string relPath)
{
string str = string.Empty;
Expand Down Expand Up @@ -1303,10 +1300,16 @@
_shapeInfo = shapeInfo;
}

// todo: pattern not match??
public override Pattern Pattern => FusionBucket.BucketFusionPattern;

public Expr? GetReplace(Call outerCall, BucketFusion fusion, Expr fusionBody)
{
if (FusionBucket.ShouldRestore(outerCall, fusion))
{
return FusionBucket.RestoreBodyWithArgs(outerCall.Arguments.ToArray(), fusion.Parameters.ToArray(), fusion.Body);

Check warning on line 1310 in src/Nncase.Passes/Rules/ShapeBucket/ShapeBucket.cs

View check run for this annotation

Codecov / codecov/patch

src/Nncase.Passes/Rules/ShapeBucket/ShapeBucket.cs#L1310

Added line #L1310 was not covered by tests
}

// only once RecordShape
var options = CompileSession.CompileOptions.ShapeBucketOptions;

Expand Down Expand Up @@ -1397,6 +1400,15 @@
}
}

if (expr.Target is Call { Target: IR.Tensors.Reshape })
{
var type = expr.Arguments[IR.Tensors.Reshape.Shape.Index].CheckedType;

Check warning on line 1405 in src/Nncase.Passes/Rules/ShapeBucket/ShapeBucket.cs

View check run for this annotation

Codecov / codecov/patch

src/Nncase.Passes/Rules/ShapeBucket/ShapeBucket.cs#L1405

Added line #L1405 was not covered by tests
if (type is TensorType { Shape.IsFixed: false })
{
_hasDynamic = true;

Check warning on line 1408 in src/Nncase.Passes/Rules/ShapeBucket/ShapeBucket.cs

View check run for this annotation

Codecov / codecov/patch

src/Nncase.Passes/Rules/ShapeBucket/ShapeBucket.cs#L1408

Added line #L1408 was not covered by tests
}
}

return expr;
}
}
Expand Down
Loading
Loading