Skip to content

Commit

Permalink
Fix Bucket (#1159)
Browse files Browse the repository at this point in the history
* update

* update

* update

* fix

* fix

* fix swish evaluator

* update

* fix

* update

* update

* Apply code-format changes

* update

* fix transpose

* Add test cases for onnx Transpose.

* Apply code-format changes

---------

Co-authored-by: FusionBolt <[email protected]>
Co-authored-by: zhangyang2057 <[email protected]>
Co-authored-by: zhangyang2057 <[email protected]>
  • Loading branch information
4 people authored Jan 24, 2024
1 parent 9647756 commit de6c89d
Show file tree
Hide file tree
Showing 9 changed files with 269 additions and 63 deletions.
18 changes: 15 additions & 3 deletions src/Nncase.Compiler/Compiler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,12 @@
using Nncase.Passes.Rules.Neutral;
using Nncase.Passes.Rules.ShapeBucket;
using Nncase.Passes.Rules.ShapeExpr;
using Nncase.Passes.Rules.WithMarker;
using Nncase.Passes.Transforms;
using Nncase.Quantization;
using static Nncase.Passes.Rules.ShapeBucket.ShapeBucketRegister;
using CombinePadTranspose = Nncase.Passes.Rules.WithMarker.CombinePadTranspose;
using CombineReshapePad = Nncase.Passes.Rules.Neutral.CombineReshapePad;
using FoldConstCall = Nncase.Passes.Rules.Neutral.FoldConstCall;

namespace Nncase.Compiler;
Expand Down Expand Up @@ -97,6 +100,7 @@ public void TargetIndependentPass(IPassManager passManager)
p.Add<Passes.Rules.Neutral.NormAxisReshape>();
p.Add<Passes.Rules.Neutral.NormAxisReduceArg>();
p.Add<Passes.Rules.Neutral.NormAxisSlice>();
p.Add<Passes.Rules.Neutral.SwapBinaryArgs>();
p.Add<Passes.Rules.Neutral.SqueezeTransposeShape>();
p.Add<Passes.Rules.Neutral.Squeeze5DTranspose>();
p.Add<Passes.Rules.Neutral.SqueezeBinaryShape>();
Expand All @@ -117,8 +121,6 @@ public void TargetIndependentPass(IPassManager passManager)
p.Add<Passes.Rules.Neutral.FoldTwoSlices>();
p.Add<Passes.Rules.Neutral.FocusFull>();
p.Add<Passes.Rules.Neutral.ReshapeMatMul>();
p.Add<Passes.Rules.Neutral.SplitSpaceToBatch>();
p.Add<Passes.Rules.Neutral.SplitBatchToSpace>();
p.Add<Passes.Rules.Neutral.FoldConstCall>();
p.Add<Passes.Rules.Neutral.FoldShapeOf>();
p.Add<Passes.Rules.Neutral.FoldTwoReshapes>();
Expand All @@ -131,20 +133,28 @@ public void TargetIndependentPass(IPassManager passManager)
p.Add<Passes.Rules.Neutral.FoldUnsqueezeSqueeze>();
p.Add<Passes.Rules.Neutral.FoldTwoTransposes>();
p.Add<Passes.Rules.Neutral.FoldNopClamp>();
p.Add<Passes.Rules.ShapeBucket.FoldRepeatMarker>();
p.Add<Passes.Rules.Neutral.SqueezeToReshape>();
p.Add<Passes.Rules.Neutral.UnSqueezeToReshape>();
p.Add<Passes.Rules.ShapeExpr.GatherToGetItem>();
p.Add<Passes.Rules.ShapeExpr.FoldGetItemShapeOf>();
p.Add<Passes.Rules.Neutral.FoldNopReduce>();
p.Add<Passes.Rules.Neutral.SliceToGetItem>();
p.Add<Passes.Rules.Neutral.FoldTwoPads>();
p.Add<Passes.Rules.Neutral.FoldDilatedConv2D>();
});

