Skip to content

Commit

Permalink
Fix FInferShape for some ops to support partial type inference (apach…
Browse files Browse the repository at this point in the history
…e#18348)

* Fix FInferShape for some ops to support partial type inference

Signed-off-by: Serge Panev <[email protected]>

* Add missing ndim check in in matrix_op-inl.h

Signed-off-by: Serge Panev <[email protected]>
  • Loading branch information
Kh4L authored May 22, 2020
1 parent 2cb6153 commit d9fc74e
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 22 deletions.
7 changes: 3 additions & 4 deletions src/operator/contrib/batch_norm_relu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ static bool BatchNormWithReLUShape(const nnvm::NodeAttrs& attrs,
CHECK_EQ(in_shape->size(), 5U) << "Input:[data, gamma, beta, MovingMean, MovingVar]";
CHECK_EQ(out_shape->size(), 4U);
const mxnet::TShape &dshape = in_shape->at(batchnormrelu::kData);
if (!mxnet::ndim_is_known(dshape)) {
return false;
}

const size_t channelAxis = static_cast<size_t>(param.axis < 0
? static_cast<int>(dshape.ndim()) + param.axis
Expand All @@ -63,10 +66,6 @@ static bool BatchNormWithReLUShape(const nnvm::NodeAttrs& attrs,

const int channelCount = dshape[channelAxis];

if (!mxnet::ndim_is_known(dshape)) {
return false;
}

in_shape->at(batchnormrelu::kGamma) = mxnet::TShape(Shape1(channelCount));
in_shape->at(batchnormrelu::kBeta) = mxnet::TShape(Shape1(channelCount));
in_shape->at(batchnormrelu::kInMovingMean) = mxnet::TShape(Shape1(channelCount)); // kMovingMean
Expand Down
7 changes: 3 additions & 4 deletions src/operator/nn/batch_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,9 @@ static bool BatchNormShape(const nnvm::NodeAttrs& attrs,
CHECK_EQ(in_shape->size(), 5U) << "Input:[data, gamma, beta, MovingMean, MovingVar]";
CHECK_EQ(out_shape->size(), 3U);
const mxnet::TShape &dshape = in_shape->at(batchnorm::kData);
if (!mxnet::ndim_is_known(dshape)) {
return false;
}

const size_t channelAxis = static_cast<size_t>(param.axis < 0
? static_cast<int>(dshape.ndim()) + param.axis
Expand All @@ -331,10 +334,6 @@ static bool BatchNormShape(const nnvm::NodeAttrs& attrs,

const index_t channelCount = dshape[channelAxis];

if (!mxnet::ndim_is_known(dshape)) {
return false;
}

in_shape->at(batchnorm::kGamma) = mxnet::TShape(Shape1(channelCount));
in_shape->at(batchnorm::kBeta) = mxnet::TShape(Shape1(channelCount));
in_shape->at(batchnorm::kInMovingMean) = mxnet::TShape(Shape1(channelCount)); // kMovingMean
Expand Down
8 changes: 4 additions & 4 deletions src/operator/nn/group_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,14 @@ static bool GroupNormShape(const nnvm::NodeAttrs& attrs,
using namespace mshadow;
CHECK_EQ(in_shape->size(), 3U) << "Input:[data, gamma, beta]";
const mxnet::TShape &dshape = in_shape->at(groupnorm::kData);
CHECK_GE(dshape.ndim(), 3U);
const int num_groups = param.num_groups;
CHECK_EQ(dshape[1] % num_groups, 0) << "# of channels must be divisible by # of groups";

if (!mxnet::ndim_is_known(dshape)) {
return false;
}

CHECK_GE(dshape.ndim(), 3U);
const int num_groups = param.num_groups;
CHECK_EQ(dshape[1] % num_groups, 0) << "# of channels must be divisible by # of groups";

in_shape->at(groupnorm::kGamma) = mxnet::TShape(Shape1(dshape[1]));
in_shape->at(groupnorm::kBeta) = mxnet::TShape(Shape1(dshape[1]));

Expand Down
7 changes: 4 additions & 3 deletions src/operator/nn/layer_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,16 @@ static bool LayerNormShape(const nnvm::NodeAttrs& attrs,
using namespace mshadow;
CHECK_EQ(in_shape->size(), 3U) << "Input:[data, gamma, beta]";
const mxnet::TShape &dshape = in_shape->at(layernorm::kData);
if (!mxnet::ndim_is_known(dshape)) {
return false;
}

int axis = GetRealAxis(param.axis, dshape.ndim());
CHECK(axis >= 0 && axis < dshape.ndim())
<< "Channel axis out of range: axis=" << param.axis;

const index_t channelCount = dshape[axis];

if (!mxnet::ndim_is_known(dshape)) {
return false;
}
SHAPE_ASSIGN_CHECK(*in_shape,
layernorm::kGamma,
mxnet::TShape(Shape1(channelCount)));
Expand Down
9 changes: 7 additions & 2 deletions src/operator/nn/pooling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,15 @@ static bool PoolingShape(const nnvm::NodeAttrs &attrs,
mxnet::ShapeVector *out_shape) {
const PoolingParam &param = nnvm::get<PoolingParam>(attrs.parsed);
CHECK_EQ(in_shape->size(), 1U);
const mxnet::TShape &dshape = (*in_shape)[0];
if (!mxnet::ndim_is_known(dshape)) {
return false;
}

if (param.pool_type == pool_enum::kLpPooling) {
CHECK(param.p_value.has_value());
}
const mxnet::TShape &dshape = (*in_shape)[0];

if (param.pooling_convention == pool_enum::kSame) {
CHECK_EQ(dshape.ndim(), 3U)
<< "Pooling: Input data should be 3D in (batch, channel, x)"
Expand All @@ -114,7 +119,7 @@ static bool PoolingShape(const nnvm::NodeAttrs &attrs,
<< "Pooling: Input data should be 3D in (batch, channel, x)"
<< " Or 4D in (batch, channel, y, x) "
<< " Or 5D in (batch, channel, d, y, x)";
if (!mxnet::ndim_is_known(dshape)) return false;

int layout = param.GetLayout(dshape.ndim());
if (param.global_pool) {
mxnet::TShape oshape = dshape;
Expand Down
22 changes: 17 additions & 5 deletions src/operator/tensor/matrix_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -456,9 +456,9 @@ inline bool TransposeShape(const nnvm::NodeAttrs& attrs,
CHECK_EQ(out_attrs->size(), 1U);
mxnet::TShape& shp = (*in_attrs)[0];
mxnet::TShape& out_shp = (*out_attrs)[0];
CHECK_LE(shp.ndim(), 6) << "Transpose support at most 6 dimensions";
if (shp.ndim() == -1 && out_shp.ndim() == -1)
if (!mxnet::ndim_is_known(shp) && !mxnet::ndim_is_known(out_shp))
return false; // none of the shapes is known
CHECK_LE(shp.ndim(), 6) << "Transpose support at most 6 dimensions";
if (out_shp.ndim() >= 0 && shp.ndim() >= 0)
CHECK_EQ(out_shp.ndim(), shp.ndim());
mxnet::TShape get(std::max(shp.ndim(), out_shp.ndim()), -1);
Expand Down Expand Up @@ -513,12 +513,12 @@ inline bool ExpandDimShape(const nnvm::NodeAttrs& attrs,
const ExpandDimParam& param = nnvm::get<ExpandDimParam>(attrs.parsed);
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 1U);
if (!mxnet::ndim_is_known(in_attrs->at(0)) && !mxnet::ndim_is_known(out_attrs->at(0))) {
mxnet::TShape& ishape = (*in_attrs)[0];
mxnet::TShape& oshape = (*out_attrs)[0];
if (!mxnet::ndim_is_known(ishape) && !mxnet::ndim_is_known(oshape)) {
return false;
}

mxnet::TShape& ishape = (*in_attrs)[0];
mxnet::TShape& oshape = (*out_attrs)[0];
int indim = ishape.ndim();
bool unknown_ishape = false;
if (-1 == indim) {
Expand Down Expand Up @@ -1441,6 +1441,9 @@ inline bool SliceLikeShape(const nnvm::NodeAttrs& attrs,
CHECK_EQ(out_attrs->size(), 1U);
mxnet::TShape& ishape = (*in_attrs)[0];
mxnet::TShape& from_shape = (*in_attrs)[1];
if (!mxnet::ndim_is_known(ishape) || !mxnet::ndim_is_known(from_shape)) {
return false;
}
if (param.axes.ndim() == 0) {
CHECK_EQ(ishape.ndim(), from_shape.ndim())
<< "By default slice_axis performs slice on all axes, but ndim mismatch "
Expand Down Expand Up @@ -1749,6 +1752,9 @@ inline bool RepeatOpShape(const nnvm::NodeAttrs& attrs,
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 1U);
const mxnet::TShape& ishape = (*in_attrs)[0];
if (!mxnet::ndim_is_known(ishape)) {
return false;
}
int repeats = 0;
dmlc::optional<int> axisOpt;
GetRepeatParams(param, ishape, &repeats, &axisOpt);
Expand Down Expand Up @@ -2427,6 +2433,9 @@ inline bool DepthToSpaceOpShape(const nnvm::NodeAttrs& attrs,
mxnet::TShape expected_out(4, -1);

mxnet::TShape& in_shape = in_attrs->at(0);
if (!mxnet::ndim_is_known(in_shape)) {
return false;
}
int block = param.block_size;
CHECK_NE(block, 0) << "block_size must be a positive integer value";
CHECK_NE(in_shape[1], 0) << "Depth dimension:1 cannot be 0";
Expand Down Expand Up @@ -2591,6 +2600,9 @@ inline bool SpaceToDepthOpShape(const nnvm::NodeAttrs& attrs,
mxnet::TShape expected_out(in_attrs->at(0).ndim(), -1);

mxnet::TShape& in_shape = in_attrs->at(0);
if (!mxnet::ndim_is_known(in_shape)) {
return false;
}
int block = param.block_size;
CHECK_NE(block, 0) << "block_size must be a positive integer value";
CHECK_NE(in_shape[0], 0)
Expand Down

0 comments on commit d9fc74e

Please sign in to comment.