diff --git a/MANIFEST.in b/MANIFEST.in index 6f6187c0..8a103dd6 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,3 +1,4 @@ include faster_whisper/assets/silero_vad.onnx include requirements.txt include requirements.conversion.txt +include faster_whisper/assets/pyannote_vad_model.bin diff --git a/README.md b/README.md index e57edbf3..f7d54ee4 100644 --- a/README.md +++ b/README.md @@ -69,7 +69,6 @@ segments, info = model.transcribe("audio.mp3", beam_size=5, language="en") * Python 3.8 or greater -Unlike openai-whisper, FFmpeg does **not** need to be installed on the system. The audio is decoded with the Python library [PyAV](https://github.com/PyAV-Org/PyAV) which bundles the FFmpeg libraries in its package. ### GPU @@ -166,6 +165,35 @@ for segment in segments: segments, _ = model.transcribe("audio.mp3") segments = list(segments) # The transcription will actually run here. ``` + +### multi-segment language detection + +To directly use the model for improved language detection, the following code snippet can be used: + +```python +from faster_whisper import WhisperModel +model = WhisperModel("medium", device="cuda", compute_type="float16") +language_info = model.detect_language_multi_segment("audio.mp3") +``` + +### Batched faster-whisper + + +The batched version of faster-whisper is inspired by [whisper-x](https://github.com/m-bain/whisperX) licensed under the BSD-2 Clause license and integrates its VAD model to this library. We modify this implementation and also replaced the feature extraction with a faster torch-based implementation. Batched version improves the speed upto 10-12x compared to openAI implementation and 3-4x compared to the sequential faster_whisper version. It works by transcribing semantically meaningful audio chunks as batches leading to faster inference. + +The following code snippet illustrates how to run inference with batched version on an example audio file. Please also refer to the test scripts of batched faster whisper. + +```python +from faster_whisper import WhisperModel, BatchedInferencePipeline + +model = WhisperModel("medium", device="cuda", compute_type="float16") +batched_model = BatchedInferencePipeline(model=model) +segments, info = batched_model.transcribe("audio.mp3", batch_size=16) + +for segment in segments: + print("[%.2fs -> %.2fs] %s" % (segment.start, segment.end, segment.text)) +``` + ### Faster Distil-Whisper The Distil-Whisper checkpoints are compatible with the Faster-Whisper package. In particular, the latest [distil-large-v3](https://huggingface.co/distil-whisper/distil-large-v3) diff --git a/benchmark/wer_benchmark.py b/benchmark/wer_benchmark.py index bf0a1e0e..f7a0b792 100644 --- a/benchmark/wer_benchmark.py +++ b/benchmark/wer_benchmark.py @@ -1,5 +1,6 @@ import argparse import json +import os from datasets import load_dataset from evaluate import load @@ -26,7 +27,9 @@ # define the evaluation metric wer_metric = load("wer") -normalizer = EnglishTextNormalizer(json.load(open("normalizer.json"))) + +with open(os.path.join(os.path.dirname(__file__), "normalizer.json"), "r") as f: + normalizer = EnglishTextNormalizer(json.load(f)) def inference(batch): diff --git a/faster_whisper/__init__.py b/faster_whisper/__init__.py index 9b56a393..ad692778 100644 --- a/faster_whisper/__init__.py +++ b/faster_whisper/__init__.py @@ -1,5 +1,5 @@ from faster_whisper.audio import decode_audio -from faster_whisper.transcribe import WhisperModel +from faster_whisper.transcribe import BatchedInferencePipeline, WhisperModel from faster_whisper.utils import available_models, download_model, format_timestamp from faster_whisper.version import __version__ @@ -7,6 +7,7 @@ "available_models", "decode_audio", "WhisperModel", + "BatchedInferencePipeline", "download_model", "format_timestamp", "__version__", diff --git a/faster_whisper/assets/pyannote_vad_model.bin b/faster_whisper/assets/pyannote_vad_model.bin new file mode 100644 index 00000000..75c92f09 Binary files /dev/null and b/faster_whisper/assets/pyannote_vad_model.bin differ diff --git a/faster_whisper/audio.py b/faster_whisper/audio.py index a597fd83..7ae68d40 100644 --- a/faster_whisper/audio.py +++ b/faster_whisper/audio.py @@ -1,19 +1,7 @@ -"""We use the PyAV library to decode the audio: https://github.com/PyAV-Org/PyAV - -The advantage of PyAV is that it bundles the FFmpeg libraries so there is no additional -system dependencies. FFmpeg does not need to be installed on the system. - -However, the API is quite low-level so we need to manipulate audio frames directly. -""" - -import gc -import io -import itertools - from typing import BinaryIO, Union -import av -import numpy as np +import torch +import torchaudio def decode_audio( @@ -29,91 +17,42 @@ def decode_audio( split_stereo: Return separate left and right channels. Returns: - A float32 Numpy array. + A float32 Torch Tensor. If `split_stereo` is enabled, the function returns a 2-tuple with the separated left and right channels. """ - resampler = av.audio.resampler.AudioResampler( - format="s16", - layout="mono" if not split_stereo else "stereo", - rate=sampling_rate, - ) - - raw_buffer = io.BytesIO() - dtype = None - with av.open(input_file, mode="r", metadata_errors="ignore") as container: - frames = container.decode(audio=0) - frames = _ignore_invalid_frames(frames) - frames = _group_frames(frames, 500000) - frames = _resample_frames(frames, resampler) - - for frame in frames: - array = frame.to_ndarray() - dtype = array.dtype - raw_buffer.write(array) - - # It appears that some objects related to the resampler are not freed - # unless the garbage collector is manually run. - del resampler - gc.collect() - - audio = np.frombuffer(raw_buffer.getbuffer(), dtype=dtype) - - # Convert s16 back to f32. - audio = audio.astype(np.float32) / 32768.0 + waveform, audio_sf = torchaudio.load(input_file) # waveform: channels X T + if audio_sf != sampling_rate: + waveform = torchaudio.functional.resample( + waveform, orig_freq=audio_sf, new_freq=sampling_rate + ) if split_stereo: - left_channel = audio[0::2] - right_channel = audio[1::2] - return left_channel, right_channel - - return audio - - -def _ignore_invalid_frames(frames): - iterator = iter(frames) - - while True: - try: - yield next(iterator) - except StopIteration: - break - except av.error.InvalidDataError: - continue - - -def _group_frames(frames, num_samples=None): - fifo = av.audio.fifo.AudioFifo() - - for frame in frames: - frame.pts = None # Ignore timestamp check. - fifo.write(frame) - - if num_samples is not None and fifo.samples >= num_samples: - yield fifo.read() - - if fifo.samples > 0: - yield fifo.read() - + return waveform[0], waveform[1] -def _resample_frames(frames, resampler): - # Add None to flush the resampler. - for frame in itertools.chain(frames, [None]): - yield from resampler.resample(frame) + return waveform.mean(0) def pad_or_trim(array, length: int, *, axis: int = -1): """ Pad or trim the audio array to N_SAMPLES, as expected by the encoder. """ + axis = axis % array.ndim if array.shape[axis] > length: - array = array.take(indices=range(length), axis=axis) + idx = [Ellipsis] * axis + [slice(length)] + [Ellipsis] * (array.ndim - axis - 1) + return array[idx] if array.shape[axis] < length: - pad_widths = [(0, 0)] * array.ndim - pad_widths[axis] = (0, length - array.shape[axis]) - array = np.pad(array, pad_widths) + pad_widths = ( + [ + 0, + ] + * array.ndim + * 2 + ) + pad_widths[2 * axis] = length - array.shape[axis] + array = torch.nn.functional.pad(array, tuple(pad_widths[::-1])) return array diff --git a/faster_whisper/feature_extractor.py b/faster_whisper/feature_extractor.py index 0aa15070..6371d5ef 100644 --- a/faster_whisper/feature_extractor.py +++ b/faster_whisper/feature_extractor.py @@ -1,16 +1,21 @@ -import numpy as np +import torch # Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/feature_extraction_whisper.py # noqa: E501 class FeatureExtractor: def __init__( self, + device: str = "auto", feature_size=80, sampling_rate=16000, hop_length=160, chunk_length=30, n_fft=400, ): + if device == "auto": + self.device = "cuda" if torch.cuda.is_available() else "cpu" + else: + self.device = device self.n_fft = n_fft self.hop_length = hop_length self.chunk_length = chunk_length @@ -22,21 +27,22 @@ def __init__( sampling_rate, n_fft, n_mels=feature_size ) - def get_mel_filters(self, sr, n_fft, n_mels=128, dtype=np.float32): + @staticmethod + def get_mel_filters(sr, n_fft, n_mels=128): + """ + Implementation of librosa.filters.mel in Pytorch + """ # Initialize the weights n_mels = int(n_mels) - weights = np.zeros((n_mels, int(1 + n_fft // 2)), dtype=dtype) # Center freqs of each FFT bin - fftfreqs = np.fft.rfftfreq(n=n_fft, d=1.0 / sr) + fftfreqs = torch.fft.rfftfreq(n=n_fft, d=1.0 / sr) # 'Center freqs' of mel bands - uniformly spaced between limits min_mel = 0.0 max_mel = 45.245640471924965 - mels = np.linspace(min_mel, max_mel, n_mels + 2) - - mels = np.asanyarray(mels) + mels = torch.linspace(min_mel, max_mel, n_mels + 2) # Fill in the linear scale f_min = 0.0 @@ -46,125 +52,63 @@ def get_mel_filters(self, sr, n_fft, n_mels=128, dtype=np.float32): # And now the nonlinear scale min_log_hz = 1000.0 # beginning of log region (Hz) min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels) - logstep = np.log(6.4) / 27.0 # step size for log region + logstep = torch.log(torch.tensor(6.4)) / 27.0 # step size for log region # If we have vector data, vectorize log_t = mels >= min_log_mel - freqs[log_t] = min_log_hz * np.exp(logstep * (mels[log_t] - min_log_mel)) + freqs[log_t] = min_log_hz * torch.exp(logstep * (mels[log_t] - min_log_mel)) mel_f = freqs - fdiff = np.diff(mel_f) - ramps = np.subtract.outer(mel_f, fftfreqs) + fdiff = torch.diff(mel_f) + ramps = mel_f.view(-1, 1) - fftfreqs.view(1, -1) - for i in range(n_mels): - # lower and upper slopes for all bins - lower = -ramps[i] / fdiff[i] - upper = ramps[i + 2] / fdiff[i + 1] + lower = -ramps[:-2] / fdiff[:-1].unsqueeze(1) + upper = ramps[2:] / fdiff[1:].unsqueeze(1) - # .. then intersect them with each other and zero - weights[i] = np.maximum(0, np.minimum(lower, upper)) + # Intersect them with each other and zero, vectorized across all i + weights = torch.maximum(torch.zeros_like(lower), torch.minimum(lower, upper)) # Slaney-style mel is scaled to be approx constant energy per channel enorm = 2.0 / (mel_f[2 : n_mels + 2] - mel_f[:n_mels]) - weights *= enorm[:, np.newaxis] + weights *= enorm.unsqueeze(1) return weights - def fram_wave(self, waveform, center=True): - """ - Transform a raw waveform into a list of smaller waveforms. - The window length defines how much of the signal is - contain in each frame (smalle waveform), while the hope length defines the step - between the beginning of each new frame. - Centering is done by reflecting the waveform which is first centered around - `frame_idx * hop_length`. + def __call__(self, waveform, padding=True, chunk_length=None, to_cpu=False): """ - frames = [] - for i in range(0, waveform.shape[0] + 1, self.hop_length): - half_window = (self.n_fft - 1) // 2 + 1 - if center: - start = i - half_window if i > half_window else 0 - end = ( - i + half_window - if i < waveform.shape[0] - half_window - else waveform.shape[0] - ) - - frame = waveform[start:end] - - if start == 0: - padd_width = (-i + half_window, 0) - frame = np.pad(frame, pad_width=padd_width, mode="reflect") - - elif end == waveform.shape[0]: - padd_width = (0, (i - waveform.shape[0] + half_window)) - frame = np.pad(frame, pad_width=padd_width, mode="reflect") - - else: - frame = waveform[i : i + self.n_fft] - frame_width = frame.shape[0] - if frame_width < waveform.shape[0]: - frame = np.lib.pad( - frame, - pad_width=(0, self.n_fft - frame_width), - mode="constant", - constant_values=0, - ) - - frames.append(frame) - return np.stack(frames, 0) - - def stft(self, frames, window): + Compute the log-Mel spectrogram of the provided audio. """ - Calculates the complex Short-Time Fourier Transform (STFT) of the given framed signal. - Should give the same results as `torch.stft`. - """ - frame_size = frames.shape[1] - fft_size = self.n_fft - - if fft_size is None: - fft_size = frame_size - - if fft_size < frame_size: - raise ValueError("FFT size must greater or equal the frame size") - # number of FFT bins to store - num_fft_bins = (fft_size >> 1) + 1 - - data = np.empty((len(frames), num_fft_bins), dtype=np.complex64) - fft_signal = np.zeros(fft_size) - for f, frame in enumerate(frames): - if window is not None: - np.multiply(frame, window, out=fft_signal[:frame_size]) - else: - fft_signal[:frame_size] = frame - data[f] = np.fft.fft(fft_signal, axis=0)[:num_fft_bins] - return data.T - - def __call__(self, waveform, padding=True, chunk_length=None): - """ - Compute the log-Mel spectrogram of the provided audio, gives similar results - whisper's original torch implementation with 1e-5 tolerance. - """ if chunk_length is not None: self.n_samples = chunk_length * self.sampling_rate self.nb_max_frames = self.n_samples // self.hop_length + if waveform.dtype is not torch.float32: + waveform = waveform.to(torch.float32) + + waveform = ( + waveform.to(self.device) + if self.device == "cuda" and not waveform.is_cuda + else waveform + ) + if padding: - waveform = np.pad(waveform, [(0, self.n_samples)]) + waveform = torch.nn.functional.pad(waveform, (0, self.n_samples)) - window = np.hanning(self.n_fft + 1)[:-1] + window = torch.hann_window(self.n_fft).to(waveform.device) - frames = self.fram_wave(waveform) - stft = self.stft(frames, window=window) - magnitudes = np.abs(stft[:, :-1]) ** 2 + stft = torch.stft( + waveform, self.n_fft, self.hop_length, window=window, return_complex=True + ) + magnitudes = stft[..., :-1].abs() ** 2 - filters = self.mel_filters - mel_spec = filters @ magnitudes + mel_spec = self.mel_filters.to(waveform.device) @ magnitudes - log_spec = np.log10(np.clip(mel_spec, a_min=1e-10, a_max=None)) - log_spec = np.maximum(log_spec, log_spec.max() - 8.0) + log_spec = torch.clamp(mel_spec, min=1e-10).log10() + log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) log_spec = (log_spec + 4.0) / 4.0 - return log_spec + # When the model is running on multiple GPUs, the output should be moved + # to the CPU since we don't know which GPU will handle the next job. + return log_spec.cpu() if to_cpu else log_spec diff --git a/faster_whisper/transcribe.py b/faster_whisper/transcribe.py index 6b95f016..2fa233af 100644 --- a/faster_whisper/transcribe.py +++ b/faster_whisper/transcribe.py @@ -2,24 +2,39 @@ import json import logging import os +import random import zlib +from collections import Counter, defaultdict from inspect import signature from typing import BinaryIO, Iterable, List, NamedTuple, Optional, Tuple, Union import ctranslate2 import numpy as np import tokenizers +import torch + +from pyannote.audio import Model +from transformers import Pipeline +from transformers.pipelines.pt_utils import PipelineIterator from faster_whisper.audio import decode_audio, pad_or_trim from faster_whisper.feature_extractor import FeatureExtractor from faster_whisper.tokenizer import _LANGUAGE_CODES, Tokenizer -from faster_whisper.utils import download_model, format_timestamp, get_end, get_logger +from faster_whisper.utils import ( + download_model, + format_timestamp, + get_assets_path, + get_end, + get_logger, +) from faster_whisper.vad import ( SpeechTimestampsMap, VadOptions, + VoiceActivitySegmentation, collect_chunks, get_speech_timestamps, + merge_chunks, ) @@ -37,13 +52,14 @@ class Segment(NamedTuple): end: float text: str tokens: List[int] - temperature: float avg_logprob: float compression_ratio: float no_speech_prob: float words: Optional[List[Word]] + temperature: Optional[float] = 1.0 +# Added additional parameters for multilingual videos and fixes below class TranscriptionOptions(NamedTuple): beam_size: int best_of: int @@ -52,6 +68,7 @@ class TranscriptionOptions(NamedTuple): repetition_penalty: float no_repeat_ngram_size: int log_prob_threshold: Optional[float] + log_prob_low_threshold: Optional[float] no_speech_threshold: Optional[float] compression_ratio_threshold: Optional[float] condition_on_previous_text: bool @@ -66,6 +83,8 @@ class TranscriptionOptions(NamedTuple): word_timestamps: bool prepend_punctuations: str append_punctuations: str + multilingual: bool + output_language: Optional[str] max_new_tokens: Optional[int] clip_timestamps: Union[str, List[float]] hallucination_silence_threshold: Optional[float] @@ -82,6 +101,583 @@ class TranscriptionInfo(NamedTuple): vad_options: VadOptions +# The code below is originally from HF pipeline and is used in whisper-x +# (https://github.com/m-bain/whisperX) and adapted for faster_whisper + + +class BatchedInferencePipeline(Pipeline): + """ + Huggingface Pipeline wrapper for WhisperModel. + Copyright (c) 2022, Max Bain + All rights reserved. + Modified by Mobius Labs GmbH + """ + + def __init__( + self, + model, + use_vad_model: bool = True, + options: Optional[NamedTuple] = None, + tokenizer=None, + device: Union[int, str, "torch.device"] = -1, + chunk_length: int = 30, + vad_device: Union[int, str, "torch.device"] = "auto", + framework="pt", + language: Optional[str] = None, + **kwargs, + ): + self.model: WhisperModel = model + self.tokenizer = tokenizer + self.options = options + self.preset_language = language + self._batch_size = kwargs.pop("batch_size", None) + self._num_workers = 0 + self.use_vad_model = use_vad_model + self.vad_onset = 0.500 + self.vad_offset = 0.363 + self.vad_model_path = os.path.join(get_assets_path(), "pyannote_vad_model.bin") + self.vad_model = None + + ( + self._preprocess_params, + self._forward_params, + self._postprocess_params, + ) = self._sanitize_parameters(**kwargs) + self.call_count = 0 + self.framework = framework + if self.framework == "pt": + self.device = self.get_device(device) + else: + self.device = device + + if self.use_vad_model and self.vad_model is None: + self.vad_device = self.get_device(vad_device) + + # load vad model and perform VAD preprocessing if needed + self.vad_model = self.load_vad_model( + vad_onset=self.vad_onset, vad_offset=self.vad_offset + ) + self.chunk_length = chunk_length # VAD merging size + self.last_speech_timestamp = 0.0 + super(Pipeline, self).__init__() + + def _sanitize_parameters(self, **kwargs): + preprocess_kwargs = {} + if "tokenizer" in kwargs: + preprocess_kwargs["maybe_arg"] = kwargs["maybe_arg"] + return preprocess_kwargs, {}, {} + + def get_device(self, device: Union[int, str, "torch.device"]): + """ + Converts the input device into a torch.device object. + + The input can be an integer, a string, or a `torch.device` object. + + The function handles a special case where the input device is "auto". + When "auto" is specified, the device will default to the + device of the model (self.model.device). If the model's device is also "auto", + it selects "cuda" if a CUDA-capable device is available; otherwise, it selects "cpu". + """ + if isinstance(device, torch.device): + return device + elif isinstance(device, str): + if device == "auto" and self.model.device == "auto": + device = "cuda" if torch.cuda.is_available() else "cpu" + elif device == "auto": + device = self.model.device + return torch.device(device) + elif device < 0: + return torch.device("cpu") + else: + return torch.device(f"cuda:{device}") + + def preprocess(self, inputs): + audio = inputs["inputs"] + to_cpu = ( + self.model.model.device == "cuda" and len(self.model.model.device_index) > 1 + ) + features = self.model.feature_extractor(audio, padding=True, to_cpu=to_cpu)[ + :, : self.model.feature_extractor.nb_max_frames + ] + + inputs["features"] = features + del features + return inputs + + def _forward(self, model_inputs, **forward_params): + encoder_output, outputs = self.model.generate_segment_batched( + model_inputs["features"], self.tokenizer, forward_params + ) + + segment_size = encoder_output.shape[1] * 2 + segmented_outputs = [] + for segment_metadata, output in zip(model_inputs["seg_metadata"], outputs): + ( + subsegments, + seek, + single_timestamp_ending, + ) = self.model._split_segments_by_timestamps( + tokenizer=self.tokenizer, + tokens=output["tokens"], + time_offset=segment_metadata["start_time"], + segment_size=segment_size, + segment_duration=segment_metadata["end_time"] + - segment_metadata["start_time"], + seek=0, + ) + segmented_outputs.append( + [ + dict( + text=self.tokenizer.decode(subsegment["tokens"]), + avg_logprob=output["avg_logprob"], + no_speech_prob=output["no_speech_prob"], + tokens=subsegment["tokens"], + start=subsegment["start"], + end=subsegment["end"], + compression_ratio=get_compression_ratio( + self.tokenizer.decode(subsegment["tokens"]) + ), + ) + for subsegment in subsegments + ] + ) + if forward_params["word_timestamps"]: + self.last_speech_timestamp = self.model.add_word_timestamps( + segmented_outputs, + self.tokenizer, + encoder_output, + segment_size, + forward_params["prepend_punctuations"], + forward_params["append_punctuations"], + self.last_speech_timestamp, + ) + + return {"output": segmented_outputs} + + def __call__(self, inputs, options, batch_size=None, **kwargs): + if batch_size is None: + if self._batch_size is None: + batch_size = 1 + else: + batch_size = self._batch_size + + ( + preprocess_params, + forward_params, + postprocess_params, + ) = self._sanitize_parameters(**kwargs) + + # Fuse __init__ params and __call__ params without modifying the __init__ ones. + preprocess_params = { + **self._preprocess_params, + **preprocess_params, + } + options_dict = options._asdict() + forward_params = {**self._forward_params, **forward_params, **options_dict} + postprocess_params = {**self._postprocess_params, **postprocess_params} + + self.call_count += 1 + if ( + self.call_count > 10 + and self.framework == "pt" + and self.device.type == "cuda" + ): + logging.warning( + "You seem to be using the pipelines sequentially on GPU. Please use a Dataset" + ) + + return self.get_iterator( + inputs, + batch_size, + preprocess_params, + forward_params, + postprocess_params, + ) + + def postprocess(self, model_outputs): + return model_outputs + + def get_iterator( + self, + inputs, + batch_size: int, + preprocess_params=None, + forward_params=None, + postprocess_params=None, + ): + def stack(items): + return { + "inputs": [x["inputs"] for x in items], + "seg_metadata": [x["seg_metadata"] for x in items], + "features": torch.stack([x["features"] for x in items]), + } + + if "TOKENIZERS_PARALLELISM" not in os.environ: + os.environ["TOKENIZERS_PARALLELISM"] = "false" + + dataset = PipelineIterator(inputs, self.preprocess, preprocess_params) + dataloader = torch.utils.data.DataLoader( + dataset, + num_workers=self._num_workers, + batch_size=batch_size, + collate_fn=stack, + ) + model_iterator = PipelineIterator( + dataloader, self.forward, forward_params, loader_batch_size=batch_size + ) + final_iterator = PipelineIterator( + model_iterator, self.postprocess, postprocess_params + ) + return final_iterator + + def get_language_and_tokenizer( + self, audio, task: Optional[str] = None, language: Optional[str] = None + ): + all_language_probs = None + language_probability = 1.0 + + if self.tokenizer is None: + if not language: + ( + language, + language_probability, + all_language_probs, + ) = self.model.detect_language(audio) + task = task or "transcribe" + self.tokenizer = Tokenizer( + self.model.hf_tokenizer, + self.model.model.is_multilingual, + task=task, + language=language, + ) + else: + if task is not None: + self.tokenizer.task = self.tokenizer.tokenizer.token_to_id( + f"<|{task}|>" + ) + + if language is not None: + self.tokenizer.language = self.tokenizer.tokenizer.token_to_id( + f"<|{language}|>" + ) + self.tokenizer.language_code = language + + return language, language_probability, task, all_language_probs + + @staticmethod + def audio_split(audio, segments, sampling_rate): + """Returns splitted audio chunks as iterator""" + + for seg in segments: + f1 = int(seg["start"] * sampling_rate) + f2 = int(seg["end"] * sampling_rate) + seg_metadata = { + "start_time": seg["start"], + "end_time": seg["end"], + "stitched_seg": seg["segments"], + } + yield {"inputs": audio[f1:f2], "seg_metadata": seg_metadata} + + def load_vad_model(self, vad_onset=0.500, vad_offset=0.363): + vad_model = Model.from_pretrained(self.vad_model_path) + hyperparameters = { + "onset": vad_onset, + "offset": vad_offset, + "min_duration_on": 0.1, + "min_duration_off": 0.1, + } + + vad_pipeline = VoiceActivitySegmentation( + segmentation=vad_model, device=torch.device(self.vad_device) + ) + vad_pipeline.instantiate(hyperparameters) + return vad_pipeline + + def transcribe( + self, + audio: Union[str, torch.Tensor, np.ndarray], + vad_segments: Optional[List[dict]] = None, + batch_size: int = 16, + language: Optional[str] = None, + task: str = None, + log_progress: bool = False, + beam_size: int = 5, + best_of: int = 5, + patience: float = 1, + length_penalty: float = 1, + repetition_penalty: float = 1, + no_repeat_ngram_size: int = 0, + temperature: Union[float, List[float], Tuple[float, ...]] = [ + 0.0, + 0.2, + 0.4, + 0.6, + 0.8, + 1.0, + ], + compression_ratio_threshold: Optional[float] = 2.4, + log_prob_threshold: Optional[float] = -1.0, + log_prob_low_threshold: Optional[float] = None, + no_speech_threshold: Optional[float] = 0.6, + initial_prompt: Optional[Union[str, Iterable[int]]] = None, + prefix: Optional[str] = None, + suppress_blank: bool = True, + suppress_tokens: Optional[List[int]] = [-1], + prepend_punctuations: str = "\"'“¿([{-", + append_punctuations: str = "\"'.。,,!!??::”)]}、", + max_new_tokens: Optional[int] = None, + hotwords: Optional[str] = None, + word_timestamps: bool = False, + without_timestamps: bool = True, + ) -> Tuple[Iterable[Segment], TranscriptionInfo]: + """transcribe audio in chunks in batched fashion and return with language info. + + Arguments: + audio: audio file as numpy array/path for batched transcription. + vad_segments: Optionally provide list of dictionaries each containing "start", "end", + and "segments" keys. + "start" and "end" keys specify the start and end of the voiced region within + 30 sec boundary. An additional key "segments" contains all the start + and end of voiced regions within that 30sec boundary as a list of tuples. + If no vad_segments specified, it uses internal vad model automatically segment them. + batch_size: the maximum number of parallel requests to model for decoding. + language: The language spoken in the audio. + task: either "transcribe" or "translate". + log_progress: whether to show progress bar or not. + beam_size: Beam size to use for decoding. + best_of: Number of candidates when sampling with non-zero temperature. + patience: Beam search patience factor. + length_penalty: Exponential length penalty constant. + repetition_penalty: Penalty applied to the score of previously generated tokens + (set > 1 to penalize). + no_repeat_ngram_size: Prevent repetitions of ngrams with this size (set 0 to disable). + temperature: Temperature for sampling. It can be a tuple of temperatures, + which will be successively used upon failures according to either + `compression_ratio_threshold` or `log_prob_threshold`. + compression_ratio_threshold: If the gzip compression ratio is above this value, + treat as failed. + log_prob_threshold: If the average log probability over sampled tokens is + below this value, treat as failed. + log_prob_low_threshold: This parameter alone is sufficient to skip an output text, + whereas log_prob_threshold also looks for appropriate no_speech_threshold value. + This value should be less than log_prob_threshold. + no_speech_threshold: If the no_speech probability is higher than this value AND + the average log probability over sampled tokens is below `log_prob_threshold`, + consider the segment as silent. + initial_prompt: Optional text string or iterable of token ids to provide as a + prompt for the first window. + prefix: Optional text to provide as a prefix for the first window. + suppress_blank: Suppress blank outputs at the beginning of the sampling. + suppress_tokens: List of token IDs to suppress. -1 will suppress a default set + of symbols as defined in `tokenizer.non_speech_tokens()`. + prepend_punctuations: If word_timestamps is True, merge these punctuation symbols + with the next word + append_punctuations: If word_timestamps is True, merge these punctuation symbols + with the previous word + max_new_tokens: Maximum number of new tokens to generate per-chunk. If not set, + the maximum will be set by the default max_length. + hotwords: + Hotwords/hint phrases to the model. Has no effect if prefix is not None. + word_timestamps: Extract word-level timestamps using the cross-attention pattern + and dynamic time warping, and include the timestamps for each word in each segment. + Set as False. + without_timestamps: Only sample text tokens. + + Static params: (Fixed for batched version) + max_initial_timestamp: The initial timestamp cannot be later than this, set at 0.0. + multilingual: If True, perform transcription on multilingual videos. Set as False. + output_language: Valid only if multilingual is set to True. + Specifies the string representing the output language. One of + 'en' (English) or 'hybrid' (code-switched transcription). set as None. + condition_on_previous_text: If True, the previous output of the model is provided + as a prompt for the next window; disabling may make the text inconsistent across + windows, but the model becomes less prone to getting stuck in a failure loop, + such as repetition looping or timestamps going out of sync. Set as False + prompt_reset_on_temperature: Resets prompt if temperature is above this value. + Arg has effect only if condition_on_previous_text is True. Set at 0.5 + #TODO: support "hallucination_silence_threshold" when "word_timestamps=True" + hallucination_silence_threshold: Optional[float] + When word_timestamps is True, skip silent periods longer than this threshold + (in seconds) when a possible hallucination is detected. set as None. + clip_timestamps: + Comma-separated list start,end,start,end,... timestamps (in seconds) of clips to + process. The last end timestamp defaults to the end of the file. Set as "0". + + unused: + language_detection_threshold: If the maximum probability of the language tokens is + higher than this value, the language is detected. + language_detection_segments: Number of segments to consider for the language detection. + vad_filter: Enable the voice activity detection (VAD) to filter out parts of the audio + without speech. This step is using the Silero VAD model + https://github.com/snakers4/silero-vad. + vad_parameters: Dictionary of Silero VAD parameters or VadOptions class (see available + parameters and default values in the class `VadOptions`). + chunk_length: The length of audio segments. If it is not None, it will overwrite the + default chunk_length of the FeatureExtractor. + + + Returns: + A tuple with: + + - a generator over transcribed batched segments. + - an instance of TranscriptionInfo. + """ + + sampling_rate = self.model.feature_extractor.sampling_rate + + if isinstance(audio, np.ndarray): + audio = torch.from_numpy(audio) + elif not isinstance(audio, torch.Tensor): + audio = decode_audio(audio, sampling_rate=sampling_rate) + duration = audio.shape[0] / sampling_rate + + # if no segment split is provided, use vad_model and generate segments + if not vad_segments: + # run the audio if it is less than 30 sec even without vad_segments + if self.use_vad_model: + vad_segments = self.vad_model( + { + "waveform": audio.unsqueeze(0), + "sample_rate": 16000, + } + ) + vad_segments = merge_chunks( + vad_segments, + self.chunk_length, + onset=self.vad_onset, + offset=self.vad_offset, + ) + elif duration < self.chunk_length: + vad_segments = [ + {"start": 0.0, "end": duration, "segments": [(0.0, duration)]} + ] + else: + raise RuntimeError( + "No vad segments found. Set 'use_vad_model' to True while loading the model" + ) + if self.model.model.is_multilingual: + language = language or self.preset_language + elif language != "en": + if language is not None: + self.model.logger.warning( + f"English-only model is used, but {language} language is" + "chosen, setting language to 'en'." + ) + language = "en" + + ( + language, + language_probability, + task, + all_language_probs, + ) = self.get_language_and_tokenizer(audio, task, language) + batch_size = batch_size or self._batch_size + + duration_after_vad = sum( + segment["end"] - segment["start"] for segment in vad_segments + ) + + # batched options: see the difference with default options in WhisperModel + batched_options = TranscriptionOptions( + beam_size=beam_size, + best_of=best_of, + patience=patience, + length_penalty=length_penalty, + repetition_penalty=repetition_penalty, + no_repeat_ngram_size=no_repeat_ngram_size, + log_prob_threshold=log_prob_threshold, + log_prob_low_threshold=log_prob_low_threshold, + no_speech_threshold=no_speech_threshold, + compression_ratio_threshold=compression_ratio_threshold, + temperatures=( + temperature if isinstance(temperature, (list, tuple)) else [temperature] + ), + initial_prompt=initial_prompt, + prefix=prefix, + suppress_blank=suppress_blank, + suppress_tokens=get_suppressed_tokens(self.tokenizer, suppress_tokens), + prepend_punctuations=prepend_punctuations, + append_punctuations=append_punctuations, + max_new_tokens=max_new_tokens, + hotwords=hotwords, + word_timestamps=word_timestamps, + hallucination_silence_threshold=None, + condition_on_previous_text=False, + clip_timestamps="0", + prompt_reset_on_temperature=0.5, + multilingual=False, + output_language=None, + without_timestamps=without_timestamps, + max_initial_timestamp=0.0, + ) + + info = TranscriptionInfo( + language=language, + language_probability=language_probability, + duration=duration, + duration_after_vad=duration_after_vad, + transcription_options=batched_options, + vad_options=None, + all_language_probs=all_language_probs, + ) + + segments = self._batched_segments_generator( + audio, + vad_segments, + sampling_rate, + batch_size, + batched_options, + log_progress, + ) + + return segments, info + + def _batched_segments_generator( + self, audio, vad_segments, sampling_rate, batch_size, options, log_progress + ): + seg_idx = 0 + total_segments = len(vad_segments) + for idx, out in enumerate( + self.__call__( + self.audio_split(audio, vad_segments, sampling_rate), + batch_size=batch_size, + options=options, + ) + ): + if log_progress: + percent_complete = ((idx + 1) / total_segments) * 100 + self.model.logger.info(f"Progress: {percent_complete:.2f}%...") + + responses = out["output"] + if batch_size == 1: + responses = responses[0] + + for response in responses: + seg_idx += 1 + segments = Segment( + seek=int(responses[-1]["end"] * self.model.frames_per_second), + id=seg_idx, + text=response["text"], + start=round(response["start"], 3), + end=round(response["end"], 3), + words=( + None + if not options.word_timestamps + else [Word(**word) for word in response["words"]] + ), + tokens=response["tokens"], + avg_logprob=response["avg_logprob"], + no_speech_prob=response["no_speech_prob"], + compression_ratio=response["compression_ratio"], + ) + yield segments + + # revert the tokenizer if multilingual inference is enabled + if self.preset_language is None: + self.tokenizer = None + self.last_speech_timestamp = 0.0 + + class WhisperModel: def __init__( self, @@ -89,7 +685,7 @@ def __init__( device: str = "auto", device_index: Union[int, List[int]] = 0, compute_type: str = "default", - cpu_threads: int = 0, + cpu_threads: int = 16, num_workers: int = 1, download_root: Optional[str] = None, local_files_only: bool = False, @@ -141,10 +737,12 @@ def __init__( local_files_only=local_files_only, cache_dir=download_root, ) - + self.device = device + # set the random seed to make sure consistency across runs + ctranslate2.set_random_seed(42) self.model = ctranslate2.models.Whisper( model_path, - device=device, + device=self.device, device_index=device_index, compute_type=compute_type, intra_threads=cpu_threads, @@ -163,15 +761,19 @@ def __init__( "openai/whisper-tiny" + ("" if self.model.is_multilingual else ".en") ) self.feat_kwargs = self._get_feature_kwargs(model_path, preprocessor_bytes) - self.feature_extractor = FeatureExtractor(**self.feat_kwargs) - self.num_samples_per_token = self.feature_extractor.hop_length * 2 + self.feature_extractor = FeatureExtractor( + **self.feat_kwargs, device=self.device + ) + self.input_stride = 2 + self.num_samples_per_token = ( + self.feature_extractor.hop_length * self.input_stride + ) self.frames_per_second = ( self.feature_extractor.sampling_rate // self.feature_extractor.hop_length ) self.tokens_per_second = ( self.feature_extractor.sampling_rate // self.num_samples_per_token ) - self.input_stride = 2 self.time_precision = 0.02 self.max_length = 448 @@ -200,7 +802,7 @@ def _get_feature_kwargs(self, model_path, preprocessor_bytes=None) -> dict: def transcribe( self, - audio: Union[str, BinaryIO, np.ndarray], + audio: Union[str, BinaryIO, torch.Tensor, np.ndarray], language: Optional[str] = None, task: str = "transcribe", beam_size: int = 5, @@ -219,6 +821,7 @@ def transcribe( ], compression_ratio_threshold: Optional[float] = 2.4, log_prob_threshold: Optional[float] = -1.0, + log_prob_low_threshold: Optional[float] = None, no_speech_threshold: Optional[float] = 0.6, condition_on_previous_text: bool = True, prompt_reset_on_temperature: float = 0.5, @@ -231,6 +834,8 @@ def transcribe( word_timestamps: bool = False, prepend_punctuations: str = "\"'“¿([{-", append_punctuations: str = "\"'.。,,!!??::”)]}、", + multilingual: bool = False, + output_language: Optional[str] = None, vad_filter: bool = False, vad_parameters: Optional[Union[dict, VadOptions]] = None, max_new_tokens: Optional[int] = None, @@ -263,6 +868,9 @@ def transcribe( treat as failed. log_prob_threshold: If the average log probability over sampled tokens is below this value, treat as failed. + log_prob_low_threshold: This parameter alone is sufficient to skip an output text, + wheras log_prob_threshold also looks for appropriate no_speech_threshold value. + This value should be less than log_prob_threshold. no_speech_threshold: If the no_speech probability is higher than this value AND the average log probability over sampled tokens is below `log_prob_threshold`, consider the segment as silent. @@ -277,7 +885,7 @@ def transcribe( prefix: Optional text to provide as a prefix for the first window. suppress_blank: Suppress blank outputs at the beginning of the sampling. suppress_tokens: List of token IDs to suppress. -1 will suppress a default set - of symbols as defined in `tokenizer.non_speech_tokens()` + of symbols as defined in `tokenizer.non_speech_tokens()`. without_timestamps: Only sample text tokens. max_initial_timestamp: The initial timestamp cannot be later than this. word_timestamps: Extract word-level timestamps using the cross-attention pattern @@ -286,6 +894,12 @@ def transcribe( with the next word append_punctuations: If word_timestamps is True, merge these punctuation symbols with the previous word + multilingual: If True, perform transcription on multilingual videos + and return the transcript based + on the 'output_language' flag. + output_language: Valid only if multilingual is set to True. + Specifies the string representing the output language. One of + 'en' (English) or 'hybrid' (code-switched transcription). vad_filter: Enable the voice activity detection (VAD) to filter out parts of the audio without speech. This step is using the Silero VAD model https://github.com/snakers4/silero-vad. @@ -313,9 +927,12 @@ def transcribe( - a generator over transcribed segments - an instance of TranscriptionInfo """ + sampling_rate = self.feature_extractor.sampling_rate - if not isinstance(audio, np.ndarray): + if isinstance(audio, np.ndarray): + audio = torch.from_numpy(audio) + elif not isinstance(audio, torch.Tensor): audio = decode_audio(audio, sampling_rate=sampling_rate) duration = audio.shape[0] / sampling_rate @@ -355,11 +972,22 @@ def transcribe( else: speech_chunks = None - features = self.feature_extractor(audio, chunk_length=chunk_length) + to_cpu = self.model.device == "cuda" and len(self.model.device_index) > 1 + features = self.feature_extractor( + audio, chunk_length=chunk_length, to_cpu=to_cpu + ) encoder_output = None all_language_probs = None + # setting output_language for multilingual videos + if multilingual: + if output_language is None: + output_language = "en" + elif output_language not in ["en", "hybrid"]: + raise ValueError("Output language needs to be one of 'en'/'hybrid'.") + + # detecting the language if not provided if language is None: if not self.model.is_multilingual: language = "en" @@ -452,6 +1080,7 @@ def transcribe( repetition_penalty=repetition_penalty, no_repeat_ngram_size=no_repeat_ngram_size, log_prob_threshold=log_prob_threshold, + log_prob_low_threshold=log_prob_low_threshold, no_speech_threshold=no_speech_threshold, compression_ratio_threshold=compression_ratio_threshold, condition_on_previous_text=condition_on_previous_text, @@ -472,6 +1101,8 @@ def transcribe( word_timestamps=word_timestamps, prepend_punctuations=prepend_punctuations, append_punctuations=append_punctuations, + multilingual=multilingual, + output_language=output_language, max_new_tokens=max_new_tokens, clip_timestamps=clip_timestamps, hallucination_silence_threshold=hallucination_silence_threshold, @@ -494,9 +1125,88 @@ def transcribe( ) return segments, info + def _split_segments_by_timestamps( + self, + tokenizer: Tokenizer, + tokens: List[int], + time_offset: float, + segment_size: int, + segment_duration: float, + seek: int, + ) -> List[List[int]]: + current_segments = [] + single_timestamp_ending = ( + len(tokens) >= 2 and tokens[-2] < tokenizer.timestamp_begin <= tokens[-1] + ) + + consecutive_timestamps = [ + i + for i in range(len(tokens)) + if i > 0 + and tokens[i] >= tokenizer.timestamp_begin + and tokens[i - 1] >= tokenizer.timestamp_begin + ] + + if len(consecutive_timestamps) > 0: + slices = list(consecutive_timestamps) + if single_timestamp_ending: + slices.append(len(tokens)) + + last_slice = 0 + for current_slice in slices: + sliced_tokens = tokens[last_slice:current_slice] + start_timestamp_position = sliced_tokens[0] - tokenizer.timestamp_begin + end_timestamp_position = sliced_tokens[-1] - tokenizer.timestamp_begin + start_time = ( + time_offset + start_timestamp_position * self.time_precision + ) + end_time = time_offset + end_timestamp_position * self.time_precision + + current_segments.append( + dict( + seek=seek, + start=start_time, + end=end_time, + tokens=sliced_tokens, + ) + ) + last_slice = current_slice + + if single_timestamp_ending: + # single timestamp at the end means no speech after the last timestamp. + seek += segment_size + else: + # otherwise, ignore the unfinished segment and seek to the last timestamp + last_timestamp_position = ( + tokens[last_slice - 1] - tokenizer.timestamp_begin + ) + seek += last_timestamp_position * self.input_stride + + else: + duration = segment_duration + timestamps = [ + token for token in tokens if token >= tokenizer.timestamp_begin + ] + if len(timestamps) > 0 and timestamps[-1] != tokenizer.timestamp_begin: + last_timestamp_position = timestamps[-1] - tokenizer.timestamp_begin + duration = last_timestamp_position * self.time_precision + + current_segments.append( + dict( + seek=seek, + start=time_offset, + end=time_offset + duration, + tokens=tokens, + ) + ) + + seek += segment_size + + return current_segments, seek, single_timestamp_ending + def generate_segments( self, - features: np.ndarray, + features: torch.Tensor, tokenizer: Tokenizer, options: TranscriptionOptions, encoder_output: Optional[ctranslate2.StorageView] = None, @@ -578,6 +1288,28 @@ def generate_segments( ) previous_tokens = all_tokens[prompt_reset_since:] + + if encoder_output is None: + encoder_output = self.encode(segment) + + # Perform language detection at every segment to update task based on output language, + # if the language is english, task is transcribe, + # else the task is translate to english (default) + # or transcribe if 'output_language' is 'hybrid'. + if options.multilingual: + results = self.model.detect_language(encoder_output) + language_token, language_probability = results[0][0] + language = language_token[2:-2] + if options.output_language == "en" and language != "en": + task = "translate" + else: + task = "transcribe" + + # Update tokenizer based on task and language + tokenizer.task = tokenizer.tokenizer.token_to_id(f"<|{task}|>") + tokenizer.language = tokenizer.tokenizer.token_to_id(language_token) + tokenizer.language_code = language + # Update prompt based on task and language prompt = self.get_prompt( tokenizer, previous_tokens, @@ -614,6 +1346,18 @@ def generate_segments( options.no_speech_threshold, ) + # Skip if the logprob is very low (below the threshold value), + # despite no_speech_prob being low (ex: Too ambiguous outputs) + if options.log_prob_low_threshold: + if avg_logprob < options.log_prob_low_threshold: + should_skip = True + self.logger.debug( + "log prob low threshold is met (%f > %f)", + avg_logprob, + options.log_prob_low_threshold, + ) + + if should_skip: # fast-forward to the next segment boundary seek += segment_size continue @@ -621,7 +1365,6 @@ def generate_segments( tokens = result.sequences_ids[0] previous_seek = seek - current_segments = [] # anomalous words are very long/short/improbable def word_anomaly_score(word: dict) -> float: @@ -647,83 +1390,22 @@ def is_segment_anomaly(segment: Optional[dict]) -> bool: def next_words_segment(segments: List[dict]) -> Optional[dict]: return next((s for s in segments if s["words"]), None) - single_timestamp_ending = ( - len(tokens) >= 2 - and tokens[-2] < tokenizer.timestamp_begin <= tokens[-1] + ( + current_segments, + seek, + single_timestamp_ending, + ) = self._split_segments_by_timestamps( + tokenizer=tokenizer, + tokens=tokens, + time_offset=time_offset, + segment_size=segment_size, + segment_duration=segment_duration, + seek=seek, ) - consecutive_timestamps = [ - i - for i in range(len(tokens)) - if i > 0 - and tokens[i] >= tokenizer.timestamp_begin - and tokens[i - 1] >= tokenizer.timestamp_begin - ] - - if len(consecutive_timestamps) > 0: - slices = list(consecutive_timestamps) - if single_timestamp_ending: - slices.append(len(tokens)) - - last_slice = 0 - for current_slice in slices: - sliced_tokens = tokens[last_slice:current_slice] - start_timestamp_position = ( - sliced_tokens[0] - tokenizer.timestamp_begin - ) - end_timestamp_position = ( - sliced_tokens[-1] - tokenizer.timestamp_begin - ) - start_time = ( - time_offset + start_timestamp_position * self.time_precision - ) - end_time = ( - time_offset + end_timestamp_position * self.time_precision - ) - - current_segments.append( - dict( - seek=seek, - start=start_time, - end=end_time, - tokens=sliced_tokens, - ) - ) - last_slice = current_slice - - if single_timestamp_ending: - # single timestamp at the end means no speech after the last timestamp. - seek += segment_size - else: - # otherwise, ignore the unfinished segment and seek to the last timestamp - last_timestamp_position = ( - tokens[last_slice - 1] - tokenizer.timestamp_begin - ) - seek += last_timestamp_position * self.input_stride - - else: - duration = segment_duration - timestamps = [ - token for token in tokens if token >= tokenizer.timestamp_begin - ] - if len(timestamps) > 0 and timestamps[-1] != tokenizer.timestamp_begin: - last_timestamp_position = timestamps[-1] - tokenizer.timestamp_begin - duration = last_timestamp_position * self.time_precision - - current_segments.append( - dict( - seek=seek, - start=time_offset, - end=time_offset + duration, - tokens=tokens, - ) - ) - - seek += segment_size - if options.word_timestamps: self.add_word_timestamps( - current_segments, + [current_segments], tokenizer, encoder_output, segment_size, @@ -731,7 +1413,6 @@ def next_words_segment(segments: List[dict]) -> Optional[dict]: options.append_punctuations, last_speech_timestamp=last_speech_timestamp, ) - if not single_timestamp_ending: last_word_end = get_end(current_segments) if last_word_end is not None and last_word_end > time_offset: @@ -788,7 +1469,6 @@ def next_words_segment(segments: List[dict]) -> Optional[dict]: last_word_end = get_end(current_segments) if last_word_end is not None: last_speech_timestamp = last_word_end - for segment in current_segments: tokens = segment["tokens"] text = tokenizer.decode(tokens) @@ -830,12 +1510,13 @@ def next_words_segment(segments: List[dict]) -> Optional[dict]: prompt_reset_since = len(all_tokens) - def encode(self, features: np.ndarray) -> ctranslate2.StorageView: + def encode(self, features: torch.Tensor) -> ctranslate2.StorageView: # When the model is running on multiple GPUs, the encoder output should be moved # to the CPU since we don't know which GPU will handle the next job. to_cpu = self.model.device == "cuda" and len(self.model.device_index) > 1 - features = np.expand_dims(features, 0) + if features.ndim == 2: + features = features.unsqueeze(0) features = get_ctranslate2_storage(features) return self.model.encode(features, to_cpu=to_cpu) @@ -1014,115 +1695,127 @@ def add_word_timestamps( prepend_punctuations: str, append_punctuations: str, last_speech_timestamp: float, - ) -> None: + ) -> float: if len(segments) == 0: return - text_tokens_per_segment = [ - [token for token in segment["tokens"] if token < tokenizer.eot] - for segment in segments - ] + text_tokens = [] + text_tokens_per_segment = [] + for segment in segments: + segment_tokens = [ + [token for token in subsegment["tokens"] if token < tokenizer.eot] + for subsegment in segment + ] + text_tokens.append(list(itertools.chain.from_iterable(segment_tokens))) + text_tokens_per_segment.append(segment_tokens) - text_tokens = list(itertools.chain.from_iterable(text_tokens_per_segment)) - alignment = self.find_alignment( + alignments = self.find_alignment( tokenizer, text_tokens, encoder_output, num_frames ) - word_durations = np.array([word["end"] - word["start"] for word in alignment]) - word_durations = word_durations[word_durations.nonzero()] - median_duration = np.median(word_durations) if len(word_durations) > 0 else 0.0 - median_duration = min(0.7, float(median_duration)) - max_duration = median_duration * 2 - - # hack: truncate long words at sentence boundaries. - # a better segmentation algorithm based on VAD should be able to replace this. - if len(word_durations) > 0: - sentence_end_marks = ".。!!??" - # ensure words at sentence boundaries - # are not longer than twice the median word duration. - for i in range(1, len(alignment)): - if alignment[i]["end"] - alignment[i]["start"] > max_duration: - if alignment[i]["word"] in sentence_end_marks: - alignment[i]["end"] = alignment[i]["start"] + max_duration - elif alignment[i - 1]["word"] in sentence_end_marks: - alignment[i]["start"] = alignment[i]["end"] - max_duration - - merge_punctuations(alignment, prepend_punctuations, append_punctuations) - - time_offset = ( - segments[0]["seek"] - * self.feature_extractor.hop_length - / self.feature_extractor.sampling_rate - ) - - word_index = 0 - - for segment, text_tokens in zip(segments, text_tokens_per_segment): - saved_tokens = 0 - words = [] - - while word_index < len(alignment) and saved_tokens < len(text_tokens): - timing = alignment[word_index] + median_max_durations = [] + for alignment in alignments: + word_durations = np.array( + [word["end"] - word["start"] for word in alignment] + ) + word_durations = word_durations[word_durations.nonzero()] + median_duration = ( + np.median(word_durations) if len(word_durations) > 0 else 0.0 + ) + median_duration = min(0.7, float(median_duration)) + max_duration = median_duration * 2 - if timing["word"]: - words.append( - dict( - word=timing["word"], - start=round(time_offset + timing["start"], 2), - end=round(time_offset + timing["end"], 2), - probability=timing["probability"], + # hack: truncate long words at sentence boundaries. + # a better segmentation algorithm based on VAD should be able to replace this. + if len(word_durations) > 0: + sentence_end_marks = ".。!!??" + # ensure words at sentence boundaries + # are not longer than twice the median word duration. + for i in range(1, len(alignment)): + if alignment[i]["end"] - alignment[i]["start"] > max_duration: + if alignment[i]["word"] in sentence_end_marks: + alignment[i]["end"] = alignment[i]["start"] + max_duration + elif alignment[i - 1]["word"] in sentence_end_marks: + alignment[i]["start"] = alignment[i]["end"] - max_duration + + merge_punctuations(alignment, prepend_punctuations, append_punctuations) + median_max_durations.append((median_duration, max_duration)) + + for segment_idx, segment in enumerate(segments): + word_index = 0 + time_offset = segment[0]["start"] + median_duration, max_duration = median_max_durations[segment_idx] + for subsegment_idx, subsegment in enumerate(segment): + saved_tokens = 0 + words = [] + + while word_index < len(alignments[segment_idx]) and saved_tokens < len( + text_tokens_per_segment[segment_idx][subsegment_idx] + ): + timing = alignments[segment_idx][word_index] + + if timing["word"]: + words.append( + dict( + word=timing["word"], + start=round(time_offset + timing["start"], 2), + end=round(time_offset + timing["end"], 2), + probability=timing["probability"], + ) ) - ) - saved_tokens += len(timing["tokens"]) - word_index += 1 + saved_tokens += len(timing["tokens"]) + word_index += 1 + + # hack: truncate long words at segment boundaries. + # a better segmentation algorithm based on VAD should be able to replace this. + if len(words) > 0: + # ensure the first and second word after a pause is not longer than + # twice the median word duration. + if words[0][ + "end" + ] - last_speech_timestamp > median_duration * 4 and ( + words[0]["end"] - words[0]["start"] > max_duration + or ( + len(words) > 1 + and words[1]["end"] - words[0]["start"] > max_duration * 2 + ) + ): + if ( + len(words) > 1 + and words[1]["end"] - words[1]["start"] > max_duration + ): + boundary = max( + words[1]["end"] / 2, words[1]["end"] - max_duration + ) + words[0]["end"] = words[1]["start"] = boundary + words[0]["start"] = max(0, words[0]["end"] - max_duration) - # hack: truncate long words at segment boundaries. - # a better segmentation algorithm based on VAD should be able to replace this. - if len(words) > 0: - # ensure the first and second word after a pause is not longer than - # twice the median word duration. - if words[0]["end"] - last_speech_timestamp > median_duration * 4 and ( - words[0]["end"] - words[0]["start"] > max_duration - or ( - len(words) > 1 - and words[1]["end"] - words[0]["start"] > max_duration * 2 - ) - ): + # prefer the segment-level start timestamp if the first word is too long. if ( - len(words) > 1 - and words[1]["end"] - words[1]["start"] > max_duration + subsegment["start"] < words[0]["end"] + and subsegment["start"] - 0.5 > words[0]["start"] ): - boundary = max( - words[1]["end"] / 2, words[1]["end"] - max_duration + words[0]["start"] = max( + 0, + min(words[0]["end"] - median_duration, subsegment["start"]), ) - words[0]["end"] = words[1]["start"] = boundary - words[0]["start"] = max(0, words[0]["end"] - max_duration) + else: + subsegment["start"] = words[0]["start"] - # prefer the segment-level start timestamp if the first word is too long. - if ( - segment["start"] < words[0]["end"] - and segment["start"] - 0.5 > words[0]["start"] - ): - words[0]["start"] = max( - 0, min(words[0]["end"] - median_duration, segment["start"]) - ) - else: - segment["start"] = words[0]["start"] - - # prefer the segment-level end timestamp if the last word is too long. - if ( - segment["end"] > words[-1]["start"] - and segment["end"] + 0.5 < words[-1]["end"] - ): - words[-1]["end"] = max( - words[-1]["start"] + median_duration, segment["end"] - ) - else: - segment["end"] = words[-1]["end"] - - last_speech_timestamp = segment["end"] + # prefer the segment-level end timestamp if the last word is too long. + if ( + subsegment["end"] > words[-1]["start"] + and subsegment["end"] + 0.5 < words[-1]["end"] + ): + words[-1]["end"] = max( + words[-1]["start"] + median_duration, subsegment["end"] + ) + else: + subsegment["end"] = words[-1]["end"] - segment["words"] = words + last_speech_timestamp = subsegment["end"] + segments[segment_idx][subsegment_idx]["words"] = words + return last_speech_timestamp def find_alignment( self, @@ -1135,51 +1828,332 @@ def find_alignment( if len(text_tokens) == 0: return [] - result = self.model.align( + results = self.model.align( encoder_output, tokenizer.sot_sequence, - [text_tokens], + text_tokens, num_frames, median_filter_width=median_filter_width, - )[0] + ) + return_list = [] + for result, text_token in zip(results, text_tokens): + text_token_probs = result.text_token_probs + alignments = result.alignments + text_indices = np.array([pair[0] for pair in alignments]) + time_indices = np.array([pair[1] for pair in alignments]) + + words, word_tokens = tokenizer.split_to_word_tokens( + text_token + [tokenizer.eot] + ) + if len(word_tokens) <= 1: + # return on eot only + # >>> np.pad([], (1, 0)) + # array([0.]) + # This results in crashes when we lookup jump_times with float, like + # IndexError: arrays used as indices must be of integer (or boolean) type + return [] + word_boundaries = np.pad( + np.cumsum([len(t) for t in word_tokens[:-1]]), (1, 0) + ) + if len(word_boundaries) <= 1: + return [] - text_token_probs = result.text_token_probs + jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype( + bool + ) + jump_times = time_indices[jumps] / self.tokens_per_second + start_times = jump_times[word_boundaries[:-1]] + end_times = jump_times[word_boundaries[1:]] + word_probabilities = [ + np.mean(text_token_probs[i:j]) + for i, j in zip(word_boundaries[:-1], word_boundaries[1:]) + ] + + return_list.append( + [ + dict( + word=word, + tokens=tokens, + start=start, + end=end, + probability=probability, + ) + for word, tokens, start, end, probability in zip( + words, word_tokens, start_times, end_times, word_probabilities + ) + ] + ) + return return_list + + def generate_segment_batched( + self, + features: torch.Tensor, + tokenizer: Tokenizer, + options: dict, + ): + batch_size = features.shape[0] + all_tokens = [] + prompt_reset_since = 0 + + if options["initial_prompt"] is not None: + initial_prompt = " " + options["initial_prompt"].strip() + initial_prompt_tokens = tokenizer.encode(initial_prompt) + all_tokens.extend(initial_prompt_tokens) + previous_tokens = all_tokens[prompt_reset_since:] + prompt = self.get_prompt( + tokenizer, + previous_tokens, + without_timestamps=options["without_timestamps"], + prefix=options["prefix"], + ) - alignments = result.alignments - text_indices = np.array([pair[0] for pair in alignments]) - time_indices = np.array([pair[1] for pair in alignments]) + encoder_output = self.encode(features) - words, word_tokens = tokenizer.split_to_word_tokens( - text_tokens + [tokenizer.eot] + result = self.model.generate( + encoder_output, + [prompt] * batch_size, + beam_size=options["beam_size"], + patience=options["patience"], + length_penalty=options["length_penalty"], + max_length=self.max_length, + suppress_blank=options["suppress_blank"], + suppress_tokens=options["suppress_tokens"], + return_scores=True, + return_no_speech_prob=True, ) - if len(word_tokens) <= 1: - # return on eot only - # >>> np.pad([], (1, 0)) - # array([0.]) - # This results in crashes when we lookup jump_times with float, like - # IndexError: arrays used as indices must be of integer (or boolean) type - return [] - word_boundaries = np.pad(np.cumsum([len(t) for t in word_tokens[:-1]]), (1, 0)) - if len(word_boundaries) <= 1: - return [] - jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool) - jump_times = time_indices[jumps] / self.tokens_per_second - start_times = jump_times[word_boundaries[:-1]] - end_times = jump_times[word_boundaries[1:]] - word_probabilities = [ - np.mean(text_token_probs[i:j]) - for i, j in zip(word_boundaries[:-1], word_boundaries[1:]) + output = [] + for res in result: + output.append({}) + # return scores + seq_len = len(res.sequences_ids[0]) + cum_logprob = res.scores[0] * (seq_len ** options["length_penalty"]) + output[-1]["avg_logprob"] = cum_logprob / (seq_len + 1) + + # return no speech prob + output[-1]["no_speech_prob"] = res.no_speech_prob + output[-1]["tokens"] = res.sequences_ids[0] + + return encoder_output, output + + def detect_language(self, audio: torch.Tensor): + to_cpu = self.model.device == "cuda" and len(self.model.device_index) > 1 + segment = self.feature_extractor(audio, padding=True, to_cpu=to_cpu)[ + :, : self.feature_extractor.nb_max_frames ] + encoder_output = self.encode(segment) + results = self.model.detect_language(encoder_output) + language_token, language_probability = results[0][0] + language = language_token[2:-2] + self.logger.info( + f"Detected language: {language} ({language_probability:.2f}) in first 30s of audio..." + ) + all_language_probs = [(token[2:-2], prob) for (token, prob) in results[0]] + return language, language_probability, all_language_probs - return [ - dict( - word=word, tokens=tokens, start=start, end=end, probability=probability + def detect_language_multi_segment( + self, audio: Union[str, BinaryIO, torch.Tensor], params: Optional[dict] = None + ): + """ + Detect language based on N highly-confident segments of a language. + """ + # The threshold is used to decide if the audio is silence or not. + # The default is 0.02 (2.0%) i.e, if more than 2.0% of the audio is silent, + # the audio is considered as silence. + if not params: + params = { + "multilingual": False, + "speech_percentage_threshold": 0.02, + "language_detection_segments": 4, + "vad_filter": True, + "vad_min_silence_duration": 2500, + "language_threshold": 0.7, + } + + if params.get("multilingual", False): + logging.warning( + "lang_id is not supported for multilingual audios, detecting the major language." ) - for word, tokens, start, end, probability in zip( - words, word_tokens, start_times, end_times, word_probabilities + + speech_percentage_threshold = params.get("speech_percentage_threshold", 0.02) + language_threshold = params.get("language_threshold", 0.7) + num_detection_segments = params.get("language_detection_segments", 4) + vad_filter_enabled = params.get("vad_filter", True) + vad_params = dict( + min_silence_duration_ms=params.get("vad_min_silence_duration", 2500) + ) + + if vad_filter_enabled: + vad_params = VadOptions(**vad_params) + + # decode audio if it is not decoded already + sampling_rate = self.feature_extractor.sampling_rate + if not isinstance(audio, torch.Tensor): + audio: torch.Tensor = decode_audio(audio, sampling_rate=sampling_rate) + + # calculate duration of audio as number of seconds + # audio.shape[0] is the number of samples in the audio + # sampling_rate is the number of samples per second + # if we divide the number of samples by the number of samples per second, + # we get the duration in seconds + duration = audio.shape[0] / sampling_rate + + # Check if vad is enabled, and collect voiced segments + if vad_filter_enabled: + # get chunks of audio that contain speech + speech_chunks = get_speech_timestamps(audio, vad_params) + # merge chunks of audio that contain speech into a single array + audio = collect_chunks(audio, speech_chunks) + + # calculate new duration of audio without silence + duration_vad = audio.shape[0] / sampling_rate + + logging.debug( + f"Lang ID: VAD filter removed {duration - duration_vad} sec of audio" ) - ] + + # if the audio after VAD is less than 2% of the original audio, consider it as silence + if duration_vad / duration < speech_percentage_threshold: + return {"language_code": None, "language_confidence": 1.0} + + # update duration to be the duration after VAD + duration = duration_vad + + # if the duration of the audio is less than 1 second, consider it as silence + if duration < 1.0: + return {"language_code": None, "language_confidence": 1.0} + + # number of feature frames in 30 seconds of audio is 3000 + nb_max_frames = self.feature_extractor.nb_max_frames + + # extract features from audio with padding (default) + to_cpu = self.model.device == "cuda" and len(self.model.device_index) > 1 + features = self.feature_extractor(audio, to_cpu=to_cpu) + + # number of segments in the audio + num_segments = features.shape[-1] // nb_max_frames + # more number of segments than possible with the duration of file + if num_detection_segments > num_segments: + logging.warning( + f"Lang ID: Can not have more segments, setting {num_segments} segments." + ) + num_detection_segments = num_segments + + # create a list of indices to randomly select segments from + indices = list(range(num_detection_segments)) + + # fix seed to get deterministic results + random.seed(0) + random.shuffle(indices) + + detected_languages = [] + all_language_probabilities = defaultdict(list) + confident_language_probabilities = defaultdict(list) + num_confident_segments_per_language = defaultdict(int) + + # Iterate over the randomly selected indices of the segments. + # + # For each segment, extract features and detect language. + # + # If the language is confident, add it to the list of confident segments for that language. + # + # If the number of confident segments for a language + # is greater than or equal to the number of detection segments, + # return the language and the average probability of the language. + # + # If we are unable to get sufficient number of confident predcitions, + # return the most frequently detected language with maximum probability. + # + # We need to get sufficient number of confident predictions per language, not in total. + + for i in indices: + segment_features = features[:, i * nb_max_frames : (i + 1) * nb_max_frames] + try: + encoder_output = self.encode(segment_features) + results = self.model.detect_language(encoder_output)[0] + + except ValueError as e: # or RuntimeError + logging.error(f"Inference error:{e}") + + # results is the list of classes (languages) and their probabilities (descending), + # for eg: [('<|de|>', 0.482177734375),('<|en|>', 0.283447265625),...] + + # take top language token and probability + # and parse language token to strip out markers + # for eg: '<|de|>' -> 'de' + + language_token = results[0][0] + language = language_token[2:-2] + + language_probability = results[0][1] + + detected_languages.append(language) + all_language_probabilities[language].append(language_probability) + + # only consider if the language prediction is confident + if language_probability > language_threshold: + num_confident_segments_per_language[language] += 1 + + # Add language and probability to the list of languages when it is confident + confident_language_probabilities[language].append(language_probability) + + # return the language when sufficient number of confident segments is achieved + if ( + num_confident_segments_per_language[language] + >= num_detection_segments + ): + # Considering the average probability of only confident segments + mean = sum(confident_language_probabilities[language]) / len( + confident_language_probabilities[language] + ) + return { + "language_code": language, + "language_confidence": mean, + } + + # if we are unable to get sufficient number of confident predictions, + # return the most frequently detected language. + # if there is a tie, return the one with maximum average probability. + counter = Counter(detected_languages) + + # Define the key function to select frequent language with attached probabilities + def key_func(language): + # Calculate the frequency of the language + frequency = counter[language] + + # Calculate the average probability of the language + prob_avg = sum(all_language_probabilities[language]) / len( + all_language_probabilities[language] + ) + + return frequency, prob_avg + + if detected_languages: + # Use the key function to find the language with maximum frequency and probability + max_language = max(detected_languages, key=key_func) + max_probability = sum(all_language_probabilities[max_language]) / len( + all_language_probabilities[max_language] + ) + + # Do additional checks for silence for non-confident case + # calculate RMS amplitude and DC offset + dc_offset = audio.mean() + audio_minus_dc_offset = audio - dc_offset + is_silent = ( + torch.all(audio.abs() < 0.01) + or torch.sqrt(torch.mean(audio_minus_dc_offset**2)) < 0.01 + ) + + if is_silent: + return {"language_code": None, "language_confidence": 1.0} + + return { + "language_code": max_language, + "language_confidence": max_probability, + } + + # Language is not detected for any segment and none of prev conditions met + return {"language_code": None, "language_confidence": 1.0} def restore_speech_timestamps( @@ -1217,9 +2191,12 @@ def restore_speech_timestamps( yield segment -def get_ctranslate2_storage(segment: np.ndarray) -> ctranslate2.StorageView: - segment = np.ascontiguousarray(segment) - segment = ctranslate2.StorageView.from_array(segment) +def get_ctranslate2_storage(segment: torch.Tensor) -> ctranslate2.StorageView: + segment = segment.contiguous() + segment = ctranslate2.StorageView.from_array( + segment if segment.is_cuda else segment.numpy() + ) # torch cpu tensors don't implement __array_interface__ + # https://github.com/pytorch/pytorch/issues/51156 return segment @@ -1263,9 +2240,11 @@ def merge_punctuations(alignment: List[dict], prepended: str, appended: str) -> if previous["word"].startswith(" ") and previous["word"].strip() in prepended: # prepend it to the following word following["word"] = previous["word"] + following["word"] - following["tokens"] = previous["tokens"] + following["tokens"] + if "tokens" in alignment[0].keys(): + following["tokens"] = previous["tokens"] + following["tokens"] + previous["tokens"] = [] previous["word"] = "" - previous["tokens"] = [] + else: j = i i -= 1 @@ -1279,9 +2258,11 @@ def merge_punctuations(alignment: List[dict], prepended: str, appended: str) -> if not previous["word"].endswith(" ") and following["word"] in appended: # append it to the previous word previous["word"] = previous["word"] + following["word"] - previous["tokens"] = previous["tokens"] + following["tokens"] + if "tokens" in alignment[0].keys(): + previous["tokens"] = previous["tokens"] + following["tokens"] + following["tokens"] = [] following["word"] = "" - following["tokens"] = [] + else: i = j j += 1 diff --git a/faster_whisper/vad.py b/faster_whisper/vad.py index 99dfb401..3881fd81 100644 --- a/faster_whisper/vad.py +++ b/faster_whisper/vad.py @@ -2,9 +2,17 @@ import functools import os -from typing import List, NamedTuple, Optional +from abc import ABC +from collections.abc import Callable +from typing import List, NamedTuple, Optional, Union import numpy as np +import torch + +from pyannote.audio.core.io import AudioFile +from pyannote.audio.pipelines import VoiceActivityDetection +from pyannote.audio.pipelines.utils import PipelineModel +from pyannote.core import Annotation, Segment, SlidingWindowFeature from faster_whisper.utils import get_assets_path @@ -35,7 +43,7 @@ class VadOptions(NamedTuple): def get_speech_timestamps( - audio: np.ndarray, + audio: torch.Tensor, vad_options: Optional[VadOptions] = None, **kwargs, ) -> List[dict]: @@ -176,12 +184,12 @@ def get_speech_timestamps( return speeches -def collect_chunks(audio: np.ndarray, chunks: List[dict]) -> np.ndarray: +def collect_chunks(audio: torch.Tensor, chunks: List[dict]) -> torch.Tensor: """Collects and concatenates audio chunks.""" if not chunks: - return np.array([], dtype=np.float32) + return torch.tensor([], dtype=torch.float32) - return np.concatenate([audio[chunk["start"] : chunk["end"]] for chunk in chunks]) + return torch.cat([audio[chunk["start"] : chunk["end"]] for chunk in chunks]) class SpeechTimestampsMap: @@ -276,3 +284,313 @@ def __call__(self, x, state, context, sr: int): context = x[..., -64:] return out, state, context + + +# BSD 2-Clause License + +# Copyright (c) 2024, Max Bain + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +# The code below is copied from whisper-x (https://github.com/m-bain/whisperX) +# and adapted for faster_whisper. +class SegmentX: + def __init__(self, start, end, speaker=None): + self.start = start + self.end = end + self.speaker = speaker + + +class VoiceActivitySegmentation(VoiceActivityDetection, ABC): + """Pipeline wrapper class for Voice Activity Segmentation based on VAD scores.""" + + def __init__( + self, + segmentation: PipelineModel = "pyannote/segmentation", + device: Optional[Union[str, torch.device]] = None, + fscore: bool = False, + use_auth_token: Optional[str] = None, + **inference_kwargs, + ): + """Initialize the pipeline with the model name and the optional device. + + Args: + dict parameters of VoiceActivityDetection class from pyannote: + segmentation (PipelineModel): Loaded model name. + device (torch.device or None): Device to perform the segmentation. + fscore (bool): Flag indicating whether to compute F-score during inference. + use_auth_token (str or None): Optional authentication token for model access. + inference_kwargs (dict): Additional arguments from VoiceActivityDetection pipeline. + """ + super().__init__( + segmentation=segmentation, + device=device, + fscore=fscore, + use_auth_token=use_auth_token, + **inference_kwargs, + ) + + def apply( + self, file: AudioFile, hook: Optional[Callable] = None + ) -> SlidingWindowFeature: + """Apply voice activity detection on the audio file. + + Args: + file (AudioFile): Processed file. + hook (callable): Hook called with signature: hook("step_name", step_artefact, file=file) + + Returns: + segmentations (SlidingWindowFeature): Voice activity segmentation. + """ + # setup hook (e.g. for debugging purposes) + hook = self.setup_hook(file, hook=hook) + + # apply segmentation model if needed + # output shape is (num_chunks, num_frames, 1) + if self.training: + if self.CACHED_SEGMENTATION in file: + segmentations = file[self.CACHED_SEGMENTATION] + else: + segmentations = self._segmentation(file) + file[self.CACHED_SEGMENTATION] = segmentations + else: + segmentations: SlidingWindowFeature = self._segmentation(file) + + return segmentations + + +class BinarizeVadScores: + """Binarize detection scores using hysteresis thresholding. + + Reference: + Gregory Gelly and Jean-Luc Gauvain. "Minimum Word Error Training of + RNN-based Voice Activity Detection", InterSpeech 2015. + + Modified by Max Bain to include WhisperX's min-cut operation + https://arxiv.org/abs/2303.00747 + + """ + + def __init__( + self, + onset: float = 0.5, + offset: Optional[float] = None, + min_duration_on: float = 0.0, + min_duration_off: float = 0.0, + pad_onset: float = 0.0, + pad_offset: float = 0.0, + max_duration: float = float("inf"), + ): + """Initializes the parameters for Binarizing the VAD scores. + + Args: + onset (float, optional): + Onset threshold. Defaults to 0.5. + offset (float, optional): + Offset threshold. Defaults to `onset`. + min_duration_on (float, optional): + Remove active regions shorter than that many seconds. Defaults to 0s. + min_duration_off (float, optional): + Fill inactive regions shorter than that many seconds. Defaults to 0s. + pad_onset (float, optional): + Extend active regions by moving their start time by that many seconds. + Defaults to 0s. + pad_offset (float, optional): + Extend active regions by moving their end time by that many seconds. + Defaults to 0s. + max_duration (float): + The maximum length of an active segment. + """ + super().__init__() + + self.onset = onset + self.offset = offset or onset + + self.pad_onset = pad_onset + self.pad_offset = pad_offset + + self.min_duration_on = min_duration_on + self.min_duration_off = min_duration_off + + self.max_duration = max_duration + + def __get_active_regions(self, scores: SlidingWindowFeature) -> Annotation: + """Extract active regions from VAD scores. + + Args: + scores (SlidingWindowFeature): Detection scores. + + Returns: + active (Annotation): Active regions. + """ + num_frames, num_classes = scores.data.shape + frames = scores.sliding_window + timestamps = [frames[i].middle for i in range(num_frames)] + # annotation meant to store 'active' regions + active = Annotation() + for k, k_scores in enumerate(scores.data.T): + label = k if scores.labels is None else scores.labels[k] + + # initial state + start = timestamps[0] + is_active = k_scores[0] > self.onset + curr_scores = [k_scores[0]] + curr_timestamps = [start] + t = start + # optionally add `strict=False` for python 3.10 or later + for t, y in zip(timestamps[1:], k_scores[1:]): + # currently active + if is_active: + curr_duration = t - start + if curr_duration > self.max_duration: + search_after = len(curr_scores) // 2 + # divide segment + min_score_div_idx = search_after + np.argmin( + curr_scores[search_after:] + ) + min_score_t = curr_timestamps[min_score_div_idx] + region = Segment( + start - self.pad_onset, min_score_t + self.pad_offset + ) + active[region, k] = label + start = curr_timestamps[min_score_div_idx] + curr_scores = curr_scores[min_score_div_idx + 1 :] + curr_timestamps = curr_timestamps[min_score_div_idx + 1 :] + # switching from active to inactive + elif y < self.offset: + region = Segment(start - self.pad_onset, t + self.pad_offset) + active[region, k] = label + start = t + is_active = False + curr_scores = [] + curr_timestamps = [] + curr_scores.append(y) + curr_timestamps.append(t) + # currently inactive + else: + # switching from inactive to active + if y > self.onset: + start = t + is_active = True + + # if active at the end, add final region + if is_active: + region = Segment(start - self.pad_onset, t + self.pad_offset) + active[region, k] = label + + return active + + def __call__(self, scores: SlidingWindowFeature) -> Annotation: + """Binarize detection scores. + + Args: + scores (SlidingWindowFeature): Detection scores. + + Returns: + active (Annotation): Binarized scores. + """ + active = self.__get_active_regions(scores) + # because of padding, some active regions might be overlapping: merge them. + # also: fill same speaker gaps shorter than min_duration_off + if self.pad_offset > 0.0 or self.pad_onset > 0.0 or self.min_duration_off > 0.0: + if self.max_duration < float("inf"): + raise NotImplementedError("This would break current max_duration param") + active = active.support(collar=self.min_duration_off) + + # remove tracks shorter than min_duration_on + if self.min_duration_on > 0: + for segment, track in list(active.itertracks()): + if segment.duration < self.min_duration_on: + del active[segment, track] + + return active + + +def merge_chunks( + segments, + chunk_length, + onset: float = 0.5, + offset: Optional[float] = None, + edge_padding: float = 0.1, +): + """ + Merge operation described in whisper-x paper + """ + curr_end = 0 + merged_segments = [] + seg_idxs = [] + speaker_idxs = [] + + assert chunk_length > 0 + binarize = BinarizeVadScores(max_duration=chunk_length, onset=onset, offset=offset) + segments = binarize(segments) + segments_list = [] + for speech_turn in segments.get_timeline(): + segments_list.append( + SegmentX( + max(0.0, speech_turn.start - edge_padding), + speech_turn.end + edge_padding, + "UNKNOWN", + ) + ) # 100ms edge padding to account for edge errors + + if len(segments_list) == 0: + print("No active speech found in audio") + return [] + + # Make sur the starting point is the start of the segment. + curr_start = segments_list[0].start + + for idx, seg in enumerate(segments_list): + # if any segment start timing is less than previous segment end timing, + # reset the edge padding. Similarly for end timing. + if idx > 0: + if seg.start < segments_list[idx - 1].end: + seg.start += edge_padding + if idx < len(segments_list) - 1: + if seg.end > segments_list[idx + 1].start: + seg.end -= edge_padding + + if seg.end - curr_start > chunk_length and curr_end - curr_start > 0: + merged_segments.append( + { + "start": curr_start, + "end": curr_end, + "segments": seg_idxs, + } + ) + curr_start = seg.start + seg_idxs = [] + speaker_idxs = [] + curr_end = seg.end + seg_idxs.append((seg.start, seg.end)) + speaker_idxs.append(seg.speaker) + # add final + merged_segments.append( + { + "start": curr_start, + "end": curr_end, + "segments": seg_idxs, + } + ) + return merged_segments diff --git a/requirements.txt b/requirements.txt index b1497ab4..6516f96c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,8 @@ -av>=11.0,<13 ctranslate2>=4.0,<5 huggingface_hub>=0.13 tokenizers>=0.13,<1 -onnxruntime>=1.14,<2 +onnxruntime>=1.14,<2 +transformers +pyannote-audio>=3.1.1 +torch>=2.1.1 +torchaudio>=2.1.2 \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py index 1a1ee1d1..0c0f4248 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,3 +11,8 @@ def data_dir(): @pytest.fixture def jfk_path(data_dir): return os.path.join(data_dir, "jfk.flac") + + +@pytest.fixture +def physcisworks_path(data_dir): + return os.path.join(data_dir, "physicsworks.wav") diff --git a/tests/data/physicsworks.wav b/tests/data/physicsworks.wav new file mode 100644 index 00000000..885b6c1c Binary files /dev/null and b/tests/data/physicsworks.wav differ diff --git a/tests/test_transcribe.py b/tests/test_transcribe.py index 7fa27b11..96eb68c3 100644 --- a/tests/test_transcribe.py +++ b/tests/test_transcribe.py @@ -1,6 +1,6 @@ import os -from faster_whisper import WhisperModel, decode_audio +from faster_whisper import BatchedInferencePipeline, WhisperModel, decode_audio from faster_whisper.tokenizer import Tokenizer from faster_whisper.transcribe import get_suppressed_tokens @@ -39,6 +39,50 @@ def test_transcribe(jfk_path): assert segment.text == "".join(word.word for word in segment.words) assert segment.start == segment.words[0].start assert segment.end == segment.words[-1].end + batched_model = BatchedInferencePipeline(model=model, use_vad_model=False) + result, info = batched_model.transcribe(jfk_path, word_timestamps=True) + assert info.language == "en" + assert info.language_probability > 0.7 + segments = [] + for segment in result: + segments.append( + {"start": segment.start, "end": segment.end, "text": segment.text} + ) + + assert len(segments) == 1 + assert segment.text == ( + " And so my fellow Americans ask not what your country can do for you, " + "ask what you can do for your country." + ) + + +def test_batched_transcribe(physcisworks_path): + model = WhisperModel("tiny") + batched_model = BatchedInferencePipeline(model=model) + result, info = batched_model.transcribe(physcisworks_path, batch_size=16) + assert info.language == "en" + assert info.language_probability > 0.7 + segments = [] + for segment in result: + segments.append( + {"start": segment.start, "end": segment.end, "text": segment.text} + ) + # number of near 30 sec segments + assert len(segments) == 8 + + result, info = batched_model.transcribe( + physcisworks_path, + batch_size=16, + without_timestamps=False, + word_timestamps=True, + ) + segments = [] + for segment in result: + assert segment.words is not None + segments.append( + {"start": segment.start, "end": segment.end, "text": segment.text} + ) + assert len(segments) > 8 def test_prefix_with_timestamps(jfk_path): @@ -101,6 +145,13 @@ def test_stereo_diarization(data_dir): assert transcription == "The horizon seems extremely distant." +def test_multisegment_lang_id(physcisworks_path): + model = WhisperModel("tiny") + language_info = model.detect_language_multi_segment(physcisworks_path) + assert language_info["language_code"] == "en" + assert language_info["language_confidence"] > 0.8 + + def test_suppressed_tokens_minus_1(): model = WhisperModel("tiny.en")