passManager.AddWithName<EGraphRulesPass>("NeutralOptimizeTranspose").Configure(p =>
{
p.Add<Passes.Rules.Neutral.FoldConstCall>();
p.Add<Passes.Rules.Neutral.FoldNopTranspose>();
p.Add<Passes.Rules.Neutral.FoldTwoTransposes>();
p.Add<FoldRepeatMarker>();
p.Add<Passes.Rules.WithMarker.FoldTransposeActTranspose>();
p.Add<Passes.Rules.WithMarker.FoldTransposeBinaryActTranspose>();
p.Add<Passes.Rules.WithMarker.CombineReshapePad>();
p.Add<Passes.Rules.WithMarker.CombineTransposePad>();
p.Add<Passes.Rules.WithMarker.CombinePadTranspose>();
p.Add<Passes.Rules.Neutral.CombineTransposeUnary>();
p.Add<Passes.Rules.Neutral.CombineTransposePad>();
p.Add<Passes.Rules.Neutral.CombinePadTranspose>();
Expand Down Expand Up @@ -179,12 +189,14 @@ public void TargetIndependentPass(IPassManager passManager)
p.Add<Passes.Rules.Neutral.FoldNopSlice>();
p.Add<Passes.Rules.Neutral.FoldTwoSlices>();
p.Add<Passes.Rules.Neutral.SpaceToBatchToPad>();
p.Add<Passes.Rules.Neutral.FoldConv2DAddMul>();
});

_compileSession.Target.RegisterTargetInDependentPass(passManager, _compileSession.CompileOptions);

passManager.AddWithName<DataflowPass>("BroadcastMarker").Configure(p =>
{
p.Add<FoldTransposeActTranspose>();
p.Add<BroadcastInputMarker>();
p.Add<BroadcastOutputMarker>();
});
Expand Down Expand Up @@ -220,8 +232,8 @@ public void RegisterShapeBucket(IPassManager p)
MergeOp(p, true);
ClearMarker(p);
MergeFusion(p, singleVar, true);
Bucket(p);
Rebuild(p, singleVar);
Bucket(p);
Simplify(p);
}
else
Expand Down
4 changes: 3 additions & 1 deletion src/Nncase.Evaluator/NN/Activations.cs
Original file line number Diff line number Diff line change
Expand Up @@ -521,7 +521,7 @@ private IRType Visit(TensorType input)
/// <summary>
/// Evaluator for <see cref="Sigmoid"/>.
/// </summary>
public class SwishEvaluator : IEvaluator<Swish>, ITypeInferencer<Swish>, ICostEvaluator<Swish>, IMetricEvaluator<Swish>
public class SwishEvaluator : IEvaluator<Swish>, ITypeInferencer<Swish>, ICostEvaluator<Swish>, IMetricEvaluator<Swish>, IShapeEvaluator<Swish>
{
/// <inheritdoc/>
public IValue Visit(IEvaluateContext context, Swish swish)
Expand Down Expand Up @@ -560,6 +560,8 @@ public Metric Visit(IMetricEvaluateContext context, Swish target)
};
}

public Expr Visit(IShapeEvaluateContext context, Swish target) => context.GetArgumentShape(target, Swish.Input);

private IRType Visit(IRType input)
{
if (input is DistributedType d && d.NdSBP.Any(s => s is SBPPartialSum))
Expand Down
4 changes: 3 additions & 1 deletion src/Nncase.Importer/Onnx/Transpose.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

using System;
using System.Linq;
using DryIoc.ImTools;
using LanguageExt.UnsafeValueAccess;
using Nncase.IR;
using Nncase.IR.Tensors;
Expand All @@ -16,7 +17,8 @@ public partial class OnnxImporter
private Expr VisitTranspose(NodeProto op)
{
var input = GetSingleInputExpr(op);
var perm = Tensor.From<long>(GetIntsAttribute(op, "perm"));
var defaultPerm = Enumerable.Range(0, input.CheckedShape.Rank).Reverse().ToArray();
var perm = Tensor.From(GetIntsAttribute(op, "perm", defaultPerm));
return F.Tensors.Transpose(input, perm).With(metadata: new IRMetadata() { OutputNames = op.Output, });
}
}
Expand Down
12 changes: 8 additions & 4 deletions src/Nncase.Importer/TFLite/SpaceToBatchND.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.

