Skip to content

Commit

Permalink
Fix a problem with shared adam
Browse files Browse the repository at this point in the history
  • Loading branch information
ikostrikov2 committed Mar 27, 2017
1 parent e19ac39 commit 5d9b07d
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 4 deletions.
43 changes: 42 additions & 1 deletion my_optim.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import math
import torch
import torch.optim as optim

class SharedAdam(optim.Adam):
Expand All @@ -12,13 +13,53 @@ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
for group in self.param_groups:
for p in group['params']:
state = self.state[p]
state['step'] = 0
state['step'] = torch.zeros(1)
state['exp_avg'] = p.data.new().resize_as_(p.data).zero_()
state['exp_avg_sq'] = p.data.new().resize_as_(p.data).zero_()

def share_memory(self):
for group in self.param_groups:
for p in group['params']:
state = self.state[p]
state['step'].share_memory_()
state['exp_avg'].share_memory_()
state['exp_avg_sq'].share_memory_()

def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()

for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data
state = self.state[p]

exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
beta1, beta2 = group['betas']

state['step'] += 1

if group['weight_decay'] != 0:
grad = grad.add(group['weight_decay'], p.data)

# Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(1 - beta1, grad)
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)

denom = exp_avg_sq.sqrt().add_(group['eps'])

bias_correction1 = 1 - beta1 ** state['step'][0]
bias_correction2 = 1 - beta2 ** state['step'][0]
step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1

p.data.addcdiv_(-step_size, exp_avg, denom)

return loss
3 changes: 0 additions & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,6 @@ def train(rank, args, shared_model, optimizer=None):

model.train()

values = []
log_probs = []

state = env.reset()
state = torch.from_numpy(state)
done = True
Expand Down

0 comments on commit 5d9b07d

Please sign in to comment.