Skip to content

Commit

Permalink
Merge branch 'master' into patch/add-project
Browse files Browse the repository at this point in the history
  • Loading branch information
jhj0517 authored Nov 18, 2024
2 parents 95e88a1 + be9fb36 commit e207f58
Show file tree
Hide file tree
Showing 8 changed files with 482 additions and 576 deletions.
29 changes: 9 additions & 20 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,9 @@ segments, info = model.transcribe("audio.mp3", beam_size=5, language="en")
GPU execution requires the following NVIDIA libraries to be installed:

* [cuBLAS for CUDA 12](https://developer.nvidia.com/cublas)
* [cuDNN 8 for CUDA 12](https://developer.nvidia.com/cudnn)
* [cuDNN 9 for CUDA 12](https://developer.nvidia.com/cudnn)

**Note**: Latest versions of `ctranslate2` support CUDA 12 only. For CUDA 11, the current workaround is downgrading to the `3.24.0` version of `ctranslate2` (This can be done with `pip install --force-reinstall ctranslate2==3.24.0` or specifying the version in a `requirements.txt`).
**Note**: The latest versions of `ctranslate2` only support CUDA 12 and cuDNN 9. For CUDA 11 and cuDNN 8, the current workaround is downgrading to the `3.24.0` version of `ctranslate2`, for CUDA 12 and cuDNN 8, downgrade to the `4.4.0` version of `ctranslate2`, (This can be done with `pip install --force-reinstall ctranslate2==4.4.0` or specifying the version in a `requirements.txt`).

There are multiple ways to install the NVIDIA libraries mentioned above. The recommended way is described in the official NVIDIA documentation, but we also suggest other installation methods below.

Expand All @@ -89,20 +89,18 @@ There are multiple ways to install the NVIDIA libraries mentioned above. The rec

#### Use Docker

The libraries (cuBLAS, cuDNN) are installed in these official NVIDIA CUDA Docker images: `nvidia/cuda:12.0.0-runtime-ubuntu20.04` or `nvidia/cuda:12.0.0-runtime-ubuntu22.04`.
The libraries (cuBLAS, cuDNN) are installed in this official NVIDIA CUDA Docker images: `nvidia/cuda:12.3.2-cudnn9-runtime-ubuntu22.04`.

#### Install with `pip` (Linux only)

On Linux these libraries can be installed with `pip`. Note that `LD_LIBRARY_PATH` must be set before launching Python.

```bash
pip install nvidia-cublas-cu12 nvidia-cudnn-cu12
pip install nvidia-cublas-cu12 nvidia-cudnn-cu12==9.*

export LD_LIBRARY_PATH=`python3 -c 'import os; import nvidia.cublas.lib; import nvidia.cudnn.lib; print(os.path.dirname(nvidia.cublas.lib.__file__) + ":" + os.path.dirname(nvidia.cudnn.lib.__file__))'`
```

**Note**: Version 9+ of `nvidia-cudnn-cu12` appears to cause issues due its reliance on cuDNN 9 (Faster-Whisper does not currently support cuDNN 9). Ensure your version of the Python package is for cuDNN 8.

#### Download the libraries from Purfview's repository (Windows & Linux)

Purfview's [whisper-standalone-win](https://github.com/Purfview/whisper-standalone-win) provides the required NVIDIA libraries for Windows & Linux in a [single archive](https://github.com/Purfview/whisper-standalone-win/releases/tag/libs). Decompress the archive and place the libraries in a directory included in the `PATH`.
Expand Down Expand Up @@ -166,24 +164,13 @@ segments, _ = model.transcribe("audio.mp3")
segments = list(segments) # The transcription will actually run here.
```

### multi-segment language detection

To directly use the model for improved language detection, the following code snippet can be used:

```python
from faster_whisper import WhisperModel
model = WhisperModel("medium", device="cuda", compute_type="float16")
language_info = model.detect_language_multi_segment("audio.mp3")
```

### Batched faster-whisper

The following code snippet illustrates how to run inference with batched version on an example audio file. Please also refer to the test scripts of batched faster whisper.
### Batched Transcription
The following code snippet illustrates how to run batched transcription on an example audio file. `BatchedInferencePipeline.transcribe` is a drop-in replacement for `WhisperModel.transcribe`

```python
from faster_whisper import WhisperModel, BatchedInferencePipeline

model = WhisperModel("medium", device="cuda", compute_type="float16")
model = WhisperModel("turbo", device="cuda", compute_type="float16")
batched_model = BatchedInferencePipeline(model=model)
segments, info = batched_model.transcribe("audio.mp3", batch_size=16)

Expand Down Expand Up @@ -238,6 +225,7 @@ segments, _ = model.transcribe(
vad_parameters=dict(min_silence_duration_ms=500),
)
```
Vad filter is enabled by default for batched transcription.

### Logging

Expand Down Expand Up @@ -310,6 +298,7 @@ model = faster_whisper.WhisperModel("username/whisper-large-v3-ct2")
If you are comparing the performance against other Whisper implementations, you should make sure to run the comparison with similar settings. In particular:

* Verify that the same transcription options are used, especially the same beam size. For example in openai/whisper, `model.transcribe` uses a default beam size of 1 but here we use a default beam size of 5.
* Transcription speed is closely affected by the number of words in the transcript, so ensure that other implementations have a similar WER (Word Error Rate) to this one.
* When running on CPU, make sure to set the same number of threads. Many frameworks will read the environment variable `OMP_NUM_THREADS`, which can be set when running your script:

```bash
Expand Down
21 changes: 6 additions & 15 deletions faster_whisper/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

import av
import numpy as np
import torch


def decode_audio(
Expand Down Expand Up @@ -72,9 +71,9 @@ def decode_audio(
if split_stereo:
left_channel = audio[0::2]
right_channel = audio[1::2]
return torch.from_numpy(left_channel), torch.from_numpy(right_channel)
return left_channel, right_channel

return torch.from_numpy(audio)
return audio


def _ignore_invalid_frames(frames):
Expand Down Expand Up @@ -113,20 +112,12 @@ def pad_or_trim(array, length: int = 3000, *, axis: int = -1):
"""
Pad or trim the Mel features array to 3000, as expected by the encoder.
"""
axis = axis % array.ndim
if array.shape[axis] > length:
idx = [Ellipsis] * axis + [slice(length)] + [Ellipsis] * (array.ndim - axis - 1)
return array[idx]
array = array.take(indices=range(length), axis=axis)

if array.shape[axis] < length:
pad_widths = (
[
0,
]
* array.ndim
* 2
)
pad_widths[2 * axis] = length - array.shape[axis]
array = torch.nn.functional.pad(array, tuple(pad_widths[::-1]))
pad_widths = [(0, 0)] * array.ndim
pad_widths[axis] = (0, length - array.shape[axis])
array = np.pad(array, pad_widths)

return array
206 changes: 161 additions & 45 deletions faster_whisper/feature_extractor.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,15 @@
import torch
import numpy as np


# Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/feature_extraction_whisper.py # noqa: E501
class FeatureExtractor:
def __init__(
self,
device: str = "auto",
feature_size=80,
sampling_rate=16000,
hop_length=160,
chunk_length=30,
n_fft=400,
):
if device == "auto":
self.device = "cuda" if torch.cuda.is_available() else "cpu"
else:
self.device = device
self.n_fft = n_fft
self.hop_length = hop_length
self.chunk_length = chunk_length
Expand All @@ -25,24 +19,21 @@ def __init__(
self.sampling_rate = sampling_rate
self.mel_filters = self.get_mel_filters(
sampling_rate, n_fft, n_mels=feature_size
)
).astype("float32")

@staticmethod
def get_mel_filters(sr, n_fft, n_mels=128):
"""
Implementation of librosa.filters.mel in Pytorch
"""
# Initialize the weights
n_mels = int(n_mels)

# Center freqs of each FFT bin
fftfreqs = torch.fft.rfftfreq(n=n_fft, d=1.0 / sr)
fftfreqs = np.fft.rfftfreq(n=n_fft, d=1.0 / sr)

# 'Center freqs' of mel bands - uniformly spaced between limits
min_mel = 0.0
max_mel = 45.245640471924965

mels = torch.linspace(min_mel, max_mel, n_mels + 2)
mels = np.linspace(min_mel, max_mel, n_mels + 2)

# Fill in the linear scale
f_min = 0.0
Expand All @@ -52,30 +43,159 @@ def get_mel_filters(sr, n_fft, n_mels=128):
# And now the nonlinear scale
min_log_hz = 1000.0 # beginning of log region (Hz)
min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels)
logstep = torch.log(torch.tensor(6.4)) / 27.0 # step size for log region
logstep = np.log(6.4) / 27.0 # step size for log region

# If we have vector data, vectorize
log_t = mels >= min_log_mel
freqs[log_t] = min_log_hz * torch.exp(logstep * (mels[log_t] - min_log_mel))

mel_f = freqs
freqs[log_t] = min_log_hz * np.exp(logstep * (mels[log_t] - min_log_mel))

fdiff = torch.diff(mel_f)
ramps = mel_f.view(-1, 1) - fftfreqs.view(1, -1)
fdiff = np.diff(freqs)
ramps = freqs.reshape(-1, 1) - fftfreqs.reshape(1, -1)

lower = -ramps[:-2] / fdiff[:-1].unsqueeze(1)
upper = ramps[2:] / fdiff[1:].unsqueeze(1)
lower = -ramps[:-2] / np.expand_dims(fdiff[:-1], axis=1)
upper = ramps[2:] / np.expand_dims(fdiff[1:], axis=1)

# Intersect them with each other and zero, vectorized across all i
weights = torch.maximum(torch.zeros_like(lower), torch.minimum(lower, upper))
weights = np.maximum(np.zeros_like(lower), np.minimum(lower, upper))

# Slaney-style mel is scaled to be approx constant energy per channel
enorm = 2.0 / (mel_f[2 : n_mels + 2] - mel_f[:n_mels])
weights *= enorm.unsqueeze(1)
enorm = 2.0 / (freqs[2 : n_mels + 2] - freqs[:n_mels])
weights *= np.expand_dims(enorm, axis=1)

return weights

def __call__(self, waveform, padding=True, chunk_length=None, to_cpu=False):
@staticmethod
def stft(
input_array: np.ndarray,
n_fft: int,
hop_length: int = None,
win_length: int = None,
window: np.ndarray = None,
center: bool = True,
mode: str = "reflect",
normalized: bool = False,
onesided: bool = None,
return_complex: bool = None,
):
# Default initialization for hop_length and win_length
hop_length = hop_length if hop_length is not None else n_fft // 4
win_length = win_length if win_length is not None else n_fft
input_is_complex = np.iscomplexobj(input_array)

# Determine if the output should be complex
return_complex = (
return_complex
if return_complex is not None
else (input_is_complex or (window is not None and np.iscomplexobj(window)))
)

if not return_complex and return_complex is None:
raise ValueError(
"stft requires the return_complex parameter for real inputs."
)

# Input checks
if not np.issubdtype(input_array.dtype, np.floating) and not input_is_complex:
raise ValueError(
"stft: expected an array of floating point or complex values,"
f" got {input_array.dtype}"
)

if input_array.ndim > 2 or input_array.ndim < 1:
raise ValueError(
f"stft: expected a 1D or 2D array, but got {input_array.ndim}D array"
)

# Handle 1D input
if input_array.ndim == 1:
input_array = np.expand_dims(input_array, axis=0)
input_array_1d = True
else:
input_array_1d = False

# Center padding if required
if center:
pad_amount = n_fft // 2
input_array = np.pad(
input_array, ((0, 0), (pad_amount, pad_amount)), mode=mode
)

batch, length = input_array.shape

# Additional input checks
if n_fft <= 0 or n_fft > length:
raise ValueError(
f"stft: expected 0 < n_fft <= {length}, but got n_fft={n_fft}"
)

if hop_length <= 0:
raise ValueError(
f"stft: expected hop_length > 0, but got hop_length={hop_length}"
)

if win_length <= 0 or win_length > n_fft:
raise ValueError(
f"stft: expected 0 < win_length <= n_fft, but got win_length={win_length}"
)

if window is not None:
if window.ndim != 1 or window.shape[0] != win_length:
raise ValueError(
f"stft: expected a 1D window array of size equal to win_length={win_length}, "
f"but got window with size {window.shape}"
)

# Handle padding of the window if necessary
if win_length < n_fft:
left = (n_fft - win_length) // 2
window_ = np.zeros(n_fft, dtype=window.dtype)
window_[left : left + win_length] = window
else:
window_ = window

# Calculate the number of frames
n_frames = 1 + (length - n_fft) // hop_length

# Time to columns
input_array = np.lib.stride_tricks.as_strided(
input_array,
(batch, n_frames, n_fft),
(
input_array.strides[0],
hop_length * input_array.strides[1],
input_array.strides[1],
),
)

if window_ is not None:
input_array = input_array * window_

# FFT and transpose
complex_fft = input_is_complex
onesided = onesided if onesided is not None else not complex_fft

if normalized:
norm = "ortho"
else:
norm = None

if complex_fft:
if onesided:
raise ValueError(
"Cannot have onesided output if window or input is complex"
)
output = np.fft.fft(input_array, n=n_fft, axis=-1, norm=norm)
else:
output = np.fft.rfft(input_array, n=n_fft, axis=-1, norm=norm)

output = output.transpose((0, 2, 1))

if input_array_1d:
output = output.squeeze(0)

return output if return_complex else np.real(output)

def __call__(self, waveform: np.ndarray, padding=160, chunk_length=None):
"""
Compute the log-Mel spectrogram of the provided audio.
"""
Expand All @@ -84,31 +204,27 @@ def __call__(self, waveform, padding=True, chunk_length=None, to_cpu=False):
self.n_samples = chunk_length * self.sampling_rate
self.nb_max_frames = self.n_samples // self.hop_length

if waveform.dtype is not torch.float32:
waveform = waveform.to(torch.float32)

waveform = (
waveform.to(self.device)
if self.device == "cuda" and not waveform.is_cuda
else waveform
)
if waveform.dtype is not np.float32:
waveform = waveform.astype(np.float32)

if padding:
waveform = torch.nn.functional.pad(waveform, (0, self.n_samples))
waveform = np.pad(waveform, (0, padding))

window = torch.hann_window(self.n_fft).to(waveform.device)
window = np.hanning(self.n_fft + 1)[:-1].astype("float32")

stft = torch.stft(
waveform, self.n_fft, self.hop_length, window=window, return_complex=True
)
magnitudes = stft[..., :-1].abs() ** 2
stft = self.stft(
waveform,
self.n_fft,
self.hop_length,
window=window,
return_complex=True,
).astype("complex64")
magnitudes = np.abs(stft[..., :-1]) ** 2

mel_spec = self.mel_filters.to(waveform.device) @ magnitudes
mel_spec = self.mel_filters @ magnitudes

log_spec = torch.clamp(mel_spec, min=1e-10).log10()
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
log_spec = np.log10(np.clip(mel_spec, a_min=1e-10, a_max=None))
log_spec = np.maximum(log_spec, log_spec.max() - 8.0)
log_spec = (log_spec + 4.0) / 4.0

# When the model is running on multiple GPUs, the output should be moved
# to the CPU since we don't know which GPU will handle the next job.
return log_spec.cpu() if to_cpu else log_spec
return log_spec
Loading

0 comments on commit e207f58

Please sign in to comment.