Skip to content

Commit

Permalink
refactor: Migrate RemoteWhisperTranscriber to OpenAI SDK. (#6149)
Browse files Browse the repository at this point in the history
* Migrate RemoteWhisperTranscriber to OpenAI SDK

* Migrate RemoteWhisperTranscriber to OpenAI SDK

* Remove unnecessary imports

* Add release notes

* Fix api_key serialization

* Fix linting

* Apply suggestions from code review

Co-authored-by: ZanSara <[email protected]>

* Add additional tests for api_key

* Adapt .run() to take ByteStream inputs

* Update docstrings

* Rework implementation to use io.BytesIO

* Update error message

* Add default file name

---------

Co-authored-by: ZanSara <[email protected]>
  • Loading branch information
awinml and ZanSara authored Oct 26, 2023
1 parent 26a2204 commit 5f35e7d
Show file tree
Hide file tree
Showing 3 changed files with 266 additions and 247 deletions.
177 changes: 89 additions & 88 deletions haystack/preview/components/audio/whisper_remote.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,17 @@
from typing import List, Optional, Dict, Any, Union, BinaryIO, Literal, get_args, Sequence

import os
import json
import io
import logging
from pathlib import Path

from haystack.preview.utils import request_with_retry
from haystack.preview import component, Document, default_to_dict
import os
from typing import Any, Dict, List, Optional

logger = logging.getLogger(__name__)
import openai

from haystack.preview import Document, component, default_from_dict, default_to_dict
from haystack.preview.dataclasses import ByteStream

OPENAI_TIMEOUT = float(os.environ.get("HAYSTACK_OPENAI_TIMEOUT_SEC", 600))
logger = logging.getLogger(__name__)


WhisperRemoteModel = Literal["whisper-1"]
API_BASE_URL = "https://api.openai.com/v1"


@component
Expand All @@ -30,108 +27,112 @@ class RemoteWhisperTranscriber:

def __init__(
self,
api_key: str,
model_name: WhisperRemoteModel = "whisper-1",
api_base: str = "https://api.openai.com/v1",
whisper_params: Optional[Dict[str, Any]] = None,
api_key: Optional[str] = None,
model_name: str = "whisper-1",
organization: Optional[str] = None,
api_base_url: str = API_BASE_URL,
**kwargs,
):
"""
Transcribes a list of audio files into a list of Documents.
:param api_key: OpenAI API key.
:param model_name: Name of the model to use. It now accepts only `whisper-1`.
:param organization: The OpenAI-Organization ID, defaults to `None`. For more details, see OpenAI
[documentation](https://platform.openai.com/docs/api-reference/requesting-organization).
:param api_base: OpenAI base URL, defaults to `"https://api.openai.com/v1"`.
:param kwargs: Other parameters to use for the model. These parameters are all sent directly to the OpenAI
endpoint. See OpenAI [documentation](https://platform.openai.com/docs/api-reference/audio) for more details.
Some of the supported parameters:
- `language`: The language of the input audio.
Supplying the input language in ISO-639-1 format
will improve accuracy and latency.
- `prompt`: An optional text to guide the model's
style or continue a previous audio segment.
The prompt should match the audio language.
- `response_format`: The format of the transcript
output, in one of these options: json, text, srt,
verbose_json, or vtt. Defaults to "json". Currently only "json" is supported.
- `temperature`: The sampling temperature, between 0
and 1. Higher values like 0.8 will make the output more
random, while lower values like 0.2 will make it more
focused and deterministic. If set to 0, the model will
use log probability to automatically increase the
temperature until certain thresholds are hit.
"""
if model_name not in get_args(WhisperRemoteModel):
raise ValueError(
f"Model name not recognized. Choose one among: " f"{', '.join(get_args(WhisperRemoteModel))}."
)
if not api_key:
raise ValueError("API key is None.")

# if the user does not provide the API key, check if it is set in the module client
api_key = api_key or openai.api_key
if api_key is None:
try:
api_key = os.environ["OPENAI_API_KEY"]
except KeyError as e:
raise ValueError(
"RemoteWhisperTranscriber expects an OpenAI API key. "
"Set the OPENAI_API_KEY environment variable (recommended) or pass it explicitly."
) from e
openai.api_key = api_key

self.organization = organization
self.model_name = model_name
self.api_key = api_key
self.api_base = api_base
self.whisper_params = whisper_params or {}
self.api_base_url = api_base_url

@component.output_types(documents=List[Document])
def run(self, audio_files: List[Path], whisper_params: Optional[Dict[str, Any]] = None):
"""
Transcribe the audio files into a list of Documents, one for each input file.
# Only response_format = "json" is supported
whisper_params = kwargs
if whisper_params.get("response_format") != "json":
logger.warning(
"RemoteWhisperTranscriber only supports 'response_format: json'. This parameter will be overwritten."
)
whisper_params["response_format"] = "json"
self.whisper_params = whisper_params

For the supported audio formats, languages, and other parameters, see the
[Whisper API documentation](https://platform.openai.com/docs/guides/speech-to-text) and the official Whisper
[github repo](https://github.com/openai/whisper).
if organization is not None:
openai.organization = organization

:param audio_files: a list of paths or binary streams to transcribe
:returns: a list of Documents, one for each file. The content of the document is the transcription text,
while the document's metadata contains all the other values returned by the Whisper model, such as the
alignment data. Another key called `audio_file` contains the path to the audio file used for the
transcription.
def to_dict(self) -> Dict[str, Any]:
"""
if whisper_params is None:
whisper_params = self.whisper_params
Serialize this component to a dictionary.
This method overrides the default serializer in order to
avoid leaking the `api_key` value passed to the constructor.
"""
return default_to_dict(
self,
model_name=self.model_name,
organization=self.organization,
api_base_url=self.api_base_url,
**self.whisper_params,
)

documents = self.transcribe(audio_files, **whisper_params)
return {"documents": documents}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "RemoteWhisperTranscriber":
"""
Deserialize this component from a dictionary.
"""
return default_from_dict(cls, data)

def transcribe(self, audio_files: Sequence[Union[str, Path, BinaryIO]], **kwargs) -> List[Document]:
@component.output_types(documents=List[Document])
def run(self, streams: List[ByteStream]):
"""
Transcribe the audio files into a list of Documents, one for each input file.
For the supported audio formats, languages, and other parameters, see the
[Whisper API documentation](https://platform.openai.com/docs/guides/speech-to-text) and the official Whisper
[github repo](https://github.com/openai/whisper).
:param audio_files: a list of paths or binary streams to transcribe
:returns: a list of transcriptions.
:param audio_files: a list of ByteStream objects to transcribe.
:returns: a list of Documents, one for each file. The content of the document is the transcription text.
"""
transcriptions = self._raw_transcribe(audio_files=audio_files, **kwargs)
documents = []
for audio, transcript in zip(audio_files, transcriptions):
content = transcript.pop("text")
if not isinstance(audio, (str, Path)):
audio = "<<binary stream>>"
doc = Document(text=content, metadata={"audio_file": audio, **transcript})
documents.append(doc)
return documents

def _raw_transcribe(self, audio_files: Sequence[Union[str, Path, BinaryIO]], **kwargs) -> List[Dict[str, Any]]:
"""
Transcribe the given audio files. Returns a list of strings.

For the supported audio formats, languages, and other parameters, see the
[Whisper API documentation](https://platform.openai.com/docs/guides/speech-to-text) and the official Whisper
[github repo](https://github.com/openai/whisper).
for stream in streams:
file = io.BytesIO(stream.data)
try:
file.name = stream.metadata["file_path"]
except KeyError:
file.name = "audio_input.wav"

:param audio_files: a list of paths or binary streams to transcribe.
:param kwargs: any other parameters that Whisper API can understand.
:returns: a list of transcriptions as they are produced by the Whisper API (JSON).
"""
translate = kwargs.pop("translate", False)
url = f"{self.api_base}/audio/{'translations' if translate else 'transcriptions'}"
data = {"model": self.model_name, **kwargs}
headers = {"Authorization": f"Bearer {self.api_key}"}

transcriptions = []
for audio_file in audio_files:
if isinstance(audio_file, (str, Path)):
audio_file = open(audio_file, "rb")

request_files = ("file", (audio_file.name, audio_file, "application/octet-stream"))
response = request_with_retry(
method="post", url=url, data=data, headers=headers, files=[request_files], timeout=OPENAI_TIMEOUT
)
transcription = json.loads(response.content)

transcriptions.append(transcription)
return transcriptions
content = openai.Audio.transcribe(file=file, model=self.model_name, **self.whisper_params)
doc = Document(text=content["text"], metadata=stream.metadata)
documents.append(doc)

def to_dict(self) -> Dict[str, Any]:
"""
This method overrides the default serializer in order to avoid leaking the `api_key` value passed
to the constructor.
"""
return default_to_dict(
self, model_name=self.model_name, api_base=self.api_base, whisper_params=self.whisper_params
)
return {"documents": documents}
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
preview:
- |
Migrate RemoteWhisperTranscriber to OpenAI SDK.
Loading

0 comments on commit 5f35e7d

Please sign in to comment.