Skip to content

Commit

Permalink
model and hybrid block
Browse files Browse the repository at this point in the history
  • Loading branch information
zhanghang1989 committed May 25, 2018
1 parent 5eac829 commit f50115a
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 40 deletions.
12 changes: 11 additions & 1 deletion example/gluon/style_transfer/models/download_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,17 @@
# specific language governing permissions and limitations
# under the License.

import os
import zipfile
import shutil
from mxnet.test_utils import download

download('https://apache-mxnet.s3-accelerate.amazonaws.com/gluon/models/21styles-2cb88353.zip', 'models/21styles.params')
zip_file_path = 'models/21styles.zip'
download('https://apache-mxnet.s3-accelerate.amazonaws.com/gluon/models/21styles-2cb88353.zip', zip_file_path)

with zipfile.ZipFile(zip_file_path) as zf:
zf.extractall()

os.remove(zip_file_path)

shutil.move('21styles-2cb88353.params', 'models/21styles.params')
132 changes: 93 additions & 39 deletions example/gluon/style_transfer/net.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def forward(self, x):
return F.pad(x, mode='reflect', pad_width=self.pad_width)


class Bottleneck(Block):
class Bottleneck(HybridBlock):
""" Pre-activation residual block
Identity Mapping in Deep Residual Networks
ref https://arxiv.org/abs/1603.05027
Expand All @@ -73,7 +73,7 @@ def __init__(self, inplanes, planes, stride=1, downsample=None, norm_layer=Insta
self.residual_layer = nn.Conv2D(in_channels=inplanes,
channels=planes * self.expansion,
kernel_size=1, strides=(stride, stride))
self.conv_block = nn.Sequential()
self.conv_block = nn.HybridSequential()
with self.conv_block.name_scope():
self.conv_block.add(norm_layer(in_channels=inplanes))
self.conv_block.add(nn.Activation('relu'))
Expand All @@ -89,15 +89,15 @@ def __init__(self, inplanes, planes, stride=1, downsample=None, norm_layer=Insta
channels=planes * self.expansion,
kernel_size=1))

def forward(self, x):
def hybrid_forward(self, F, x):
if self.downsample is not None:
residual = self.residual_layer(x)
else:
residual = x
return residual + self.conv_block(x)


class UpBottleneck(Block):
class UpBottleneck(HybridBlock):
""" Up-sample residual block (from MSG-Net paper)
Enables passing identity all the way through the generator
ref https://arxiv.org/abs/1703.06953
Expand All @@ -107,7 +107,7 @@ def __init__(self, inplanes, planes, stride=2, norm_layer=InstanceNorm):
self.expansion = 4
self.residual_layer = UpsampleConvLayer(inplanes, planes * self.expansion,
kernel_size=1, stride=1, upsample=stride)
self.conv_block = nn.Sequential()
self.conv_block = nn.HybridSequential()
with self.conv_block.name_scope():
self.conv_block.add(norm_layer(in_channels=inplanes))
self.conv_block.add(nn.Activation('relu'))
Expand All @@ -122,11 +122,11 @@ def __init__(self, inplanes, planes, stride=2, norm_layer=InstanceNorm):
channels=planes * self.expansion,
kernel_size=1))

def forward(self, x):
def hybrid_forward(self, F, x):
return self.residual_layer(x) + self.conv_block(x)


class ConvLayer(Block):
class ConvLayer(HybridBlock):
def __init__(self, in_channels, out_channels, kernel_size, stride):
super(ConvLayer, self).__init__()
padding = int(np.floor(kernel_size / 2))
Expand All @@ -135,13 +135,13 @@ def __init__(self, in_channels, out_channels, kernel_size, stride):
kernel_size=kernel_size, strides=(stride,stride),
padding=0)

def forward(self, x):
def hybrid_forward(self, F, x):
x = self.pad(x)
out = self.conv2d(x)
return out


