Skip to content

Commit

Permalink
Added multiprocessing for cpu processing
Browse files Browse the repository at this point in the history
  • Loading branch information
joiemoie committed Jan 22, 2024
1 parent 72ff979 commit 57df6dd
Showing 1 changed file with 64 additions and 35 deletions.
99 changes: 64 additions & 35 deletions faster_whisper/transcribe.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import itertools
import json
import logging
import multiprocessing
import os
import zlib

Expand Down Expand Up @@ -78,6 +79,51 @@ class TranscriptionInfo(NamedTuple):
vad_options: VadOptions


# Performs the preprocessing on its own process to make use of all CPU cores
def cpu_preprocessing(
logger,
feature_extractor,
audio: Union[str, BinaryIO, np.ndarray],
vad_filter: bool = False,
vad_parameters: Optional[Union[dict, VadOptions]] = None,
) -> Tuple[np.ndarray, float, float, Optional[List[dict]]]:
sampling_rate = feature_extractor.sampling_rate
duration = audio.shape[0] / sampling_rate
duration_after_vad = duration

logger.info("Processing audio with duration %s", format_timestamp(duration))

if vad_filter:
speech_chunks = get_speech_timestamps(audio, vad_parameters)
audio = collect_chunks(audio, speech_chunks)
duration_after_vad = audio.shape[0] / sampling_rate

logger.info(
"VAD filter removed %s of audio",
format_timestamp(duration - duration_after_vad),
)

if logger.isEnabledFor(logging.DEBUG):
logger.debug(
"VAD filter kept the following audio segments: %s",
", ".join(
"[%s -> %s]"
% (
format_timestamp(chunk["start"] / sampling_rate),
format_timestamp(chunk["end"] / sampling_rate),
)
for chunk in speech_chunks
),
)

else:
speech_chunks = None

features = feature_extractor(audio)

return features, duration, duration_after_vad, speech_chunks


class WhisperModel:
def __init__(
self,
Expand Down Expand Up @@ -156,6 +202,7 @@ def __init__(
self.input_stride = 2
self.time_precision = 0.02
self.max_length = 448
self.cpu_pool = multiprocessing.Pool()

@property
def supported_languages(self) -> List[str]:
Expand Down Expand Up @@ -271,49 +318,29 @@ def transcribe(
- a generator over transcribed segments
- an instance of TranscriptionInfo
"""
sampling_rate = self.feature_extractor.sampling_rate

if not isinstance(audio, np.ndarray):
audio = decode_audio(audio, sampling_rate=sampling_rate)

duration = audio.shape[0] / sampling_rate
duration_after_vad = duration

self.logger.info(
"Processing audio with duration %s", format_timestamp(duration)
)
audio = decode_audio(
audio, sampling_rate=self.feature_extractor.sampling_rate
)

if vad_filter:
if vad_parameters is None:
vad_parameters = VadOptions()
elif isinstance(vad_parameters, dict):
vad_parameters = VadOptions(**vad_parameters)
speech_chunks = get_speech_timestamps(audio, vad_parameters)
audio = collect_chunks(audio, speech_chunks)
duration_after_vad = audio.shape[0] / sampling_rate

self.logger.info(
"VAD filter removed %s of audio",
format_timestamp(duration - duration_after_vad),
)

if self.logger.isEnabledFor(logging.DEBUG):
self.logger.debug(
"VAD filter kept the following audio segments: %s",
", ".join(
"[%s -> %s]"
% (
format_timestamp(chunk["start"] / sampling_rate),
format_timestamp(chunk["end"] / sampling_rate),
)
for chunk in speech_chunks
),
)

else:
speech_chunks = None

features = self.feature_extractor(audio)
# Spawns a new process to run preprocessing on CPU
features, duration, duration_after_vad, speech_chunks = self.cpu_pool.apply(
cpu_preprocessing,
(
self.logger,
self.feature_extractor,
audio,
vad_filter,
vad_parameters,
),
)

encoder_output = None
all_language_probs = None
Expand Down Expand Up @@ -384,7 +411,9 @@ def transcribe(
segments = self.generate_segments(features, tokenizer, options, encoder_output)

if speech_chunks:
segments = restore_speech_timestamps(segments, speech_chunks, sampling_rate)
segments = restore_speech_timestamps(
segments, speech_chunks, self.feature_extractor.sampling_rate
)

info = TranscriptionInfo(
language=language,
Expand Down

0 comments on commit 57df6dd

Please sign in to comment.