diff --git a/Diag.lua b/Diag.lua index dfa3b96..816c94c 100644 --- a/Diag.lua +++ b/Diag.lua @@ -14,21 +14,29 @@ end function Diag:updateOutput(input) self.output:resizeAs(input):copy(input) - if input:dim() > 1 then - for i=1,input:size(1) do - self.output[{{i}}]:mul(self.weight[i]) + if input:dim() == 4 then -- batch mode + for i=1,input:size(2) do + self.output[{{},{i}}]:mul(self.weight[i]) end + elseif input:dim() == 3 then -- no batch + for i=1,input:size(1) do + self.output[{{i}}]:mul(self.weight[i]) + end else - self.output:cmul(self.weight) + self.output:cmul(self.weight) end return self.output end function Diag:updateGradInput(input, gradOutput) self.gradInput:resizeAs(gradOutput):copy(gradOutput) - if input:dim() > 1 then + if input:dim() == 4 then -- batch mode + for i=1,input:size(2) do + self.gradInput[{{},{i}}]:mul(self.weight[i]) + end + elseif input:dim() == 3 then for i=1,input:size(1) do - self.gradInput[{{i}}]:mul(self.weight[i]) + self.gradInput[{{i}}]:mul(self.weight[i]) end else self.gradInput:cmul(self.weight) @@ -37,8 +45,14 @@ function Diag:updateGradInput(input, gradOutput) end function Diag:accGradParameters(input, gradOutput, scale) + if input:dim() == 4 then -- batch mode + for i=1,input:size(2) do + self.gradWeight[i] = self.gradWeight[i] + scale*gradOutput[{{},{i}}]:dot(input[{{},{i}}]) + end + elseif input:dim() == 3 then for i=1,input:size(1) do self.gradWeight[i] = self.gradWeight[i] + scale*gradOutput[{{i}}]:dot(input[{{i}}]) end + end end diff --git a/SpatialConvFistaL1.lua b/SpatialConvFistaL1.lua index 7dcce27..6a74ea8 100644 --- a/SpatialConvFistaL1.lua +++ b/SpatialConvFistaL1.lua @@ -8,7 +8,7 @@ local SpatialConvFistaL1, parent = torch.class('unsupgpu.SpatialConvFistaL1','un -- padh : zero padding vertical -- lambda : sparsity coefficient -- params : optim.FistaLS parameters -function SpatialConvFistaL1:__init(conntable, kw, kh, iw, ih, padw, padh, lambda, params) +function SpatialConvFistaL1:__init(conntable, kw, kh, iw, ih, padw, padh, lambda, batchSize) -- parent.__init(self) @@ -22,18 +22,29 @@ function SpatialConvFistaL1:__init(conntable, kw, kh, iw, ih, padw, padh, lambda ----------------------------------------- -- L2 reconstruction cost with weighting ----------------------------------------- - local tt = torch.Tensor(inputFeatures,ih,iw) - local utt= tt:unfold(2,kh,1):unfold(3,kw,1) + local batchSize = batchSize or 1 + local tt, utt + if batchSize > 1 then + tt = torch.Tensor(batchSize,inputFeatures,ih,iw) + utt= tt:unfold(3,kh,1):unfold(4,kw,1) + else + tt = torch.Tensor(inputFeatures,ih,iw) + utt= tt:unfold(2,kh,1):unfold(3,kw,1) + end tt:zero() utt:add(1) tt:div(tt:max()) local Fcost = nn.WeightedMSECriterion(tt) Fcost.sizeAverage = false; - parent.__init(self,D,Fcost,lambda,params) + parent.__init(self,D,Fcost,lambda) -- this is going to be passed to optim.FistaLS - self.code:resize(outputFeatures, utt:size(2)+2*padw,utt:size(3)+2*padh) + if batchSize > 1 then + self.code:resize(batchSize, outputFeatures, utt:size(3)+2*padw,utt:size(4)+2*padh) + else + self.code:resize(outputFeatures, utt:size(2)+2*padw,utt:size(3)+2*padh) + end self.code:fill(0) end