Skip to content

Commit

Permalink
initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
LazyDok committed Jan 23, 2018
1 parent f7626fc commit b6a24f8
Show file tree
Hide file tree
Showing 7 changed files with 84 additions and 21 deletions.
6 changes: 6 additions & 0 deletions .idea/vcs.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

57 changes: 44 additions & 13 deletions 1_dqn_cartpole.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@
import torch.optim as optim
from torch.autograd import Variable
import torch.nn.functional as F
import matplotlib.pyplot as plt
from moviepy.editor import ImageSequenceClip

# hyper parameters
EPISODES = 1000 # number of episodes
EPISODES = 2000 # number of episodes
EPS_START = 0.9 # e-greedy threshold start value
EPS_END = 0.05 # e-greedy threshold end value
EPS_DECAY = 200 # e-greedy threshold decay
Expand Down Expand Up @@ -41,28 +43,36 @@ def sample(self, batch_size):
def __len__(self):
return len(self.memory)


class Network(nn.Module):
def __init__(self):
nn.Module.__init__(self)
self.l1 = nn.Linear(4, HIDDEN_LAYER)
# self.bn1 = nn.BatchNorm1d(HIDDEN_LAYER)
self.l2 = nn.Linear(HIDDEN_LAYER, 2)

def forward(self, x):
# x = F.relu(self.bn1(self.l1(x)))
x = F.relu(self.l1(x))
x = self.l2(x)
return x

env = gym.make('CartPole-v0')
env = gym.make('CartPole-v0').unwrapped

model = Network()
if use_cuda:
model.cuda()
memory = ReplayMemory(10000)
optimizer = optim.Adam(model.parameters(), LR)
steps_done = 0
ed = []

# def plot_durations(d):
# plt.figure(2)
# plt.clf()
# plt.title('Training...')
# plt.xlabel('Episode')
# plt.ylabel('Duration')
# plt.plot(d)
#
# plt.savefig('test2.png')

def select_action(state, train=True):
global steps_done
Expand All @@ -81,16 +91,22 @@ def run_episode(episode, env):
state = env.reset()
steps = 0
while True:
env.render()
# env.render()
action = select_action(FloatTensor([state]))
next_state, reward, done, _ = env.step(action[0, 0])

# negative reward when attempt ends
if done:
if steps < 30:
reward -= 10
elif steps > 200:
reward += 5
else:
reward = -1
if steps > 100:
reward += 1
if steps > 200:
reward += 1
if steps > 300:
reward += 1

memory.push((FloatTensor([state]),
action, # action is already a tensor
Expand All @@ -102,12 +118,15 @@ def run_episode(episode, env):
state = next_state
steps += 1

if done:
if done or steps >= 1000:
ed.append(steps)
print("[Episode {:>5}] steps: {:>5}".format(episode, steps))
if sum(ed[-10:])/10 > 800:
return True
break
return False

def learn():

if len(memory) < BATCH_SIZE:
return

Expand Down Expand Up @@ -137,16 +156,28 @@ def learn():
def botPlay():
state = env.reset()
steps = 0
frames = []
while True:
env.render()
frame = env.render(mode='rgb_array')
frames.append(frame)
action = select_action(FloatTensor([state]))
next_state, reward, done, _ = env.step(action[0, 0])

state = next_state
steps += 1

if done:
if done or steps >= 1000:
break

clip = ImageSequenceClip(frames, fps=20)
clip.write_gif('test2.gif', fps=20)

for e in range(EPISODES):
run_episode(e, env)
complete = run_episode(e, env)

if complete:
print('complete...!')
break

# plot_durations(ed)
# botPlay()
42 changes: 34 additions & 8 deletions 2_double_dqn_cartpole.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import torch.optim as optim
from torch.autograd import Variable
import torch.nn.functional as F
import matplotlib.pyplot as plt
from moviepy.editor import ImageSequenceClip

# hyper parameters
EPISODES = 1000 # number of episodes
Expand Down Expand Up @@ -41,7 +43,6 @@ def sample(self, batch_size):
def __len__(self):
return len(self.memory)


class Network(nn.Module):
def __init__(self):
nn.Module.__init__(self)
Expand All @@ -63,6 +64,17 @@ def forward(self, x):
memory = ReplayMemory(10000)
optimizer = optim.Adam(model.parameters(), LR)
steps_done = 0
ed = []

# def plot_durations(d):
# plt.figure(2)
# plt.clf()
# plt.title('Training...')
# plt.xlabel('Episode')
# plt.ylabel('Duration')
# plt.plot(d)
#
# plt.savefig('test.png')

def select_action(state, train=True):
global steps_done
Expand All @@ -81,7 +93,7 @@ def run_episode(episode, env):
state = env.reset()
steps = 0
while True:
env.render()
# env.render()
action = select_action(FloatTensor([state]))
next_state, reward, done, _ = env.step(action[0, 0])

Expand All @@ -108,12 +120,15 @@ def run_episode(episode, env):
state = next_state
steps += 1

if done:
if done or steps >= 1000:
ed.append(steps)
print("[Episode {:>5}] steps: {:>5}".format(episode, steps))
if sum(ed[-10:])/10 > 800:
return True
break
return False

def learn():

if len(memory) < BATCH_SIZE:
return

Expand Down Expand Up @@ -143,23 +158,34 @@ def learn():
def botPlay():
state = env.reset()
steps = 0
frames = []
while True:
env.render()
frame = env.render(mode='rgb_array')
frames.append(frame)
action = select_action(FloatTensor([state]))
next_state, reward, done, _ = env.step(action[0, 0])

state = next_state
steps += 1

if done:
if done or steps >= 1000:
break

clip = ImageSequenceClip(frames, fps=20)
clip.write_gif('test.gif', fps=20)

for e in range(EPISODES):
run_episode(e, env)
complete = run_episode(e, env)

if complete:
break

if (e+1) % 5 == 0:
mp = list(target.parameters())
mcp = list(model.parameters())
n = len(mp)
for i in range(0, n):
mp[i].data[:] = mcp[i].data[:]
mp[i].data[:] = mcp[i].data[:]

# plot_durations(ed)
# botPlay()
Binary file added img/1_dqn_cartpole.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added img/1_dqn_cartpole.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added img/2_double_dqn_cartpole.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added img/2_double_dqn_cartpole.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit b6a24f8

Please sign in to comment.