From 54c07fa3ae23b44888f18573046510a1f8286c13 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xingchen=20Song=28=E5=AE=8B=E6=98=9F=E8=BE=B0=29?= Date: Tue, 19 Dec 2023 18:55:16 +0800 Subject: [PATCH] [train] fix epoch when loading from previous checkpoint (#2252) --- wenet/bin/train.py | 4 +++- wenet/utils/executor.py | 1 + 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/wenet/bin/train.py b/wenet/bin/train.py index 16ebc9e07..951b419f3 100644 --- a/wenet/bin/train.py +++ b/wenet/bin/train.py @@ -122,7 +122,9 @@ def main(): scaler = torch.cuda.amp.GradScaler() # Start training loop - start_epoch = configs["init_infos"].get('epoch', -1) + 1 + tag = configs["init_infos"].get("tag", "init") + start_epoch = configs["init_infos"].get('epoch', -1) + int("epoch_" in tag) + configs.pop("init_infos", None) final_epoch = None for epoch in range(start_epoch, configs.get('max_epoch', 100)): train_dataset.set_epoch(epoch) diff --git a/wenet/utils/executor.py b/wenet/utils/executor.py index 4459b9ea2..7cf148758 100644 --- a/wenet/utils/executor.py +++ b/wenet/utils/executor.py @@ -83,6 +83,7 @@ def train(self, model, optimizer, scheduler, train_data_loader, and (batch_idx + 1) % info_dict["accum_grad"] == 0: total_loss, num_seen_utts = self.cv( model, cv_data_loader, configs) + model.train() info_dict.update({ "tag": "step_{}".format(self.step),