using System;
using Nncase.IR;
using Nncase.IR.Tensors;
using static Nncase.IR.F.NN;
Expand All @@ -14,15 +16,16 @@ private Expr VisitSpaceToBatchND(in tflite.Operator op)
{
var (input, blockShape) = GetInputExprs(op, 0, 1);
var paddings = GetInputExprs(op, 2);
if (input.CheckedShape.Rank == 3)
bool needUnsqueeze = input.CheckedShape.Rank == 3;
if (needUnsqueeze)
{
blockShape = Concat(new IR.Tuple(new[] { new[] { 1 }, blockShape }), 0);
paddings = Concat(new IR.Tuple(new[] { new[,] { { 0, 0 } }, paddings }), 0);
input = Unsqueeze(input, new[] { -3 });
}

var stb = NCHWToNHWC(SpaceToBatch(NHWCToNCHW(input), blockShape, paddings));
if (input.CheckedShape.Rank == 3)
if (needUnsqueeze)
{
return Squeeze(stb, new[] { 1 });
}
Expand All @@ -34,15 +37,16 @@ private Expr VisitBatchToSpaceND(in tflite.Operator op)
{
var (input, blockShape) = GetInputExprs(op, 0, 1);
var crops = GetInputExprs(op, 2);
if (input.CheckedShape.Rank == 3)
bool needUnsqueeze = input.CheckedShape.Rank == 3;
if (needUnsqueeze)
{
blockShape = Concat(new IR.Tuple(new[] { new[] { 1 }, blockShape }), 0);
crops = Concat(new IR.Tuple(new[] { new[,] { { 0, 0 } }, crops }), 0);
input = Unsqueeze(input, new[] { -3 });
}

var bts = NCHWToNHWC(BatchToSpace(NHWCToNCHW(input), blockShape, crops));
if (input.CheckedShape.Rank == 3)
if (needUnsqueeze)
{
return Squeeze(bts, new[] { 1 });
}
Expand Down
3 changes: 2 additions & 1 deletion src/Nncase.Passes/Rules/Neutral/FoldDilatedConv2D.cs
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,8 @@ private static CallPattern Conv2DPattern() =>
(Conv2D.Stride.Index, (Expr)new[] { strideH, strideW }),
(Conv2D.Dilation.Index, (Expr)new[] { dilationH, dilationW }),
};
return ReplaceUtility.ReplaceCallParams(conv, conv.Arguments.ToArray(), pairs).InheritMetaData(btsCall);
var res = ReplaceUtility.ReplaceCallParams(conv.Target, conv.Arguments.ToArray(), pairs).InheritMetaData(btsCall);
return res;
}

private (int[] Begin, int[] End) GetBeginEnd(int[] btsBlockShape, int[,] crop, int[] btsInputShape)
Expand Down
1 change: 1 addition & 0 deletions src/Nncase.Passes/Rules/ShapeBucket/RecordFusionShape.cs
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ protected override Task<BaseFunction> RunCoreAsync(BaseFunction main, RunPassCon
var memo = EvaluatorUtil.GetMemo(body, ConcatDictionary(input, varValues));
var f = new FusionShapeUpdater(ConcatDictionary(memo, exprValues));
f.Visit(main);
GC.Collect();
return f.FusionShape;
}).SelectMany(x => x)
.ToLookup(x => x.Key, x => x.Value)
Expand Down
76 changes: 44 additions & 32 deletions src/Nncase.Passes/Rules/ShapeBucket/ShapeBucket.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1020,28 +1020,49 @@ public static int[][][] UpdateShapeCache(FusionShapeData[] shapeInfos, ShapeBuck
x.InputShapes.Select(iShape => iShape.AsTensor().ToArray<int>().ToArray()).ToArray()).ToArray();
if (!SingleDimVar(options))
{
for (int i = 0; i < shapeInfos.Length; i++)
for (int j = 0; j < allFixedShapes.Length; j++)
{
for (int j = 0; j < allFixedShapes.Length; j++)
{
context.FixedShapeCache[j] = allFixedShapes[j];
}
context.FixedShapeCache[j] = allFixedShapes[j];
}
}
else
{
allFixedShapes = new[] { allFixedShapes[0] }.Concat(allFixedShapes.Reverse()).ToArray();
var tmpAllFixedShapes = new[] { allFixedShapes[0] }.Concat(allFixedShapes.Reverse()).ToArray();
var segments = context.DimVarValues.First().Value.Reverse().ToArray();

for (int i = 0; i < segments.Length; i++)
{
context.FixedShapeCache[segments.Length - 1 - i] = allFixedShapes[segments[i]];
context.FixedShapeCache[segments.Length - 1 - i] = tmpAllFixedShapes[segments[i]];
}
}

