Skip to content

Commit

Permalink
add Razor.Templating.Core (#1169)
Browse files Browse the repository at this point in the history
* add extract constrains

* refactor buffer schedule

* add Razor.Templating.Core

* reorder SwapBinaryArgs

* Apply code-format changes

* Update NuGet.Config

---------

Co-authored-by: zhengqihang <[email protected]>
Co-authored-by: xhuohai <[email protected]>
Co-authored-by: sunnycase <[email protected]>
  • Loading branch information
4 people authored Mar 5, 2024
1 parent bdaf0b1 commit 2498b1b
Show file tree
Hide file tree
Showing 26 changed files with 297 additions and 454 deletions.
2 changes: 0 additions & 2 deletions NuGet.Config
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,11 @@
<configuration>
<packageSources>
<clear />
<add key="nuget.cnblogs.com" value="https://nuget.cnblogs.com/v3/index.json" protocolVersion="3" />
<add key="nuget.org" value="https://api.nuget.org/v3/index.json" protocolVersion="3" />
<add key="design-packages" value="tools/design-packages" />
<add key="sunnycase" value="https://nuget.sunnycase.moe/v3/index.json" />
</packageSources>
<activePackageSource>
<add key="nuget.cnblogs.com" value="https://nuget.cnblogs.com/v3/index.json" protocolVersion="3" />
<add key="nuget.org" value="https://api.nuget.org/v3/index.json" protocolVersion="3" />
<add key="Nncase.Libs" value="https://www.myget.org/F/magicallibs/api/v3/index.json" protocolVersion="3" />
<add key="myget-xunit" value="https://www.myget.org/F/xunit/api/v3/index.json" />
Expand Down
9 changes: 8 additions & 1 deletion modules/Nncase.Modules.StackVM/packages.lock.json
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,8 @@
"dependencies": {
"Extension.Mathematics": "[1.2.12, )",
"Nncase.Core": "[1.0.0, )",
"Nncase.IO": "[1.0.0, )"
"Nncase.IO": "[1.0.0, )",
"Razor.Templating.Core": "[1.9.0, )"
}
},
"nncase.core": {
Expand Down Expand Up @@ -266,6 +267,12 @@
"libortki": "0.0.2"
}
},
"Razor.Templating.Core": {
"type": "CentralTransitive",
"requested": "[1.9.0, )",
"resolved": "1.9.0",
"contentHash": "eHNqkpmNcPr5rvP/8/FFkddnvzVMH0BSyrq03H0VLZK2r1GUe3RgIgsoIXnImHMIrBzUS8gOwV65MfRPdYRi6g=="
},
"Singulink.Collections.Weak": {
"type": "CentralTransitive",
"requested": "[1.0.2, )",
Expand Down
2 changes: 1 addition & 1 deletion src/Nncase.Cli/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ private static void ConfigureHost(IHostBuilder hostBuilder)
private static void ConfigureAppConfiguration(HostBuilderContext context, IConfigurationBuilder builder)
{
var baseDirectory = Path.GetDirectoryName(typeof(Program).Assembly.Location);
builder.SetBasePath(baseDirectory)
builder.SetBasePath(baseDirectory!)
.AddJsonFile("config.json", true, false);
}
}
9 changes: 8 additions & 1 deletion src/Nncase.Cli/packages.lock.json
Original file line number Diff line number Diff line change
Expand Up @@ -664,7 +664,8 @@
"dependencies": {
"Extension.Mathematics": "[1.2.12, )",
"Nncase.Core": "[1.0.0, )",
"Nncase.IO": "[1.0.0, )"
"Nncase.IO": "[1.0.0, )",
"Razor.Templating.Core": "[1.9.0, )"
}
},
"nncase.compiler": {
Expand Down Expand Up @@ -932,6 +933,12 @@
"libortki": "0.0.2"
}
},
"Razor.Templating.Core": {
"type": "CentralTransitive",
"requested": "[1.9.0, )",
"resolved": "1.9.0",
"contentHash": "eHNqkpmNcPr5rvP/8/FFkddnvzVMH0BSyrq03H0VLZK2r1GUe3RgIgsoIXnImHMIrBzUS8gOwV65MfRPdYRi6g=="
},
"Singulink.Collections.Weak": {
"type": "CentralTransitive",
"requested": "[1.0.2, )",
Expand Down
1 change: 1 addition & 0 deletions src/Nncase.CodeGen/Nncase.CodeGen.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

<ItemGroup>
<PackageReference Include="Extension.Mathematics" />
<PackageReference Include="Razor.Templating.Core" />
</ItemGroup>

<ItemGroup>
Expand Down
6 changes: 6 additions & 0 deletions src/Nncase.CodeGen/packages.lock.json
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@
"resolved": "1.2.12",
"contentHash": "D4mn5Cab4ztPLJ0V8uMErDrO/Y61098nwrvyIOLZymVAYOQcwP1vomVWKbTagf1aPU3cX5Q7adZtQEQwOy6XEg=="
},
"Razor.Templating.Core": {
"type": "Direct",
"requested": "[1.9.0, )",
"resolved": "1.9.0",
"contentHash": "eHNqkpmNcPr5rvP/8/FFkddnvzVMH0BSyrq03H0VLZK2r1GUe3RgIgsoIXnImHMIrBzUS8gOwV65MfRPdYRi6g=="
},
"StyleCop.Analyzers": {
"type": "Direct",
"requested": "[1.2.0-beta.435, )",
Expand Down
2 changes: 1 addition & 1 deletion src/Nncase.Compiler/Compiler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,6 @@ public void TargetIndependentPass(IPassManager passManager)
p.Add<Passes.Rules.Neutral.NormAxisReshape>();
p.Add<Passes.Rules.Neutral.NormAxisReduceArg>();
p.Add<Passes.Rules.Neutral.NormAxisSlice>();
p.Add<Passes.Rules.Neutral.SwapBinaryArgs>();
p.Add<Passes.Rules.Neutral.SqueezeTransposeShape>();
p.Add<Passes.Rules.Neutral.Squeeze5DTranspose>();
p.Add<Passes.Rules.Neutral.SqueezeBinaryShape>();
Expand Down Expand Up @@ -141,6 +140,7 @@ public void TargetIndependentPass(IPassManager passManager)
p.Add<Passes.Rules.Neutral.FoldNopReduce>();
p.Add<Passes.Rules.Neutral.SliceToGetItem>();
p.Add<Passes.Rules.Neutral.FoldTwoPads>();
p.Add<Passes.Rules.Neutral.SwapBinaryArgs>();
p.Add<Passes.Rules.Neutral.FoldDilatedConv2D>();
});

