From 6cf1eb63c73535b33ed4ac6b797994ca3aac89bc Mon Sep 17 00:00:00 2001 From: huochenghai Date: Sat, 8 Feb 2025 13:17:34 +0800 Subject: [PATCH 01/18] unfold distributed const --- .../Passes/Rules/CPU/FoldBoxingConst.cs | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/modules/Nncase.Modules.CPU/Passes/Rules/CPU/FoldBoxingConst.cs b/modules/Nncase.Modules.CPU/Passes/Rules/CPU/FoldBoxingConst.cs index f360987b06..59316661ef 100644 --- a/modules/Nncase.Modules.CPU/Passes/Rules/CPU/FoldBoxingConst.cs +++ b/modules/Nncase.Modules.CPU/Passes/Rules/CPU/FoldBoxingConst.cs @@ -32,3 +32,21 @@ public partial class FoldBoxingConst : RewriteRule return new TensorConst(input, type.NdSBP, type.Placement); } } + +[RuleGenerator] +public partial class UnfoldDistributedConst : RewriteRule +{ + /// + public override Pattern Pattern { get; } = IsTensorConst("input"); + + private Expr? GetReplace(TensorConst input) + { + var type = input.CheckedType; + if (type is DistributedType) + { + return IR.F.CPU.Boxing(input.Value, type); + } + + return null; + } +} From 64a8e9be8c76724e8a8b9475914d01deccb07dd1 Mon Sep 17 00:00:00 2001 From: huochenghai Date: Sat, 8 Feb 2025 13:18:23 +0800 Subject: [PATCH 02/18] add HierarchyLatencies for xpu toml config --- tests/config.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/config.toml b/tests/config.toml index f25539e50f..367767e054 100644 --- a/tests/config.toml +++ b/tests/config.toml @@ -154,6 +154,7 @@ HierarchyNames = "cdxyt" HierarchySizes = [268435456, 1048576] MemoryCapacities = [262144, 67108864] MemoryBandWidths = [64, 32] +HierarchyLatencies = [10000, 10000, 10000, 10000, 10000] UnifiedMemoryArch = false Packing = true HierarchyKind = "nncase.HierarchyKind.SMT" \ No newline at end of file From f4e95cfafc92c3fae86a580e73c309e41f168569 Mon Sep 17 00:00:00 2001 From: huochenghai Date: Sat, 8 Feb 2025 15:19:15 +0800 Subject: [PATCH 03/18] fix auto-dist of resize --- .../Nncase.Modules.CPU/Passes/Distributed/AutoDistributed.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/Nncase.Modules.CPU/Passes/Distributed/AutoDistributed.cs b/modules/Nncase.Modules.CPU/Passes/Distributed/AutoDistributed.cs index f044ab939b..7c2dec1ae6 100644 --- a/modules/Nncase.Modules.CPU/Passes/Distributed/AutoDistributed.cs +++ b/modules/Nncase.Modules.CPU/Passes/Distributed/AutoDistributed.cs @@ -706,7 +706,7 @@ private Expr SolveAndExtract(DistributedSearchGraph rootCluster) CostModel.Cost cost; switch (enode.Expr) { - case Const or Var or If or IR.Tuple or BaseFunction: + case Const or Var or If or IR.Tuple or BaseFunction or IR.None: cost = new CostModel.Cost() { [CostModel.CostFactorNames.CPUCycles] = 1 }; break; case Op op: From 852ba16cdebe34db778faa1acc669eddd4db2a7a Mon Sep 17 00:00:00 2001 From: huochenghai Date: Sat, 8 Feb 2025 15:19:46 +0800 Subject: [PATCH 04/18] Split PartialAndReshardBoxing --- ...{FoldBoxingConst.cs => BoxingTransform.cs} | 27 +++++++++++++++++++ 1 file changed, 27 insertions(+) rename modules/Nncase.Modules.CPU/Passes/Rules/CPU/{FoldBoxingConst.cs => BoxingTransform.cs} (59%) diff --git a/modules/Nncase.Modules.CPU/Passes/Rules/CPU/FoldBoxingConst.cs b/modules/Nncase.Modules.CPU/Passes/Rules/CPU/BoxingTransform.cs similarity index 59% rename from modules/Nncase.Modules.CPU/Passes/Rules/CPU/FoldBoxingConst.cs rename to modules/Nncase.Modules.CPU/Passes/Rules/CPU/BoxingTransform.cs index 59316661ef..00e1fb66a8 100644 --- a/modules/Nncase.Modules.CPU/Passes/Rules/CPU/FoldBoxingConst.cs +++ b/modules/Nncase.Modules.CPU/Passes/Rules/CPU/BoxingTransform.cs @@ -50,3 +50,30 @@ public partial class UnfoldDistributedConst : RewriteRule return null; } } + +[RuleGenerator] +public partial class SplitPartialAndReshardBoxing : RewriteRule +{ + /// + public override Pattern Pattern { get; } = IsBoxing( + target_name: "boxing", + call_name: "call", + _ => true, + IsWildcard("input")); + + private Expr? GetReplace(Call call, Expr input) + { + if (input.CheckedType is DistributedType it && it.NdSBP.Any(sbp => sbp is SBPPartial) && call.CheckedType is DistributedType ot) + { + var newSBPs = it.NdSBP.Select(sbp => sbp is SBPPartial ? SBP.B : sbp).ToArray(); + if (newSBPs.Length != ot.NdSBP.Count || Enumerable.Range(0, newSBPs.Length).Any(i => newSBPs[i] != ot.NdSBP[i])) + { + return IR.F.CPU.Boxing(IR.F.CPU.Boxing(input, new DistributedType(it.TensorType, newSBPs, it.Placement)), ot); + } + + return null; + } + + return null; + } +} From 0a1d476200af29afd5ee0bccf3abec51e7d8b1dd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=83=91=E5=90=AF=E8=88=AA?= <597323109@qq.com> Date: Sat, 8 Feb 2025 16:32:43 +0800 Subject: [PATCH 05/18] fix multi-branch --- .../Passes/Distributed/AutoDistributed.cs | 32 +++++++++++-------- 1 file changed, 19 insertions(+), 13 deletions(-) diff --git a/modules/Nncase.Modules.CPU/Passes/Distributed/AutoDistributed.cs b/modules/Nncase.Modules.CPU/Passes/Distributed/AutoDistributed.cs index 7c2dec1ae6..e709e37966 100644 --- a/modules/Nncase.Modules.CPU/Passes/Distributed/AutoDistributed.cs +++ b/modules/Nncase.Modules.CPU/Passes/Distributed/AutoDistributed.cs @@ -829,7 +829,7 @@ private Expr SolveAndExtract(DistributedSearchGraph rootCluster) Dump(stream, picks, costMemo); } - return new ExprBuildVisitor(_rootSearchGraph, picks).Visit(rootCluster.Clusters.OfType()); + return new ExprBuildVisitor(_rootSearchGraph, picks).Visit(rootCluster.Clusters.OfType(), null); } private HyperGraph ToHyperGraph(DistributedSearchGraph root, DistributedSearchGraph rootCluster) @@ -881,23 +881,29 @@ public ExprBuildVisitor(DistributedSearchGraph rootSearchGraph, Dictionary rootBuckets) + public Expr Visit(IEnumerable currentBuckets, SearchableNode? parent) { - var rootPicks = rootBuckets.SelectMany(b => b.Vertices).Where(v => _picks.TryGetValue(v, out var pick) && pick).ToArray(); - if (rootPicks.Length != 1) + var currentPicks = currentBuckets.SelectMany(b => b.Vertices).Where(v => _picks.TryGetValue(v, out var pick) && pick).ToArray(); + if (currentPicks.Length != 1 && parent is null) { - throw new InvalidProgramException("the one cluster only can pick one vertex!"); + throw new InvalidProgramException("the root cluster only can pick one vertex!"); } - var root = rootPicks[0]; - if (!_memo.TryGetValue(root, out var expr)) + if (currentPicks.Length > 1 && parent is not null) { - _rootSearchGraph.TryGetOutEdges(root, out var edges); - var children = edges.GroupBy(e => e.InputIndex).Select(g => Visit(g.Select(e => e.InputGraph))).ToArray(); - switch (root.Expr) + currentPicks = currentPicks.Where(cur => _rootSearchGraph.TryGetEdge(parent, cur, out _)).ToArray(); + } + + // todo is currentPicks still > 1, we should find the low cost one. + var current = currentPicks[0]; + if (!_memo.TryGetValue(current, out var expr)) + { + _rootSearchGraph.TryGetOutEdges(current, out var edges); + var children = edges.GroupBy(e => e.InputIndex).Select(g => Visit(g.Select(e => e.InputGraph), current)).ToArray(); + switch (current.Expr) { case Var or TensorConst or TupleConst or None: - expr = root.Expr; + expr = current.Expr; break; case BaseFunction func: expr = new Call(target: func, arguments: children); @@ -912,10 +918,10 @@ public Expr Visit(IEnumerable rootBuckets) expr = @if.With(condition: children[^3], then: children[^2], @else: children[^1], paramList: children[..^3].ToArray()); break; default: - throw new NotSupportedException(root.Expr.GetType().Name); + throw new NotSupportedException(current.Expr.GetType().Name); } - _memo.Add(root, expr); + _memo.Add(current, expr); } return expr; From caf1def60a170df7eb981236780844e6c22dbb14 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=83=91=E5=90=AF=E8=88=AA?= <597323109@qq.com> Date: Sat, 8 Feb 2025 17:34:32 +0800 Subject: [PATCH 06/18] fix tuple search --- .../Passes/Distributed/AutoDistributed.cs | 38 +++++++++++-------- 1 file changed, 22 insertions(+), 16 deletions(-) diff --git a/modules/Nncase.Modules.CPU/Passes/Distributed/AutoDistributed.cs b/modules/Nncase.Modules.CPU/Passes/Distributed/AutoDistributed.cs index e709e37966..e171f2c6bf 100644 --- a/modules/Nncase.Modules.CPU/Passes/Distributed/AutoDistributed.cs +++ b/modules/Nncase.Modules.CPU/Passes/Distributed/AutoDistributed.cs @@ -328,6 +328,11 @@ protected override Unit VisitLeafCall(Call expr) } } + if (callCluster.VertexCount == 0) + { + throw new InvalidDataException($"Can't add any valid candidates for {expr.Target}"); + } + _inferedMemo.Add(expr, callCluster); if (!isSupported) @@ -453,18 +458,16 @@ private DistributedSearchGraph CreateOriginatorCluster(Expr expr, bool init) if (expr is IR.Tuple tp) { var distCluster = _rootSearchGraph.CreateCluster(SearchGraphKind.DistributedCluster); - var buckets = new DistributedSearchGraph[tp.Fields.Length]; - foreach (var (f, fGraph, i) in tp.Fields.AsValueEnumerable().Select((f, i) => (f, Visit(f), i))) - { - buckets[i] = TryAddOriginator(f).Clusters.OfType().First(); - } - var tpnode = new SearchableNode(new IR.Tuple(), new TupleType(buckets.Select(g => g.Vertices.First().IRType).ToArray())); - var bucket = distCluster.CreateCluster(SearchGraphKind.Bucket); - bucket.AddVertex(tpnode); - for (int i = 0; i < tp.Fields.Length; i++) + foreach (var buckets in tp.Fields.ToArray().Select(f => TryAddOriginator(f).Clusters.OfType()).CartesianProduct().Select(x => x.ToArray())) { - _rootSearchGraph.AddEdge(new(tpnode, buckets[i].Vertices.First(), i, buckets[i])); + var tpnode = new SearchableNode(new IR.Tuple(), new TupleType(buckets.Select(g => g.Vertices.First().IRType).ToArray())); + var bucket = distCluster.CreateCluster(SearchGraphKind.Bucket); + bucket.AddVertex(tpnode); + for (int i = 0; i < buckets.Length; i++) + { + _rootSearchGraph.AddEdge(new(tpnode, buckets[i].Vertices.First(), i, buckets[i])); + } } return distCluster; @@ -711,7 +714,7 @@ private Expr SolveAndExtract(DistributedSearchGraph rootCluster) break; case Op op: { - if (!_rootSearchGraph.TryGetOutEdges(enode, out var edges)) + if (!_rootSearchGraph.TryGetOutEdges(enode, out var edges) || !edges.Any()) { throw new NotSupportedException("graph doesn't contain the vertex."); } @@ -757,15 +760,18 @@ private Expr SolveAndExtract(DistributedSearchGraph rootCluster) } // 3. no cycle - foreach (var cluster in _rootSearchGraph.Clusters.OfType()) + if (Bidirectional) { - foreach (var sourceBucket in cluster.Clusters.OfType()) + foreach (var cluster in _rootSearchGraph.Clusters.OfType()) { - foreach (var destBucket in cluster.Clusters.OfType().Where(b => !ReferenceEquals(b, sourceBucket))) + foreach (var sourceBucket in cluster.Clusters.OfType()) { - foreach (var (src, dest) in sourceBucket.Vertices.Where(v => v.IsBidirect).Zip(destBucket.Vertices.Where(v => v.IsBidirect))) + foreach (var destBucket in cluster.Clusters.OfType().Where(b => !ReferenceEquals(b, sourceBucket))) { - cpmodel.AddBoolAnd([varMemo[src].Not(), varMemo[dest].Not()]); + foreach (var (src, dest) in sourceBucket.Vertices.Where(v => v.IsBidirect).Zip(destBucket.Vertices.Where(v => v.IsBidirect))) + { + cpmodel.AddBoolAnd([varMemo[src].Not(), varMemo[dest].Not()]); + } } } } From 3453a2b17b73fce9347b3f0a6425a2e43a5ead97 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=83=91=E5=90=AF=E8=88=AA?= <597323109@qq.com> Date: Sat, 8 Feb 2025 18:38:38 +0800 Subject: [PATCH 07/18] add callback --- .../Passes/Distributed/AutoDistributed.cs | 73 ++++++++++--------- src/Nncase.Core/ITarget.cs | 4 + 2 files changed, 41 insertions(+), 36 deletions(-) diff --git a/modules/Nncase.Modules.CPU/Passes/Distributed/AutoDistributed.cs b/modules/Nncase.Modules.CPU/Passes/Distributed/AutoDistributed.cs index e171f2c6bf..af016e0e95 100644 --- a/modules/Nncase.Modules.CPU/Passes/Distributed/AutoDistributed.cs +++ b/modules/Nncase.Modules.CPU/Passes/Distributed/AutoDistributed.cs @@ -18,7 +18,7 @@ namespace Nncase.Passes.Distributed; -internal enum SearchGraphKind : int +public enum SearchGraphKind : int { Root, DistributedCluster, @@ -51,8 +51,34 @@ public AutoDistributedPass(bool bidirectional, string moduleKind, CompileOptions _moduleKind = moduleKind; } + public event Action, CompileOptions, ICpuTargetOptions>? OnExtract; + public bool Bidirectional { get; } + public static void AllConstantMemoryConstrains(CpModel model, DistributedSearchGraph searchGraph, Dictionary vars, CompileOptions compileOptions, ICpuTargetOptions targetOptions) + { + var consts = vars.Keys.Where(k => k.Expr is Call { Target: IR.CPU.Boxing { NewType: DistributedType } } call && call.Arguments[0] is TensorConst tc && tc.Value.Length >= 8).ToArray(); + model.Add(LinearExpr.WeightedSum(consts.Select(k => vars[k]), consts.Select(k => + { + var type = DistributedUtility.GetDividedTensorType((DistributedType)k.Expr.CheckedType); + return TensorUtilities.GetProduct(type.Shape.ToValueArray()) * type.DType.SizeInBytes; + })) < (2L * 512L * 1024L * 1024L)); + } + + public static void SingleMemoryConstrains(CpModel model, DistributedSearchGraph searchGraph, Dictionary vars, CompileOptions compileOptions, ICpuTargetOptions targetOptions) + { + // var cpuTargetOptions = targetOptions; + foreach (var searchableNode in searchGraph.Vertices.Where(x => x.IRType is DistributedType)) + { + if (targetOptions.HierarchySizes.Length > 1) + { + var type = DistributedUtility.GetDividedTensorType((DistributedType)searchableNode.IRType); + var size = TensorUtilities.GetProduct(type.Shape.ToValueArray()) * type.DType.SizeInBytes; + model.Add(vars[searchableNode] * size < targetOptions.HierarchySizes[^2] / targetOptions.Hierarchies[0][^1]); + } + } + } + protected override Task RunCoreAsync(BaseFunction input, RunPassContext context) { if (input.Metadata is AutoDistributedMetaData { Skip: true }) @@ -61,11 +87,12 @@ protected override Task RunCoreAsync(BaseFunction input, RunPassCo } var rewriter = new AutoDistributedRewriter(_compileOptions, _compileOptions.TargetOptions is CpuTargetOptions options ? options : new CpuTargetOptions(), _moduleKind, _bidirectional); + rewriter.OnExtract += OnExtract; return Task.FromResult(rewriter.Rewirte(input)); } } -internal sealed class SearchableNode +public sealed class SearchableNode { public SearchableNode(Expr expr, IRType type, bool isBidirect = false) { @@ -81,7 +108,7 @@ public SearchableNode(Expr expr, IRType type, bool isBidirect = false) public bool IsBidirect { get; } } -internal sealed record CrossEdge : IEdge +public sealed record CrossEdge : IEdge { public CrossEdge(SearchableNode root, SearchableNode input, int inputIndex, DistributedSearchGraph inputGraph) { @@ -104,7 +131,7 @@ public CrossEdge(SearchableNode root, SearchableNode input, int inputIndex, Dist public SearchableNode Target => Input; } -internal sealed class DistributedSearchGraph : TieredAdjacencyGraph +public sealed class DistributedSearchGraph : TieredAdjacencyGraph { public DistributedSearchGraph([NotNull] AdjacencyGraph wrappedGraph, SearchGraphKind kind) : base(wrappedGraph) @@ -159,6 +186,8 @@ public AutoDistributedRewriter(CompileOptions compileOptions, CpuTargetOptions t _bidirectional = bidirectional; } + public event Action, CompileOptions, ICpuTargetOptions>? OnExtract; + public IRArray Placements { get; } public bool Bidirectional { get; } @@ -169,42 +198,11 @@ public AutoDistributedRewriter(CompileOptions compileOptions, CpuTargetOptions t public IReadOnlyDictionary NdSBP, Placement Placement)> Scheme { get; } - public static void MemoryExtractConstrains(CpModel model, IReadOnlyDictionary vars) - { - var consts = vars.Keys.Where(k => k.Expr is Call { Target: IR.CPU.Boxing { NewType: DistributedType } } call && call.Arguments[0] is TensorConst tc && tc.Value.Length >= 8).ToArray(); - model.Add(LinearExpr.WeightedSum(consts.Select(k => vars[k]), consts.Select(k => - { - var type = DistributedUtility.GetDividedTensorType((DistributedType)k.Expr.CheckedType); - return TensorUtilities.GetProduct(type.Shape.ToValueArray()) * type.DType.SizeInBytes; - })) < (2L * 512L * 1024L * 1024L)); - } - public static IReadOnlyList GetLeafCandidateDistTypes(TensorType tensorType, IEnumerable placements) { return placements.Select(placement => DistributedUtility.GetLeafCandidateNDSBPs(tensorType, placement).Select(ndsbp => new DistributedType(tensorType, ndsbp, placement))).SelectMany(e => e).ToArray(); } - public void SingleNodeMemoryExtractConstrains(CpModel model, IReadOnlyDictionary vars) - { - var distTypes = vars.Keys.Where(k => k.Expr.CheckedType is DistributedType dt).ToArray(); - foreach (var k in distTypes) - { - if (TargetOptions.HierarchySizes.Length > 1) - { - var type = DistributedUtility.GetDividedTensorType((DistributedType)k.Expr.CheckedType); - var size = TensorUtilities.GetProduct(type.Shape.ToValueArray()) * type.DType.SizeInBytes; - - if (k.Expr is Call { Target: IR.CPU.Boxing boxing } call && boxing.NewType is DistributedType distributedType && call.Arguments[0].CheckedType is DistributedType inType && inType.NdSBP.Any(sbp => sbp is SBPPartial) && distributedType != call.Arguments[0].CheckedType) - { - type = DistributedUtility.GetDividedTensorType(inType); - size += TensorUtilities.GetProduct(type.Shape.ToValueArray()) * type.DType.SizeInBytes; - } - - model.Add(vars[k] * size < TargetOptions.HierarchySizes[^2] / TargetOptions.Hierarchies[0][^1]); - } - } - } - public void FilterByScheme(Expr expr, DistributedSearchGraph cluster) { bool Matched(SearchableNode node, (IRArray NdSBP, Placement Placement) tp) @@ -777,7 +775,10 @@ private Expr SolveAndExtract(DistributedSearchGraph rootCluster) } } - // 3. add pick weights for all enode. + // 4. add custom constraints. + OnExtract?.Invoke(cpmodel, _rootSearchGraph, varMemo, CompileOptions, TargetOptions); + + // 5. add pick weights for all enode. cpmodel.Minimize(LinearExpr.WeightedSum(_rootSearchGraph.Vertices.Select(n => varMemo[n]), _rootSearchGraph.Vertices.Select(n => checked((long)costMemo[n].Score)))); if (cpmodel.Validate().Any()) diff --git a/src/Nncase.Core/ITarget.cs b/src/Nncase.Core/ITarget.cs index 8d5f1af5ed..41058bf3b2 100644 --- a/src/Nncase.Core/ITarget.cs +++ b/src/Nncase.Core/ITarget.cs @@ -23,6 +23,10 @@ public interface ICpuTargetOptions : ITargetOptions public int[] HierarchyBandWidths { get; } + public int[] HierarchySizes { get; } + + public int[][] Hierarchies { get; } + bool UnifiedMemoryArch { get; } } From c5b36c2798b98633309b381be85133d6edd7aeb2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=83=91=E5=90=AF=E8=88=AA?= <597323109@qq.com> Date: Mon, 10 Feb 2025 13:58:48 +0800 Subject: [PATCH 08/18] add distributed search stragegy --- .../Passes/Distributed/AutoDistributed.cs | 89 ++++++++++++------- .../Nncase.Modules.CPU/Targets/CPUTarget.cs | 2 +- .../Targets/CPUTargetOptions.cs | 13 +++ .../Distributed/UnitTestDistributeSchema.cs | 2 +- .../Targets/UnitTestCPUKernels.cs | 38 ++++++++ 5 files changed, 112 insertions(+), 32 deletions(-) diff --git a/modules/Nncase.Modules.CPU/Passes/Distributed/AutoDistributed.cs b/modules/Nncase.Modules.CPU/Passes/Distributed/AutoDistributed.cs index af016e0e95..0e500893d8 100644 --- a/modules/Nncase.Modules.CPU/Passes/Distributed/AutoDistributed.cs +++ b/modules/Nncase.Modules.CPU/Passes/Distributed/AutoDistributed.cs @@ -39,22 +39,16 @@ public sealed partial class AutoDistributedPass : FunctionPass { private readonly CompileOptions _compileOptions; - private readonly bool _bidirectional; - private readonly string _moduleKind; - public AutoDistributedPass(bool bidirectional, string moduleKind, CompileOptions compileOptions) + public AutoDistributedPass(string moduleKind, CompileOptions compileOptions) { - Bidirectional = bidirectional; _compileOptions = compileOptions; - _bidirectional = bidirectional; _moduleKind = moduleKind; } public event Action, CompileOptions, ICpuTargetOptions>? OnExtract; - public bool Bidirectional { get; } - public static void AllConstantMemoryConstrains(CpModel model, DistributedSearchGraph searchGraph, Dictionary vars, CompileOptions compileOptions, ICpuTargetOptions targetOptions) { var consts = vars.Keys.Where(k => k.Expr is Call { Target: IR.CPU.Boxing { NewType: DistributedType } } call && call.Arguments[0] is TensorConst tc && tc.Value.Length >= 8).ToArray(); @@ -70,11 +64,13 @@ public static void SingleMemoryConstrains(CpModel model, DistributedSearchGraph // var cpuTargetOptions = targetOptions; foreach (var searchableNode in searchGraph.Vertices.Where(x => x.IRType is DistributedType)) { - if (targetOptions.HierarchySizes.Length > 1) + var distType = (DistributedType)searchableNode.IRType; + if (targetOptions.HierarchySizes.Length > 1 && searchableNode.Expr is TensorConst && distType.TensorType.Shape.ToValueArray().SequenceEqual(new[] { 32, 128, 2 }) && distType.NdSBP.All(sbp => sbp is SBPBroadCast)) { - var type = DistributedUtility.GetDividedTensorType((DistributedType)searchableNode.IRType); - var size = TensorUtilities.GetProduct(type.Shape.ToValueArray()) * type.DType.SizeInBytes; - model.Add(vars[searchableNode] * size < targetOptions.HierarchySizes[^2] / targetOptions.Hierarchies[0][^1]); + // var type = DistributedUtility.GetDividedTensorType((DistributedType)searchableNode.IRType); + // var size = TensorUtilities.GetProduct(type.Shape.ToValueArray()) * type.DType.SizeInBytes; + // model.Add(vars[searchableNode] * size < targetOptions.HierarchySizes[^2] / targetOptions.Hierarchies[0][^1]); + model.AddAssumption(vars[searchableNode].Not()); } } } @@ -86,7 +82,7 @@ protected override Task RunCoreAsync(BaseFunction input, RunPassCo return Task.FromResult(input); } - var rewriter = new AutoDistributedRewriter(_compileOptions, _compileOptions.TargetOptions is CpuTargetOptions options ? options : new CpuTargetOptions(), _moduleKind, _bidirectional); + var rewriter = new AutoDistributedRewriter(_compileOptions, _compileOptions.TargetOptions is CpuTargetOptions options ? options : new CpuTargetOptions(), _moduleKind); rewriter.OnExtract += OnExtract; return Task.FromResult(rewriter.Rewirte(input)); } @@ -160,12 +156,9 @@ internal sealed class AutoDistributedRewriter : ExprVisitor private readonly string _moduleKind; - private readonly bool _bidirectional; - - public AutoDistributedRewriter(CompileOptions compileOptions, CpuTargetOptions targetOptions, string moduleKind = "cpu", bool bidirectional = false) + public AutoDistributedRewriter(CompileOptions compileOptions, CpuTargetOptions targetOptions, string moduleKind = "cpu") { Placements = targetOptions.Hierarchies.Select(h => new Placement(h, targetOptions.HierarchyNames)).ToArray(); - Bidirectional = bidirectional; CompileOptions = compileOptions; TargetOptions = targetOptions; _moduleKind = moduleKind; @@ -183,15 +176,12 @@ public AutoDistributedRewriter(CompileOptions compileOptions, CpuTargetOptions t _rootGraph = new(true); _rootSearchGraph = new(_rootGraph, SearchGraphKind.Root); _moduleKind = moduleKind; - _bidirectional = bidirectional; } public event Action, CompileOptions, ICpuTargetOptions>? OnExtract; public IRArray Placements { get; } - public bool Bidirectional { get; } - public CompileOptions CompileOptions { get; } public CpuTargetOptions TargetOptions { get; } @@ -338,21 +328,60 @@ protected override Unit VisitLeafCall(Call expr) return default; } - // 3. add bidirectional connections. - if (Bidirectional) + // 3. add expand connections. + switch (TargetOptions.DistributedSearchStrategy) { - foreach (var (lType, lBucket) in bucketMemo.Where(kv => kv.Key is DistributedType)) - { - foreach (var (rType, rBucket) in bucketMemo.Where(kv => kv.Key is DistributedType distributedType && distributedType != lType)) + case AutoDistributedSearchStrategy.ExpandAll: + foreach (var (lType, lBucket) in bucketMemo.Where(kv => kv.Key is DistributedType)) { - if (Evaluator.IR.CPU.BoxingEvaluator.VisitType(lType, rType) is not InvalidType) + foreach (var (rType, rBucket) in bucketMemo.Where(kv => kv.Key is DistributedType distributedType && distributedType != lType)) { - var rnode = new SearchableNode(new Boxing(rType), rType, true); - rBucket.AddVertex(rnode); - callCluster.AddEdge(new(rnode, lBucket.Vertices.First(), 0, lBucket)); + if (Evaluator.IR.CPU.BoxingEvaluator.VisitType(lType, rType) is not InvalidType) + { + var rnode = new SearchableNode(new Boxing(rType), rType, true); + rBucket.AddVertex(rnode); + callCluster.AddEdge(new(rnode, lBucket.Vertices.First(), 0, lBucket)); + } } } - } + + break; + case AutoDistributedSearchStrategy.ExpandPartial: +#pragma warning disable SA1008 // Opening parenthesis should be spaced correctly + // partial -> broadcast + foreach (var (lType, lBucket) in bucketMemo.Where(kv => kv.Key is DistributedType ldistType && ldistType.NdSBP.Any(sbp => sbp is SBPPartial))) + { + var ldistType = (DistributedType)lType; + foreach (var (rType, rBucket) in bucketMemo.Where(kv => kv.Key is DistributedType rdistType && ldistType.NdSBP.Zip(rdistType.NdSBP).All(p => (p.First, p.Second) switch { (SBPPartial, SBPBroadCast) => true, (SBP a, SBP b) => a == b }))) + { + if (Evaluator.IR.CPU.BoxingEvaluator.VisitType(lType, rType) is not InvalidType) + { + var rnode = new SearchableNode(new Boxing(rType), rType); + rBucket.AddVertex(rnode); + callCluster.AddEdge(new(rnode, lBucket.Vertices.First(), 0, lBucket)); + } + } + } + + // broadcast -> split + foreach (var (lType, lBucket) in bucketMemo.Where(kv => kv.Key is DistributedType ldistType && ldistType.NdSBP.All(sbp => sbp is SBPSplit or SBPBroadCast))) + { + var ldistType = (DistributedType)lType; + foreach (var (rType, rBucket) in bucketMemo.Where(kv => kv.Key is DistributedType rdistType && ldistType.NdSBP.Zip(rdistType.NdSBP).All(p => (p.First, p.Second) switch { (SBPBroadCast, SBPBroadCast or SBPSplit) => true, (SBP a, SBP b) => a == b }) && ldistType != rdistType)) + { + if (Evaluator.IR.CPU.BoxingEvaluator.VisitType(lType, rType) is not InvalidType) + { + var rnode = new SearchableNode(new Boxing(rType), rType); + rBucket.AddVertex(rnode); + callCluster.AddEdge(new(rnode, lBucket.Vertices.First(), 0, lBucket)); + } + } + } +#pragma warning restore SA1008 // Opening parenthesis should be spaced correctly + + break; + default: + throw new ArgumentOutOfRangeException($"{TargetOptions.DistributedSearchStrategy}"); } // 4. add not infered type in search space. @@ -758,7 +787,7 @@ private Expr SolveAndExtract(DistributedSearchGraph rootCluster) } // 3. no cycle - if (Bidirectional) + if (TargetOptions.DistributedSearchStrategy is AutoDistributedSearchStrategy.ExpandAll) { foreach (var cluster in _rootSearchGraph.Clusters.OfType()) { diff --git a/modules/Nncase.Modules.CPU/Targets/CPUTarget.cs b/modules/Nncase.Modules.CPU/Targets/CPUTarget.cs index 776348b700..e13f2b9ef3 100644 --- a/modules/Nncase.Modules.CPU/Targets/CPUTarget.cs +++ b/modules/Nncase.Modules.CPU/Targets/CPUTarget.cs @@ -114,7 +114,7 @@ public void RegisterTargetDependentAfterQuantPass(IPassManager passManager, Comp }); } - passManager.Add(true, Kind); + passManager.Add(Kind); passManager.Add(); diff --git a/modules/Nncase.Modules.CPU/Targets/CPUTargetOptions.cs b/modules/Nncase.Modules.CPU/Targets/CPUTargetOptions.cs index d0f14b2ef7..fd913c9527 100644 --- a/modules/Nncase.Modules.CPU/Targets/CPUTargetOptions.cs +++ b/modules/Nncase.Modules.CPU/Targets/CPUTargetOptions.cs @@ -30,6 +30,13 @@ public enum NocArchitecture : byte CrossBar = 1, } +public enum AutoDistributedSearchStrategy : uint +{ + ExpandPartial = 0, + ExpandAll, + NoExpand, +} + public class CpuTargetOptions : ICpuTargetOptions { [DisplayName("--model-name")] @@ -112,6 +119,12 @@ public class CpuTargetOptions : ICpuTargetOptions [DefaultValue("")] public string DistributedScheme { get; set; } = string.Empty; + [DisplayName("--distributed-search-strategy")] + [Description("the distributed search strategy.")] + [DefaultValue(AutoDistributedSearchStrategy.ExpandPartial)] + [CommandLine.FromAmong(AutoDistributedSearchStrategy.ExpandPartial, AutoDistributedSearchStrategy.ExpandAll, AutoDistributedSearchStrategy.NoExpand)] + public AutoDistributedSearchStrategy DistributedSearchStrategy { get; set; } = AutoDistributedSearchStrategy.ExpandPartial; + [DisplayName("--custom-op-scheme")] [Description("the custom-op scheme path.")] [DefaultValue("")] diff --git a/src/Nncase.Tests/Distributed/UnitTestDistributeSchema.cs b/src/Nncase.Tests/Distributed/UnitTestDistributeSchema.cs index 15418ce940..f4602c089f 100644 --- a/src/Nncase.Tests/Distributed/UnitTestDistributeSchema.cs +++ b/src/Nncase.Tests/Distributed/UnitTestDistributeSchema.cs @@ -82,7 +82,7 @@ public async Task TestLoadScheme() func = new(output); } - var pass = new AutoDistributedPass(true, "cpu", CompileOptions); + var pass = new AutoDistributedPass("cpu", CompileOptions); var result = await pass.RunAsync(func, new()); diff --git a/src/Nncase.Tests/Targets/UnitTestCPUKernels.cs b/src/Nncase.Tests/Targets/UnitTestCPUKernels.cs index cdf791907c..3824f067d8 100644 --- a/src/Nncase.Tests/Targets/UnitTestCPUKernels.cs +++ b/src/Nncase.Tests/Targets/UnitTestCPUKernels.cs @@ -647,6 +647,40 @@ public async Task TestMatMulReshapeUnary(int[] lhsShape, int[] rhsShape, int[] n await RunCases(Path.Join(CompileOptions.DumpDir.ToString(), $"Theory{number}"), feedDict, new[] { unary }); } + [Theory] + [InlineData([true, new[] { 4, 4 }, 0])] // enable packing + public async Task TestLlamaBMM(bool packing, int[] hierarchy, int number) + { + var options = (CpuTargetOptions)CompileOptions.TargetOptions; + options.Packing = packing; + options.Hierarchies[0] = hierarchy; + options.HierarchyNames = string.Join(string.Empty, "cbt".TakeLast(hierarchy.Length)); + options.HierarchySizes = Enumerable.Repeat((int)MathF.Pow(2, 30), hierarchy.Length).ToArray(); + options.HierarchyLatencies = Enumerable.Repeat(10000, hierarchy.Length).ToArray(); + + var var_5 = new Var("var_5", new TensorType(DataTypes.Float32, new[] { 1, 384, 4096 })); + var var_9 = new Var("var_9", new TensorType(DataTypes.Float32, new[] { 1, 384, 4096 })); + Expr pre; + { + var v33 = IR.F.Math.MatMul(var_9, IR.F.Random.Normal(DataTypes.Float32, 0, 0.1, 5, new[] { 4096, 4096 }).Evaluate().AsTensor()); + var v34 = IR.F.Math.Binary(BinaryOp.Add, var_5, v33); + var v35 = IR.F.NN.LayerNorm(2, 1E-06f, v34, IR.F.Random.Normal(DataTypes.Float32, 0, 0.1, 6, new[] { 4096 }).Evaluate().AsTensor(), IR.F.Random.Normal(DataTypes.Float32, 0, 0.1, 7, new[] { 4096 }).Evaluate().AsTensor(), false); + var v36 = IR.F.Math.MatMul(v35, IR.F.Random.Normal(DataTypes.Float32, 0, 0.1, 8, new[] { 4096, 11008 }).Evaluate().AsTensor()); + pre = v36; + } + + var input_tensor1 = IR.F.Random.Normal(DataTypes.Float32, 0, 1, 100, new[] { 1, 384, 4096 }).Evaluate().AsTensor(); + var input_tensor5 = IR.F.Random.Normal(DataTypes.Float32, 0, 1, 104, new[] { 1, 384, 4096 }).Evaluate().AsTensor(); + var feedDict = new Dictionary() + { + { var_5, Value.FromTensor(input_tensor1) }, + { var_9, Value.FromTensor(input_tensor5) }, + }; + + var posts = new[] { pre }; + await RunCases(Path.Join(CompileOptions.DumpDir.ToString(), $"Theory{number}"), feedDict, posts); + } + [Theory(Skip = "ToBig")] [InlineData(new object[] { false, 0 })] [InlineData(new object[] { true, 1 })] // enable packing @@ -837,6 +871,10 @@ internal async Task Run(string dumpDir, CpuKernelCase kernelCase) private async Task Compile(IRModule module) { var pmgr = CompileSession.CreatePassManager("pmgr"); + pmgr.AddWithName("PreProcess").Configure(p => + { + p.Add(); + }); CompileSession.Target.RegisterTargetDependentAfterQuantPass(pmgr, CompileSession.CompileOptions); CompileSession.Target.RegisterTargetDependentBeforeCodeGen(pmgr, CompileSession.CompileOptions); await pmgr.RunAsync(module); From 3ee97fa73d668e1789023b0c3f60f1e96d2b08ed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=83=91=E5=90=AF=E8=88=AA?= <597323109@qq.com> Date: Mon, 10 Feb 2025 14:24:46 +0800 Subject: [PATCH 09/18] add capi --- .../Passes/Distributed/AutoDistributed.cs | 2 +- .../Targets/CPUTargetOptionsCommand.cs | 11 +++++-- python/_nncase.pyi | 13 +++++--- src/Native/include/nncase/compiler.h | 27 +++++++++++++---- src/Nncase.Compiler/Interop/CApi.cs | 30 ++++++++++++++----- tests/config.toml | 3 +- tools/stackvm_gen/CApiGen/packages.lock.json | 6 ++-- 7 files changed, 69 insertions(+), 23 deletions(-) diff --git a/modules/Nncase.Modules.CPU/Passes/Distributed/AutoDistributed.cs b/modules/Nncase.Modules.CPU/Passes/Distributed/AutoDistributed.cs index 0e500893d8..1c6299f0d5 100644 --- a/modules/Nncase.Modules.CPU/Passes/Distributed/AutoDistributed.cs +++ b/modules/Nncase.Modules.CPU/Passes/Distributed/AutoDistributed.cs @@ -352,7 +352,7 @@ protected override Unit VisitLeafCall(Call expr) foreach (var (lType, lBucket) in bucketMemo.Where(kv => kv.Key is DistributedType ldistType && ldistType.NdSBP.Any(sbp => sbp is SBPPartial))) { var ldistType = (DistributedType)lType; - foreach (var (rType, rBucket) in bucketMemo.Where(kv => kv.Key is DistributedType rdistType && ldistType.NdSBP.Zip(rdistType.NdSBP).All(p => (p.First, p.Second) switch { (SBPPartial, SBPBroadCast) => true, (SBP a, SBP b) => a == b }))) + foreach (var (rType, rBucket) in bucketMemo.Where(kv => kv.Key is DistributedType rdistType && ldistType.NdSBP.Zip(rdistType.NdSBP).All(p => (p.First, p.Second) switch { (SBPPartial, SBPBroadCast) => true, (SBP a, SBP b) => a == b }) && !rdistType.NdSBP.Any(sbp => sbp is SBPPartial))) { if (Evaluator.IR.CPU.BoxingEvaluator.VisitType(lType, rType) is not InvalidType) { diff --git a/modules/Nncase.Modules.CPU/Targets/CPUTargetOptionsCommand.cs b/modules/Nncase.Modules.CPU/Targets/CPUTargetOptionsCommand.cs index 3bb2fb8772..dd0eb03c18 100644 --- a/modules/Nncase.Modules.CPU/Targets/CPUTargetOptionsCommand.cs +++ b/modules/Nncase.Modules.CPU/Targets/CPUTargetOptionsCommand.cs @@ -1,6 +1,6 @@ // Copyright (c) Canaan Inc. All rights reserved. // Licensed under the Apache license. See LICENSE file in the project root for full license information. -/* This file is generated by tools/stackvm_gen/CApiGen at 12/20/2024 3:41:05 PM +08:00. */ +/* This file is generated by tools/stackvm_gen/CApiGen at 2/10/2025 1:59:47 PM +08:00. */ using System; using System.Collections.Generic; @@ -13,7 +13,6 @@ using System.Threading.Tasks; using Nncase; using Nncase.CommandLine; -using Nncase.IR; namespace Nncase.Targets; @@ -110,6 +109,11 @@ public CpuTargetOptionsCommand(string name) description: "the distributed scheme path.", getDefaultValue: () => string.Empty); Add(DistributedSchemeOption); + DistributedSearchStrategyOption = new Option( + name: "--distributed-search-strategy", + description: "the distributed search strategy.", + getDefaultValue: () => AutoDistributedSearchStrategy.ExpandPartial); + Add(DistributedSearchStrategyOption); CustomOpSchemeOption = new Option( name: "--custom-op-scheme", description: "the custom-op scheme path.", @@ -145,6 +149,8 @@ public CpuTargetOptionsCommand(string name) public Option DistributedSchemeOption { get; } + public Option DistributedSearchStrategyOption { get; } + public Option CustomOpSchemeOption { get; } } @@ -175,6 +181,7 @@ public CpuTargetOptions GetBoundValue(InvocationContext context) MemoryCapacities = context.ParseResult.GetValueForOption(_cmd.MemoryCapacitiesOption)!.ToArray(), MemoryBandWidths = context.ParseResult.GetValueForOption(_cmd.MemoryBandWidthsOption)!.ToArray(), DistributedScheme = context.ParseResult.GetValueForOption(_cmd.DistributedSchemeOption)!, + DistributedSearchStrategy = context.ParseResult.GetValueForOption(_cmd.DistributedSearchStrategyOption)!, CustomOpScheme = context.ParseResult.GetValueForOption(_cmd.CustomOpSchemeOption)!, }; } diff --git a/python/_nncase.pyi b/python/_nncase.pyi index ae69363e28..9378034040 100644 --- a/python/_nncase.pyi +++ b/python/_nncase.pyi @@ -3,7 +3,7 @@ from typing import Any, List, BinaryIO, Enum import numpy -""" This block is generated by tools/stackvm_gen/CApiGen at 12/20/2024 5:27:07 PM +08:00. """ +""" This block is generated by tools/stackvm_gen/CApiGen at 2/10/2025 1:59:47 PM +08:00. """ class MemoryAccessArchitecture(Enum): UMA = 0 NUMA = 1 @@ -13,9 +13,13 @@ class NocArchitecture(Enum): class HierarchyKind(Enum): Parallel = 0 SMT = 1 -""" end the auto generated block by tools/stackvm_gen/CApiGen at 12/20/2024 5:27:07 PM +08:00. """ +class AutoDistributedSearchStrategy(Enum): + ExpandPartial = 0 + ExpandAll = 1 + NoExpand = 2 +""" end the auto generated block by tools/stackvm_gen/CApiGen at 2/10/2025 1:59:47 PM +08:00. """ -""" This block is generated by tools/stackvm_gen/CApiGen at 12/20/2024 5:27:07 PM +08:00. """ +""" This block is generated by tools/stackvm_gen/CApiGen at 2/10/2025 1:59:47 PM +08:00. """ class CpuTargetOptions: def __init__(self) -> None: ... ModelName: str @@ -32,8 +36,9 @@ class CpuTargetOptions: MemoryCapacities: List[int] MemoryBandWidths: List[int] DistributedScheme: str + DistributedSearchStrategy: AutoDistributedSearchStrategy CustomOpScheme: str -""" end the auto generated block by tools/stackvm_gen/CApiGen at 12/20/2024 5:27:07 PM +08:00. """ +""" end the auto generated block by tools/stackvm_gen/CApiGen at 2/10/2025 1:59:47 PM +08:00. """ class CompileOptions: benchmark_only: bool diff --git a/src/Native/include/nncase/compiler.h b/src/Native/include/nncase/compiler.h index 5ab4c0f321..38d9315448 100644 --- a/src/Native/include/nncase/compiler.h +++ b/src/Native/include/nncase/compiler.h @@ -81,7 +81,7 @@ typedef enum { } nncase_input_type_t; // clang-format off -/* This block is generated by tools/stackvm_gen/CApiGen at 12/20/2024 3:41:05 PM +08:00. */ +/* This block is generated by tools/stackvm_gen/CApiGen at 2/10/2025 1:59:47 PM +08:00. */ enum memory_access_architecture_t : uint8_t { memory_access_architecture_uma = 0, memory_access_architecture_numa = 1, @@ -94,7 +94,12 @@ enum hierarchy_kind_t : uint8_t { hierarchy_kind_parallel = 0, hierarchy_kind_smt = 1, }; -/* end the auto generated block by tools/stackvm_gen/CApiGen at 12/20/2024 3:41:05 PM +08:00. */ +enum auto_distributed_search_strategy_t : uint8_t { + auto_distributed_search_strategy_expand_partial = 0, + auto_distributed_search_strategy_expand_all = 1, + auto_distributed_search_strategy_no_expand = 2, +}; +/* end the auto generated block by tools/stackvm_gen/CApiGen at 2/10/2025 1:59:47 PM +08:00. */ // clang-format on typedef struct { @@ -241,7 +246,7 @@ typedef struct { clr_object_handle_t shape_bucket_options, const char *fix_var_map, size_t fix_var_map_size); // clang-format off - /* This block is generated by tools/stackvm_gen/CApiGen at 12/20/2024 3:41:05 PM +08:00. */ + /* This block is generated by tools/stackvm_gen/CApiGen at 2/10/2025 1:59:47 PM +08:00. */ clr_object_handle_t (*cpu_target_options_create)(); void (*cpu_target_options_set_model_name)(clr_object_handle_t handle, const char* value, size_t length); void (*cpu_target_options_set_packing)(clr_object_handle_t handle, uint8_t value); @@ -260,8 +265,10 @@ typedef struct { void (*cpu_target_options_set_memory_capacities)(clr_object_handle_t handle, int32_t* value, size_t shape0); void (*cpu_target_options_set_memory_band_widths)(clr_object_handle_t handle, int32_t* value, size_t shape0); void (*cpu_target_options_set_distributed_scheme)(clr_object_handle_t handle, const char* value, size_t length); + uint32_t (*cpu_target_options_get_distributed_search_strategy)(clr_object_handle_t handle); + void (*cpu_target_options_set_distributed_search_strategy)(clr_object_handle_t handle, uint32_t value); void (*cpu_target_options_set_custom_op_scheme)(clr_object_handle_t handle, const char* value, size_t length); - /* end the auto generated block by tools/stackvm_gen/CApiGen at 12/20/2024 3:41:05 PM +08:00. */ + /* end the auto generated block by tools/stackvm_gen/CApiGen at 2/10/2025 1:59:47 PM +08:00. */ // clang-format on clr_object_handle_t (*rtvalue_from_handle)(nncase::value_node *value); @@ -505,7 +512,7 @@ class shape_bucket_options : public clr_object_base { }; // clang-format off -/* This block is generated by tools/stackvm_gen/CApiGen at 12/20/2024 3:41:05 PM +08:00. */ +/* This block is generated by tools/stackvm_gen/CApiGen at 2/10/2025 1:59:47 PM +08:00. */ class cpu_target_options : public clr_object_base { public: using clr_object_base::clr_object_base; @@ -628,11 +635,19 @@ class cpu_target_options : public clr_object_base { nncase_clr_api()->cpu_target_options_set_distributed_scheme(obj_.get(), value.data(), value.length()); } + auto_distributed_search_strategy_t distributed_search_strategy() { + return (auto_distributed_search_strategy_t)nncase_clr_api()->cpu_target_options_get_distributed_search_strategy(obj_.get()); + } + + void distributed_search_strategy(auto_distributed_search_strategy_t value) { + nncase_clr_api()->cpu_target_options_set_distributed_search_strategy(obj_.get(), value); + } + void custom_op_scheme(std::string_view value) { nncase_clr_api()->cpu_target_options_set_custom_op_scheme(obj_.get(), value.data(), value.length()); } }; -/* end the auto generated block by tools/stackvm_gen/CApiGen at 12/20/2024 3:41:05 PM +08:00. */ +/* end the auto generated block by tools/stackvm_gen/CApiGen at 2/10/2025 1:59:47 PM +08:00. */ // clang-format on class cstream : public clr_object_base { diff --git a/src/Nncase.Compiler/Interop/CApi.cs b/src/Nncase.Compiler/Interop/CApi.cs index 2946c732f7..fc9a5775d2 100644 --- a/src/Nncase.Compiler/Interop/CApi.cs +++ b/src/Nncase.Compiler/Interop/CApi.cs @@ -96,7 +96,7 @@ public unsafe struct CApiMT public delegate* unmanaged ShapeBucketOptionsSetRangeInfoPtr; public delegate* unmanaged ShapeBucketOptionsSetSegmentsCountPtr; public delegate* unmanaged ShapeBucketOptionsSetFixVarMapPtr; - /* This block is generated by tools/stackvm_gen/CApiGen at 12/20/2024 5:31:31 PM +08:00. */ + /* This block is generated by tools/stackvm_gen/CApiGen at 2/10/2025 1:59:47 PM +08:00. */ public delegate* unmanaged CpuTargetOptionsCreatePtr; public delegate* unmanaged CpuTargetOptionsSetModelNamePtr; public delegate* unmanaged CpuTargetOptionsSetPackingPtr; @@ -115,8 +115,10 @@ public unsafe struct CApiMT public delegate* unmanaged CpuTargetOptionsSetMemoryCapacitiesPtr; public delegate* unmanaged CpuTargetOptionsSetMemoryBandWidthsPtr; public delegate* unmanaged CpuTargetOptionsSetDistributedSchemePtr; + public delegate* unmanaged CpuTargetOptionsGetDistributedSearchStrategyPtr; + public delegate* unmanaged CpuTargetOptionsSetDistributedSearchStrategyPtr; public delegate* unmanaged CpuTargetOptionsSetCustomOpSchemePtr; - /* end the auto generated block by tools/stackvm_gen/CApiGen at 12/20/2024 5:31:31 PM +08:00. */ + /* end the auto generated block by tools/stackvm_gen/CApiGen at 2/10/2025 1:59:47 PM +08:00. */ public delegate* unmanaged RTValueFromHandlePtr; public delegate* unmanaged RTValueGetHandlePtr; public delegate* unmanaged StreamCreatePtr; @@ -190,8 +192,8 @@ public static void Initialize(CApiMT* mt) mt->ShapeBucketOptionsSetRangeInfoPtr = &ShapeBucketOptionsSetRangeInfo; mt->ShapeBucketOptionsSetSegmentsCountPtr = &ShapeBucketOptionsSetSegmentsCount; mt->ShapeBucketOptionsSetFixVarMapPtr = &ShapeBucketOptionsSetFixVarMap; - /* This block is generated by tools/stackvm_gen/CApiGen at 12/20/2024 3:41:05 PM +08:00. */ - mt->CpuTargetOptionsCreatePtr = &CpuTargetOptionsCreate; + /* This block is generated by tools/stackvm_gen/CApiGen at 2/10/2025 1:59:47 PM +08:00. */ + mt->CpuTargetOptionsCreatePtr = &CpuTargetOptionsCreate; mt->CpuTargetOptionsSetModelNamePtr = &CpuTargetOptionsSetModelName; mt->CpuTargetOptionsSetPackingPtr = &CpuTargetOptionsSetPacking; mt->CpuTargetOptionsSetUnifiedMemoryArchPtr = &CpuTargetOptionsSetUnifiedMemoryArch; @@ -209,8 +211,10 @@ public static void Initialize(CApiMT* mt) mt->CpuTargetOptionsSetMemoryCapacitiesPtr = &CpuTargetOptionsSetMemoryCapacities; mt->CpuTargetOptionsSetMemoryBandWidthsPtr = &CpuTargetOptionsSetMemoryBandWidths; mt->CpuTargetOptionsSetDistributedSchemePtr = &CpuTargetOptionsSetDistributedScheme; + mt->CpuTargetOptionsGetDistributedSearchStrategyPtr = &CpuTargetOptionsGetDistributedSearchStrategy; + mt->CpuTargetOptionsSetDistributedSearchStrategyPtr = &CpuTargetOptionsSetDistributedSearchStrategy; mt->CpuTargetOptionsSetCustomOpSchemePtr = &CpuTargetOptionsSetCustomOpScheme; - /* end the auto generated block by tools/stackvm_gen/CApiGen at 12/20/2024 3:41:05 PM +08:00. */ + /* end the auto generated block by tools/stackvm_gen/CApiGen at 2/10/2025 1:59:47 PM +08:00. */ mt->RTValueFromHandlePtr = &RTValueFromHandle; mt->RTValueGetHandlePtr = &RTValueGetHandle; mt->StreamCreatePtr = &StreamCreate; @@ -792,7 +796,7 @@ private static void ShapeBucketOptionsSetFixVarMap(IntPtr shapeBucketOptionsHand Get(shapeBucketOptionsHandle).FixVarMap = fixVarMapStruct; } - /* This block is generated by tools/stackvm_gen/CApiGen at 12/20/2024 5:31:31 PM +08:00. */ + /* This block is generated by tools/stackvm_gen/CApiGen at 2/10/2025 1:59:47 PM +08:00. */ [UnmanagedCallersOnly] private static IntPtr CpuTargetOptionsCreate() { @@ -901,13 +905,25 @@ private static void CpuTargetOptionsSetDistributedScheme(IntPtr handle, byte* va Get(handle).DistributedScheme = ToString(value, length); } + [UnmanagedCallersOnly] + private static AutoDistributedSearchStrategy CpuTargetOptionsGetDistributedSearchStrategy(IntPtr handle) + { + return Get(handle).DistributedSearchStrategy; + } + + [UnmanagedCallersOnly] + private static void CpuTargetOptionsSetDistributedSearchStrategy(IntPtr handle, AutoDistributedSearchStrategy value) + { + Get(handle).DistributedSearchStrategy = value; + } + [UnmanagedCallersOnly] private static void CpuTargetOptionsSetCustomOpScheme(IntPtr handle, byte* value, nuint length) { Get(handle).CustomOpScheme = ToString(value, length); } - /* end the auto generated block by tools/stackvm_gen/CApiGen at 12/20/2024 3:41:05 PM +08:00. */ + /* end the auto generated block by tools/stackvm_gen/CApiGen at 2/10/2025 1:59:47 PM +08:00. */ [UnmanagedCallersOnly] private static IntPtr RTValueFromHandle(IntPtr handle) diff --git a/tests/config.toml b/tests/config.toml index 367767e054..9bd9196a51 100644 --- a/tests/config.toml +++ b/tests/config.toml @@ -157,4 +157,5 @@ MemoryBandWidths = [64, 32] HierarchyLatencies = [10000, 10000, 10000, 10000, 10000] UnifiedMemoryArch = false Packing = true -HierarchyKind = "nncase.HierarchyKind.SMT" \ No newline at end of file +HierarchyKind = "nncase.HierarchyKind.SMT" +DistributedSearchStrategy = "nncase.AutoDistributedSearchStrategy.ExpandPartial" \ No newline at end of file diff --git a/tools/stackvm_gen/CApiGen/packages.lock.json b/tools/stackvm_gen/CApiGen/packages.lock.json index bcc9281162..f0b9c857e6 100644 --- a/tools/stackvm_gen/CApiGen/packages.lock.json +++ b/tools/stackvm_gen/CApiGen/packages.lock.json @@ -317,7 +317,8 @@ "Google.OrTools": "[9.4.1874, )", "NetFabric.Hyperlinq": "[3.0.0-beta48, )", "Nncase.Core": "[1.0.0, )", - "Nncase.Evaluator": "[1.0.0, )" + "Nncase.Evaluator": "[1.0.0, )", + "Nncase.Graph": "[1.0.0, )" } }, "nncase.evaluator": { @@ -332,7 +333,8 @@ "dependencies": { "Nncase.Core": "[1.0.0, )", "Nncase.Evaluator": "[1.0.0, )", - "QuikGraph": "[2.5.0, )" + "QuikGraph": "[2.5.0, )", + "QuikGraph.Graphviz": "[2.5.0, )" } }, "nncase.io": { From 858f88bce8ccb91b995565ceee9209afb1fc803c Mon Sep 17 00:00:00 2001 From: zhen8838 Date: Mon, 10 Feb 2025 06:27:33 +0000 Subject: [PATCH 10/18] Apply code-format changes --- src/Nncase.Compiler/Interop/CApi.cs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Nncase.Compiler/Interop/CApi.cs b/src/Nncase.Compiler/Interop/CApi.cs index fc9a5775d2..6f335886a9 100644 --- a/src/Nncase.Compiler/Interop/CApi.cs +++ b/src/Nncase.Compiler/Interop/CApi.cs @@ -193,7 +193,7 @@ public static void Initialize(CApiMT* mt) mt->ShapeBucketOptionsSetSegmentsCountPtr = &ShapeBucketOptionsSetSegmentsCount; mt->ShapeBucketOptionsSetFixVarMapPtr = &ShapeBucketOptionsSetFixVarMap; /* This block is generated by tools/stackvm_gen/CApiGen at 2/10/2025 1:59:47 PM +08:00. */ - mt->CpuTargetOptionsCreatePtr = &CpuTargetOptionsCreate; + mt->CpuTargetOptionsCreatePtr = &CpuTargetOptionsCreate; mt->CpuTargetOptionsSetModelNamePtr = &CpuTargetOptionsSetModelName; mt->CpuTargetOptionsSetPackingPtr = &CpuTargetOptionsSetPacking; mt->CpuTargetOptionsSetUnifiedMemoryArchPtr = &CpuTargetOptionsSetUnifiedMemoryArch; @@ -214,7 +214,7 @@ public static void Initialize(CApiMT* mt) mt->CpuTargetOptionsGetDistributedSearchStrategyPtr = &CpuTargetOptionsGetDistributedSearchStrategy; mt->CpuTargetOptionsSetDistributedSearchStrategyPtr = &CpuTargetOptionsSetDistributedSearchStrategy; mt->CpuTargetOptionsSetCustomOpSchemePtr = &CpuTargetOptionsSetCustomOpScheme; - /* end the auto generated block by tools/stackvm_gen/CApiGen at 2/10/2025 1:59:47 PM +08:00. */ + /* end the auto generated block by tools/stackvm_gen/CApiGen at 2/10/2025 1:59:47 PM +08:00. */ mt->RTValueFromHandlePtr = &RTValueFromHandle; mt->RTValueGetHandlePtr = &RTValueGetHandle; mt->StreamCreatePtr = &StreamCreate; From df340872cbd734693c7777cbccdd933d8b2ddb8d Mon Sep 17 00:00:00 2001 From: huochenghai Date: Mon, 10 Feb 2025 14:48:54 +0800 Subject: [PATCH 11/18] fix build --- modules/Nncase.Modules.CPU/Targets/CPUTargetOptionsCommand.cs | 1 + 1 file changed, 1 insertion(+) diff --git a/modules/Nncase.Modules.CPU/Targets/CPUTargetOptionsCommand.cs b/modules/Nncase.Modules.CPU/Targets/CPUTargetOptionsCommand.cs index dd0eb03c18..94d807321a 100644 --- a/modules/Nncase.Modules.CPU/Targets/CPUTargetOptionsCommand.cs +++ b/modules/Nncase.Modules.CPU/Targets/CPUTargetOptionsCommand.cs @@ -13,6 +13,7 @@ using System.Threading.Tasks; using Nncase; using Nncase.CommandLine; +using Nncase.IR; namespace Nncase.Targets; From 20c2423b0b3aaa6edc941dd34793e879eab317a6 Mon Sep 17 00:00:00 2001 From: huochenghai Date: Mon, 10 Feb 2025 15:32:59 +0800 Subject: [PATCH 12/18] skip reduce mean split on axis --- modules/Nncase.Modules.CPU/Evaluator/CPU/PackedReduce.cs | 5 +++++ src/Nncase.Evaluator/Math/Reduce.cs | 5 +++++ 2 files changed, 10 insertions(+) diff --git a/modules/Nncase.Modules.CPU/Evaluator/CPU/PackedReduce.cs b/modules/Nncase.Modules.CPU/Evaluator/CPU/PackedReduce.cs index acf04ded9a..e8fea5d8d8 100644 --- a/modules/Nncase.Modules.CPU/Evaluator/CPU/PackedReduce.cs +++ b/modules/Nncase.Modules.CPU/Evaluator/CPU/PackedReduce.cs @@ -115,6 +115,11 @@ private IRType Visit(ITypeInferenceContext context, PackedReduce target, Distrib switch (input.NdSBP[i]) { case SBPSplit { Axis: int ix } when axes.Contains(ix): + if (target.ReduceOp is ReduceOp.Mean) + { + return new InvalidType($"Not support reduce mean for now."); + } + ndsbp[i] = SBP.P(target.ReduceOp); break; default: diff --git a/src/Nncase.Evaluator/Math/Reduce.cs b/src/Nncase.Evaluator/Math/Reduce.cs index 1efa2b174a..d39abb7dbf 100644 --- a/src/Nncase.Evaluator/Math/Reduce.cs +++ b/src/Nncase.Evaluator/Math/Reduce.cs @@ -190,6 +190,11 @@ private IRType Visit(ITypeInferenceContext context, Reduce target, DistributedTy switch (input.NdSBP[i]) { case SBPSplit { Axis: int ix } when axes.Contains(ix): + if (target.ReduceOp is ReduceOp.Mean) + { + return new InvalidType($"Not support reduce mean for now."); + } + ndsbp[i] = SBP.P(target.ReduceOp); break; default: From 23cf4364271cb6011edbdb55fc98a3ab16139b78 Mon Sep 17 00:00:00 2001 From: huochenghai Date: Mon, 10 Feb 2025 19:17:34 +0800 Subject: [PATCH 13/18] fix ffi --- python/nncase/__init__.py | 2 +- python/nncase/native/ffi.cpp | 13 +++++++++++-- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/python/nncase/__init__.py b/python/nncase/__init__.py index fa9f8a38ab..7841f4029a 100644 --- a/python/nncase/__init__.py +++ b/python/nncase/__init__.py @@ -33,7 +33,7 @@ os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' import _nncase -from _nncase import RuntimeTensor, TensorDesc, Simulator, CpuTargetOptions, NocArchitecture, HierarchyKind, MemoryAccessArchitecture +from _nncase import RuntimeTensor, TensorDesc, Simulator, CpuTargetOptions, NocArchitecture, HierarchyKind, MemoryAccessArchitecture, AutoDistributedSearchStrategy def _initialize(): diff --git a/python/nncase/native/ffi.cpp b/python/nncase/native/ffi.cpp index 9f9cd8e904..00e1bcd336 100644 --- a/python/nncase/native/ffi.cpp +++ b/python/nncase/native/ffi.cpp @@ -235,7 +235,7 @@ PYBIND11_MODULE(_nncase, m) { &shape_bucket_options::fix_var_map)); // clang-format off - /* This block is generated by tools/stackvm_gen/CApiGen at 12/20/2024 3:41:05 PM +08:00. */ + /* This block is generated by tools/stackvm_gen/CApiGen at 2/10/2025 1:59:47 PM +08:00. */ py::enum_(m, "MemoryAccessArchitecture") .value("UMA", memory_access_architecture_uma) @@ -249,6 +249,11 @@ PYBIND11_MODULE(_nncase, m) { .value("Parallel", hierarchy_kind_parallel) .value("SMT", hierarchy_kind_smt); + py::enum_(m, "AutoDistributedSearchStrategy") + .value("ExpandPartial", auto_distributed_search_strategy_expand_partial) + .value("ExpandAll", auto_distributed_search_strategy_expand_all) + .value("NoExpand", auto_distributed_search_strategy_no_expand); + py::class_(m, "CpuTargetOptions") .def(py::init()) @@ -308,12 +313,16 @@ PYBIND11_MODULE(_nncase, m) { "DistributedScheme", []() {}, py::overload_cast(&cpu_target_options::distributed_scheme)) + .def_property( + "DistributedSearchStrategy", + py::overload_cast<>(&cpu_target_options::distributed_search_strategy), + py::overload_cast(&cpu_target_options::distributed_search_strategy)) .def_property( "CustomOpScheme", []() {}, py::overload_cast(&cpu_target_options::custom_op_scheme)) ; - /* end the auto generated block by tools/stackvm_gen/CApiGen at 12/20/2024 3:41:05 PM +08:00. */ + /* end the auto generated block by tools/stackvm_gen/CApiGen at 2/10/2025 1:59:47 PM +08:00. */ // clang-format on py::class_(m, "CalibrationDatasetProvider") From 323c01803d6d3110d59984a0f9aa97e097a77240 Mon Sep 17 00:00:00 2001 From: huochenghai Date: Mon, 10 Feb 2025 19:17:47 +0800 Subject: [PATCH 14/18] fix reshape type infer --- src/Nncase.Evaluator/Tensors/Reshape.cs | 3 ++- .../Distributed/UnitTestDistributedTypeInfer.cs | 16 +++++++++++----- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/src/Nncase.Evaluator/Tensors/Reshape.cs b/src/Nncase.Evaluator/Tensors/Reshape.cs index 4d1e9423c6..28b4ef3a7c 100644 --- a/src/Nncase.Evaluator/Tensors/Reshape.cs +++ b/src/Nncase.Evaluator/Tensors/Reshape.cs @@ -33,6 +33,7 @@ public static IRType VisitDistributedType(DistributedType inType, int[] newShape } var (forwardDict, backwardDict) = IRUtility.ShapeMapMatrixAsDict(mat); + var splitedShape = DistributedUtility.GetDividedTensorType(inType).Shape.ToValueArray(); var ndsbp = new SBP[inType.NdSBP.Count]; for (int meshAxis = 0; meshAxis < inType.NdSBP.Count; meshAxis++) { @@ -49,7 +50,7 @@ public static IRType VisitDistributedType(DistributedType inType, int[] newShape var firstValidAxis = mapedOutAxes.Where(axis => newShape[axis] > 1).First(); var restAxes = mapedOutAxes.Skip(mapedOutAxes.IndexOf(firstValidAxis) + 1).ToArray(); var restSize = restAxes.Aggregate(1, (x, i) => x * newShape[i]); - if (restSize < (inShape[si.Axis] / inType.Placement.Hierarchy[meshAxis])) + if (restSize < splitedShape[si.Axis]) { ndsbp[meshAxis] = SBP.S(firstValidAxis); } diff --git a/src/Nncase.Tests/Distributed/UnitTestDistributedTypeInfer.cs b/src/Nncase.Tests/Distributed/UnitTestDistributedTypeInfer.cs index 30a0bd4fcf..733d2dcddb 100644 --- a/src/Nncase.Tests/Distributed/UnitTestDistributedTypeInfer.cs +++ b/src/Nncase.Tests/Distributed/UnitTestDistributedTypeInfer.cs @@ -68,19 +68,19 @@ public sealed class UnitTestDistributedTypeInfer }, { // mesh dim 0 split on first merged-by-reshape axis. - new DistributedType(new(DataTypes.Float32, new[] { 1, 48, 64, 16 }), new SBP[] { SBP.S(1), SBP.S(2) }, new(new[] { 8 }, "t")), + new DistributedType(new(DataTypes.Float32, new[] { 1, 48, 64, 16 }), new SBP[] { SBP.S(1), SBP.S(2) }, new(new[] { 4, 8 }, "bt")), new[] { 1, 48, 1024 }, - new DistributedType(new(DataTypes.Float32, new[] { 1, 48, 1024 }), new SBP[] { SBP.S(1), SBP.S(2) }, new(new[] { 8 }, "t")) + new DistributedType(new(DataTypes.Float32, new[] { 1, 48, 1024 }), new SBP[] { SBP.S(1), SBP.S(2) }, new(new[] { 4, 8 }, "bt")) }, { // mesh dim 1 split on first merged-by-reshape axis. - new DistributedType(new(DataTypes.Float32, new[] { 1, 48, 64, 16 }), new SBP[] { SBP.S(2), SBP.S(1), }, new(new[] { 8 }, "t")), + new DistributedType(new(DataTypes.Float32, new[] { 1, 48, 64, 16 }), new SBP[] { SBP.S(2), SBP.S(1), }, new(new[] { 4, 8 }, "bt")), new[] { 1, 48, 1024 }, - new DistributedType(new(DataTypes.Float32, new[] { 1, 48, 1024 }), new SBP[] { SBP.S(2), SBP.S(1), }, new(new[] { 8 }, "t")) + new DistributedType(new(DataTypes.Float32, new[] { 1, 48, 1024 }), new SBP[] { SBP.S(2), SBP.S(1), }, new(new[] { 4, 8 }, "bt")) }, { // split on second merged-by-reshape axis. - new DistributedType(new(DataTypes.Float32, new[] { 1, 48, 64, 16 }), new SBP[] { SBP.S(1), SBP.S(3) }, new(new[] { 8 }, "t")), + new DistributedType(new(DataTypes.Float32, new[] { 1, 48, 64, 16 }), new SBP[] { SBP.S(1), SBP.S(3) }, new(new[] { 4, 8 }, "bt")), new[] { 1, 48, 1024 }, new InvalidType("not support") }, @@ -90,6 +90,12 @@ public sealed class UnitTestDistributedTypeInfer new[] { 3, 20 }, new InvalidType("unmapable") }, + { + // insufficient-data reshape + new DistributedType(new(DataTypes.Float32, new[] { 512, 4096 }), new SBP[] { SBP.B, SBP.S(0), SBP.S(1), SBP.S(1), SBP.S(1) }, new(new[] { 1, 2, 8, 4, 4 }, "cdxyt")), + new[] { 512, 64, 64 }, + new InvalidType("insufficient data") + }, }; [Fact] From f64fe5a2bdb99bb4dbe13e7f0db92f71ab28f5e6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=83=91=E5=90=AF=E8=88=AA?= <597323109@qq.com> Date: Tue, 11 Feb 2025 16:28:16 +0800 Subject: [PATCH 15/18] add dump sub search graph --- .../Passes/Distributed/AutoDistributed.cs | 137 +++++++++--------- 1 file changed, 72 insertions(+), 65 deletions(-) diff --git a/modules/Nncase.Modules.CPU/Passes/Distributed/AutoDistributed.cs b/modules/Nncase.Modules.CPU/Passes/Distributed/AutoDistributed.cs index 1c6299f0d5..68d96a5af0 100644 --- a/modules/Nncase.Modules.CPU/Passes/Distributed/AutoDistributed.cs +++ b/modules/Nncase.Modules.CPU/Passes/Distributed/AutoDistributed.cs @@ -421,6 +421,73 @@ protected override Unit VisitLeafCall(Call expr) return default; } + private static void DumpAction(GraphvizAlgorithm alg, IReadOnlyDictionary? pickMemo = null, IReadOnlyDictionary? costMemo = null) + { + alg.GraphFormat.RankDirection = QuikGraph.Graphviz.Dot.GraphvizRankDirection.LR; + alg.FormatCluster += (_, arg) => + { + if (arg.Cluster is DistributedSearchGraph tg) + { + arg.GraphFormat.LabelLocation = QuikGraph.Graphviz.Dot.GraphvizLabelLocation.T; + arg.GraphFormat.LabelJustification = QuikGraph.Graphviz.Dot.GraphvizLabelJustification.L; + arg.GraphFormat.Label = tg.Kind.ToString(); + if (tg.Kind is SearchGraphKind.Bucket) + { + arg.GraphFormat.Label += ": " + tg.Vertices.First().IRType.ToString(); + } + } + }; + + alg.FormatVertex += (_, arg) => + { + var row0 = new QuikGraph.Graphviz.Dot.GraphvizRecordCell(); + var col1 = new QuikGraph.Graphviz.Dot.GraphvizRecordCell(); + row0.Cells.Add(col1); + + col1.Cells.Add(new() { Text = CompilerServices.Print(arg.Vertex.Expr) }); + if (arg.Vertex.Expr is IR.Tuple && arg.Vertex.IRType is TupleType tpTuple) + { + for (int i = 0; i < tpTuple.Fields.Count; i++) + { + col1.Cells.Add(new() { Text = i.ToString(), Port = $"P{i}" }); + } + } + else if (arg.Vertex.Expr is Op op) + { + for (int i = 0; i < op.Parameters.Count(); i++) + { + col1.Cells.Add(new() { Text = i.ToString(), Port = $"P{i}" }); + } + } + + arg.VertexFormat.Record.Cells.Add(row0); + arg.VertexFormat.Shape = QuikGraph.Graphviz.Dot.GraphvizVertexShape.Record; + arg.VertexFormat.Style = QuikGraph.Graphviz.Dot.GraphvizVertexStyle.Filled; + if (costMemo is not null && costMemo.TryGetValue(arg.Vertex, out var cost)) + { + var row1 = new QuikGraph.Graphviz.Dot.GraphvizRecordCell(); + foreach (var (k, v) in cost.Factors) + { + row1.Cells.Add(new() { Text = $"{k}: {v}" }); + } + + row1.Cells.Add(new() { Text = $"Score: {cost.Score}" }); + col1.Cells.Add(row1); + } + + if (pickMemo is not null && pickMemo.TryGetValue(arg.Vertex, out var picked) && picked == true) + { + arg.VertexFormat.FillColor = QuikGraph.Graphviz.Dot.GraphvizColor.SkyBlue; + } + }; + + alg.FormatEdge += (_, arg) => + { + arg.EdgeFormat.Direction = QuikGraph.Graphviz.Dot.GraphvizEdgeDirection.Back; + arg.EdgeFormat.TailPort = $"P{arg.Edge.InputIndex}"; + }; + } + /// /// some times we didn't use all args. /// @@ -653,72 +720,12 @@ private DistributedSearchGraph TryInstertTerminator(Expr expr) private void Dump(Stream stream, IReadOnlyDictionary pickMemo, IReadOnlyDictionary costMemo) { using var writer = new StreamWriter(stream); - writer.Write(_rootSearchGraph.ToGraphviz(alg => - { - alg.GraphFormat.RankDirection = QuikGraph.Graphviz.Dot.GraphvizRankDirection.LR; - alg.FormatCluster += (_, arg) => - { - if (arg.Cluster is DistributedSearchGraph tg) - { - arg.GraphFormat.LabelLocation = QuikGraph.Graphviz.Dot.GraphvizLabelLocation.T; - arg.GraphFormat.LabelJustification = QuikGraph.Graphviz.Dot.GraphvizLabelJustification.L; - arg.GraphFormat.Label = tg.Kind.ToString(); - if (tg.Kind is SearchGraphKind.Bucket) - { - arg.GraphFormat.Label += ": " + tg.Vertices.First().IRType.ToString(); - } - } - }; - - alg.FormatVertex += (_, arg) => - { - var row0 = new QuikGraph.Graphviz.Dot.GraphvizRecordCell(); - var col1 = new QuikGraph.Graphviz.Dot.GraphvizRecordCell(); - row0.Cells.Add(col1); - - col1.Cells.Add(new() { Text = CompilerServices.Print(arg.Vertex.Expr) }); - if (arg.Vertex.Expr is IR.Tuple && arg.Vertex.IRType is TupleType tpTuple) - { - for (int i = 0; i < tpTuple.Fields.Count; i++) - { - col1.Cells.Add(new() { Text = i.ToString(), Port = $"P{i}" }); - } - } - else if (arg.Vertex.Expr is Op op) - { - for (int i = 0; i < op.Parameters.Count(); i++) - { - col1.Cells.Add(new() { Text = i.ToString(), Port = $"P{i}" }); - } - } - - arg.VertexFormat.Record.Cells.Add(row0); - arg.VertexFormat.Shape = QuikGraph.Graphviz.Dot.GraphvizVertexShape.Record; - arg.VertexFormat.Style = QuikGraph.Graphviz.Dot.GraphvizVertexStyle.Filled; - if (costMemo.TryGetValue(arg.Vertex, out var cost)) - { - var row1 = new QuikGraph.Graphviz.Dot.GraphvizRecordCell(); - foreach (var (k, v) in cost.Factors) - { - row1.Cells.Add(new() { Text = $"{k}: {v}" }); - } - - row1.Cells.Add(new() { Text = $"Score: {cost.Score}" }); - col1.Cells.Add(row1); - } - - if (pickMemo.TryGetValue(arg.Vertex, out var picked) && picked == true) - { - arg.VertexFormat.FillColor = QuikGraph.Graphviz.Dot.GraphvizColor.SkyBlue; - } - }; + writer.Write(_rootSearchGraph.ToGraphviz(alg => DumpAction(alg, pickMemo, costMemo))); + } - alg.FormatEdge += (_, arg) => - { - arg.EdgeFormat.Direction = QuikGraph.Graphviz.Dot.GraphvizEdgeDirection.Back; - arg.EdgeFormat.TailPort = $"P{arg.Edge.InputIndex}"; - }; - })); + private void Dump(DistributedSearchGraph searchGraph, string name, IReadOnlyDictionary? pickMemo = null, IReadOnlyDictionary? costMemo = null) + { + searchGraph.Dump(name, alg => DumpAction(alg, pickMemo, costMemo)); } private Expr SolveAndExtract(DistributedSearchGraph rootCluster) From 43abc317c07ea697b886b9fe4e85af8aecbf38cd Mon Sep 17 00:00:00 2001 From: zhengqihang <597323109@qq.com> Date: Wed, 12 Feb 2025 02:39:34 +0000 Subject: [PATCH 16/18] revert to old strategy --- .../Passes/Distributed/AutoDistributed.cs | 49 +++++++++++++++++++ 1 file changed, 49 insertions(+) diff --git a/modules/Nncase.Modules.CPU/Passes/Distributed/AutoDistributed.cs b/modules/Nncase.Modules.CPU/Passes/Distributed/AutoDistributed.cs index 68d96a5af0..e0b6616f93 100644 --- a/modules/Nncase.Modules.CPU/Passes/Distributed/AutoDistributed.cs +++ b/modules/Nncase.Modules.CPU/Passes/Distributed/AutoDistributed.cs @@ -348,6 +348,52 @@ protected override Unit VisitLeafCall(Call expr) break; case AutoDistributedSearchStrategy.ExpandPartial: #pragma warning disable SA1008 // Opening parenthesis should be spaced correctly +#if true + foreach (var (lType, lBucket) in bucketMemo.Where(kv => kv.Key is DistributedType ldistType && ldistType.NdSBP.Any(sbp => sbp is SBPPartial))) + { + var ldistType = (DistributedType)lType; + foreach (var rType in GetLeafCandidateDistTypes(expr.CheckedTensorType, Placements).Where(rdistType => ldistType.NdSBP.Zip(rdistType.NdSBP).All(p => (p.First, p.Second) switch { (SBPPartial, SBPBroadCast) => true, (SBP a, SBP b) => a == b }))) + { + if (!bucketMemo.TryGetValue(rType, out var rbucket)) + { + rbucket = callCluster.CreateCluster(SearchGraphKind.Bucket); + bucketMemo.Add(rType, rbucket); + } + + if (Evaluator.IR.CPU.BoxingEvaluator.VisitType(lType, rType) is InvalidType) + { + throw new InvalidOperationException("the partial to broadcast shouldn't be invalid!"); + } + + var rnode = new SearchableNode(new Boxing(rType), rType); + rbucket.AddVertex(rnode); + callCluster.AddEdge(new(rnode, lBucket.Vertices.First(), 0, lBucket)); + } + } + + { + if (bucketMemo.Count == 1 && bucketMemo.Keys.First() is DistributedType ldistType && ldistType.NdSBP.All(sbp => sbp is SBPBroadCast)) + { + foreach (var rdistType in GetLeafCandidateDistTypes(expr.CheckedTensorType, Placements).Where(rdistType => ldistType.NdSBP.Zip(rdistType.NdSBP).All(p => (p.First, p.Second) switch { (SBPBroadCast, SBPSplit) => true, (SBP a, SBP b) => a == b } && ldistType != rdistType))) + { + if (!bucketMemo.TryGetValue(rdistType, out var rbucket)) + { + rbucket = callCluster.CreateCluster(SearchGraphKind.Bucket); + bucketMemo.Add(rdistType, rbucket); + } + + if (Evaluator.IR.CPU.BoxingEvaluator.VisitType(ldistType, rdistType) is InvalidType) + { + throw new InvalidOperationException("the broadcast to split shouldn't be invalid!"); + } + + var rnode = new SearchableNode(new Boxing(rdistType), rdistType); + rbucket.AddVertex(rnode); + callCluster.AddEdge(new(rnode, bucketMemo[ldistType].Vertices.First(), 0, bucketMemo[ldistType])); + } + } + } +#else // partial -> broadcast foreach (var (lType, lBucket) in bucketMemo.Where(kv => kv.Key is DistributedType ldistType && ldistType.NdSBP.Any(sbp => sbp is SBPPartial))) { @@ -377,6 +423,7 @@ protected override Unit VisitLeafCall(Call expr) } } } +#endif #pragma warning restore SA1008 // Opening parenthesis should be spaced correctly break; @@ -385,6 +432,7 @@ protected override Unit VisitLeafCall(Call expr) } // 4. add not infered type in search space. +#if false var addedBuckets = bucketMemo.Values.ToArray(); foreach (var nType in GetLeafCandidateDistTypes(expr.CheckedTensorType, Placements)) { @@ -415,6 +463,7 @@ protected override Unit VisitLeafCall(Call expr) } } } +#endif // 5. filter FilterByScheme(expr, callCluster); From c60784e7f085b14fd8adad82c57d4f1c1213456d Mon Sep 17 00:00:00 2001 From: zhengqihang <597323109@qq.com> Date: Wed, 12 Feb 2025 08:32:17 +0000 Subject: [PATCH 17/18] revert autodist --- .../Passes/Distributed/AutoDistributed.cs | 704 ++++++++++++++++++ .../Utilities/DistributedUtility.cs | 37 + src/Nncase.EGraph/Passes/EGraphExtensions.cs | 37 +- 3 files changed, 776 insertions(+), 2 deletions(-) diff --git a/modules/Nncase.Modules.CPU/Passes/Distributed/AutoDistributed.cs b/modules/Nncase.Modules.CPU/Passes/Distributed/AutoDistributed.cs index e0b6616f93..79e4dc3b1b 100644 --- a/modules/Nncase.Modules.CPU/Passes/Distributed/AutoDistributed.cs +++ b/modules/Nncase.Modules.CPU/Passes/Distributed/AutoDistributed.cs @@ -1,6 +1,7 @@ // Copyright (c) Canaan Inc. All rights reserved. // Licensed under the Apache license. See LICENSE file in the project root for full license information. +#if false using System.Diagnostics.CodeAnalysis; using System.Reactive; using System.Runtime.CompilerServices; @@ -1099,3 +1100,706 @@ public override void OnSolutionCallback() } } } +#else +using System.Reactive; +using System.Runtime.CompilerServices; +using Google.OrTools.Sat; +using NetFabric.Hyperlinq; +using Nncase.CodeGen; +using Nncase.IR; +using Nncase.IR.CPU; +using Nncase.IR.Tensors; +using Nncase.PatternMatch; +using Nncase.Targets; +using Nncase.Utilities; +using static Nncase.PatternMatch.Utility; + +[assembly: InternalsVisibleTo("Nncase.Tests")] + +namespace Nncase.Passes.Distributed; + +public interface IEquality +{ +} + +public record EqualityNode(Expr Expr) : IEquality +{ +} + +public record EqualityClass(bool Tuple, List Children) : IEquality +{ +} + +public sealed class AutoDistributedMetaData : IRMetadata +{ + public bool Skip { get; set; } +} + +/// +/// auto distributed the xpu fusion. +/// +[RuleGenerator] +public sealed partial class AutoDistributedPass : FunctionPass +{ + private readonly CompileOptions _compileOptions; + private readonly string _moduleKind; + + public AutoDistributedPass(string moduleKind, CompileOptions compileOptions) + { + _compileOptions = compileOptions; + _moduleKind = moduleKind; + } + + protected override Task RunCoreAsync(BaseFunction input, RunPassContext context) + { + if (input.Metadata is AutoDistributedMetaData { Skip: true }) + { + return Task.FromResult(input); + } + + var rewriter = new AutoDistributedRewriter(_compileOptions, _compileOptions.TargetOptions is CpuTargetOptions options ? options : new CpuTargetOptions(), _moduleKind); + return Task.FromResult(rewriter.Rewirte(input)); + } +} + +internal sealed class AutoDistributedRewriter : ExprVisitor>, Unit> +{ + private readonly Dictionary _equalMemo = new(); + + private readonly string _moduleKind; + + public AutoDistributedRewriter(CompileOptions compileOptions, CpuTargetOptions targetOptions, string moduleKind = "cpu") + { + Placements = targetOptions.Hierarchies.Select(h => new Placement(h, targetOptions.HierarchyNames, targetOptions.HierarchyKind)).ToArray(); + CompileOptions = compileOptions; + TargetOptions = targetOptions; + if (Path.Exists(TargetOptions.DistributedScheme) && System.Text.Json.JsonSerializer.Deserialize(File.ReadAllText(TargetOptions.DistributedScheme)) is DistributedSchema scheme) + { + Scheme = scheme.Outputs.ToDictionary(n => n.Name, n => (new IRArray(n.NdSBP), new Placement(n.Hierarchy, n.HierarchyName, targetOptions.HierarchyKind))); + } + else + { + Scheme = new Dictionary NdSBP, Placement Placement)>(); + } + + _moduleKind = moduleKind; + } + + public IRArray Placements { get; } + + public CompileOptions CompileOptions { get; } + + public CpuTargetOptions TargetOptions { get; } + + public IReadOnlyDictionary NdSBP, Placement Placement)> Scheme { get; } + + public static void MemoryExtractConstrains(CpModel model, IReadOnlyDictionary vars) + { + var consts = vars.Keys.Where(k => k.Expr is Call { Target: IR.CPU.Boxing { NewType: DistributedType } } call && call.Arguments[0] is TensorConst tc && tc.Value.Length >= 8).ToArray(); + model.Add(LinearExpr.WeightedSum(consts.Select(k => vars[k]), consts.Select(k => + { + var type = DistributedUtility.GetDividedTensorType((DistributedType)k.Expr.CheckedType); + return TensorUtilities.GetProduct(type.Shape.ToValueArray()) * type.DType.SizeInBytes; + })) < (2L * 512L * 1024L * 1024L)); + } + + public static IReadOnlyList GetLeafCandidateBoxings(Expr expr, IEnumerable placements) + { + if (expr.CheckedType is InvalidType) + { + return [expr]; + } + + if (expr is IR.Tuple tuple) + { + return tuple.Fields.ToArray(). + Select(e => IsDistributed(e.CheckedType) ? [e] : GetLeafCandidateBoxings(e, placements)). + CartesianProduct(). + Select(fs => new IR.Tuple(fs.ToArray())). + ToArray(); + } + else + { + // Don't use expr.CheckedTensorType + return placements.Select( + placement => + DistributedUtility.GetLeafCandidateNDSBPs((TensorType)expr.CheckedType, placement). + Select(ndsbp => + expr is TensorConst tc ? (Expr)new TensorConst(tc.Value, ndsbp, placement) + : IR.F.CPU.Boxing(expr, new DistributedType((TensorType)expr.CheckedType, ndsbp, placement)))). + SelectMany(e => e).ToArray(); + } + } + + public static IReadOnlyList> GetDiverseCandidateSBPs(DistributedType distributedType, IEnumerable placements) + { + return placements.Select( + placement => + DistributedUtility.GetLeafCandidateNDSBPs(distributedType.TensorType, placement). + Where(ndsbp => ndsbp != distributedType.NdSBP)). + SelectMany(e => e).ToArray(); + } + + public void SingleNodeMemoryExtractConstrains(CpModel model, IReadOnlyDictionary vars) + { + var distTypes = vars.Keys.Where(k => k.Expr.CheckedType is DistributedType dt).ToArray(); + foreach (var k in distTypes) + { + if (TargetOptions.HierarchySizes.Length > 1) + { + var type = DistributedUtility.GetDividedTensorType((DistributedType)k.Expr.CheckedType); + var size = TensorUtilities.GetProduct(type.Shape.ToValueArray()) * type.DType.SizeInBytes; + + if (k.Expr is Call { Target: IR.CPU.Boxing boxing } call && boxing.NewType is DistributedType distributedType && call.Arguments[0].CheckedType is DistributedType inType && inType.NdSBP.Any(sbp => sbp is SBPPartial) && distributedType != call.Arguments[0].CheckedType) + { + type = DistributedUtility.GetDividedTensorType(inType); + size += TensorUtilities.GetProduct(type.Shape.ToValueArray()) * type.DType.SizeInBytes; + } + + model.Add(vars[k] * size < TargetOptions.HierarchySizes[^2] / TargetOptions.Hierarchies[0][^1]); + } + } + } + + public void FilterByScheme(Expr expr, Dictionary> result) + { + foreach (var name in expr.Metadata.OutputNames ?? Array.Empty()) + { + if (Scheme.TryGetValue(name, out var tp)) + { + var keys = result.Keys.ToArray(); + foreach (var key in keys) + { + if (!(key is DistributedType dtype && dtype.NdSBP == tp.NdSBP && dtype.Placement == tp.Placement)) + { + result.Remove(key); + } + } + } + } + } + + public BaseFunction Rewirte(BaseFunction input) + { + if (input is Function || input is Fusion) + { + var body = input is Function ? ((Function)input).Body : ((Fusion)input).Body; + var typeEquivalents = Visit(body); + + if (body is IR.Tuple tp) + { + var outputs = new List(); + var equ = _equalMemo[tp]; + + void Dfs(IEquality equality) + { + switch (equality) + { + case EqualityNode n: + outputs.Add(n.Expr); + break; + case EqualityClass tp: + foreach (var item in tp.Children) + { + Dfs(item); + } + + break; + } + } + + Dfs(equ); + + using (new ExprPinner(outputs.ToArray())) + { + BranchCut(); + } + } + else + { + var outputs = typeEquivalents.Select(g => InstertTerminator(g.Value[0])) + .Select(e => new EqualityNode(e)) + .OfType().ToList(); + + if (outputs.Any()) + { + _equalMemo.Add(body, new EqualityClass(false, outputs)); + + using (new ExprPinner(outputs.Select(e => ((EqualityNode)e).Expr).ToArray())) + { + BranchCut(); + } + } + else + { + return input; + } + } + + var graph = new EGraph(); + foreach (var (exprKey, buckets) in ExprMemo.Where(kv => kv.Key is not Op)) + { + foreach (var (typeKey, bucket) in buckets.Where(kv => kv.Value.Any())) + { + Unions(graph, bucket); + } + } + + var equivalents = _equalMemo[body]; + EClass Ddfs(IEquality equival) + { + switch (equival) + { + case EqualityNode n: + return graph.Add(n.Expr); + case EqualityClass tp: + var eids = tp.Children.Select(Ddfs).ToArray(); + if (tp.Tuple) + { + return graph.AddENode(new IR.Tuple(), eids); + } + else + { + foreach (var cls in eids.Skip(1)) + { + graph.Union(eids[0], cls); + } + + graph.Rebuild(); + return eids[0]; + } + + default: + throw new NotSupportedException(); + } + } + + var root = Ddfs(equivalents); + if (Diagnostics.DumpScope.Current.IsEnabled(Diagnostics.DumpFlags.EGraphCost)) + { + using (var stream = Diagnostics.DumpScope.Current.OpenFile("egraph.dot")) + { + EGraphPrinter.DumpEgraphAsDot(graph, stream); + } + } + + var constrains = new EGraphExtractConstrains[] { SingleNodeMemoryExtractConstrains }; + var post = graph.Extract(root, CompileOptions, null, constrains); + + if (input is Function) + { + return ((Function)input).With(body: post); + } + else + { + return ((Fusion)input).With(body: post); + } + } + + return input; + } + + protected override Dictionary> DefaultVisitLeaf(Expr expr) + { + return new(); + } + + protected override Dictionary> VisitLeafIf(If expr) + { + return new() { { expr.CheckedType, new() { expr } } }; + } + + protected override Dictionary> VisitLeafTuple(IR.Tuple expr) + { + if (ReferenceEquals(expr, VisitRoot)) + { + var fileds = new List(); + foreach (var i in Enumerable.Range(0, expr.Fields.Length)) + { + var boxings = Visit(expr.Fields[i]).Values. + Select(l => l.Select(e => e.CheckedType is DistributedType dt ? (dt.NdSBP.Any(s => s is SBPPartial) ? IR.F.CPU.Boxing(IR.F.CPU.Boxing(e, new DistributedType(dt.TensorType, dt.NdSBP.Select(s => s is SBPPartial ? SBP.B : s).ToArray(), dt.Placement)), dt.TensorType) : IR.F.CPU.Boxing(e, dt.TensorType)) : e).ToArray()). + SelectMany(e => e).Select(e => new EqualityNode(e)).OfType().ToList(); + fileds.Add(new EqualityClass(false, boxings)); + } + + _equalMemo.Add(expr, new EqualityClass(true, fileds)); + return new Dictionary> { }; // return empty. + } + + return expr.Fields.ToArray(). + Select(Visit). + CartesianProduct(). + Select(e => new IR.Tuple(e.Select(e => e.Value[0]).ToArray())). + GroupBy(tp => tp.CheckedType). + ToDictionary(g => g.Key, g => g.ToList()); + } + + protected override Dictionary> VisitLeafCall(Call expr) + { + if (expr.Target is Fusion fusion) + { + foreach (var idx in Enumerable.Range(0, fusion.Parameters.Length)) + { + VisitLeafArgument(ParameterKind.Input, expr.Arguments[idx], false); + } + + var rewriter = new AutoDistributedRewriter(CompileOptions, TargetOptions); + var post = rewriter.Rewirte(fusion); + var ret = expr.Arguments.ToArray(). + Select(Visit). + CartesianProduct(). + Select(args => args.ToArray()). + Select(args => args.Select(kv => kv.Value[0]).Select(arg => arg.CheckedType switch + { + DistributedType d => GetDiverseCandidateSBPs(d, Placements).Select(ndsbp => IR.F.CPU.Boxing(arg, new DistributedType(d.TensorType, ndsbp, d.Placement))).Concat(new[] { arg }).ToArray(), + _ => new[] { arg }, + }).ToList().CartesianProduct().Select(arg => expr.With(target: post, arguments: arg.ToArray())).ToArray()). + SelectMany(i => i). + GroupBy(c => c.CheckedType). + ToDictionary(g => g.Key, g => g.OrderByDescending(e => e.Users.Count()).ToList()); + + return ret; + } + + if (expr.Target is not Op op) + { + return new Dictionary> { { expr.CheckedType, new() { expr } } }; + } + + var isSupported = PassUtility.IsCpuSupported(op, expr, expr.Arguments.ToArray(), _moduleKind); + foreach (var param in op.Parameters) + { + VisitLeafArgument(param.ParameterKind, expr.Arguments[param.Index], isSupported); + } + + Dictionary> results; + if (TargetOptions.DistributedSearchStrategy is AutoDistributedSearchStrategy.ExpandAll && isSupported) + { + results = expr.Arguments.ToArray(). + Select(Visit). + CartesianProduct(). + Select(args => args.ToArray()). + Select(args => args.Select(kv => kv.Value[0]).Select(arg => arg.CheckedType switch + { + DistributedType d => GetDiverseCandidateSBPs(d, Placements).Select(ndsbp => IR.F.CPU.Boxing(arg, new DistributedType(d.TensorType, ndsbp, d.Placement))).Concat(new[] { arg }).ToArray(), + _ => new[] { arg }, + }).ToList().CartesianProduct().Select(arg => BuildEquivalCalls(op, arg.ToArray())).SelectMany(i => i).ToArray()). + SelectMany(i => i). + GroupBy(c => c.CheckedType). + ToDictionary(g => g.Key, g => g.OrderByDescending(e => e.Users.Count()).ToList()); + } + else + { + results = expr.Arguments.ToArray(). + Select(Visit). + CartesianProduct(). + Select(args => args.ToArray()). + Select(args => isSupported ? BuildEquivalCalls(op, args.Select(kv => kv.Value[0]).ToArray()).ToArray() : + BuildNotSupportedCalls(op, args.Select(kv => kv.Value[0]).ToArray())). + SelectMany(i => i). + GroupBy(c => c.CheckedType). + ToDictionary(g => g.Key, g => g.OrderByDescending(e => e.Users.Count()).ToList()); + } + + if (results.Count == 0) + { + return expr.Arguments.ToArray(). + Select(Visit). + CartesianProduct(). + Select(args => args.ToArray()). + Select(args => new[] { new Call(op, args.Select(kv => kv.Value[0]).Select(arg => arg.CheckedType switch + { + DistributedType d => d.NdSBP.All(sbp => sbp is SBPBroadCast) ? arg : IR.F.CPU.Boxing(arg, d with { NdSBP = new(Enumerable.Repeat(SBP.B, d.NdSBP.Count)) }), + _ => arg, + }).ToArray()), }). + SelectMany(i => i). + GroupBy(c => c.CheckedType). + ToDictionary(g => g.Key, g => g.OrderByDescending(e => e.Users.Count()).ToList()); + } + + // TODO: refactor here + if (expr.Target is not ScatterND && expr.Target is not Boxing && (expr.CheckedType is TensorType or DistributedType) && expr.CheckedShape.All(x => x != 0) && results.Count == 1 && results.First().Key is DistributedType dt && dt.NdSBP.All(sbp => sbp is SBPBroadCast)) + { + return expr.Arguments.ToArray(). + Select(Visit). + CartesianProduct(). + Select(args => args.ToArray()). + Select(args => args.Select(kv => kv.Value[0]).Select(arg => arg.CheckedType switch + { + DistributedType d => GetDiverseCandidateSBPs(d, Placements).Select(ndsbp => IR.F.CPU.Boxing(arg, new DistributedType(d.TensorType, ndsbp, d.Placement))).Concat(new[] { arg }).ToArray(), + _ => new[] { arg }, + }).ToList().CartesianProduct().Select(arg => BuildEquivalCalls(op, arg.ToArray())).SelectMany(i => i).ToArray()). + SelectMany(i => i). + GroupBy(c => c.CheckedType). + ToDictionary(g => g.Key, g => g.OrderByDescending(e => e.Users.Count()).ToList()); + } + + FilterByScheme(expr, results); + return results; + } + + private static bool IsDistributed(IRType type) => type switch + { + DistributedType => true, + TupleType t => t.All(IsDistributed), + _ => false, + }; + + private Dictionary> VisitLeafArgument(ParameterKind parameterKind, Expr expr, bool isSupported) + { + var updateBuckets = (Dictionary> buckets, IEnumerable equivalents) => + { + foreach (var eq in equivalents) + { + if (!buckets.TryGetValue(eq.CheckedType, out var bucket)) + { + bucket = new(); + buckets.Add(eq.CheckedType, bucket); + } + + bucket.Add(eq); + } + + FilterByScheme(expr, buckets); + }; + + var buckets = ExprMemo[expr]; + if (!buckets.Any()) + { + switch (parameterKind, expr) + { + case (ParameterKind.Input, Expr e) when e is Const or Var: + updateBuckets(buckets, isSupported ? GetLeafCandidateBoxings(e, Placements) : new[] { e }); + break; + case (ParameterKind.Input, Expr e) when e is IR.Tuple tp: + foreach (var f in tp.Fields) + { + VisitLeafArgument(parameterKind, f, isSupported); + } + + foreach (var (k, v) in VisitLeafTuple(tp)) + { + buckets.Add(k, v); + } + + break; + case (ParameterKind.Attribute, Var e): + updateBuckets(buckets, new[] { e }); + break; + case (ParameterKind.Attribute, TensorConst e): + updateBuckets(buckets, new[] { e.With() }); // remove all old users. + break; + case (ParameterKind.Attribute, None e): + updateBuckets(buckets, new[] { e.With() }); + break; + default: + throw new InvalidOperationException(); + } + } + else if (parameterKind == ParameterKind.Input) + { + if (isSupported) + { + if (!buckets.Keys.Any(IsDistributed)) + { + var results = buckets.Select(kv => GetLeafCandidateBoxings(kv.Value[0], Placements)).SelectMany(i => i).ToArray(); + updateBuckets(buckets, results); + } + } + else + { + if (buckets.Keys.All(IsDistributed)) + { + var results = buckets.Select(kv => InstertTerminator(kv.Value[0])).ToArray(); + updateBuckets(buckets, results); + } + } + } + + if (!buckets.Any()) + { + throw new InvalidOperationException(); + } + + return buckets; + } + + private Call[] BuildNotSupportedCalls(Op target, Expr[] args) + { + if (target.Parameters.Where(p => p.ParameterKind == ParameterKind.Input).Any(p => IsDistributed(args[p.Index].CheckedType))) + { + return Array.Empty(); + } + + return new[] { new Call(target, args) }; + } + + private IEnumerable BuildEquivalCalls(Op target, Expr[] args) + { + if (!target.Parameters.Where(p => p.ParameterKind == ParameterKind.Input).All(p => IsDistributed(args[p.Index].CheckedType))) + { + return Array.Empty(); + } + + var calls = new List(); + var call = new Call(target, args); + var valid = call.InferenceType(); + if (!valid) + { + if (target is Reshape && args[0].CheckedType is DistributedType inType && args[1] is TensorConst constNewShape) + { + // var newShape = constNewShape.Value.ToArray(); + // var tensorType = new TensorType(inType.TensorType.DType, newShape); + // foreach (var boxing in DistributedUtility.GetLeafCandidateNDSBPs(tensorType, inType.Placement). + // Select(ndsbp => IR.F.CPU.Boxing(args[0], new DistributedType(tensorType, ndsbp, inType.Placement)))) + // { + // if (boxing.CheckedType is not InvalidType) + // { + // calls.Add(boxing); + // } + // } + } + else + { + // todo expand search space. + // calls.AddRange(DistributedUtility.GetLeafCandidateNDSBPs(tensorType, inType.Placement). + // Select(ndsbp => IR.F.CPU.Boxing(args[0], new DistributedType(tensorType, ndsbp, inType.Placement)))); + } + } + else + { + calls.Add(call); + if (call.CheckedType is DistributedType distType) + { + // boxing for partialsum + var partialBoxings = DistributedUtility.GetPartialCandidateNDSBPs(distType). + Select(ndsbp => (ndsbp, IR.F.CPU.Boxing(call, distType with { NdSBP = ndsbp }))).Select(p => + { + var lastSbp = p.ndsbp; + var reduced = p.Item2; + return DistributedUtility.GetLeafCandidateNDSBPs(distType.TensorType, distType.Placement).Where(ndsbp => lastSbp != ndsbp).Select(ndsbp => IR.F.CPU.Boxing(reduced, distType with { NdSBP = ndsbp })).ToArray(); + }).SelectMany(i => i).ToArray(); + calls.AddRange(partialBoxings); + + using var pinner = new ExprPinner(calls.ToArray()); + var getExtraBoxings = (Expr expr) => Placements. + Where(p => p != distType.Placement). + Select(p => DistributedUtility.GetLeafCandidateNDSBPs(distType.TensorType, p). + Select(ndsbp => IR.F.CPU.Boxing(expr, new DistributedType(distType.TensorType, ndsbp, p)))). + SelectMany(b => b); + + // boxing for other placements + var extraBoxings = partialBoxings.Any() ? partialBoxings.Select(getExtraBoxings).SelectMany(i => i) : getExtraBoxings(call); + foreach (var boxing in extraBoxings) + { + if (boxing.CheckedType is not InvalidType) + { + calls.Add(boxing); + } + } + } + } + + // GC.Collect(); + return calls; + } + + private IReadOnlyList GetReBoxings(Expr expr) + { + if (expr is IR.Tuple tuple) + { + var candidates = tuple.Fields.ToArray(). + Select(GetReBoxings). + CartesianProduct(); + return candidates.Any() ? candidates. + Select(fs => new IR.Tuple(fs.ToArray())). + ToArray() : Array.Empty(); + } + + var type = (DistributedType)expr.CheckedType; + var tensorType = type.TensorType; + var candidateNdsbps = new List[type.Placement.Rank]; + for (int i = 0; i < type.Placement.Rank; i++) + { + candidateNdsbps[i] = new List { SBP.B }; + for (int axis = 0; axis < tensorType.Shape.Rank; axis++) + { + if (tensorType.Shape[axis] is { IsFixed: true, FixedValue: int s } && DistributedUtility.IsDivideExactly(s, type.Placement.Hierarchy[i])) + { + candidateNdsbps[i].Add(SBP.S(axis)); + } + } + } + + return candidateNdsbps.CartesianProduct(). + Select(ndsbp => new IRArray(ndsbp)). + Where(ndsbp => ndsbp != type.NdSBP). + Select(ndsbp => new DistributedType(tensorType, new IRArray(ndsbp), type.Placement)). + Select(disttype => IR.F.CPU.Boxing(expr, disttype)).ToArray(); + } + + private Expr InstertTerminator(Expr expr) + { + Expr CreateFinalBoxing(Expr e, DistributedType type) + { + if (type.NdSBP.Any(s => s is SBPPartial)) + { + var boxingP2B = IR.F.CPU.Boxing(e, new DistributedType(type.TensorType, type.NdSBP.Select(s => s is SBPPartial ? SBP.B : s).ToArray(), type.Placement)); + return IR.F.CPU.Boxing(boxingP2B, type.TensorType); + } + + return IR.F.CPU.Boxing(e, type.TensorType); + } + + return (expr, expr.CheckedType) switch + { + (IR.Tuple tp, TupleType tptype) => new IR.Tuple(tp.Fields.ToArray().Select(InstertTerminator).ToArray()), + (Expr e, DistributedType type) => CreateFinalBoxing(e, type), + (Expr e, TensorType type) => e, + (Expr e, AnyType type) => e, + (_, _) => throw new NotSupportedException(), + }; + } + + private EClass Unions(EGraph graph, IEnumerable equivalents) + { + var eids = equivalents.Select(graph.Add).ToArray(); + foreach (var cls in eids.Skip(1)) + { + graph.Union(eids[0], cls); + } + + graph.Rebuild(); + return eids[0]; + } + + private void BranchCut() + { + GC.Collect(); + bool changed = true; + while (changed) + { + changed = false; + foreach (var (e, bukets) in ExprMemo) + { + foreach (var (_, buket) in bukets.Where(kv => kv.Value.Any())) + { + if (!buket[0].Users.Any()) + { + foreach (var item in buket) + { + if (item.Users.Any()) + { + throw new InvalidOperationException("this item can't have more than zero users!"); + } + } + + buket.Clear(); + changed = true; + } + } + } + } + } +} +#endif diff --git a/src/Nncase.Core/Utilities/DistributedUtility.cs b/src/Nncase.Core/Utilities/DistributedUtility.cs index 68392d33a5..538d1786ca 100644 --- a/src/Nncase.Core/Utilities/DistributedUtility.cs +++ b/src/Nncase.Core/Utilities/DistributedUtility.cs @@ -33,6 +33,43 @@ public static IReadOnlyList> GetLeafCandidateNDSBPs(TensorType tens return ndsbps.CartesianProduct().Select(ndsbp => ndsbp.ToArray()).Where(ndsbp => IsDistributable(tensorType, ndsbp, placement)).Select(ndsbp => new IRArray(ndsbp)).ToArray(); } + public static IReadOnlyList> GetPartialCandidateNDSBPs(DistributedType distributedType) + { + IRArray ndsbp = distributedType.NdSBP; + TensorType tensorType = distributedType.TensorType; + Placement placement = distributedType.Placement; + if (!ndsbp.Any(sbp => sbp is SBPPartial)) + { + return Array.Empty>(); + } + + var candidateNdsbps = new List[placement.Rank]; + for (int i = 0; i < placement.Rank; i++) + { + candidateNdsbps[i] = new List(); + var innerSplitedAxes = distributedType.NdSBP.Skip(i + 1).OfType().Select(sbp => sbp.Axis).ToList(); + if (ndsbp[i] is SBPPartial) + { + candidateNdsbps[i].Add(SBP.B); + + // note separate reduce boxing and reshard boxing. + // for (int axis = 0; axis < tensorType.Shape.Rank; axis++) + // { + // if (tensorType.Shape[axis] is { IsFixed: true, Value: int s } && placement.Hierarchy[i] > 1 && IsDivideExactly(s, placement.Hierarchy[i]) && !innerSplitedAxes.Contains(axis)) + // { + // candidateNdsbps[i].Add(SBP.S(axis)); + // } + // } + } + else + { + candidateNdsbps[i].Add(ndsbp[i]); + } + } + + return candidateNdsbps.CartesianProduct().Select(ndsbp => ndsbp.ToArray()).Where(ndsbp => IsDistributable(tensorType, ndsbp, placement)).Select(ndsbp => new IRArray(ndsbp)).ToArray(); + } + public static bool IsDistributable(TensorType tensorType, ReadOnlySpan ndsbp, Placement placement) { if (!tensorType.Shape.IsFixed) diff --git a/src/Nncase.EGraph/Passes/EGraphExtensions.cs b/src/Nncase.EGraph/Passes/EGraphExtensions.cs index 43c56ccd43..8cf5ac0a0a 100644 --- a/src/Nncase.EGraph/Passes/EGraphExtensions.cs +++ b/src/Nncase.EGraph/Passes/EGraphExtensions.cs @@ -44,8 +44,41 @@ public static Expr Extract(this IEGraph eGraph, EClass root, CompileOptions comp } // 2. start the cost evaluator - var costModel = new CostModel.EGraphCostEvaluator(root.Find(), compileOptions, basefunc_cost_evaluator, false).Evaluate(); + // var costModel = new CostModel.EGraphCostEvaluator(root.Find(), compileOptions, basefunc_cost_evaluator, false).Evaluate(); + var enodeCostMemo = new Dictionary(); + var opCostMemo = new Dictionary(); + foreach (var enode in eGraph.Nodes) + { + switch (enode.Expr) + { + case Call { Target: Expr target } call: + switch (target) + { + case Op op: + var returnType = enode.Expr.CheckedType; + var key = new CostMemoKey(enode, new CostMemoKeyPartial(op, returnType, enode.Children.Skip(1).Select(x => x.CheckedType).ToArray())); + if (!opCostMemo.TryGetValue(key, out var newCost)) + { + var context = new EGraphOpCostEvaluateContext(returnType, enode.Children.Skip(1).Select(x => x.CheckedType).ToArray(), enode.Children.Skip(1).ToArray(), compileOptions); + newCost = CompilerServices.EvaluateOpCost(op, context); + opCostMemo.Add(key, newCost); + } + + enodeCostMemo[enode] = Cost.Zero; + break; + default: + enodeCostMemo[enode] = Cost.Zero; + break; + } + + break; + default: + enodeCostMemo[enode] = Cost.Zero; + break; + } + } - return new EGraphExtractor(costModel).Extract(root.Find(), eGraph, constrains ?? Array.Empty()); + var egraphCostModel = new EGraphCostModel(enodeCostMemo); + return new EGraphExtractor(egraphCostModel).Extract(root.Find(), eGraph, constrains ?? Array.Empty()); } } From 3ec5164e561f3b8023cee48e7bd574daedd696da Mon Sep 17 00:00:00 2001 From: zhengqihang <597323109@qq.com> Date: Wed, 12 Feb 2025 08:33:01 +0000 Subject: [PATCH 18/18] enable autodist reshape --- .../Passes/Distributed/AutoDistributed.cs | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/modules/Nncase.Modules.CPU/Passes/Distributed/AutoDistributed.cs b/modules/Nncase.Modules.CPU/Passes/Distributed/AutoDistributed.cs index 79e4dc3b1b..7b5f5595c8 100644 --- a/modules/Nncase.Modules.CPU/Passes/Distributed/AutoDistributed.cs +++ b/modules/Nncase.Modules.CPU/Passes/Distributed/AutoDistributed.cs @@ -1648,16 +1648,16 @@ private IEnumerable BuildEquivalCalls(Op target, Expr[] args) { if (target is Reshape && args[0].CheckedType is DistributedType inType && args[1] is TensorConst constNewShape) { - // var newShape = constNewShape.Value.ToArray(); - // var tensorType = new TensorType(inType.TensorType.DType, newShape); - // foreach (var boxing in DistributedUtility.GetLeafCandidateNDSBPs(tensorType, inType.Placement). - // Select(ndsbp => IR.F.CPU.Boxing(args[0], new DistributedType(tensorType, ndsbp, inType.Placement)))) - // { - // if (boxing.CheckedType is not InvalidType) - // { - // calls.Add(boxing); - // } - // } + var newShape = constNewShape.Value.ToArray(); + var tensorType = new TensorType(inType.TensorType.DType, newShape); + foreach (var boxing in DistributedUtility.GetLeafCandidateNDSBPs(tensorType, inType.Placement). + Select(ndsbp => IR.F.CPU.Boxing(args[0], new DistributedType(tensorType, ndsbp, inType.Placement)))) + { + if (boxing.CheckedType is not InvalidType) + { + calls.Add(boxing); + } + } } else {