class UpsampleConvLayer(Block):
class UpsampleConvLayer(HybridBlock):
"""UpsampleConvLayer
Upsamples the input and then does a convolution. This method gives better results
compared to ConvTranspose2d.
Expand All @@ -152,41 +152,57 @@ def __init__(self, in_channels, out_channels, kernel_size,
stride, upsample=None):
super(UpsampleConvLayer, self).__init__()
self.upsample = upsample
"""
if upsample:
self.upsample_layer = torch.nn.UpsamplingNearest2d(scale_factor=upsample)
"""
self.reflection_padding = int(np.floor(kernel_size / 2))
self.conv2d = nn.Conv2D(in_channels=in_channels,
channels=out_channels,
kernel_size=kernel_size, strides=(stride,stride),
padding=self.reflection_padding)

def forward(self, x):
def hybrid_forward(self, F, x):
if self.upsample:
x = F.UpSampling(x, scale=self.upsample, sample_type='nearest')
"""
if self.reflection_padding != 0:
x = self.reflection_pad(x)
"""
out = self.conv2d(x)
return out


def gram_matrix(y):
(b, ch, h, w) = y.shape
features = y.reshape((b, ch, w * h))
#features_t = F.SwapAxis(features,1, 2)
gram = F.batch_dot(features, features, transpose_b=True) / (ch * h * w)
return gram
class gram_matrix(mx.operator.CustomOp):
def forward(self, is_train, req, in_data, out_data, aux):
x = in_data[0]
_, ch, h, w = x.shape
features = x.reshape((0, 0, -1))
y = F.batch_dot(features, features, transpose_b=True) / (ch * h * w)
self.assign(out_data[0], req[0], y)

def backward(self, req, out_grad, in_data, out_data, in_grad, aux):
dy = out_grad[0]
x = in_data[0]
_, ch, h, w = x.shape
features = x.reshape((0, 0, -1))
dx = F.batch_dot(dy, features) + F.batch_dot(dy, features, transpose_a=True)
dx = dx.view(0, 0, h, w) / (ch * h * w)
self.assign(in_grad[0], req[0], dx)


@mx.operator.register("gram_matrix")
class GramProp(mx.operator.CustomOpProp):
def list_arguments(self):
return ['data']

def infer_shape(self, in_shapes):
data_shape = in_shapes[0]
output_shape = (data_shape[0], data_shape[1], data_shape[1])
return (data_shape, ), (output_shape,), ()

def create_operator(self, ctx, in_shapes, in_dtypes):
return gram_matrix()

class GramMatrix(Block):
def forward(self, x):
gram = gram_matrix(x)
return gram

class Net(Block):
class GramMatrix(HybridBlock):
def hybrid_forward(self, F, x):
return mx.nd.Custom(x, op_type='gram_matrix')


class Net(HybridBlock):
def __init__(self, input_nc=3, output_nc=3, ngf=64,
norm_layer=InstanceNorm, n_blocks=6, gpu_ids=[]):
super(Net, self).__init__()
Expand All @@ -198,9 +214,9 @@ def __init__(self, input_nc=3, output_nc=3, ngf=64,
expansion = 4

with self.name_scope():
self.model1 = nn.Sequential()
self.model1 = nn.HybridSequential()
self.ins = Inspiration(ngf*expansion)
self.model = nn.Sequential()
self.model = nn.HybridSequential()

self.model1.add(ConvLayer(input_nc, 64, kernel_size=7, stride=1))
self.model1.add(norm_layer(in_channels=64))
Expand All @@ -227,11 +243,43 @@ def set_target(self, Xs):
G = self.gram(F)
self.ins.set_target(G)

def forward(self, input):
def hybrid_forward(self, F, input):
return self.model(input)


class Inspiration(Block):
class broadcast_like(mx.operator.CustomOp):
def forward(self, is_train, req, in_data, out_data, aux):
x = in_data[0]
z = in_data[1]
b, c, _, _ = z.shape
y = F.broadcast_to(x, (b, c, c))
self.assign(out_data[0], req[0], y)

def backward(self, req, out_grad, in_data, out_data, in_grad, aux):
dy = out_grad[0]
x = in_data[0]
z = in_data[1]
# hacky solution, only allow expanding at batch dims
dx = F.mean(dy, axis=0)
self.assign(in_grad[0], req[0], dx)


@mx.operator.register("broadcast_like")
class BroadcastLikeProp(mx.operator.CustomOpProp):
def list_arguments(self):
return ['data', 'target']

def infer_shape(self, in_shapes):
input_shape = in_shapes[1]
output_shape = (input_shape[0], input_shape[1], input_shape[1])
# return 3 lists representing inputs shapes, outputs shapes, and aux data shapes.
return in_shapes, (output_shape,), ()

def create_operator(self, ctx, in_shapes, in_dtypes):
return broadcast_like()


class Inspiration(HybridBlock):
""" Inspiration Layer (from MSG-Net paper)
tuning the featuremap with target Gram Matrix
ref https://arxiv.org/abs/1703.06953
Expand All @@ -240,6 +288,7 @@ def __init__(self, C, B=1):
super(Inspiration, self).__init__()
# B is equal to 1 or input mini_batch
self.C = C
self.B = B
self.weight = self.params.get('weight', shape=(1,C,C),
init=mx.initializer.Uniform(),
allow_deferred_init=True)
Expand All @@ -248,17 +297,22 @@ def __init__(self, C, B=1):
def set_target(self, target):
self.gram = target

def forward(self, X):
def hybrid_forward(self, F, X, weight):
# input X is a 3D feature map
self.P = F.batch_dot(F.broadcast_to(self.weight.data(), shape=(self.gram.shape)), self.gram)
return F.batch_dot(F.SwapAxis(self.P,1,2).broadcast_to((X.shape[0], self.C, self.C)), X.reshape((0,0,X.shape[2]*X.shape[3]))).reshape(X.shape)
P = F.batch_dot(
F.broadcast_to(weight, shape=(self.B, self.C, self.C)), self.gram)
P = F.SwapAxis(P,1,2)
return F.batch_dot(
#P.broadcast_to((X.shape[0], self.C, self.C)),
mx.nd.Custom(P, X, op_type='broadcast_like'),
X.reshape((0, 0, -1))).reshape_like(X)

def __repr__(self):
return self.__class__.__name__ + '(' \
+ 'N x ' + str(self.C) + ')'


class Vgg16(Block):
class Vgg16(HybridBlock):
def __init__(self):
super(Vgg16, self).__init__()
self.conv1_1 = nn.Conv2D(in_channels=3, channels=64, kernel_size=3, strides=1, padding=1)
Expand All @@ -279,7 +333,7 @@ def __init__(self):
self.conv5_2 = nn.Conv2D(in_channels=512, channels=512, kernel_size=3, strides=1, padding=1)
self.conv5_3 = nn.Conv2D(in_channels=512, channels=512, kernel_size=3, strides=1, padding=1)

def forward(self, X):
def hybrid_forward(self, F, X):
h = F.Activation(self.conv1_1(X), act_type='relu')
h = F.Activation(self.conv1_2(h), act_type='relu')
relu1_2 = h
Expand All @@ -301,4 +355,4 @@ def forward(self, X):
h = F.Activation(self.conv4_3(h), act_type='relu')
relu4_3 = h

return [relu1_2, relu2_2, relu3_3, relu4_3]
return relu1_2, relu2_2, relu3_3, relu4_3

0 comments on commit f50115a

Please sign in to comment.