Skip to content
This repository has been archived by the owner on Dec 31, 2024. It is now read-only.

Add speech router and transcribe endpoint #23

Merged
merged 2 commits into from
Dec 5, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@ WORKDIR /app
COPY pyproject.toml poetry.lock ./
COPY src ./src

RUN pip install poetry && \
RUN apt-get update && \
apt-get install -y ffmpeg && \
pip install poetry && \
Comment on lines +11 to +12
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
apt-get install -y ffmpeg && \
pip install poetry && \
apt-get install -y ffmpeg
RUN pip install poetry && \

I'd do separate run layers for apt and pip/poetry so it can cache when poetry breaks (way more likely than apt breaking)

poetry config virtualenvs.create false && \
poetry install --no-dev --no-interaction --no-ansi && \
rm -rf /root/.cache
Expand Down
29 changes: 28 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ openai = "^1.3.7"
psycopg2-binary = "^2.9.9"
boto3 = "^1.33.4"
python-multipart = "^0.0.6"
ffmpeg-python = "^0.2.0"

[tool.poetry.group.dev.dependencies]
pytest = "^7.4.3"
Expand Down
4 changes: 4 additions & 0 deletions src/linguaweb_api/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ class Settings(pydantic_settings.BaseSettings): # type: ignore[valid-type, misc
"tts-1",
json_schema_extra={"env": "OPENAI_TTS_MODEL"},
)
OPENAI_STT_MODEL: openai_constants.STTModels = pydantic.Field(
"whisper-1",
json_schema_extra={"env": "OPENAI_STT_MODEL"},
)

