-
-
Notifications
You must be signed in to change notification settings - Fork 32.6k
/
Copy pathstt.py
147 lines (125 loc) · 5.05 KB
/
stt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
"""Support for the Google Cloud STT service."""
from __future__ import annotations
from collections.abc import AsyncGenerator, AsyncIterable
import logging
from google.api_core.exceptions import GoogleAPIError, Unauthenticated
from google.cloud import speech_v1
from homeassistant.components.stt import (
AudioBitRates,
AudioChannels,
AudioCodecs,
AudioFormats,
AudioSampleRates,
SpeechMetadata,
SpeechResult,
SpeechResultState,
SpeechToTextEntity,
)
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant
from homeassistant.helpers import device_registry as dr
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
from .const import (
CONF_SERVICE_ACCOUNT_INFO,
CONF_STT_MODEL,
DEFAULT_STT_MODEL,
DOMAIN,
STT_LANGUAGES,
)
_LOGGER = logging.getLogger(__name__)
async def async_setup_entry(
hass: HomeAssistant,
config_entry: ConfigEntry,
async_add_entities: AddConfigEntryEntitiesCallback,
) -> None:
"""Set up Google Cloud speech platform via config entry."""
service_account_info = config_entry.data[CONF_SERVICE_ACCOUNT_INFO]
client = speech_v1.SpeechAsyncClient.from_service_account_info(service_account_info)
async_add_entities([GoogleCloudSpeechToTextEntity(config_entry, client)])
class GoogleCloudSpeechToTextEntity(SpeechToTextEntity):
"""Google Cloud STT entity."""
def __init__(
self,
entry: ConfigEntry,
client: speech_v1.SpeechAsyncClient,
) -> None:
"""Init Google Cloud STT entity."""
self._attr_unique_id = f"{entry.entry_id}"
self._attr_name = entry.title
self._attr_device_info = dr.DeviceInfo(
identifiers={(DOMAIN, entry.entry_id)},
manufacturer="Google",
model="Cloud",
entry_type=dr.DeviceEntryType.SERVICE,
)
self._entry = entry
self._client = client
self._model = entry.options.get(CONF_STT_MODEL, DEFAULT_STT_MODEL)
@property
def supported_languages(self) -> list[str]:
"""Return a list of supported languages."""
return STT_LANGUAGES
@property
def supported_formats(self) -> list[AudioFormats]:
"""Return a list of supported formats."""
return [AudioFormats.WAV, AudioFormats.OGG]
@property
def supported_codecs(self) -> list[AudioCodecs]:
"""Return a list of supported codecs."""
return [AudioCodecs.PCM, AudioCodecs.OPUS]
@property
def supported_bit_rates(self) -> list[AudioBitRates]:
"""Return a list of supported bitrates."""
return [AudioBitRates.BITRATE_16]
@property
def supported_sample_rates(self) -> list[AudioSampleRates]:
"""Return a list of supported samplerates."""
return [AudioSampleRates.SAMPLERATE_16000]
@property
def supported_channels(self) -> list[AudioChannels]:
"""Return a list of supported channels."""
return [AudioChannels.CHANNEL_MONO]
async def async_process_audio_stream(
self, metadata: SpeechMetadata, stream: AsyncIterable[bytes]
) -> SpeechResult:
"""Process an audio stream to STT service."""
streaming_config = speech_v1.StreamingRecognitionConfig(
config=speech_v1.RecognitionConfig(
encoding=(
speech_v1.RecognitionConfig.AudioEncoding.OGG_OPUS
if metadata.codec == AudioCodecs.OPUS
else speech_v1.RecognitionConfig.AudioEncoding.LINEAR16
),
sample_rate_hertz=metadata.sample_rate,
language_code=metadata.language,
model=self._model,
)
)
async def request_generator() -> AsyncGenerator[
speech_v1.StreamingRecognizeRequest
]:
# The first request must only contain a streaming_config
yield speech_v1.StreamingRecognizeRequest(streaming_config=streaming_config)
# All subsequent requests must only contain audio_content
async for audio_content in stream:
yield speech_v1.StreamingRecognizeRequest(audio_content=audio_content)
try:
responses = await self._client.streaming_recognize(
requests=request_generator(),
timeout=10,
)
transcript = ""
async for response in responses:
_LOGGER.debug("response: %s", response)
if not response.results:
continue
result = response.results[0]
if not result.alternatives:
continue
transcript += response.results[0].alternatives[0].transcript
except GoogleAPIError as err:
_LOGGER.error("Error occurred during Google Cloud STT call: %s", err)
if isinstance(err, Unauthenticated):
self._entry.async_start_reauth(self.hass)
return SpeechResult(None, SpeechResultState.ERROR)
return SpeechResult(transcript, SpeechResultState.SUCCESS)