From a94d4c4ef4c7b0892b9647b0319f815c01e5d057 Mon Sep 17 00:00:00 2001 From: fg-mindee Date: Fri, 22 Oct 2021 17:46:29 +0200 Subject: [PATCH] fix: Fixed TransformerDecoder for PyTorch 1.10 --- doctr/models/recognition/transformer/pytorch.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/doctr/models/recognition/transformer/pytorch.py b/doctr/models/recognition/transformer/pytorch.py index 16719a33d1..c548ead3cc 100644 --- a/doctr/models/recognition/transformer/pytorch.py +++ b/doctr/models/recognition/transformer/pytorch.py @@ -59,6 +59,7 @@ def __init__( dim_feedforward=dff, dropout=dropout, activation='relu', + batch_first=True, ) for _ in range(num_layers) ]) @@ -79,13 +80,11 @@ def forward( x += self.pos_encoding[:, :seq_len, :] x = self.dropout(x) - # Batch first = False in decoder - x = x.permute(1, 0, 2) + # Batch first = True in decoder for i in range(self.num_layers): x = self.dec_layers[i]( tgt=x, memory=enc_output, tgt_mask=look_ahead_mask, memory_mask=padding_mask ) # shape (batch_size, target_seq_len, d_model) - x = x.permute(1, 0, 2) return x