forked from torch/optim
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsparsecoding.lua
127 lines (110 loc) · 3.82 KB
/
sparsecoding.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
require 'kex'
-- L1 FISTA Solution
-- L1 solution with a linear dictionary ||Ax-b||^2 + \lambda ||x||_1
-- D : dictionary, each column is a dictionary element
-- params: set of params to pass to FISTA and possibly temp allocation (**optional**)
-- check unsup.FistaLS function for details.
-- returns fista : a table with the following entries
-- fista.run(x,lambda) : run L1 sparse coding algorithm with input x and lambda.
-- The following entries will be allocated and reused by each call to fista.run(x,lambda)
-- fista.reconstruction: reconstructed input.
-- fista.gradf : gradient of L2 part of the problem wrt x
-- fista.code : the solution of L1 problem
-- The following entries just point to data passed to fista.run(x)
-- fista.input : points to the tensor 'x' used in the last fista.run(x,lambda)
-- fista.lambda : the lambda value used in the last fista.run(x,lambda)
function optim.FistaL1(D, params)
-- this is for keeping parameters related to fista algorithm
local params = params or {}
-- this is for temporary variables and such
local fista = {}
-- related to FISTA
params.L = params.L or 0.1
params.Lstep = params.Lstep or 1.5
params.maxiter = params.maxiter or 50
params.maxline = params.maxline or 20
params.errthres = params.errthres or 1e-4
-- temporary stuff that might be good to keep around
fista.reconstruction = torch.Tensor()
fista.gradf = torch.Tensor()
fista.gradg = torch.Tensor()
fista.code = torch.Tensor()
-- these will be assigned in run(x)
-- fista.input points to the last input that was run
-- fista.lambda is the lambda value from the last run
fista.input = nil
fista.lambda = nil
-- CREATE FUNCTION CLOSURES
-- smooth function
fista.f = function (x,mode)
local reconstruction = fista.reconstruction
local input = fista.input
-- -------------------
-- function evaluation
if x:dim() == 1 then
--print(D:size(),x:size())
reconstruction:resize(D:size(1))
reconstruction:addmv(0,1,D,x)
elseif x:dim(2) then
reconstruction:resize(x:size(1),D:size(1))
reconstruction:addmm(0,1,x,D:t())
end
local fval = input:dist(reconstruction)^2
-- ----------------------
-- derivative calculation
if mode and mode:match('dx') then
local gradf = fista.gradf
reconstruction:add(-1,input):mul(2)
gradf:resizeAs(x)
if input:dim() == 1 then
gradf:addmv(0,1,D:t(),reconstruction)
else
gradf:addmm(0,1,reconstruction, D)
end
---------------------------------------
-- return function value and derivative
return fval, gradf, reconstruction
end
------------------------
-- return function value
return fval, reconstruction
end
-- non-smooth function L1
fista.g = function (x)
local fval = fista.lambda*x:norm(1)
if mod and mode:match('dx') then
local gradg = fista.gradg
gradg:resizAs(x)
gradg:sign():mul(fista.lambda)
return fval,gradg
end
return fval
end
-- argmin_x Q(x,y), just shrinkage for L1
fista.pl = function (x,L)
x:shrinkage(fista.lambda/L)
end
fista.run = function(x, lam, codeinit)
local code = fista.code
fista.input = x
fista.lambda = lam
-- resize code, maybe a different number of dimensions
-- fill with zeros, initial point
if codeinit then
code:resizeAs(codeinit)
code:copy(codeinit)
else
if x:dim() == 1 then
code:resize(D:size(2))
elseif x:dim() == 2 then
code:resize(x:size(1),D:size(2))
else
error(' I do not know how to handle ' .. x:dim() .. ' dimensional input')
end
code:fill(0)
end
-- return the result of unsup.FistaLS call.
return optim.FistaLS(fista.f, fista.g, fista.pl, fista.code, params)
end
return fista
end