-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathtrain_Q.m
60 lines (43 loc) · 1.42 KB
/
train_Q.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
function train_Q(Q, setting_str, episode_limit, time_limit, plot_flag,...
epsilon, if_eps_decay, decay_rate, gamma, alpha)
max_states_num = 3^6 * 2;
action_num = 6;
N = zeros(max_states_num, action_num);
score = zeros(episode_limit, 1);
K = decay_rate / episode_limit;
start_time = tic;
for episode = 1 : episode_limit
% epsilon anealing
if if_eps_decay
epsilon = 1 * exp(- K * episode);
end
fprintf('\nepisode = %d\n', episode);
num_env_cars = randi(5) + 20;
[Q, score(episode), collision_flag, N] ...
= Q_learning(Q, N, num_env_cars, plot_flag, epsilon, gamma, alpha);
plot_score(score(episode), episode, collision_flag, epsilon);
if mod(episode, 1000) == 0
cd Parameters
csvwrite(['Q_' setting_str '.csv'], Q);
csvwrite(['score_' setting_str '.csv'], score);
csvwrite(['N_' setting_str '.csv'], N);
[~, Policy_1] = max(Q,[],2);
csvwrite(['Policy_1_' setting_str '.csv'], Policy_1);
cd ..
end
if toc(start_time) > time_limit
break;
end
if plot_flag == true
close all;
end
end % for episode
cd Parameters
csvwrite(['Q_' setting_str '.csv'], Q);
csvwrite(['score_' setting_str '.csv'], score);
csvwrite(['N_' setting_str '.csv'], N);
[~, Policy_1] = max(Q,[],2);
csvwrite(['Policy_1_' setting_str '.csv'], Policy_1);
cd ..
figure;
image(N);