-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun2_dqn_2way-single-intersection.py
executable file
·66 lines (50 loc) · 1.97 KB
/
run2_dqn_2way-single-intersection.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import gym
import argparse
from datetime import datetime
from stable_baselines3.dqn.dqn import DQN
import os
import sys
if 'SUMO_HOME' in os.environ:
tools = os.path.join(os.environ['SUMO_HOME'], 'tools')
sys.path.append(tools)
else:
sys.exit("Please declare the environment variable 'SUMO_HOME'")
from sumo_rl import SumoEnvironment
import traci
if __name__ == '__main__':
# add by me
experiment_time = str(datetime.now()).split('.')[0]
env = SumoEnvironment(net_file='nets/2way-single-intersection/single-intersection.net.xml',
route_file='nets/2way-single-intersection/single-intersection-vhvh-type-test.rou.xml',
out_csv_name='outputs/{}_DQN_2way'.format(experiment_time),
single_agent=True,
use_gui=True,
num_seconds=100000,
max_depart_delay=0)
model = DQN(
env=env,
policy="MlpPolicy",
learning_rate=0.01,
learning_starts=0,
train_freq=1,
target_update_interval=100,
exploration_initial_eps=0.05,
exploration_final_eps=0.01,
verbose=1
)
#model.learn(total_timesteps=100000)
#changed by me
prs = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter,
description="""Q-Learning Single-Intersection""")
prs.add_argument("-pretrain", action="store_true", default=False, help="Do you want to use pretained model?\n")
args = prs.parse_args()
if args.pretrain:
model = DQN.load("outputs/last_saved_dqn_2way")
model.set_env(env)
model.learn(total_timesteps=100000)
else:
model.learn(total_timesteps=100000)
save_model = input('Do you want to save model (Y/N) ?')
if save_model == 'Y':
model.save("outputs/last_saved_dqn_2way")
env.close()