From bcfa670054ec920293450d856a245524b7123b51 Mon Sep 17 00:00:00 2001 From: mamba Date: Thu, 23 Nov 2023 22:31:47 +0800 Subject: [PATCH 1/3] fix bug --- wenet/utils/train_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/wenet/utils/train_utils.py b/wenet/utils/train_utils.py index e0401f475..4c2aba516 100644 --- a/wenet/utils/train_utils.py +++ b/wenet/utils/train_utils.py @@ -587,7 +587,7 @@ def log_per_step(writer, info_dict): if (batch_idx + 1) % log_interval == 0: log_str = '{} Batch {}/{} loss {:.6f} '.format( - tag, epoch, batch_idx + 1, loss_dict['loss'].item()) + tag, epoch, batch_idx + 1, loss_dict['loss'].item() * accum_grad) for name, value in loss_dict.items(): if name != 'loss' and value is not None: log_str += '{} {:.6f} '.format(name, value.item()) From acc5ec839470e5ccc6a2747fb7ca37eee80b9ebb Mon Sep 17 00:00:00 2001 From: mamba Date: Fri, 24 Nov 2023 15:36:49 +0800 Subject: [PATCH 2/3] fix bug --- wenet/transformer/attention.py | 5 +++-- wenet/utils/train_utils.py | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/wenet/transformer/attention.py b/wenet/transformer/attention.py index 4842e64b6..ba6d8e0f8 100644 --- a/wenet/transformer/attention.py +++ b/wenet/transformer/attention.py @@ -195,9 +195,10 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention): n_feat (int): The number of features. dropout_rate (float): Dropout rate. """ - def __init__(self, n_head, n_feat, dropout_rate): + def __init__(self, n_head: int, n_feat: int, dropout_rate: float, + key_bias: bool = True): """Construct an RelPositionMultiHeadedAttention object.""" - super().__init__(n_head, n_feat, dropout_rate) + super().__init__(n_head, n_feat, dropout_rate, key_bias) # linear transformation for positional encoding self.linear_pos = nn.Linear(n_feat, n_feat, bias=False) # these two learnable bias are used in matrix c and matrix d diff --git a/wenet/utils/train_utils.py b/wenet/utils/train_utils.py index 6952cb190..b7c719727 100644 --- a/wenet/utils/train_utils.py +++ b/wenet/utils/train_utils.py @@ -593,7 +593,7 @@ def log_per_step(writer, info_dict): if tag == "TRAIN" and rank == 0 and writer is not None: if (train_engine == "deepspeed" and is_gradient_accumulation_boundary) or \ (train_engine == "torch_ddp" and (batch_idx + 1) % accum_grad == 0): - writer.add_scalar('train/train_loss', loss_dict['loss'].item(), step + 1) + writer.add_scalar('train/train_loss', loss_dict['loss'].item() * accum_grad, step + 1) writer.add_scalar('train/grad_norm', info_dict['grad_norm'], step + 1) if (batch_idx + 1) % log_interval == 0: From 9d105a94791bb4d869f1059626a387c464aa352d Mon Sep 17 00:00:00 2001 From: mamba Date: Fri, 24 Nov 2023 16:12:27 +0800 Subject: [PATCH 3/3] fix lint --- wenet/utils/train_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/wenet/utils/train_utils.py b/wenet/utils/train_utils.py index b7c719727..45411cf5c 100644 --- a/wenet/utils/train_utils.py +++ b/wenet/utils/train_utils.py @@ -593,7 +593,8 @@ def log_per_step(writer, info_dict): if tag == "TRAIN" and rank == 0 and writer is not None: if (train_engine == "deepspeed" and is_gradient_accumulation_boundary) or \ (train_engine == "torch_ddp" and (batch_idx + 1) % accum_grad == 0): - writer.add_scalar('train/train_loss', loss_dict['loss'].item() * accum_grad, step + 1) + writer.add_scalar('train/train_loss', + loss_dict['loss'].item() * accum_grad, step + 1) writer.add_scalar('train/grad_norm', info_dict['grad_norm'], step + 1) if (batch_idx + 1) % log_interval == 0: