Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

stage1 training issue #175

Closed
zhiqiuiyiye opened this issue Dec 28, 2023 · 6 comments
Closed

stage1 training issue #175

zhiqiuiyiye opened this issue Dec 28, 2023 · 6 comments

Comments

@zhiqiuiyiye
Copy link

Hi, I'm training the styletts2 on a new language Thai, when I trained epoch 7 , I found the loss were been Nan, and the g_loss seems increasing when traing. I want to know what will cause this problem. Here is my log, training loss.
微信截图_20231228094528
微信截图_20231228094433

@yl4579
Copy link
Owner

yl4579 commented Jan 8, 2024

Sorry for the late reply. I was quite busy recently. Have you checked #10 and #11? Did you use mixed precision as well?

@zhiqiuiyiye
Copy link
Author

zhiqiuiyiye commented Jan 9, 2024

thanks for your reply, I have fixed this issue, may caused by too small batch size

Akito-UzukiP pushed a commit to Akito-UzukiP/StyleTTS2 that referenced this issue Jan 13, 2024
* SYNC CHANGE TO EMO BRANCH (yl4579#162)

* Update README.md

* 更新 bert_models.json

* fix

* Update data_utils.py

* Update infer.py

* performance improve

* Feat: support auto split in webui (yl4579#158)

* Feat: support auto split in webui

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* Fix: change /voice api to post (yl4579#160)

* Fix: change /voice api to post

* Fix: support /voice api get

* Fix: Add missing torch.cuda.empty_cache() (yl4579#161)

---------

Co-authored-by: Sora <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Artrajz <[email protected]>

* sync  (yl4579#163)

* Update README.md

* 更新 bert_models.json

* fix

* Update data_utils.py

* Update infer.py

* performance improve

* Feat: support auto split in webui (yl4579#158)

* Feat: support auto split in webui

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* Fix: change /voice api to post (yl4579#160)

* Fix: change /voice api to post

* Fix: support /voice api get

* Fix: Add missing torch.cuda.empty_cache() (yl4579#161)

* del emo

* del emo

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: Sora <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Artrajz <[email protected]>

* Add files via upload

* Update infer.py

* add emo

* add emo

* Update default_config.yml

* Fix slice segments GPU perf (yl4579#165)

* Fix slice segments GPU perf

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update commons.py

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* Update infer.py

* Update models.py

* Update infer.py

* remove spec cache

* Update data_utils.py

* Update data_utils.py

* Update train_ms.py

* Revert "Fix slice segments GPU perf (yl4579#165)" (yl4579#169)

This reverts commit 28430fc76bc628297bb59d8f8d25100dbe46ab59.

* Update train_ms.py

* Update train_ms.py

* Update data_utils.py

* Update data_utils.py

* Update train_ms.py

* Update train_ms.py

* Update train_ms.py

* Update train_ms.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update default_config.yml

* Switch to Japanese wwm DeBERTa (yl4579#172)

* Switch to Japanese wwm DeBERTa

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* Fix wrong ellipsis g2p (yl4579#173)

* Switch to Japanese wwm DeBERTa

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix ellipsis g2p

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* Add files via upload

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix English phones not aligned with BERT features (yl4579#174)

* Fix English phones not aligned with BERT features

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* Fix english bert gen (yl4579#175)

* Update webui.py

* Update webui.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add NCCL timeout

* Update train_ms.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update train_ms.py

* Update default_config.yml

* Update infer.py

* Update models.py

* Update train_ms.py

* Update infer.py

* Update emo_gen.py

* Feat: Support load and infer 2.0 models (yl4579#178)

* Feat: Support load and infer 2.0 models

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* 复用相同逻辑,修正静音添加错误 (yl4579#181)

* Refactor: reuse the same part of voice api.

* Fix: server_fastapi.py

* Update train_ms.py

* Update data_utils.py

* Update data_utils.py

* Update train_ms.py

* Update train_ms.py

* Update train_ms.py

* Update train_ms.py

* Update data_utils.py

* Update data_utils.py

* Add files via upload

* Update train_ms.py

* Update train_ms.py

* Update train_ms.py

* Update default_config.yml

* Update utils.py

* Update train_ms.py

* Update utils.py

* Update default_config.yml

* Update data_utils.py

* Update default_config.yml

* Update train_ms.py

* Update train_ms.py

* Update config.py

* Update utils.py

* Update train_ms.py

* Update train_ms.py

* feat: add voice mix and tone mix (yl4579#187)

* feat: add voice mix and tone mix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Stardust·减 <[email protected]>

* Add files via upload

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: Sora <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Artrajz <[email protected]>
Co-authored-by: Leng Yue <[email protected]>
Co-authored-by: OedoSoldier <[email protected]>
Co-authored-by: 潮幻Mark <[email protected]>
@akshatgarg99
Copy link

Were you able to do it? I was trying to train but was facing some issue. Can we discuss?

@RillmentGames
Copy link

RillmentGames commented Feb 7, 2024

Same issue with batch size 2, generator loss can reach about 100 and then it Nan's. (EDIT: Didn't work!) I have a preliminary solution, still testing though but based on #11 (comment) it seems to be discriminator overfitting. So I am trying to force the discriminators weight decay to a high value to prevent overfitting, in train_first:

for module in ["mpd", "msd"]:
    for g in optimizer.optimizers[module].param_groups:
        g['weight_decay'] = 0.1

and also lowering the feature discriminator gain by premultiplying by 0.5, in losses.py

class GeneratorLoss(torch.nn.Module):

    def __init__(self, mpd, msd):
        super(GeneratorLoss, self).__init__()
        self.mpd = mpd
        self.msd = msd
        
    def forward(self, y, y_hat):
        y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = self.mpd(y, y_hat)
        y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = self.msd(y, y_hat)
        loss_fm_f = feature_loss(fmap_f_r, fmap_f_g)
        loss_fm_s = feature_loss(fmap_s_r, fmap_s_g)
        loss_gen_f, losses_gen_f = generator_loss(y_df_hat_g)
        loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g)

        loss_rel = generator_TPRLS_loss(y_df_hat_r, y_df_hat_g) + generator_TPRLS_loss(y_ds_hat_r, y_ds_hat_g)
        
        loss_gen_all = loss_gen_s + loss_gen_f + 0.5*loss_fm_s + 0.5*loss_fm_f + loss_rel
        
        return loss_gen_all.mean()

At first I tried decay = 0.01 and gains 1.0,1.0 but that only delayed the problem.
Then I tried decay = 1.0 and gains 0.1,0.1 and that seemed to prevent Nan but the audio quality wasn't good.
So now I am trying decay = 0.1 and gains 0.5,0.5. I should be able to report back the results in a few days.

@RillmentGames
Copy link

No that didn't work :( the loss made some strange moves and eventually ended with Nan.
Stts2NanProblem

@RillmentGames
Copy link

Integrating PhaseAug and using batch_percentage=1.0 with Batch=2, fixed it for me.
PhaseAug tries to address the overfitting issue by randomly rotating the phase of each frequency bin.
The gen error still creeps up but very slowly now and audio quality becomes quite nice after 2 epochs:

...
    aug = PhaseAug()
    gl = GeneratorLoss(model.mpd, model.msd, aug).to(device)
    dl = DiscriminatorLoss(model.mpd, model.msd, aug).to(device)
...
class GeneratorLoss(torch.nn.Module):

    def __init__(self, mpd, msd, aug):
        super(GeneratorLoss, self).__init__()
        self.mpd = mpd
        self.msd = msd
        self.aug = aug  
        
    def forward(self, y, y_hat):
        y, y_hat = self.aug.forward_sync(y, y_hat)                 #               <--- Augment here
        y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = self.mpd(y, y_hat)
        y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = self.msd(y, y_hat)
        loss_fm_f = feature_loss(fmap_f_r, fmap_f_g)
        loss_fm_s = feature_loss(fmap_s_r, fmap_s_g)
        loss_gen_f, losses_gen_f = generator_loss(y_df_hat_g)
        loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g)

        loss_rel = generator_TPRLS_loss(y_df_hat_r, y_df_hat_g) + generator_TPRLS_loss(y_ds_hat_r, y_ds_hat_g)
        
        loss_gen_all = loss_gen_s + loss_gen_f + 1.0*loss_fm_s + 1.0*loss_fm_f + loss_rel
        
        return loss_gen_all.mean()
    
class DiscriminatorLoss(torch.nn.Module):

    def __init__(self, mpd, msd, aug):
        super(DiscriminatorLoss, self).__init__()
        self.aug = aug
        self.mpd = mpd
        self.msd = msd
        
    def forward(self, y, y_hat):
        y, y_hat = self.aug.forward_sync(y, y_hat.detach())    #                   <--- Augment here
        # MPD
        y_df_hat_r, y_df_hat_g, _, _ = self.mpd(y, y_hat)
        loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss(y_df_hat_r, y_df_hat_g)
        # MSD
        y_ds_hat_r, y_ds_hat_g, _, _ = self.msd(y, y_hat)
        loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss(y_ds_hat_r, y_ds_hat_g)
        
        loss_rel = discriminator_TPRLS_loss(y_df_hat_r, y_df_hat_g) + discriminator_TPRLS_loss(y_ds_hat_r, y_ds_hat_g)


        d_loss = loss_disc_s + loss_disc_f + loss_rel
        
        return d_loss.mean()

@yl4579 yl4579 closed this as completed Mar 7, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants