Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Melgan #1168

Merged
merged 2 commits into from
Feb 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions src/Nncase.Compiler/Compiler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -153,10 +153,17 @@
p.Add<Passes.Rules.WithMarker.FoldTransposeActTranspose>();
p.Add<Passes.Rules.WithMarker.FoldTransposeBinaryActTranspose>();
p.Add<Passes.Rules.WithMarker.CombineReshapePad>();
p.Add<Passes.Rules.WithMarker.CombineTransposePad>();
p.Add<Passes.Rules.WithMarker.CombinePadTranspose>();
p.Add<Passes.Rules.Neutral.CombineTransposeUnary>();
p.Add<Passes.Rules.Neutral.CombineTransposePad>();
if (_compileSession.CompileOptions.ShapeBucketOptions.Enable)
{
p.Add<Passes.Rules.WithMarker.CombineTransposePad>();

Check warning on line 160 in src/Nncase.Compiler/Compiler.cs

View check run for this annotation

Codecov / codecov/patch

src/Nncase.Compiler/Compiler.cs#L160

Added line #L160 was not covered by tests
}
else
{
p.Add<Passes.Rules.Neutral.CombineTransposePad>();
}

p.Add<Passes.Rules.Neutral.CombinePadTranspose>();
p.Add<Passes.Rules.Neutral.CombineBinaryTranspose>();
p.Add<Passes.Rules.Neutral.CombineConstBinaryTranspose>();
Expand Down
2 changes: 2 additions & 0 deletions src/Nncase.Core/PatternMatch/PatternUtility.cs
Original file line number Diff line number Diff line change
Expand Up @@ -288,5 +288,7 @@ public static Pattern IsCallWildcardMaybeSwappable<TOp>(string callName, Pattern

public static Pattern MaybeMarker(Pattern input) => IsAlt(input, IsRangeOfMarker(input, IsWildcard()));

public static Pattern MaybeMarker(Pattern input, string markerName) => IsAlt(input, IsRangeOfMarker(markerName, input, IsWildcard()));

public static Pattern HasMarker(Pattern input, string? markerName = null) => IsRangeOfMarker(markerName, input, IsWildcard());
}
64 changes: 50 additions & 14 deletions src/Nncase.Passes/Rules/WithMarker/CombineReshapePad.cs
Original file line number Diff line number Diff line change
Expand Up @@ -32,27 +32,41 @@
{
/// <inheritdoc/>
public IPattern Pattern { get; } =
IsReshape(
MaybeMarker(
IsReshape(
"reshape",
"reshapeCall",
_ => true,
HasMarker(IsPad("pad", "padCall", _ => true, HasMarker(IsWildcard("input"), "marker"), IsTensorConst("pads"), IsTensorConst("value")) with { TypePattern = HasFixedShape() }, "padOutMarker"),
IsWildcard("shape")) with
{ TypePattern = HasFixedShape() };
{ TypePattern = HasFixedShape() },
"outMarker");

private Expr? GetReplace(Reshape reshape, Call reshapeCall, Pad pad, Call padCall, Expr input, Expr shape, int[] pads, Expr value, Marker marker)
private Expr? GetReplace(Reshape reshape, Call reshapeCall, Pad pad, Call padCall, Expr input, Expr shape, int[] pads, Expr value, Marker marker, IMatchResult result)
{
// only support pattern like melgan
var reshapeRank = reshapeCall.CheckedShape.Rank;
var padRank = padCall.CheckedShape.Rank;
if (reshapeRank >= padRank
&& Enumerable.SequenceEqual(reshapeCall.CheckedShape.ToValueArray()[(reshapeRank - padRank)..], padCall.CheckedShape.ToValueArray()))
{
return Pad(
marker.With(target: Reshape(input, Enumerable.Repeat(1, reshapeRank - padRank).Concat(input.CheckedShape.ToValueArray()).ToArray()).InheritMetaData(reshapeCall)),
Tensor.From(Enumerable.Repeat(0, (reshapeRank - padRank) * 2).Concat(pads).ToArray(), new[] { reshapeRank, 2 }),
var newPad = Pad(
marker.With(target: Reshape(
marker.With(target: input),
Enumerable.Repeat(1, reshapeRank - padRank).Concat(input.CheckedShape.ToValueArray()).ToArray())
.InheritMetaData(reshapeCall)),
Tensor.From(
Enumerable.Repeat(0, (reshapeRank - padRank) * 2).Concat(pads).ToArray(),
new[] { reshapeRank, 2 }),

Check warning on line 60 in src/Nncase.Passes/Rules/WithMarker/CombineReshapePad.cs

View check run for this annotation

Codecov / codecov/patch

src/Nncase.Passes/Rules/WithMarker/CombineReshapePad.cs#L53-L60

Added lines #L53 - L60 were not covered by tests
pad.PadMode,
value).InheritMetaData(padCall);
var outMarker = result.GetValueOrDefault("outMarker");

Check warning on line 63 in src/Nncase.Passes/Rules/WithMarker/CombineReshapePad.cs

View check run for this annotation

Codecov / codecov/patch

src/Nncase.Passes/Rules/WithMarker/CombineReshapePad.cs#L63

Added line #L63 was not covered by tests
if (outMarker != null)
{
return ((Marker)outMarker).With(target: newPad);

Check warning on line 66 in src/Nncase.Passes/Rules/WithMarker/CombineReshapePad.cs

View check run for this annotation

Codecov / codecov/patch

src/Nncase.Passes/Rules/WithMarker/CombineReshapePad.cs#L66

Added line #L66 was not covered by tests
}

return newPad;

Check warning on line 69 in src/Nncase.Passes/Rules/WithMarker/CombineReshapePad.cs

View check run for this annotation

Codecov / codecov/patch

src/Nncase.Passes/Rules/WithMarker/CombineReshapePad.cs#L69

Added line #L69 was not covered by tests
}

return null;
Expand All @@ -67,15 +81,17 @@
public sealed partial class CombineTransposePad : IRewriteRule
{
/// <inheritdoc/>
public IPattern Pattern { get; } = IsPad(
public IPattern Pattern { get; } = MaybeMarker(
IsPad(

Check warning on line 85 in src/Nncase.Passes/Rules/WithMarker/CombineReshapePad.cs

View check run for this annotation

Codecov / codecov/patch

src/Nncase.Passes/Rules/WithMarker/CombineReshapePad.cs#L85

Added line #L85 was not covered by tests
"pad",
"padCall",
x => true,
HasMarker(IsTranspose(IsWildcard("input"), IsTensorConst("perm")), "marker"),
IsTensorConst("pads"),
IsWildcard("padValue"));
IsWildcard("padValue")),
"outMarker");

Check warning on line 92 in src/Nncase.Passes/Rules/WithMarker/CombineReshapePad.cs

View check run for this annotation

Codecov / codecov/patch

src/Nncase.Passes/Rules/WithMarker/CombineReshapePad.cs#L91-L92

Added lines #L91 - L92 were not covered by tests

private Expr GetReplace(Pad pad, Call padCall, Expr input, int[] perm, Expr pads, Expr padValue, Marker marker)
private Expr GetReplace(Pad pad, Call padCall, Expr input, int[] perm, Expr pads, Expr padValue, Marker marker, IMatchResult result)
{
var inv_perm = perm.Select((p, i) => (p, i)).OrderBy(tp => tp.p).ToArray();
var newPads = new List<Expr>();
Expand All @@ -87,7 +103,14 @@
}

var p = Pad(input, Stack(new IR.Tuple(newPads.ToArray()), 0).Evaluate().AsTensor(), pad.PadMode, padValue).InheritMetaData(padCall);
return Transpose(marker.With(target: p), perm);
var newTranspose = Transpose(marker.With(target: p), perm);
var outMarker = result.GetValueOrDefault("outMarker");

Check warning on line 107 in src/Nncase.Passes/Rules/WithMarker/CombineReshapePad.cs

View check run for this annotation

Codecov / codecov/patch

src/Nncase.Passes/Rules/WithMarker/CombineReshapePad.cs#L106-L107

Added lines #L106 - L107 were not covered by tests
if (outMarker != null)
{
return ((Marker)outMarker).With(target: newTranspose);

Check warning on line 110 in src/Nncase.Passes/Rules/WithMarker/CombineReshapePad.cs

View check run for this annotation

Codecov / codecov/patch

src/Nncase.Passes/Rules/WithMarker/CombineReshapePad.cs#L110

Added line #L110 was not covered by tests
}

return newTranspose;

Check warning on line 113 in src/Nncase.Passes/Rules/WithMarker/CombineReshapePad.cs

View check run for this annotation

Codecov / codecov/patch

src/Nncase.Passes/Rules/WithMarker/CombineReshapePad.cs#L113

Added line #L113 was not covered by tests
}
}

Expand All @@ -99,7 +122,8 @@
public sealed partial class CombinePadTranspose : IRewriteRule
{
/// <inheritdoc/>
public IPattern Pattern { get; } = IsTranspose(
public IPattern Pattern { get; } = MaybeMarker(
IsTranspose(
"transpose",
x => true,
HasMarker(
Expand All @@ -111,9 +135,10 @@
IsTensorConst("pads"),
IsTensorConst("padValue")),
"marker"),
IsTensorConst("perm"));
IsTensorConst("perm")),
"outMarker");

private Expr GetReplace(Pad pad, Call padCall, Expr input, int[] perm, Expr pads, Expr padValue, Marker marker)
private Expr GetReplace(Pad pad, Call padCall, Expr input, int[] perm, Expr pads, Expr padValue, Marker marker, IMatchResult result)
{
var newPads = new List<int>();
for (int i = 0; i < perm.Length; i++)
Expand All @@ -122,6 +147,17 @@
newPads.Add(((TensorConst)pads).Value.ToArray<int>()[(perm[i] * 2) + 1]);
}

return Pad(marker.With(target: Transpose(input, perm)), Tensor.From<int>(newPads.ToArray(), pads.CheckedShape), pad.PadMode, padValue).InheritMetaData(padCall);
var newPad = Pad(
marker.With(target: Transpose(input, perm)),
Tensor.From<int>(newPads.ToArray(), pads.CheckedShape),
pad.PadMode,
padValue).InheritMetaData(padCall);
var outMarker = result.GetValueOrDefault("outMarker");

Check warning on line 155 in src/Nncase.Passes/Rules/WithMarker/CombineReshapePad.cs

View check run for this annotation

Codecov / codecov/patch

src/Nncase.Passes/Rules/WithMarker/CombineReshapePad.cs#L150-L155

Added lines #L150 - L155 were not covered by tests
if (outMarker != null)
{
return ((Marker)outMarker).With(target: newPad);

Check warning on line 158 in src/Nncase.Passes/Rules/WithMarker/CombineReshapePad.cs

View check run for this annotation

Codecov / codecov/patch

src/Nncase.Passes/Rules/WithMarker/CombineReshapePad.cs#L158

Added line #L158 was not covered by tests
}

return newPad;

Check warning on line 161 in src/Nncase.Passes/Rules/WithMarker/CombineReshapePad.cs

View check run for this annotation

Codecov / codecov/patch

src/Nncase.Passes/Rules/WithMarker/CombineReshapePad.cs#L161

Added line #L161 was not covered by tests
}
}
Loading