diff --git a/whisper/decoding.py b/whisper/decoding.py index 2592ba9b6..457ee7ccb 100644 --- a/whisper/decoding.py +++ b/whisper/decoding.py @@ -469,9 +469,7 @@ def apply(self, logits: Tensor, tokens: Tensor): ] if timestamps.numel() > 0: # timestamps shouldn't decrease; forbid timestamp tokens smaller than the last - logits[k, self.tokenizer.timestamp_begin : timestamps[-1]] = -np.inf - - # to force that timestamps are strictly increasing + # also force each segment to have a nonzero length, to prevent infinite looping if last_was_timestamp and not penultimate_was_timestamp: timestamp_last = timestamps[-1] else: