Skip to content

Commit

Permalink
Merge pull request #1131 from huihuifan/weightnorm
Browse files Browse the repository at this point in the history
adding weight norm container
  • Loading branch information
soumith authored Feb 21, 2017
2 parents c3cdec6 + 28d67b0 commit 1b95f7e
Show file tree
Hide file tree
Showing 4 changed files with 252 additions and 26 deletions.
165 changes: 165 additions & 0 deletions WeightNorm.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
-- Weight Normalization
-- https://arxiv.org/pdf/1602.07868v3.pdf
local WeightNorm, parent = torch.class("nn.WeightNorm", "nn.Container")

function WeightNorm:__init(module, outputDim)
-- this container will apply Weight Normalization to any module it wraps
-- it accepts parameter ``outputDim`` that represents the dimension of the output of the weight
-- if outputDim is not 1, the container will transpose the weight
-- if the weight is not 2D, the container will view the weight into a 2D shape
-- that is nOut x (nIn x kw x dw x ...)

parent.__init(self)
assert(module.weight)

if module.bias then
self.bias = module.bias
self.gradBias = module.gradBias
end
self.gradWeight = module.gradWeight
self.weight = module.weight

self.outputDim = outputDim or 1

-- track the non-output weight dimensions
self.otherDims = 1
for i = 1, self.weight:dim() do
if i ~= self.outputDim then
self.otherDims = self.otherDims * self.weight:size(i)
end
end

-- view size for weight norm 2D calculations
self.viewIn = torch.LongStorage({self.weight:size(self.outputDim), self.otherDims})

-- view size back to original weight
self.viewOut = self.weight:size()

-- bubble outputDim size up to the front
for i = self.outputDim - 1, 1, -1 do
self.viewOut[i], self.viewOut[i + 1] = self.viewOut[i + 1], self.viewOut[i]
end

-- weight is reparametrized to decouple the length from the direction
-- such that w = g * ( v / ||v|| )
self.v = torch.Tensor(self.viewIn[1], self.viewIn[2])
self.g = torch.Tensor(self.viewIn[1])

self._norm = torch.Tensor(self.viewIn[1])
self._scale = torch.Tensor(self.viewIn[1])

-- gradient of g
self.gradG = torch.Tensor(self.viewIn[1]):zero()
-- gradient of v
self.gradV = torch.Tensor(self.viewIn)

self.modules[1] = module
self:resetInit()
end

function WeightNorm:permuteIn(inpt)
local ans = inpt
for i = self.outputDim - 1, 1, -1 do
ans = ans:transpose(i, i+1)
end
return ans
end

function WeightNorm:permuteOut(inpt)
local ans = inpt
for i = 1, self.outputDim - 1 do
ans = ans:transpose(i, i+1)
end
return ans
end

function WeightNorm:resetInit(inputSize, outputSize)
self.v:normal(0, math.sqrt(2/self.viewIn[2]))
self.g:norm(self.v, 2, 2)
if self.bias then
self.bias:zero()
end
end

function WeightNorm:updateOutput(input)
-- view to 2D when weight norm container operates
self.gradV:copy(self:permuteIn(self.weight))
self.gradV = self.gradV:view(self.viewIn)

-- ||w||
self._norm:norm(self.v, 2, 2):pow(2):add(10e-5):sqrt()
-- g * w / ||w||
self.gradV:copy(self.v)
self._scale:copy(self.g):cdiv(self._norm)
self.gradV:cmul(self._scale:view(self.viewIn[1], 1)
:expand(self.viewIn[1], self.viewIn[2]))

-- otherwise maintain size of original module weight
self.gradV = self.gradV:view(self.viewOut)

self.weight:copy(self:permuteOut(self.gradV))
self.output:set(self.modules[1]:updateOutput(input))
return self.output
end

function WeightNorm:accGradParameters(input, gradOutput, scale)
scale = scale or 1
self.modules[1]:accGradParameters(input, gradOutput, scale)

self.weight:copy(self:permuteIn(self.weight))
self.gradV:copy(self:permuteIn(self.gradWeight))
self.weight = self.weight:view(self.viewIn)