S3_ENDPOINT_URL: str | None = pydantic.Field(
None,
Expand Down
2 changes: 2 additions & 0 deletions src/linguaweb_api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from linguaweb_api.microservices import sql
from linguaweb_api.routers.admin import views as admin_views
from linguaweb_api.routers.health import views as health_views
from linguaweb_api.routers.speech import views as speech_views
from linguaweb_api.routers.words import views as words_views

settings = config.get_settings()
Expand Down Expand Up @@ -47,6 +48,7 @@
base_router = fastapi.APIRouter(prefix="/api/v1")
base_router.include_router(admin_views.router)
base_router.include_router(health_views.router)
base_router.include_router(speech_views.router)
base_router.include_router(words_views.router)
app.include_router(base_router)

Expand Down
25 changes: 24 additions & 1 deletion src/linguaweb_api/microservices/openai.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""This module contains interactions with OpenAI models."""
import abc
import logging
import pathlib
from typing import Any, Literal, TypedDict

import fastapi
Expand All @@ -13,6 +14,7 @@
OPENAI_API_KEY = settings.OPENAI_API_KEY
OPENAI_GPT_MODEL = settings.OPENAI_GPT_MODEL
OPENAI_TTS_MODEL = settings.OPENAI_TTS_MODEL
OPENAI_STT_MODEL = settings.OPENAI_STT_MODEL
OPENAI_VOICE = settings.OPENAI_VOICE
LOGGER_NAME = settings.LOGGER_NAME

Expand Down Expand Up @@ -96,9 +98,30 @@ async def run(self, text: str) -> bytes:
The model's response.
"""
response = self.client.audio.speech.create(
model=OPENAI_TTS_MODEL,
model=OPENAI_TTS_MODEL.value,
voice=OPENAI_VOICE.value,
input=text,
)

return b"".join(response.iter_bytes())


class SpeechToText(OpenAIBaseClass):
"""A class for running the Speech-To-Text models."""

async def run(self, audio_file: pathlib.Path | str) -> str:
"""Runs the Speech-To-Text model.

Args:
audio_file: The audio to convert to text.
model: The name of the Speech-To-Text model to use.

Returns:
The model's response.
"""
with pathlib.Path(audio_file).open("rb") as audio:
return self.client.audio.transcriptions.create(
model=OPENAI_STT_MODEL.value,
file=audio,
response_format="text",
) # type: ignore[return-value] # response_format overrides output type.
6 changes: 6 additions & 0 deletions src/linguaweb_api/microservices/openai_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@ class TTSModels(str, enum.Enum):
TTS1 = "tts-1"


class STTModels(str, enum.Enum):
"""Supported Speech-To-Text models."""

WHISPER1 = "whisper-1"


class GPTModels(str, enum.Enum):
"""Supported GPT models."""

Expand Down
1 change: 1 addition & 0 deletions src/linguaweb_api/routers/speech/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Endpoint definitions for the speech router."""
66 changes: 66 additions & 0 deletions src/linguaweb_api/routers/speech/controller.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
"""Speech router controller."""
import logging
import pathlib
import tempfile

import fastapi
import ffmpeg
from fastapi import status

from linguaweb_api.core import config
from linguaweb_api.microservices import openai

settings = config.get_settings()
LOGGER_NAME = settings.LOGGER_NAME

logger = logging.getLogger(LOGGER_NAME)

TARGET_FILE_FORMAT = ".mp3"


async def transcribe(audio: fastapi.UploadFile) -> str:
"""Transcribes audio using OpenAI's Whisper.

Args:
audio: The audio file.

Returns:
str: The transcription of the audio as a string. The string is
stripped of newlines and converted to lowercase.
"""
logger.debug("Transcribing audio.")
with tempfile.TemporaryDirectory() as temp_dir:
target_path = pathlib.Path(temp_dir) / f"audio{TARGET_FILE_FORMAT}"
_convert_audio(audio, temp_dir, target_path)
return await openai.SpeechToText().run(target_path)


def _convert_audio(
audio: fastapi.UploadFile,
directory: str,
target_path: pathlib.Path,
) -> None:
"""Converts the audio to the target format.

Args:
audio: The audio file.
directory: The directory to save the audio file to.
target_path: The path to save the audio file to.
"""
if audio.filename is None:
raise fastapi.HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="The audio file must have a filename.",
)

extension = pathlib.Path(audio.filename).suffix
if extension == TARGET_FILE_FORMAT:
logger.debug("Audio is already in the correct format.")
with target_path.open("wb") as target_file:
target_file.write(audio.file.read())
else:
logger.debug("Converting audio to correct format.")
audio_path = pathlib.Path(directory) / f"audio{extension}"
with audio_path.open("wb") as audio_file:
audio_file.write(audio.file.read())
ffmpeg.input(str(audio_path)).output(str(target_path)).run()
42 changes: 42 additions & 0 deletions src/linguaweb_api/routers/speech/views.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
"""View definitions for the speech router."""
import logging

import fastapi
from fastapi import status

from linguaweb_api.core import config
from linguaweb_api.routers.speech import controller

settings = config.get_settings()
LOGGER_NAME = settings.LOGGER_NAME

logger = logging.getLogger(LOGGER_NAME)

router = fastapi.APIRouter(prefix="/speech", tags=["speech"])


@router.post(
"/transcribe",
response_model=str,
status_code=status.HTTP_200_OK,
summary="Transcribes audio.",
description="Endpoint that uses OpenAI's Whisper API to transcribe audio.",
responses={
status.HTTP_400_BAD_REQUEST: {
"description": "The audio file must have a filename.",
},
},
)
async def transcribe(audio: fastapi.UploadFile = fastapi.File(...)) -> str:
"""Transcribes audio using OpenAI's Whisper API.

Args:
audio: The audio file.

Returns:
The transcription of the audio as a string.
"""
logger.debug("Transcribing audio.")
transcription = controller.transcribe(audio)
logger.debug("Transcribed audio.")
return await transcription
1 change: 1 addition & 0 deletions src/linguaweb_api/routers/words/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ class WordData(pydantic.BaseModel):
"""Word data, without the word itself."""

id: int
word: str
description: str
synonyms: list[str]
antonyms: list[str]
Expand Down
2 changes: 2 additions & 0 deletions tests/endpoint/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ class Endpoints(str, enum.Enum):
GET_AUDIO = f"{API_ROOT}/words/download/{{audio_id}}"
POST_CHECK_WORD = f"{API_ROOT}/words/check/{{word_id}}"

POST_SPEECH_TRANSCRIBE = f"{API_ROOT}/speech/transcribe"

GET_HEALTH = f"{API_ROOT}/health"


Expand Down
67 changes: 67 additions & 0 deletions tests/endpoint/test_speech.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
"""Tests for the speech endpoints."""
import array
import tempfile
import wave
from collections.abc import Generator
from typing import Any

import ffmpeg
import pytest
import pytest_mock
from fastapi import status, testclient

from linguaweb_api.microservices import openai
from tests.endpoint import conftest


@pytest.fixture()
def wav_file() -> Generator[str, Any, None]:
"""Returns a path to a temporary wav file."""
with tempfile.NamedTemporaryFile(suffix=".wav") as f:
wav = wave.open(f, "w")
wav.setnchannels(1)
wav.setsampwidth(2)
wav.setframerate(44100)
wav.writeframes(array.array("h", [0] * 44100).tobytes())
wav.close()
yield f.name


@pytest.fixture()
def mp3_file(wav_file: str) -> Generator[str, Any, None]:
"""Returns a path to a temporary mp3 file."""
with tempfile.NamedTemporaryFile(suffix=".mp3") as f:
ffmpeg.input(wav_file).output(f.name).overwrite_output().run()
yield f.name


@pytest.fixture()
def files(wav_file: str, mp3_file: str) -> dict[str, str]:
"""Workaround for pytest.mark.parametrize not supporting fixtures."""
return {"wav": wav_file, "mp3": mp3_file}


@pytest.mark.parametrize("file_type", ["wav", "mp3"])
def test_transcribe(
mocker: pytest_mock.MockerFixture,
client: testclient.TestClient,
endpoints: conftest.Endpoints,
files: dict[str, str],
file_type: str,
) -> None:
"""Tests the transcribe endpoint."""
expected_transcription = "Expected transcription"
mock_stt_run = mocker.patch.object(
openai.SpeechToText,
"run",
return_value=expected_transcription,
)

response = client.post(
endpoints.POST_SPEECH_TRANSCRIBE,
files={"audio": open(files[file_type], "rb")}, # noqa: SIM115, PTH123
)

mock_stt_run.assert_called_once()
assert response.status_code == status.HTTP_200_OK
assert response.json() == expected_transcription
Loading