Skip to content

Commit

Permalink
Merge pull request #612 from zxzxwu/lc3
Browse files Browse the repository at this point in the history
Replace liblc3 wasm library
  • Loading branch information
zxzxwu authored Dec 19, 2024
2 parents 62e4670 + a27f55a commit 80d60aa
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 199 deletions.
266 changes: 68 additions & 198 deletions apps/lea_unicast/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,22 @@
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations

import asyncio
import datetime
import enum
import functools
from importlib import resources
import json
import os
import logging
import pathlib
from typing import Optional, List, cast
import weakref
import struct
import wave

import ctypes
import wasmtime
import wasmtime.loader
import liblc3 # type: ignore
try:
import lc3 # type: ignore # pylint: disable=E0401
except ImportError as e:
raise ImportError("Try `python -m pip install \".[lc3]\"`.") from e

import click
import aiohttp.web
Expand All @@ -45,6 +44,7 @@
from bumble.profiles import ascs, bap, pacs
from bumble.hci import Address, CodecID, CodingFormat, HCI_IsoDataPacket


# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
Expand All @@ -54,6 +54,7 @@
# Constants
# -----------------------------------------------------------------------------
DEFAULT_UI_PORT = 7654
DEFAULT_PCM_BYTES_PER_SAMPLE = 2


def _sink_pac_record() -> pacs.PacRecord:
Expand Down Expand Up @@ -100,153 +101,8 @@ def _source_pac_record() -> pacs.PacRecord:
)


# -----------------------------------------------------------------------------
# WASM - liblc3
# -----------------------------------------------------------------------------
store = wasmtime.loader.store
_memory = cast(wasmtime.Memory, liblc3.memory)
STACK_POINTER = _memory.data_len(store)
_memory.grow(store, 1)
# Mapping wasmtime memory to linear address
memory = (ctypes.c_ubyte * _memory.data_len(store)).from_address(
ctypes.addressof(_memory.data_ptr(store).contents) # type: ignore
)


class Liblc3PcmFormat(enum.IntEnum):
S16 = 0
S24 = 1
S24_3LE = 2
FLOAT = 3


MAX_DECODER_SIZE = liblc3.lc3_decoder_size(10000, 48000)
MAX_ENCODER_SIZE = liblc3.lc3_encoder_size(10000, 48000)

DECODER_STACK_POINTER = STACK_POINTER
ENCODER_STACK_POINTER = DECODER_STACK_POINTER + MAX_DECODER_SIZE * 2
DECODE_BUFFER_STACK_POINTER = ENCODER_STACK_POINTER + MAX_ENCODER_SIZE * 2
ENCODE_BUFFER_STACK_POINTER = DECODE_BUFFER_STACK_POINTER + 8192
DEFAULT_PCM_SAMPLE_RATE = 48000
DEFAULT_PCM_FORMAT = Liblc3PcmFormat.S16
DEFAULT_PCM_BYTES_PER_SAMPLE = 2


encoders: List[int] = []
decoders: List[int] = []


def setup_encoders(
sample_rate_hz: int, frame_duration_us: int, num_channels: int
) -> None:
logger.info(
f"setup_encoders {sample_rate_hz}Hz {frame_duration_us}us {num_channels}channels"
)
encoders[:num_channels] = [
liblc3.lc3_setup_encoder(
frame_duration_us,
sample_rate_hz,
DEFAULT_PCM_SAMPLE_RATE, # Input sample rate
ENCODER_STACK_POINTER + MAX_ENCODER_SIZE * i,
)
for i in range(num_channels)
]


def setup_decoders(
sample_rate_hz: int, frame_duration_us: int, num_channels: int
) -> None:
logger.info(
f"setup_decoders {sample_rate_hz}Hz {frame_duration_us}us {num_channels}channels"
)
decoders[:num_channels] = [
liblc3.lc3_setup_decoder(
frame_duration_us,
sample_rate_hz,
DEFAULT_PCM_SAMPLE_RATE, # Output sample rate
DECODER_STACK_POINTER + MAX_DECODER_SIZE * i,
)
for i in range(num_channels)
]


def decode(
frame_duration_us: int,
num_channels: int,
input_bytes: bytes,
) -> bytes:
if not input_bytes:
return b''

input_buffer_offset = DECODE_BUFFER_STACK_POINTER
input_buffer_size = len(input_bytes)
input_bytes_per_frame = input_buffer_size // num_channels

# Copy into wasm
memory[input_buffer_offset : input_buffer_offset + input_buffer_size] = input_bytes # type: ignore

output_buffer_offset = input_buffer_offset + input_buffer_size
output_buffer_size = (
liblc3.lc3_frame_samples(frame_duration_us, DEFAULT_PCM_SAMPLE_RATE)
* DEFAULT_PCM_BYTES_PER_SAMPLE
* num_channels
)

for i in range(num_channels):
res = liblc3.lc3_decode(
decoders[i],
input_buffer_offset + input_bytes_per_frame * i,
input_bytes_per_frame,
DEFAULT_PCM_FORMAT,
output_buffer_offset + i * DEFAULT_PCM_BYTES_PER_SAMPLE,
num_channels, # Stride
)

if res != 0:
logging.error(f"Parsing failed, res={res}")

# Extract decoded data from the output buffer
return bytes(
memory[output_buffer_offset : output_buffer_offset + output_buffer_size]
)


def encode(
sdu_length: int,
num_channels: int,
stride: int,
input_bytes: bytes,
) -> bytes:
if not input_bytes:
return b''

input_buffer_offset = ENCODE_BUFFER_STACK_POINTER
input_buffer_size = len(input_bytes)

# Copy into wasm
memory[input_buffer_offset : input_buffer_offset + input_buffer_size] = input_bytes # type: ignore

output_buffer_offset = input_buffer_offset + input_buffer_size
output_buffer_size = sdu_length
output_frame_size = output_buffer_size // num_channels

for i in range(num_channels):
res = liblc3.lc3_encode(
encoders[i],
DEFAULT_PCM_FORMAT,
input_buffer_offset + DEFAULT_PCM_BYTES_PER_SAMPLE * i,
stride,
output_frame_size,
output_buffer_offset + output_frame_size * i,
)

if res != 0:
logging.error(f"Parsing failed, res={res}")

# Extract decoded data from the output buffer
return bytes(
memory[output_buffer_offset : output_buffer_offset + output_buffer_size]
)
decoder: lc3.Decoder | None = None
encoding_config: bap.CodecSpecificConfiguration | None = None


async def lc3_source_task(
Expand All @@ -256,42 +112,58 @@ async def lc3_source_task(
device: Device,
cis_handle: int,
) -> None:
with open(filename, 'rb') as f:
header = f.read(44)
assert header[8:12] == b'WAVE'

pcm_num_channel, pcm_sample_rate, _byte_rate, _block_align, bits_per_sample = (
struct.unpack("<HIIHH", header[22:36])
)
assert pcm_sample_rate == DEFAULT_PCM_SAMPLE_RATE
assert bits_per_sample == DEFAULT_PCM_BYTES_PER_SAMPLE * 8

frame_bytes = (
liblc3.lc3_frame_samples(frame_duration_us, DEFAULT_PCM_SAMPLE_RATE)
* DEFAULT_PCM_BYTES_PER_SAMPLE
)
logger.info(
"lc3_source_task filename=%s, sdu_length=%d, frame_duration=%.1f",
filename,
sdu_length,
frame_duration_us / 1000,
)
with wave.open(filename, 'rb') as wav:
bits_per_sample = wav.getsampwidth() * 8
packet_sequence_number = 0

encoder: lc3.Encoder | None = None

while True:
next_round = datetime.datetime.now() + datetime.timedelta(
microseconds=frame_duration_us
)
pcm_data = f.read(frame_bytes)
sdu = encode(sdu_length, pcm_num_channel, pcm_num_channel, pcm_data)

iso_packet = HCI_IsoDataPacket(
connection_handle=cis_handle,
data_total_length=sdu_length + 4,
packet_sequence_number=packet_sequence_number,
pb_flag=0b10,
packet_status_flag=0,
iso_sdu_length=sdu_length,
iso_sdu_fragment=sdu,
)
device.host.send_hci_packet(iso_packet)
packet_sequence_number += 1
if not encoder:
if (
encoding_config
and (frame_duration := encoding_config.frame_duration)
and (sampling_frequency := encoding_config.sampling_frequency)
and (
audio_channel_allocation := encoding_config.audio_channel_allocation
)
):
logger.info("Use %s", encoding_config)
encoder = lc3.Encoder(
frame_duration_us=frame_duration.us,
sample_rate_hz=sampling_frequency.hz,
num_channels=audio_channel_allocation.channel_count,
input_sample_rate_hz=wav.getframerate(),
)
else:
sdu = encoder.encode(
pcm=wav.readframes(encoder.get_frame_samples()),
num_bytes=sdu_length,
bit_depth=bits_per_sample,
)

iso_packet = HCI_IsoDataPacket(
connection_handle=cis_handle,
data_total_length=sdu_length + 4,
packet_sequence_number=packet_sequence_number,
pb_flag=0b10,
packet_status_flag=0,
iso_sdu_length=sdu_length,
iso_sdu_fragment=sdu,
)
device.host.send_hci_packet(iso_packet)
packet_sequence_number += 1
sleep_time = next_round - datetime.datetime.now()
await asyncio.sleep(sleep_time.total_seconds())
await asyncio.sleep(sleep_time.total_seconds() * 0.9)


# -----------------------------------------------------------------------------
Expand Down Expand Up @@ -410,7 +282,7 @@ class Speaker:

def __init__(
self,
device_config_path: Optional[str],
device_config_path: str | None,
ui_port: int,
transport: str,
lc3_input_file_path: str,
Expand Down Expand Up @@ -490,12 +362,12 @@ def on_pdu(pdu: HCI_IsoDataPacket, ase: ascs.AseStateMachine):
not isinstance(codec_config, bap.CodecSpecificConfiguration)
or codec_config.frame_duration is None
or codec_config.audio_channel_allocation is None
or decoder is None
or not pdu.iso_sdu_fragment
):
return
pcm = decode(
codec_config.frame_duration.us,
codec_config.audio_channel_allocation.channel_count,
pdu.iso_sdu_fragment,
pcm = decoder.decode(
pdu.iso_sdu_fragment, bit_depth=DEFAULT_PCM_BYTES_PER_SAMPLE * 8
)
self.device.abort_on('disconnection', self.ui_server.send_audio(pcm))

Expand Down Expand Up @@ -537,16 +409,14 @@ def on_ase_state_change(ase: ascs.AseStateMachine) -> None:
):
return
if ase.role == ascs.AudioRole.SOURCE:
setup_encoders(
codec_config.sampling_frequency.hz,
codec_config.frame_duration.us,
codec_config.audio_channel_allocation.channel_count,
)
global encoding_config
encoding_config = codec_config
else:
setup_decoders(
codec_config.sampling_frequency.hz,
codec_config.frame_duration.us,
codec_config.audio_channel_allocation.channel_count,
global decoder
decoder = lc3.Decoder(
frame_duration_us=codec_config.frame_duration.us,
sample_rate_hz=codec_config.sampling_frequency.hz,
num_channels=codec_config.audio_channel_allocation.channel_count,
)

for ase in ascs_service.ase_state_machines.values():
Expand Down Expand Up @@ -585,7 +455,7 @@ def speaker(ui_port: int, device_config: str, transport: str, lc3_file: str) ->

# -----------------------------------------------------------------------------
def main():
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper())
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
speaker()


Expand Down
Binary file removed apps/lea_unicast/liblc3.wasm
Binary file not shown.
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ development = [
"types-appdirs >= 1.4.3",
"types-invoke >= 1.7.3",
"types-protobuf >= 4.21.0",
"wasmtime == 20.0.0",
]
avatar = [
"pandora-avatar == 0.0.10",
Expand All @@ -66,6 +65,9 @@ documentation = [
"mkdocs-material >= 8.5.6",
"mkdocstrings[python] >= 0.19.0",
]
lc3 = [
"lc3 @ git+https://github.com/google/liblc3.git",
]

[project.scripts]
bumble-ble-rpa-tool = "bumble.apps.ble_rpa_tool:main"
Expand Down

0 comments on commit 80d60aa

Please sign in to comment.