Skip to content

Commit

Permalink
Fix RemoveBoxing
Browse files Browse the repository at this point in the history
  • Loading branch information
sunnycase committed Feb 12, 2025
1 parent d2067be commit f1daad0
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 8 deletions.
12 changes: 12 additions & 0 deletions src/Nncase.Core/IR/Expr.Operators.cs
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,21 @@ public partial class Expr
IR.Tuple t => t[(int)indices.Single()],
Call { Target: Concat { Axis: 0 } } c when indices.Length == 1 => c[Concat.Input][indices[0]][0],
Call { Target: Reshape } c when c[Reshape.Shape] is TensorConst tc && tc.Value.Length == 1 && tc.Value.ToScalar<long>() == 1 => c[Reshape.Input],
Call { Target: Stack } c when indices.Length == 1 => c[Stack.Inputs][indices[0]],
Call { Target: Fusion } c when indices.Length == 1 => GetItemOfBaseFunction(c, indices[0]),
_ => this[indices.Select(x => (Expr)x).ToArray()],
};

private Expr GetItemOfBaseFunction(Call c, long v)
{
if (c.CheckedType is TupleType tt && tt.Count == 3 && v == 3)
{
Console.WriteLine();
}

return c[(Expr)v];
}

/// <summary>
/// get the item from the expr.
/// </summary>
Expand Down
6 changes: 6 additions & 0 deletions src/Nncase.Core/IR/Shape.cs
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,12 @@ public long[] ToValueArray()
return this.Select(x => x.FixedValue).ToArray();
}

public Expr ToValueArrayExpr()
{
var tuple = new IR.Tuple(Dimensions);
return IR.F.Tensors.Stack(tuple, 0);
}

/// <inheritdoc/>
public override string ToString() => Kind switch
{
Expand Down
8 changes: 3 additions & 5 deletions src/Nncase.Evaluator/Tensors/Slice.cs
Original file line number Diff line number Diff line change
Expand Up @@ -145,12 +145,10 @@ private static Dimension TranslateBeginEnd(Dimension x, Dimension dim, long lowe
}
else
{
return ShapeExprUtility.If(
return Select(
x.Value < 0L,
(x, dim) => dim + x,
(x, dim) => Clamp(x, lowerBound, dim + upperBoundBias),
x.Value,
dim.Value);
(dim + x).ToExpr(),
Dimension.Clamp(x, lowerBound, dim + upperBoundBias).ToExpr());
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/Nncase.Passes/GraphPartition/ExprReConstructor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ public Expr Construct()
{
var sourceOutIndex = sourceOutVertices.IndexOf(inEdge.Source);
var postResult = ClusterMemo[sourceCluster];
postArg = postResult is IR.Tuple tp ? tp.Fields[sourceOutIndex] : IR.F.Tensors.GetItem(postResult, sourceOutIndex);
postArg = postResult[sourceOutIndex];
}

pairs.Add((inEdge.Source.Expr, postArg));
Expand Down
16 changes: 15 additions & 1 deletion src/Nncase.Passes/Mutators/RemoveBoxingRewriter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,21 @@ public RemoveBoxingCloner()

protected override Expr VisitLeafCall(Call expr, Unit context)
{
return expr.Target is Boxing ? Visit(expr[Boxing.Input], context) : base.VisitLeafCall(expr, context);
if (expr.Target is Boxing boxing)
{
var input = Visit(expr[Boxing.Input], context);
if (boxing.NewType is DistributedType dt && dt.TensorType != input.CheckedType)
{
// Reshape
return IR.F.Tensors.Reshape(input, dt.TensorType.Shape.ToValueArrayExpr());
}
else
{
return input;
}
}

return base.VisitLeafCall(expr, context);
}

protected override Expr VisitLeafTensorConst(TensorConst expr, Unit context)
Expand Down
2 changes: 1 addition & 1 deletion src/Nncase.Passes/Rules/Neutral/FoldGetItemReshape.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ public partial class FoldGetItemReshape : RewriteRule<Pattern>

public Pattern ReshapePattern => IsReshape(IsWildcard("input"), new long[] { 1 });

private Expr? GetReplace(Expr input, int index)
private Expr? GetReplace(Expr input)
{
return input;
}
Expand Down

0 comments on commit f1daad0

Please sign in to comment.