Skip to content

Commit

Permalink
audio: add transcriptions to OpenAI backend
Browse files Browse the repository at this point in the history
  • Loading branch information
stintel committed May 16, 2024
1 parent 6980366 commit 0d98896
Show file tree
Hide file tree
Showing 8 changed files with 230 additions and 6 deletions.
28 changes: 28 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ path = "src/client/embed.rs"
[dependencies]
anyhow = { version = "1.0.83", features = ["backtrace"] }
async-stream = "0.3.5"
axum = "0.7.5"
axum = { version = "0.7.5", features = ["multipart"] }
axum-prometheus = "0.6.1"
axum-tracing-opentelemetry = "0.18.1"
bytemuck = "1.16.0"
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ Point your clients at AI Router and use any combination of Triton Inference Serv
| Inference Type | OpenAI backend | Triton backend |
| :--------------------------- | :----------------: | :----------------: |
| Audio > Create Speech | :white_check_mark: | :x: |
| Audio > Create Transcription | :x: | :x: |
| Audio > Create Transcription | :white_check_mark: | :x: |
| Audio > Create Translation | :x: | :x: |
| Chat | :white_check_mark: | :white_check_mark: |
| Embeddings | :white_check_mark: | :white_check_mark: |
Expand Down
5 changes: 5 additions & 0 deletions ai-router.toml.example
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,11 @@ max_input = 512

# OpenAI tts-1 example
[models.audio_speech."tts-1"]

# Audio Transcriptions

# OpenAI whisper-1 example
[models.audio_transcriptions."whisper-1"]
backend = "openai"

# Chat completions
Expand Down
32 changes: 30 additions & 2 deletions src/backend/openai/routes/audio.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
use axum::response::{IntoResponse, Response};
use axum::{
response::{IntoResponse, Response},
Json,
};
use openai_dive::v1::{
api::Client,
resources::audio::{AudioSpeechParameters, AudioSpeechResponseFormat},
resources::audio::{
AudioOutputFormat, AudioSpeechParameters, AudioSpeechResponseFormat,
AudioTranscriptionParameters,
},
};

use crate::errors::{transform_openai_dive_apierror, AiRouterError};
Expand Down Expand Up @@ -31,3 +37,25 @@ pub async fn speech(

Ok(([("content-type", content_type)], response.bytes).into_response())
}

pub async fn transcriptions(
client: &Client,
parameters: AudioTranscriptionParameters,
) -> Result<Response, AiRouterError<String>> {
let response_format = parameters.response_format.clone();

let response = client
.audio()
.create_transcription(parameters)
.await
.map_err(|e| transform_openai_dive_apierror(&e))?;

match response_format {
None | Some(AudioOutputFormat::Json | AudioOutputFormat::VerboseJson) => {
Ok(Json(response).into_response())
}
Some(AudioOutputFormat::Srt | AudioOutputFormat::Text | AudioOutputFormat::Vtt) => {
Ok(response.into_response())
}
}
}
1 change: 1 addition & 0 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ pub enum AiRouterBackendType {
#[serde(rename_all = "snake_case")]
pub enum AiRouterModelType {
AudioSpeech,
AudioTranscriptions,
ChatCompletions,
Embeddings,
}
Expand Down
162 changes: 160 additions & 2 deletions src/routes/audio.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,18 @@
use std::io::Read;
use std::path::Path;
use std::str::FromStr;
use std::sync::Arc;

use axum::extract::State as AxumState;
use anyhow::Context;
use axum::extract::{Multipart, State as AxumState};
use axum::response::Response;
use axum::Json;
use openai_dive::v1::resources::audio::AudioSpeechParameters;
use bytes::Bytes;
use openai_dive::v1::resources::audio::{
AudioOutputFormat, AudioSpeechParameters, AudioTranscriptionBytes, AudioTranscriptionFile,
AudioTranscriptionParameters, TimestampGranularity,
};
use tracing::instrument;

use crate::backend::openai::routes as openai_routes;
use crate::config::AiRouterModelType;
Expand Down Expand Up @@ -45,3 +54,152 @@ pub async fn speech(
parameters.model.clone(),
))
}

