diff --git a/.gitignore b/.gitignore index d07de3b3..7eb4c4f5 100644 --- a/.gitignore +++ b/.gitignore @@ -7,6 +7,7 @@ __pycache__/ # Distribution / Packaging venv/ *.egg-info +build/ # Unit Test .pytest_cache/ diff --git a/faster_whisper/transcribe.py b/faster_whisper/transcribe.py index da23d50c..b190056f 100644 --- a/faster_whisper/transcribe.py +++ b/faster_whisper/transcribe.py @@ -8,13 +8,24 @@ from collections import Counter, defaultdict from inspect import signature from math import ceil -from typing import BinaryIO, Iterable, List, NamedTuple, Optional, Tuple, Union +from typing import ( + BinaryIO, + Iterable, + List, + Literal, + NamedTuple, + Optional, + Sequence, + Tuple, + Union, +) import ctranslate2 import numpy as np import tokenizers import torch +from pydantic import BaseModel from tqdm import tqdm from faster_whisper.audio import decode_audio, pad_or_trim @@ -51,36 +62,73 @@ class Segment(NamedTuple): temperature: Optional[float] = 1.0 +class _TranscriptionOptionsBase(BaseModel): + task: Literal["transcribe", "translate"] = "transcribe" + language: Optional[str] = None + beam_size: int = 5 + best_of: int = 5 + patience: float = 1 + length_penalty: float = 1 + repetition_penalty: float = 1 + no_repeat_ngram_size: int = 0 + temperature: Union[float, List[float], Tuple[float, ...]] = [ + 0.0, + 0.2, + 0.4, + 0.6, + 0.8, + 1.0, + ] + compression_ratio_threshold: Optional[float] = 2.4 + log_prob_threshold: Optional[float] = -1.0 + log_prob_low_threshold: Optional[float] = None + no_speech_threshold: Optional[float] = 0.6 + initial_prompt: Optional[Union[str, Iterable[int]]] = None + prefix: Optional[str] = None + suppress_blank: bool = True + suppress_tokens: Optional[List[int]] = [-1] + word_timestamps: bool = False + prepend_punctuations: str = "\"'“¿([{-" + append_punctuations: str = "\"'.。,,!!??::”)]}、" + vad_parameters: Optional[Union[dict, VadOptions]] = None + max_new_tokens: Optional[int] = None + chunk_length: Optional[int] = None + hotwords: Optional[str] = None + + # Added additional parameters for multilingual videos and fixes below -class TranscriptionOptions(NamedTuple): - beam_size: int - best_of: int - patience: float - length_penalty: float - repetition_penalty: float - no_repeat_ngram_size: int - log_prob_threshold: Optional[float] - log_prob_low_threshold: Optional[float] - no_speech_threshold: Optional[float] - compression_ratio_threshold: Optional[float] +class TranscriptionOptions(_TranscriptionOptionsBase): condition_on_previous_text: bool prompt_reset_on_temperature: float - temperatures: List[float] - initial_prompt: Optional[Union[str, Iterable[int]]] - prefix: Optional[str] - suppress_blank: bool - suppress_tokens: Optional[List[int]] + temperatures: Sequence[float] without_timestamps: bool max_initial_timestamp: float - word_timestamps: bool - prepend_punctuations: str - append_punctuations: str multilingual: bool output_language: Optional[str] - max_new_tokens: Optional[int] - clip_timestamps: Union[str, List[float]] + clip_timestamps: Optional[Union[str, List[dict], List[float]]] = None hallucination_silence_threshold: Optional[float] - hotwords: Optional[str] + + +class WhisperModelTranscriptionOptions(_TranscriptionOptionsBase): + condition_on_previous_text: bool = True + prompt_reset_on_temperature: float = 0.5 + without_timestamps: bool = False + max_initial_timestamp: float = 1.0 + multilingual: bool = False + output_language: Optional[str] = None + vad_filter: bool = False + clip_timestamps: Union[str, List[float]] = "0" + hallucination_silence_threshold: Optional[float] = None + language_detection_threshold: Optional[float] = None + language_detection_segments: int = 1 + + +class BatchTranscriptionOptions(_TranscriptionOptionsBase): + log_progress: bool = False + without_timestamps: bool = True + vad_filter: bool = True + clip_timestamps: Optional[List[dict]] = None + batch_size: int = 16 class TranscriptionInfo(NamedTuple): @@ -207,42 +255,7 @@ def get_language_and_tokenizer( def transcribe( self, audio: Union[str, BinaryIO, torch.Tensor, np.ndarray], - language: Optional[str] = None, - task: str = None, - log_progress: bool = False, - beam_size: int = 5, - best_of: int = 5, - patience: float = 1, - length_penalty: float = 1, - repetition_penalty: float = 1, - no_repeat_ngram_size: int = 0, - temperature: Union[float, List[float], Tuple[float, ...]] = [ - 0.0, - 0.2, - 0.4, - 0.6, - 0.8, - 1.0, - ], - compression_ratio_threshold: Optional[float] = 2.4, - log_prob_threshold: Optional[float] = -1.0, - log_prob_low_threshold: Optional[float] = None, - no_speech_threshold: Optional[float] = 0.6, - initial_prompt: Optional[Union[str, Iterable[int]]] = None, - prefix: Optional[str] = None, - suppress_blank: bool = True, - suppress_tokens: Optional[List[int]] = [-1], - without_timestamps: bool = True, - word_timestamps: bool = False, - prepend_punctuations: str = "\"'“¿([{-", - append_punctuations: str = "\"'.。,,!!??::”)]}、", - vad_filter: bool = True, - vad_parameters: Optional[Union[dict, VadOptions]] = None, - max_new_tokens: Optional[int] = None, - chunk_length: Optional[int] = None, - clip_timestamps: Optional[List[dict]] = None, - batch_size: int = 16, - hotwords: Optional[str] = None, + options: BatchTranscriptionOptions = BatchTranscriptionOptions(), ) -> Tuple[Iterable[Segment], TranscriptionInfo]: """transcribe audio in chunks in batched fashion and return with language info. @@ -341,88 +354,79 @@ def transcribe( audio = decode_audio(audio, sampling_rate=sampling_rate) duration = audio.shape[0] / sampling_rate - chunk_length = chunk_length or self.model.feature_extractor.chunk_length + chunk_length = options.chunk_length or self.model.feature_extractor.chunk_length # if no segment split is provided, use vad_model and generate segments - if not clip_timestamps: - if vad_filter: - if vad_parameters is None: - vad_parameters = VadOptions( + if not options.clip_timestamps: + if options.vad_filter: + if options.vad_parameters is None: + options.vad_parameters = VadOptions( max_speech_duration_s=chunk_length, min_silence_duration_ms=160, ) - elif isinstance(vad_parameters, dict): - if "max_speech_duration_s" in vad_parameters.keys(): - vad_parameters.pop("max_speech_duration_s") + elif isinstance(options.vad_parameters, dict): + if "max_speech_duration_s" in options.vad_parameters.keys(): + options.vad_parameters.pop("max_speech_duration_s") - vad_parameters = VadOptions( - **vad_parameters, max_speech_duration_s=chunk_length + options.vad_parameters = VadOptions( + **options.vad_parameters, max_speech_duration_s=chunk_length ) - active_segments = get_speech_timestamps(audio, vad_parameters) - clip_timestamps = merge_segments(active_segments, vad_parameters) + active_segments = get_speech_timestamps(audio, options.vad_parameters) + options.clip_timestamps = merge_segments( + active_segments, options.vad_parameters + ) # run the audio if it is less than 30 sec even without clip_timestamps elif duration < chunk_length: - clip_timestamps = [{"start": 0, "end": audio.shape[0]}] + options.clip_timestamps = [{"start": 0, "end": audio.shape[0]}] else: raise RuntimeError( "No clip timestamps found. " "Set 'vad_filter' to True or provide 'clip_timestamps'." ) if self.model.model.is_multilingual: - language = language or self.preset_language - elif language != "en": - if language is not None: + options.language = options.language or self.preset_language + elif options.language != "en": + if options.language is not None: self.model.logger.warning( - f"English-only model is used, but {language} language is" + f"English-only model is used, but {options.language} language is" " chosen, setting language to 'en'." ) - language = "en" + options.language = "en" ( language, language_probability, task, all_language_probs, - ) = self.get_language_and_tokenizer(audio, task, language) + ) = self.get_language_and_tokenizer(audio, options.task, options.language) duration_after_vad = ( - sum((segment["end"] - segment["start"]) for segment in clip_timestamps) + sum( + (segment["end"] - segment["start"]) + for segment in options.clip_timestamps + ) / sampling_rate ) # batched options: see the difference with default options in WhisperModel batched_options = TranscriptionOptions( - beam_size=beam_size, - best_of=best_of, - patience=patience, - length_penalty=length_penalty, - repetition_penalty=repetition_penalty, - no_repeat_ngram_size=no_repeat_ngram_size, - log_prob_threshold=log_prob_threshold, - log_prob_low_threshold=log_prob_low_threshold, - no_speech_threshold=no_speech_threshold, - compression_ratio_threshold=compression_ratio_threshold, + **options.model_dump(), temperatures=( - temperature if isinstance(temperature, (list, tuple)) else [temperature] + options.temperature + if isinstance(options.temperature, (list, tuple)) + else [options.temperature] ), - initial_prompt=initial_prompt, - prefix=prefix, - suppress_blank=suppress_blank, - suppress_tokens=get_suppressed_tokens(self.tokenizer, suppress_tokens), - prepend_punctuations=prepend_punctuations, - append_punctuations=append_punctuations, - max_new_tokens=max_new_tokens, - hotwords=hotwords, - word_timestamps=word_timestamps, hallucination_silence_threshold=None, condition_on_previous_text=False, - clip_timestamps="0", prompt_reset_on_temperature=0.5, multilingual=False, output_language=None, - without_timestamps=without_timestamps, max_initial_timestamp=0.0, ) + batched_options.clip_timestamps = None + batched_options.suppress_tokens = get_suppressed_tokens( + self.tokenizer, options.suppress_tokens + ) info = TranscriptionInfo( language=language, @@ -434,7 +438,7 @@ def transcribe( all_language_probs=all_language_probs, ) - audio_chunks, chunks_metadata = collect_chunks(audio, clip_timestamps) + audio_chunks, chunks_metadata = collect_chunks(audio, options.clip_timestamps) to_cpu = ( self.model.model.device == "cuda" and len(self.model.model.device_index) > 1 ) @@ -457,9 +461,9 @@ def transcribe( segments = self._batched_segments_generator( features, chunks_metadata, - batch_size, + options.batch_size, batched_options, - log_progress, + options.log_progress, ) return segments, info @@ -473,7 +477,7 @@ def _batched_segments_generator( results = self.forward( features[i : i + batch_size], chunks_metadata[i : i + batch_size], - **options._asdict(), + **options.model_dump(), ) for result in results: @@ -630,48 +634,7 @@ def _get_feature_kwargs(self, model_path, preprocessor_bytes=None) -> dict: def transcribe( self, audio: Union[str, BinaryIO, torch.Tensor, np.ndarray], - language: Optional[str] = None, - task: str = "transcribe", - beam_size: int = 5, - best_of: int = 5, - patience: float = 1, - length_penalty: float = 1, - repetition_penalty: float = 1, - no_repeat_ngram_size: int = 0, - temperature: Union[float, List[float], Tuple[float, ...]] = [ - 0.0, - 0.2, - 0.4, - 0.6, - 0.8, - 1.0, - ], - compression_ratio_threshold: Optional[float] = 2.4, - log_prob_threshold: Optional[float] = -1.0, - log_prob_low_threshold: Optional[float] = None, - no_speech_threshold: Optional[float] = 0.6, - condition_on_previous_text: bool = True, - prompt_reset_on_temperature: float = 0.5, - initial_prompt: Optional[Union[str, Iterable[int]]] = None, - prefix: Optional[str] = None, - suppress_blank: bool = True, - suppress_tokens: Optional[List[int]] = [-1], - without_timestamps: bool = False, - max_initial_timestamp: float = 1.0, - word_timestamps: bool = False, - prepend_punctuations: str = "\"'“¿([{-", - append_punctuations: str = "\"'.。,,!!??::”)]}、", - multilingual: bool = False, - output_language: Optional[str] = None, - vad_filter: bool = False, - vad_parameters: Optional[Union[dict, VadOptions]] = None, - max_new_tokens: Optional[int] = None, - chunk_length: Optional[int] = None, - clip_timestamps: Union[str, List[float]] = "0", - hallucination_silence_threshold: Optional[float] = None, - hotwords: Optional[str] = None, - language_detection_threshold: Optional[float] = None, - language_detection_segments: int = 1, + options: WhisperModelTranscriptionOptions = WhisperModelTranscriptionOptions(), ) -> Tuple[Iterable[Segment], TranscriptionInfo]: """Transcribes an input file. @@ -769,12 +732,12 @@ def transcribe( "Processing audio with duration %s", format_timestamp(duration) ) - if vad_filter and clip_timestamps == "0": - if vad_parameters is None: - vad_parameters = VadOptions() - elif isinstance(vad_parameters, dict): - vad_parameters = VadOptions(**vad_parameters) - speech_chunks = get_speech_timestamps(audio, vad_parameters) + if options.vad_filter and options.clip_timestamps == "0": + if options.vad_parameters is None: + options.vad_parameters = VadOptions() + elif isinstance(options.vad_parameters, dict): + options.vad_parameters = VadOptions(**options.vad_parameters) + speech_chunks = get_speech_timestamps(audio, options.vad_parameters) audio_chunks, chunks_metadata = collect_chunks(audio, speech_chunks) audio = torch.cat(audio_chunks, dim=0) duration_after_vad = audio.shape[0] / sampling_rate @@ -802,34 +765,34 @@ def transcribe( to_cpu = self.model.device == "cuda" and len(self.model.device_index) > 1 features = self.feature_extractor( - audio, chunk_length=chunk_length, to_cpu=to_cpu + audio, chunk_length=options.chunk_length, to_cpu=to_cpu ) encoder_output = None all_language_probs = None # setting output_language for multilingual videos - if multilingual: - if output_language is None: - output_language = "en" - elif output_language not in ["en", "hybrid"]: + if options.multilingual: + if options.output_language is None: + options.output_language = "en" + elif options.output_language not in ["en", "hybrid"]: raise ValueError("Output language needs to be one of 'en'/'hybrid'.") # detecting the language if not provided - if language is None: + if options.language is None: if not self.model.is_multilingual: language = "en" language_probability = 1 else: if ( - language_detection_segments is None - or language_detection_segments < 1 + options.language_detection_segments is None + or options.language_detection_segments < 1 ): - language_detection_segments = 1 + options.language_detection_segments = 1 start_timestamp = ( - float(clip_timestamps.split(",")[0]) - if isinstance(clip_timestamps, str) - else clip_timestamps[0] + float(options.clip_timestamps.split(",")[0]) + if isinstance(options.clip_timestamps, str) + else options.clip_timestamps[0] ) content_frames = ( features.shape[-1] - self.feature_extractor.nb_max_frames @@ -842,7 +805,7 @@ def transcribe( end_frames = min( seek + self.feature_extractor.nb_max_frames - * language_detection_segments, + * options.language_detection_segments, content_frames, ) detected_language_info = {} @@ -859,20 +822,20 @@ def transcribe( (token[2:-2], prob) for (token, prob) in results ] # Get top language token and probability - language, language_probability = all_language_probs[0] + options.language, language_probability = all_language_probs[0] if ( - language_detection_threshold is None - or language_probability > language_detection_threshold + options.language_detection_threshold is None + or language_probability > options.language_detection_threshold ): break - detected_language_info.setdefault(language, []).append( + detected_language_info.setdefault(options.language, []).append( language_probability ) seek += segment.shape[-1] else: # If no language detected for all segments, the majority vote of the highest # projected languages for all segments is used to determine the language. - language = max( + options.language = max( detected_language_info, key=lambda lang: len(detected_language_info[lang]), ) @@ -880,14 +843,14 @@ def transcribe( self.logger.info( "Detected language '%s' with probability %.2f", - language, + options.language, language_probability, ) else: - if not self.model.is_multilingual and language != "en": + if not self.model.is_multilingual and options.language != "en": self.logger.warning( "The current model is English-only but the language parameter is set to '%s'; " - "using 'en' instead." % language + "using 'en' instead." % options.language ) language = "en" @@ -896,59 +859,36 @@ def transcribe( tokenizer = Tokenizer( self.hf_tokenizer, self.model.is_multilingual, - task=task, - language=language, + task=options.task, + language=options.language, ) - options = TranscriptionOptions( - beam_size=beam_size, - best_of=best_of, - patience=patience, - length_penalty=length_penalty, - repetition_penalty=repetition_penalty, - no_repeat_ngram_size=no_repeat_ngram_size, - log_prob_threshold=log_prob_threshold, - log_prob_low_threshold=log_prob_low_threshold, - no_speech_threshold=no_speech_threshold, - compression_ratio_threshold=compression_ratio_threshold, - condition_on_previous_text=condition_on_previous_text, - prompt_reset_on_temperature=prompt_reset_on_temperature, + _options = TranscriptionOptions( + **options.model_dump(), temperatures=( - temperature if isinstance(temperature, (list, tuple)) else [temperature] - ), - initial_prompt=initial_prompt, - prefix=prefix, - suppress_blank=suppress_blank, - suppress_tokens=( - get_suppressed_tokens(tokenizer, suppress_tokens) - if suppress_tokens - else suppress_tokens + options.temperature + if isinstance(options.temperature, (list, tuple)) + else [options.temperature] ), - without_timestamps=without_timestamps, - max_initial_timestamp=max_initial_timestamp, - word_timestamps=word_timestamps, - prepend_punctuations=prepend_punctuations, - append_punctuations=append_punctuations, - multilingual=multilingual, - output_language=output_language, - max_new_tokens=max_new_tokens, - clip_timestamps=clip_timestamps, - hallucination_silence_threshold=hallucination_silence_threshold, - hotwords=hotwords, + ) + _options.suppress_tokens = ( + get_suppressed_tokens(tokenizer, options.suppress_tokens) + if options.suppress_tokens + else options.suppress_tokens ) - segments = self.generate_segments(features, tokenizer, options, encoder_output) + segments = self.generate_segments(features, tokenizer, _options, encoder_output) if speech_chunks: segments = restore_speech_timestamps(segments, speech_chunks, sampling_rate) info = TranscriptionInfo( - language=language, + language=options.language, language_probability=language_probability, duration=duration, duration_after_vad=duration_after_vad, - transcription_options=options, - vad_options=vad_parameters, + transcription_options=_options, + vad_options=options.vad_parameters, all_language_probs=all_language_probs, ) return segments, info @@ -1043,16 +983,14 @@ def generate_segments( content_duration = float(content_frames * self.feature_extractor.time_per_frame) if isinstance(options.clip_timestamps, str): - options = options._replace( - clip_timestamps=[ - float(ts) - for ts in ( - options.clip_timestamps.split(",") - if options.clip_timestamps - else [] - ) - ] - ) + options.clip_timestamps = [ + float(ts) + for ts in ( + options.clip_timestamps.split(",") + if options.clip_timestamps + else [] + ) + ] seek_points: List[int] = [ round(ts * self.frames_per_second) for ts in options.clip_timestamps ] diff --git a/requirements.txt b/requirements.txt index 71fc482e..27367532 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,4 +4,5 @@ tokenizers>=0.13,<1 onnxruntime>=1.14,<2 torch>=2.1.1 av>=11 -tqdm \ No newline at end of file +tqdm +pydantic>=2.9 \ No newline at end of file diff --git a/tests/test_transcribe.py b/tests/test_transcribe.py index 08cc3cc7..bf368019 100644 --- a/tests/test_transcribe.py +++ b/tests/test_transcribe.py @@ -2,7 +2,11 @@ from faster_whisper import BatchedInferencePipeline, WhisperModel, decode_audio from faster_whisper.tokenizer import Tokenizer -from faster_whisper.transcribe import get_suppressed_tokens +from faster_whisper.transcribe import ( + BatchTranscriptionOptions, + WhisperModelTranscriptionOptions, + get_suppressed_tokens, +) def test_supported_languages(): @@ -12,7 +16,9 @@ def test_supported_languages(): def test_transcribe(jfk_path): model = WhisperModel("tiny") - segments, info = model.transcribe(jfk_path, word_timestamps=True) + segments, info = model.transcribe( + jfk_path, WhisperModelTranscriptionOptions(word_timestamps=True) + ) assert info.all_language_probs is not None assert info.language == "en" @@ -41,7 +47,7 @@ def test_transcribe(jfk_path): assert segment.end == segment.words[-1].end batched_model = BatchedInferencePipeline(model=model) result, info = batched_model.transcribe( - jfk_path, word_timestamps=True, vad_filter=False + jfk_path, BatchTranscriptionOptions(word_timestamps=True, vad_filter=False) ) assert info.language == "en" assert info.language_probability > 0.7 @@ -61,7 +67,9 @@ def test_transcribe(jfk_path): def test_batched_transcribe(physcisworks_path): model = WhisperModel("tiny") batched_model = BatchedInferencePipeline(model=model) - result, info = batched_model.transcribe(physcisworks_path, batch_size=16) + result, info = batched_model.transcribe( + physcisworks_path, BatchTranscriptionOptions(batch_size=16) + ) assert info.language == "en" assert info.language_probability > 0.7 segments = [] @@ -74,9 +82,11 @@ def test_batched_transcribe(physcisworks_path): result, info = batched_model.transcribe( physcisworks_path, - batch_size=16, - without_timestamps=False, - word_timestamps=True, + BatchTranscriptionOptions( + batch_size=16, + without_timestamps=False, + word_timestamps=True, + ), ) segments = [] for segment in result: @@ -89,7 +99,9 @@ def test_batched_transcribe(physcisworks_path): def test_prefix_with_timestamps(jfk_path): model = WhisperModel("tiny") - segments, _ = model.transcribe(jfk_path, prefix="And so my fellow Americans") + segments, _ = model.transcribe( + jfk_path, WhisperModelTranscriptionOptions(prefix="And so my fellow Americans") + ) segments = list(segments) assert len(segments) == 1 @@ -109,8 +121,10 @@ def test_vad(jfk_path): model = WhisperModel("tiny") segments, info = model.transcribe( jfk_path, - vad_filter=True, - vad_parameters=dict(min_silence_duration_ms=500, speech_pad_ms=200), + WhisperModelTranscriptionOptions( + vad_filter=True, + vad_parameters=dict(min_silence_duration_ms=500, speech_pad_ms=200), + ), ) segments = list(segments)