From 641b751fea14d36b8c6e566d816da658fea357b9 Mon Sep 17 00:00:00 2001 From: Cliff Dyer Date: Mon, 8 Jul 2024 17:26:37 -0400 Subject: [PATCH] Custom hosts (#70) * Support custom deepgram hostnames. This is useful to support self-hosted deployments or deepgram in-house development of the SDK itself. * Add options for constructing Deepgram with different hosts * Redact api keys --- src/lib.rs | 109 ++++++++++++++++++++++++++++--- src/redacted.rs | 18 +++++ src/transcription/live.rs | 90 +++++++++++++++++-------- src/transcription/prerecorded.rs | 32 ++++++++- 4 files changed, 211 insertions(+), 38 deletions(-) create mode 100644 src/redacted.rs diff --git a/src/lib.rs b/src/lib.rs index f6635679..7b14578f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -8,23 +8,27 @@ use std::io; +use redacted::RedactedString; use reqwest::{ header::{HeaderMap, HeaderValue}, RequestBuilder, }; use serde::de::DeserializeOwned; use thiserror::Error; +use url::Url; pub mod billing; pub mod invitations; pub mod keys; pub mod members; pub mod projects; +mod redacted; +mod response; pub mod scopes; pub mod transcription; pub mod usage; -mod response; +static DEEPGRAM_BASE_URL: &str = "https://api.deepgram.com"; /// A client for the Deepgram API. /// @@ -32,7 +36,9 @@ mod response; #[derive(Debug, Clone)] pub struct Deepgram { #[cfg_attr(not(feature = "live"), allow(unused))] - api_key: String, + api_key: Option, + #[cfg_attr(not(any(feature = "live", feature = "prerecorded")), allow(unused))] + base_url: Url, client: reqwest::Client, } @@ -81,6 +87,8 @@ type Result = std::result::Result; impl Deepgram { /// Construct a new Deepgram client. /// + /// The client will be pointed at Deepgram's hosted API. + /// /// Create your first API key on the [Deepgram Console][console]. /// /// [console]: https://console.deepgram.com/ @@ -89,6 +97,88 @@ impl Deepgram { /// /// Panics under the same conditions as [`reqwest::Client::new`]. pub fn new>(api_key: K) -> Self { + let api_key = Some(api_key.as_ref().to_owned()); + Self::inner_constructor(DEEPGRAM_BASE_URL.try_into().unwrap(), api_key) + } + + /// Construct a new Deepgram client with the specified base URL. + /// + /// When using a self-hosted instance of deepgram, this will be the + /// host portion of your own instance. For instance, if you would + /// query your deepgram instance at `http://deepgram.internal/v1/listen`, + /// the base_url will be `http://deepgram.internal`. + /// + /// Admin features, such as billing, usage, and key management will + /// still go through the hosted site at `https://api.deepgram.com`. + /// + /// Self-hosted instances do not in general authenticate incoming + /// requests, so unlike in [`Deepgram::new`], so no api key needs to be + /// provided. The SDK will not include an `Authorization` header in its + /// requests. If an API key is required, consider using + /// [`Deepgram::with_base_url_and_api_key`]. + /// + /// [console]: https://console.deepgram.com/ + /// + /// # Example: + /// + /// ``` + /// # use deepgram::Deepgram; + /// let deepgram = Deepgram::with_base_url( + /// "http://localhost:8080", + /// ); + /// ``` + /// + /// # Panics + /// + /// Panics under the same conditions as [`reqwest::Client::new`], or if `base_url` + /// is not a valid URL. + pub fn with_base_url(base_url: U) -> Self + where + U: TryInto, + U::Error: std::fmt::Debug, + { + let base_url = base_url.try_into().expect("base_url must be a valid Url"); + Self::inner_constructor(base_url, None) + } + + /// Construct a new Deepgram client with the specified base URL and + /// API Key. + /// + /// When using a self-hosted instance of deepgram, this will be the + /// host portion of your own instance. For instance, if you would + /// query your deepgram instance at `http://deepgram.internal/v1/listen`, + /// the base_url will be `http://deepgram.internal`. + /// + /// Admin features, such as billing, usage, and key management will + /// still go through the hosted site at `https://api.deepgram.com`. + /// + /// [console]: https://console.deepgram.com/ + /// + /// # Example: + /// + /// ``` + /// # use deepgram::Deepgram; + /// let deepgram = Deepgram::with_base_url_and_api_key( + /// "http://localhost:8080", + /// "apikey12345", + /// ); + /// ``` + /// + /// # Panics + /// + /// Panics under the same conditions as [`reqwest::Client::new`], or if `base_url` + /// is not a valid URL. + pub fn with_base_url_and_api_key(base_url: U, api_key: K) -> Self + where + U: TryInto, + U::Error: std::fmt::Debug, + K: AsRef, + { + let base_url = base_url.try_into().expect("base_url must be a valid Url"); + Self::inner_constructor(base_url, Some(api_key.as_ref().to_owned())) + } + + fn inner_constructor(base_url: Url, api_key: Option) -> Self { static USER_AGENT: &str = concat!( env!("CARGO_PKG_NAME"), "/", @@ -98,17 +188,18 @@ impl Deepgram { let authorization_header = { let mut header = HeaderMap::new(); - header.insert( - "Authorization", - HeaderValue::from_str(&format!("Token {}", api_key.as_ref())) - .expect("Invalid API key"), - ); + if let Some(api_key) = &api_key { + header.insert( + "Authorization", + HeaderValue::from_str(&format!("Token {}", api_key)).expect("Invalid API key"), + ); + } header }; - let api_key = api_key.as_ref().to_owned(); Deepgram { - api_key, + api_key: api_key.map(RedactedString), + base_url, client: reqwest::Client::builder() .user_agent(USER_AGENT) .default_headers(authorization_header) diff --git a/src/redacted.rs b/src/redacted.rs new file mode 100644 index 00000000..bc6c3267 --- /dev/null +++ b/src/redacted.rs @@ -0,0 +1,18 @@ +use std::{fmt, ops::Deref}; + +#[derive(Clone, PartialEq, Eq, PartialOrd, Ord)] +pub(crate) struct RedactedString(pub String); + +impl fmt::Debug for RedactedString { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str("***") + } +} + +impl Deref for RedactedString { + type Target = str; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} diff --git a/src/transcription/live.rs b/src/transcription/live.rs index 1f43ec37..e2e693e3 100644 --- a/src/transcription/live.rs +++ b/src/transcription/live.rs @@ -30,6 +30,8 @@ use crate::{Deepgram, DeepgramError, Result}; use super::Transcription; +static LIVE_LISTEN_URL_PATH: &str = "v1/listen"; + #[derive(Debug)] pub struct StreamRequestBuilder<'a, S, E> where @@ -40,6 +42,7 @@ where encoding: Option, sample_rate: Option, channels: Option, + stream_url: Url, } #[derive(Debug, Serialize, Deserialize)] @@ -95,7 +98,18 @@ impl Transcription<'_> { encoding: None, sample_rate: None, channels: None, + stream_url: self.listen_stream_url(), + } + } + + fn listen_stream_url(&self) -> Url { + let mut url = self.0.base_url.join(LIVE_LISTEN_URL_PATH).unwrap(); + match url.scheme() { + "http" | "ws" => url.set_scheme("ws").unwrap(), + "https" | "wss" => url.set_scheme("wss").unwrap(), + _ => panic!("base_url must have a scheme of http, https, ws, or wss"), } + url } } @@ -201,42 +215,43 @@ where E: Send + std::fmt::Debug, { pub async fn start(self) -> Result>> { - let StreamRequestBuilder { - config, - source, - encoding, - sample_rate, - channels, - } = self; - let mut source = source - .ok_or(DeepgramError::NoSource)? - .map(|res| res.map(|bytes| Message::binary(Vec::from(bytes.as_ref())))); - // This unwrap is safe because we're parsing a static. - let mut base = Url::parse("wss://api.deepgram.com/v1/listen").unwrap(); + let mut url = self.stream_url; { - let mut pairs = base.query_pairs_mut(); - if let Some(encoding) = encoding { - pairs.append_pair("encoding", &encoding); + let mut pairs = url.query_pairs_mut(); + if let Some(encoding) = &self.encoding { + pairs.append_pair("encoding", encoding); } - if let Some(sample_rate) = sample_rate { + if let Some(sample_rate) = self.sample_rate { pairs.append_pair("sample_rate", &sample_rate.to_string()); } - if let Some(channels) = channels { + if let Some(channels) = self.channels { pairs.append_pair("channels", &channels.to_string()); } } - let request = Request::builder() - .method("GET") - .uri(base.to_string()) - .header("authorization", format!("token {}", config.api_key)) - .header("sec-websocket-key", client::generate_key()) - .header("host", "api.deepgram.com") - .header("connection", "upgrade") - .header("upgrade", "websocket") - .header("sec-websocket-version", "13") - .body(())?; + let mut source = self + .source + .ok_or(DeepgramError::NoSource)? + .map(|res| res.map(|bytes| Message::binary(Vec::from(bytes.as_ref())))); + + let request = { + let builder = Request::builder() + .method("GET") + .uri(url.to_string()) + .header("sec-websocket-key", client::generate_key()) + .header("host", "api.deepgram.com") + .header("connection", "upgrade") + .header("upgrade", "websocket") + .header("sec-websocket-version", "13"); + + let builder = if let Some(api_key) = self.config.api_key.as_deref() { + builder.header("authorization", format!("token {}", api_key)) + } else { + builder + }; + builder.body(())? + }; let (ws_stream, _) = tokio_tungstenite::connect_async(request).await?; let (mut write, mut read) = ws_stream.split(); let (mut tx, rx) = mpsc::channel::>(1); @@ -288,3 +303,24 @@ where Ok(rx) } } + +#[cfg(test)] +mod tests { + #[test] + fn test_stream_url() { + let dg = crate::Deepgram::new("token"); + assert_eq!( + dg.transcription().listen_stream_url().to_string(), + "wss://api.deepgram.com/v1/listen", + ); + } + + #[test] + fn test_stream_url_custom_host() { + let dg = crate::Deepgram::with_base_url_and_api_key("http://localhost:8080", "token"); + assert_eq!( + dg.transcription().listen_stream_url().to_string(), + "ws://localhost:8080/v1/listen", + ); + } +} diff --git a/src/transcription/prerecorded.rs b/src/transcription/prerecorded.rs index 23713617..e8315f59 100644 --- a/src/transcription/prerecorded.rs +++ b/src/transcription/prerecorded.rs @@ -5,6 +5,7 @@ //! [api]: https://developers.deepgram.com/api-reference/#transcription-prerecorded use reqwest::RequestBuilder; +use url::Url; use super::Transcription; use crate::send_and_translate_response; @@ -17,7 +18,7 @@ use audio_source::AudioSource; use options::{Options, SerializableOptions}; use response::{CallbackResponse, Response}; -static DEEPGRAM_API_URL_LISTEN: &str = "https://api.deepgram.com/v1/listen"; +static DEEPGRAM_API_URL_LISTEN: &str = "v1/listen"; impl Transcription<'_> { /// Sends a request to Deepgram to transcribe pre-recorded audio. @@ -195,7 +196,7 @@ impl Transcription<'_> { let request_builder = self .0 .client - .post(DEEPGRAM_API_URL_LISTEN) + .post(self.listen_url()) .query(&SerializableOptions(options)); source.fill_body(request_builder) @@ -267,4 +268,31 @@ impl Transcription<'_> { self.make_prerecorded_request_builder(source, options) .query(&[("callback", callback)]) } + + fn listen_url(&self) -> Url { + self.0.base_url.join(DEEPGRAM_API_URL_LISTEN).unwrap() + } +} + +#[cfg(test)] +mod tests { + use crate::Deepgram; + + #[test] + fn listen_url() { + let dg = Deepgram::new("token"); + assert_eq!( + &dg.transcription().listen_url().to_string(), + "https://api.deepgram.com/v1/listen" + ); + } + + #[test] + fn listen_url_custom_host() { + let dg = Deepgram::with_base_url("http://localhost:8888/abc/"); + assert_eq!( + &dg.transcription().listen_url().to_string(), + "http://localhost:8888/abc/v1/listen" + ); + } }