diff --git a/valle/bin/trainer.py b/valle/bin/trainer.py index 44349a0..8525ad6 100644 --- a/valle/bin/trainer.py +++ b/valle/bin/trainer.py @@ -692,8 +692,23 @@ def train_one_epoch( set_batch_count(model, params.batch_idx_train) except: # noqa + # Save the broken batch + logging.warning(f"Hit a broken batch of training data. Cut ID: {batch['utt_id']} Text: {batch['text']} - Skipping...") display_and_save_batch(batch, params=params) - raise + # Clean up batch data from Memory and GPU + del batch["text_tokens"] + del batch["text_tokens_lens"] + del batch["audio_features"] + del batch["audio_features_lens"] + del batch + try: + del loss + del loss_info + except UnboundLocalError: + pass + torch.cuda.empty_cache() + # Continue training + continue if params.average_period > 0: if ( @@ -1101,6 +1116,10 @@ def scan_pessimistic_batches_for_oom( elif params.dtype in ["float16", "fp16"]: dtype = torch.float16 + scaler = GradScaler( + enabled=(params.dtype in ["fp16", "float16"]), init_scale=1.0 + ) + for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: @@ -1111,7 +1130,7 @@ def scan_pessimistic_batches_for_oom( batch=batch, is_training=True, ) - loss.backward() + scaler.scale(loss).backward() optimizer.zero_grad() except Exception as e: if "CUDA out of memory" in str(e):