Skip to content

Commit

Permalink
allow for mixed precision training with fp16 flag
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Sep 9, 2020
1 parent 88f83d0 commit 4bf2891
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 12 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ trainer = Trainer(
train_lr = 2e-5,
train_num_steps = 100000, # total training steps
gradient_accumulate_every = 2, # gradient accumulation steps
ema_decay = 0.995 # exponential moving average decay
ema_decay = 0.995, # exponential moving average decay
fp16 = True # turn on mixed precision training with apex
)

trainer.train()
Expand Down
56 changes: 46 additions & 10 deletions denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@
from tqdm import tqdm
from einops import rearrange

try:
from apex import amp
APEX_AVAILABLE = True
except:
APEX_AVAILABLE = False

# constants

SAVE_AND_SAMPLE_EVERY = 1000
Expand All @@ -37,6 +43,13 @@ def cycle(dl):
for data in dl:
yield data

def loss_backwards(fp16, loss, optimizer, **kwargs):
if fp16:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward(**kwargs)
else:
loss.backward(**kwargs)

# small helper modules

class EMA():
Expand Down Expand Up @@ -107,7 +120,7 @@ def forward(self, x):
# building block modules

class Block(nn.Module):
def __init__(self, dim, dim_out, groups = 32):
def __init__(self, dim, dim_out, groups = 8):
super().__init__()
self.block = nn.Sequential(
nn.Conv2d(dim, dim_out, 3, padding=1),
Expand All @@ -118,7 +131,7 @@ def forward(self, x):
return self.block(x)

class ResnetBlock(nn.Module):
def __init__(self, dim, dim_out, *, time_emb_dim, groups = 32):
def __init__(self, dim, dim_out, *, time_emb_dim, groups = 8):
super().__init__()
self.mlp = nn.Sequential(
Mish(),
Expand Down Expand Up @@ -157,7 +170,7 @@ def forward(self, x):
# model

class Unet(nn.Module):
def __init__(self, dim, out_dim = None, dim_mults=(1, 2, 4, 8), groups = 32):
def __init__(self, dim, out_dim = None, dim_mults=(1, 2, 4, 8), groups = 8):
super().__init__()
dims = [3, *map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:]))
Expand All @@ -178,6 +191,7 @@ def __init__(self, dim, out_dim = None, dim_mults=(1, 2, 4, 8), groups = 32):

self.downs.append(nn.ModuleList([
ResnetBlock(dim_in, dim_out, time_emb_dim = dim),
ResnetBlock(dim_out, dim_out, time_emb_dim = dim),
Residual(Rezero(LinearAttention(dim_out))),
Downsample(dim_out) if not is_last else nn.Identity()
]))
Expand All @@ -192,6 +206,7 @@ def __init__(self, dim, out_dim = None, dim_mults=(1, 2, 4, 8), groups = 32):

self.ups.append(nn.ModuleList([
ResnetBlock(dim_out * 2, dim_in, time_emb_dim = dim),
ResnetBlock(dim_in, dim_in, time_emb_dim = dim),
Residual(Rezero(LinearAttention(dim_in))),
Upsample(dim_in) if not is_last else nn.Identity()
]))
Expand All @@ -208,8 +223,9 @@ def forward(self, x, time):

h = []

for resnet, attn, downsample in self.downs:
for resnet, resnet2, attn, downsample in self.downs:
x = resnet(x, t)
x = resnet2(x, t)
x = attn(x)
h.append(x)
x = downsample(x)
Expand All @@ -218,9 +234,10 @@ def forward(self, x, time):
x = self.mid_attn(x)
x = self.mid_block2(x, t)

for resnet, attn, upsample in self.ups:
for resnet, resnet2, attn, upsample in self.ups:
x = torch.cat((x, h.pop()), dim=1)
x = resnet(x, t)
x = resnet2(x, t)
x = attn(x)
x = upsample(x)

Expand Down Expand Up @@ -417,23 +434,40 @@ def __init__(
train_lr = 2e-5,
train_num_steps = 100000,
gradient_accumulate_every = 2,
fp16 = False
):
super().__init__()
self.model = diffusion_model
self.ema = EMA(ema_decay)
self.ema_model = copy.deepcopy(self.model)

self.image_size = image_size
self.gradient_accumulate_every = gradient_accumulate_every
self.train_num_steps = train_num_steps

self.ema = EMA(ema_decay)
self.ema_model = copy.deepcopy(self.model)

self.ds = Dataset(folder, image_size)
self.dl = cycle(data.DataLoader(self.ds, batch_size = train_batch_size, shuffle=True, pin_memory=True))
self.opt = Adam(diffusion_model.parameters(), lr=train_lr)

self.step = 0

assert not fp16 or fp16 and APEX_AVAILABLE, 'Apex must be installed in order for mixed precision training to be turned on'

self.fp16 = fp16
if fp16:
(self.model, self.ema_model), self.opt = amp.initialize([self.model, self.ema_model], self.opt, opt_level='O1')

self.reset_parameters()

def reset_parameters(self):
self.ema_model.load_state_dict(self.model.state_dict())

def step_ema(self):
if self.step < 2000:
self.reset_parameters()
return
self.ema.update_model_average(self.ema_model, self.model)

def save(self, milestone):
data = {
'step': self.step,
Expand All @@ -450,18 +484,20 @@ def load(self, milestone):
self.ema_model.load_state_dict(data['ema'])

def train(self):
backwards = partial(loss_backwards, self.fp16)

while self.step < self.train_num_steps:
for i in range(self.gradient_accumulate_every):
data = next(self.dl).cuda()
loss = self.model(data)
print(f'{self.step}: {loss.item()}')
(loss / self.gradient_accumulate_every).backward()
backwards(loss / self.gradient_accumulate_every, self.opt)

self.opt.step()
self.opt.zero_grad()

if self.step % UPDATE_EMA_EVERY == 0:
self.ema.update_model_average(self.ema_model, self.model)
self.step_ema()

if self.step % SAVE_AND_SAMPLE_EVERY == 0:
milestone = self.step // SAVE_AND_SAMPLE_EVERY
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'denoising-diffusion-pytorch',
packages = find_packages(),
version = '0.2.2',
version = '0.2.4',
license='MIT',
description = 'Denoising Diffusion Probabilistic Models - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 4bf2891

Please sign in to comment.