Expand Down
9 changes: 8 additions & 1 deletion src/Nncase.Compiler/packages.lock.json
Original file line number Diff line number Diff line change
Expand Up @@ -661,7 +661,8 @@
"dependencies": {
"Extension.Mathematics": "[1.2.12, )",
"Nncase.Core": "[1.0.0, )",
"Nncase.IO": "[1.0.0, )"
"Nncase.IO": "[1.0.0, )",
"Razor.Templating.Core": "[1.9.0, )"
}
},
"nncase.core": {
Expand Down Expand Up @@ -880,6 +881,12 @@
"libortki": "0.0.2"
}
},
"Razor.Templating.Core": {
"type": "CentralTransitive",
"requested": "[1.9.0, )",
"resolved": "1.9.0",
"contentHash": "eHNqkpmNcPr5rvP/8/FFkddnvzVMH0BSyrq03H0VLZK2r1GUe3RgIgsoIXnImHMIrBzUS8gOwV65MfRPdYRi6g=="
},
"Singulink.Collections.Weak": {
"type": "CentralTransitive",
"requested": "[1.0.2, )",
Expand Down
17 changes: 17 additions & 0 deletions src/Nncase.Core/TIR/TIRUtilities.cs
Original file line number Diff line number Diff line change
Expand Up @@ -71,4 +71,21 @@ public static class TIRUtilities
IR.F.Math.Max(0, t.First.Start),
IR.F.Math.Min(t.Second.FixedValue, t.First.Stop),
t.First.Step)).ToArray();

public static bool TryGetFixedRegions(TIR.BufferRegion region, out (int Start, int Stop, int Step)[] slice)
{
slice = new (int Start, int Stop, int Step)[region.Region.Length];
for (int i = 0; i < region.Region.Length; i++)
{
var rg = region.Region[i];
if (rg is not Range { Start: IR.TensorConst start, Stop: IR.TensorConst stop, Step: IR.TensorConst step })
{
return false;
}

slice[i] = (start.Value.ToScalar<int>(), stop.Value.ToScalar<int>(), step.Value.ToScalar<int>());
}

return true;
}
}
1 change: 0 additions & 1 deletion src/Nncase.Core/Utilities/ShapeExprUtility.cs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.

using GiGraph.Dot.Output.Writers.Edges;
using Nncase.Diagnostics;
using Nncase.IR;
using Nncase.IR.Tensors;
Expand Down
47 changes: 43 additions & 4 deletions src/Nncase.EGraph/CostModel/EGraphCostPrinter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,25 @@ internal static DotGraph DumpEgraphAsDot(IEGraph eGraph, CostModel.EGraphCostMod
return printer.SaveToStream(file);
}

/// <summary>
/// find the minCostEnode in eclass.
/// <remarks>
/// the marker first.
/// </remarks>
/// </summary>
internal static ENode MinByWithMarker(EClass eClass, CostModel.EGraphCostModel costModel)
{
return eClass.Nodes.OrderBy(e => e.Expr, ENodeTypeComparer.Instance).MinBy(x => x.Expr is Marker ? CostModel.Cost.Zero : costModel[x])!;
}

/// <summary>
/// find the minCostEnode in eclass skip marker.
/// </summary>
internal static ENode MinByWithOutMarker(EClass eClass, CostModel.EGraphCostModel costModel)
{
return eClass.Nodes.Where(e => e.Expr is not Marker).MinBy(x => costModel[x])!;
}

private DotGraph AttachEGraphCost(CostModel.EGraphCostModel costModel, EClass entry)
{
// 1. display each enode costs.
Expand Down Expand Up @@ -72,12 +91,12 @@ void Dfs(EClass curclass)
continue;
}

var minCostEnode = parent.MinByWithMarker(costModel);
var minCostEnode = MinByWithMarker(parent, costModel);

// when this marker ecalss has been visited, skip it.
if (markerEclassMemo.Contains(parent))
{
minCostEnode = parent.MinByWithOutMarker(costModel);
minCostEnode = MinByWithOutMarker(parent, costModel);
}

var (minCostDotnode, table) = NodesMap[minCostEnode];
Expand All @@ -93,7 +112,7 @@ void Dfs(EClass curclass)
if (minCostEnode.Expr is Marker && child == parent)
{
markerEclassMemo.Add(child);
var otherminCostENode = child.MinByWithOutMarker(costModel);
var otherminCostENode = MinByWithOutMarker(child, costModel);
var (childDotNode, _) = NodesMap[otherminCostENode];
_dotGraph.Edges.Add(childDotNode, minCostDotnode, edge =>
{
Expand All @@ -103,7 +122,7 @@ void Dfs(EClass curclass)
}
else
{
var childEnode = child.Find().MinByWithMarker(costModel);
var childEnode = MinByWithMarker(child.Find(), costModel);
var (childDotNode, _) = NodesMap[childEnode];
_dotGraph.Edges.Add(childDotNode, minCostDotnode, edge =>
{
Expand All @@ -126,3 +145,23 @@ void Dfs(EClass curclass)
return _dotGraph;
}
}

internal sealed class ENodeTypeComparer : IComparer<Expr>
{
public static readonly ENodeTypeComparer Instance = new();

public int Compare(Expr? x, Expr? y) => (x, y) switch
{
(null, null) => 0,
(Expr, null) => 1,
(null, Expr) => -1,
(Expr, Expr) => GetPriority(x).CompareTo(GetPriority(y)),
};

private int GetPriority(Expr x) => x switch
{
Marker => 0,
Const => 1,
_ => 2,
};
}
50 changes: 50 additions & 0 deletions src/Nncase.EGraph/Passes/EGraphExtensions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.

using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Text;
using Google.OrTools.Sat;
using Nncase.CostModel;
using Nncase.Diagnostics;
using Nncase.IR;
using Nncase.PatternMatch;
using static Nncase.PatternMatch.F.Math;
using static Nncase.PatternMatch.Utility;

namespace Nncase.Passes;

/// <summary>
/// EGraph extract extensions.
/// </summary>
public static class EGraphExtensions
{
/// <summary>
/// Extract egraph.
/// </summary>
/// <param name="eGraph">egraph.</param>
/// <param name="root">Root eclass.</param>
/// <param name="basefunc_cost_evaluator">base func cost evaluator.</param>
/// <param name="constrains">the cp model constrains.</param>
public static Expr Extract(this IEGraph eGraph, EClass root, Evaluator.IBaseFuncCostEvaluator? basefunc_cost_evaluator, EGraphExtractConstrains[] constrains)
{
// 1. set enode expr with more accuracy type.
foreach (var eclass in eGraph.Classes)
{
foreach (var nodes in eclass.Nodes)
{
if (eclass.CheckedType.CompareTo(nodes.Expr.CheckedType) > 0)
{
nodes.Expr.CheckedType = eclass.CheckedType;
}
}
}

// 2. start the cost evaluator
var costModel = new CostModel.EGraphCostEvaluator(root.Find(), basefunc_cost_evaluator, false).Evaluate();

return new EGraphExtractor(costModel).Extract(root.Find(), eGraph, constrains);
}
}
95 changes: 0 additions & 95 deletions src/Nncase.EGraph/Passes/EGraphExtractExtensions.cs

This file was deleted.

Loading

0 comments on commit 2498b1b

Please sign in to comment.