Skip to content

Commit

Permalink
Add an optimizer with shared statistics
Browse files Browse the repository at this point in the history
  • Loading branch information
ikostrikov2 committed Mar 22, 2017
1 parent b0c1560 commit 85c7efd
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 4 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
This is a PyTorch implementation of Asynchronous Advantage Actor Critic (A3C) from ["Asynchronous Methods for Deep Reinforcement Learning"](https://arxiv.org/pdf/1602.01783v1.pdf).

This implementation is inspired by [Universe Starter Agent](https://github.com/openai/universe-starter-agent).
As in the starter agent, I don't share parameters of the optimizers between threads. If you want to have the same optimizer as in the original paper by DeepMind, you might want to check [this implementation.](https://github.com/rarilurelo/pytorch_a3c)
In contrast to the starter agent, it uses an optimizer with shared statistics as in the original paper.

## Contibutions

Expand Down
13 changes: 12 additions & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,16 @@
import sys

import torch
import torch.optim as optim
import torch.multiprocessing as mp
import torch.nn as nn
import torch.nn.functional as F
from envs import create_atari_env
from model import ActorCritic
from train import train
from test import test
import my_optim

# Based on
# https://github.com/pytorch/examples/tree/master/mnist_hogwild
# Training settings
Expand All @@ -32,6 +35,8 @@
help='maximum length of an episode (default: 10000)')
parser.add_argument('--env-name', default='PongDeterministic-v3', metavar='ENV',
help='environment to train on (default: PongDeterministic-v3)')
parser.add_argument('--no-shared', default=False, metavar='O',
help='use an optimizer without shared momentum.')


if __name__ == '__main__':
Expand All @@ -44,14 +49,20 @@
env.observation_space.shape[0], env.action_space)
shared_model.share_memory()

if args.no_shared:
optimizer = None
else:
optimizer = my_optim.SharedAdam(shared_model.parameters(), lr=args.lr)
optimizer.share_memory()

processes = []

p = mp.Process(target=test, args=(args.num_processes, args, shared_model))
p.start()
processes.append(p)

for rank in range(0, args.num_processes):
p = mp.Process(target=train, args=(rank, args, shared_model))
p = mp.Process(target=train, args=(rank, args, shared_model, optimizer))
p.start()
processes.append(p)
for p in processes:
Expand Down
24 changes: 24 additions & 0 deletions my_optim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import math
import torch.optim as optim

class SharedAdam(optim.Adam):
"""Implements Adam algorithm with shared states.
"""

def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
weight_decay=0):
super(SharedAdam, self).__init__(params, lr, betas, eps, weight_decay)

for group in self.param_groups:
for p in group['params']:
state = self.state[p]
state['step'] = 0
state['exp_avg'] = p.data.new().resize_as_(p.data).zero_()
state['exp_avg_sq'] = p.data.new().resize_as_(p.data).zero_()

def share_memory(self):
for group in self.param_groups:
for p in group['params']:
state = self.state[p]
state['exp_avg'].share_memory_()
state['exp_avg_sq'].share_memory_()
5 changes: 3 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,16 @@ def ensure_shared_grads(model, shared_model):
shared_param._grad = param.grad


def train(rank, args, shared_model):
def train(rank, args, shared_model, optimizer=None):
torch.manual_seed(args.seed + rank)

env = create_atari_env(args.env_name)
env.seed(args.seed + rank)

model = ActorCritic(env.observation_space.shape[0], env.action_space)

optimizer = optim.Adam(shared_model.parameters(), lr=args.lr)
if optimizer is None:
optimizer = optim.Adam(shared_model.parameters(), lr=args.lr)

model.train()

Expand Down

0 comments on commit 85c7efd

Please sign in to comment.