Skip to content

Commit

Permalink
Update audio.py
Browse files Browse the repository at this point in the history
  • Loading branch information
BBC-Esq authored Oct 26, 2024
1 parent 6d786b6 commit ca50320
Showing 1 changed file with 7 additions and 28 deletions.
35 changes: 7 additions & 28 deletions faster_whisper/audio.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,11 @@
"""We use the PyAV library to decode the audio: https://github.com/PyAV-Org/PyAV
The advantage of PyAV is that it bundles the FFmpeg libraries so there is no additional
system dependencies. FFmpeg does not need to be installed on the system.
However, the API is quite low-level so we need to manipulate audio frames directly.
"""

import gc
import io
import itertools

from typing import BinaryIO, Union

import av
Expand All @@ -23,15 +19,12 @@ def decode_audio(
split_stereo: bool = False,
):
"""Decodes the audio.
Args:
input_file: Path to the input file or a file-like object.
sampling_rate: Resample the audio to this sample rate.
split_stereo: Return separate left and right channels.
Returns:
A float32 Numpy array.
If `split_stereo` is enabled, the function returns a 2-tuple with the
separated left and right channels.
"""
Expand All @@ -40,16 +33,14 @@ def decode_audio(
layout="mono" if not split_stereo else "stereo",
rate=sampling_rate,
)

audio_chunks = [] # Initialize a list to hold NumPy arrays
dtype = None

with av.open(input_file, mode="r", metadata_errors="ignore") as container:
frames = container.decode(audio=0)
frames = _ignore_invalid_frames(frames)
frames = _group_frames(frames, 500000)
frames = _resample_frames(frames, resampler)

frames = ignore_invalid_frames(frames)
frames = group_frames(frames, 500000)
frames = resample_frames(frames, resampler)
for frame in frames:
array = frame.to_ndarray()
dtype = array.dtype
Expand All @@ -71,13 +62,11 @@ def decode_audio(
left_channel = audio[0::2]
right_channel = audio[1::2]
return torch.from_numpy(left_channel), torch.from_numpy(right_channel)

return torch.from_numpy(audio)


def _ignore_invalid_frames(frames):
def ignore_invalid_frames(frames):
iterator = iter(frames)

while True:
try:
yield next(iterator)
Expand All @@ -87,21 +76,18 @@ def _ignore_invalid_frames(frames):
continue


def _group_frames(frames, num_samples=None):
def group_frames(frames, num_samples=None):
fifo = av.audio.fifo.AudioFifo()

for frame in frames:
frame.pts = None # Ignore timestamp check.
fifo.write(frame)

if num_samples is not None and fifo.samples >= num_samples:
yield fifo.read()

if fifo.samples > 0:
yield fifo.read()


def _resample_frames(frames, resampler):
def resample_frames(frames, resampler):
# Add None to flush the resampler.
for frame in itertools.chain(frames, [None]):
yield from resampler.resample(frame)
Expand All @@ -117,14 +103,7 @@ def pad_or_trim(array, length: int, *, axis: int = -1):
return array[idx]

if array.shape[axis] < length:
pad_widths = (
[
0,
]
* array.ndim
* 2
)
pad_widths = ([0] * array.ndim * 2)
pad_widths[2 * axis] = length - array.shape[axis]
array = torch.nn.functional.pad(array, tuple(pad_widths[::-1]))

return array

0 comments on commit ca50320

Please sign in to comment.