Skip to content

Commit

Permalink
now running on mini-batches
Browse files Browse the repository at this point in the history
  • Loading branch information
viorik committed Oct 5, 2015
1 parent e890207 commit ff9e475
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 11 deletions.
26 changes: 20 additions & 6 deletions Diag.lua
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,29 @@ end

function Diag:updateOutput(input)
self.output:resizeAs(input):copy(input)
if input:dim() > 1 then
for i=1,input:size(1) do
self.output[{{i}}]:mul(self.weight[i])
if input:dim() == 4 then -- batch mode
for i=1,input:size(2) do
self.output[{{},{i}}]:mul(self.weight[i])
end
elseif input:dim() == 3 then -- no batch
for i=1,input:size(1) do
self.output[{{i}}]:mul(self.weight[i])
end
else
self.output:cmul(self.weight)
self.output:cmul(self.weight)
end
return self.output
end

function Diag:updateGradInput(input, gradOutput)
self.gradInput:resizeAs(gradOutput):copy(gradOutput)
if input:dim() > 1 then
if input:dim() == 4 then -- batch mode
for i=1,input:size(2) do
self.gradInput[{{},{i}}]:mul(self.weight[i])
end
elseif input:dim() == 3 then
for i=1,input:size(1) do
self.gradInput[{{i}}]:mul(self.weight[i])
self.gradInput[{{i}}]:mul(self.weight[i])
end
else
self.gradInput:cmul(self.weight)
Expand All @@ -37,8 +45,14 @@ function Diag:updateGradInput(input, gradOutput)
end

function Diag:accGradParameters(input, gradOutput, scale)
if input:dim() == 4 then -- batch mode
for i=1,input:size(2) do
self.gradWeight[i] = self.gradWeight[i] + scale*gradOutput[{{},{i}}]:dot(input[{{},{i}}])
end
elseif input:dim() == 3 then
for i=1,input:size(1) do
self.gradWeight[i] = self.gradWeight[i] + scale*gradOutput[{{i}}]:dot(input[{{i}}])
end
end
end

21 changes: 16 additions & 5 deletions SpatialConvFistaL1.lua
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ local SpatialConvFistaL1, parent = torch.class('unsupgpu.SpatialConvFistaL1','un
-- padh : zero padding vertical
-- lambda : sparsity coefficient
-- params : optim.FistaLS parameters
function SpatialConvFistaL1:__init(conntable, kw, kh, iw, ih, padw, padh, lambda, params)
function SpatialConvFistaL1:__init(conntable, kw, kh, iw, ih, padw, padh, lambda, batchSize)

-- parent.__init(self)

Expand All @@ -22,18 +22,29 @@ function SpatialConvFistaL1:__init(conntable, kw, kh, iw, ih, padw, padh, lambda
-----------------------------------------
-- L2 reconstruction cost with weighting
-----------------------------------------
local tt = torch.Tensor(inputFeatures,ih,iw)
local utt= tt:unfold(2,kh,1):unfold(3,kw,1)
local batchSize = batchSize or 1
local tt, utt
if batchSize > 1 then
tt = torch.Tensor(batchSize,inputFeatures,ih,iw)
utt= tt:unfold(3,kh,1):unfold(4,kw,1)
else
tt = torch.Tensor(inputFeatures,ih,iw)
utt= tt:unfold(2,kh,1):unfold(3,kw,1)
end
tt:zero()
utt:add(1)
tt:div(tt:max())
local Fcost = nn.WeightedMSECriterion(tt)
Fcost.sizeAverage = false;

parent.__init(self,D,Fcost,lambda,params)
parent.__init(self,D,Fcost,lambda)

-- this is going to be passed to optim.FistaLS
self.code:resize(outputFeatures, utt:size(2)+2*padw,utt:size(3)+2*padh)
if batchSize > 1 then
self.code:resize(batchSize, outputFeatures, utt:size(3)+2*padw,utt:size(4)+2*padh)
else
self.code:resize(outputFeatures, utt:size(2)+2*padw,utt:size(3)+2*padh)
end
self.code:fill(0)
end

0 comments on commit ff9e475

Please sign in to comment.