Skip to content

Commit

Permalink
Fix 3D BatchToSpace (#1058)
Browse files Browse the repository at this point in the history
* fix transpose perm

* Fix nhwc2nchw in 3D

---------

Co-authored-by: lerenhua <[email protected]>
Co-authored-by: 乐仁华 <[email protected]>
  • Loading branch information
3 people authored Aug 17, 2023
1 parent e5d305d commit 2f10458
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 8 deletions.
4 changes: 2 additions & 2 deletions src/Native/src/kernels/stackvm/reference/batch_to_space.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ dims_t infer_shape(gsl::span<const size_t> origin_in_shape,
gsl::span<const size_t> block_shape,
const paddings_t &crops) {
auto d4 = fixed_dims(0, 2, 3, 1);
auto d3 = fixed_dims(1, 2, 0);
auto d3 = fixed_dims(0, 2, 1);
auto inPerm = origin_in_shape.size() == 4
? gsl::span<const size_t>{d4.data(), d4.size()}
: gsl::span<const size_t>{d3.data(), d3.size()};
Expand All @@ -123,7 +123,7 @@ dims_t infer_shape(gsl::span<const size_t> origin_in_shape,
in_shape.end());
}
auto outd4 = fixed_dims(0, 3, 1, 2);
auto outd3 = fixed_dims(2, 0, 1);
auto outd3 = fixed_dims(0, 2, 1);
auto outPerm = origin_in_shape.size() == 4
? gsl::span<const size_t>{outd4.data(), outd4.size()}
: gsl::span<const size_t>{outd3.data(), outd3.size()};
Expand Down
4 changes: 2 additions & 2 deletions src/Nncase.Core/IR/Tensors/Functional.cs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ public static Expr NHWCToNCHW(Expr input)
}
else if (input.CheckedShape.Rank == 3)
{
perm = new[] { 2, 0, 1 };
perm = new[] { 0, 2, 1 };
}
else
{
Expand All @@ -50,7 +50,7 @@ public static Expr NCHWToNHWC(Expr input)
}
else if (input.CheckedShape.Rank == 3)
{
perm = new[] { 1, 2, 0 };
perm = new[] { 0, 2, 1 };
}
else
{
Expand Down
8 changes: 4 additions & 4 deletions src/Nncase.Evaluator/NN/BatchToSpace.cs
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ public Expr Visit(IShapeEvaluateContext context, BatchToSpace target)

if (input.CheckedShape.Rank == 3)
{
inShape = Stack(new IR.Tuple(inShape[1], inShape[2], inShape[0]), 0);
inShape = Stack(new IR.Tuple(inShape[0], inShape[2], inShape[1]), 0);
}

var blockShape = context.GetArgument(target, BatchToSpace.BlockShape);
Expand Down Expand Up @@ -142,7 +142,7 @@ public Expr Visit(IShapeEvaluateContext context, BatchToSpace target)

if (input.CheckedShape.Rank == 3)
{
return Stack(new IR.Tuple(outShapeList[2], outShapeList[0], outShapeList[1]), 0);
return Stack(new IR.Tuple(outShapeList[0], outShapeList[2], outShapeList[1]), 0);
}

throw new NotImplementedException();
Expand Down Expand Up @@ -186,7 +186,7 @@ private IRType Visit(ITypeInferenceContext context, BatchToSpace target, TensorT
{
var inShape = input.Shape.Rank == 4
? TypeInference.ApplyPerm(input.Shape, new[] { 0, 2, 3, 1 })
: TypeInference.ApplyPerm(input.Shape, new[] { 1, 2, 0 });
: TypeInference.ApplyPerm(input.Shape, new[] { 0, 2, 1 });
var batch = inShape[0];
if (context.GetArgument(target, BatchToSpace.BlockShape) is TensorConst blockShapeValue &&
context.GetArgument(target, BatchToSpace.Crops) is TensorConst cropsValue)
Expand All @@ -211,7 +211,7 @@ private IRType Visit(ITypeInferenceContext context, BatchToSpace target, TensorT
var outShape =
outShapeList.Length == 4
? TypeInference.ApplyPerm(outShapeList, new[] { 0, 3, 1, 2 })
: TypeInference.ApplyPerm(outShapeList, new[] { 2, 0, 1 });
: TypeInference.ApplyPerm(outShapeList, new[] { 0, 2, 1 });
return input with { Shape = outShape };
}
else
Expand Down

0 comments on commit 2f10458

Please sign in to comment.