Skip to content

Commit

Permalink
Fixed audio routes
Browse files Browse the repository at this point in the history
  • Loading branch information
mbsantiago committed Jan 24, 2024
1 parent 4f98162 commit 37ea904
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 25 deletions.
47 changes: 25 additions & 22 deletions back/src/whombat/api/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
"load_clip_bytes",
]

CHUNK_SIZE = 1024 * 1024
CHUNK_SIZE = 512 * 1024
HEADER_FORMAT = "<4si4s4sihhiihh4si"


def load_audio(
Expand Down Expand Up @@ -131,7 +132,6 @@ def load_clip_bytes(
samplerate = sf_file.samplerate * time_expansion
channels = sf_file.channels
bit_depth = BIT_DEPTH_MAP.get(sf_file.subtype)
bytes_per_sample = channels * bit_depth // 8

if bit_depth is None:
raise NotImplementedError(
Expand All @@ -152,35 +152,38 @@ def load_clip_bytes(
end_position = f.tell()

data_size = end_position - start_position
filesize = data_size + 44

header = generate_wav_header(
int(samplerate * speed),
channels,
data_size,
bit_depth,
)
header_size = len(header)
filesize = data_size + header_size

bytes_to_load = end - start if end else CHUNK_SIZE

header = b""
if start < 44:
if start < header_size:
start = 0
bytes_to_load -= 44
header = generate_wav_header(
int(samplerate * speed),
channels,
data_size,
bit_depth,
)
bytes_to_load -= header_size
else:
header = b""

if bytes_to_load < 0:
return header, 0, 44, filesize

start_sample = int(start_time * samplerate)
start_offset = max(int((start - 44) / bytes_per_sample), 0)
return header, 0, header_size, filesize

sf_file.seek(start_sample + start_offset)
current_position = f.tell()
current_position = start_position + start - header_size
bytes_to_load = min(
bytes_to_load,
end_position - current_position,
)

max_bytes_to_load = int(end_position - current_position)
audio_bytes = f.read(min(bytes_to_load, max_bytes_to_load))
f.seek(current_position)
audio_bytes = f.read(bytes_to_load)

data = bytes(header + audio_bytes)
end = start + len(data)
end = start + len(data) - 1
return data, start, end, filesize


Expand Down Expand Up @@ -219,7 +222,7 @@ def generate_wav_header(
block_align = channels * bit_depth // 8

return struct.pack(
"<4si4s4sihhiihh4si", # Format string
HEADER_FORMAT,
b"RIFF", # RIFF chunk id
data_size + 36, # Size of the entire file minus 8 bytes
b"WAVE", # RIFF chunk id
Expand Down
1 change: 1 addition & 0 deletions back/src/whombat/routes/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ async def stream_recording_audio(

headers = {
"Content-Range": f"bytes {start}-{end}/{filesize}",
"Content-Length": f"{len(data)}",
"Accept-Ranges": "bytes",
}
return Response(
Expand Down
4 changes: 2 additions & 2 deletions back/src/whombat/system/boot.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,15 +106,15 @@ async def whombat_init(settings: Settings, _: FastAPI):
if await is_first_run(settings):
print_first_run_message(settings)

if settings.open_on_startup:
if settings.open_on_startup and not settings.debug:
webbrowser.open(
f"http://{settings.backend_host}:{settings.backend_port}/first/"
)
return

print_ready_message(settings)

if settings.open_on_startup:
if settings.open_on_startup and not settings.debug:
webbrowser.open(
f"http://{settings.backend_host}:{settings.backend_port}/"
)
77 changes: 76 additions & 1 deletion back/tests/test_api/test_audio.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from whombat.api.audio import load_clip_bytes
import struct

from whombat.api.audio import HEADER_FORMAT, CHUNK_SIZE, load_clip_bytes


def test_load_clip_bytes(random_wav_factory):
Expand Down Expand Up @@ -104,3 +106,76 @@ def test_stream_a_whole_audio_file(random_wav_factory):
break

assert b"".join(parts) == full_bytes


def test_stream_a_whole_audio_file_with_non_1_speed(random_wav_factory):
path = random_wav_factory(
duration=0.5,
samplerate=384_000,
channels=1,
bit_depth=32,
)

speed = 0.1

with open(path, "rb") as f:
full_bytes = f.read()

true_filesize = path.stat().st_size

start = 0
filesize = None
parts = []
while True:
end = start + CHUNK_SIZE

if filesize is not None and end > filesize:
end = filesize

part, start, end, filesize = load_clip_bytes(
path=path,
start=start,
end=end,
speed=speed,
)
parts.append(part)
start = end

assert filesize == true_filesize

if not part or start >= filesize:
break

streamed = b"".join(parts)

assert streamed[44:] == full_bytes[44:]

fields = [
"riff",
"size",
"wave",
"fmt ",
"fmt_size",
"format",
"channels",
"samplerate",
"byte_rate",
"block_align",
"bit_depth",
"data",
"data_size",
]

orig_header = struct.unpack(HEADER_FORMAT, full_bytes[:44])
streamed_header = struct.unpack(HEADER_FORMAT, streamed[:44])

for (field, h1, h2) in zip(fields, orig_header, streamed_header):
if field == "samplerate":
assert int(h1 * speed) == h2
continue

if field == "byte_rate":
assert int(h1 * speed) == h2
continue

assert h1 == h2

0 comments on commit 37ea904

Please sign in to comment.