Skip to content

Commit

Permalink
rename flag in convolution.cc
Browse files Browse the repository at this point in the history
  • Loading branch information
crazydemo committed Mar 4, 2022
1 parent 329dcab commit 0357162
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions src/relay/op/nn/convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -192,10 +192,10 @@ bool Conv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const Layout in_layout(param->data_layout);
const Layout kernel_layout(param->kernel_layout);

bool is_group = false;
bool is_dnnl_group_conv = false;
if (param->groups > 1 && kernel_layout.name().find("G") != std::string::npos) {
kOIHW = Layout("GOIHW");
is_group = true;
is_dnnl_group_conv = true;
}

const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCHW);
Expand Down Expand Up @@ -250,7 +250,7 @@ bool Conv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
ICHECK_EQ(param->dilation.size(), 2);
Array<IndexExpr> wshape;

if (is_group) {
if (is_dnnl_group_conv) {
// infer weight's shape for group convolution
wshape = {{param->groups, indexdiv(param->channels, param->groups),
indexdiv(dshape_nchw[1], param->groups), param->kernel_size[0],
Expand Down Expand Up @@ -752,10 +752,10 @@ bool Conv2DTransposeRel(const Array<Type>& types, int num_inputs, const Attrs& a
const Layout in_layout(param->data_layout);
const Layout kernel_layout(param->kernel_layout);

bool is_group = false;
bool is_dnnl_group_conv = false;
if (param->groups > 1 && kernel_layout.name().find("G") != std::string::npos) {
kIOHW = Layout("GIOHW");
is_group = true;
is_dnnl_group_conv = true;
}

const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCHW);
Expand Down Expand Up @@ -784,7 +784,7 @@ bool Conv2DTransposeRel(const Array<Type>& types, int num_inputs, const Attrs& a
ICHECK_EQ(param->dilation.size(), 2);

Array<IndexExpr> wshape;
if (is_group) {
if (is_dnnl_group_conv) {
// infer weight's shape for group convolution
wshape = {{param->groups, indexdiv(dshape_nchw[1], param->groups),
indexdiv(param->channels, param->groups), param->kernel_size[0],
Expand Down

0 comments on commit 0357162

Please sign in to comment.