#[instrument(level = "debug", skip(state, multipart))]
pub async fn transcriptions(
AxumState(state): AxumState<Arc<State>>,
// Multipart must be the last argument
// <https://github.com/tokio-rs/axum/discussions/1600>
multipart: Multipart,
) -> Result<Response, AiRouterError<String>> {
let mut parameters = match build_transcription_parameters(multipart).await {
Ok(o) => o,
Err(e) => {
return Err(e);
}
};

if let Some(models) = state
.config
.models
.get(&AiRouterModelType::AudioTranscriptions)
{
if let Some(model) = models.get(&parameters.model) {
if let Some(backend_model) = model.backend_model.clone() {
parameters.model = backend_model;
}

let model_backend = model.backend.as_ref().map_or("default", |m| m);

let Some(backend) = state.backends.get(model_backend) else {
return Err(AiRouterError::InternalServerError::<String>(format!(
"backend {model_backend} not found"
)));
};

match &backend.client {
BackendTypes::OpenAI(c) => {
return openai_routes::audio::transcriptions(c, parameters).await
}
BackendTypes::Triton(_c) => {
return Err(AiRouterError::InternalServerError::<String>(String::from(
"audio transcriptions to Triton backend not implemented yet",
)));
}
}
}
}

return Err(AiRouterError::ModelNotFound::<String>(parameters.model));
}

#[instrument(level = "debug", skip(multipart))]
pub async fn build_transcription_parameters(
mut multipart: Multipart,
) -> Result<AudioTranscriptionParameters, AiRouterError<String>> {
let mut parameters = AudioTranscriptionParameters {
file: AudioTranscriptionFile::File(String::new()),
language: None,
model: String::new(),
prompt: None,
response_format: None,
temperature: None,
timestamp_granularities: None,
};

let mut timestamp_granularities: Vec<TimestampGranularity> = Vec::new();

while let Ok(Some(field)) = multipart.next_field().await {
tracing::trace!("{field:#?}");
let field_name = field
.name()
.ok_or(AiRouterError::InternalServerError::<String>(String::from(
"failed to read field name",
)))?
.to_string();

if field_name == "file" {
let filename: String =
String::from(field.file_name().context("failed to read field filename")?);

let extension = Path::new(&filename)
.extension()
.context("failed to get extension of uploaded file")?;

tracing::debug!("filename: {filename} - extension: {extension:?}");

if extension != "wav" {
return Err(AiRouterError::InternalServerError::<String>(format!(
"extension {extension:?} not supported",
)));
}

let field_data_vec = get_field_data_vec(&field.bytes().await?, &field_name)?;
let bytes = AudioTranscriptionBytes {
bytes: Bytes::copy_from_slice(&field_data_vec),
filename,
};

parameters.file = AudioTranscriptionFile::Bytes(bytes);
} else if field_name == "language" {
let field_data_vec = get_field_data_vec(&field.bytes().await?, &field_name)?;

parameters.language = Some(String::from_utf8(field_data_vec)?);
} else if field_name == "model" {
let field_data_vec = get_field_data_vec(&field.bytes().await?, &field_name)?;

parameters.model = String::from_utf8(field_data_vec)?;
} else if field_name == "prompt" {
let field_data_vec = get_field_data_vec(&field.bytes().await?, &field_name)?;

parameters.prompt = Some(String::from_utf8(field_data_vec)?);
} else if field_name == "response_format" {
let field_data_vec = get_field_data_vec(&field.bytes().await?, &field_name)?;
let response_format = String::from_utf8(field_data_vec)?;
let response_format = AudioOutputFormat::from_str(&response_format)?;

parameters.response_format = Some(response_format);
} else if field_name == "temperature" {
let field_data_vec = get_field_data_vec(&field.bytes().await?, &field_name)?;

parameters.temperature = Some(String::from_utf8(field_data_vec)?.parse()?);
} else if field_name == "timestamp_granularities[]" {
let field_data_vec = get_field_data_vec(&field.bytes().await?, &field_name)?;
let granularity = String::from_utf8(field_data_vec)?;
let granularity = TimestampGranularity::from_str(&granularity)?;

timestamp_granularities.push(granularity);
}
}

if !timestamp_granularities.is_empty()
&& parameters.response_format == Some(AudioOutputFormat::VerboseJson)
{
parameters.timestamp_granularities = Some(timestamp_granularities);
}

Ok(parameters)
}

fn get_field_data_vec(data: &Bytes, name: &str) -> Result<Vec<u8>, AiRouterError<String>> {
let field_data_vec: Vec<u8> = match data.bytes().collect() {
Ok(o) => o,
Err(e) => {
return Err(AiRouterError::InternalServerError::<String>(format!(
"failed to read {name} field: {e}",
)));
}
};

Ok(field_data_vec)
}
4 changes: 4 additions & 0 deletions src/startup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ pub async fn run_server(config_file: &AiRouterConfigFile) -> anyhow::Result<()>

let app = Router::new()
.route("/v1/audio/speech", post(routes::audio::speech))
.route(
"/v1/audio/transcriptions",
post(routes::audio::transcriptions),
)
.route("/v1/chat/completions", post(routes::chat::completion))
.route("/v1/completions", post(routes::completions::completion))
.route("/v1/embeddings", post(routes::embeddings::embed))
Expand Down

0 comments on commit 0d98896

Please sign in to comment.