local norm = self._norm:view(self.viewIn[1], 1):expand(self.viewIn[1], self.viewIn[2])
local scale = self._scale:view(self.viewIn[1], 1):expand(self.viewIn[1], self.viewIn[2])

-- dL / dw * (w / ||w||)
self.weight:copy(self.gradV)
self.weight:cmul(self.v):cdiv(norm)
self.gradG:sum(self.weight, 2)

-- dL / dw * g / ||w||
self.gradV:cmul(scale)

-- dL / dg * (w * g / ||w||^2)
self.weight:copy(self.v):cmul(scale):cdiv(norm)
self.weight:cmul(self.gradG:view(self.viewIn[1], 1)
:expand(self.viewIn[1], self.viewIn[2]))

-- dL / dv update
self.gradV:add(-1, self.weight)

self.gradV = self.gradV:view(self.viewOut)
self.weight = self.weight:view(self.viewOut)
self.gradWeight:copy(self:permuteOut(self.gradV))
end

function WeightNorm:updateGradInput(input, gradOutput)
self.gradInput:set(self.modules[1]:updateGradInput(input, gradOutput))
return self.gradInput
end

function WeightNorm:zeroGradParameters()
self.modules[1]:zeroGradParameters()
self.gradV:zero()
self.gradG:zero()
end

function WeightNorm:updateParameters(lr)
self.modules[1]:updateParameters(lr)
self.g:add(-lr, self.gradG)
self.v:add(-lr, self.gradV)
end

function WeightNorm:parameters()
if self.bias then
return {self.v, self.g, self.bias}, {self.gradV, self.gradG, self.gradBias}
else
return {self.v, self.g}, {self.gradV, self.gradG}
end
end

