Skip to content

Commit

Permalink
Only calculate randn in some samplers when it's actually being used.
Browse files Browse the repository at this point in the history
  • Loading branch information
comfyanonymous committed Jul 17, 2023
1 parent ee8f8ee commit 3a150ba
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions comfy/k_diffusion/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,9 @@ def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None,
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
eps = torch.randn_like(x) * s_noise
sigma_hat = sigmas[i] * (gamma + 1)
if gamma > 0:
eps = torch.randn_like(x) * s_noise
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
denoised = model(x, sigma_hat * s_in, **extra_args)
d = to_d(x, sigma_hat, denoised)
Expand Down Expand Up @@ -172,9 +172,9 @@ def sample_heun(model, x, sigmas, extra_args=None, callback=None, disable=None,
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
eps = torch.randn_like(x) * s_noise
sigma_hat = sigmas[i] * (gamma + 1)
if gamma > 0:
eps = torch.randn_like(x) * s_noise
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
denoised = model(x, sigma_hat * s_in, **extra_args)
d = to_d(x, sigma_hat, denoised)
Expand All @@ -201,9 +201,9 @@ def sample_dpm_2(model, x, sigmas, extra_args=None, callback=None, disable=None,
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
eps = torch.randn_like(x) * s_noise
sigma_hat = sigmas[i] * (gamma + 1)
if gamma > 0:
eps = torch.randn_like(x) * s_noise
x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
denoised = model(x, sigma_hat * s_in, **extra_args)
d = to_d(x, sigma_hat, denoised)
Expand Down

0 comments on commit 3a150ba

Please sign in to comment.