Skip to content

Commit

Permalink
Deterministically evaluate in a separate thread
Browse files Browse the repository at this point in the history
  • Loading branch information
ikostrikov2 committed Feb 17, 2017
1 parent 29ef0f4 commit 721a3c3
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 14 deletions.
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ This is a PyTorch implementation of Asynchronous Advantage Actor Critic (A3C) fr

This implementation is inspired by [Universe Starter Agent](https://github.com/openai/universe-starter-agent).


## Contibutions

Contributions are very welcome. If you know how to make this code better, don't hesitate to send a pull request.
Expand All @@ -13,6 +14,15 @@ Contributions are very welcome. If you know how to make this code better, don't
python main.py --env-name "PongDeterministic-v3" --num-processes 16
```

This code runs evaluation in a separate thread in addition to 16 processes.

## Results

With 16 processes it converges for PongDeterministic-v3 in 15 minutes.
![PongDeterministic-v3](images/PongReward.png)

For BreakoutDeterministic-v3 it takes more than several hours.

## Todo

- [ ] Deterministic evaluation in a separate thread
Binary file added images/PongReward.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
11 changes: 9 additions & 2 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from envs import create_atari_env
from model import ActorCritic
from train import train

from test import test
# Based on
# https://github.com/pytorch/examples/tree/master/mnist_hogwild
# Training settings
Expand All @@ -28,6 +28,8 @@
help='how many training processes to use (default: 4)')
parser.add_argument('--num-steps', type=int, default=20, metavar='NS',
help='number of forward steps in A3C (default: 20)')
parser.add_argument('--max-episode-length', type=int, default=10000, metavar='M',
help='maximum length of an episode (default: 10000)')
parser.add_argument('--env-name', default='PongDeterministic-v3', metavar='ENV',
help='environment to train on (default: PongDeterministic-v3)')

Expand All @@ -44,7 +46,12 @@
shared_model.share_memory()

processes = []
for rank in range(args.num_processes):

p = mp.Process(target=test, args=(args.num_processes, args, shared_model))
p.start()
processes.append(p)

for rank in range(0, args.num_processes):
p = mp.Process(target=train, args=(rank, args, shared_model))
p.start()
processes.append(p)
Expand Down
72 changes: 72 additions & 0 deletions test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import math
import os
import sys

import torch
import torch.nn.functional as F
import torch.optim as optim
from envs import create_atari_env
from model import ActorCritic
from torch.autograd import Variable
from torchvision import datasets, transforms
import time
from collections import deque


def test(rank, args, shared_model):
torch.manual_seed(args.seed + rank)

env = create_atari_env(args.env_name)
env.seed(args.seed + rank)

model = ActorCritic(env.observation_space.shape[0], env.action_space)

model.eval()

state = env.reset()
state = torch.from_numpy(state)
reward_sum = 0
done = True

start_time = time.time()

# a quick hack to prevent the agent from stucking
actions = deque(maxlen=100)
episode_length = 0
while True:
episode_length += 1
# Sync with the shared model
if done:
model.load_state_dict(shared_model.state_dict())
cx = Variable(torch.zeros(1, 256), volatile=True)
hx = Variable(torch.zeros(1, 256), volatile=True)
else:
cx = Variable(cx.data, volatile=True)
hx = Variable(hx.data, volatile=True)

value, logit, (hx, cx) = model(
(Variable(state.unsqueeze(0), volatile=True), (hx, cx)))
prob = F.softmax(logit)
action = prob.max(1)[1].data.numpy()

state, reward, done, _ = env.step(action[0, 0])
done = done or episode_length >= args.max_episode_length
reward_sum += reward

# a quick hack to prevent the agent from stucking
actions.append(action[0, 0])
if actions.count(actions[0]) == actions.maxlen:
done = True

if done:
print("Time {}, episode reward {}, episode length {}".format(
time.strftime("%Hh %Mm %Ss",
time.gmtime(time.time() - start_time)),
reward_sum, episode_length))
reward_sum = 0
episode_length = 0
actions.clear()
state = env.reset()
time.sleep(60)

state = torch.from_numpy(state)
17 changes: 5 additions & 12 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,17 @@ def train(rank, args, shared_model):
optimizer = optim.Adam(shared_model.parameters(), lr=args.lr)

model.train()
pid = os.getpid()

values = []
log_probs = []

state = env.reset()
state = torch.from_numpy(state)
reward_sum = 0
done = True

running_reward = 0
num_updates = 0
episode_length = 0
while True:
episode_length += 1
# Sync with the shared model
model.load_state_dict(shared_model.state_dict())
if done:
Expand All @@ -65,16 +63,11 @@ def train(rank, args, shared_model):
log_prob = log_prob.gather(1, Variable(action))

state, reward, done, _ = env.step(action.numpy())
reward_sum += reward
done = done or episode_length >= args.max_episode_length
reward = max(min(reward, 1), -1)
if done:
running_reward = running_reward * 0.9 + reward_sum * 0.1
num_updates += 1

if rank == 0:
print("Agent {2}, episodes {0}, running reward {1:.2f}, current reward {3}".format(
num_updates, running_reward / (1 - pow(0.9, num_updates)), rank, reward_sum))
reward_sum = 0
if done:
episode_length = 0
state = env.reset()

state = torch.from_numpy(state)
Expand Down

0 comments on commit 721a3c3

Please sign in to comment.