Skip to content

Commit

Permalink
Generalise the setting of the width and height of the input parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
samhodge committed Feb 24, 2018
1 parent 48a0e1a commit 1d72d60
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 5 deletions.
5 changes: 4 additions & 1 deletion example/gluon/style_transfer/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,9 @@ def evaluate(args):
style_image = utils.tensor_load_rgbimage(args.style_image, ctx, size=args.style_size)
style_image = utils.preprocess_batch(style_image)
# model
style_model = net.Net(ngf=args.ngf)
WIDTH = content_image.shape[2]
HEIGHT = content_image.shape[3]
style_model = net.Net(ngf=args.ngf,width=WIDTH,height=HEIGHT)
style_model.collect_params().load(args.model, ctx=ctx)
# forward
style_model.setTarget(style_image)
Expand All @@ -156,6 +158,7 @@ def evaluate(args):
style_model.save_params("MODEL.params")
graph = mx.viz.plot_network(y,save_format='pdf')
graph.render()
mx.visualization.print_summary(y,{'data':(1,3,WIDTH,HEIGHT)})


def optimize(args):
Expand Down
10 changes: 6 additions & 4 deletions example/gluon/style_transfer/net.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def hybrid_forward(self, F, x):

class Net(HybridBlock):
def __init__(self, input_nc=3, output_nc=3, ngf=64,
norm_layer=InstanceNorm, n_blocks=6, gpu_ids=[],ctx=mx.cpu(0)):
norm_layer=InstanceNorm, n_blocks=6, gpu_ids=[],ctx=mx.cpu(0),width=1920,height=1080):
super(Net, self).__init__()
self.gpu_ids = gpu_ids
self.gram = GramMatrix()
Expand All @@ -265,7 +265,7 @@ def __init__(self, input_nc=3, output_nc=3, ngf=64,

with self.name_scope():
self.model1 = nn.HybridSequential()
self.ins = Inspiration(ngf*expansion,ctx=ctx)
self.ins = Inspiration(ngf*expansion,ctx=ctx,width=width,height=height)
self.model = nn.HybridSequential()

self.model1.add(ConvLayer(input_nc, 64, kernel_size=7, stride=1))
Expand Down Expand Up @@ -302,7 +302,7 @@ class Inspiration(HybridBlock):
tuning the featuremap with target Gram Matrix
ref https://arxiv.org/abs/1703.06953
"""
def __init__(self, C, B=1,ctx=mx.cpu(0)):
def __init__(self, C, B=1,ctx=mx.cpu(0),width=1920,height=1080):
super(Inspiration, self).__init__()
# B is equal to 1 or input mini_batch
self.C = C
Expand All @@ -317,6 +317,8 @@ def __init__(self, C, B=1,ctx=mx.cpu(0)):
lr_mult=0)
self.weight.initialize(ctx=ctx)
self.gram.initialize(ctx=ctx)
self.width=width
self.height=height

def setTarget(self, target):
self.gram.set_data(target)
Expand All @@ -325,7 +327,7 @@ def hybrid_forward(self, F, X, gram, weight):
if not isinstance(X,mx.symbol.Symbol):
return F.batch_dot(F.SwapAxis(self.P,1,2).broadcast_to((X.shape[0], self.C, self.C)), X.reshape((0,0,X.shape[2]*X.shape[3]))).reshape(X.shape)
else:
arg_shapes ,out_shapes,aux_shapes=X.infer_shape_partial(data=(1,3,1920,1080)) #1 , RGB, Width, Height, Based on 1080p resolution.
arg_shapes ,out_shapes,aux_shapes=X.infer_shape_partial(data=(1,3,self.width,self.height)) #1 , RGB, Width, Height, Based on 1080p resolution.
return F.batch_dot(F.SwapAxis(self.P,1,2).broadcast_to((out_shapes[0][0], self.C, self.C)), X.reshape((0,0,out_shapes[0][2]*out_shapes[0][3]))).reshape(out_shapes[0])

def __repr__(self):
Expand Down

0 comments on commit 1d72d60

Please sign in to comment.