Skip to content

Commit

Permalink
Merge branch 'master' into feature/non_uniform
Browse files Browse the repository at this point in the history
  • Loading branch information
zhen8838 authored Nov 15, 2023
2 parents d1f4078 + 734cce0 commit 351e795
Showing 1 changed file with 19 additions and 6 deletions.
25 changes: 19 additions & 6 deletions src/Nncase.Passes/Rules/Neutral/AddPreProcess.cs
Original file line number Diff line number Diff line change
Expand Up @@ -175,13 +175,26 @@ protected override Task<IRModule> RunCoreAsync(IRModule module, RunPassContext o
// Normalization
if (mean.Length != 0)
{
newInput = mean.Length switch
Expr meanCall;
Expr stdCall;
switch (mean.Length)
{
3 when inputShape.Length == 4 => (newInput - Tensor.From(mean, new[] { 1, mean.Length, 1, 1 })) /
Tensor.From(std, new[] { 1, std.Length, 1, 1 }),
_ => (newInput - Tensor.From(new float[] { mean[0] }, new[] { 1 })) /
Tensor.From(new float[] { std[0] }, new[] { 1 }),
};
case 3 when inputShape.Length == 4:
meanCall = (Expr)Tensor.From(mean, new[] { 1, mean.Length, 1, 1 });
stdCall = (Expr)Tensor.From(std, new[] { 1, std.Length, 1, 1 });
break;

default:
meanCall = (Expr)Tensor.From(new float[] { mean[0] }, new[] { 1 });
stdCall = (Expr)Tensor.From(new float[] { std[0] }, new[] { 1 });
break;
}

meanCall.Metadata.OutputNames = new[] { "Mean" };
stdCall.Metadata.OutputNames = new[] { "Std" };
var subMean = (newInput - meanCall).With(metadata: new IRMetadata() { OutputNames = new[] { input.Metadata.OutputNames?[0] + "_SubMean" } });
var divStd = (subMean / stdCall).With(metadata: new IRMetadata() { OutputNames = new[] { input.Metadata.OutputNames?[0] + "_DivStd" } });
newInput = divStd;

// newInput = Binary(BinaryOp.Div, Binary(BinaryOp.Sub, newInput, Tensor.From(mean, new []{1,3,1,1})), Const.FromTensor(std) );
}
Expand Down

0 comments on commit 351e795

Please sign in to comment.