diff --git a/src/Nncase.Core/IR/Expr.cs b/src/Nncase.Core/IR/Expr.cs index 894ac9ef08..3ae35627d8 100644 --- a/src/Nncase.Core/IR/Expr.cs +++ b/src/Nncase.Core/IR/Expr.cs @@ -116,6 +116,8 @@ public DataType CheckedDataType { case TensorType type: return type.DType; + case DistributedType type: + return type.TensorType.DType; default: if (DumpScope.Current.IsEnabled(DumpFlags.Compile)) { diff --git a/src/Nncase.Core/Utilities/DistributedUtility.cs b/src/Nncase.Core/Utilities/DistributedUtility.cs index 13b2870bfb..2061a40958 100644 --- a/src/Nncase.Core/Utilities/DistributedUtility.cs +++ b/src/Nncase.Core/Utilities/DistributedUtility.cs @@ -16,7 +16,7 @@ public static IReadOnlyList> GetLeafCandidateNDSBPs(TensorType tens var ndsbp = new List(); for (int axis = 0; axis < tensorType.Shape.Rank; axis++) { - if (tensorType.Shape[axis] is { IsFixed: true, Value: int s } && IsDivisible(s, placement.Hierarchy[i])) + if (tensorType.Shape[axis] is { IsFixed: true, Value: int s } && IsDivideBy(s, placement.Hierarchy[i])) { ndsbp.Add(SBP.S(axis)); } @@ -28,7 +28,7 @@ public static IReadOnlyList> GetLeafCandidateNDSBPs(TensorType tens return ndsbps.CartesianProduct(). Select(ndsbp => ndsbp.ToArray()). - Where(ndsbp => IsDistributable(tensorType, ndsbp, placement, out _)). + Where(ndsbp => IsDistributable(tensorType, ndsbp, placement)). Select(ndsbp => new IRArray(ndsbp)). ToArray(); } @@ -53,7 +53,7 @@ public static IReadOnlyList> GetPartialCandidateNDSBPs(DistributedT candidateNdsbps[i].Add(SBP.B); for (int axis = 0; axis < tensorType.Shape.Rank; axis++) { - if (tensorType.Shape[axis] is { IsFixed: true, Value: int s } && IsDivisible(s, placement.Hierarchy[i]) && !innerSplitedAxes.Contains(axis)) + if (tensorType.Shape[axis] is { IsFixed: true, Value: int s } && IsDivideBy(s, placement.Hierarchy[i]) && !innerSplitedAxes.Contains(axis)) { candidateNdsbps[i].Add(SBP.S(axis)); } @@ -67,38 +67,101 @@ public static IReadOnlyList> GetPartialCandidateNDSBPs(DistributedT return candidateNdsbps.CartesianProduct(). Select(ndsbp => ndsbp.ToArray()). - Where(ndsbp => IsDistributable(tensorType, ndsbp, placement, out _)). + Where(ndsbp => IsDistributable(tensorType, ndsbp, placement)). Select(ndsbp => new IRArray(ndsbp)). ToArray(); } - public static bool IsDistributable(TensorType tensorType, ReadOnlySpan ndsbp, Placement placement, [MaybeNullWhen(false)] out TensorType distType) + public static bool IsDistributable(TensorType tensorType, ReadOnlySpan ndsbp, Placement placement) { - distType = null; if (!tensorType.Shape.IsFixed) { return false; } - var shape = tensorType.Shape.ToValueArray(); - for (int i = 0; i < ndsbp.Length; i++) + var divisors = GetDivisors(new DistributedType(tensorType, new IRArray(ndsbp.ToArray()), placement)); + return divisors.Select((d, axis) => (d, axis)).All(p => p.d == 0 ? true : IsDivideBy(tensorType.Shape[p.axis].FixedValue, p.d)); + } + + public static IReadOnlyList GetDivisors(DistributedType distributedType) + { + var shape = distributedType.TensorType.Shape.ToValueArray(); + var divisors = Enumerable.Repeat(0, shape.Length).ToArray(); + for (int i = 0; i < distributedType.NdSBP.Count; i++) { - if (ndsbp[i] is SBPSplit { Axis: int axis }) + if (distributedType.NdSBP[i] is SBPSplit { Axis: int axis }) { - if (!IsDivisible(shape[axis], placement.Hierarchy[i])) + if (divisors[axis] == 0) { - return false; + divisors[axis] = 1; } - shape[axis] /= placement.Hierarchy[i]; + divisors[axis] *= distributedType.Placement.Hierarchy[i]; } } - distType = tensorType with { Shape = shape }; - return true; + return divisors; + } + + public static bool TryGetDividedTensorType(DistributedType distributedType, [System.Diagnostics.CodeAnalysis.MaybeNullWhen(false)] out TensorType tensorType) + { + tensorType = null; + var divisors = GetDivisors(distributedType); + if (divisors.Select((d, i) => (d, i)).All(p => p.d == 0 || IsDivideExactly(distributedType.TensorType.Shape[p.i].FixedValue, p.d))) + { + tensorType = new TensorType(distributedType.TensorType.DType, distributedType.TensorType.Shape.Zip(divisors).Select(p => p.Second == 0 ? p.First.FixedValue : p.First.FixedValue / p.Second).ToArray()); + return true; + } + + return false; + } + + public static Expr[] TryGetNonUniformDividedShape(DistributedType distributedType) + { + var shape = distributedType.TensorType.Shape.ToValueArray(); + var hierarchies = Enumerable.Range(0, shape.Length).Select(i => new List()).ToArray(); + var ids = distributedType.Placement.Name.Select(c => new Var(c + "id", TensorType.Scalar(DataTypes.Int32))).ToArray(); + var hierarchyStrides = TensorUtilities.GetStrides(distributedType.Placement.Hierarchy.ToArray()); + for (int i = 0; i < distributedType.NdSBP.Count; i++) + { + if (distributedType.NdSBP[i] is SBPSplit { Axis: int axis }) + { + hierarchies[axis].Add(i); + } + } + + return hierarchies.Select((divs, axis) => + { + Expr dim; + if (divs.Any()) + { + var divsor = (int)TensorUtilities.GetProduct(divs.Select(h => distributedType.Placement.Hierarchy[h]).ToArray()); + var (res, rem) = Math.DivRem(shape[axis], divsor); + dim = IR.F.Math.Select( + TensorUtilities.GetIndex(hierarchyStrides.TakeLast(divs.Count).Select(s => (Expr)s).ToArray(), divs.Select(h => ids[h]).ToArray()) < (divsor - 1), + res, + res + rem); + } + else + { + dim = distributedType.TensorType.Shape[axis].FixedValue; + } + + return dim; + }).ToArray(); + } + + public static bool IsDivideBy(int input, int divisor) + { + if (input >= divisor) + { + return true; + } + + return false; } - public static bool IsDivisible(int input, int divisor) + public static bool IsDivideExactly(int input, int divisor) { if (input >= divisor && input % divisor == 0) {