function WeightNorm:__tostring__()
local str = 'nn.WeightNorm [' .. tostring(self.modules[1]) .. ']'
return str
end
63 changes: 37 additions & 26 deletions doc/containers.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@ Complex neural networks are easily built using container classes:
* [Concat](#nn.Concat) : concatenates in one layer several modules along dimension `dim` ;
* [DepthConcat](#nn.DepthConcat) : like Concat, but adds zero-padding when non-`dim` sizes don't match;
* [Bottle](#nn.Bottle) : allows any dimensionality input be forwarded through a module ;

See also the [Table Containers](#nn.TableContainers) for manipulating tables of [Tensors](https://github.com/torch/torch7/blob/master/doc/tensor.md).

<a name="nn.Container"></a>
## Container ##

This is an abstract [Module](module.md#nn.Module) class which declares methods defined in all containers.
It reimplements many of the Module methods such that calls are propagated to the
It reimplements many of the Module methods such that calls are propagated to the
contained modules. For example, a call to [zeroGradParameters](module.md#nn.Module.zeroGradParameters)
will be propagated to all contained modules.

Expand All @@ -37,7 +37,7 @@ Returns the number of contained modules.
Sequential provides a means to plug layers together
in a feed-forward fully connected manner.

E.g.
E.g.
creating a one hidden-layer multi-layer perceptron is thus just as easy as:
```lua
mlp = nn.Sequential()
Expand Down Expand Up @@ -104,17 +104,17 @@ nn.Sequential {

`module` = `Parallel(inputDimension,outputDimension)`

Creates a container module that applies its `ith` child module to the `ith` slice of the input Tensor by using [select](https://github.com/torch/torch7/blob/master/doc/tensor.md#tensor-selectdim-index)
Creates a container module that applies its `ith` child module to the `ith` slice of the input Tensor by using [select](https://github.com/torch/torch7/blob/master/doc/tensor.md#tensor-selectdim-index)
on dimension `inputDimension`. It concatenates the results of its contained modules together along dimension `outputDimension`.

Example:
```lua
mlp = nn.Parallel(2,1); -- Parallel container will associate a module to each slice of dimension 2
-- (column space), and concatenate the outputs over the 1st dimension.

mlp:add(nn.Linear(10,3)); -- Linear module (input 10, output 3), applied on 1st slice of dimension 2
mlp:add(nn.Linear(10,2)) -- Linear module (input 10, output 2), applied on 2nd slice of dimension 2

-- After going through the Linear module the outputs are
-- concatenated along the unique dimension, to form 1D Tensor
> mlp:forward(torch.randn(10,2)) -- of size 5.
Expand All @@ -131,8 +131,8 @@ A more complicated example:

mlp = nn.Sequential();
c = nn.Parallel(1,2) -- Parallel container will associate a module to each slice of dimension 1
-- (row space), and concatenate the outputs over the 2nd dimension.
-- (row space), and concatenate the outputs over the 2nd dimension.

for i=1,10 do -- Add 10 Linear+Reshape modules in parallel (input = 3, output = 2x1)
local t=nn.Sequential()
t:add(nn.Linear(3,2)) -- Linear module (input = 3, output = 2)
Expand Down Expand Up @@ -165,7 +165,7 @@ for i = 1, 10000 do -- Train for a few iterations
local err = criterion:forward(pred,y)
local gradCriterion = criterion:backward(pred,y);
mlp:zeroGradParameters();
mlp:backward(x, gradCriterion);
mlp:backward(x, gradCriterion);
mlp:updateParameters(0.01);
print(err)
end
Expand Down Expand Up @@ -209,16 +209,16 @@ module = nn.DepthConcat(dim)
DepthConcat concatenates the output of one layer of "parallel" modules along the
provided dimension `dim`: they take the same inputs, and their output is
concatenated. For dimensions other than `dim` having different sizes,
the smaller tensors are copied in the center of the output tensor,
the smaller tensors are copied in the center of the output tensor,
effectively padding the borders with zeros.

The module is particularly useful for concatenating the output of [Convolutions](convolution.md)
along the depth dimension (i.e. `nOutputFrame`).
This is used to implement the *DepthConcat* layer
The module is particularly useful for concatenating the output of [Convolutions](convolution.md)
along the depth dimension (i.e. `nOutputFrame`).
This is used to implement the *DepthConcat* layer
of the [Going deeper with convolutions](http://arxiv.org/pdf/1409.4842v1.pdf) article.
The normal [Concat](#nn.Concat) Module can't be used since the spatial
dimensions (height and width) of the output Tensors requiring concatenation
may have different values. To deal with this, the output uses the largest
The normal [Concat](#nn.Concat) Module can't be used since the spatial
dimensions (height and width) of the output Tensors requiring concatenation
may have different values. To deal with this, the output uses the largest
spatial dimensions and adds zero-padding around the smaller Tensors.
```lua
inputSize = 3
Expand All @@ -231,7 +231,7 @@ mlp:add(nn.SpatialConvolutionMM(inputSize, outputSize, 3, 3))
mlp:add(nn.SpatialConvolutionMM(inputSize, outputSize, 4, 4))

> print(mlp:forward(input))
(1,.,.) =
(1,.,.) =
-0.2874 0.6255 1.1122 0.4768 0.9863 -0.2201 -0.1516
0.2779 0.9295 1.1944 0.4457 1.1470 0.9693 0.1654
-0.5769 -0.4730 0.3283 0.6729 1.3574 -0.6610 0.0265
Expand All @@ -240,7 +240,7 @@ mlp:add(nn.SpatialConvolutionMM(inputSize, outputSize, 4, 4))
0.4147 0.5062 0.6251 0.4374 0.3252 0.3478 0.0046
0.7845 -0.0902 0.3499 0.0342 1.0706 -0.0605 0.5525

(2,.,.) =
(2,.,.) =
-0.7351 -0.9327 -0.3092 -1.3395 -0.4596 -0.6377 -0.5097
-0.2406 -0.2617 -0.3400 -0.4339 -0.3648 0.1539 -0.2961
-0.7124 -1.2228 -0.2632 0.1690 0.4836 -0.9469 -0.7003
Expand All @@ -249,7 +249,7 @@ mlp:add(nn.SpatialConvolutionMM(inputSize, outputSize, 4, 4))
-0.3086 -0.0298 -0.2031 0.1026 -0.5785 -0.3275 -0.1630
0.0596 -0.6097 0.1443 -0.8603 -0.2774 -0.4506 -0.5367

(3,.,.) =
(3,.,.) =
0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000
0.0000 -0.7326 0.3544 0.1821 0.4796 1.0164 0.0000
0.0000 -0.9195 -0.0567 -0.1947 0.0169 0.1924 0.0000
Expand All @@ -258,7 +258,7 @@ mlp:add(nn.SpatialConvolutionMM(inputSize, outputSize, 4, 4))
0.0000 -0.1911 0.2912 0.5092 0.2955 0.7171 0.0000
0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000

(4,.,.) =
(4,.,.) =
0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000
0.0000 -0.8263 0.3646 0.6750 0.2062 0.2785 0.0000
0.0000 -0.7572 0.0432 -0.0821 0.4871 1.9506 0.0000
Expand All @@ -267,7 +267,7 @@ mlp:add(nn.SpatialConvolutionMM(inputSize, outputSize, 4, 4))
0.0000 0.2570 0.4694 -0.1262 0.5602 0.0821 0.0000
0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000

(5,.,.) =
(5,.,.) =
0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000
0.0000 0.3158 0.4389 -0.0485 -0.2179 0.0000 0.0000
0.0000 0.1966 0.6185 -0.9563 -0.3365 0.0000 0.0000
Expand All @@ -276,7 +276,7 @@ mlp:add(nn.SpatialConvolutionMM(inputSize, outputSize, 4, 4))
0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000
0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000

(6,.,.) =
(6,.,.) =
0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000
0.0000 1.1148 0.2324 -0.1093 0.5024 0.0000 0.0000
0.0000 -0.2624 -0.5863 0.3444 0.3506 0.0000 0.0000
Expand All @@ -286,11 +286,11 @@ mlp:add(nn.SpatialConvolutionMM(inputSize, outputSize, 4, 4))
0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000
[torch.DoubleTensor of dimension 6x7x7]
```
Note how the last 2 of 6 filter maps have 1 column of zero-padding
on the left and top, as well as 2 on the right and bottom.
Note how the last 2 of 6 filter maps have 1 column of zero-padding
on the left and top, as well as 2 on the right and bottom.
This is inevitable when the component
module output tensors non-`dim` sizes aren't all odd or even.
Such that in order to keep the mappings aligned, one need
module output tensors non-`dim` sizes aren't all odd or even.
Such that in order to keep the mappings aligned, one need
only ensure that these be all odd (or even).

<a name="nn.Bottle"></a>
Expand Down Expand Up @@ -323,6 +323,17 @@ mlp = nn.Bottle(nn.Linear(10, 2))
[torch.LongStorage of size 4]
```

<a name="nn.WeightNorm"></a>
## Weight Normalization

```lua
module = nn.WeightNorm(module)
```

WeightNorm implements the reparametrization presented in [Weight Normalization](https://arxiv.org/pdf/1602.07868v3.pdf), which decouples the length of neural network weight vectors from their direction. The weight vectors `w` is determined instead by parameters `g` and `v` such that `w = g * v / ||v||`, where `||v||` is the euclidean norm of vector v. This container can wrap nn layers with weights.

It accepts a parameter ``outputDim`` that represents the output dimension of the module weight it wraps, which defaults to 1. If the outputDim is not 1, the container will transpose the weight appropriately. If the module weight is not 2D, the container will view the weight into an appropriate 2D shape based on the outputDim specified by the user.

<a name="nn.TableContainers"></a>
## Table Containers ##
While the above containers are used for manipulating input [Tensors](https://github.com/torch/torch7/blob/master/doc/tensor.md), table containers are used for manipulating tables :
Expand Down
1 change: 1 addition & 0 deletions init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ require('nn.Parallel')
require('nn.Sequential')
require('nn.DepthConcat')
require('nn.Bottle')
require('nn.WeightNorm')

require('nn.Linear')
require('nn.Bilinear')
Expand Down
Loading

0 comments on commit 1b95f7e

Please sign in to comment.