return allFixedShapes;
}

public static bool ShouldRestore(Call outerCall, BucketFusion fusion)
{
if (CallValidator.IsSimple(fusion))
{
return true;
}

if (outerCall.CheckedType is TupleType tt)
{
if (tt.Fields.All(f => f is TensorType t && t.Shape.Rank < 2))
{
return true;
}
}

if (outerCall.Arguments.ToArray().Any(arg =>
arg.CheckedType is TupleType))
{
return true;
}

return false;
}

public Expr? GetReplace(Call outerCall, BucketFusion fusion, Expr fusionBody)
{
if (ShouldRestore(outerCall, fusion))
Expand Down Expand Up @@ -1073,7 +1094,7 @@ public static int[][][] UpdateShapeCache(FusionShapeData[] shapeInfos, ShapeBuck
int[][][] allFixedShapes = UpdateShapeCache(shapeInfos, options, context);

var minFixedShapeList = allFixedShapes[^1];
var maxFixedShapeList = allFixedShapes[1];
var maxFixedShapeList = allFixedShapes[0];

// PrintMinMaxShape(minFixedShapeList, maxFixedShapeList, _relPath);

Expand Down Expand Up @@ -1183,30 +1204,6 @@ private static bool IsFixed(int totalCount, int[][] minFixedShapeList, int[][] m
totalCount == 0 || (minFixedShapeList[0].SequenceEqual(maxFixedShapeList[0]) &&
minFixedShapeList[1].SequenceEqual(maxFixedShapeList[1]));

private static bool ShouldRestore(Call outerCall, BucketFusion fusion)
{
if (CallValidator.IsSimple(fusion))
{
return true;
}

if (outerCall.CheckedType is TupleType tt)
{
if (tt.Fields.All(f => f is TensorType t && t.Shape.Rank < 2))
{
return true;
}
}

if (outerCall.Arguments.ToArray().Any(arg =>
arg.CheckedType is TupleType))
{
return true;
}

return false;
}

private static void PrintMinMaxShape(int[][] minFixedShapeList, int[][] maxFixedShapeList, string relPath)
{
string str = string.Empty;
Expand Down Expand Up @@ -1303,10 +1300,16 @@ public RebuildBucket(Dictionary<BucketFusion, FusionShapeData[]> shapeInfo)
_shapeInfo = shapeInfo;
}

// todo: pattern not match??
public override Pattern Pattern => FusionBucket.BucketFusionPattern;

public Expr? GetReplace(Call outerCall, BucketFusion fusion, Expr fusionBody)
{
if (FusionBucket.ShouldRestore(outerCall, fusion))
{
return FusionBucket.RestoreBodyWithArgs(outerCall.Arguments.ToArray(), fusion.Parameters.ToArray(), fusion.Body);
}

// only once RecordShape
var options = CompileSession.CompileOptions.ShapeBucketOptions;

Expand Down Expand Up @@ -1397,6 +1400,15 @@ protected override Expr VisitLeafCall(Call expr)
}
}

if (expr.Target is Call { Target: IR.Tensors.Reshape })
{
var type = expr.Arguments[IR.Tensors.Reshape.Shape.Index].CheckedType;
if (type is TensorType { Shape.IsFixed: false })
{
_hasDynamic = true;
}
}

return expr;
}
}
Expand Down
Loading

0 comments on commit de6c89d

Please sign in to comment.