forked from tpbarron/pytorch-a2c
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 3e2d230
Showing
5 changed files
with
325 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
# pytorch-a3c | ||
|
||
This is a PyTorch implementation of Asynchronous Advantage Actor Critic (A3C) from ["Asynchronous Methods for Deep Reinforcement Learning"](https://arxiv.org/pdf/1602.01783v1.pdf). | ||
|
||
This implementation is inspired by [Universe Starter Agent](https://github.com/openai/universe-starter-agent). | ||
|
||
## Contibutions | ||
|
||
Contributions are very welcome. If you know how to make this code better, don't hesitate to send a pool request. | ||
|
||
## Usage | ||
``` | ||
python main.py --env_name "PongDeterministic-v3" --num_processes 16 | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
import gym | ||
import numpy as np | ||
import universe | ||
from gym.spaces.box import Box | ||
from universe import vectorized | ||
from universe.wrappers import Unvectorize, Vectorize | ||
|
||
import cv2 | ||
|
||
|
||
# Taken from https://github.com/openai/universe-starter-agent | ||
def create_atari_env(env_id): | ||
env = gym.make(env_id) | ||
if len(env.observation_space.shape) > 1: | ||
env = Vectorize(env) | ||
env = AtariRescale42x42(env) | ||
env = NormalizedEnv(env) | ||
env = Unvectorize(env) | ||
return env | ||
|
||
|
||
def _process_frame42(frame): | ||
frame = frame[34:34 + 160, :160] | ||
# Resize by half, then down to 42x42 (essentially mipmapping). If | ||
# we resize directly we lose pixels that, when mapped to 42x42, | ||
# aren't close enough to the pixel boundary. | ||
frame = cv2.resize(frame, (80, 80)) | ||
frame = cv2.resize(frame, (42, 42)) | ||
frame = frame.mean(2) | ||
frame = frame.astype(np.float32) | ||
frame *= (1.0 / 255.0) | ||
frame = np.reshape(frame, [1, 42, 42]) | ||
return frame | ||
|
||
|
||
class AtariRescale42x42(vectorized.ObservationWrapper): | ||
|
||
def __init__(self, env=None): | ||
super(AtariRescale42x42, self).__init__(env) | ||
self.observation_space = Box(0.0, 1.0, [1, 42, 42]) | ||
|
||
def _observation(self, observation_n): | ||
return [_process_frame42(observation) for observation in observation_n] | ||
|
||
|
||
class NormalizedEnv(vectorized.ObservationWrapper): | ||
|
||
def __init__(self, env=None): | ||
super(NormalizedEnv, self).__init__(env) | ||
self.state_mean = 0 | ||
self.state_std = 0 | ||
self.alpha = 0.999 | ||
self.num_steps = 0 | ||
|
||
def _observation(self, observation_n): | ||
for observation in observation_n: | ||
self.num_steps += 1 | ||
self.state_mean = self.state_mean * self.alpha + \ | ||
observation.mean() * (1 - self.alpha) | ||
self.state_std = self.state_std * self.alpha + \ | ||
observation.std() * (1 - self.alpha) | ||
|
||
unbiased_mean = self.state_mean / (1 - pow(self.alpha, self.num_steps)) | ||
unbiased_std = self.state_std / (1 - pow(self.alpha, self.num_steps)) | ||
|
||
return [(observation - unbiased_mean) / (unbiased_std + 1e-8) for observation in observation_n] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
from __future__ import print_function | ||
|
||
import argparse | ||
import os | ||
import sys | ||
|
||
import torch | ||
import torch.multiprocessing as mp | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from envs import create_atari_env | ||
from model import ActorCritic | ||
from train import train | ||
|
||
# Based on | ||
# https://github.com/pytorch/examples/tree/master/mnist_hogwild | ||
# Training settings | ||
parser = argparse.ArgumentParser(description='A3C') | ||
parser.add_argument('--lr', type=float, default=0.0001, metavar='LR', | ||
help='learning rate (default: 0.0001)') | ||
parser.add_argument('--gamma', type=float, default=0.99, metavar='G', | ||
help='discount factor for rewards (default: 0.99)') | ||
parser.add_argument('--tau', type=float, default=1.00, metavar='T', | ||
help='parameter for GAE (default: 1.00)') | ||
parser.add_argument('--seed', type=int, default=1, metavar='S', | ||
help='random seed (default: 1)') | ||
parser.add_argument('--num-processes', type=int, default=4, metavar='N', | ||
help='how many training processes to use (default: 4)') | ||
parser.add_argument('--num-steps', type=int, default=20, metavar='NS', | ||
help='number of forward steps in A3C (default: 20)') | ||
parser.add_argument('--env-name', default='PongDeterministic-v3', metavar='ENV', | ||
help='environment to train on (default: PongDeterministic-v3)') | ||
|
||
|
||
if __name__ == '__main__': | ||
args = parser.parse_args() | ||
|
||
torch.manual_seed(args.seed) | ||
torch.set_num_threads(1) | ||
|
||
env = create_atari_env(args.env_name) | ||
shared_model = ActorCritic(env.observation_space.shape[0], env.action_space) | ||
shared_model.share_memory() | ||
|
||
processes = [] | ||
for rank in range(args.num_processes): | ||
p = mp.Process(target=train, args=(rank, args, shared_model)) | ||
p.start() | ||
processes.append(p) | ||
for p in processes: | ||
p.join() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
import math | ||
|
||
import numpy as np | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from torch.autograd import Variable | ||
|
||
|
||
def normalized_columns_initializer(weights, std=1.0): | ||
out = torch.randn(weights.size()) | ||
out *= std / torch.sqrt(out.pow(2).sum(1).expand_as(out)) | ||
return out | ||
|
||
|
||
def weights_init(m): | ||
classname = m.__class__.__name__ | ||
if classname.find('Conv') != -1: | ||
weight_shape = list(m.weight.data.size()) | ||
fan_in = np.prod(weight_shape[1:4]) | ||
fan_out = np.prod(weight_shape[2:4]) * weight_shape[0] | ||
w_bound = np.sqrt(6. / (fan_in + fan_out)) | ||
m.weight.data.uniform_(-w_bound, w_bound) | ||
m.bias.data.fill_(0) | ||
elif classname.find('Linear') != -1: | ||
weight_shape = list(m.weight.data.size()) | ||
fan_in = weight_shape[1] | ||
fan_out = weight_shape[0] | ||
w_bound = np.sqrt(6. / (fan_in + fan_out)) | ||
m.weight.data.uniform_(-w_bound, w_bound) | ||
m.bias.data.fill_(0) | ||
|
||
|
||
class ActorCritic(torch.nn.Module): | ||
|
||
def __init__(self, num_inputs, action_space): | ||
super(ActorCritic, self).__init__() | ||
self.conv1 = nn.Conv2d(num_inputs, 32, 3, stride=2) | ||
self.conv2 = nn.Conv2d(32, 32, 3, stride=2) | ||
self.conv3 = nn.Conv2d(32, 32, 3, stride=2) | ||
|
||
self.lstm = nn.LSTMCell(32 * 4 * 4, 256) | ||
|
||
num_outputs = action_space.n | ||
self.critic_linear = nn.Linear(256, 1) | ||
self.actor_linear = nn.Linear(256, num_outputs) | ||
|
||
self.apply(weights_init) | ||
self.actor_linear.weight.data = normalized_columns_initializer( | ||
self.actor_linear.weight.data, 0.01) | ||
self.actor_linear.bias.data.fill_(0) | ||
self.critic_linear.weight.data = normalized_columns_initializer( | ||
self.critic_linear.weight.data, 1.0) | ||
self.critic_linear.bias.data.fill_(0) | ||
|
||
self.lstm.bias_ih.data.fill_(0) | ||
self.lstm.bias_hh.data.fill_(0) | ||
|
||
self.train() | ||
|
||
def forward(self, inputs): | ||
inputs, (hx, cx) = inputs | ||
x = F.elu(self.conv1(inputs)) | ||
x = F.elu(self.conv2(x)) | ||
x = F.elu(self.conv3(x)) | ||
|
||
x = x.view(-1, 32 * 4 * 4) | ||
hx, cx = self.lstm(x, (hx, cx)) | ||
x = hx | ||
|
||
return self.critic_linear(x), self.actor_linear(x), (hx, cx) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
import math | ||
import os | ||
import sys | ||
|
||
import torch | ||
import torch.nn.functional as F | ||
import torch.optim as optim | ||
from envs import create_atari_env | ||
from model import ActorCritic | ||
from torch.autograd import Variable | ||
from torchvision import datasets, transforms | ||
|
||
|
||
def train(rank, args, shared_model): | ||
torch.manual_seed(args.seed + rank) | ||
|
||
env = create_atari_env(args.env_name) | ||
env.seed(args.seed + rank) | ||
|
||
model = ActorCritic(env.observation_space.shape[0], env.action_space) | ||
|
||
for param, shared_param in zip(model.parameters(), shared_model.parameters()): | ||
# Use gradients from the local model | ||
shared_param.grad.data = param.grad.data | ||
|
||
optimizer = optim.Adam(shared_model.parameters(), lr=args.lr) | ||
|
||
model.train() | ||
pid = os.getpid() | ||
|
||
values = [] | ||
log_probs = [] | ||
|
||
state = env.reset() | ||
state = torch.from_numpy(state) | ||
reward_sum = 0 | ||
done = True | ||
|
||
running_reward = 0 | ||
num_updates = 0 | ||
while True: | ||
# Sync with the shared model | ||
model.load_state_dict(shared_model.state_dict()) | ||
if done: | ||
cx = Variable(torch.zeros(1, 256)) | ||
hx = Variable(torch.zeros(1, 256)) | ||
else: | ||
cx = Variable(cx.data) | ||
hx = Variable(hx.data) | ||
|
||
values = [] | ||
log_probs = [] | ||
rewards = [] | ||
entropies = [] | ||
|
||
for step in range(args.num_steps): | ||
value, logit, (hx, cx) = model( | ||
(Variable(state.unsqueeze(0)), (hx, cx))) | ||
prob = F.softmax(logit) | ||
log_prob = F.log_softmax(logit) | ||
entropy = -(log_prob * prob).sum(1) | ||
entropies.append(entropy) | ||
|
||
action = prob.multinomial().data | ||
log_prob = log_prob.gather(1, Variable(action)) | ||
|
||
state, reward, done, _ = env.step(action.numpy()) | ||
reward_sum += reward | ||
reward = max(min(reward, 1), -1) | ||
if done: | ||
running_reward = running_reward * 0.9 + reward_sum * 0.1 | ||
num_updates += 1 | ||
|
||
if rank == 0: | ||
print("Agent {2}, episodes {0}, running reward {1:.2f}, current reward {3}".format( | ||
num_updates, running_reward / (1 - pow(0.9, num_updates)), rank, reward_sum)) | ||
reward_sum = 0 | ||
state = env.reset() | ||
|
||
state = torch.from_numpy(state) | ||
values.append(value) | ||
log_probs.append(log_prob) | ||
rewards.append(reward) | ||
|
||
if done: | ||
break | ||
|
||
R = torch.zeros(1, 1) | ||
if not done: | ||
value, _, _ = model((Variable(state.unsqueeze(0)), (hx, cx))) | ||
R = value.data | ||
|
||
values.append(Variable(R)) | ||
policy_loss = 0 | ||
value_loss = 0 | ||
R = Variable(R) | ||
gae = torch.zeros(1, 1) | ||
for i in reversed(range(len(rewards))): | ||
R = args.gamma * R + rewards[i] | ||
advantage = R - values[i] | ||
value_loss = value_loss + 0.5 * advantage.pow(2) | ||
|
||
# Generalized Advantage Estimataion | ||
delta_t = rewards[i] + args.gamma * \ | ||
values[i + 1].data - values[i].data | ||
gae = gae * args.gamma * args.tau + delta_t | ||
|
||
policy_loss = policy_loss - \ | ||
log_probs[i] * Variable(gae) - 0.01 * entropies[i] | ||
|
||
optimizer.zero_grad() | ||
(policy_loss + 0.5 * value_loss).backward() | ||
|
||
global_norm = 0 | ||
for param in model.parameters(): | ||
global_norm += param.grad.data.pow(2).sum() | ||
global_norm = math.sqrt(global_norm) | ||
ratio = 40 / global_norm | ||
if ratio < 1: | ||
for param in model.parameters(): | ||
param.grad.data.mul_(ratio) | ||
optimizer.step() |