Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deterministic sampling and sampling steps #4

Open
FutureXiang opened this issue Nov 29, 2022 · 0 comments
Open

Deterministic sampling and sampling steps #4

FutureXiang opened this issue Nov 29, 2022 · 0 comments

Comments

@FutureXiang
Copy link

FutureXiang commented Nov 29, 2022

Hi, I try to train the EDM model with a simpler 35.7M #params UNet (proposed by original DDPM paper) and compare the result with DDPM/DDIM.
I notice that $S_{churn} = 0$ leads to deterministic sampling, and $\gamma_i = \sqrt{2}-1$ leads to "max" stochastic sampling. So I introduce a parameter $\eta = \frac{S_{churn} / N}{\sqrt{2}-1}$ to control stochasticity by interpolations. That is to say, $\gamma_i = (\sqrt{2}-1) * \eta$. Like in DDIM, $\eta = 0$ means deterministic, $\eta = 1$ means "max" stochastic.

I set different $\eta$ s and different steps to observe FIDs:

$\eta$/steps steps=18 steps=50 steps=100
$\eta=0.0$ 3.39 3.64 3.68
$\eta=0.5$ 3.10 2.95 2.93
$\eta=1.0$ 3.12 2.84 2.97

The FID is supposed to decrease when using more sampling steps, right? But why the FID gets worse for deterministic sampling? However it performs normally when $\eta=0.5$, and it increases again from 50 steps to 100 steps @ $\eta=1.0$. Why the behavior is so unstable and unpredictable?

To confirm it's not a bug, I train a model with your official codebase under the simpler setting close to DDPM (duration=100, augment=None, xflip=True; channel_mult=[1,2,2,2], num_blocks=2). The results are:

$\eta$/steps steps=18 steps=50
$\eta=0.0$ 2.94 3.09
$\eta=0.5$ 2.80 2.75
$\eta=1.0$ 2.95 2.78

For deterministic sampling, the FID is still getting worse when using more steps. When $\eta > 0$, the FID slightly gets better when steps increase.
If the hyper-parameter settings and the corresponding performance are not consistently predictable, then how to obtain a good model under different datasets? Only by brute force & grid search?

Could you please provide some explanation and thoughts?
Thanks a lot!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant