Reproduce work in arXiv:1904.04971v2 with the implement of MXNet-Gluon.
I use groupwise convolution to implement CondConv easily --
- Combine kernels then do convolution
- Reshape
x
from(bs, c, h, w)
to(1, bs*c, h, w)
- Combine
weight
from(k, oc, c, kh, kw)
to(bs, oc, c, kh, kw)
and then reshape to(bs*oc, c, kh, kw)
- Combine
bias
from(k, oc)
to(bs, oc)
and then reshape to(bs*oc, )
- Do convolution with
num_filter=bs*oc
andnum_group=bs
and get outputs with shape(1, bs*oc, oh, ow)
- Reshape outputs to
(bs, oc, oh, ow)
which are the final results for CondConv
- Reshape
- Do convolution then combine outputs
- Tile
x
on the second axis fork
times, and get a newx
with shape(bs, k*c, h, w)
- Reshape
weight
from(k, oc, c, kh, kw)
to(k*oc, c, kh, kw)
- Reshape
bias
from(k, oc)
to(k*oc, )
- Do convolution with
num_filter=k*oc
andnum_group=k
and get outputs with shape(bs, k*oc, oh, ow)
- Reshape outputs to
(bs, k, oc, oh, ow)
and combine to(bs, oc, oh, ow)
which are the final results for CondConv
- Tile
For small k
(<8), training with latter method is faster.
For large k
(>=8), training with the former method is suggested.
num_experts | Parameters | FLOPS | Top-1 Acc |
---|---|---|---|
(baseline) | 274,042 | 41,013,878 | 91.51% |
4 | 1,078,402(+293%) | 42,087,854(+2.6%) | 91.77% |
8 | 2,150,026(+684%) | 43,161,830(+5.2%) | 91.81% |
16 | 4,293,274(+1467%) | 45,309,782(+10.5%) | 91.89% |
32 | 8,579,770(+3031%) | 49,605,686(+20.9%) | 92.26% |
(resnet56) | 860,026(+314%) | 126,292,598(+308%) | 92.85% |
More details refer to CondConv:按需定制的卷积权重 | Hey~YaHei!