-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathCCritic.m
73 lines (73 loc) · 2.79 KB
/
CCritic.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
classdef CCritic < handle
properties
alpha_v;
gamma1;
gamma2;
input_dim;
v_ji;
v_init_range;
norm_param;
z_i_prev;
z_j_prev;
J;
end
methods
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function obj = CCritic(PARAM)
% PARAM = {alpha_v,gamma1,gamma2,input_dim,v_init_range};
obj.alpha_v = PARAM{1};
obj.gamma1 = PARAM{2};
obj.gamma2 = PARAM{3};
obj.input_dim = PARAM{4};
obj.v_init_range = PARAM{5};
obj.norm_param=zeros(1,2);
%
obj.v_ji = (2*rand(1,obj.input_dim)-1)*obj.v_init_range;
obj.z_j_prev = 0;
obj.J = 0;
% obj.J = -40;
end
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function z_j = forward(this,z_i)
z_j = this.v_ji * z_i;
end
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function [delta,params] = update(this,z_j,reward)
params = zeros(1,2);
this.J = (1-this.gamma1) * this.J + this.gamma1*reward;
delta = reward - this.J + this.gamma2 * z_j - this.z_j_prev;
dv_ji = this.alpha_v * delta * this.z_i_prev';
% this.v_ji = (1-0.005*this.alpha_v)*this.v_ji;
this.v_ji = this.v_ji + dv_ji;
params(1) = norm(this.v_ji,'fro');
params(2) = delta;
this.norm_param(1,1)=norm(this.v_ji,'fro');
this.norm_param(1,2)=delta;
end
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function [delta,params] = train(this,z_i,reward,iters,update)
delta = 0;
% params = [0,0];
% this.alpha_v=(0.2)*exp(-iters/100000)+0.01;
% this.alpha_v=0.5;
z_j = this.forward(z_i);
if(iters>1 && update)
[delta,~] = this.update(z_j,reward);
end
params=this.norm_param;
this.z_i_prev = z_i;
this.z_j_prev = z_j;
end
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%% end
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
end
end