Skip to content

Commit

Permalink
refactor!: provide protocol plugin only (#23)
Browse files Browse the repository at this point in the history
  • Loading branch information
JarbasAl authored Dec 29, 2024
1 parent a965922 commit c16e8fc
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 132 deletions.
113 changes: 0 additions & 113 deletions hivemind_listener/__init__.py
Original file line number Diff line number Diff line change
@@ -1,113 +0,0 @@
import click
from ovos_config import Configuration
from ovos_plugin_manager.stt import OVOSSTTFactory
from ovos_plugin_manager.tts import OVOSTTSFactory
from ovos_plugin_manager.vad import OVOSVADFactory
from ovos_utils.xdg_utils import xdg_data_home

from hivemind_core.database import ClientDatabase
from hivemind_core.scripts import get_db_kwargs
from hivemind_websocket_protocol import HiveMindWebsocketProtocol
from hivemind_core.service import HiveMindService
from hivemind_core.protocol import HiveMindListenerProtocol, ClientCallbacks
from ovos_bus_client.hpm import OVOSProtocol
from hivemind_listener.protocol import PluginOptions, AudioBinaryProtocol
from hivemind_listener.transformers import (DialogTransformersService,
MetadataTransformersService,
UtteranceTransformersService)


@click.command()
@click.option('--wakeword', default="hey_mycroft", type=str,
help="Specify the wake word for the listener. Default is 'hey_mycroft'.")
@click.option('--stt-plugin', default=None, type=str, help="Specify the STT plugin to use.")
@click.option('--tts-plugin', default=None, type=str, help="Specify the TTS plugin to use.")
@click.option('--vad-plugin', default=None, type=str, help="Specify the VAD plugin to use.")
@click.option("--dialog-transformers", multiple=True, type=str,
help=f"dialog transformer plugins to load."
f"Installed plugins: {DialogTransformersService.get_available_plugins() or None}")
@click.option("--utterance-transformers", multiple=True, type=str,
help=f"utterance transformer plugins to load."
f"Installed plugins: {UtteranceTransformersService.get_available_plugins() or None}")
@click.option("--metadata-transformers", multiple=True, type=str,
help=f"metadata transformer plugins to load."
f"Installed plugins: {MetadataTransformersService.get_available_plugins() or None}")
@click.option("--ovos_bus_address", help="Open Voice OS bus address", type=str, default="127.0.0.1")
@click.option("--ovos_bus_port", help="Open Voice OS bus port number", type=int, default=8181)
@click.option("--host", help="HiveMind host", type=str, default="0.0.0.0")
@click.option("--port", help="HiveMind port number", type=int, required=False)
@click.option("--ssl", help="use wss://", type=bool, default=False)
@click.option("--cert_dir", help="HiveMind SSL certificate directory", type=str, default=f"{xdg_data_home()}/hivemind")
@click.option("--cert_name", help="HiveMind SSL certificate file name", type=str, default="hivemind")
@click.option("--db-backend", type=click.Choice(['redis', 'json', 'sqlite'], case_sensitive=False), default='json',
help="Select the database backend to use. Options: redis, sqlite, json.")
@click.option("--db-name", type=str, default="clients",
help="[json/sqlite] The name for the database file. ~/.cache/hivemind-core/{name}")
@click.option("--db-folder", type=str, default="hivemind-core",
help="[json/sqlite] The subfolder where database files are stored. ~/.cache/{db_folder}}")
@click.option("--redis-host", default="localhost", help="[redis] Host for Redis. Default is localhost.")
@click.option("--redis-port", default=6379, help="[redis] Port for Redis. Default is 6379.")
@click.option("--redis-password", required=False, help="[redis] Password for Redis. Default None")
def run_hivemind_listener(wakeword, stt_plugin, tts_plugin, vad_plugin,
dialog_transformers, utterance_transformers, metadata_transformers,
ovos_bus_address: str, ovos_bus_port: int, host: str, port: int,
ssl: bool, cert_dir: str, cert_name: str,
db_backend, db_name, db_folder,
redis_host, redis_port, redis_password
):
"""
Run the HiveMind Listener with configurable plugins.
If a plugin is not specified, the defaults from mycroft.conf will be used.
mycroft.conf will be loaded as usual for plugin settings
"""
kwargs = get_db_kwargs(db_backend, db_name, db_folder, redis_host, redis_port, redis_password)
ovos_bus_config = {
"host": ovos_bus_address or "127.0.0.1",
"port": ovos_bus_port or 8181,
}

websocket_config = {
"host": host,
"port": port or 5678,
"ssl": ssl or False,
"cert_dir": cert_dir,
"cert_name": cert_name,
}

# Configure wakeword, TTS, STT, and VAD plugins
config = Configuration()
if stt_plugin:
config["stt"]["module"] = stt_plugin
if tts_plugin:
config["tts"]["module"] = tts_plugin
if vad_plugin:
config["listener"]["VAD"]["module"] = vad_plugin

AudioBinaryProtocol.plugins = PluginOptions(
wakeword=wakeword,
stt=OVOSSTTFactory.create(config),
tts=OVOSTTSFactory.create(config),
vad=OVOSVADFactory.create(config),
dialog_transformers=dialog_transformers,
utterance_transformers=utterance_transformers,
metadata_transformers=metadata_transformers
)

# Start the service
click.echo(f"Starting HiveMind Listener with wakeword '{wakeword}'...")
service = HiveMindService(agent_protocol=OVOSProtocol,
agent_config=ovos_bus_config,
network_protocol=HiveMindWebsocketProtocol,
network_config=websocket_config,
hm_protocol=HiveMindListenerProtocol,
binary_data_protocol=AudioBinaryProtocol,
callbacks=ClientCallbacks(on_disconnect=AudioBinaryProtocol.stop_listener),
db=ClientDatabase(**kwargs))

service.run()


if __name__ == "__main__":
run_hivemind_listener()
68 changes: 61 additions & 7 deletions hivemind_listener/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
from queue import Queue
from shutil import which
from tempfile import NamedTemporaryFile
from typing import Dict
from typing import List, Tuple, Optional, Union
from typing import Dict, Any, List, Tuple, Optional, Union

import pybase64
import speech_recognition as sr
Expand All @@ -23,6 +22,7 @@
from ovos_plugin_manager.tts import OVOSTTSFactory
from ovos_plugin_manager.vad import OVOSVADFactory
from ovos_plugin_manager.wakewords import OVOSWakeWordFactory
from ovos_simple_listener import SimpleListener, ListenerCallbacks
from ovos_utils.fakebus import FakeBus
from ovos_utils.log import LOG

Expand All @@ -31,8 +31,7 @@
from hivemind_listener.transformers import (DialogTransformersService,
MetadataTransformersService,
UtteranceTransformersService)
from hivemind_plugin_manager.protocols import BinaryDataHandlerProtocol
from ovos_simple_listener import SimpleListener, ListenerCallbacks
from hivemind_plugin_manager.protocols import BinaryDataHandlerProtocol, ClientCallbacks


def bytes2audiodata(data: bytes) -> sr.AudioData:
Expand Down Expand Up @@ -183,13 +182,47 @@ class AudioBinaryProtocol(BinaryDataHandlerProtocol):
utterance_transformers: Optional[UtteranceTransformersService] = None
metadata_transformers: Optional[MetadataTransformersService] = None
dialog_transformers: Optional[DialogTransformersService] = None

config: Dict[str, Any] = field(default_factory=dict)
hm_protocol: Optional['AudioReceiverProtocol'] = None
callbacks: Optional[ClientCallbacks] = None
listeners = {}

def __post_init__(self):
# Configure wakeword, TTS, STT, and VAD plugins
if self.plugins is None:
self.plugins = PluginOptions()

if not self.config:
LOG.warning("No config passed to AudioBinaryProtocol, "
"reading mycroft.conf to select plugins")
# use regular mycroft.conf
from ovos_config import Configuration
config = Configuration()
self.config["stt"] = config["stt"]
self.config["tts"] = config["tts"]
self.config["vad"] = config["listener"]["VAD"]
self.config["wakeword"] = config["listener"]["wake_word"]
self.config["hotwords"] = config["hotwords"]
self.config["utterance_transformers"] = list(config.get("utterance_transformers", {}))
self.config["dialog_transformers"] = list(config.get("dialog_transformers", {}))
self.config["metadata_transformers"] = list(config.get("metadata_transformers", {}))

LOG.debug(f"Loading STT '{self.config['stt']['module']}': {self.config['stt']}")
stt = OVOSSTTFactory.create(self.config["stt"])
LOG.debug(f"Loading TTS '{self.config['tts']['module']}': {self.config['tts']}")
tts = OVOSTTSFactory.create(self.config["tts"])
LOG.debug(f"Loading VAD '{self.config['vad']['module']}': {self.config['vad']}")
vad = OVOSVADFactory.create(self.config["vad"])

self.plugins = PluginOptions(
wakeword=self.config["wakeword"], # TODO - allow per client
stt=stt,
tts=tts,
vad=vad,
dialog_transformers=self.config.get("dialog_transformers", []),
utterance_transformers=self.config.get("utterance_transformers", []),
metadata_transformers=self.config.get("metadata_transformers", [])
)

if self.utterance_transformers is None:
self.utterance_transformers = UtteranceTransformersService(
self.agent_protocol.bus, self.plugins.utterance_transformers)
Expand All @@ -199,6 +232,23 @@ def __post_init__(self):
if self.metadata_transformers is None:
self.metadata_transformers = MetadataTransformersService(
self.agent_protocol.bus, self.plugins.metadata_transformers)

# ensure client audio listener is closed when client disconnects
if not self.callbacks:
self.callbacks = ClientCallbacks(on_disconnect=AudioBinaryProtocol.stop_listener)
else:
original = self.callbacks.on_disconnect

def wrapper(c):
try:
original(c)
except:
raise
finally:
AudioBinaryProtocol.stop_listener(c)

self.callbacks.on_disconnect = wrapper

# agent protocol payloads with binary audio results
self.agent_protocol.bus.on("recognizer_loop:b64_audio", self.handle_audio_b64)
self.agent_protocol.bus.on("recognizer_loop:b64_transcribe", self.handle_transcribe_b64)
Expand Down Expand Up @@ -226,10 +276,14 @@ def on_msg(m: str):

bus.on("message", on_msg)

# TODO allow different per client
ww_cfg = self.config["hotwords"][self.plugins.wakeword]
LOG.debug(f"Loading client Wake Word '{self.plugins.wakeword}': {ww_cfg}")

AudioBinaryProtocol.listeners[client.peer] = SimpleListener(
mic=FakeMicrophone(),
vad=self.plugins.vad,
wakeword=OVOSWakeWordFactory.create_hotword(self.plugins.wakeword), # TODO allow different per client
wakeword=OVOSWakeWordFactory.create_hotword(self.plugins.wakeword, self.config["hotwords"]),
stt=self.plugins.stt,
callbacks=AudioCallbacks(bus)
)
Expand Down
11 changes: 2 additions & 9 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,11 +1,4 @@
ovos-simple-listener
hivemind_bus_client>=0.1.0,<1.0.0
ovos-plugin-manager<1.0.0
hivemind-core>=2.0.0,<3.0.0
hivemind-plugin-manager>=0.0.2,<1.0.0
ovos-bus-client>=1.3.0,<2.0.0
click
pybase64

# TODO - for backwards compat, will be removed in a future release
hivemind-websocket-protocol>=0.0.1,<1.0.0
hivemind-plugin-manager>=0.3.0,<1.0.0
pybase64
6 changes: 3 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

def get_version():
"""Find the version of the package"""
version = None
version_file = os.path.join(BASEDIR, "hivemind_listener", "version.py")
major, minor, build, alpha = (None, None, None, None)
with open(version_file) as f:
Expand Down Expand Up @@ -39,6 +38,7 @@ def required(requirements_file):
]
return [pkg for pkg in requirements if pkg.strip() and not pkg.startswith("#")]

PLUGIN_ENTRY_POINT = 'hivemind-audio-binary-protocol-plugin=hivemind_listener.protocol:AudioBinaryProtocol'

setup(
name="hivemind-listener",
Expand All @@ -52,6 +52,6 @@ def required(requirements_file):
author_email="[email protected]",
description="Mesh Networking utilities for OpenVoiceOS",
entry_points={
"console_scripts": ["hivemind-listener=hivemind_listener:run_hivemind_listener"]
},
'hivemind.binary.protocol': PLUGIN_ENTRY_POINT
}
)

0 comments on commit c16e8fc

Please sign in to comment.