diff --git a/README.md b/README.md index 13ed0d05..45a579d9 100644 --- a/README.md +++ b/README.md @@ -164,17 +164,6 @@ segments, _ = model.transcribe("audio.mp3") segments = list(segments) # The transcription will actually run here. ``` -### Multi-Segment Language Detection - -To directly use the model for improved language detection, the following code snippet can be used: - -```python -from faster_whisper import WhisperModel - -model = WhisperModel("turbo", device="cuda", compute_type="float16") -language_info = model.detect_language_multi_segment("audio.mp3") -``` - ### Batched Transcription The following code snippet illustrates how to run batched transcription on an example audio file. `BatchedInferencePipeline.transcribe` is a drop-in replacement for `WhisperModel.transcribe` diff --git a/faster_whisper/transcribe.py b/faster_whisper/transcribe.py index 80e5d92c..d3d2bdf7 100644 --- a/faster_whisper/transcribe.py +++ b/faster_whisper/transcribe.py @@ -2,10 +2,8 @@ import json import logging import os -import random import zlib -from collections import Counter, defaultdict from dataclasses import asdict, dataclass from inspect import signature from math import ceil @@ -194,45 +192,11 @@ def forward(self, features, chunks_metadata, **forward_params): return segmented_outputs - def get_language_and_tokenizer( - self, audio, task: Optional[str] = None, language: Optional[str] = None - ): - all_language_probs = None - language_probability = 1.0 - - if self.tokenizer is None: - if not language: - ( - language, - language_probability, - all_language_probs, - ) = self.model.detect_language(audio) - task = task or "transcribe" - self.tokenizer = Tokenizer( - self.model.hf_tokenizer, - self.model.model.is_multilingual, - task=task, - language=language, - ) - else: - if task is not None: - self.tokenizer.task = self.tokenizer.tokenizer.token_to_id( - f"<|{task}|>" - ) - - if language is not None: - self.tokenizer.language = self.tokenizer.tokenizer.token_to_id( - f"<|{language}|>" - ) - self.tokenizer.language_code = language - - return language, language_probability, task, all_language_probs - def transcribe( self, audio: Union[str, BinaryIO, np.ndarray], language: Optional[str] = None, - task: str = None, + task: str = "transcribe", log_progress: bool = False, beam_size: int = 5, best_of: int = 5, @@ -267,6 +231,8 @@ def transcribe( clip_timestamps: Optional[List[dict]] = None, batch_size: int = 16, hotwords: Optional[str] = None, + language_detection_threshold: Optional[float] = 0.5, + language_detection_segments: int = 1, ) -> Tuple[Iterable[Segment], TranscriptionInfo]: """transcribe audio in chunks in batched fashion and return with language info. @@ -326,6 +292,9 @@ def transcribe( batch_size: the maximum number of parallel requests to model for decoding. hotwords: Hotwords/hint phrases to the model. Has no effect if prefix is not None. + language_detection_threshold: If the maximum probability of the language tokens is + higher than this value, the language is detected. + language_detection_segments: Number of segments to consider for the language detection. Static params: (Fixed for batched version) max_initial_timestamp: The initial timestamp cannot be later than this, set at 0.0. @@ -390,28 +359,68 @@ def transcribe( "No clip timestamps found. " "Set 'vad_filter' to True or provide 'clip_timestamps'." ) - if self.model.model.is_multilingual: - language = language or self.preset_language - elif language != "en": - if language is not None: - self.model.logger.warning( - f"English-only model is used, but {language} language is" - " chosen, setting language to 'en'." - ) - language = "en" - - ( - language, - language_probability, - task, - all_language_probs, - ) = self.get_language_and_tokenizer(audio, task, language) duration_after_vad = ( sum((segment["end"] - segment["start"]) for segment in clip_timestamps) / sampling_rate ) + audio_chunks, chunks_metadata = collect_chunks(audio, clip_timestamps) + features = ( + [self.model.feature_extractor(chunk)[..., :-1] for chunk in audio_chunks] + if duration_after_vad + else [] + ) + + all_language_probs = None + # detecting the language if not provided + if language is None: + if not self.model.model.is_multilingual: + language = "en" + language_probability = 1 + else: + ( + language, + language_probability, + all_language_probs, + ) = self.model.detect_language( + features=np.concatenate( + features + + [ + np.full((self.model.model.n_mels, 1), -1.5, dtype="float32") + ], + axis=1, + ), # add a dummy feature to account for empty audio + language_detection_segments=language_detection_segments, + language_detection_threshold=language_detection_threshold, + ) + + self.model.logger.info( + "Detected language '%s' with probability %.2f", + language, + language_probability, + ) + else: + if not self.model.model.is_multilingual and language != "en": + self.model.logger.warning( + "The current model is English-only but the language parameter is set to '%s'; " + "using 'en' instead." % language + ) + language = "en" + + language_probability = 1 + + self.tokenizer = Tokenizer( + self.model.hf_tokenizer, + self.model.model.is_multilingual, + task=task, + language=language, + ) + + features = ( + np.stack([pad_or_trim(feature) for feature in features]) if features else [] + ) + # batched options: see the difference with default options in WhisperModel batched_options = TranscriptionOptions( beam_size=beam_size, @@ -456,23 +465,6 @@ def transcribe( all_language_probs=all_language_probs, ) - audio_chunks, chunks_metadata = collect_chunks(audio, clip_timestamps) - features = ( - np.stack( - [ - pad_or_trim( - self.model.feature_extractor(chunk)[ - ..., - : chunk.shape[0] // self.model.feature_extractor.hop_length, - ] - ) - for chunk in audio_chunks - ] - ) - if duration_after_vad - else [] - ) - segments = self._batched_segments_generator( features, chunks_metadata, @@ -518,9 +510,6 @@ def _batched_segments_generator( pbar.update(1) pbar.close() - # revert the tokenizer if multilingual inference is enabled - if self.preset_language is None: - self.tokenizer = None self.last_speech_timestamp = 0.0 @@ -835,11 +824,6 @@ def transcribe( language = "en" language_probability = 1 else: - if ( - language_detection_segments is None - or language_detection_segments < 1 - ): - language_detection_segments = 1 start_timestamp = ( float(clip_timestamps.split(",")[0]) if isinstance(clip_timestamps, str) @@ -851,41 +835,15 @@ def transcribe( if start_timestamp * self.frames_per_second < content_frames else 0 ) - end_frames = min( - seek - + self.feature_extractor.nb_max_frames - * language_detection_segments, - content_frames, + ( + language, + language_probability, + all_language_probs, + ) = self.detect_language( + features=features[..., seek:], + language_detection_segments=language_detection_segments, + language_detection_threshold=language_detection_threshold, ) - detected_language_info = {} - while seek <= end_frames: - segment = features[ - :, seek : seek + self.feature_extractor.nb_max_frames - ] - encoder_output = self.encode(pad_or_trim(segment)) - # results is a list of tuple[str, float] with language names and - # probabilities. - results = self.model.detect_language(encoder_output)[0] - # Parse language names to strip out markers - all_language_probs = [ - (token[2:-2], prob) for (token, prob) in results - ] - # Get top language token and probability - language, language_probability = all_language_probs[0] - if language_probability > language_detection_threshold: - break - detected_language_info.setdefault(language, []).append( - language_probability - ) - seek += segment.shape[-1] - else: - # If no language detected for all segments, the majority vote of the highest - # projected languages for all segments is used to determine the language. - language = max( - detected_language_info, - key=lambda lang: len(detected_language_info[lang]), - ) - language_probability = max(detected_language_info[language]) self.logger.info( "Detected language '%s' with probability %.2f", @@ -1782,223 +1740,80 @@ def generate_segment_batched( return encoder_output, output - def detect_language(self, audio: np.ndarray): - segment = self.feature_extractor(audio)[ - :, : self.feature_extractor.nb_max_frames - ] - encoder_output = self.encode(pad_or_trim(segment)) - results = self.model.detect_language(encoder_output) - language_token, language_probability = results[0][0] - language = language_token[2:-2] - self.logger.info( - f"Detected language: {language} ({language_probability:.2f}) in first 30s of audio..." - ) - all_language_probs = [(token[2:-2], prob) for (token, prob) in results[0]] - return language, language_probability, all_language_probs - - def detect_language_multi_segment( - self, audio: Union[str, BinaryIO, np.ndarray], params: Optional[dict] = None - ): - """ - Detect language based on N highly-confident segments of a language. + def detect_language( + self, + audio: Optional[np.ndarray] = None, + features: Optional[np.ndarray] = None, + vad_filter: bool = False, + vad_parameters: Union[dict, VadOptions] = None, + language_detection_segments: int = 1, + language_detection_threshold: float = 0.5, + ) -> Tuple[str, float, List[Tuple[str, float]]]: """ - # The threshold is used to decide if the audio is silence or not. - # The default is 0.02 (2.0%) i.e, if more than 2.0% of the audio is silent, - # the audio is considered as silence. - if not params: - params = { - "multilingual": False, - "speech_percentage_threshold": 0.02, - "language_detection_segments": 4, - "vad_filter": True, - "vad_min_silence_duration": 2500, - "language_threshold": 0.7, - } - - if params.get("multilingual", False): - logging.warning( - "lang_id is not supported for multilingual audios, detecting the major language." - ) - - speech_percentage_threshold = params.get("speech_percentage_threshold", 0.02) - language_threshold = params.get("language_threshold", 0.7) - num_detection_segments = params.get("language_detection_segments", 4) - vad_filter_enabled = params.get("vad_filter", True) - vad_params = dict( - min_silence_duration_ms=params.get("vad_min_silence_duration", 2500) - ) - - if vad_filter_enabled: - vad_params = VadOptions(**vad_params) + Use Whisper to detect the language of the input audio or features. - # decode audio if it is not decoded already - sampling_rate = self.feature_extractor.sampling_rate - if not isinstance(audio, np.ndarray): - audio: np.ndarray = decode_audio(audio, sampling_rate=sampling_rate) - - # calculate duration of audio as number of seconds - # audio.shape[0] is the number of samples in the audio - # sampling_rate is the number of samples per second - # if we divide the number of samples by the number of samples per second, - # we get the duration in seconds - duration = audio.shape[0] / sampling_rate - - # Check if vad is enabled, and collect voiced segments - if vad_filter_enabled: - # get chunks of audio that contain speech - speech_chunks = get_speech_timestamps(audio, vad_params) - # merge chunks of audio that contain speech into a single array - audio_chunks, chunks_metadata = collect_chunks(audio, speech_chunks) - audio = np.concatenate(audio_chunks, axis=0) - - # calculate new duration of audio without silence - duration_vad = audio.shape[0] / sampling_rate - - logging.debug( - f"Lang ID: VAD filter removed {duration - duration_vad} sec of audio" - ) - - # if the audio after VAD is less than 2% of the original audio, consider it as silence - if duration_vad / duration < speech_percentage_threshold: - return {"language_code": None, "language_confidence": 1.0} - - # update duration to be the duration after VAD - duration = duration_vad - - # if the duration of the audio is less than 1 second, consider it as silence - if duration < 1.0: - return {"language_code": None, "language_confidence": 1.0} + Arguments: + audio: Input audio signal, must be a 1D float array sampled at 16khz. + features: Input Mel spectrogram features, must be a float array with + shape (n_mels, n_frames), if `audio` is provided, the features will be ignored. + Either `audio` or `features` must be provided. + vad_filter: Enable the voice activity detection (VAD) to filter out parts of the audio + without speech. This step is using the Silero VAD model. + vad_parameters: Dictionary of Silero VAD parameters or VadOptions class (see available + parameters and default values in the class `VadOptions`). + language_detection_threshold: If the maximum probability of the language tokens is + higher than this value, the language is detected. + language_detection_segments: Number of segments to consider for the language detection. - # number of feature frames in 30 seconds of audio is 3000 - nb_max_frames = self.feature_extractor.nb_max_frames + Returns: + language: Detected language. + languege_probability: Probability of the detected language. + all_language_probs: List of tuples with all language names and probabilities. + """ + assert ( + audio is not None or features is not None + ), "Either `audio` or `features` must be provided." - # extract features from audio with padding (default) - features = self.feature_extractor(audio) + if audio is not None: + if vad_filter: + speech_chunks = get_speech_timestamps(audio, vad_parameters) + audio_chunks, chunks_metadata = collect_chunks(audio, speech_chunks) + audio = np.concatenate(audio_chunks, axis=0) - # number of segments in the audio - num_segments = features.shape[-1] // nb_max_frames - # more number of segments than possible with the duration of file - if num_detection_segments > num_segments: - logging.warning( - f"Lang ID: Can not have more segments, setting {num_segments} segments." - ) - num_detection_segments = num_segments - - # create a list of indices to randomly select segments from - indices = list(range(num_detection_segments)) - - # fix seed to get deterministic results - random.seed(0) - random.shuffle(indices) - - detected_languages = [] - all_language_probabilities = defaultdict(list) - confident_language_probabilities = defaultdict(list) - num_confident_segments_per_language = defaultdict(int) - - # Iterate over the randomly selected indices of the segments. - # - # For each segment, extract features and detect language. - # - # If the language is confident, add it to the list of confident segments for that language. - # - # If the number of confident segments for a language - # is greater than or equal to the number of detection segments, - # return the language and the average probability of the language. - # - # If we are unable to get sufficient number of confident predcitions, - # return the most frequently detected language with maximum probability. - # - # We need to get sufficient number of confident predictions per language, not in total. - - for i in indices: - segment_features = features[:, i * nb_max_frames : (i + 1) * nb_max_frames] - try: - encoder_output = self.encode(pad_or_trim(segment_features)) - results = self.model.detect_language(encoder_output)[0] - - except ValueError as e: # or RuntimeError - logging.error(f"Inference error:{e}") - - # results is the list of classes (languages) and their probabilities (descending), - # for eg: [('<|de|>', 0.482177734375),('<|en|>', 0.283447265625),...] - - # take top language token and probability - # and parse language token to strip out markers - # for eg: '<|de|>' -> 'de' - - language_token = results[0][0] - language = language_token[2:-2] - - language_probability = results[0][1] - - detected_languages.append(language) - all_language_probabilities[language].append(language_probability) - - # only consider if the language prediction is confident - if language_probability > language_threshold: - num_confident_segments_per_language[language] += 1 - - # Add language and probability to the list of languages when it is confident - confident_language_probabilities[language].append(language_probability) - - # return the language when sufficient number of confident segments is achieved - if ( - num_confident_segments_per_language[language] - >= num_detection_segments - ): - # Considering the average probability of only confident segments - mean = sum(confident_language_probabilities[language]) / len( - confident_language_probabilities[language] - ) - return { - "language_code": language, - "language_confidence": mean, - } - - # if we are unable to get sufficient number of confident predictions, - # return the most frequently detected language. - # if there is a tie, return the one with maximum average probability. - counter = Counter(detected_languages) - - # Define the key function to select frequent language with attached probabilities - def key_func(language): - # Calculate the frequency of the language - frequency = counter[language] - - # Calculate the average probability of the language - prob_avg = sum(all_language_probabilities[language]) / len( - all_language_probabilities[language] - ) + audio = audio[ + : language_detection_segments * self.feature_extractor.n_samples + ] + features = self.feature_extractor(audio) - return frequency, prob_avg + features = features[ + ..., : language_detection_segments * self.feature_extractor.nb_max_frames + ] - if detected_languages: - # Use the key function to find the language with maximum frequency and probability - max_language = max(detected_languages, key=key_func) - max_probability = sum(all_language_probabilities[max_language]) / len( - all_language_probabilities[max_language] + detected_language_info = {} + for i in range(0, features.shape[-1], self.feature_extractor.nb_max_frames): + encoder_output = self.encode( + pad_or_trim(features[..., i : i + self.feature_extractor.nb_max_frames]) ) - - # Do additional checks for silence for non-confident case - # calculate RMS amplitude and DC offset - dc_offset = audio.mean() - audio_minus_dc_offset = audio - dc_offset - is_silent = ( - all(np.abs(audio) < 0.1) - or np.sqrt(np.mean(audio_minus_dc_offset**2)) < 0.01 + # results is a list of tuple[str, float] with language names and probabilities. + results = self.model.detect_language(encoder_output)[0] + + # Parse language names to strip out markers + all_language_probs = [(token[2:-2], prob) for (token, prob) in results] + # Get top language token and probability + language, language_probability = all_language_probs[0] + if language_probability > language_detection_threshold: + break + detected_language_info.setdefault(language, []).append(language_probability) + else: + # If no language detected for all segments, the majority vote of the highest + # projected languages for all segments is used to determine the language. + language = max( + detected_language_info, + key=lambda lang: len(detected_language_info[lang]), ) + language_probability = max(detected_language_info[language]) - if is_silent: - return {"language_code": None, "language_confidence": 1.0} - - return { - "language_code": max_language, - "language_confidence": max_probability, - } - - # Language is not detected for any segment and none of prev conditions met - return {"language_code": None, "language_confidence": 1.0} + return language, language_probability, all_language_probs def restore_speech_timestamps( diff --git a/tests/test_transcribe.py b/tests/test_transcribe.py index e25af3ac..710b06a2 100644 --- a/tests/test_transcribe.py +++ b/tests/test_transcribe.py @@ -1,5 +1,7 @@ import os +import numpy as np + from faster_whisper import BatchedInferencePipeline, WhisperModel, decode_audio from faster_whisper.tokenizer import Tokenizer from faster_whisper.transcribe import get_suppressed_tokens @@ -87,6 +89,15 @@ def test_batched_transcribe(physcisworks_path): assert len(segments) > 7 +def test_empty_audio(): + audio = np.asarray([], dtype="float32") + model = WhisperModel("tiny") + pipeline = BatchedInferencePipeline(model=model) + assert list(model.transcribe(audio)[0]) == [] + assert list(pipeline.transcribe(audio)[0]) == [] + model.detect_language(audio) + + def test_prefix_with_timestamps(jfk_path): model = WhisperModel("tiny") segments, _ = model.transcribe(jfk_path, prefix="And so my fellow Americans") @@ -147,13 +158,6 @@ def test_stereo_diarization(data_dir): assert transcription == "The horizon seems extremely distant." -def test_multisegment_lang_id(physcisworks_path): - model = WhisperModel("tiny") - language_info = model.detect_language_multi_segment(physcisworks_path) - assert language_info["language_code"] == "en" - assert language_info["language_confidence"] > 0.8 - - def test_suppressed_tokens_minus_1(): model = WhisperModel("tiny.en")