diff --git a/doctr/models/recognition/transformer/pytorch.py b/doctr/models/recognition/transformer/pytorch.py index 16719a33d1..c548ead3cc 100644 --- a/doctr/models/recognition/transformer/pytorch.py +++ b/doctr/models/recognition/transformer/pytorch.py @@ -59,6 +59,7 @@ def __init__( dim_feedforward=dff, dropout=dropout, activation='relu', + batch_first=True, ) for _ in range(num_layers) ]) @@ -79,13 +80,11 @@ def forward( x += self.pos_encoding[:, :seq_len, :] x = self.dropout(x) - # Batch first = False in decoder - x = x.permute(1, 0, 2) + # Batch first = True in decoder for i in range(self.num_layers): x = self.dec_layers[i]( tgt=x, memory=enc_output, tgt_mask=look_ahead_mask, memory_mask=padding_mask ) # shape (batch_size, target_seq_len, d_model) - x = x.permute(1, 0, 2) return x