Skip to content

Commit

Permalink
Fix bug for VITS resuming training (#108)
Browse files Browse the repository at this point in the history
* Fix bug for VITS resuming training. Related issue #94
  • Loading branch information
lmxue authored Jan 17, 2024
1 parent a840088 commit 4125584
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 16 deletions.
44 changes: 43 additions & 1 deletion egs/tts/VITS/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,56 @@ We provide the default hyparameters in the `exp_config.json`. They can work on s
}
```

### Run
### Train From Scratch

Run the `run.sh` as the training stage (set `--stage 2`). Specify a experimental name to run the following command. The tensorboard logs and checkpoints will be saved in `Amphion/ckpts/tts/[YourExptName]`.

```bash
sh egs/tts/VITS/run.sh --stage 2 --name [YourExptName]
```

### Train From Existing Source

We support training from existing source for various purposes. You can resume training the model from a checkpoint or fine-tune a model from another checkpoint.

Setting `--resume true`, the training will resume from the **latest checkpoint** from the current `[YourExptName]` by default. For example, if you want to resume training from the latest checkpoint in `Amphion/ckpts/tts/[YourExptName]/checkpoint`, run:

```bash
sh egs/tts/VITS/run.sh --stage 2 --name [YourExptName] \
--resume true
```

You can also choose a **specific checkpoint** for retraining by `--resume_from_ckpt_path` argument. For example, if you want to resume training from the checkpoint `Amphion/ckpts/tts/[YourExptName]/checkpoint/[SpecificCheckpoint]`, run:

```bash
sh egs/tts/VITS/run.sh --stage 2 --name [YourExptName] \
--resume true
--resume_from_ckpt_path "Amphion/ckpts/tts/[YourExptName]/checkpoint/[SpecificCheckpoint]" \
```

If you want to **fine-tune from another checkpoint**, just use `--resume_type` and set it to `"finetune"`. For example, If you want to fine-tune the model from the checkpoint `Amphion/ckpts/tts/[AnotherExperiment]/checkpoint/[SpecificCheckpoint]`, run:


```bash
sh egs/tts/VITS/run.sh --stage 2 --name [YourExptName] \
--resume true
--resume_from_ckpt_path "Amphion/ckpts/tts/[YourExptName]/checkpoint/[SpecificCheckpoint]" \
--resume_type "finetune"
```

> **NOTE:** The `--resume_type` is set as `"resume"` in default. It's not necessary to specify it when resuming training.
>
> The difference between `"resume"` and `"finetune"` is that the `"finetune"` will **only** load the pretrained model weights from the checkpoint, while the `"resume"` will load all the training states (including optimizer, scheduler, etc.) from the checkpoint.
Here are some example scenarios to better understand how to use these arguments:
| Scenario | `--resume` | `--resume_from_ckpt_path` | `--resume_type` |
| ------ | -------- | ----------------------- | ------------- |
| You want to train from scratch | no | no | no |
| The machine breaks down during training and you want to resume training from the latest checkpoint | `true` | no | no |
| You find the latest model is overfitting and you want to re-train from the checkpoint before | `true` | `SpecificCheckpoint Path` | no |
| You want to fine-tune a model from another checkpoint | `true` | `SpecificCheckpoint Path` | `"finetune"` |


> **NOTE:** The `CUDA_VISIBLE_DEVICES` is set as `"0"` in default. You can change it when running `run.sh` by specifying such as `--gpu "0,1,2,3"`.

Expand Down
38 changes: 33 additions & 5 deletions egs/tts/VITS/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ cd $work_dir

######## Parse the Given Parameters from the Commond ###########
# options=$(getopt -o c:n:s --long gpu:,config:,infer_expt_dir:,infer_output_dir:,infer_source_file:,infer_source_audio_dir:,infer_target_speaker:,infer_key_shift:,infer_vocoder_dir:,name:,stage: -- "$@")
options=$(getopt -o c:n:s --long gpu:,config:,infer_expt_dir:,infer_output_dir:,infer_mode:,infer_dataset:,infer_testing_set:,infer_text:,name:,stage: -- "$@")
options=$(getopt -o c:n:s --long gpu:,config:,resume:,resume_from_ckpt_path:,resume_type:,infer_expt_dir:,infer_output_dir:,infer_mode:,infer_dataset:,infer_testing_set:,infer_text:,name:,stage: -- "$@")
eval set -- "$options"

while true; do
Expand All @@ -32,6 +32,13 @@ while true; do
# Visible GPU machines. The default value is "0".
--gpu) shift; gpu=$1 ; shift ;;

# [Only for Training] Resume configuration
--resume) shift; resume=$1 ; shift ;;
# [Only for Training] The specific checkpoint path that you want to resume from.
--resume_from_ckpt_path) shift; resume_from_ckpt_path=$1 ; shift ;;
# [Only for Training] `resume` for loading all the things (including model weights, optimizer, scheduler, and random states). `finetune` for loading only the model weights.
--resume_type) shift; resume_type=$1 ; shift ;;

# [Only for Inference] The experiment dir. The value is like "[Your path to save logs and checkpoints]/[YourExptName]"
--infer_expt_dir) shift; infer_expt_dir=$1 ; shift ;;
# [Only for Inference] The output dir to save inferred audios. Its default value is "$expt_dir/result"
Expand Down Expand Up @@ -81,10 +88,31 @@ if [ $running_stage -eq 2 ]; then
fi
echo "Exprimental Name: $exp_name"

CUDA_VISIBLE_DEVICES=$gpu accelerate launch "${work_dir}"/bins/tts/train.py \
--config $exp_config \
--exp_name $exp_name \
--log_level debug
# add default value
if [ -z "$resume_from_ckpt_path" ]; then
resume_from_ckpt_path=""
fi

if [ -z "$resume_type" ]; then
resume_type="resume"
fi

if [ "$resume" = true ]; then
echo "Resume from the existing experiment..."
CUDA_VISIBLE_DEVICES="$gpu" accelerate launch "${work_dir}"/bins/tts/train.py \
--config "$exp_config" \
--exp_name "$exp_name" \
--log_level info \
--resume \
--checkpoint_path "$resume_from_ckpt_path" \
--resume_type "$resume_type"
else
echo "Start a new experiment..."
CUDA_VISIBLE_DEVICES=$gpu accelerate launch "${work_dir}"/bins/tts/train.py \
--config $exp_config \
--exp_name $exp_name \
--log_level debug
fi
fi

######## Inference ###########
Expand Down
28 changes: 18 additions & 10 deletions models/tts/base/tts_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,11 +179,6 @@ def _check_resume(self):
open(os.path.join(self.ckpt_path, "ckpts.json"), "r")
)

self.checkpoint_dir = os.path.join(self.exp_dir, "checkpoint")
if self.accelerator.is_main_process:
os.makedirs(self.checkpoint_dir, exist_ok=True)
self.logger.debug(f"Checkpoint directory: {self.checkpoint_dir}")

def _init_accelerator(self):
self.exp_dir = os.path.join(
os.path.abspath(self.cfg.log_dir), self.args.exp_name
Expand Down Expand Up @@ -292,7 +287,7 @@ def _load_model(self, checkpoint_dir, checkpoint_path=None, resume_type="resume"
it will load the checkpoint specified by checkpoint_path.
**Only use this method after** ``accelerator.prepare()``.
"""
if checkpoint_path is None:
if checkpoint_path is None or checkpoint_path == "":
ls = [str(i) for i in Path(checkpoint_dir).glob("*")]
ls.sort(key=lambda x: int(x.split("_")[-3].split("-")[-1]), reverse=True)
checkpoint_path = ls[0]
Expand All @@ -303,11 +298,24 @@ def _load_model(self, checkpoint_dir, checkpoint_path=None, resume_type="resume"
self.epoch = int(checkpoint_path.split("_")[-3].split("-")[-1]) + 1
self.step = int(checkpoint_path.split("_")[-2].split("-")[-1]) + 1
elif resume_type == "finetune":
self.model.load_state_dict(
torch.load(os.path.join(checkpoint_path, "pytorch_model.bin"))
)
self.model.cuda(self.accelerator.device)
if isinstance(self.model, dict):
for idx, sub_model in enumerate(self.model.keys()):
if idx == 0:
ckpt_name = "pytorch_model.bin"
else:
ckpt_name = "pytorch_model_{}.bin".format(idx)

self.model[sub_model].load_state_dict(
torch.load(os.path.join(checkpoint_path, ckpt_name))
)
self.model[sub_model].cuda(self.accelerator.device)
else:
self.model.load_state_dict(
torch.load(os.path.join(checkpoint_path, "pytorch_model.bin"))
)
self.model.cuda(self.accelerator.device)
self.logger.info("Load model weights for finetune SUCCESS!")

else:
raise ValueError("Unsupported resume type: {}".format(resume_type))

Expand Down

0 comments on commit 4125584

Please sign in to comment.