Skip to content

Commit

Permalink
Fix grad sharing problem
Browse files Browse the repository at this point in the history
  • Loading branch information
apaszke committed Mar 14, 2017
1 parent 15dd5e5 commit d4d0036
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 22 deletions.
10 changes: 0 additions & 10 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,16 +59,6 @@ def __init__(self, num_inputs, action_space):
self.lstm.bias_hh.data.fill_(0)

self.train()
self.__dummy_backprob()

def __dummy_backprob(self):
# See: https://discuss.pytorch.org/t/problem-on-variable-grad-data/957/7
# An ugly hack until there is a better solution.
inputs = Variable(torch.randn(1, 1, 42, 42))
hx, cx = Variable(torch.randn(1, 256)), Variable(torch.randn(1, 256))
outputs = self((inputs, (hx, cx)))
loss = (outputs[0].mean() + outputs[1].mean()) * 0.0
loss.backward()

def forward(self, inputs):
inputs, (hx, cx) = inputs
Expand Down
22 changes: 10 additions & 12 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,13 @@
from torchvision import datasets, transforms


def ensure_shared_grads(model, shared_model):
for param, shared_param in zip(model.parameters(), shared_model.parameters()):
if shared_param.grad is not None:
return
shared_param._grad = param.grad


def train(rank, args, shared_model):
torch.manual_seed(args.seed + rank)

Expand All @@ -19,10 +26,6 @@ def train(rank, args, shared_model):

model = ActorCritic(env.observation_space.shape[0], env.action_space)

for param, shared_param in zip(model.parameters(), shared_model.parameters()):
# Use gradients from the local model
shared_param.grad.data = param.grad.data

optimizer = optim.Adam(shared_model.parameters(), lr=args.lr)

model.train()
Expand Down Expand Up @@ -102,14 +105,9 @@ def train(rank, args, shared_model):
log_probs[i] * Variable(gae) - 0.01 * entropies[i]

optimizer.zero_grad()

(policy_loss + 0.5 * value_loss).backward()
torch.nn.utils.clip_grad_norm(model.parameters(), 40)

global_norm = 0
for param in model.parameters():
global_norm += param.grad.data.pow(2).sum()
global_norm = math.sqrt(global_norm)
ratio = 40 / global_norm
if ratio < 1:
for param in model.parameters():
param.grad.data.mul_(ratio)
ensure_shared_grads(model, shared_model)
optimizer.step()

0 comments on commit d4d0036

Please sign in to comment.