forked from rhasspy/wyoming-satellite
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'rhasspy:master' into master
- Loading branch information
Showing
14 changed files
with
467 additions
and
35 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,7 @@ | ||
--- | ||
name: test | ||
|
||
# yamllint disable-line rule:truthy | ||
on: | ||
workflow_dispatch: | ||
pull_request: | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
wyoming==1.5.2 | ||
wyoming==1.5.3 | ||
zeroconf==0.88.0 | ||
pyring-buffer==1.0.0 | ||
webrtc-noise-gain==1.2.3 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,64 @@ | ||
"""Shared code for Wyoming satellite tests.""" | ||
from wyoming.audio import AudioChunk | ||
import asyncio | ||
import io | ||
from collections.abc import Iterable | ||
from typing import Optional | ||
|
||
from wyoming.audio import AudioChunk, AudioStart, AudioStop | ||
from wyoming.client import AsyncClient | ||
from wyoming.event import Event | ||
|
||
AUDIO_START = AudioStart(rate=16000, width=2, channels=1) | ||
AUDIO_STOP = AudioStop() | ||
|
||
AUDIO_CHUNK = AudioChunk( | ||
rate=16000, width=2, channels=1, audio=bytes([255] * 960) # 30ms | ||
) | ||
|
||
|
||
class FakeStreamReaderWriter: | ||
def __init__(self) -> None: | ||
self._undrained_data = bytes() | ||
self._value = bytes() | ||
self._data_ready = asyncio.Event() | ||
|
||
def write(self, data: bytes) -> None: | ||
self._undrained_data += data | ||
|
||
def writelines(self, data: Iterable[bytes]) -> None: | ||
for line in data: | ||
self.write(line) | ||
|
||
async def drain(self) -> None: | ||
self._value += self._undrained_data | ||
self._undrained_data = bytes() | ||
self._data_ready.set() | ||
self._data_ready.clear() | ||
|
||
async def readline(self) -> bytes: | ||
while b"\n" not in self._value: | ||
await self._data_ready.wait() | ||
|
||
with io.BytesIO(self._value) as value_io: | ||
data = value_io.readline() | ||
self._value = self._value[len(data) :] | ||
return data | ||
|
||
async def readexactly(self, n: int) -> bytes: | ||
while len(self._value) < n: | ||
await self._data_ready.wait() | ||
|
||
data = self._value[:n] | ||
self._value = self._value[n:] | ||
return data | ||
|
||
|
||
class MicClient(AsyncClient): | ||
async def read_event(self) -> Optional[Event]: | ||
# Send 30ms of audio every 30ms | ||
await asyncio.sleep(AUDIO_CHUNK.seconds) | ||
return AUDIO_CHUNK.event() | ||
|
||
async def write_event(self, event: Event) -> None: | ||
# Output only | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,222 @@ | ||
import asyncio | ||
import logging | ||
from typing import Final, Optional | ||
from unittest.mock import patch | ||
|
||
import pytest | ||
from wyoming.asr import Transcript | ||
from wyoming.audio import AudioChunk, AudioStart, AudioStop | ||
from wyoming.client import AsyncClient | ||
from wyoming.event import Event, async_read_event | ||
from wyoming.pipeline import PipelineStage, RunPipeline | ||
from wyoming.satellite import RunSatellite, StreamingStarted, StreamingStopped | ||
from wyoming.tts import Synthesize | ||
from wyoming.wake import Detection | ||
|
||
from wyoming_satellite import ( | ||
EventSettings, | ||
MicSettings, | ||
SatelliteSettings, | ||
SndSettings, | ||
WakeSettings, | ||
WakeStreamingSatellite, | ||
) | ||
|
||
from .shared import ( | ||
AUDIO_CHUNK, | ||
AUDIO_START, | ||
AUDIO_STOP, | ||
FakeStreamReaderWriter, | ||
MicClient, | ||
) | ||
|
||
_LOGGER = logging.getLogger() | ||
|
||
TIMEOUT: Final = 1 | ||
|
||
|
||
class WakeClient(AsyncClient): | ||
def __init__(self) -> None: | ||
super().__init__() | ||
self._event_ready = asyncio.Event() | ||
self._event: Optional[Event] = None | ||
self._detected: bool = False | ||
|
||
async def read_event(self) -> Optional[Event]: | ||
await self._event_ready.wait() | ||
self._event_ready.clear() | ||
return self._event | ||
|
||
async def write_event(self, event: Event) -> None: | ||
if AudioChunk.is_type(event.type): | ||
if not self._detected: | ||
self._detected = True | ||
self._event = Detection().event() | ||
self._event_ready.set() | ||
|
||
|
||
class SndClient(AsyncClient): | ||
def __init__(self) -> None: | ||
super().__init__() | ||
self.synthesize = asyncio.Event() | ||
self.audio_start = asyncio.Event() | ||
self.audio_chunk = asyncio.Event() | ||
self.audio_stop = asyncio.Event() | ||
|
||
async def read_event(self) -> Optional[Event]: | ||
# Input only | ||
pass | ||
|
||
async def write_event(self, event: Event) -> None: | ||
if AudioChunk.is_type(event.type): | ||
self.audio_chunk.set() | ||
elif Synthesize.is_type(event.type): | ||
self.synthesize.set() | ||
elif AudioStart.is_type(event.type): | ||
self.audio_start.set() | ||
elif AudioStop.is_type(event.type): | ||
self.audio_stop.set() | ||
|
||
|
||
class EventClient(AsyncClient): | ||
def __init__(self) -> None: | ||
super().__init__() | ||
self.detection = asyncio.Event() | ||
self.streaming_started = asyncio.Event() | ||
self.streaming_stopped = asyncio.Event() | ||
self.transcript = asyncio.Event() | ||
self.synthesize = asyncio.Event() | ||
self.audio_start = asyncio.Event() | ||
self.audio_chunk = asyncio.Event() | ||
self.audio_stop = asyncio.Event() | ||
|
||
async def read_event(self) -> Optional[Event]: | ||
# Input only | ||
return None | ||
|
||
async def write_event(self, event: Event) -> None: | ||
if Detection.is_type(event.type): | ||
self.detection.set() | ||
elif StreamingStarted.is_type(event.type): | ||
self.streaming_started.set() | ||
elif StreamingStopped.is_type(event.type): | ||
self.streaming_stopped.set() | ||
elif Transcript.is_type(event.type): | ||
self.transcript.set() | ||
elif Synthesize.is_type(event.type): | ||
self.synthesize.set() | ||
elif AudioChunk.is_type(event.type): | ||
self.audio_chunk.set() | ||
elif AudioStart.is_type(event.type): | ||
self.audio_start.set() | ||
elif AudioStop.is_type(event.type): | ||
self.audio_stop.set() | ||
|
||
|
||
# ----------------------------------------------------------------------------- | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_wake_satellite() -> None: | ||
mic_client = MicClient() | ||
snd_client = SndClient() | ||
wake_client = WakeClient() | ||
event_client = EventClient() | ||
|
||
with patch( | ||
"wyoming_satellite.satellite.SatelliteBase._make_mic_client", | ||
return_value=mic_client, | ||
), patch( | ||
"wyoming_satellite.satellite.SatelliteBase._make_snd_client", | ||
return_value=snd_client, | ||
), patch( | ||
"wyoming_satellite.satellite.SatelliteBase._make_wake_client", | ||
return_value=wake_client, | ||
), patch( | ||
"wyoming_satellite.satellite.SatelliteBase._make_event_client", | ||
return_value=event_client, | ||
): | ||
satellite = WakeStreamingSatellite( | ||
SatelliteSettings( | ||
mic=MicSettings(uri="test"), | ||
snd=SndSettings(uri="test"), | ||
wake=WakeSettings(uri="test"), | ||
event=EventSettings(uri="test"), | ||
) | ||
) | ||
|
||
async def event_from_satellite() -> Optional[Event]: | ||
return await async_read_event(server_io) | ||
|
||
satellite_task = asyncio.create_task(satellite.run(), name="satellite") | ||
|
||
# Fake server connection | ||
server_io = FakeStreamReaderWriter() | ||
await satellite.set_server("test", server_io) # type: ignore | ||
|
||
# Start satellite | ||
await satellite.event_from_server(RunSatellite().event()) | ||
|
||
# Trigger detection | ||
event = await asyncio.wait_for(event_from_satellite(), timeout=TIMEOUT) | ||
assert event is not None | ||
assert Detection.is_type(event.type), event | ||
|
||
# Pipeline should start | ||
event = await asyncio.wait_for(event_from_satellite(), timeout=TIMEOUT) | ||
assert event is not None | ||
assert RunPipeline.is_type(event.type), event | ||
run_pipeline = RunPipeline.from_event(event) | ||
assert run_pipeline.start_stage == PipelineStage.ASR | ||
assert run_pipeline.end_stage == PipelineStage.TTS | ||
|
||
# Event service should have received detection | ||
await asyncio.wait_for(event_client.detection.wait(), timeout=TIMEOUT) | ||
|
||
# Server should be receiving audio now | ||
assert satellite.is_streaming, "Not streaming" | ||
for _ in range(5): | ||
event = await asyncio.wait_for(event_from_satellite(), timeout=TIMEOUT) | ||
assert event is not None | ||
assert AudioChunk.is_type(event.type) | ||
|
||
# Event service should have received streaming start | ||
await asyncio.wait_for(event_client.streaming_started.wait(), timeout=TIMEOUT) | ||
|
||
# Send transcript | ||
await satellite.event_from_server(Transcript(text="test").event()) | ||
|
||
# Event service should have received transcript | ||
await asyncio.wait_for(event_client.transcript.wait(), timeout=TIMEOUT) | ||
|
||
# Wait for streaming to stop | ||
while satellite.is_streaming: | ||
event = await asyncio.wait_for(event_from_satellite(), timeout=TIMEOUT) | ||
assert event is not None | ||
assert AudioChunk.is_type(event.type) | ||
|
||
# Event service should have received streaming stop | ||
await asyncio.wait_for(event_client.streaming_stopped.wait(), timeout=TIMEOUT) | ||
|
||
# Fake a TTS response | ||
await satellite.event_from_server(Synthesize(text="test").event()) | ||
|
||
# Event service should have received synthesize | ||
await asyncio.wait_for(event_client.synthesize.wait(), timeout=TIMEOUT) | ||
|
||
# Audio start, chunk, stop | ||
await satellite.event_from_server(AUDIO_START.event()) | ||
await asyncio.wait_for(snd_client.audio_start.wait(), timeout=TIMEOUT) | ||
await asyncio.wait_for(event_client.audio_start.wait(), timeout=TIMEOUT) | ||
|
||
# Event service does not get audio chunks, just start/stop | ||
await satellite.event_from_server(AUDIO_CHUNK.event()) | ||
await asyncio.wait_for(snd_client.audio_chunk.wait(), timeout=TIMEOUT) | ||
|
||
await satellite.event_from_server(AUDIO_STOP.event()) | ||
await asyncio.wait_for(snd_client.audio_stop.wait(), timeout=TIMEOUT) | ||
await asyncio.wait_for(event_client.audio_stop.wait(), timeout=TIMEOUT) | ||
|
||
# Stop satellite | ||
await satellite.stop() | ||
await satellite_task |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
[tox] | ||
env_list = | ||
py{39,310,311} | ||
minversion = 4.12.1 | ||
|
||
[testenv] | ||
description = run the tests with pytest | ||
package = wheel | ||
wheel_build_env = .pkg | ||
deps = | ||
pytest>=7,<8 | ||
pytest-asyncio<1 | ||
commands = | ||
pytest {tty:--color=yes} {posargs} |
Oops, something went wrong.