Skip to content

Commit

Permalink
a3c -> a2c
Browse files Browse the repository at this point in the history
  • Loading branch information
tpbarron committed Jun 18, 2017
1 parent 842ec4d commit e0c169a
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 114 deletions.
36 changes: 7 additions & 29 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from model import ActorCritic
from train import train
from test import test
import my_optim

# Based on
# https://github.com/pytorch/examples/tree/master/mnist_hogwild
Expand All @@ -27,45 +26,24 @@
help='parameter for GAE (default: 1.00)')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed (default: 1)')
parser.add_argument('--num-processes', type=int, default=4, metavar='N',
help='how many training processes to use (default: 4)')
parser.add_argument('--num-steps', type=int, default=20, metavar='NS',
help='number of forward steps in A3C (default: 20)')
parser.add_argument('--num-updates', type=int, default=100, metavar='NU',
help='number of updates between tests (default: 100)')
parser.add_argument('--max-episode-length', type=int, default=10000, metavar='M',
help='maximum length of an episode (default: 10000)')
parser.add_argument('--env-name', default='PongDeterministic-v3', metavar='ENV',
parser.add_argument('--env-name', default='PongDeterministic-v4', metavar='ENV',
help='environment to train on (default: PongDeterministic-v3)')
parser.add_argument('--no-shared', default=False, metavar='O',
help='use an optimizer without shared momentum.')


if __name__ == '__main__':
os.environ['OMP_NUM_THREADS'] = '1'

args = parser.parse_args()

torch.manual_seed(args.seed)

env = create_atari_env(args.env_name)
shared_model = ActorCritic(
env.observation_space.shape[0], env.action_space)
shared_model.share_memory()
model = ActorCritic(env.observation_space.shape[0], env.action_space)

if args.no_shared:
optimizer = None
else:
optimizer = my_optim.SharedAdam(shared_model.parameters(), lr=args.lr)
optimizer.share_memory()

processes = []

p = mp.Process(target=test, args=(args.num_processes, args, shared_model))
p.start()
processes.append(p)

for rank in range(0, args.num_processes):
p = mp.Process(target=train, args=(rank, args, shared_model, optimizer))
p.start()
processes.append(p)
for p in processes:
p.join()
while True:
train(args, model)
test(args, model)
65 changes: 0 additions & 65 deletions my_optim.py

This file was deleted.

11 changes: 6 additions & 5 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
from collections import deque


def test(rank, args, shared_model):
torch.manual_seed(args.seed + rank)
def test(args, model):
torch.manual_seed(args.seed)

env = create_atari_env(args.env_name)
env.seed(args.seed + rank)
env.seed(args.seed)

model = ActorCritic(env.observation_space.shape[0], env.action_space)

Expand All @@ -37,7 +37,7 @@ def test(rank, args, shared_model):
episode_length += 1
# Sync with the shared model
if done:
model.load_state_dict(shared_model.state_dict())
# model.load_state_dict(shared_model.state_dict())
cx = Variable(torch.zeros(1, 256), volatile=True)
hx = Variable(torch.zeros(1, 256), volatile=True)
else:
Expand Down Expand Up @@ -67,6 +67,7 @@ def test(rank, args, shared_model):
episode_length = 0
actions.clear()
state = env.reset()
time.sleep(60)
return
# time.sleep(60)

state = torch.from_numpy(state)
26 changes: 11 additions & 15 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,36 +10,32 @@
from torch.autograd import Variable
from torchvision import datasets, transforms


def ensure_shared_grads(model, shared_model):
for param, shared_param in zip(model.parameters(), shared_model.parameters()):
if shared_param.grad is not None:
return
shared_param._grad = param.grad


def train(rank, args, shared_model, optimizer=None):
torch.manual_seed(args.seed + rank)
def train(args, model, optimizer=None):
torch.manual_seed(args.seed)

env = create_atari_env(args.env_name)
env.seed(args.seed + rank)
print ("env: ", env.observation_space.shape, env.action_space)
env.seed(args.seed)

model = ActorCritic(env.observation_space.shape[0], env.action_space)

if optimizer is None:
optimizer = optim.Adam(shared_model.parameters(), lr=args.lr)
optimizer = optim.Adam(model.parameters(), lr=args.lr)

model.train()

state = env.reset()
print ("state: ", state.shape)
state = torch.from_numpy(state)
done = True

episode_length = 0
while True:
u = 0
while u < args.num_updates:
# print ("update: ", u)
episode_length += 1
# Sync with the shared model
model.load_state_dict(shared_model.state_dict())
# model.load_state_dict(shared_model.state_dict())
if done:
cx = Variable(torch.zeros(1, 256))
hx = Variable(torch.zeros(1, 256))
Expand Down Expand Up @@ -107,5 +103,5 @@ def train(rank, args, shared_model, optimizer=None):
(policy_loss + 0.5 * value_loss).backward()
torch.nn.utils.clip_grad_norm(model.parameters(), 40)

ensure_shared_grads(model, shared_model)
optimizer.step()
u += 1

0 comments on commit e0c169a

Please sign in to comment.