Skip to content

Commit

Permalink
give an initial conv
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed May 1, 2022
1 parent f39b3b1 commit e274fb3
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 2 deletions.
8 changes: 7 additions & 1 deletion denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ class Unet(nn.Module):
def __init__(
self,
dim,
init_dim = None,
out_dim = None,
dim_mults=(1, 2, 4, 8),
channels = 3,
Expand All @@ -200,7 +201,10 @@ def __init__(
super().__init__()
self.channels = channels

dims = [channels, *map(lambda m: dim * m, dim_mults)]
init_dim = default(init_dim, dim // 3 * 2)
self.init_conv = nn.Conv2d(channels, init_dim, 7, padding = 3)

dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:]))

if with_time_emb:
Expand Down Expand Up @@ -251,6 +255,8 @@ def __init__(
)

def forward(self, x, time):
x = self.init_conv(x)

t = self.time_mlp(time) if exists(self.time_mlp) else None

h = []
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'denoising-diffusion-pytorch',
packages = find_packages(),
version = '0.10.0',
version = '0.10.1',
license='MIT',
description = 'Denoising Diffusion Probabilistic Models - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit e274fb3

Please sign in to comment.