Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Broadcast marker #1033

Merged
merged 4 commits into from
Aug 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/Nncase.Core/Utilities/ReplaceUtility.cs
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,11 @@ public static Call ReplaceCallParams(Expr target, IReadOnlyList<Expr> oldParams,
return new Call(target, ReplaceItems(oldParams, pairs));
}

public static Call ReplaceCallParams(Call call, params (int, Expr)[] pairs)
{
return new Call(call.Target, ReplaceItems(call.Arguments.ToArray(), pairs));
}

/// <summary>
/// replace the call params with parameter info.
/// </summary>
Expand All @@ -117,6 +122,11 @@ public static Call ReplaceCallParams(Expr target, IReadOnlyList<Expr> oldParams,
public static Call ReplaceCallFirstParam(Expr target, IReadOnlyList<Expr> oldParams, Expr expr) =>
ReplaceCallParams(target, oldParams, (oldParams[0], expr));

public static Expr ReplaceCallFirstParam(Call call, Expr expr)
{
return ReplaceCallFirstParam(call.Target, call.Arguments.ToArray(), expr);
}

/// <summary>
/// Replace target in body with expr.
/// </summary>
Expand Down
79 changes: 79 additions & 0 deletions src/Nncase.Passes/Rules/Lower/BroadcastMarker.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.

using System;
using Nncase.IR;
using Nncase.IR.Math;
using Nncase.IR.Tensors;
using Nncase.PatternMatch;
using Nncase.Utilities;
using static Nncase.Passes.Rules.Lower.BroadcastMarkerHelper;
using static Nncase.PatternMatch.Utility;
using static Nncase.Utilities.ReplaceUtility;

namespace Nncase.Passes.Rules.Lower;

// e.g. matmul(reshape(marker(x))) -> matmul(marker(reshape(marker(x))))
[RuleGenerator]
public partial class BroadcastInputMarker : RewriteRule<Pattern>
{
public override Pattern Pattern => IsCallWildcard(
"outer",
IsWildcard(),
InputPattern);

public Pattern InputPattern => IsCallWildcard(
"call",
IsWildcard(),
IsRangeOfMarker(
"marker",
IsWildcard(),
IsWildcard()));

public Expr? GetReplace(Call outer, Call call, Marker marker)
{
if (!NotChangeRangeOp(call.Target))
{
return null;
}

if (outer.Target is MatMul && CompilerServices.TryMatchRoot(outer.Arguments[1], InputPattern, new(), out var matchResult))
{
var rhsMarker = (Marker)matchResult["marker"];
var rhsCall = (Call)matchResult["call"];
var lhs = marker.With(target: ReplaceCallFirstParam(call, marker));
var rhs = rhsMarker.With(target: ReplaceCallFirstParam(rhsCall, rhsMarker));
return ReplaceCallParams(outer, (0, lhs), (1, rhs));
}

return ReplaceCallFirstParam(outer, marker.With(target: ReplaceCallFirstParam(call, marker)));
}
}

// e.g. marker(reshape(matmul(x))) -> marker(reshape(marker(matmul(x))))
[RuleGenerator]
public partial class BroadcastOutputMarker : RewriteRule<Pattern>
{
public override Pattern Pattern => IsRangeOfMarker(
"marker",
IsCallWildcard("input", IsWildcard(), IsCallWildcard(null, IsWildcard())),
IsWildcard());

public Expr? GetReplace(Call input, Marker marker)
{
if (!NotChangeRangeOp(input.Target))
{
return null;
}

return ReplaceCallFirstParam(input, marker.With(target: input.Arguments[0]));
}
}

internal static class BroadcastMarkerHelper
{
public static bool NotChangeRangeOp(Expr op)
{
return op is Squeeze || op is Unsqueeze || op is Reshape || op is Broadcast;
}
}
35 changes: 35 additions & 0 deletions src/Nncase.Tests/Rules/UnitTestBroadcastMarker.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.

using Nncase.IR;
using Nncase.Passes.Rules.Lower;
using Nncase.Tests.TestFixture;
using Xunit;
using static Nncase.IR.F.Math;
using static Nncase.IR.F.Tensors;

namespace Nncase.Tests.Rules;

[AutoSetupTestMethod(InitSession = true)]
public class UnitTestBroadcastMarker : TransformTestBase
{
[Fact]
public void TestBroadcastInputMarker()
{
var input = Testing.Rand<float>(1, 3, 24, 24);
var a = IR.F.Math.MatMul(
Reshape(new Marker(WellknownMarkerNames.RangeOf, input, new[] { -1f, 1f }), input.Shape),
Reshape(new Marker(WellknownMarkerNames.RangeOf, input, new[] { -2f, 2f }), input.Shape));
var result = TestMatched<BroadcastInputMarker>(a);
TestNotMatch<BroadcastInputMarker>(result);
}

[Fact]
public void TestBroadcastOutputMarker()
{
var input = Testing.Rand<float>(1, 3, 24, 24);
var a = new Marker(WellknownMarkerNames.RangeOf, Reshape(Abs(input), input.Shape), new[] { -1f, 1f });
var result = TestMatched<BroadcastOutputMarker>(a);
TestNotMatch<BroadcastOutputMarker>(result);
}
}